提交 9b1c04f7 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #2631 from wanghaoshuang/fix_xmap

fix xmap_readers and refine flowers dataset
......@@ -25,8 +25,9 @@ import uci_housing
import sentiment
import wmt14
import mq2007
import flowers
__all__ = [
'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment'
'uci_housing', 'wmt14', 'mq2007'
'uci_housing', 'wmt14', 'mq2007', 'flowers'
]
......@@ -34,9 +34,9 @@ from common import download
import tarfile
import scipy.io as scio
from paddle.v2.image import *
from paddle.v2.reader import *
import os
import numpy as np
import paddle.v2 as paddle
from multiprocessing import cpu_count
__all__ = ['train', 'test', 'valid']
......@@ -46,6 +46,12 @@ SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
# In official 'readme', tstid is the flag of test data
# and trnid is the flag of train data. But test data is more than train data.
# So we exchange the train data and test data.
TRAIN_FLAG = 'tstid'
TEST_FLAG = 'trnid'
VALID_FLAG = 'valid'
def default_mapper(sample):
......@@ -53,8 +59,8 @@ def default_mapper(sample):
map image bytes data to type needed by model input layer
'''
img, label = sample
img = paddle.image.load_image_bytes(img)
img = paddle.image.simple_transform(img, 256, 224, True)
img = load_image_bytes(img)
img = simple_transform(img, 256, 224, True)
return img.flatten().astype('float32'), label
......@@ -63,7 +69,8 @@ def reader_creator(data_file,
setid_file,
dataset_name,
mapper=default_mapper,
buffered_size=1024):
buffered_size=1024,
use_xmap=True):
'''
1. read images from tar file and
merge images into batch files in 102flowers.tgz_batch/
......@@ -105,11 +112,13 @@ def reader_creator(data_file,
for sample, label in itertools.izip(data, batch['label']):
yield sample, int(label)
return paddle.reader.xmap_readers(mapper, reader,
cpu_count(), buffered_size)
if use_xmap:
return xmap_readers(mapper, reader, cpu_count(), buffered_size)
else:
return map_readers(mapper, reader)
def train(mapper=default_mapper, buffered_size=1024):
def train(mapper=default_mapper, buffered_size=1024, use_xmap=True):
'''
Create flowers training set reader.
It returns a reader, each sample in the reader is
......@@ -128,11 +137,11 @@ def train(mapper=default_mapper, buffered_size=1024):
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper,
buffered_size)
download(SETID_URL, 'flowers', SETID_MD5), TRAIN_FLAG, mapper,
buffered_size, use_xmap)
def test(mapper=default_mapper, buffered_size=1024):
def test(mapper=default_mapper, buffered_size=1024, use_xmap=True):
'''
Create flowers test set reader.
It returns a reader, each sample in the reader is
......@@ -151,11 +160,11 @@ def test(mapper=default_mapper, buffered_size=1024):
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper,
buffered_size)
download(SETID_URL, 'flowers', SETID_MD5), TEST_FLAG, mapper,
buffered_size, use_xmap)
def valid(mapper=default_mapper, buffered_size=1024):
def valid(mapper=default_mapper, buffered_size=1024, use_xmap=True):
'''
Create flowers validation set reader.
It returns a reader, each sample in the reader is
......@@ -174,8 +183,8 @@ def valid(mapper=default_mapper, buffered_size=1024):
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'valid', mapper,
buffered_size)
download(SETID_URL, 'flowers', SETID_MD5), VALID_FLAG, mapper,
buffered_size, use_xmap)
def fetch():
......
......@@ -31,13 +31,13 @@ class TestFlowers(unittest.TestCase):
def test_train(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.train())
self.assertEqual(instances, 1020)
self.assertEqual(instances, 6149)
self.assertEqual(max_label_value, 102)
def test_test(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.test())
self.assertEqual(instances, 6149)
self.assertEqual(instances, 1020)
self.assertEqual(max_label_value, 102)
def test_valid(self):
......
......@@ -248,9 +248,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
:rtype: callable
"""
end = XmapEndSignal()
in_queue = Queue(buffer_size)
out_queue = Queue(buffer_size)
out_order = [0]
# define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue):
......@@ -266,12 +263,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
in_order += 1
in_queue.put(end)
# start a read worker in a thread
target = order_read_worker if order else read_worker
t = Thread(target=target, args=(reader, in_queue))
t.daemon = True
t.start()
# define a worker to handle samples from in_queue by mapper
# and put mapped samples into out_queue
def handle_worker(in_queue, out_queue, mapper):
......@@ -298,6 +289,15 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
in_queue.put(end)
out_queue.put(end)
def xreader():
in_queue = Queue(buffer_size)
out_queue = Queue(buffer_size)
out_order = [0]
# start a read worker in a thread
target = order_read_worker if order else read_worker
t = Thread(target=target, args=(reader, in_queue))
t.daemon = True
t.start()
# start several handle_workers
target = order_handle_worker if order else handle_worker
args = (in_queue, out_queue, mapper, out_order) if order else (
......@@ -310,7 +310,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
for w in workers:
w.start()
def xreader():
sample = out_queue.get()
while not isinstance(sample, XmapEndSignal):
yield sample
......
......@@ -132,10 +132,12 @@ class TestXmap(unittest.TestCase):
for order in orders:
for tNum in thread_nums:
for size in buffered_size:
result = []
for i in paddle.v2.reader.xmap_readers(mapper,
reader = paddle.v2.reader.xmap_readers(mapper,
reader_creator_10(0),
tNum, size, order)():
tNum, size, order)
for n in xrange(3):
result = []
for i in reader():
result.append(i)
if not order:
result.sort()
......
......@@ -15,7 +15,8 @@ setup_requires=["requests",
"protobuf==3.1",
"recordio",
"matplotlib",
"rarfile"]
"rarfile",
"scipy>=0.19.0"]
if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']:
setup_requires+=["opencv-python"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册