import os import json from os.path import join import numpy as np import scipy from scipy import io import scipy.misc from PIL import Image import pandas as pd import matplotlib.pyplot as plt import torch from torch.utils.data import Dataset from torchvision.datasets import VisionDataset from torchvision.datasets.folder import default_loader from torchvision.datasets.utils import download_url, list_dir, check_integrity, extract_archive, verify_str_arg class emptyJudge(): def __init__(self, root, is_train=True, data_len=None, transform=None): self.root = root self.is_train = is_train self.transform = transform img_txt_file = open(os.path.join(self.root, 'images.txt')) label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt')) train_val_file = open(os.path.join(self.root, 'train_test_split.txt')) img_name_list = [] for line in img_txt_file: img_name_list.append(line[:-1].split(' ')[-1]) label_list = [] for line in label_txt_file: label_list.append(int(line[:-1].split(' ')[-1]) - 1) train_test_list = [] for line in train_val_file: train_test_list.append(int(line[:-1].split(' ')[-1])) train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i] test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i] if self.is_train: self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in train_file_list[:data_len]] self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len] self.train_imgname = [x for x in train_file_list[:data_len]] if not self.is_train: self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in test_file_list[:data_len]] self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len] self.test_imgname = [x for x in test_file_list[:data_len]] def __getitem__(self, index): if self.is_train: img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index] if len(img.shape) == 2: img = np.stack([img] * 3, 2) img = Image.fromarray(img, mode='RGB') if self.transform is not None: img = self.transform(img) else: img, target, imgname = self.test_img[index], self.test_label[index], self.test_imgname[index] if len(img.shape) == 2: img = np.stack([img] * 3, 2) img = Image.fromarray(img, mode='RGB') if self.transform is not None: img = self.transform(img) return img, target def __len__(self): if self.is_train: return len(self.train_label) else: return len(self.test_label) class CUB(): def __init__(self, root, is_train=True, data_len=None, transform=None): self.root = root self.is_train = is_train self.transform = transform img_txt_file = open(os.path.join(self.root, 'images.txt')) label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt')) train_val_file = open(os.path.join(self.root, 'train_test_split.txt')) img_name_list = [] for line in img_txt_file: img_name_list.append(line[:-1].split(' ')[-1]) label_list = [] for line in label_txt_file: label_list.append(int(line[:-1].split(' ')[-1]) - 1) train_test_list = [] for line in train_val_file: train_test_list.append(int(line[:-1].split(' ')[-1])) train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i] test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i] if self.is_train: self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in train_file_list[:data_len]] self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len] self.train_imgname = [x for x in train_file_list[:data_len]] if not self.is_train: self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in test_file_list[:data_len]] self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len] self.test_imgname = [x for x in test_file_list[:data_len]] def __getitem__(self, index): if self.is_train: img, target, imgname = self.train_img[index], self.train_label[index], self.train_imgname[index] if len(img.shape) == 2: img = np.stack([img] * 3, 2) img = Image.fromarray(img, mode='RGB') if self.transform is not None: img = self.transform(img) else: img, target, imgname = self.test_img[index], self.test_label[index], self.test_imgname[index] if len(img.shape) == 2: img = np.stack([img] * 3, 2) img = Image.fromarray(img, mode='RGB') if self.transform is not None: img = self.transform(img) return img, target def __len__(self): if self.is_train: return len(self.train_label) else: return len(self.test_label) class CarsDataset(Dataset): def __init__(self, mat_anno, data_dir, car_names, cleaned=None, transform=None): """ Args: mat_anno (string): Path to the MATLAB annotation file. data_dir (string): Directory with all the images. transform (callable, optional): Optional transform to be applied on a sample. """ self.full_data_set = io.loadmat(mat_anno) self.car_annotations = self.full_data_set['annotations'] self.car_annotations = self.car_annotations[0] if cleaned is not None: cleaned_annos = [] print("Cleaning up data set (only take pics with rgb chans)...") clean_files = np.loadtxt(cleaned, dtype=str) for c in self.car_annotations: if c[-1][0] in clean_files: cleaned_annos.append(c) self.car_annotations = cleaned_annos self.car_names = scipy.io.loadmat(car_names)['class_names'] self.car_names = np.array(self.car_names[0]) self.data_dir = data_dir self.transform = transform def __len__(self): return len(self.car_annotations) def __getitem__(self, idx): img_name = os.path.join(self.data_dir, self.car_annotations[idx][-1][0]) image = Image.open(img_name).convert('RGB') car_class = self.car_annotations[idx][-2][0][0] car_class = torch.from_numpy(np.array(car_class.astype(np.float32))).long() - 1 assert car_class < 196 if self.transform: image = self.transform(image) # return image, car_class, img_name return image, car_class def map_class(self, id): id = np.ravel(id) ret = self.car_names[id - 1][0][0] return ret def show_batch(self, img_batch, class_batch): for i in range(img_batch.shape[0]): ax = plt.subplot(1, img_batch.shape[0], i + 1) title_str = self.map_class(int(class_batch[i])) img = np.transpose(img_batch[i, ...], (1, 2, 0)) ax.imshow(img) ax.set_title(title_str.__str__(), {'fontsize': 5}) plt.tight_layout() def make_dataset(dir, image_ids, targets): assert(len(image_ids) == len(targets)) images = [] dir = os.path.expanduser(dir) for i in range(len(image_ids)): item = (os.path.join(dir, 'data', 'images', '%s.jpg' % image_ids[i]), targets[i]) images.append(item) return images def find_classes(classes_file): # read classes file, separating out image IDs and class names image_ids = [] targets = [] f = open(classes_file, 'r') for line in f: split_line = line.split(' ') image_ids.append(split_line[0]) targets.append(' '.join(split_line[1:])) f.close() # index class names classes = np.unique(targets) class_to_idx = {classes[i]: i for i in range(len(classes))} targets = [class_to_idx[c] for c in targets] return (image_ids, targets, classes, class_to_idx) class dogs(Dataset): """`Stanford Dogs `_ Dataset. Args: root (string): Root directory of dataset where directory ``omniglot-py`` exists. cropped (bool, optional): If true, the images will be cropped into the bounding box specified in the annotations transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If true, downloads the dataset tar files from the internet and puts it in root directory. If the tar files are already downloaded, they are not downloaded again. """ folder = 'dog' download_url_prefix = 'http://vision.stanford.edu/aditya86/ImageNetDogs' def __init__(self, root, train=True, cropped=False, transform=None, target_transform=None, download=False): # self.root = join(os.path.expanduser(root), self.folder) self.root = root self.train = train self.cropped = cropped self.transform = transform self.target_transform = target_transform if download: self.download() split = self.load_split() self.images_folder = join(self.root, 'Images') self.annotations_folder = join(self.root, 'Annotation') self._breeds = list_dir(self.images_folder) if self.cropped: self._breed_annotations = [[(annotation, box, idx) for box in self.get_boxes(join(self.annotations_folder, annotation))] for annotation, idx in split] self._flat_breed_annotations = sum(self._breed_annotations, []) self._flat_breed_images = [(annotation+'.jpg', idx) for annotation, box, idx in self._flat_breed_annotations] else: self._breed_images = [(annotation+'.jpg', idx) for annotation, idx in split] self._flat_breed_images = self._breed_images self.classes = ["Chihuaha", "Japanese Spaniel", "Maltese Dog", "Pekinese", "Shih-Tzu", "Blenheim Spaniel", "Papillon", "Toy Terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", "Bloodhound", "Bluetick", "Black-and-tan Coonhound", "Walker Hound", "English Foxhound", "Redbone", "Borzoi", "Irish Wolfhound", "Italian Greyhound", "Whippet", "Ibizian Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", "Weimaraner", "Staffordshire Bullterrier", "American Staffordshire Terrier", "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wirehaired Fox Terrier", "Lakeland Terrier", "Sealyham Terrier", "Airedale", "Cairn", "Australian Terrier", "Dandi Dinmont", "Boston Bull", "Miniature Schnauzer", "Giant Schnauzer", "Standard Schnauzer", "Scotch Terrier", "Tibetan Terrier", "Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", "Lhasa", "Flat-coated Retriever", "Curly-coater Retriever", "Golden Retriever", "Labrador Retriever", "Chesapeake Bay Retriever", "German Short-haired Pointer", "Vizsla", "English Setter", "Irish Setter", "Gordon Setter", "Brittany", "Clumber", "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael", "Malinois", "Briard", "Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "Collie", "Border Collie", "Bouvier des Flandres", "Rottweiler", "German Shepard", "Doberman", "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", "Appenzeller", "EntleBucher", "Boxer", "Bull Mastiff", "Tibetan Mastiff", "French Bulldog", "Great Dane", "Saint Bernard", "Eskimo Dog", "Malamute", "Siberian Husky", "Affenpinscher", "Basenji", "Pug", "Leonberg", "Newfoundland", "Great Pyrenees", "Samoyed", "Pomeranian", "Chow", "Keeshond", "Brabancon Griffon", "Pembroke", "Cardigan", "Toy Poodle", "Miniature Poodle", "Standard Poodle", "Mexican Hairless", "Dingo", "Dhole", "African Hunting Dog"] def __len__(self): return len(self._flat_breed_images) def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target character class. """ image_name, target_class = self._flat_breed_images[index] image_path = join(self.images_folder, image_name) image = Image.open(image_path).convert('RGB') if self.cropped: image = image.crop(self._flat_breed_annotations[index][1]) if self.transform: image = self.transform(image) if self.target_transform: target_class = self.target_transform(target_class) return image, target_class def download(self): import tarfile if os.path.exists(join(self.root, 'Images')) and os.path.exists(join(self.root, 'Annotation')): if len(os.listdir(join(self.root, 'Images'))) == len(os.listdir(join(self.root, 'Annotation'))) == 120: print('Files already downloaded and verified') return for filename in ['images', 'annotation', 'lists']: tar_filename = filename + '.tar' url = self.download_url_prefix + '/' + tar_filename download_url(url, self.root, tar_filename, None) print('Extracting downloaded file: ' + join(self.root, tar_filename)) with tarfile.open(join(self.root, tar_filename), 'r') as tar_file: tar_file.extractall(self.root) os.remove(join(self.root, tar_filename)) @staticmethod def get_boxes(path): import xml.etree.ElementTree e = xml.etree.ElementTree.parse(path).getroot() boxes = [] for objs in e.iter('object'): boxes.append([int(objs.find('bndbox').find('xmin').text), int(objs.find('bndbox').find('ymin').text), int(objs.find('bndbox').find('xmax').text), int(objs.find('bndbox').find('ymax').text)]) return boxes def load_split(self): if self.train: split = scipy.io.loadmat(join(self.root, 'train_list.mat'))['annotation_list'] labels = scipy.io.loadmat(join(self.root, 'train_list.mat'))['labels'] else: split = scipy.io.loadmat(join(self.root, 'test_list.mat'))['annotation_list'] labels = scipy.io.loadmat(join(self.root, 'test_list.mat'))['labels'] split = [item[0][0] for item in split] labels = [item[0]-1 for item in labels] return list(zip(split, labels)) def stats(self): counts = {} for index in range(len(self._flat_breed_images)): image_name, target_class = self._flat_breed_images[index] if target_class not in counts.keys(): counts[target_class] = 1 else: counts[target_class] += 1 print("%d samples spanning %d classes (avg %f per class)"%(len(self._flat_breed_images), len(counts.keys()), float(len(self._flat_breed_images))/float(len(counts.keys())))) return counts class NABirds(Dataset): """`NABirds `_ Dataset. Args: root (string): Root directory of the dataset. train (bool, optional): If True, creates dataset from training set, otherwise creates from test set. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ base_folder = 'nabirds/images' def __init__(self, root, train=True, transform=None): dataset_path = os.path.join(root, 'nabirds') self.root = root self.loader = default_loader self.train = train self.transform = transform image_paths = pd.read_csv(os.path.join(dataset_path, 'images.txt'), sep=' ', names=['img_id', 'filepath']) image_class_labels = pd.read_csv(os.path.join(dataset_path, 'image_class_labels.txt'), sep=' ', names=['img_id', 'target']) # Since the raw labels are non-continuous, map them to new ones self.label_map = get_continuous_class_map(image_class_labels['target']) train_test_split = pd.read_csv(os.path.join(dataset_path, 'train_test_split.txt'), sep=' ', names=['img_id', 'is_training_img']) data = image_paths.merge(image_class_labels, on='img_id') self.data = data.merge(train_test_split, on='img_id') # Load in the train / test split if self.train: self.data = self.data[self.data.is_training_img == 1] else: self.data = self.data[self.data.is_training_img == 0] # Load in the class data self.class_names = load_class_names(dataset_path) self.class_hierarchy = load_hierarchy(dataset_path) def __len__(self): return len(self.data) def __getitem__(self, idx): sample = self.data.iloc[idx] path = os.path.join(self.root, self.base_folder, sample.filepath) target = self.label_map[sample.target] img = self.loader(path) if self.transform is not None: img = self.transform(img) return img, target def get_continuous_class_map(class_labels): label_set = set(class_labels) return {k: i for i, k in enumerate(label_set)} def load_class_names(dataset_path=''): names = {} with open(os.path.join(dataset_path, 'classes.txt')) as f: for line in f: pieces = line.strip().split() class_id = pieces[0] names[class_id] = ' '.join(pieces[1:]) return names def load_hierarchy(dataset_path=''): parents = {} with open(os.path.join(dataset_path, 'hierarchy.txt')) as f: for line in f: pieces = line.strip().split() child_id, parent_id = pieces parents[child_id] = parent_id return parents class INat2017(VisionDataset): """`iNaturalist 2017 `_ Dataset. Args: root (string): Root directory of the dataset. split (string, optional): The dataset split, supports ``train``, or ``val``. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. download (bool, optional): If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ base_folder = 'train_val_images/' file_list = { 'imgs': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val_images.tar.gz', 'train_val_images.tar.gz', '7c784ea5e424efaec655bd392f87301f'), 'annos': ('https://storage.googleapis.com/asia_inat_data/train_val/train_val2017.zip', 'train_val2017.zip', '444c835f6459867ad69fcb36478786e7') } def __init__(self, root, split='train', transform=None, target_transform=None, download=False): super(INat2017, self).__init__(root, transform=transform, target_transform=target_transform) self.loader = default_loader self.split = verify_str_arg(split, "split", ("train", "val",)) if self._check_exists(): print('Files already downloaded and verified.') elif download: if not (os.path.exists(os.path.join(self.root, self.file_list['imgs'][1])) and os.path.exists(os.path.join(self.root, self.file_list['annos'][1]))): print('Downloading...') self._download() print('Extracting...') extract_archive(os.path.join(self.root, self.file_list['imgs'][1])) extract_archive(os.path.join(self.root, self.file_list['annos'][1])) else: raise RuntimeError( 'Dataset not found. You can use download=True to download it.') anno_filename = split + '2017.json' with open(os.path.join(self.root, anno_filename), 'r') as fp: all_annos = json.load(fp) self.annos = all_annos['annotations'] self.images = all_annos['images'] def __getitem__(self, index): path = os.path.join(self.root, self.images[index]['file_name']) target = self.annos[index]['category_id'] image = self.loader(path) if self.transform is not None: image = self.transform(image) if self.target_transform is not None: target = self.target_transform(target) return image, target def __len__(self): return len(self.images) def _check_exists(self): return os.path.exists(os.path.join(self.root, self.base_folder)) def _download(self): for url, filename, md5 in self.file_list.values(): download_url(url, root=self.root, filename=filename) if not check_integrity(os.path.join(self.root, filename), md5): raise RuntimeError("File not found or corrupted.")