44 lines
1.3 KiB
Python
44 lines
1.3 KiB
Python
import os
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
|
|
|
import time
|
|
import numpy as np
|
|
import cv2
|
|
import torch
|
|
import segmentation_models_pytorch as smp
|
|
import imageio
|
|
|
|
def predict(best_model=None):
|
|
DATA_DIR = './data/CamVid/'
|
|
x_test_dir = os.path.join(DATA_DIR, 'abc')
|
|
|
|
img_test = cv2.imread('data/CamVid/abc/pic_unscan_front.jpg')
|
|
height = img_test.shape[0]
|
|
weight = img_test.shape[1]
|
|
|
|
ENCODER = 'resnet18'
|
|
ENCODER_WEIGHTS = 'imagenet'
|
|
CLASSES = ['front']
|
|
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
|
|
DEVICE = 'cuda'
|
|
|
|
# 加载最佳模型
|
|
best_model = torch.load('./module/best_model.pth')
|
|
|
|
for imgpath in os.listdir(x_test_dir):
|
|
img = cv2.imread(os.sep.join([x_test_dir, imgpath]))
|
|
img = cv2.resize(img, (512,512))
|
|
img = img.transpose((2,0,1))
|
|
T1 = time.time()
|
|
x_tensor = torch.from_numpy(img).float().to(DEVICE).unsqueeze(0)
|
|
pr_mask = best_model.predict(x_tensor)
|
|
T2 = time.time()
|
|
print('>>>> {}'.format(T2-T1))
|
|
pr_mask = (pr_mask.squeeze().cpu().numpy().round())
|
|
print('>>>>>>>>> pr_mask{}'.format(pr_mask.shape))
|
|
pr_mask = cv2.resize(pr_mask[0,:,:], (weight, height))
|
|
imageio.imwrite('f_test_out.png', pr_mask)
|
|
|
|
if __name__ == '__main__':
|
|
predict()
|