From a94d0f19e3646b2d6a9f7cf1b699492439531430 Mon Sep 17 00:00:00 2001 From: Brainway Date: Wed, 26 Oct 2022 01:24:21 +0000 Subject: [PATCH] update utils/data_utils.py. --- utils/data_utils.py | 40 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 37 insertions(+), 3 deletions(-) diff --git a/utils/data_utils.py b/utils/data_utils.py index 9ecfe0c..4da4c0c 100755 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -13,6 +13,39 @@ from .autoaugment import AutoAugImageNetPolicy logger = logging.getLogger(__name__) +def get_loader_new(): + + train_transform = transforms.Compose([transforms.Resize((320, 320), Image.BILINEAR), + transforms.RandomCrop((320, 320)), #448 + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + test_transform = transforms.Compose([transforms.Resize((320, 320), Image.BILINEAR), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) + trainset = emptyJudge(root='../emptyJudge2', is_train=True, transform=train_transform) + testset = emptyJudge(root='../emptyJudge2', is_train=False, transform=test_transform) + + + train_sampler = RandomSampler(trainset) + test_sampler = SequentialSampler(testset) + train_loader = DataLoader(trainset, + sampler=train_sampler, + batch_size=8, + num_workers=4, + drop_last=True, + pin_memory=True) + test_loader = DataLoader(testset, + sampler=test_sampler, + batch_size=8, + num_workers=4, + pin_memory=True) if testset is not None else None + + print('emptyJudge5 getdataloader ok!') + return train_loader, test_loader + + +#根据不同数据集,获取data_loader def get_loader(args): if args.local_rank not in [-1, 0]: torch.distributed.barrier() @@ -99,9 +132,9 @@ def get_loader(args): transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) trainset = INat2017(args.data_root, 'train', train_transform) testset = INat2017(args.data_root, 'val', test_transform) - elif args.dataset == 'emptyJudge5' or args.dataset == 'emptyJudge4': + elif args.dataset == 'emptyJudge2': train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), - transforms.RandomCrop((320, 320)), + transforms.RandomCrop((600, 600)), #448 transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) @@ -109,7 +142,7 @@ def get_loader(args): # transforms.CenterCrop((448, 448)), # transforms.ToTensor(), # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) - test_transform = transforms.Compose([transforms.Resize((320, 320), Image.BILINEAR), + test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) trainset = emptyJudge(root=args.data_root, is_train=True, transform=train_transform) @@ -132,4 +165,5 @@ def get_loader(args): num_workers=4, pin_memory=True) if testset is not None else None + print('emptyJudge2 getdataloader ok!') return train_loader, test_loader