update
This commit is contained in:
0
utils/__init__.py
Executable file
0
utils/__init__.py
Executable file
204
utils/autoaugment.py
Executable file
204
utils/autoaugment.py
Executable file
@ -0,0 +1,204 @@
|
||||
"""
|
||||
Copy from https://github.com/DeepVoltaire/AutoAugment/blob/master/autoaugment.py
|
||||
"""
|
||||
|
||||
from PIL import Image, ImageEnhance, ImageOps
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
__all__ = ['AutoAugImageNetPolicy', 'AutoAugCIFAR10Policy', 'AutoAugSVHNPolicy']
|
||||
|
||||
|
||||
class AutoAugImageNetPolicy(object):
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor),
|
||||
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor),
|
||||
SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor),
|
||||
SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor),
|
||||
SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor),
|
||||
SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor),
|
||||
SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor),
|
||||
SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor),
|
||||
SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor),
|
||||
SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor),
|
||||
SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor)
|
||||
]
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment ImageNet Policy"
|
||||
|
||||
|
||||
class AutoAugCIFAR10Policy(object):
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
|
||||
SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
|
||||
SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
|
||||
SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),
|
||||
|
||||
SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
|
||||
SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
|
||||
SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
|
||||
SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),
|
||||
|
||||
SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
|
||||
SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
|
||||
SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
|
||||
SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
|
||||
SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
|
||||
SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor),
|
||||
SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
|
||||
SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),
|
||||
|
||||
SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
|
||||
SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
|
||||
SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
|
||||
]
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment CIFAR10 Policy"
|
||||
|
||||
|
||||
class AutoAugSVHNPolicy(object):
|
||||
def __init__(self, fillcolor=(128, 128, 128)):
|
||||
self.policies = [
|
||||
SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor),
|
||||
SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor),
|
||||
SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor),
|
||||
|
||||
SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor),
|
||||
SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor),
|
||||
SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor),
|
||||
SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor),
|
||||
|
||||
SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor),
|
||||
SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor),
|
||||
SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor),
|
||||
SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor),
|
||||
SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor),
|
||||
|
||||
SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor),
|
||||
SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor),
|
||||
SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor),
|
||||
SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor)
|
||||
]
|
||||
|
||||
def __call__(self, img):
|
||||
policy_idx = random.randint(0, len(self.policies) - 1)
|
||||
return self.policies[policy_idx](img)
|
||||
|
||||
def __repr__(self):
|
||||
return "AutoAugment SVHN Policy"
|
||||
|
||||
|
||||
class SubPolicy(object):
|
||||
def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)):
|
||||
ranges = {
|
||||
"shearX": np.linspace(0, 0.3, 10),
|
||||
"shearY": np.linspace(0, 0.3, 10),
|
||||
"translateX": np.linspace(0, 150 / 331, 10),
|
||||
"translateY": np.linspace(0, 150 / 331, 10),
|
||||
"rotate": np.linspace(0, 30, 10),
|
||||
"color": np.linspace(0.0, 0.9, 10),
|
||||
"posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int),
|
||||
"solarize": np.linspace(256, 0, 10),
|
||||
"contrast": np.linspace(0.0, 0.9, 10),
|
||||
"sharpness": np.linspace(0.0, 0.9, 10),
|
||||
"brightness": np.linspace(0.0, 0.9, 10),
|
||||
"autocontrast": [0] * 10,
|
||||
"equalize": [0] * 10,
|
||||
"invert": [0] * 10
|
||||
}
|
||||
|
||||
def rotate_with_fill(img, magnitude):
|
||||
rot = img.convert("RGBA").rotate(magnitude)
|
||||
return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode)
|
||||
|
||||
func = {
|
||||
"shearX": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0),
|
||||
Image.BICUBIC, fillcolor=fillcolor),
|
||||
"shearY": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0),
|
||||
Image.BICUBIC, fillcolor=fillcolor),
|
||||
"translateX": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0),
|
||||
fillcolor=fillcolor),
|
||||
"translateY": lambda img, magnitude: img.transform(
|
||||
img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])),
|
||||
fillcolor=fillcolor),
|
||||
"rotate": lambda img, magnitude: rotate_with_fill(img, magnitude),
|
||||
# "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])),
|
||||
"color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])),
|
||||
"posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude),
|
||||
"solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude),
|
||||
"contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance(
|
||||
1 + magnitude * random.choice([-1, 1])),
|
||||
"autocontrast": lambda img, magnitude: ImageOps.autocontrast(img),
|
||||
"equalize": lambda img, magnitude: ImageOps.equalize(img),
|
||||
"invert": lambda img, magnitude: ImageOps.invert(img)
|
||||
}
|
||||
|
||||
# self.name = "{}_{:.2f}_and_{}_{:.2f}".format(
|
||||
# operation1, ranges[operation1][magnitude_idx1],
|
||||
# operation2, ranges[operation2][magnitude_idx2])
|
||||
self.p1 = p1
|
||||
self.operation1 = func[operation1]
|
||||
self.magnitude1 = ranges[operation1][magnitude_idx1]
|
||||
self.p2 = p2
|
||||
self.operation2 = func[operation2]
|
||||
self.magnitude2 = ranges[operation2][magnitude_idx2]
|
||||
|
||||
def __call__(self, img):
|
||||
if random.random() < self.p1:
|
||||
img = self.operation1(img, self.magnitude1)
|
||||
if random.random() < self.p2:
|
||||
img = self.operation2(img, self.magnitude2)
|
||||
return img
|
135
utils/data_utils.py
Executable file
135
utils/data_utils.py
Executable file
@ -0,0 +1,135 @@
|
||||
import logging
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler
|
||||
|
||||
from .dataset import CUB, CarsDataset, NABirds, dogs, INat2017, emptyJudge
|
||||
from .autoaugment import AutoAugImageNetPolicy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_loader(args):
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier()
|
||||
|
||||
if args.dataset == 'CUB_200_2011':
|
||||
train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.RandomCrop((448, 448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.CenterCrop((448, 448)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
trainset = CUB(root=args.data_root, is_train=True, transform=train_transform)
|
||||
testset = CUB(root=args.data_root, is_train=False, transform=test_transform)
|
||||
elif args.dataset == 'car':
|
||||
trainset = CarsDataset(os.path.join(args.data_root,'devkit/cars_train_annos.mat'),
|
||||
os.path.join(args.data_root,'cars_train'),
|
||||
os.path.join(args.data_root,'devkit/cars_meta.mat'),
|
||||
# cleaned=os.path.join(data_dir,'cleaned.dat'),
|
||||
transform=transforms.Compose([
|
||||
transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.RandomCrop((448, 448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
AutoAugImageNetPolicy(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
)
|
||||
testset = CarsDataset(os.path.join(args.data_root,'cars_test_annos_withlabels.mat'),
|
||||
os.path.join(args.data_root,'cars_test'),
|
||||
os.path.join(args.data_root,'devkit/cars_meta.mat'),
|
||||
# cleaned=os.path.join(data_dir,'cleaned_test.dat'),
|
||||
transform=transforms.Compose([
|
||||
transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.CenterCrop((448, 448)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
)
|
||||
elif args.dataset == 'dog':
|
||||
train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.RandomCrop((448, 448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.CenterCrop((448, 448)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
trainset = dogs(root=args.data_root,
|
||||
train=True,
|
||||
cropped=False,
|
||||
transform=train_transform,
|
||||
download=False
|
||||
)
|
||||
testset = dogs(root=args.data_root,
|
||||
train=False,
|
||||
cropped=False,
|
||||
transform=test_transform,
|
||||
download=False
|
||||
)
|
||||
elif args.dataset == 'nabirds':
|
||||
train_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.RandomCrop((448, 448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
test_transform=transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.CenterCrop((448, 448)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
trainset = NABirds(root=args.data_root, train=True, transform=train_transform)
|
||||
testset = NABirds(root=args.data_root, train=False, transform=test_transform)
|
||||
elif args.dataset == 'INat2017':
|
||||
train_transform=transforms.Compose([transforms.Resize((400, 400), Image.BILINEAR),
|
||||
transforms.RandomCrop((304, 304)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
AutoAugImageNetPolicy(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
test_transform=transforms.Compose([transforms.Resize((400, 400), Image.BILINEAR),
|
||||
transforms.CenterCrop((304, 304)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
trainset = INat2017(args.data_root, 'train', train_transform)
|
||||
testset = INat2017(args.data_root, 'val', test_transform)
|
||||
elif args.dataset == 'emptyJudge5' or args.dataset == 'emptyJudge4':
|
||||
train_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
transforms.RandomCrop((448, 448)),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
# test_transform = transforms.Compose([transforms.Resize((600, 600), Image.BILINEAR),
|
||||
# transforms.CenterCrop((448, 448)),
|
||||
# transforms.ToTensor(),
|
||||
# transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
test_transform = transforms.Compose([transforms.Resize((448, 448), Image.BILINEAR),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
|
||||
trainset = emptyJudge(root=args.data_root, is_train=True, transform=train_transform)
|
||||
testset = emptyJudge(root=args.data_root, is_train=False, transform=test_transform)
|
||||
|
||||
if args.local_rank == 0:
|
||||
torch.distributed.barrier()
|
||||
|
||||
train_sampler = RandomSampler(trainset) if args.local_rank == -1 else DistributedSampler(trainset)
|
||||
test_sampler = SequentialSampler(testset) if args.local_rank == -1 else DistributedSampler(testset)
|
||||
train_loader = DataLoader(trainset,
|
||||
sampler=train_sampler,
|
||||
batch_size=args.train_batch_size,
|
||||
num_workers=4,
|
||||
drop_last=True,
|
||||
pin_memory=True)
|
||||
test_loader = DataLoader(testset,
|
||||
sampler=test_sampler,
|
||||
batch_size=args.eval_batch_size,
|
||||
num_workers=4,
|
||||
pin_memory=True) if testset is not None else None
|
||||
|
||||
return train_loader, test_loader
|
629
utils/dataset.py
Executable file
629
utils/dataset.py
Executable file
@ -0,0 +1,629 @@
|
||||
import os
|
||||
import json
|
||||
from os.path import join
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
from scipy import io
|
||||
import scipy.misc
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.datasets.folder import default_loader
|
||||
from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg
|
||||
|
||||
|
||||
class emptyJudge():
|
||||
def __init__(self, root, is_train=True, data_len=None, transform=None):
|
||||
self.root = root
|
||||
self.is_train = is_train
|
||||
self.transform = transform
|
||||
img_txt_file = open(os.path.join(self.root, 'images.txt'))
|
||||
label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
|
||||
train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
|
||||
img_name_list = []
|
||||
for line in img_txt_file:
|
||||
img_name_list.append(line[:-1].split(' ')[-1])
|
||||
label_list = []
|
||||
for line in label_txt_file:
|
||||
label_list.append(int(line[:-1].split(' ')[-1]) - 1)
|
||||
train_test_list = []
|
||||
for line in train_val_file:
|
||||
train_test_list.append(int(line[:-1].split(' ')[-1]))
|
||||
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
||||
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
||||
if self.is_train:
|
||||
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||
train_file_list[:data_len]]
|
||||
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
||||
self.train_imgname = [x for x in train_file_list[:data_len]]
|
||||
if not self.is_train:
|
||||
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||
test_file_list[:data_len]]
|
||||
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
||||
self.test_imgname = [x for x in test_file_list[:data_len]]
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.is_train:
|
||||
img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack([img] * 3, 2)
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
else:
|
||||
img, target, imgname = self.test_img[index], self.test_label[index], self.test_imgname[index]
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack([img] * 3, 2)
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
if self.is_train:
|
||||
return len(self.train_label)
|
||||
else:
|
||||
return len(self.test_label)
|
||||
|
||||
|
||||
class CUB():
|
||||
def __init__(self, root, is_train=True, data_len=None, transform=None):
|
||||
self.root = root
|
||||
self.is_train = is_train
|
||||
self.transform = transform
|
||||
img_txt_file = open(os.path.join(self.root, 'images.txt'))
|
||||
label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
|
||||
train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
|
||||
img_name_list = []
|
||||
for line in img_txt_file:
|
||||
img_name_list.append(line[:-1].split(' ')[-1])
|
||||
label_list = []
|
||||
for line in label_txt_file:
|
||||
label_list.append(int(line[:-1].split(' ')[-1]) - 1)
|
||||
train_test_list = []
|
||||
for line in train_val_file:
|
||||
train_test_list.append(int(line[:-1].split(' ')[-1]))
|
||||
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
||||
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
||||
if self.is_train:
|
||||
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||
train_file_list[:data_len]]
|
||||
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
||||
self.train_imgname = [x for x in train_file_list[:data_len]]
|
||||
if not self.is_train:
|
||||
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||
test_file_list[:data_len]]
|
||||
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
||||
self.test_imgname = [x for x in test_file_list[:data_len]]
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.is_train:
|
||||
img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack([img] * 3, 2)
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
else:
|
||||
img, target, imgname = self.test_img[index], self.test_label[index], self.test_imgname[index]
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack([img] * 3, 2)
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
if self.is_train:
|
||||
return len(self.train_label)
|
||||
else:
|
||||
return len(self.test_label)
|
||||
|
||||
|
||||
class CarsDataset(Dataset):
|
||||
def __init__(self, mat_anno, data_dir, car_names, cleaned=None, transform=None):
|
||||
"""
|
||||
Args:
|
||||
mat_anno (string): Path to the MATLAB annotation file.
|
||||
data_dir (string): Directory with all the images.
|
||||
transform (callable, optional): Optional transform to be applied
|
||||
on a sample.
|
||||
"""
|
||||
|
||||
self.full_data_set = io.loadmat(mat_anno)
|
||||
self.car_annotations = self.full_data_set['annotations']
|
||||
self.car_annotations = self.car_annotations[0]
|
||||
|
||||
if cleaned is not None:
|
||||
cleaned_annos = []
|
||||
print("Cleaning up data set (only take pics with rgb chans)...")
|
||||
clean_files = np.loadtxt(cleaned, dtype=str)
|
||||
for c in self.car_annotations:
|
||||
if c[-1][0] in clean_files:
|
||||
cleaned_annos.append(c)
|
||||
self.car_annotations = cleaned_annos
|
||||
|
||||
self.car_names = scipy.io.loadmat(car_names)['class_names']
|
||||
self.car_names = np.array(self.car_names[0])
|
||||
|
||||
self.data_dir = data_dir
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.car_annotations)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_name = os.path.join(self.data_dir, self.car_annotations[idx][-1][0])
|
||||
image = Image.open(img_name).convert('RGB')
|
||||
car_class = self.car_annotations[idx][-2][0][0]
|
||||
car_class = torch.from_numpy(np.array(car_class.astype(np.float32))).long() - 1
|
||||
assert car_class < 196
|
||||
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
|
||||
# return image, car_class, img_name
|
||||
return image, car_class
|
||||
|
||||
def map_class(self, id):
|
||||
id = np.ravel(id)
|
||||
ret = self.car_names[id - 1][0][0]
|
||||
return ret
|
||||
|
||||
def show_batch(self, img_batch, class_batch):
|
||||
|
||||
for i in range(img_batch.shape[0]):
|
||||
ax = plt.subplot(1, img_batch.shape[0], i + 1)
|
||||
title_str = self.map_class(int(class_batch[i]))
|
||||
img = np.transpose(img_batch[i, ...], (1, 2, 0))
|
||||
ax.imshow(img)
|
||||
ax.set_title(title_str.__str__(), {'fontsize': 5})
|
||||
plt.tight_layout()
|
||||
|
||||
|
||||
def make_dataset(dir, image_ids, targets):
|
||||
assert(len(image_ids) == len(targets))
|
||||
images = []
|
||||
dir = os.path.expanduser(dir)
|
||||
for i in range(len(image_ids)):
|
||||
item = (os.path.join(dir, 'data', 'images', '%s.jpg' % image_ids[i]), targets[i])
|
||||
images.append(item)
|
||||
return images
|
||||
|
||||
|
||||
def find_classes(classes_file):
|
||||
# read classes file, separating out image IDs and class names
|
||||
image_ids = []
|
||||
targets = []
|
||||
f = open(classes_file, 'r')
|
||||
for line in f:
|
||||
split_line = line.split(' ')
|
||||
image_ids.append(split_line[0])
|
||||
targets.append(' '.join(split_line[1:]))
|
||||
f.close()
|
||||
|
||||
# index class names
|
||||
classes = np.unique(targets)
|
||||
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
||||
targets = [class_to_idx[c] for c in targets]
|
||||
return (image_ids, targets, classes, class_to_idx)
|
||||
|
||||
|
||||
class dogs(Dataset):
|
||||
"""`Stanford Dogs <http://vision.stanford.edu/aditya86/ImageNetDogs/>`_ Dataset.
|
||||
Args:
|
||||
root (string): Root directory of dataset where directory
|
||||
``omniglot-py`` exists.
|
||||
cropped (bool, optional): If true, the images will be cropped into the bounding box specified
|
||||
in the annotations
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
download (bool, optional): If true, downloads the dataset tar files from the internet and
|
||||
puts it in root directory. If the tar files are already downloaded, they are not
|
||||
downloaded again.
|
||||
"""
|
||||
folder = 'dog'
|
||||
download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs'
|
||||
|
||||
def __init__(self,
|
||||
root,
|
||||
train=True,
|
||||
cropped=False,
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
download=False):
|
||||
|
||||
# self.root = join(os.path.expanduser(root), self.folder)
|
||||
self.root = root
|
||||
self.train = train
|
||||
self.cropped = cropped
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
||||
split = self.load_split()
|
||||
|
||||
self.images_folder = join(self.root, 'Images')
|
||||
self.annotations_folder = join(self.root, 'Annotation')
|
||||
self._breeds = list_dir(self.images_folder)
|
||||
|
||||
if self.cropped:
|
||||
self._breed_annotations = [[(annotation, box, idx)
|
||||
for box in self.get_boxes(join(self.annotations_folder, annotation))]
|
||||
for annotation, idx in split]
|
||||
self._flat_breed_annotations = sum(self._breed_annotations, [])
|
||||
|
||||
self._flat_breed_images = [(annotation+'.jpg', idx) for annotation, box, idx in self._flat_breed_annotations]
|
||||
else:
|
||||
self._breed_images = [(annotation+'.jpg', idx) for annotation, idx in split]
|
||||
|
||||
self._flat_breed_images = self._breed_images
|
||||
|
||||
self.classes = ["Chihuaha",
|
||||
"Japanese Spaniel",
|
||||
"Maltese Dog",
|
||||
"Pekinese",
|
||||
"Shih-Tzu",
|
||||
"Blenheim Spaniel",
|
||||
"Papillon",
|
||||
"Toy Terrier",
|
||||
"Rhodesian Ridgeback",
|
||||
"Afghan Hound",
|
||||
"Basset Hound",
|
||||
"Beagle",
|
||||
"Bloodhound",
|
||||
"Bluetick",
|
||||
"Black-and-tan Coonhound",
|
||||
"Walker Hound",
|
||||
"English Foxhound",
|
||||
"Redbone",
|
||||
"Borzoi",
|
||||
"Irish Wolfhound",
|
||||
"Italian Greyhound",
|
||||
"Whippet",
|
||||
"Ibizian Hound",
|
||||
"Norwegian Elkhound",
|
||||
"Otterhound",
|
||||
"Saluki",
|
||||
"Scottish Deerhound",
|
||||
"Weimaraner",
|
||||
"Staffordshire Bullterrier",
|
||||
"American Staffordshire Terrier",
|
||||
"Bedlington Terrier",
|
||||
"Border Terrier",
|
||||
"Kerry Blue Terrier",
|
||||
"Irish Terrier",
|
||||
"Norfolk Terrier",
|
||||
"Norwich Terrier",
|
||||
"Yorkshire Terrier",
|
||||
"Wirehaired Fox Terrier",
|
||||
"Lakeland Terrier",
|
||||
"Sealyham Terrier",
|
||||
"Airedale",
|
||||
"Cairn",
|
||||
"Australian Terrier",
|
||||
"Dandi Dinmont",
|
||||
"Boston Bull",
|
||||
"Miniature Schnauzer",
|
||||
"Giant Schnauzer",
|
||||
"Standard Schnauzer",
|
||||
"Scotch Terrier",
|
||||
"Tibetan Terrier",
|
||||
"Silky Terrier",
|
||||
"Soft-coated Wheaten Terrier",
|
||||
"West Highland White Terrier",
|
||||
"Lhasa",
|
||||
"Flat-coated Retriever",
|
||||
"Curly-coater Retriever",
|
||||
"Golden Retriever",
|
||||
"Labrador Retriever",
|
||||
"Chesapeake Bay Retriever",
|
||||
"German Short-haired Pointer",
|
||||
"Vizsla",
|
||||
"English Setter",
|
||||
"Irish Setter",
|
||||
"Gordon Setter",
|
||||
"Brittany",
|
||||
"Clumber",
|
||||
"English Springer Spaniel",
|
||||
"Welsh Springer Spaniel",
|
||||
"Cocker Spaniel",
|
||||
"Sussex Spaniel",
|
||||
"Irish Water Spaniel",
|
||||
"Kuvasz",
|
||||
"Schipperke",
|
||||
"Groenendael",
|
||||
"Malinois",
|
||||
"Briard",
|
||||
"Kelpie",
|
||||
"Komondor",
|
||||
"Old English Sheepdog",
|
||||
"Shetland Sheepdog",
|
||||
"Collie",
|
||||
"Border Collie",
|
||||
"Bouvier des Flandres",
|
||||
"Rottweiler",
|
||||
"German Shepard",
|
||||
"Doberman",
|
||||
"Miniature Pinscher",
|
||||
"Greater Swiss Mountain Dog",
|
||||
"Bernese Mountain Dog",
|
||||
"Appenzeller",
|
||||
"EntleBucher",
|
||||
"Boxer",
|
||||
"Bull Mastiff",
|
||||
"Tibetan Mastiff",
|
||||
"French Bulldog",
|
||||
"Great Dane",
|
||||
"Saint Bernard",
|
||||
"Eskimo Dog",
|
||||
"Malamute",
|
||||
"Siberian Husky",
|
||||
"Affenpinscher",
|
||||
"Basenji",
|
||||
"Pug",
|
||||
"Leonberg",
|
||||
"Newfoundland",
|
||||
"Great Pyrenees",
|
||||
"Samoyed",
|
||||
"Pomeranian",
|
||||
"Chow",
|
||||
"Keeshond",
|
||||
"Brabancon Griffon",
|
||||
"Pembroke",
|
||||
"Cardigan",
|
||||
"Toy Poodle",
|
||||
"Miniature Poodle",
|
||||
"Standard Poodle",
|
||||
"Mexican Hairless",
|
||||
"Dingo",
|
||||
"Dhole",
|
||||
"African Hunting Dog"]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._flat_breed_images)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
Returns:
|
||||
tuple: (image, target) where target is index of the target character class.
|
||||
"""
|
||||
image_name, target_class = self._flat_breed_images[index]
|
||||
image_path = join(self.images_folder, image_name)
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
|
||||
if self.cropped:
|
||||
image = image.crop(self._flat_breed_annotations[index][1])
|
||||
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
|
||||
if self.target_transform:
|
||||
target_class = self.target_transform(target_class)
|
||||
|
||||
return image, target_class
|
||||
|
||||
def download(self):
|
||||
import tarfile
|
||||
|
||||
if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')):
|
||||
if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120:
|
||||
print('Files already downloaded and verified')
|
||||
return
|
||||
|
||||
for filename in ['images', 'annotation', 'lists']:
|
||||
tar_filename = filename + '.tar'
|
||||
url = self.download_url_prefix + '/' + tar_filename
|
||||
download_url(url, self.root, tar_filename, None)
|
||||
print('Extracting downloaded file: ' + join(self.root, tar_filename))
|
||||
with tarfile.open(join(self.root, tar_filename), 'r') as tar_file:
|
||||
tar_file.extractall(self.root)
|
||||
os.remove(join(self.root, tar_filename))
|
||||
|
||||
@staticmethod
|
||||
def get_boxes(path):
|
||||
import xml.etree.ElementTree
|
||||
e = xml.etree.ElementTree.parse(path).getroot()
|
||||
boxes = []
|
||||
for objs in e.iter('object'):
|
||||
boxes.append([int(objs.find('bndbox').find('xmin').text),
|
||||
int(objs.find('bndbox').find('ymin').text),
|
||||
int(objs.find('bndbox').find('xmax').text),
|
||||
int(objs.find('bndbox').find('ymax').text)])
|
||||
return boxes
|
||||
|
||||
def load_split(self):
|
||||
if self.train:
|
||||
split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list']
|
||||
labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels']
|
||||
else:
|
||||
split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list']
|
||||
labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels']
|
||||
|
||||
split = [item[0][0] for item in split]
|
||||
labels = [item[0]-1 for item in labels]
|
||||
return list(zip(split, labels))
|
||||
|
||||
def stats(self):
|
||||
counts = {}
|
||||
for index in range(len(self._flat_breed_images)):
|
||||
image_name, target_class = self._flat_breed_images[index]
|
||||
if target_class not in counts.keys():
|
||||
counts[target_class] = 1
|
||||
else:
|
||||
counts[target_class] += 1
|
||||
|
||||
print("%d samples spanning %d classes (avg %f per class)"%(len(self._flat_breed_images), len(counts.keys()), float(len(self._flat_breed_images))/float(len(counts.keys()))))
|
||||
return counts
|
||||
|
||||
|
||||
class NABirds(Dataset):
|
||||
"""`NABirds <https://dl.allaboutbirds.org/nabirds>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory of the dataset.
|
||||
train (bool, optional): If True, creates dataset from training set, otherwise
|
||||
creates from test set.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
download (bool, optional): If true, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
"""
|
||||
base_folder = 'nabirds/images'
|
||||
|
||||
def __init__(self, root, train=True, transform=None):
|
||||
dataset_path = os.path.join(root, 'nabirds')
|
||||
self.root = root
|
||||
self.loader = default_loader
|
||||
self.train = train
|
||||
self.transform = transform
|
||||
|
||||
image_paths = pd.read_csv(os.path.join(dataset_path, 'images.txt'),
|
||||
sep=' ', names=['img_id', 'filepath'])
|
||||
image_class_labels = pd.read_csv(os.path.join(dataset_path, 'image_class_labels.txt'),
|
||||
sep=' ', names=['img_id', 'target'])
|
||||
# Since the raw labels are non-continuous, map them to new ones
|
||||
self.label_map = get_continuous_class_map(image_class_labels['target'])
|
||||
train_test_split = pd.read_csv(os.path.join(dataset_path, 'train_test_split.txt'),
|
||||
sep=' ', names=['img_id', 'is_training_img'])
|
||||
data = image_paths.merge(image_class_labels, on='img_id')
|
||||
self.data = data.merge(train_test_split, on='img_id')
|
||||
# Load in the train / test split
|
||||
if self.train:
|
||||
self.data = self.data[self.data.is_training_img == 1]
|
||||
else:
|
||||
self.data = self.data[self.data.is_training_img == 0]
|
||||
|
||||
# Load in the class data
|
||||
self.class_names = load_class_names(dataset_path)
|
||||
self.class_hierarchy = load_hierarchy(dataset_path)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.data.iloc[idx]
|
||||
path = os.path.join(self.root, self.base_folder, sample.filepath)
|
||||
target = self.label_map[sample.target]
|
||||
img = self.loader(path)
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
return img, target
|
||||
|
||||
|
||||
def get_continuous_class_map(class_labels):
|
||||
label_set = set(class_labels)
|
||||
return {k: i for i, k in enumerate(label_set)}
|
||||
|
||||
|
||||
def load_class_names(dataset_path=''):
|
||||
names = {}
|
||||
|
||||
with open(os.path.join(dataset_path, 'classes.txt')) as f:
|
||||
for line in f:
|
||||
pieces = line.strip().split()
|
||||
class_id = pieces[0]
|
||||
names[class_id] = ' '.join(pieces[1:])
|
||||
|
||||
return names
|
||||
|
||||
|
||||
def load_hierarchy(dataset_path=''):
|
||||
parents = {}
|
||||
|
||||
with open(os.path.join(dataset_path, 'hierarchy.txt')) as f:
|
||||
for line in f:
|
||||
pieces = line.strip().split()
|
||||
child_id, parent_id = pieces
|
||||
parents[child_id] = parent_id
|
||||
|
||||
return parents
|
||||
|
||||
|
||||
class INat2017(VisionDataset):
|
||||
"""`iNaturalist 2017 <https://github.com/visipedia/inat_comp/blob/master/2017/README.md>`_ Dataset.
|
||||
Args:
|
||||
root (string): Root directory of the dataset.
|
||||
split (string, optional): The dataset split, supports ``train``, or ``val``.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
download (bool, optional): If true, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
"""
|
||||
base_folder = 'train_val_images/'
|
||||
file_list = {
|
||||
'imgs': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val_images.tar.gz',
|
||||
'train_val_images.tar.gz',
|
||||
'7c784ea5e424efaec655bd392f87301f'),
|
||||
'annos': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val2017.zip',
|
||||
'train_val2017.zip',
|
||||
'444c835f6459867ad69fcb36478786e7')
|
||||
}
|
||||
|
||||
def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
|
||||
super(INat2017, self).__init__(root, transform=transform, target_transform=target_transform)
|
||||
self.loader = default_loader
|
||||
self.split = verify_str_arg(split, "split", ("train", "val",))
|
||||
|
||||
if self._check_exists():
|
||||
print('Files already downloaded and verified.')
|
||||
elif download:
|
||||
if not (os.path.exists(os.path.join(self.root, self.file_list['imgs'][1]))
|
||||
and os.path.exists(os.path.join(self.root, self.file_list['annos'][1]))):
|
||||
print('Downloading...')
|
||||
self._download()
|
||||
print('Extracting...')
|
||||
extract_archive(os.path.join(self.root, self.file_list['imgs'][1]))
|
||||
extract_archive(os.path.join(self.root, self.file_list['annos'][1]))
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Dataset not found. You can use download=True to download it.')
|
||||
anno_filename = split + '2017.json'
|
||||
with open(os.path.join(self.root, anno_filename), 'r') as fp:
|
||||
all_annos = json.load(fp)
|
||||
|
||||
self.annos = all_annos['annotations']
|
||||
self.images = all_annos['images']
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = os.path.join(self.root, self.images[index]['file_name'])
|
||||
target = self.annos[index]['category_id']
|
||||
|
||||
image = self.loader(path)
|
||||
if self.transform is not None:
|
||||
image = self.transform(image)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return image, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def _check_exists(self):
|
||||
return os.path.exists(os.path.join(self.root, self.base_folder))
|
||||
|
||||
def _download(self):
|
||||
for url, filename, md5 in self.file_list.values():
|
||||
download_url(url, root=self.root, filename=filename)
|
||||
if not check_integrity(os.path.join(self.root, filename), md5):
|
||||
raise RuntimeError("File not found or corrupted.")
|
30
utils/dist_util.py
Executable file
30
utils/dist_util.py
Executable file
@ -0,0 +1,30 @@
|
||||
import torch.distributed as dist
|
||||
|
||||
def get_rank():
|
||||
if not dist.is_available():
|
||||
return 0
|
||||
if not dist.is_initialized():
|
||||
return 0
|
||||
return dist.get_rank()
|
||||
|
||||
def get_world_size():
|
||||
if not dist.is_available():
|
||||
return 1
|
||||
if not dist.is_initialized():
|
||||
return 1
|
||||
return dist.get_world_size()
|
||||
|
||||
def is_main_process():
|
||||
return get_rank() == 0
|
||||
|
||||
def format_step(step):
|
||||
if isinstance(step, str):
|
||||
return step
|
||||
s = ""
|
||||
if len(step) > 0:
|
||||
s += "Training Epoch: {} ".format(step[0])
|
||||
if len(step) > 1:
|
||||
s += "Training Iteration: {} ".format(step[1])
|
||||
if len(step) > 2:
|
||||
s += "Validation Iteration: {} ".format(step[2])
|
||||
return s
|
63
utils/scheduler.py
Executable file
63
utils/scheduler.py
Executable file
@ -0,0 +1,63 @@
|
||||
import logging
|
||||
import math
|
||||
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ConstantLRSchedule(LambdaLR):
|
||||
""" Constant learning rate schedule.
|
||||
"""
|
||||
def __init__(self, optimizer, last_epoch=-1):
|
||||
super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch)
|
||||
|
||||
|
||||
class WarmupConstantSchedule(LambdaLR):
|
||||
""" Linear warmup and then constant.
|
||||
Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps.
|
||||
Keeps learning rate schedule equal to 1. after warmup_steps.
|
||||
"""
|
||||
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
|
||||
self.warmup_steps = warmup_steps
|
||||
super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
def lr_lambda(self, step):
|
||||
if step < self.warmup_steps:
|
||||
return float(step) / float(max(1.0, self.warmup_steps))
|
||||
return 1.
|
||||
|
||||
|
||||
class WarmupLinearSchedule(LambdaLR):
|
||||
""" Linear warmup and then linear decay.
|
||||
Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
|
||||
Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps.
|
||||
"""
|
||||
def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
|
||||
self.warmup_steps = warmup_steps
|
||||
self.t_total = t_total
|
||||
super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
def lr_lambda(self, step):
|
||||
if step < self.warmup_steps:
|
||||
return float(step) / float(max(1, self.warmup_steps))
|
||||
return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
|
||||
|
||||
|
||||
class WarmupCosineSchedule(LambdaLR):
|
||||
""" Linear warmup and then cosine decay.
|
||||
Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
|
||||
Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve.
|
||||
If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
|
||||
"""
|
||||
def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
|
||||
self.warmup_steps = warmup_steps
|
||||
self.t_total = t_total
|
||||
self.cycles = cycles
|
||||
super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
||||
|
||||
def lr_lambda(self, step):
|
||||
if step < self.warmup_steps:
|
||||
return float(step) / float(max(1.0, self.warmup_steps))
|
||||
# progress after warmup
|
||||
progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
|
||||
return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
|
Reference in New Issue
Block a user