提交 b6aed0a9 编写于 作者: L lilong12 提交者: guru4elephant

modify single queue to a queue per process for data processing (#2136)

上级 63fe2970
from __future__ import division
import os import os
import numpy as np import numpy as np
...@@ -17,28 +18,37 @@ import time ...@@ -17,28 +18,37 @@ import time
import multiprocessing import multiprocessing
FINISH_EVENT = "FINISH_EVENT" FINISH_EVENT = "FINISH_EVENT"
class PaddleDataLoader(object): class PaddleDataLoader(object):
def __init__(self, torch_dataset, indices=None, concurrent=24, queue_size=3072, shuffle=True, shuffle_seed=0): def __init__(self,
torch_dataset,
indices=None,
concurrent=24,
queue_size=3072,
shuffle=True,
shuffle_seed=0):
self.torch_dataset = torch_dataset self.torch_dataset = torch_dataset
self.data_queue = multiprocessing.Queue(queue_size)
self.indices = indices self.indices = indices
self.concurrent = concurrent self.concurrent = concurrent
self.shuffle = shuffle self.shuffle = shuffle
self.shuffle_seed=shuffle_seed self.shuffle_seed = shuffle_seed
self.queue_size = queue_size // self.concurrent
def _worker_loop(self, dataset, worker_indices, worker_id): def _worker_loop(self, queue, worker_indices, worker_id):
cnt = 0 cnt = 0
for idx in worker_indices: for idx in worker_indices:
cnt += 1 cnt += 1
img, label = self.torch_dataset[idx] img, label = self.torch_dataset[idx]
img = np.array(img).astype('uint8').transpose((2, 0, 1)) img = np.array(img).astype('uint8').transpose((2, 0, 1))
self.data_queue.put((img, label)) queue.put((img, label))
print("worker: [%d] read [%d] samples. " % (worker_id, cnt)) print("worker: [%d] read [%d] samples. " % (worker_id, cnt))
self.data_queue.put(FINISH_EVENT) queue.put(FINISH_EVENT)
def reader(self): def reader(self):
def _reader_creator(): def _reader_creator():
worker_processes = [] worker_processes = []
index_queues = []
total_img = len(self.torch_dataset) total_img = len(self.torch_dataset)
print("total image: ", total_img) print("total image: ", total_img)
if self.shuffle: if self.shuffle:
...@@ -50,19 +60,25 @@ class PaddleDataLoader(object): ...@@ -50,19 +60,25 @@ class PaddleDataLoader(object):
imgs_per_worker = int(math.ceil(total_img / self.concurrent)) imgs_per_worker = int(math.ceil(total_img / self.concurrent))
for i in xrange(self.concurrent): for i in xrange(self.concurrent):
start = i * imgs_per_worker start = i * imgs_per_worker
end = (i + 1) * imgs_per_worker if i != self.concurrent - 1 else None end = (i + 1
) * imgs_per_worker if i != self.concurrent - 1 else None
sliced_indices = self.indices[start:end] sliced_indices = self.indices[start:end]
index_queue = multiprocessing.Queue(self.queue_size)
w = multiprocessing.Process( w = multiprocessing.Process(
target=self._worker_loop, target=self._worker_loop,
args=(self.torch_dataset, sliced_indices, i) args=(index_queue, sliced_indices, i))
)
w.daemon = True w.daemon = True
w.start() w.start()
worker_processes.append(w) worker_processes.append(w)
index_queues.append(index_queue)
finish_workers = 0 finish_workers = 0
worker_cnt = len(worker_processes) worker_cnt = len(worker_processes)
recv_index = 0
while finish_workers < worker_cnt: while finish_workers < worker_cnt:
sample = self.data_queue.get() while (index_queues[recv_index].empty()):
recv_index = (recv_index + 1) % self.concurrent
sample = index_queues[recv_index].get()
recv_index = (recv_index + 1) % self.concurrent
if sample == FINISH_EVENT: if sample == FINISH_EVENT:
finish_workers += 1 finish_workers += 1
else: else:
...@@ -70,25 +86,30 @@ class PaddleDataLoader(object): ...@@ -70,25 +86,30 @@ class PaddleDataLoader(object):
return _reader_creator return _reader_creator
def train(traindir, sz, min_scale=0.08, shuffle_seed=0): def train(traindir, sz, min_scale=0.08, shuffle_seed=0):
train_tfms = [ train_tfms = [
transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)), transforms.RandomResizedCrop(
transforms.RandomHorizontalFlip() sz, scale=(min_scale, 1.0)), transforms.RandomHorizontalFlip()
] ]
train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_tfms)) train_dataset = datasets.ImageFolder(traindir,
transforms.Compose(train_tfms))
return PaddleDataLoader(train_dataset, shuffle_seed=shuffle_seed).reader() return PaddleDataLoader(train_dataset, shuffle_seed=shuffle_seed).reader()
def test(valdir, bs, sz, rect_val=False): def test(valdir, bs, sz, rect_val=False):
if rect_val: if rect_val:
idx_ar_sorted = sort_ar(valdir) idx_ar_sorted = sort_ar(valdir)
idx_sorted, _ = zip(*idx_ar_sorted) idx_sorted, _ = zip(*idx_ar_sorted)
idx2ar = map_idx2ar(idx_ar_sorted, bs) idx2ar = map_idx2ar(idx_ar_sorted, bs)
ar_tfms = [transforms.Resize(int(sz* 1.14)), CropArTfm(idx2ar, sz)] ar_tfms = [transforms.Resize(int(sz * 1.14)), CropArTfm(idx2ar, sz)]
val_dataset = ValDataset(valdir, transform=ar_tfms) val_dataset = ValDataset(valdir, transform=ar_tfms)
return PaddleDataLoader(val_dataset, concurrent=1, indices=idx_sorted, shuffle=False).reader() return PaddleDataLoader(
val_dataset, concurrent=1, indices=idx_sorted,
shuffle=False).reader()
val_tfms = [transforms.Resize(int(sz* 1.14)), transforms.CenterCrop(sz)] val_tfms = [transforms.Resize(int(sz * 1.14)), transforms.CenterCrop(sz)]
val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_tfms)) val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_tfms))
return PaddleDataLoader(val_dataset).reader() return PaddleDataLoader(val_dataset).reader()
...@@ -132,15 +153,19 @@ def sort_ar(valdir): ...@@ -132,15 +153,19 @@ def sort_ar(valdir):
idx2ar_file = valdir + '/../sorted_idxar.p' idx2ar_file = valdir + '/../sorted_idxar.p'
if os.path.isfile(idx2ar_file): if os.path.isfile(idx2ar_file):
return pickle.load(open(idx2ar_file, 'rb')) return pickle.load(open(idx2ar_file, 'rb'))
print('Creating AR indexes. Please be patient this may take a couple minutes...') print(
val_dataset = datasets.ImageFolder(valdir) # AS: TODO: use Image.open instead of looping through dataset 'Creating AR indexes. Please be patient this may take a couple minutes...'
)
val_dataset = datasets.ImageFolder(
valdir) # AS: TODO: use Image.open instead of looping through dataset
sizes = [img[0].size for img in tqdm(val_dataset, total=len(val_dataset))] sizes = [img[0].size for img in tqdm(val_dataset, total=len(val_dataset))]
idx_ar = [(i, round(s[0] * 1.0/ s[1], 5)) for i, s in enumerate(sizes)] idx_ar = [(i, round(s[0] * 1.0 / s[1], 5)) for i, s in enumerate(sizes)]
sorted_idxar = sorted(idx_ar, key=lambda x: x[1]) sorted_idxar = sorted(idx_ar, key=lambda x: x[1])
pickle.dump(sorted_idxar, open(idx2ar_file, 'wb')) pickle.dump(sorted_idxar, open(idx2ar_file, 'wb'))
print('Done') print('Done')
return sorted_idxar return sorted_idxar
def chunks(l, n): def chunks(l, n):
n = max(1, n) n = max(1, n)
return (l[i:i + n] for i in range(0, len(l), n)) return (l[i:i + n] for i in range(0, len(l), n))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册