From b6aed0a940c92532b140eade54265b53f3b4e52c Mon Sep 17 00:00:00 2001 From: lilong12 Date: Thu, 16 May 2019 10:01:05 +0800 Subject: [PATCH] modify single queue to a queue per process for data processing (#2136) --- .../fast_imagenet/torchvision_reader.py | 63 +++++++++++++------ 1 file changed, 44 insertions(+), 19 deletions(-) diff --git a/PaddleCV/image_classification/fast_imagenet/torchvision_reader.py b/PaddleCV/image_classification/fast_imagenet/torchvision_reader.py index c1b0fb95..20500816 100644 --- a/PaddleCV/image_classification/fast_imagenet/torchvision_reader.py +++ b/PaddleCV/image_classification/fast_imagenet/torchvision_reader.py @@ -1,3 +1,4 @@ +from __future__ import division import os import numpy as np @@ -17,28 +18,37 @@ import time import multiprocessing FINISH_EVENT = "FINISH_EVENT" + + class PaddleDataLoader(object): - def __init__(self, torch_dataset, indices=None, concurrent=24, queue_size=3072, shuffle=True, shuffle_seed=0): + def __init__(self, + torch_dataset, + indices=None, + concurrent=24, + queue_size=3072, + shuffle=True, + shuffle_seed=0): self.torch_dataset = torch_dataset - self.data_queue = multiprocessing.Queue(queue_size) self.indices = indices self.concurrent = concurrent self.shuffle = shuffle - self.shuffle_seed=shuffle_seed + self.shuffle_seed = shuffle_seed + self.queue_size = queue_size // self.concurrent - def _worker_loop(self, dataset, worker_indices, worker_id): + def _worker_loop(self, queue, worker_indices, worker_id): cnt = 0 for idx in worker_indices: cnt += 1 img, label = self.torch_dataset[idx] img = np.array(img).astype('uint8').transpose((2, 0, 1)) - self.data_queue.put((img, label)) + queue.put((img, label)) print("worker: [%d] read [%d] samples. " % (worker_id, cnt)) - self.data_queue.put(FINISH_EVENT) + queue.put(FINISH_EVENT) def reader(self): def _reader_creator(): worker_processes = [] + index_queues = [] total_img = len(self.torch_dataset) print("total image: ", total_img) if self.shuffle: @@ -50,19 +60,25 @@ class PaddleDataLoader(object): imgs_per_worker = int(math.ceil(total_img / self.concurrent)) for i in xrange(self.concurrent): start = i * imgs_per_worker - end = (i + 1) * imgs_per_worker if i != self.concurrent - 1 else None + end = (i + 1 + ) * imgs_per_worker if i != self.concurrent - 1 else None sliced_indices = self.indices[start:end] + index_queue = multiprocessing.Queue(self.queue_size) w = multiprocessing.Process( target=self._worker_loop, - args=(self.torch_dataset, sliced_indices, i) - ) + args=(index_queue, sliced_indices, i)) w.daemon = True w.start() worker_processes.append(w) + index_queues.append(index_queue) finish_workers = 0 worker_cnt = len(worker_processes) + recv_index = 0 while finish_workers < worker_cnt: - sample = self.data_queue.get() + while (index_queues[recv_index].empty()): + recv_index = (recv_index + 1) % self.concurrent + sample = index_queues[recv_index].get() + recv_index = (recv_index + 1) % self.concurrent if sample == FINISH_EVENT: finish_workers += 1 else: @@ -70,25 +86,30 @@ class PaddleDataLoader(object): return _reader_creator + def train(traindir, sz, min_scale=0.08, shuffle_seed=0): train_tfms = [ - transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)), - transforms.RandomHorizontalFlip() + transforms.RandomResizedCrop( + sz, scale=(min_scale, 1.0)), transforms.RandomHorizontalFlip() ] - train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_tfms)) + train_dataset = datasets.ImageFolder(traindir, + transforms.Compose(train_tfms)) return PaddleDataLoader(train_dataset, shuffle_seed=shuffle_seed).reader() + def test(valdir, bs, sz, rect_val=False): if rect_val: idx_ar_sorted = sort_ar(valdir) idx_sorted, _ = zip(*idx_ar_sorted) idx2ar = map_idx2ar(idx_ar_sorted, bs) - ar_tfms = [transforms.Resize(int(sz* 1.14)), CropArTfm(idx2ar, sz)] + ar_tfms = [transforms.Resize(int(sz * 1.14)), CropArTfm(idx2ar, sz)] val_dataset = ValDataset(valdir, transform=ar_tfms) - return PaddleDataLoader(val_dataset, concurrent=1, indices=idx_sorted, shuffle=False).reader() + return PaddleDataLoader( + val_dataset, concurrent=1, indices=idx_sorted, + shuffle=False).reader() - val_tfms = [transforms.Resize(int(sz* 1.14)), transforms.CenterCrop(sz)] + val_tfms = [transforms.Resize(int(sz * 1.14)), transforms.CenterCrop(sz)] val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_tfms)) return PaddleDataLoader(val_dataset).reader() @@ -132,15 +153,19 @@ def sort_ar(valdir): idx2ar_file = valdir + '/../sorted_idxar.p' if os.path.isfile(idx2ar_file): return pickle.load(open(idx2ar_file, 'rb')) - print('Creating AR indexes. Please be patient this may take a couple minutes...') - val_dataset = datasets.ImageFolder(valdir) # AS: TODO: use Image.open instead of looping through dataset + print( + 'Creating AR indexes. Please be patient this may take a couple minutes...' + ) + val_dataset = datasets.ImageFolder( + valdir) # AS: TODO: use Image.open instead of looping through dataset sizes = [img[0].size for img in tqdm(val_dataset, total=len(val_dataset))] - idx_ar = [(i, round(s[0] * 1.0/ s[1], 5)) for i, s in enumerate(sizes)] + idx_ar = [(i, round(s[0] * 1.0 / s[1], 5)) for i, s in enumerate(sizes)] sorted_idxar = sorted(idx_ar, key=lambda x: x[1]) pickle.dump(sorted_idxar, open(idx2ar_file, 'wb')) print('Done') return sorted_idxar + def chunks(l, n): n = max(1, n) return (l[i:i + n] for i in range(0, len(l), n)) -- GitLab