import os import numpy as np import math import random import torchvision import torchvision.transforms as transforms import torchvision.datasets as datasets import pickle from tqdm import tqdm import time import multiprocessing TRAINER_NUMS = int(os.getenv("PADDLE_TRAINER_NUM", "1")) TRAINER_ID = int(os.getenv("PADDLE_TRAINER_ID", "0")) epoch = 0 class ImageFolder(object): def __init__(self, root, transforms): pass FINISH_EVENT = "FINISH_EVENT" class PaddleDataLoader(object): def __init__(self, torch_dataset, indices=None, concurrent=16, queue_size=3072): self.torch_dataset = torch_dataset self.data_queue = multiprocessing.Queue(queue_size) self.indices = indices self.concurrent = concurrent def _worker_loop(self, dataset, worker_indices, worker_id): cnt = 0 for idx in worker_indices: cnt += 1 img, label = self.torch_dataset[idx] img = np.array(img).astype('uint8').transpose((2, 0, 1)) self.data_queue.put((img, label)) print("worker: [%d] read [%d] samples. " % (worker_id, cnt)) self.data_queue.put(FINISH_EVENT) def reader(self): def _reader_creator(): worker_processes = [] total_img = len(self.torch_dataset) print("total image: ", total_img) if self.indices is None: self.indices = [i for i in xrange(total_img)] random.seed(time.time()) random.shuffle(self.indices) print("shuffle indices: %s ..." % self.indices[:10]) imgs_per_worker = int(math.ceil(total_img / self.concurrent)) for i in xrange(self.concurrent): start = i * imgs_per_worker end = (i + 1) * imgs_per_worker if i != self.concurrent - 1 else None sliced_indices = self.indices[start:end] w = multiprocessing.Process( target=self._worker_loop, args=(self.torch_dataset, sliced_indices, i) ) w.daemon = True w.start() worker_processes.append(w) finish_workers = 0 worker_cnt = len(worker_processes) while finish_workers < worker_cnt: sample = self.data_queue.get() if sample == FINISH_EVENT: finish_workers += 1 else: yield sample return _reader_creator def train(traindir, sz, min_scale=0.08): train_tfms = [ transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)), transforms.RandomHorizontalFlip() ] train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_tfms)) return PaddleDataLoader(train_dataset).reader() def test(valdir, bs, sz, rect_val=False): if rect_val: idx_ar_sorted = sort_ar(valdir) idx_sorted, _ = zip(*idx_ar_sorted) idx2ar = map_idx2ar(idx_ar_sorted, bs) ar_tfms = [transforms.Resize(int(sz* 1.14)), CropArTfm(idx2ar, sz)] val_dataset = ValDataset(valdir, transform=ar_tfms) return PaddleDataLoader(val_dataset, concurrent=1, indices=idx_sorted).reader() val_tfms = [transforms.Resize(int(sz* 1.14)), transforms.CenterCrop(sz)] val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_tfms)) return PaddleDataLoader(val_dataset).reader() class ValDataset(datasets.ImageFolder): def __init__(self, root, transform=None, target_transform=None): super(ValDataset, self).__init__(root, transform, target_transform) def __getitem__(self, index): path, target = self.imgs[index] sample = self.loader(path) if self.transform is not None: for tfm in self.transform: if isinstance(tfm, CropArTfm): sample = tfm(sample, index) else: sample = tfm(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target class CropArTfm(object): def __init__(self, idx2ar, target_size): self.idx2ar, self.target_size = idx2ar, target_size def __call__(self, img, idx): target_ar = self.idx2ar[idx] if target_ar < 1: w = int(self.target_size / target_ar) size = (w // 8 * 8, self.target_size) else: h = int(self.target_size * target_ar) size = (self.target_size, h // 8 * 8) return transforms.functional.center_crop(img, size) def sort_ar(valdir): idx2ar_file = valdir + '/../sorted_idxar.p' if os.path.isfile(idx2ar_file): return pickle.load(open(idx2ar_file, 'rb')) print('Creating AR indexes. Please be patient this may take a couple minutes...') val_dataset = datasets.ImageFolder(valdir) # AS: TODO: use Image.open instead of looping through dataset sizes = [img[0].size for img in tqdm(val_dataset, total=len(val_dataset))] idx_ar = [(i, round(s[0] * 1.0/ s[1], 5)) for i, s in enumerate(sizes)] sorted_idxar = sorted(idx_ar, key=lambda x: x[1]) pickle.dump(sorted_idxar, open(idx2ar_file, 'wb')) print('Done') return sorted_idxar def chunks(l, n): n = max(1, n) return (l[i:i + n] for i in range(0, len(l), n)) def map_idx2ar(idx_ar_sorted, batch_size): ar_chunks = list(chunks(idx_ar_sorted, batch_size)) idx2ar = {} for chunk in ar_chunks: idxs, ars = list(zip(*chunk)) mean = round(np.mean(ars), 5) for idx in idxs: idx2ar[idx] = mean return idx2ar if __name__ == "__main__": #ds, sampler = create_validation_set("/data/imagenet/validation", 128, 288, True, True) #for item in sampler: # for idx in item: # ds[idx] import time test_reader = test(valdir="/data/imagenet/validation", bs=50, sz=288, rect_val=True) start_ts = time.time() for idx, data in enumerate(test_reader()): print(idx, data[0].shape, data[1]) if idx == 10: break if (idx + 1) % 1000 == 0: cost = (time.time() - start_ts) print("%d samples per second" % (1000 / cost)) start_ts = time.time()