update
This commit is contained in:
@ -101,7 +101,7 @@ def get_loader(args):
|
||||
testset = INat2017(args.data_root, 'val', test_transform)
|
||||
elif args.dataset == 'emptyJudge5' or args.dataset == 'emptyJudge4':
|
||||
train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.RandomCrop((320, 320)),
|
||||
transforms.RandomCrop((448, 448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
@ -109,7 +109,7 @@ def get_loader(args):
|
||||
# transforms.CenterCrop((448, 448)),
|
||||
# transforms.ToTensor(),
|
||||
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
test_transform = transforms.Compose([transforms.Resize((320, 320), Image.BILINEAR),
|
||||
test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
trainset = emptyJudge(root=args.data_root, is_train=True, transform=train_transform)
|
||||
|
@ -5,7 +5,7 @@ from os.path import join
|
||||
import numpy as np
|
||||
import scipy
|
||||
from scipy import io
|
||||
import imageio
|
||||
import scipy.misc
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
@ -16,7 +16,7 @@ from torchvision.datasets import VisionDataset
|
||||
from torchvision.datasets.folder import default_loader
|
||||
from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg
|
||||
|
||||
#对各种数据集的底层读取
|
||||
|
||||
class emptyJudge():
|
||||
def __init__(self, root, is_train=True, data_len=None, transform=None):
|
||||
self.root = root
|
||||
@ -37,12 +37,12 @@ class emptyJudge():
|
||||
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
||||
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
||||
if self.is_train:
|
||||
self.train_img = [imageio.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||
train_file_list[:data_len]]
|
||||
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
||||
self.train_imgname = [x for x in train_file_list[:data_len]]
|
||||
if not self.is_train:
|
||||
self.test_img = [imageio.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||
test_file_list[:data_len]]
|
||||
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
||||
self.test_imgname = [x for x in test_file_list[:data_len]]
|
||||
@ -51,7 +51,7 @@ class emptyJudge():
|
||||
if self.is_train:
|
||||
img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack([img] * 3, 2) #拼接为三维数组,[3,width,highth]
|
||||
img = np.stack([img] * 3, 2)
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
@ -91,12 +91,12 @@ class CUB():
|
||||
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
||||
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
||||
if self.is_train:
|
||||
self.train_img = [imageio.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||
train_file_list[:data_len]]
|
||||
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
||||
self.train_imgname = [x for x in train_file_list[:data_len]]
|
||||
if not self.is_train:
|
||||
self.test_img = [imageio.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||
test_file_list[:data_len]]
|
||||
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
||||
self.test_imgname = [x for x in test_file_list[:data_len]]
|
||||
|
Reference in New Issue
Block a user