181 lines
5.4 KiB
Python
181 lines
5.4 KiB
Python
import os
|
||
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
|
||
|
||
import time
|
||
import numpy as np
|
||
import cv2
|
||
import matplotlib.pyplot as plt
|
||
import albumentations as albu
|
||
import torch
|
||
import segmentation_models_pytorch as smp
|
||
from torch.utils.data import Dataset as BaseDataset
|
||
import imageio
|
||
|
||
|
||
# ---------------------------------------------------------------
|
||
### Dataloader
|
||
|
||
class Dataset(BaseDataset):
|
||
"""CamVid数据集。进行图像读取,图像增强增强和图像预处理.
|
||
|
||
Args:
|
||
images_dir (str): 图像文件夹所在路径
|
||
masks_dir (str): 图像分割的标签图像所在路径
|
||
class_values (list): 用于图像分割的所有类别数
|
||
augmentation (albumentations.Compose): 数据传输管道
|
||
preprocessing (albumentations.Compose): 数据预处理
|
||
"""
|
||
# CamVid数据集中用于图像分割的所有标签类别
|
||
#CLASSES = ['sky', 'building', 'pole', 'road', 'pavement',
|
||
# 'tree', 'signsymbol', 'fence', 'car',
|
||
# 'pedestrian', 'bicyclist', 'unlabelled']
|
||
CLASSES = ['front']
|
||
|
||
def __init__(
|
||
self,
|
||
images_dir,
|
||
# masks_dir,
|
||
classes=None,
|
||
augmentation=None,
|
||
preprocessing=None,
|
||
):
|
||
self.ids = os.listdir(images_dir)
|
||
self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
|
||
|
||
# convert str names to class values on masks
|
||
self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
|
||
|
||
self.augmentation = augmentation
|
||
self.preprocessing = preprocessing
|
||
|
||
def __getitem__(self, i):
|
||
|
||
# read data
|
||
image = cv2.imread(self.images_fps[i])
|
||
image = cv2.resize(image, (512, 512)) # 改变图片分辨率
|
||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||
|
||
# 图像增强应用
|
||
if self.augmentation:
|
||
sample = self.augmentation(image=image)
|
||
image = sample['image']
|
||
|
||
# 图像预处理应用
|
||
if self.preprocessing:
|
||
sample = self.preprocessing(image=image)
|
||
image = sample['image']
|
||
|
||
return image
|
||
|
||
def __len__(self):
|
||
return len(self.ids)
|
||
|
||
# ---------------------------------------------------------------
|
||
|
||
def get_validation_augmentation():
|
||
"""调整图像使得图片的分辨率长宽能被32整除"""
|
||
test_transform = [
|
||
albu.PadIfNeeded(384, 480)
|
||
]
|
||
return albu.Compose(test_transform)
|
||
|
||
|
||
def to_tensor(x, **kwargs):
|
||
return x.transpose(2, 0, 1).astype('float32')
|
||
|
||
|
||
def get_preprocessing(preprocessing_fn):
|
||
"""进行图像预处理操作
|
||
|
||
Args:
|
||
preprocessing_fn (callbale): 数据规范化的函数
|
||
(针对每种预训练的神经网络)
|
||
Return:
|
||
transform: albumentations.Compose
|
||
"""
|
||
|
||
_transform = [
|
||
albu.Lambda(image=preprocessing_fn),
|
||
albu.Lambda(image=to_tensor),
|
||
]
|
||
return albu.Compose(_transform)
|
||
|
||
|
||
# 图像分割结果的可视化展示
|
||
def visualize(**images):
|
||
"""PLot images in one row."""
|
||
n = len(images)
|
||
plt.figure(figsize=(16, 5))
|
||
for i, (name, image) in enumerate(images.items()):
|
||
plt.subplot(1, n, i + 1)
|
||
plt.xticks([])
|
||
plt.yticks([])
|
||
plt.title(' '.join(name.split('_')).title())
|
||
plt.imshow(image)
|
||
plt.show()
|
||
|
||
|
||
# ---------------------------------------------------------------
|
||
if __name__ == '__main__':
|
||
|
||
DATA_DIR = './data/CamVid/'
|
||
|
||
x_test_dir = os.path.join(DATA_DIR, 'abc')
|
||
|
||
img_test = cv2.imread('data/CamVid/abc/pic_unscan_front.jpg')
|
||
height = img_test.shape[0]
|
||
weight = img_test.shape[1]
|
||
|
||
ENCODER = 'resnet18'
|
||
ENCODER_WEIGHTS = 'imagenet'
|
||
CLASSES = ['front']
|
||
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
|
||
DEVICE = 'cuda'
|
||
|
||
# 按照权重预训练的相同方法准备数据
|
||
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)
|
||
|
||
# 加载最佳模型
|
||
best_model = torch.load('./module/best_model.pth')
|
||
|
||
# 创建检测数据集
|
||
predict_dataset = Dataset(
|
||
x_test_dir,
|
||
augmentation=get_validation_augmentation(),
|
||
preprocessing=get_preprocessing(preprocessing_fn),
|
||
classes=CLASSES,
|
||
)
|
||
|
||
# # 对检测图像进行图像分割并进行图像可视化展示
|
||
# predict_dataset_vis = Dataset(
|
||
# x_test_dir,
|
||
# classes=CLASSES,
|
||
# )
|
||
|
||
for i in range(len(predict_dataset)):
|
||
# 原始图像image_vis
|
||
#image_vis = predict_dataset_vis[i].astype('uint8')
|
||
image = predict_dataset[i]
|
||
print('>>>>> {}>>>size{}'.format(type(image), image.shape))
|
||
|
||
# 通过图像分割得到的0-1图像pr_mask
|
||
T1 = time.time()
|
||
x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
|
||
pr_mask = best_model.predict(x_tensor)
|
||
T2 = time.time()
|
||
print('>>>>>> {}'.format(T2-T1))
|
||
pr_mask = (pr_mask.squeeze().cpu().numpy().round())
|
||
|
||
# 恢复图片原来的分辨率
|
||
#image_vis = cv2.resize(image_vis, (weight, height))
|
||
#pr_mask = cv2.resize(pr_mask, (weight, height))
|
||
pr_mask = cv2.resize(pr_mask[0,:,:], (weight, height))
|
||
# 保存图像分割后的黑白结果图像
|
||
imageio.imwrite('f_test_out.png', pr_mask)
|
||
# 原始图像和图像分割结果的可视化展示
|
||
# visualize(
|
||
# image=image_vis,
|
||
# predicted_mask=pr_mask
|
||
# )
|
||
|