update utils/dataset.py.

This commit is contained in:
Brainway
2022-10-26 01:24:48 +00:00
committed by Gitee
parent a94d0f19e3
commit 5a8c6a5d2e

View File

@ -5,7 +5,7 @@ from os.path import join
import numpy as np import numpy as np
import scipy import scipy
from scipy import io from scipy import io
import scipy.misc import imageio
from PIL import Image from PIL import Image
import pandas as pd import pandas as pd
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
@ -16,7 +16,7 @@ from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg
#对各种数据集的底层读取
class emptyJudge(): class emptyJudge():
def __init__(self, root, is_train=True, data_len=None, transform=None): def __init__(self, root, is_train=True, data_len=None, transform=None):
self.root = root self.root = root
@ -37,12 +37,12 @@ class emptyJudge():
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i] train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i] test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
if self.is_train: if self.is_train:
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in self.train_img = [imageio.imread(os.path.join(self.root, 'images', train_file)) for train_file in
train_file_list[:data_len]] train_file_list[:data_len]]
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len] self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
self.train_imgname = [x for x in train_file_list[:data_len]] self.train_imgname = [x for x in train_file_list[:data_len]]
if not self.is_train: if not self.is_train:
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in self.test_img = [imageio.imread(os.path.join(self.root, 'images', test_file)) for test_file in
test_file_list[:data_len]] test_file_list[:data_len]]
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len] self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
self.test_imgname = [x for x in test_file_list[:data_len]] self.test_imgname = [x for x in test_file_list[:data_len]]
@ -51,7 +51,7 @@ class emptyJudge():
if self.is_train: if self.is_train:
img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index] img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
if len(img.shape) == 2: if len(img.shape) == 2:
img = np.stack([img] * 3, 2) img = np.stack([img] * 3, 2) #拼接为三维数组,[3,width,highth]
img = Image.fromarray(img, mode='RGB') img = Image.fromarray(img, mode='RGB')
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
@ -91,12 +91,12 @@ class CUB():
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i] train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i] test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
if self.is_train: if self.is_train:
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in self.train_img = [imageio.imread(os.path.join(self.root, 'images', train_file)) for train_file in
train_file_list[:data_len]] train_file_list[:data_len]]
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len] self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
self.train_imgname = [x for x in train_file_list[:data_len]] self.train_imgname = [x for x in train_file_list[:data_len]]
if not self.is_train: if not self.is_train:
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in self.test_img = [imageio.imread(os.path.join(self.root, 'images', test_file)) for test_file in
test_file_list[:data_len]] test_file_list[:data_len]]
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len] self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
self.test_imgname = [x for x in test_file_list[:data_len]] self.test_imgname = [x for x in test_file_list[:data_len]]