提交 b6aed0a9 编写于 作者: L lilong12 提交者: guru4elephant

modify single queue to a queue per process for data processing (#2136)

上级 63fe2970
from __future__ import division
import os
import numpy as np
......@@ -17,28 +18,37 @@ import time
import multiprocessing
FINISH_EVENT = "FINISH_EVENT"
class PaddleDataLoader(object):
def __init__(self, torch_dataset, indices=None, concurrent=24, queue_size=3072, shuffle=True, shuffle_seed=0):
def __init__(self,
torch_dataset,
indices=None,
concurrent=24,
queue_size=3072,
shuffle=True,
shuffle_seed=0):
self.torch_dataset = torch_dataset
self.data_queue = multiprocessing.Queue(queue_size)
self.indices = indices
self.concurrent = concurrent
self.shuffle = shuffle
self.shuffle_seed=shuffle_seed
self.shuffle_seed = shuffle_seed
self.queue_size = queue_size // self.concurrent
def _worker_loop(self, dataset, worker_indices, worker_id):
def _worker_loop(self, queue, worker_indices, worker_id):
cnt = 0
for idx in worker_indices:
cnt += 1
img, label = self.torch_dataset[idx]
img = np.array(img).astype('uint8').transpose((2, 0, 1))
self.data_queue.put((img, label))
queue.put((img, label))
print("worker: [%d] read [%d] samples. " % (worker_id, cnt))
self.data_queue.put(FINISH_EVENT)
queue.put(FINISH_EVENT)
def reader(self):
def _reader_creator():
worker_processes = []
index_queues = []
total_img = len(self.torch_dataset)
print("total image: ", total_img)
if self.shuffle:
......@@ -50,19 +60,25 @@ class PaddleDataLoader(object):
imgs_per_worker = int(math.ceil(total_img / self.concurrent))
for i in xrange(self.concurrent):
start = i * imgs_per_worker
end = (i + 1) * imgs_per_worker if i != self.concurrent - 1 else None
end = (i + 1
) * imgs_per_worker if i != self.concurrent - 1 else None
sliced_indices = self.indices[start:end]
index_queue = multiprocessing.Queue(self.queue_size)
w = multiprocessing.Process(
target=self._worker_loop,
args=(self.torch_dataset, sliced_indices, i)
)
args=(index_queue, sliced_indices, i))
w.daemon = True
w.start()
worker_processes.append(w)
index_queues.append(index_queue)
finish_workers = 0
worker_cnt = len(worker_processes)
recv_index = 0
while finish_workers < worker_cnt:
sample = self.data_queue.get()
while (index_queues[recv_index].empty()):
recv_index = (recv_index + 1) % self.concurrent
sample = index_queues[recv_index].get()
recv_index = (recv_index + 1) % self.concurrent
if sample == FINISH_EVENT:
finish_workers += 1
else:
......@@ -70,25 +86,30 @@ class PaddleDataLoader(object):
return _reader_creator
def train(traindir, sz, min_scale=0.08, shuffle_seed=0):
train_tfms = [
transforms.RandomResizedCrop(sz, scale=(min_scale, 1.0)),
transforms.RandomHorizontalFlip()
transforms.RandomResizedCrop(
sz, scale=(min_scale, 1.0)), transforms.RandomHorizontalFlip()
]
train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_tfms))
train_dataset = datasets.ImageFolder(traindir,
transforms.Compose(train_tfms))
return PaddleDataLoader(train_dataset, shuffle_seed=shuffle_seed).reader()
def test(valdir, bs, sz, rect_val=False):
if rect_val:
idx_ar_sorted = sort_ar(valdir)
idx_sorted, _ = zip(*idx_ar_sorted)
idx2ar = map_idx2ar(idx_ar_sorted, bs)
ar_tfms = [transforms.Resize(int(sz* 1.14)), CropArTfm(idx2ar, sz)]
ar_tfms = [transforms.Resize(int(sz * 1.14)), CropArTfm(idx2ar, sz)]
val_dataset = ValDataset(valdir, transform=ar_tfms)
return PaddleDataLoader(val_dataset, concurrent=1, indices=idx_sorted, shuffle=False).reader()
return PaddleDataLoader(
val_dataset, concurrent=1, indices=idx_sorted,
shuffle=False).reader()
val_tfms = [transforms.Resize(int(sz* 1.14)), transforms.CenterCrop(sz)]
val_tfms = [transforms.Resize(int(sz * 1.14)), transforms.CenterCrop(sz)]
val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_tfms))
return PaddleDataLoader(val_dataset).reader()
......@@ -132,15 +153,19 @@ def sort_ar(valdir):
idx2ar_file = valdir + '/../sorted_idxar.p'
if os.path.isfile(idx2ar_file):
return pickle.load(open(idx2ar_file, 'rb'))
print('Creating AR indexes. Please be patient this may take a couple minutes...')
val_dataset = datasets.ImageFolder(valdir) # AS: TODO: use Image.open instead of looping through dataset
print(
'Creating AR indexes. Please be patient this may take a couple minutes...'
)
val_dataset = datasets.ImageFolder(
valdir) # AS: TODO: use Image.open instead of looping through dataset
sizes = [img[0].size for img in tqdm(val_dataset, total=len(val_dataset))]
idx_ar = [(i, round(s[0] * 1.0/ s[1], 5)) for i, s in enumerate(sizes)]
idx_ar = [(i, round(s[0] * 1.0 / s[1], 5)) for i, s in enumerate(sizes)]
sorted_idxar = sorted(idx_ar, key=lambda x: x[1])
pickle.dump(sorted_idxar, open(idx2ar_file, 'wb'))
print('Done')
return sorted_idxar
def chunks(l, n):
n = max(1, n)
return (l[i:i + n] for i in range(0, len(l), n))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册