update utils/dataset.py.
This commit is contained in:
@ -5,7 +5,7 @@ from os.path import join
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy
|
import scipy
|
||||||
from scipy import io
|
from scipy import io
|
||||||
import scipy.misc
|
import imageio
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -16,7 +16,7 @@ from torchvision.datasets import VisionDataset
|
|||||||
from torchvision.datasets.folder import default_loader
|
from torchvision.datasets.folder import default_loader
|
||||||
from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg
|
from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg
|
||||||
|
|
||||||
|
#对各种数据集的底层读取
|
||||||
class emptyJudge():
|
class emptyJudge():
|
||||||
def __init__(self, root, is_train=True, data_len=None, transform=None):
|
def __init__(self, root, is_train=True, data_len=None, transform=None):
|
||||||
self.root = root
|
self.root = root
|
||||||
@ -37,12 +37,12 @@ class emptyJudge():
|
|||||||
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
||||||
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
self.train_img = [imageio.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||||
train_file_list[:data_len]]
|
train_file_list[:data_len]]
|
||||||
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
||||||
self.train_imgname = [x for x in train_file_list[:data_len]]
|
self.train_imgname = [x for x in train_file_list[:data_len]]
|
||||||
if not self.is_train:
|
if not self.is_train:
|
||||||
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
self.test_img = [imageio.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||||
test_file_list[:data_len]]
|
test_file_list[:data_len]]
|
||||||
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
||||||
self.test_imgname = [x for x in test_file_list[:data_len]]
|
self.test_imgname = [x for x in test_file_list[:data_len]]
|
||||||
@ -51,7 +51,7 @@ class emptyJudge():
|
|||||||
if self.is_train:
|
if self.is_train:
|
||||||
img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
|
img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
|
||||||
if len(img.shape) == 2:
|
if len(img.shape) == 2:
|
||||||
img = np.stack([img] * 3, 2)
|
img = np.stack([img] * 3, 2) #拼接为三维数组,[3,width,highth]
|
||||||
img = Image.fromarray(img, mode='RGB')
|
img = Image.fromarray(img, mode='RGB')
|
||||||
if self.transform is not None:
|
if self.transform is not None:
|
||||||
img = self.transform(img)
|
img = self.transform(img)
|
||||||
@ -91,12 +91,12 @@ class CUB():
|
|||||||
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
||||||
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
||||||
if self.is_train:
|
if self.is_train:
|
||||||
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
self.train_img = [imageio.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||||
train_file_list[:data_len]]
|
train_file_list[:data_len]]
|
||||||
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
||||||
self.train_imgname = [x for x in train_file_list[:data_len]]
|
self.train_imgname = [x for x in train_file_list[:data_len]]
|
||||||
if not self.is_train:
|
if not self.is_train:
|
||||||
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
self.test_img = [imageio.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||||
test_file_list[:data_len]]
|
test_file_list[:data_len]]
|
||||||
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
||||||
self.test_imgname = [x for x in test_file_list[:data_len]]
|
self.test_imgname = [x for x in test_file_list[:data_len]]
|
||||||
|
Reference in New Issue
Block a user