update
This commit is contained in:
135
utils/data_utils.py
Executable file
135
utils/data_utils.py
Executable file
@ -0,0 +1,135 @@
|
||||
import logging
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler
|
||||
|
||||
from .dataset import CUB, CarsDataset, NABirds, dogs, INat2017, emptyJudge
|
||||
from .autoaugment import AutoAugImageNetPolicy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_loader(args):
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier()
|
||||
|
||||
if args.dataset == 'CUB_200_2011':
|
||||
train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.RandomCrop((448, 448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.CenterCrop((448, 448)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
trainset = CUB(root=args.data_root, is_train=True, transform=train_transform)
|
||||
testset = CUB(root=args.data_root, is_train=False, transform=test_transform)
|
||||
elif args.dataset == 'car':
|
||||
trainset = CarsDataset(os.path.join(args.data_root,'devkit/cars_train_annos.mat'),
|
||||
os.path.join(args.data_root,'cars_train'),
|
||||
os.path.join(args.data_root,'devkit/cars_meta.mat'),
|
||||
# cleaned=os.path.join(data_dir,'cleaned.dat'),
|
||||
transform=transforms.Compose([
|
||||
transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.RandomCrop((448, 448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
AutoAugImageNetPolicy(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
)
|
||||
testset = CarsDataset(os.path.join(args.data_root,'cars_test_annos_withlabels.mat'),
|
||||
os.path.join(args.data_root,'cars_test'),
|
||||
os.path.join(args.data_root,'devkit/cars_meta.mat'),
|
||||
# cleaned=os.path.join(data_dir,'cleaned_test.dat'),
|
||||
transform=transforms.Compose([
|
||||
transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.CenterCrop((448, 448)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
)
|
||||
elif args.dataset == 'dog':
|
||||
train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.RandomCrop((448, 448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.CenterCrop((448, 448)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
trainset = dogs(root=args.data_root,
|
||||
train=True,
|
||||
cropped=False,
|
||||
transform=train_transform,
|
||||
download=False
|
||||
)
|
||||
testset = dogs(root=args.data_root,
|
||||
train=False,
|
||||
cropped=False,
|
||||
transform=test_transform,
|
||||
download=False
|
||||
)
|
||||
elif args.dataset == 'nabirds':
|
||||
train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.RandomCrop((448, 448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.CenterCrop((448, 448)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
trainset = NABirds(root=args.data_root, train=True, transform=train_transform)
|
||||
testset = NABirds(root=args.data_root, train=False, transform=test_transform)
|
||||
elif args.dataset == 'INat2017':
|
||||
train_transform=transforms.Compose([transforms.Resize((400, 400), Image.BILINEAR),
|
||||
transforms.RandomCrop((304, 304)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
AutoAugImageNetPolicy(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
test_transform=transforms.Compose([transforms.Resize((400, 400), Image.BILINEAR),
|
||||
transforms.CenterCrop((304, 304)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
trainset = INat2017(args.data_root, 'train', train_transform)
|
||||
testset = INat2017(args.data_root, 'val', test_transform)
|
||||
elif args.dataset == 'emptyJudge5' or args.dataset == 'emptyJudge4':
|
||||
train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.RandomCrop((448, 448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
# test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
# transforms.CenterCrop((448, 448)),
|
||||
# transforms.ToTensor(),
|
||||
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
trainset = emptyJudge(root=args.data_root, is_train=True, transform=train_transform)
|
||||
testset = emptyJudge(root=args.data_root, is_train=False, transform=test_transform)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier()
|
||||
|
||||
train_sampler = RandomSampler(trainset) if args.local_rank == -1 else DistributedSampler(trainset)
|
||||
test_sampler = SequentialSampler(testset) if args.local_rank == -1 else DistributedSampler(testset)
|
||||
train_loader = DataLoader(trainset,
|
||||
sampler=train_sampler,
|
||||
batch_size=args.train_batch_size,
|
||||
num_workers=4,
|
||||
drop_last=True,
|
||||
pin_memory=True)
|
||||
test_loader = DataLoader(testset,
|
||||
sampler=test_sampler,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_workers=4,
|
||||
pin_memory=True) if testset is not None else None
|
||||
|
||||
return train_loader, test_loader
|
Reference in New Issue
Block a user