update utils/dataset.py.
This commit is contained in:
@ -5,7 +5,7 @@ from os.path import join
|
||||
import numpy as np
|
||||
import scipy
|
||||
from scipy import io
|
||||
import scipy.misc
|
||||
import imageio
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
@ -16,7 +16,7 @@ from torchvision.datasets import VisionDataset
|
||||
from torchvision.datasets.folder import default_loader
|
||||
from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg
|
||||
|
||||
|
||||
#对各种数据集的底层读取
|
||||
class emptyJudge():
|
||||
def __init__(self, root, is_train=True, data_len=None, transform=None):
|
||||
self.root = root
|
||||
@ -37,12 +37,12 @@ class emptyJudge():
|
||||
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
||||
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
||||
if self.is_train:
|
||||
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||
self.train_img = [imageio.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||
train_file_list[:data_len]]
|
||||
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
||||
self.train_imgname = [x for x in train_file_list[:data_len]]
|
||||
if not self.is_train:
|
||||
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||
self.test_img = [imageio.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||
test_file_list[:data_len]]
|
||||
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
||||
self.test_imgname = [x for x in test_file_list[:data_len]]
|
||||
@ -51,7 +51,7 @@ class emptyJudge():
|
||||
if self.is_train:
|
||||
img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack([img] * 3, 2)
|
||||
img = np.stack([img] * 3, 2) #拼接为三维数组,[3,width,highth]
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
@ -91,12 +91,12 @@ class CUB():
|
||||
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
||||
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
||||
if self.is_train:
|
||||
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||
self.train_img = [imageio.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||
train_file_list[:data_len]]
|
||||
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
||||
self.train_imgname = [x for x in train_file_list[:data_len]]
|
||||
if not self.is_train:
|
||||
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||
self.test_img = [imageio.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||
test_file_list[:data_len]]
|
||||
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
||||
self.test_imgname = [x for x in test_file_list[:data_len]]
|
||||
|
Reference in New Issue
Block a user