first push
This commit is contained in:
75
RAFT/demo.py
Executable file
75
RAFT/demo.py
Executable file
@ -0,0 +1,75 @@
|
||||
import sys
|
||||
sys.path.append('core')
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import cv2
|
||||
import glob
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from raft import RAFT
|
||||
from utils import flow_viz
|
||||
from utils.utils import InputPadder
|
||||
|
||||
|
||||
|
||||
DEVICE = 'cuda'
|
||||
|
||||
def load_image(imfile):
|
||||
img = np.array(Image.open(imfile)).astype(np.uint8)
|
||||
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
||||
return img[None].to(DEVICE)
|
||||
|
||||
|
||||
def viz(img, flo):
|
||||
img = img[0].permute(1,2,0).cpu().numpy()
|
||||
flo = flo[0].permute(1,2,0).cpu().numpy()
|
||||
|
||||
# map flow to rgb image
|
||||
flo = flow_viz.flow_to_image(flo)
|
||||
img_flo = np.concatenate([img, flo], axis=0)
|
||||
|
||||
# import matplotlib.pyplot as plt
|
||||
# plt.imshow(img_flo / 255.0)
|
||||
# plt.show()
|
||||
|
||||
cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
|
||||
cv2.waitKey()
|
||||
|
||||
|
||||
def demo(args):
|
||||
model = torch.nn.DataParallel(RAFT(args))
|
||||
model.load_state_dict(torch.load(args.model))
|
||||
|
||||
model = model.module
|
||||
model.to(DEVICE)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
images = glob.glob(os.path.join(args.path, '*.png')) + \
|
||||
glob.glob(os.path.join(args.path, '*.jpg'))
|
||||
|
||||
images = sorted(images)
|
||||
for imfile1, imfile2 in zip(images[:-1], images[1:]):
|
||||
image1 = load_image(imfile1)
|
||||
image2 = load_image(imfile2)
|
||||
|
||||
padder = InputPadder(image1.shape)
|
||||
image1, image2 = padder.pad(image1, image2)
|
||||
|
||||
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
|
||||
viz(image1, flow_up)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', help="restore checkpoint")
|
||||
parser.add_argument('--path', help="dataset for evaluation")
|
||||
parser.add_argument('--small', action='store_true', help='use small model')
|
||||
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
|
||||
args = parser.parse_args()
|
||||
|
||||
demo(args)
|
Reference in New Issue
Block a user