import numpy as np import scipy.misc import os from PIL import Image from torchvision import transforms from config import INPUT_SIZE class CUB(): def __init__(self, root, is_train=True, data_len=None): self.root = root self.is_train = is_train img_txt_file = open(os.path.join(self.root, 'images.txt')) label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt')) train_val_file = open(os.path.join(self.root, 'train_test_split.txt')) img_name_list = [] for line in img_txt_file: img_name_list.append(line[:-1].split(' ')[-1]) label_list = [] for line in label_txt_file: label_list.append(int(line[:-1].split(' ')[-1]) - 1) train_test_list = [] for line in train_val_file: train_test_list.append(int(line[:-1].split(' ')[-1])) train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i] test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i] if self.is_train: self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in train_file_list[:data_len]] self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len] if not self.is_train: self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in test_file_list[:data_len]] self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len] def __getitem__(self, index): if self.is_train: img, target = self.train_img[index], self.train_label[index] if len(img.shape) == 2: img = np.stack([img] * 3, 2) img = Image.fromarray(img, mode='RGB') img = transforms.Resize((600, 600), Image.BILINEAR)(img) img = transforms.RandomCrop(INPUT_SIZE)(img) img = transforms.RandomHorizontalFlip()(img) img = transforms.ToTensor()(img) img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img) else: img, target = self.test_img[index], self.test_label[index] if len(img.shape) == 2: img = np.stack([img] * 3, 2) img = Image.fromarray(img, mode='RGB') img = transforms.Resize((600, 600), Image.BILINEAR)(img) img = transforms.CenterCrop(INPUT_SIZE)(img) img = transforms.ToTensor()(img) img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img) return img, target def __len__(self): if self.is_train: return len(self.train_label) else: return len(self.test_label) if __name__ == '__main__': dataset = CUB(root='./CUB_200_2011') print(len(dataset.train_img)) print(len(dataset.train_label)) for data in dataset: print(data[0].size(), data[1]) dataset = CUB(root='./CUB_200_2011', is_train=False) print(len(dataset.test_img)) print(len(dataset.test_label)) for data in dataset: print(data[0].size(), data[1])