提交 9b1c04f7 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #2631 from wanghaoshuang/fix_xmap

fix xmap_readers and refine flowers dataset
...@@ -25,8 +25,9 @@ import uci_housing ...@@ -25,8 +25,9 @@ import uci_housing
import sentiment import sentiment
import wmt14 import wmt14
import mq2007 import mq2007
import flowers
__all__ = [ __all__ = [
'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment' 'mnist', 'imikolov', 'imdb', 'cifar', 'movielens', 'conll05', 'sentiment'
'uci_housing', 'wmt14', 'mq2007' 'uci_housing', 'wmt14', 'mq2007', 'flowers'
] ]
...@@ -13,18 +13,18 @@ ...@@ -13,18 +13,18 @@
# limitations under the License. # limitations under the License.
""" """
This module will download dataset from This module will download dataset from
http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
and parse train/test set intopaddle reader creators. and parse train/test set intopaddle reader creators.
This set contains images of flowers belonging to 102 different categories. This set contains images of flowers belonging to 102 different categories.
The images were acquired by searching the web and taking pictures. There are a The images were acquired by searching the web and taking pictures. There are a
minimum of 40 images for each category. minimum of 40 images for each category.
The database was used in: The database was used in:
Nilsback, M-E. and Zisserman, A. Automated flower classification over a large Nilsback, M-E. and Zisserman, A. Automated flower classification over a large
number of classes.Proceedings of the Indian Conference on Computer Vision, number of classes.Proceedings of the Indian Conference on Computer Vision,
Graphics and Image Processing (2008) Graphics and Image Processing (2008)
http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}. http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.
""" """
...@@ -34,9 +34,9 @@ from common import download ...@@ -34,9 +34,9 @@ from common import download
import tarfile import tarfile
import scipy.io as scio import scipy.io as scio
from paddle.v2.image import * from paddle.v2.image import *
from paddle.v2.reader import *
import os import os
import numpy as np import numpy as np
import paddle.v2 as paddle
from multiprocessing import cpu_count from multiprocessing import cpu_count
__all__ = ['train', 'test', 'valid'] __all__ = ['train', 'test', 'valid']
...@@ -46,6 +46,12 @@ SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat' ...@@ -46,6 +46,12 @@ SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa' DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
# In official 'readme', tstid is the flag of test data
# and trnid is the flag of train data. But test data is more than train data.
# So we exchange the train data and test data.
TRAIN_FLAG = 'tstid'
TEST_FLAG = 'trnid'
VALID_FLAG = 'valid'
def default_mapper(sample): def default_mapper(sample):
...@@ -53,8 +59,8 @@ def default_mapper(sample): ...@@ -53,8 +59,8 @@ def default_mapper(sample):
map image bytes data to type needed by model input layer map image bytes data to type needed by model input layer
''' '''
img, label = sample img, label = sample
img = paddle.image.load_image_bytes(img) img = load_image_bytes(img)
img = paddle.image.simple_transform(img, 256, 224, True) img = simple_transform(img, 256, 224, True)
return img.flatten().astype('float32'), label return img.flatten().astype('float32'), label
...@@ -63,22 +69,23 @@ def reader_creator(data_file, ...@@ -63,22 +69,23 @@ def reader_creator(data_file,
setid_file, setid_file,
dataset_name, dataset_name,
mapper=default_mapper, mapper=default_mapper,
buffered_size=1024): buffered_size=1024,
use_xmap=True):
''' '''
1. read images from tar file and 1. read images from tar file and
merge images into batch files in 102flowers.tgz_batch/ merge images into batch files in 102flowers.tgz_batch/
2. get a reader to read sample from batch file 2. get a reader to read sample from batch file
:param data_file: downloaded data file :param data_file: downloaded data file
:type data_file: string :type data_file: string
:param label_file: downloaded label file :param label_file: downloaded label file
:type label_file: string :type label_file: string
:param setid_file: downloaded setid file containing information :param setid_file: downloaded setid file containing information
about how to split dataset about how to split dataset
:type setid_file: string :type setid_file: string
:param dataset_name: data set name (tstid|trnid|valid) :param dataset_name: data set name (tstid|trnid|valid)
:type dataset_name: string :type dataset_name: string
:param mapper: a function to map image bytes data to type :param mapper: a function to map image bytes data to type
needed by model input layer needed by model input layer
:type mapper: callable :type mapper: callable
:param buffered_size: the size of buffer used to process images :param buffered_size: the size of buffer used to process images
...@@ -105,15 +112,17 @@ def reader_creator(data_file, ...@@ -105,15 +112,17 @@ def reader_creator(data_file,
for sample, label in itertools.izip(data, batch['label']): for sample, label in itertools.izip(data, batch['label']):
yield sample, int(label) yield sample, int(label)
return paddle.reader.xmap_readers(mapper, reader, if use_xmap:
cpu_count(), buffered_size) return xmap_readers(mapper, reader, cpu_count(), buffered_size)
else:
return map_readers(mapper, reader)
def train(mapper=default_mapper, buffered_size=1024): def train(mapper=default_mapper, buffered_size=1024, use_xmap=True):
''' '''
Create flowers training set reader. Create flowers training set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102] image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps: translated from original color image by steps:
1. resize to 256*256 1. resize to 256*256
2. random crop to 224*224 2. random crop to 224*224
...@@ -128,15 +137,15 @@ def train(mapper=default_mapper, buffered_size=1024): ...@@ -128,15 +137,15 @@ def train(mapper=default_mapper, buffered_size=1024):
return reader_creator( return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5), download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper, download(SETID_URL, 'flowers', SETID_MD5), TRAIN_FLAG, mapper,
buffered_size) buffered_size, use_xmap)
def test(mapper=default_mapper, buffered_size=1024): def test(mapper=default_mapper, buffered_size=1024, use_xmap=True):
''' '''
Create flowers test set reader. Create flowers test set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102] image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps: translated from original color image by steps:
1. resize to 256*256 1. resize to 256*256
2. random crop to 224*224 2. random crop to 224*224
...@@ -151,15 +160,15 @@ def test(mapper=default_mapper, buffered_size=1024): ...@@ -151,15 +160,15 @@ def test(mapper=default_mapper, buffered_size=1024):
return reader_creator( return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5), download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper, download(SETID_URL, 'flowers', SETID_MD5), TEST_FLAG, mapper,
buffered_size) buffered_size, use_xmap)
def valid(mapper=default_mapper, buffered_size=1024): def valid(mapper=default_mapper, buffered_size=1024, use_xmap=True):
''' '''
Create flowers validation set reader. Create flowers validation set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102] image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps: translated from original color image by steps:
1. resize to 256*256 1. resize to 256*256
2. random crop to 224*224 2. random crop to 224*224
...@@ -174,8 +183,8 @@ def valid(mapper=default_mapper, buffered_size=1024): ...@@ -174,8 +183,8 @@ def valid(mapper=default_mapper, buffered_size=1024):
return reader_creator( return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5), download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'valid', mapper, download(SETID_URL, 'flowers', SETID_MD5), VALID_FLAG, mapper,
buffered_size) buffered_size, use_xmap)
def fetch(): def fetch():
......
...@@ -31,13 +31,13 @@ class TestFlowers(unittest.TestCase): ...@@ -31,13 +31,13 @@ class TestFlowers(unittest.TestCase):
def test_train(self): def test_train(self):
instances, max_label_value = self.check_reader( instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.train()) paddle.v2.dataset.flowers.train())
self.assertEqual(instances, 1020) self.assertEqual(instances, 6149)
self.assertEqual(max_label_value, 102) self.assertEqual(max_label_value, 102)
def test_test(self): def test_test(self):
instances, max_label_value = self.check_reader( instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.test()) paddle.v2.dataset.flowers.test())
self.assertEqual(instances, 6149) self.assertEqual(instances, 1020)
self.assertEqual(max_label_value, 102) self.assertEqual(max_label_value, 102)
def test_valid(self): def test_valid(self):
......
...@@ -166,12 +166,12 @@ def buffered(reader, size): ...@@ -166,12 +166,12 @@ def buffered(reader, size):
The buffered data reader will read and save data entries into a The buffered data reader will read and save data entries into a
buffer. Reading from the buffered data reader will proceed as long buffer. Reading from the buffered data reader will proceed as long
as the buffer is not empty. as the buffer is not empty.
:param reader: the data reader to read from. :param reader: the data reader to read from.
:type reader: callable :type reader: callable
:param size: max buffer size. :param size: max buffer size.
:type size: int :type size: int
:returns: the buffered data reader. :returns: the buffered data reader.
""" """
...@@ -238,7 +238,7 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): ...@@ -238,7 +238,7 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
:type mapper: callable :type mapper: callable
:param reader: the data reader to read from :param reader: the data reader to read from
:type reader: callable :type reader: callable
:param process_num: process number to handle original sample :param process_num: process number to handle original sample
:type process_num: int :type process_num: int
:param buffer_size: max buffer size :param buffer_size: max buffer size
:type buffer_size: int :type buffer_size: int
...@@ -248,9 +248,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): ...@@ -248,9 +248,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
:rtype: callable :rtype: callable
""" """
end = XmapEndSignal() end = XmapEndSignal()
in_queue = Queue(buffer_size)
out_queue = Queue(buffer_size)
out_order = [0]
# define a worker to read samples from reader to in_queue # define a worker to read samples from reader to in_queue
def read_worker(reader, in_queue): def read_worker(reader, in_queue):
...@@ -266,12 +263,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): ...@@ -266,12 +263,6 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
in_order += 1 in_order += 1
in_queue.put(end) in_queue.put(end)
# start a read worker in a thread
target = order_read_worker if order else read_worker
t = Thread(target=target, args=(reader, in_queue))
t.daemon = True
t.start()
# define a worker to handle samples from in_queue by mapper # define a worker to handle samples from in_queue by mapper
# and put mapped samples into out_queue # and put mapped samples into out_queue
def handle_worker(in_queue, out_queue, mapper): def handle_worker(in_queue, out_queue, mapper):
...@@ -298,19 +289,27 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False): ...@@ -298,19 +289,27 @@ def xmap_readers(mapper, reader, process_num, buffer_size, order=False):
in_queue.put(end) in_queue.put(end)
out_queue.put(end) out_queue.put(end)
# start several handle_workers
target = order_handle_worker if order else handle_worker
args = (in_queue, out_queue, mapper, out_order) if order else (
in_queue, out_queue, mapper)
workers = []
for i in xrange(process_num):
worker = Thread(target=target, args=args)
worker.daemon = True
workers.append(worker)
for w in workers:
w.start()
def xreader(): def xreader():
in_queue = Queue(buffer_size)
out_queue = Queue(buffer_size)
out_order = [0]
# start a read worker in a thread
target = order_read_worker if order else read_worker
t = Thread(target=target, args=(reader, in_queue))
t.daemon = True
t.start()
# start several handle_workers
target = order_handle_worker if order else handle_worker
args = (in_queue, out_queue, mapper, out_order) if order else (
in_queue, out_queue, mapper)
workers = []
for i in xrange(process_num):
worker = Thread(target=target, args=args)
worker.daemon = True
workers.append(worker)
for w in workers:
w.start()
sample = out_queue.get() sample = out_queue.get()
while not isinstance(sample, XmapEndSignal): while not isinstance(sample, XmapEndSignal):
yield sample yield sample
......
...@@ -132,15 +132,17 @@ class TestXmap(unittest.TestCase): ...@@ -132,15 +132,17 @@ class TestXmap(unittest.TestCase):
for order in orders: for order in orders:
for tNum in thread_nums: for tNum in thread_nums:
for size in buffered_size: for size in buffered_size:
result = [] reader = paddle.v2.reader.xmap_readers(mapper,
for i in paddle.v2.reader.xmap_readers(mapper,
reader_creator_10(0), reader_creator_10(0),
tNum, size, order)(): tNum, size, order)
result.append(i) for n in xrange(3):
if not order: result = []
result.sort() for i in reader():
for idx, e in enumerate(result): result.append(i)
self.assertEqual(e, mapper(idx)) if not order:
result.sort()
for idx, e in enumerate(result):
self.assertEqual(e, mapper(idx))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -15,7 +15,8 @@ setup_requires=["requests", ...@@ -15,7 +15,8 @@ setup_requires=["requests",
"protobuf==3.1", "protobuf==3.1",
"recordio", "recordio",
"matplotlib", "matplotlib",
"rarfile"] "rarfile",
"scipy>=0.19.0"]
if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']: if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']:
setup_requires+=["opencv-python"] setup_requires+=["opencv-python"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册