update utils/data_utils.py.

This commit is contained in:
Brainway
2022-10-18 03:36:28 +00:00
committed by Gitee
parent 0c2e0dccac
commit 8de11a03d8

View File

@ -13,39 +13,6 @@ from .autoaugment import AutoAugImageNetPolicy
logger = logging.getLogger(__name__)
def get_loader_new():
train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.RandomCrop((320, 320)), #448
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_transform = transforms.Compose([transforms.Resize((320, 320), Image.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
trainset = emptyJudge(root='./emptyJudge5', is_train=True, transform=train_transform)
testset = emptyJudge(root='./emptyJudge5', is_train=False, transform=test_transform)
train_sampler = RandomSampler(trainset)
test_sampler = SequentialSampler(testset)
train_loader = DataLoader(trainset,
sampler=train_sampler,
batch_size=8,
num_workers=4,
drop_last=True,
pin_memory=True)
test_loader = DataLoader(testset,
sampler=test_sampler,
batch_size=8,
num_workers=4,
pin_memory=True) if testset is not None else None
print('emptyJudge5 getdataloader ok!')
return train_loader, test_loader
#根据不同数据集获取data_loader
def get_loader(args):
if args.local_rank not in [-1, 0]:
torch.distributed.barrier()
@ -132,9 +99,9 @@ def get_loader(args):
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
trainset = INat2017(args.data_root, 'train', train_transform)
testset = INat2017(args.data_root, 'val', test_transform)
elif args.dataset == 'emptyJudge5':
elif args.dataset == 'emptyJudge5' or args.dataset == 'emptyJudge4':
train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.RandomCrop((600, 600)), #448
transforms.RandomCrop((320, 320)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
@ -142,7 +109,7 @@ def get_loader(args):
# transforms.CenterCrop((448, 448)),
# transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
test_transform = transforms.Compose([transforms.Resize((320, 320), Image.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
trainset = emptyJudge(root=args.data_root, is_train=True, transform=train_transform)
@ -165,5 +132,4 @@ def get_loader(args):
num_workers=4,
pin_memory=True) if testset is not None else None
print('emptyJudge5 getdataloader ok!')
return train_loader, test_loader