Files
ieemoo-ai-isempty/utils/data_utils.py
2022-09-27 02:12:39 +00:00

170 lines
9.7 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
from PIL import Image
import os
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler
from .dataset import CUB, CarsDataset, NABirds, dogs, INat2017, emptyJudge
from .autoaugment import AutoAugImageNetPolicy
logger = logging.getLogger(__name__)
def get_loader_new():
train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.RandomCrop((320, 320)), #448
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_transform = transforms.Compose([transforms.Resize((320, 320), Image.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
trainset = emptyJudge(root='./emptyJudge5', is_train=True, transform=train_transform)
testset = emptyJudge(root='./emptyJudge5', is_train=False, transform=test_transform)
train_sampler = RandomSampler(trainset)
test_sampler = SequentialSampler(testset)
train_loader = DataLoader(trainset,
sampler=train_sampler,
batch_size=8,
num_workers=4,
drop_last=True,
pin_memory=True)
test_loader = DataLoader(testset,
sampler=test_sampler,
batch_size=8,
num_workers=4,
pin_memory=True) if testset is not None else None
print('emptyJudge5 getdataloader ok!')
return train_loader, test_loader
#根据不同数据集获取data_loader
def get_loader(args):
if args.local_rank not in [-1, 0]:
torch.distributed.barrier()
if args.dataset == 'CUB_200_2011':
train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.RandomCrop((448, 448)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.CenterCrop((448, 448)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
trainset = CUB(root=args.data_root, is_train=True, transform=train_transform)
testset = CUB(root=args.data_root, is_train=False, transform=test_transform)
elif args.dataset == 'car':
trainset = CarsDataset(os.path.join(args.data_root,'devkit/cars_train_annos.mat'),
os.path.join(args.data_root,'cars_train'),
os.path.join(args.data_root,'devkit/cars_meta.mat'),
# cleaned=os.path.join(data_dir,'cleaned.dat'),
transform=transforms.Compose([
transforms.Resize((600, 600), Image.BILINEAR),
transforms.RandomCrop((448, 448)),
transforms.RandomHorizontalFlip(),
AutoAugImageNetPolicy(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
)
testset = CarsDataset(os.path.join(args.data_root,'cars_test_annos_withlabels.mat'),
os.path.join(args.data_root,'cars_test'),
os.path.join(args.data_root,'devkit/cars_meta.mat'),
# cleaned=os.path.join(data_dir,'cleaned_test.dat'),
transform=transforms.Compose([
transforms.Resize((600, 600), Image.BILINEAR),
transforms.CenterCrop((448, 448)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
)
elif args.dataset == 'dog':
train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.RandomCrop((448, 448)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.CenterCrop((448, 448)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
trainset = dogs(root=args.data_root,
train=True,
cropped=False,
transform=train_transform,
download=False
)
testset = dogs(root=args.data_root,
train=False,
cropped=False,
transform=test_transform,
download=False
)
elif args.dataset == 'nabirds':
train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.RandomCrop((448, 448)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.CenterCrop((448, 448)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
trainset = NABirds(root=args.data_root, train=True, transform=train_transform)
testset = NABirds(root=args.data_root, train=False, transform=test_transform)
elif args.dataset == 'INat2017':
train_transform=transforms.Compose([transforms.Resize((400, 400), Image.BILINEAR),
transforms.RandomCrop((304, 304)),
transforms.RandomHorizontalFlip(),
AutoAugImageNetPolicy(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_transform=transforms.Compose([transforms.Resize((400, 400), Image.BILINEAR),
transforms.CenterCrop((304, 304)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
trainset = INat2017(args.data_root, 'train', train_transform)
testset = INat2017(args.data_root, 'val', test_transform)
elif args.dataset == 'emptyJudge5':
train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.RandomCrop((600, 600)), #448
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
# transforms.CenterCrop((448, 448)),
# transforms.ToTensor(),
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
trainset = emptyJudge(root=args.data_root, is_train=True, transform=train_transform)
testset = emptyJudge(root=args.data_root, is_train=False, transform=test_transform)
if args.local_rank == 0:
torch.distributed.barrier()
train_sampler = RandomSampler(trainset) if args.local_rank == -1 else DistributedSampler(trainset)
test_sampler = SequentialSampler(testset) if args.local_rank == -1 else DistributedSampler(testset)
train_loader = DataLoader(trainset,
sampler=train_sampler,
batch_size=args.train_batch_size,
num_workers=4,
drop_last=True,
pin_memory=True)
test_loader = DataLoader(testset,
sampler=test_sampler,
batch_size=args.eval_batch_size,
num_workers=4,
pin_memory=True) if testset is not None else None
print('emptyJudge5 getdataloader ok!')
return train_loader, test_loader