From 5a8c6a5d2ec95112d0af1669c1f743a3677b1ed6 Mon Sep 17 00:00:00 2001 From: Brainway Date: Wed, 26 Oct 2022 01:24:48 +0000 Subject: [PATCH] update utils/dataset.py. --- utils/dataset.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/utils/dataset.py b/utils/dataset.py index 7a06567..0381dc3 100755 --- a/utils/dataset.py +++ b/utils/dataset.py @@ -5,7 +5,7 @@ from os.path import join import numpy as np import scipy from scipy import io -import scipy.misc +import imageio from PIL import Image import pandas as pd import matplotlib.pyplot as plt @@ -16,7 +16,7 @@ from torchvision.datasets import VisionDataset from torchvision.datasets.folder import default_loader from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg - +#对各种数据集的底层读取 class emptyJudge(): def __init__(self, root, is_train=True, data_len=None, transform=None): self.root = root @@ -37,12 +37,12 @@ class emptyJudge(): train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i] test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i] if self.is_train: - self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in + self.train_img = [imageio.imread(os.path.join(self.root, 'images', train_file)) for train_file in train_file_list[:data_len]] self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len] self.train_imgname = [x for x in train_file_list[:data_len]] if not self.is_train: - self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in + self.test_img = [imageio.imread(os.path.join(self.root, 'images', test_file)) for test_file in test_file_list[:data_len]] self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len] self.test_imgname = [x for x in test_file_list[:data_len]] @@ -51,7 +51,7 @@ class emptyJudge(): if self.is_train: img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index] if len(img.shape) == 2: - img = np.stack([img] * 3, 2) + img = np.stack([img] * 3, 2) #拼接为三维数组,[3,width,highth] img = Image.fromarray(img, mode='RGB') if self.transform is not None: img = self.transform(img) @@ -91,12 +91,12 @@ class CUB(): train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i] test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i] if self.is_train: - self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in + self.train_img = [imageio.imread(os.path.join(self.root, 'images', train_file)) for train_file in train_file_list[:data_len]] self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len] self.train_imgname = [x for x in train_file_list[:data_len]] if not self.is_train: - self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in + self.test_img = [imageio.imread(os.path.join(self.root, 'images', test_file)) for test_file in test_file_list[:data_len]] self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len] self.test_imgname = [x for x in test_file_list[:data_len]]