630 lines
25 KiB
Python
Executable File
630 lines
25 KiB
Python
Executable File
import os
|
|
import json
|
|
from os.path import join
|
|
|
|
import numpy as np
|
|
import scipy
|
|
from scipy import io
|
|
import imageio
|
|
from PIL import Image
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
from torchvision.datasets import VisionDataset
|
|
from torchvision.datasets.folder import default_loader
|
|
from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg
|
|
|
|
#对各种数据集的底层读取
|
|
class emptyJudge():
|
|
def __init__(self, root, is_train=True, data_len=None, transform=None):
|
|
self.root = root
|
|
self.is_train = is_train
|
|
self.transform = transform
|
|
img_txt_file = open(os.path.join(self.root, 'images.txt'))
|
|
label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
|
|
train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
|
|
img_name_list = []
|
|
for line in img_txt_file:
|
|
img_name_list.append(line[:-1].split(' ')[-1])
|
|
label_list = []
|
|
for line in label_txt_file:
|
|
label_list.append(int(line[:-1].split(' ')[-1]) - 1)
|
|
train_test_list = []
|
|
for line in train_val_file:
|
|
train_test_list.append(int(line[:-1].split(' ')[-1]))
|
|
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
|
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
|
if self.is_train:
|
|
self.train_img = [imageio.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
|
train_file_list[:data_len]]
|
|
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
|
self.train_imgname = [x for x in train_file_list[:data_len]]
|
|
if not self.is_train:
|
|
self.test_img = [imageio.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
|
test_file_list[:data_len]]
|
|
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
|
self.test_imgname = [x for x in test_file_list[:data_len]]
|
|
|
|
def __getitem__(self, index):
|
|
if self.is_train:
|
|
img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
|
|
if len(img.shape) == 2:
|
|
img = np.stack([img] * 3, 2) #拼接为三维数组,[3,width,highth]
|
|
img = Image.fromarray(img, mode='RGB')
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
else:
|
|
img, target, imgname = self.test_img[index], self.test_label[index], self.test_imgname[index]
|
|
if len(img.shape) == 2:
|
|
img = np.stack([img] * 3, 2)
|
|
img = Image.fromarray(img, mode='RGB')
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
return img, target
|
|
|
|
def __len__(self):
|
|
if self.is_train:
|
|
return len(self.train_label)
|
|
else:
|
|
return len(self.test_label)
|
|
|
|
|
|
class CUB():
|
|
def __init__(self, root, is_train=True, data_len=None, transform=None):
|
|
self.root = root
|
|
self.is_train = is_train
|
|
self.transform = transform
|
|
img_txt_file = open(os.path.join(self.root, 'images.txt'))
|
|
label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
|
|
train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
|
|
img_name_list = []
|
|
for line in img_txt_file:
|
|
img_name_list.append(line[:-1].split(' ')[-1])
|
|
label_list = []
|
|
for line in label_txt_file:
|
|
label_list.append(int(line[:-1].split(' ')[-1]) - 1)
|
|
train_test_list = []
|
|
for line in train_val_file:
|
|
train_test_list.append(int(line[:-1].split(' ')[-1]))
|
|
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
|
|
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
|
|
if self.is_train:
|
|
self.train_img = [imageio.imread(os.path.join(self.root, 'images', train_file)) for train_file in
|
|
train_file_list[:data_len]]
|
|
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
|
|
self.train_imgname = [x for x in train_file_list[:data_len]]
|
|
if not self.is_train:
|
|
self.test_img = [imageio.imread(os.path.join(self.root, 'images', test_file)) for test_file in
|
|
test_file_list[:data_len]]
|
|
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
|
|
self.test_imgname = [x for x in test_file_list[:data_len]]
|
|
|
|
def __getitem__(self, index):
|
|
if self.is_train:
|
|
img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index]
|
|
if len(img.shape) == 2:
|
|
img = np.stack([img] * 3, 2)
|
|
img = Image.fromarray(img, mode='RGB')
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
else:
|
|
img, target, imgname = self.test_img[index], self.test_label[index], self.test_imgname[index]
|
|
if len(img.shape) == 2:
|
|
img = np.stack([img] * 3, 2)
|
|
img = Image.fromarray(img, mode='RGB')
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
|
|
return img, target
|
|
|
|
def __len__(self):
|
|
if self.is_train:
|
|
return len(self.train_label)
|
|
else:
|
|
return len(self.test_label)
|
|
|
|
|
|
class CarsDataset(Dataset):
|
|
def __init__(self, mat_anno, data_dir, car_names, cleaned=None, transform=None):
|
|
"""
|
|
Args:
|
|
mat_anno (string): Path to the MATLAB annotation file.
|
|
data_dir (string): Directory with all the images.
|
|
transform (callable, optional): Optional transform to be applied
|
|
on a sample.
|
|
"""
|
|
|
|
self.full_data_set = io.loadmat(mat_anno)
|
|
self.car_annotations = self.full_data_set['annotations']
|
|
self.car_annotations = self.car_annotations[0]
|
|
|
|
if cleaned is not None:
|
|
cleaned_annos = []
|
|
print("Cleaning up data set (only take pics with rgb chans)...")
|
|
clean_files = np.loadtxt(cleaned, dtype=str)
|
|
for c in self.car_annotations:
|
|
if c[-1][0] in clean_files:
|
|
cleaned_annos.append(c)
|
|
self.car_annotations = cleaned_annos
|
|
|
|
self.car_names = scipy.io.loadmat(car_names)['class_names']
|
|
self.car_names = np.array(self.car_names[0])
|
|
|
|
self.data_dir = data_dir
|
|
self.transform = transform
|
|
|
|
def __len__(self):
|
|
return len(self.car_annotations)
|
|
|
|
def __getitem__(self, idx):
|
|
img_name = os.path.join(self.data_dir, self.car_annotations[idx][-1][0])
|
|
image = Image.open(img_name).convert('RGB')
|
|
car_class = self.car_annotations[idx][-2][0][0]
|
|
car_class = torch.from_numpy(np.array(car_class.astype(np.float32))).long() - 1
|
|
assert car_class < 196
|
|
|
|
if self.transform:
|
|
image = self.transform(image)
|
|
|
|
# return image, car_class, img_name
|
|
return image, car_class
|
|
|
|
def map_class(self, id):
|
|
id = np.ravel(id)
|
|
ret = self.car_names[id - 1][0][0]
|
|
return ret
|
|
|
|
def show_batch(self, img_batch, class_batch):
|
|
|
|
for i in range(img_batch.shape[0]):
|
|
ax = plt.subplot(1, img_batch.shape[0], i + 1)
|
|
title_str = self.map_class(int(class_batch[i]))
|
|
img = np.transpose(img_batch[i, ...], (1, 2, 0))
|
|
ax.imshow(img)
|
|
ax.set_title(title_str.__str__(), {'fontsize': 5})
|
|
plt.tight_layout()
|
|
|
|
|
|
def make_dataset(dir, image_ids, targets):
|
|
assert(len(image_ids) == len(targets))
|
|
images = []
|
|
dir = os.path.expanduser(dir)
|
|
for i in range(len(image_ids)):
|
|
item = (os.path.join(dir, 'data', 'images', '%s.jpg' % image_ids[i]), targets[i])
|
|
images.append(item)
|
|
return images
|
|
|
|
|
|
def find_classes(classes_file):
|
|
# read classes file, separating out image IDs and class names
|
|
image_ids = []
|
|
targets = []
|
|
f = open(classes_file, 'r')
|
|
for line in f:
|
|
split_line = line.split(' ')
|
|
image_ids.append(split_line[0])
|
|
targets.append(' '.join(split_line[1:]))
|
|
f.close()
|
|
|
|
# index class names
|
|
classes = np.unique(targets)
|
|
class_to_idx = {classes[i]: i for i in range(len(classes))}
|
|
targets = [class_to_idx[c] for c in targets]
|
|
return (image_ids, targets, classes, class_to_idx)
|
|
|
|
|
|
class dogs(Dataset):
|
|
"""`Stanford Dogs <http://vision.stanford.edu/aditya86/ImageNetDogs/>`_ Dataset.
|
|
Args:
|
|
root (string): Root directory of dataset where directory
|
|
``omniglot-py`` exists.
|
|
cropped (bool, optional): If true, the images will be cropped into the bounding box specified
|
|
in the annotations
|
|
transform (callable, optional): A function/transform that takes in an PIL image
|
|
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
|
target_transform (callable, optional): A function/transform that takes in the
|
|
target and transforms it.
|
|
download (bool, optional): If true, downloads the dataset tar files from the internet and
|
|
puts it in root directory. If the tar files are already downloaded, they are not
|
|
downloaded again.
|
|
"""
|
|
folder = 'dog'
|
|
download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs'
|
|
|
|
def __init__(self,
|
|
root,
|
|
train=True,
|
|
cropped=False,
|
|
transform=None,
|
|
target_transform=None,
|
|
download=False):
|
|
|
|
# self.root = join(os.path.expanduser(root), self.folder)
|
|
self.root = root
|
|
self.train = train
|
|
self.cropped = cropped
|
|
self.transform = transform
|
|
self.target_transform = target_transform
|
|
|
|
if download:
|
|
self.download()
|
|
|
|
split = self.load_split()
|
|
|
|
self.images_folder = join(self.root, 'Images')
|
|
self.annotations_folder = join(self.root, 'Annotation')
|
|
self._breeds = list_dir(self.images_folder)
|
|
|
|
if self.cropped:
|
|
self._breed_annotations = [[(annotation, box, idx)
|
|
for box in self.get_boxes(join(self.annotations_folder, annotation))]
|
|
for annotation, idx in split]
|
|
self._flat_breed_annotations = sum(self._breed_annotations, [])
|
|
|
|
self._flat_breed_images = [(annotation+'.jpg', idx) for annotation, box, idx in self._flat_breed_annotations]
|
|
else:
|
|
self._breed_images = [(annotation+'.jpg', idx) for annotation, idx in split]
|
|
|
|
self._flat_breed_images = self._breed_images
|
|
|
|
self.classes = ["Chihuaha",
|
|
"Japanese Spaniel",
|
|
"Maltese Dog",
|
|
"Pekinese",
|
|
"Shih-Tzu",
|
|
"Blenheim Spaniel",
|
|
"Papillon",
|
|
"Toy Terrier",
|
|
"Rhodesian Ridgeback",
|
|
"Afghan Hound",
|
|
"Basset Hound",
|
|
"Beagle",
|
|
"Bloodhound",
|
|
"Bluetick",
|
|
"Black-and-tan Coonhound",
|
|
"Walker Hound",
|
|
"English Foxhound",
|
|
"Redbone",
|
|
"Borzoi",
|
|
"Irish Wolfhound",
|
|
"Italian Greyhound",
|
|
"Whippet",
|
|
"Ibizian Hound",
|
|
"Norwegian Elkhound",
|
|
"Otterhound",
|
|
"Saluki",
|
|
"Scottish Deerhound",
|
|
"Weimaraner",
|
|
"Staffordshire Bullterrier",
|
|
"American Staffordshire Terrier",
|
|
"Bedlington Terrier",
|
|
"Border Terrier",
|
|
"Kerry Blue Terrier",
|
|
"Irish Terrier",
|
|
"Norfolk Terrier",
|
|
"Norwich Terrier",
|
|
"Yorkshire Terrier",
|
|
"Wirehaired Fox Terrier",
|
|
"Lakeland Terrier",
|
|
"Sealyham Terrier",
|
|
"Airedale",
|
|
"Cairn",
|
|
"Australian Terrier",
|
|
"Dandi Dinmont",
|
|
"Boston Bull",
|
|
"Miniature Schnauzer",
|
|
"Giant Schnauzer",
|
|
"Standard Schnauzer",
|
|
"Scotch Terrier",
|
|
"Tibetan Terrier",
|
|
"Silky Terrier",
|
|
"Soft-coated Wheaten Terrier",
|
|
"West Highland White Terrier",
|
|
"Lhasa",
|
|
"Flat-coated Retriever",
|
|
"Curly-coater Retriever",
|
|
"Golden Retriever",
|
|
"Labrador Retriever",
|
|
"Chesapeake Bay Retriever",
|
|
"German Short-haired Pointer",
|
|
"Vizsla",
|
|
"English Setter",
|
|
"Irish Setter",
|
|
"Gordon Setter",
|
|
"Brittany",
|
|
"Clumber",
|
|
"English Springer Spaniel",
|
|
"Welsh Springer Spaniel",
|
|
"Cocker Spaniel",
|
|
"Sussex Spaniel",
|
|
"Irish Water Spaniel",
|
|
"Kuvasz",
|
|
"Schipperke",
|
|
"Groenendael",
|
|
"Malinois",
|
|
"Briard",
|
|
"Kelpie",
|
|
"Komondor",
|
|
"Old English Sheepdog",
|
|
"Shetland Sheepdog",
|
|
"Collie",
|
|
"Border Collie",
|
|
"Bouvier des Flandres",
|
|
"Rottweiler",
|
|
"German Shepard",
|
|
"Doberman",
|
|
"Miniature Pinscher",
|
|
"Greater Swiss Mountain Dog",
|
|
"Bernese Mountain Dog",
|
|
"Appenzeller",
|
|
"EntleBucher",
|
|
"Boxer",
|
|
"Bull Mastiff",
|
|
"Tibetan Mastiff",
|
|
"French Bulldog",
|
|
"Great Dane",
|
|
"Saint Bernard",
|
|
"Eskimo Dog",
|
|
"Malamute",
|
|
"Siberian Husky",
|
|
"Affenpinscher",
|
|
"Basenji",
|
|
"Pug",
|
|
"Leonberg",
|
|
"Newfoundland",
|
|
"Great Pyrenees",
|
|
"Samoyed",
|
|
"Pomeranian",
|
|
"Chow",
|
|
"Keeshond",
|
|
"Brabancon Griffon",
|
|
"Pembroke",
|
|
"Cardigan",
|
|
"Toy Poodle",
|
|
"Miniature Poodle",
|
|
"Standard Poodle",
|
|
"Mexican Hairless",
|
|
"Dingo",
|
|
"Dhole",
|
|
"African Hunting Dog"]
|
|
|
|
def __len__(self):
|
|
return len(self._flat_breed_images)
|
|
|
|
def __getitem__(self, index):
|
|
"""
|
|
Args:
|
|
index (int): Index
|
|
Returns:
|
|
tuple: (image, target) where target is index of the target character class.
|
|
"""
|
|
image_name, target_class = self._flat_breed_images[index]
|
|
image_path = join(self.images_folder, image_name)
|
|
image = Image.open(image_path).convert('RGB')
|
|
|
|
if self.cropped:
|
|
image = image.crop(self._flat_breed_annotations[index][1])
|
|
|
|
if self.transform:
|
|
image = self.transform(image)
|
|
|
|
if self.target_transform:
|
|
target_class = self.target_transform(target_class)
|
|
|
|
return image, target_class
|
|
|
|
def download(self):
|
|
import tarfile
|
|
|
|
if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')):
|
|
if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120:
|
|
print('Files already downloaded and verified')
|
|
return
|
|
|
|
for filename in ['images', 'annotation', 'lists']:
|
|
tar_filename = filename + '.tar'
|
|
url = self.download_url_prefix + '/' + tar_filename
|
|
download_url(url, self.root, tar_filename, None)
|
|
print('Extracting downloaded file: ' + join(self.root, tar_filename))
|
|
with tarfile.open(join(self.root, tar_filename), 'r') as tar_file:
|
|
tar_file.extractall(self.root)
|
|
os.remove(join(self.root, tar_filename))
|
|
|
|
@staticmethod
|
|
def get_boxes(path):
|
|
import xml.etree.ElementTree
|
|
e = xml.etree.ElementTree.parse(path).getroot()
|
|
boxes = []
|
|
for objs in e.iter('object'):
|
|
boxes.append([int(objs.find('bndbox').find('xmin').text),
|
|
int(objs.find('bndbox').find('ymin').text),
|
|
int(objs.find('bndbox').find('xmax').text),
|
|
int(objs.find('bndbox').find('ymax').text)])
|
|
return boxes
|
|
|
|
def load_split(self):
|
|
if self.train:
|
|
split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list']
|
|
labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels']
|
|
else:
|
|
split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list']
|
|
labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels']
|
|
|
|
split = [item[0][0] for item in split]
|
|
labels = [item[0]-1 for item in labels]
|
|
return list(zip(split, labels))
|
|
|
|
def stats(self):
|
|
counts = {}
|
|
for index in range(len(self._flat_breed_images)):
|
|
image_name, target_class = self._flat_breed_images[index]
|
|
if target_class not in counts.keys():
|
|
counts[target_class] = 1
|
|
else:
|
|
counts[target_class] += 1
|
|
|
|
print("%d samples spanning %d classes (avg %f per class)"%(len(self._flat_breed_images), len(counts.keys()), float(len(self._flat_breed_images))/float(len(counts.keys()))))
|
|
return counts
|
|
|
|
|
|
class NABirds(Dataset):
|
|
"""`NABirds <https://dl.allaboutbirds.org/nabirds>`_ Dataset.
|
|
|
|
Args:
|
|
root (string): Root directory of the dataset.
|
|
train (bool, optional): If True, creates dataset from training set, otherwise
|
|
creates from test set.
|
|
transform (callable, optional): A function/transform that takes in an PIL image
|
|
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
|
target_transform (callable, optional): A function/transform that takes in the
|
|
target and transforms it.
|
|
download (bool, optional): If true, downloads the dataset from the internet and
|
|
puts it in root directory. If dataset is already downloaded, it is not
|
|
downloaded again.
|
|
"""
|
|
base_folder = 'nabirds/images'
|
|
|
|
def __init__(self, root, train=True, transform=None):
|
|
dataset_path = os.path.join(root, 'nabirds')
|
|
self.root = root
|
|
self.loader = default_loader
|
|
self.train = train
|
|
self.transform = transform
|
|
|
|
image_paths = pd.read_csv(os.path.join(dataset_path, 'images.txt'),
|
|
sep=' ', names=['img_id', 'filepath'])
|
|
image_class_labels = pd.read_csv(os.path.join(dataset_path, 'image_class_labels.txt'),
|
|
sep=' ', names=['img_id', 'target'])
|
|
# Since the raw labels are non-continuous, map them to new ones
|
|
self.label_map = get_continuous_class_map(image_class_labels['target'])
|
|
train_test_split = pd.read_csv(os.path.join(dataset_path, 'train_test_split.txt'),
|
|
sep=' ', names=['img_id', 'is_training_img'])
|
|
data = image_paths.merge(image_class_labels, on='img_id')
|
|
self.data = data.merge(train_test_split, on='img_id')
|
|
# Load in the train / test split
|
|
if self.train:
|
|
self.data = self.data[self.data.is_training_img == 1]
|
|
else:
|
|
self.data = self.data[self.data.is_training_img == 0]
|
|
|
|
# Load in the class data
|
|
self.class_names = load_class_names(dataset_path)
|
|
self.class_hierarchy = load_hierarchy(dataset_path)
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __getitem__(self, idx):
|
|
sample = self.data.iloc[idx]
|
|
path = os.path.join(self.root, self.base_folder, sample.filepath)
|
|
target = self.label_map[sample.target]
|
|
img = self.loader(path)
|
|
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
return img, target
|
|
|
|
|
|
def get_continuous_class_map(class_labels):
|
|
label_set = set(class_labels)
|
|
return {k: i for i, k in enumerate(label_set)}
|
|
|
|
|
|
def load_class_names(dataset_path=''):
|
|
names = {}
|
|
|
|
with open(os.path.join(dataset_path, 'classes.txt')) as f:
|
|
for line in f:
|
|
pieces = line.strip().split()
|
|
class_id = pieces[0]
|
|
names[class_id] = ' '.join(pieces[1:])
|
|
|
|
return names
|
|
|
|
|
|
def load_hierarchy(dataset_path=''):
|
|
parents = {}
|
|
|
|
with open(os.path.join(dataset_path, 'hierarchy.txt')) as f:
|
|
for line in f:
|
|
pieces = line.strip().split()
|
|
child_id, parent_id = pieces
|
|
parents[child_id] = parent_id
|
|
|
|
return parents
|
|
|
|
|
|
class INat2017(VisionDataset):
|
|
"""`iNaturalist 2017 <https://github.com/visipedia/inat_comp/blob/master/2017/README.md>`_ Dataset.
|
|
Args:
|
|
root (string): Root directory of the dataset.
|
|
split (string, optional): The dataset split, supports ``train``, or ``val``.
|
|
transform (callable, optional): A function/transform that takes in an PIL image
|
|
and returns a transformed version. E.g, ``transforms.RandomCrop``
|
|
target_transform (callable, optional): A function/transform that takes in the
|
|
target and transforms it.
|
|
download (bool, optional): If true, downloads the dataset from the internet and
|
|
puts it in root directory. If dataset is already downloaded, it is not
|
|
downloaded again.
|
|
"""
|
|
base_folder = 'train_val_images/'
|
|
file_list = {
|
|
'imgs': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val_images.tar.gz',
|
|
'train_val_images.tar.gz',
|
|
'7c784ea5e424efaec655bd392f87301f'),
|
|
'annos': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val2017.zip',
|
|
'train_val2017.zip',
|
|
'444c835f6459867ad69fcb36478786e7')
|
|
}
|
|
|
|
def __init__(self, root, split='train', transform=None, target_transform=None, download=False):
|
|
super(INat2017, self).__init__(root, transform=transform, target_transform=target_transform)
|
|
self.loader = default_loader
|
|
self.split = verify_str_arg(split, "split", ("train", "val",))
|
|
|
|
if self._check_exists():
|
|
print('Files already downloaded and verified.')
|
|
elif download:
|
|
if not (os.path.exists(os.path.join(self.root, self.file_list['imgs'][1]))
|
|
and os.path.exists(os.path.join(self.root, self.file_list['annos'][1]))):
|
|
print('Downloading...')
|
|
self._download()
|
|
print('Extracting...')
|
|
extract_archive(os.path.join(self.root, self.file_list['imgs'][1]))
|
|
extract_archive(os.path.join(self.root, self.file_list['annos'][1]))
|
|
else:
|
|
raise RuntimeError(
|
|
'Dataset not found. You can use download=True to download it.')
|
|
anno_filename = split + '2017.json'
|
|
with open(os.path.join(self.root, anno_filename), 'r') as fp:
|
|
all_annos = json.load(fp)
|
|
|
|
self.annos = all_annos['annotations']
|
|
self.images = all_annos['images']
|
|
|
|
def __getitem__(self, index):
|
|
path = os.path.join(self.root, self.images[index]['file_name'])
|
|
target = self.annos[index]['category_id']
|
|
|
|
image = self.loader(path)
|
|
if self.transform is not None:
|
|
image = self.transform(image)
|
|
if self.target_transform is not None:
|
|
target = self.target_transform(target)
|
|
|
|
return image, target
|
|
|
|
def __len__(self):
|
|
return len(self.images)
|
|
|
|
def _check_exists(self):
|
|
return os.path.exists(os.path.join(self.root, self.base_folder))
|
|
|
|
def _download(self):
|
|
for url, filename, md5 in self.file_list.values():
|
|
download_url(url, root=self.root, filename=filename)
|
|
if not check_integrity(os.path.join(self.root, filename), md5):
|
|
raise RuntimeError("File not found or corrupted.")
|