This commit is contained in:
2023-02-14 15:36:00 +08:00
parent 3014078e68
commit 9f35445a4f
3 changed files with 10 additions and 8 deletions

View File

@ -1,6 +1,7 @@
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import time
import numpy as np
import cv2
import matplotlib.pyplot as plt
@ -127,8 +128,8 @@ if __name__ == '__main__':
print(type(img_test))
print('>>>>>>shape {}'.format(img_test.shape))
#ENCODER = 'resnet18'
ENCODER = 'mobilenet_v2'
ENCODER = 'resnet18'
#ENCODER = 'mobilenet_v2'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['front']
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
@ -160,11 +161,12 @@ if __name__ == '__main__':
image = predict_dataset[i]
# 通过图像分割得到的0-1图像pr_mask
T1 = time.time()
x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
pr_mask = best_model.predict(x_tensor)
T2 = time.time()
print('>>>>>> {}'.format(T2-T1))
pr_mask = (pr_mask.squeeze().cpu().numpy().round())
print('>>>>>>> pr_mask{}'.format(pr_mask.shape))
print('>>>>>>{} {}'.format(height, weight))
# 恢复图片原来的分辨率
#image_vis = cv2.resize(image_vis, (weight, height))