update utils/dataset.py.

This commit is contained in:
Brainway
2022-10-26 01:24:48 +00:00
committed by Gitee
parent a94d0f19e3
commit 5a8c6a5d2e

View File

@ -5,7 +5,7 @@ from os.path import join
import numpy as np
import scipy
from scipy import io
import scipy.misc
import imageio
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
@ -16,7 +16,7 @@ from torchvision.datasets import VisionDataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg
#对各种数据集的底层读取
class emptyJudge():
def __init__(self, root, is_train=True, data_len=None, transform=None):
self.root = root
@ -37,12 +37,12 @@ class emptyJudge():
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
if self.is_train:
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
self.train_img = [imageio.imread(os.path.join(self.root, 'images', train_file)) for train_file in
train_file_list[:data_len]]
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
self.train_imgname = [x for x in train_file_list[:data_len]]
if not self.is_train:
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
self.test_img = [imageio.imread(os.path.join(self.root, 'images', test_file)) for test_file in
test_file_list[:data_len]]
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
self.test_imgname = [x for x in test_file_list[:data_len]]
@ -51,7 +51,7 @@ class emptyJudge():
if self.is_train:
img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
if len(img.shape) == 2:
img = np.stack([img] * 3, 2)
img = np.stack([img] * 3, 2) #拼接为三维数组,[3,width,highth]
img = Image.fromarray(img, mode='RGB')
if self.transform is not None:
img = self.transform(img)
@ -91,12 +91,12 @@ class CUB():
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
if self.is_train:
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
self.train_img = [imageio.imread(os.path.join(self.root, 'images', train_file)) for train_file in
train_file_list[:data_len]]
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
self.train_imgname = [x for x in train_file_list[:data_len]]
if not self.is_train:
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
self.test_img = [imageio.imread(os.path.join(self.root, 'images', test_file)) for test_file in
test_file_list[:data_len]]
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
self.test_imgname = [x for x in test_file_list[:data_len]]