Files
ieemoo-ai-searchv2/nts/core/dataset.py
2022-11-22 15:32:06 +08:00

78 lines
3.3 KiB
Python

import numpy as np
import scipy.misc
import os
from PIL import Image
from torchvision import transforms
from config import INPUT_SIZE
class CUB():
def __init__(self, root, is_train=True, data_len=None):
self.root = root
self.is_train = is_train
img_txt_file = open(os.path.join(self.root, 'images.txt'))
label_txt_file = open(os.path.join(self.root, 'image_class_labels.txt'))
train_val_file = open(os.path.join(self.root, 'train_test_split.txt'))
img_name_list = []
for line in img_txt_file:
img_name_list.append(line[:-1].split(' ')[-1])
label_list = []
for line in label_txt_file:
label_list.append(int(line[:-1].split(' ')[-1]) - 1)
train_test_list = []
for line in train_val_file:
train_test_list.append(int(line[:-1].split(' ')[-1]))
train_file_list = [x for i, x in zip(train_test_list, img_name_list) if i]
test_file_list = [x for i, x in zip(train_test_list, img_name_list) if not i]
if self.is_train:
self.train_img = [scipy.misc.imread(os.path.join(self.root, 'images', train_file)) for train_file in
train_file_list[:data_len]]
self.train_label = [x for i, x in zip(train_test_list, label_list) if i][:data_len]
if not self.is_train:
self.test_img = [scipy.misc.imread(os.path.join(self.root, 'images', test_file)) for test_file in
test_file_list[:data_len]]
self.test_label = [x for i, x in zip(train_test_list, label_list) if not i][:data_len]
def __getitem__(self, index):
if self.is_train:
img, target = self.train_img[index], self.train_label[index]
if len(img.shape) == 2:
img = np.stack([img] * 3, 2)
img = Image.fromarray(img, mode='RGB')
img = transforms.Resize((600, 600), Image.BILINEAR)(img)
img = transforms.RandomCrop(INPUT_SIZE)(img)
img = transforms.RandomHorizontalFlip()(img)
img = transforms.ToTensor()(img)
img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img)
else:
img, target = self.test_img[index], self.test_label[index]
if len(img.shape) == 2:
img = np.stack([img] * 3, 2)
img = Image.fromarray(img, mode='RGB')
img = transforms.Resize((600, 600), Image.BILINEAR)(img)
img = transforms.CenterCrop(INPUT_SIZE)(img)
img = transforms.ToTensor()(img)
img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(img)
return img, target
def __len__(self):
if self.is_train:
return len(self.train_label)
else:
return len(self.test_label)
if __name__ == '__main__':
dataset = CUB(root='./CUB_200_2011')
print(len(dataset.train_img))
print(len(dataset.train_label))
for data in dataset:
print(data[0].size(), data[1])
dataset = CUB(root='./CUB_200_2011', is_train=False)
print(len(dataset.test_img))
print(len(dataset.test_label))
for data in dataset:
print(data[0].size(), data[1])