update
This commit is contained in:
629
utils/dataset.py
Executable file
629
utils/dataset.py
Executable file
@ -0,0 +1,629 @@
|
||||
import os
|
||||
import json
|
||||
from os.path import join
|
||||
|
||||
import numpy as np
|
||||
import scipy
|
||||
from scipy import io
|
||||
import scipy.misc
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.datasets.folder import default_loader
|
||||
from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg
|
||||
|
||||
|
||||
class emptyJudge():
|
||||
def __init__(self, root, is_train=True, data_len=None, transform=None):
|
||||
self.root = root
|
||||
self.is_train = is_train
|
||||
self.transform = transform
|
||||
img_txt_file = open(os.path.join(self.root, 'images.txt'))
|
||||
label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
|
||||
train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
|
||||
img_name_list = []
|
||||
for line in img_txt_file:
|
||||
img_name_list.append(line[:-1].split(' ')[-1])
|
||||
label_list = []
|
||||
for line in label_txt_file:
|
||||
label_list.append(int(line[:-1].split(' ')[-1]) - 1)
|
||||
train_test_list = []
|
||||
for line in train_val_file:
|
||||
train_test_list.append(int(line[:-1].split(' ')[-1]))
|
||||
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
||||
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
||||
if self.is_train:
|
||||
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||
train_file_list[:data_len]]
|
||||
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
||||
self.train_imgname = [x for x in train_file_list[:data_len]]
|
||||
if not self.is_train:
|
||||
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||
test_file_list[:data_len]]
|
||||
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
||||
self.test_imgname = [x for x in test_file_list[:data_len]]
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.is_train:
|
||||
img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack([img] * 3, 2)
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
else:
|
||||
img, target, imgname = self.test_img[index], self.test_label[index], self.test_imgname[index]
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack([img] * 3, 2)
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
if self.is_train:
|
||||
return len(self.train_label)
|
||||
else:
|
||||
return len(self.test_label)
|
||||
|
||||
|
||||
class CUB():
|
||||
def __init__(self, root, is_train=True, data_len=None, transform=None):
|
||||
self.root = root
|
||||
self.is_train = is_train
|
||||
self.transform = transform
|
||||
img_txt_file = open(os.path.join(self.root, 'images.txt'))
|
||||
label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
|
||||
train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
|
||||
img_name_list = []
|
||||
for line in img_txt_file:
|
||||
img_name_list.append(line[:-1].split(' ')[-1])
|
||||
label_list = []
|
||||
for line in label_txt_file:
|
||||
label_list.append(int(line[:-1].split(' ')[-1]) - 1)
|
||||
train_test_list = []
|
||||
for line in train_val_file:
|
||||
train_test_list.append(int(line[:-1].split(' ')[-1]))
|
||||
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
||||
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
||||
if self.is_train:
|
||||
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
||||
train_file_list[:data_len]]
|
||||
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
||||
self.train_imgname = [x for x in train_file_list[:data_len]]
|
||||
if not self.is_train:
|
||||
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
||||
test_file_list[:data_len]]
|
||||
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
||||
self.test_imgname = [x for x in test_file_list[:data_len]]
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.is_train:
|
||||
img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack([img] * 3, 2)
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
else:
|
||||
img, target, imgname = self.test_img[index], self.test_label[index], self.test_imgname[index]
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack([img] * 3, 2)
|
||||
img = Image.fromarray(img, mode='RGB')
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
|
||||
return img, target
|
||||
|
||||
def __len__(self):
|
||||
if self.is_train:
|
||||
return len(self.train_label)
|
||||
else:
|
||||
return len(self.test_label)
|
||||
|
||||
|
||||
class CarsDataset(Dataset):
|
||||
def __init__(self, mat_anno, data_dir, car_names, cleaned=None, transform=None):
|
||||
"""
|
||||
Args:
|
||||
mat_anno (string): Path to the MATLAB annotation file.
|
||||
data_dir (string): Directory with all the images.
|
||||
transform (callable, optional): Optional transform to be applied
|
||||
on a sample.
|
||||
"""
|
||||
|
||||
self.full_data_set = io.loadmat(mat_anno)
|
||||
self.car_annotations = self.full_data_set['annotations']
|
||||
self.car_annotations = self.car_annotations[0]
|
||||
|
||||
if cleaned is not None:
|
||||
cleaned_annos = []
|
||||
print("Cleaning up data set (only take pics with rgb chans)...")
|
||||
clean_files = np.loadtxt(cleaned, dtype=str)
|
||||
for c in self.car_annotations:
|
||||
if c[-1][0] in clean_files:
|
||||
cleaned_annos.append(c)
|
||||
self.car_annotations = cleaned_annos
|
||||
|
||||
self.car_names = scipy.io.loadmat(car_names)['class_names']
|
||||
self.car_names = np.array(self.car_names[0])
|
||||
|
||||
self.data_dir = data_dir
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.car_annotations)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_name = os.path.join(self.data_dir, self.car_annotations[idx][-1][0])
|
||||
image = Image.open(img_name).convert('RGB')
|
||||
car_class = self.car_annotations[idx][-2][0][0]
|
||||
car_class = torch.from_numpy(np.array(car_class.astype(np.float32))).long() - 1
|
||||
assert car_class < 196
|
||||
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
|
||||
# return image, car_class, img_name
|
||||
return image, car_class
|
||||
|
||||
def map_class(self, id):
|
||||
id = np.ravel(id)
|
||||
ret = self.car_names[id - 1][0][0]
|
||||
return ret
|
||||
|
||||
def show_batch(self, img_batch, class_batch):
|
||||
|
||||
for i in range(img_batch.shape[0]):
|
||||
ax = plt.subplot(1, img_batch.shape[0], i + 1)
|
||||
title_str = self.map_class(int(class_batch[i]))
|
||||
img = np.transpose(img_batch[i, ...], (1, 2, 0))
|
||||
ax.imshow(img)
|
||||
ax.set_title(title_str.__str__(), {'fontsize': 5})
|
||||
plt.tight_layout()
|
||||
|
||||
|
||||
def make_dataset(dir, image_ids, targets):
|
||||
assert(len(image_ids) == len(targets))
|
||||
images = []
|
||||
dir = os.path.expanduser(dir)
|
||||
for i in range(len(image_ids)):
|
||||
item = (os.path.join(dir, 'data', 'images', '%s.jpg' % image_ids[i]), targets[i])
|
||||
images.append(item)
|
||||
return images
|
||||
|
||||
|
||||
def find_classes(classes_file):
|
||||
# read classes file, separating out image IDs and class names
|
||||
image_ids = []
|
||||
targets = []
|
||||
f = open(classes_file, 'r')
|
||||
for line in f:
|
||||
split_line = line.split(' ')
|
||||
image_ids.append(split_line[0])
|
||||
targets.append(' '.join(split_line[1:]))
|
||||
f.close()
|
||||
|
||||
# index class names
|
||||
classes = np.unique(targets)
|
||||
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
||||
targets = [class_to_idx[c] for c in targets]
|
||||
return (image_ids, targets, classes, class_to_idx)
|
||||
|
||||
|
||||
class dogs(Dataset):
|
||||
"""`Stanford Dogs <http://vision.stanford.edu/aditya86/ImageNetDogs/>`_ Dataset.
|
||||
Args:
|
||||
root (string): Root directory of dataset where directory
|
||||
``omniglot-py`` exists.
|
||||
cropped (bool, optional): If true, the images will be cropped into the bounding box specified
|
||||
in the annotations
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
download (bool, optional): If true, downloads the dataset tar files from the internet and
|
||||
puts it in root directory. If the tar files are already downloaded, they are not
|
||||
downloaded again.
|
||||
"""
|
||||
folder = 'dog'
|
||||
download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs'
|
||||
|
||||
def __init__(self,
|
||||
root,
|
||||
train=True,
|
||||
cropped=False,
|
||||
transform=None,
|
||||
target_transform=None,
|
||||
download=False):
|
||||
|
||||
# self.root = join(os.path.expanduser(root), self.folder)
|
||||
self.root = root
|
||||
self.train = train
|
||||
self.cropped = cropped
|
||||
self.transform = transform
|
||||
self.target_transform = target_transform
|
||||
|
||||
if download:
|
||||
self.download()
|
||||
|
||||
split = self.load_split()
|
||||
|
||||
self.images_folder = join(self.root, 'Images')
|
||||
self.annotations_folder = join(self.root, 'Annotation')
|
||||
self._breeds = list_dir(self.images_folder)
|
||||
|
||||
if self.cropped:
|
||||
self._breed_annotations = [[(annotation, box, idx)
|
||||
for box in self.get_boxes(join(self.annotations_folder, annotation))]
|
||||
for annotation, idx in split]
|
||||
self._flat_breed_annotations = sum(self._breed_annotations, [])
|
||||
|
||||
self._flat_breed_images = [(annotation+'.jpg', idx) for annotation, box, idx in self._flat_breed_annotations]
|
||||
else:
|
||||
self._breed_images = [(annotation+'.jpg', idx) for annotation, idx in split]
|
||||
|
||||
self._flat_breed_images = self._breed_images
|
||||
|
||||
self.classes = ["Chihuaha",
|
||||
"Japanese Spaniel",
|
||||
"Maltese Dog",
|
||||
"Pekinese",
|
||||
"Shih-Tzu",
|
||||
"Blenheim Spaniel",
|
||||
"Papillon",
|
||||
"Toy Terrier",
|
||||
"Rhodesian Ridgeback",
|
||||
"Afghan Hound",
|
||||
"Basset Hound",
|
||||
"Beagle",
|
||||
"Bloodhound",
|
||||
"Bluetick",
|
||||
"Black-and-tan Coonhound",
|
||||
"Walker Hound",
|
||||
"English Foxhound",
|
||||
"Redbone",
|
||||
"Borzoi",
|
||||
"Irish Wolfhound",
|
||||
"Italian Greyhound",
|
||||
"Whippet",
|
||||
"Ibizian Hound",
|
||||
"Norwegian Elkhound",
|
||||
"Otterhound",
|
||||
"Saluki",
|
||||
"Scottish Deerhound",
|
||||
"Weimaraner",
|
||||
"Staffordshire Bullterrier",
|
||||
"American Staffordshire Terrier",
|
||||
"Bedlington Terrier",
|
||||
"Border Terrier",
|
||||
"Kerry Blue Terrier",
|
||||
"Irish Terrier",
|
||||
"Norfolk Terrier",
|
||||
"Norwich Terrier",
|
||||
"Yorkshire Terrier",
|
||||
"Wirehaired Fox Terrier",
|
||||
"Lakeland Terrier",
|
||||
"Sealyham Terrier",
|
||||
"Airedale",
|
||||
"Cairn",
|
||||
"Australian Terrier",
|
||||
"Dandi Dinmont",
|
||||
"Boston Bull",
|
||||
"Miniature Schnauzer",
|
||||
"Giant Schnauzer",
|
||||
"Standard Schnauzer",
|
||||
"Scotch Terrier",
|
||||
"Tibetan Terrier",
|
||||
"Silky Terrier",
|
||||
"Soft-coated Wheaten Terrier",
|
||||
"West Highland White Terrier",
|
||||
"Lhasa",
|
||||
"Flat-coated Retriever",
|
||||
"Curly-coater Retriever",
|
||||
"Golden Retriever",
|
||||
"Labrador Retriever",
|
||||
"Chesapeake Bay Retriever",
|
||||
"German Short-haired Pointer",
|
||||
"Vizsla",
|
||||
"English Setter",
|
||||
"Irish Setter",
|
||||
"Gordon Setter",
|
||||
"Brittany",
|
||||
"Clumber",
|
||||
"English Springer Spaniel",
|
||||
"Welsh Springer Spaniel",
|
||||
"Cocker Spaniel",
|
||||
"Sussex Spaniel",
|
||||
"Irish Water Spaniel",
|
||||
"Kuvasz",
|
||||
"Schipperke",
|
||||
"Groenendael",
|
||||
"Malinois",
|
||||
"Briard",
|
||||
"Kelpie",
|
||||
"Komondor",
|
||||
"Old English Sheepdog",
|
||||
"Shetland Sheepdog",
|
||||
"Collie",
|
||||
"Border Collie",
|
||||
"Bouvier des Flandres",
|
||||
"Rottweiler",
|
||||
"German Shepard",
|
||||
"Doberman",
|
||||
"Miniature Pinscher",
|
||||
"Greater Swiss Mountain Dog",
|
||||
"Bernese Mountain Dog",
|
||||
"Appenzeller",
|
||||
"EntleBucher",
|
||||
"Boxer",
|
||||
"Bull Mastiff",
|
||||
"Tibetan Mastiff",
|
||||
"French Bulldog",
|
||||
"Great Dane",
|
||||
"Saint Bernard",
|
||||
"Eskimo Dog",
|
||||
"Malamute",
|
||||
"Siberian Husky",
|
||||
"Affenpinscher",
|
||||
"Basenji",
|
||||
"Pug",
|
||||
"Leonberg",
|
||||
"Newfoundland",
|
||||
"Great Pyrenees",
|
||||
"Samoyed",
|
||||
"Pomeranian",
|
||||
"Chow",
|
||||
"Keeshond",
|
||||
"Brabancon Griffon",
|
||||
"Pembroke",
|
||||
"Cardigan",
|
||||
"Toy Poodle",
|
||||
"Miniature Poodle",
|
||||
"Standard Poodle",
|
||||
"Mexican Hairless",
|
||||
"Dingo",
|
||||
"Dhole",
|
||||
"African Hunting Dog"]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._flat_breed_images)
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Args:
|
||||
index (int): Index
|
||||
Returns:
|
||||
tuple: (image, target) where target is index of the target character class.
|
||||
"""
|
||||
image_name, target_class = self._flat_breed_images[index]
|
||||
image_path = join(self.images_folder, image_name)
|
||||
image = Image.open(image_path).convert('RGB')
|
||||
|
||||
if self.cropped:
|
||||
image = image.crop(self._flat_breed_annotations[index][1])
|
||||
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
|
||||
if self.target_transform:
|
||||
target_class = self.target_transform(target_class)
|
||||
|
||||
return image, target_class
|
||||
|
||||
def download(self):
|
||||
import tarfile
|
||||
|
||||
if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')):
|
||||
if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120:
|
||||
print('Files already downloaded and verified')
|
||||
return
|
||||
|
||||
for filename in ['images', 'annotation', 'lists']:
|
||||
tar_filename = filename + '.tar'
|
||||
url = self.download_url_prefix + '/' + tar_filename
|
||||
download_url(url, self.root, tar_filename, None)
|
||||
print('Extracting downloaded file: ' + join(self.root, tar_filename))
|
||||
with tarfile.open(join(self.root, tar_filename), 'r') as tar_file:
|
||||
tar_file.extractall(self.root)
|
||||
os.remove(join(self.root, tar_filename))
|
||||
|
||||
@staticmethod
|
||||
def get_boxes(path):
|
||||
import xml.etree.ElementTree
|
||||
e = xml.etree.ElementTree.parse(path).getroot()
|
||||
boxes = []
|
||||
for objs in e.iter('object'):
|
||||
boxes.append([int(objs.find('bndbox').find('xmin').text),
|
||||
int(objs.find('bndbox').find('ymin').text),
|
||||
int(objs.find('bndbox').find('xmax').text),
|
||||
int(objs.find('bndbox').find('ymax').text)])
|
||||
return boxes
|
||||
|
||||
def load_split(self):
|
||||
if self.train:
|
||||
split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list']
|
||||
labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels']
|
||||
else:
|
||||
split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list']
|
||||
labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels']
|
||||
|
||||
split = [item[0][0] for item in split]
|
||||
labels = [item[0]-1 for item in labels]
|
||||
return list(zip(split, labels))
|
||||
|
||||
def stats(self):
|
||||
counts = {}
|
||||
for index in range(len(self._flat_breed_images)):
|
||||
image_name, target_class = self._flat_breed_images[index]
|
||||
if target_class not in counts.keys():
|
||||
counts[target_class] = 1
|
||||
else:
|
||||
counts[target_class] += 1
|
||||
|
||||
print("%d samples spanning %d classes (avg %f per class)"%(len(self._flat_breed_images), len(counts.keys()), float(len(self._flat_breed_images))/float(len(counts.keys()))))
|
||||
return counts
|
||||
|
||||
|
||||
class NABirds(Dataset):
|
||||
"""`NABirds <https://dl.allaboutbirds.org/nabirds>`_ Dataset.
|
||||
|
||||
Args:
|
||||
root (string): Root directory of the dataset.
|
||||
train (bool, optional): If True, creates dataset from training set, otherwise
|
||||
creates from test set.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
download (bool, optional): If true, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
"""
|
||||
base_folder = 'nabirds/images'
|
||||
|
||||
def __init__(self, root, train=True, transform=None):
|
||||
dataset_path = os.path.join(root, 'nabirds')
|
||||
self.root = root
|
||||
self.loader = default_loader
|
||||
self.train = train
|
||||
self.transform = transform
|
||||
|
||||
image_paths = pd.read_csv(os.path.join(dataset_path, 'images.txt'),
|
||||
sep=' ', names=['img_id', 'filepath'])
|
||||
image_class_labels = pd.read_csv(os.path.join(dataset_path, 'image_class_labels.txt'),
|
||||
sep=' ', names=['img_id', 'target'])
|
||||
# Since the raw labels are non-continuous, map them to new ones
|
||||
self.label_map = get_continuous_class_map(image_class_labels['target'])
|
||||
train_test_split = pd.read_csv(os.path.join(dataset_path, 'train_test_split.txt'),
|
||||
sep=' ', names=['img_id', 'is_training_img'])
|
||||
data = image_paths.merge(image_class_labels, on='img_id')
|
||||
self.data = data.merge(train_test_split, on='img_id')
|
||||
# Load in the train / test split
|
||||
if self.train:
|
||||
self.data = self.data[self.data.is_training_img == 1]
|
||||
else:
|
||||
self.data = self.data[self.data.is_training_img == 0]
|
||||
|
||||
# Load in the class data
|
||||
self.class_names = load_class_names(dataset_path)
|
||||
self.class_hierarchy = load_hierarchy(dataset_path)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.data.iloc[idx]
|
||||
path = os.path.join(self.root, self.base_folder, sample.filepath)
|
||||
target = self.label_map[sample.target]
|
||||
img = self.loader(path)
|
||||
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
return img, target
|
||||
|
||||
|
||||
def get_continuous_class_map(class_labels):
|
||||
label_set = set(class_labels)
|
||||
return {k: i for i, k in enumerate(label_set)}
|
||||
|
||||
|
||||
def load_class_names(dataset_path=''):
|
||||
names = {}
|
||||
|
||||
with open(os.path.join(dataset_path, 'classes.txt')) as f:
|
||||
for line in f:
|
||||
pieces = line.strip().split()
|
||||
class_id = pieces[0]
|
||||
names[class_id] = ' '.join(pieces[1:])
|
||||
|
||||
return names
|
||||
|
||||
|
||||
def load_hierarchy(dataset_path=''):
|
||||
parents = {}
|
||||
|
||||
with open(os.path.join(dataset_path, 'hierarchy.txt')) as f:
|
||||
for line in f:
|
||||
pieces = line.strip().split()
|
||||
child_id, parent_id = pieces
|
||||
parents[child_id] = parent_id
|
||||
|
||||
return parents
|
||||
|
||||
|
||||
class INat2017(VisionDataset):
|
||||
"""`iNaturalist 2017 <https://github.com/visipedia/inat_comp/blob/master/2017/README.md>`_ Dataset.
|
||||
Args:
|
||||
root (string): Root directory of the dataset.
|
||||
split (string, optional): The dataset split, supports ``train``, or ``val``.
|
||||
transform (callable, optional): A function/transform that takes in an PIL image
|
||||
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
||||
target_transform (callable, optional): A function/transform that takes in the
|
||||
target and transforms it.
|
||||
download (bool, optional): If true, downloads the dataset from the internet and
|
||||
puts it in root directory. If dataset is already downloaded, it is not
|
||||
downloaded again.
|
||||
"""
|
||||
base_folder = 'train_val_images/'
|
||||
file_list = {
|
||||
'imgs': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val_images.tar.gz',
|
||||
'train_val_images.tar.gz',
|
||||
'7c784ea5e424efaec655bd392f87301f'),
|
||||
'annos': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val2017.zip',
|
||||
'train_val2017.zip',
|
||||
'444c835f6459867ad69fcb36478786e7')
|
||||
}
|
||||
|
||||
def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
|
||||
super(INat2017, self).__init__(root, transform=transform, target_transform=target_transform)
|
||||
self.loader = default_loader
|
||||
self.split = verify_str_arg(split, "split", ("train", "val",))
|
||||
|
||||
if self._check_exists():
|
||||
print('Files already downloaded and verified.')
|
||||
elif download:
|
||||
if not (os.path.exists(os.path.join(self.root, self.file_list['imgs'][1]))
|
||||
and os.path.exists(os.path.join(self.root, self.file_list['annos'][1]))):
|
||||
print('Downloading...')
|
||||
self._download()
|
||||
print('Extracting...')
|
||||
extract_archive(os.path.join(self.root, self.file_list['imgs'][1]))
|
||||
extract_archive(os.path.join(self.root, self.file_list['annos'][1]))
|
||||
else:
|
||||
raise RuntimeError(
|
||||
'Dataset not found. You can use download=True to download it.')
|
||||
anno_filename = split + '2017.json'
|
||||
with open(os.path.join(self.root, anno_filename), 'r') as fp:
|
||||
all_annos = json.load(fp)
|
||||
|
||||
self.annos = all_annos['annotations']
|
||||
self.images = all_annos['images']
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = os.path.join(self.root, self.images[index]['file_name'])
|
||||
target = self.annos[index]['category_id']
|
||||
|
||||
image = self.loader(path)
|
||||
if self.transform is not None:
|
||||
image = self.transform(image)
|
||||
if self.target_transform is not None:
|
||||
target = self.target_transform(target)
|
||||
|
||||
return image, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
def _check_exists(self):
|
||||
return os.path.exists(os.path.join(self.root, self.base_folder))
|
||||
|
||||
def _download(self):
|
||||
for url, filename, md5 in self.file_list.values():
|
||||
download_url(url, root=self.root, filename=filename)
|
||||
if not check_integrity(os.path.join(self.root, filename), md5):
|
||||
raise RuntimeError("File not found or corrupted.")
|
Reference in New Issue
Block a user