diff --git a/utils/data_utils.py b/utils/data_utils.py index 0507615..9ecfe0c 100755 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -13,39 +13,6 @@ from .autoaugment import AutoAugImageNetPolicy logger = logging.getLogger(__name__) -def get_loader_new(): - - train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), - transforms.RandomCrop((320, 320)), #448 - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) - test_transform = transforms.Compose([transforms.Resize((320, 320), Image.BILINEAR), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) - trainset = emptyJudge(root='./emptyJudge5', is_train=True, transform=train_transform) - testset = emptyJudge(root='./emptyJudge5', is_train=False, transform=test_transform) - - - train_sampler = RandomSampler(trainset) - test_sampler = SequentialSampler(testset) - train_loader = DataLoader(trainset, - sampler=train_sampler, - batch_size=8, - num_workers=4, - drop_last=True, - pin_memory=True) - test_loader = DataLoader(testset, - sampler=test_sampler, - batch_size=8, - num_workers=4, - pin_memory=True) if testset is not None else None - - print('emptyJudge5 getdataloader ok!') - return train_loader, test_loader - - -#根据不同数据集,获取data_loader def get_loader(args): if args.local_rank not in [-1, 0]: torch.distributed.barrier() @@ -132,9 +99,9 @@ def get_loader(args): transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) trainset = INat2017(args.data_root, 'train', train_transform) testset = INat2017(args.data_root, 'val', test_transform) - elif args.dataset == 'emptyJudge5': + elif args.dataset == 'emptyJudge5' or args.dataset == 'emptyJudge4': train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), - transforms.RandomCrop((600, 600)), #448 + transforms.RandomCrop((320, 320)), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) @@ -142,7 +109,7 @@ def get_loader(args): # transforms.CenterCrop((448, 448)), # transforms.ToTensor(), # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) - test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR), + test_transform = transforms.Compose([transforms.Resize((320, 320), Image.BILINEAR), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) trainset = emptyJudge(root=args.data_root, is_train=True, transform=train_transform) @@ -165,5 +132,4 @@ def get_loader(args): num_workers=4, pin_memory=True) if testset is not None else None - print('emptyJudge5 getdataloader ok!') return train_loader, test_loader