Files
ieemoo-ai-filtervideo/tools/segpredict.py
2023-02-22 15:39:24 +08:00

44 lines
1.3 KiB
Python

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import time
import numpy as np
import cv2
import torch
import segmentation_models_pytorch as smp
import imageio
def predict(best_model=None):
DATA_DIR = './data/CamVid/'
x_test_dir = os.path.join(DATA_DIR, 'abc')
img_test = cv2.imread('data/CamVid/abc/pic_unscan_front.jpg')
height = img_test.shape[0]
weight = img_test.shape[1]
ENCODER = 'resnet18'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['front']
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda'
# 加载最佳模型
best_model = torch.load('./module/best_model.pth')
for imgpath in os.listdir(x_test_dir):
img = cv2.imread(os.sep.join([x_test_dir, imgpath]))
img = cv2.resize(img, (512,512))
img = img.transpose((2,0,1))
T1 = time.time()
x_tensor = torch.from_numpy(img).float().to(DEVICE).unsqueeze(0)
pr_mask = best_model.predict(x_tensor)
T2 = time.time()
print('>>>> {}'.format(T2-T1))
pr_mask = (pr_mask.squeeze().cpu().numpy().round())
print('>>>>>>>>> pr_mask{}'.format(pr_mask.shape))
pr_mask = cv2.resize(pr_mask[0,:,:], (weight, height))
imageio.imwrite('f_test_out.png', pr_mask)
if __name__ == '__main__':
predict()