From 01f44bff669442ffdb67a5baac14aa693cba08c6 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 28 Jun 2017 23:12:19 +0800 Subject: [PATCH] rename args and add comments 1. rename 'useXmap' to 'use_xmap' 2. add comments about exchanging train data and test data --- python/paddle/v2/dataset/flowers.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/python/paddle/v2/dataset/flowers.py b/python/paddle/v2/dataset/flowers.py index a181f3881..158cfe158 100644 --- a/python/paddle/v2/dataset/flowers.py +++ b/python/paddle/v2/dataset/flowers.py @@ -46,6 +46,12 @@ SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat' DATA_MD5 = '52808999861908f626f3c1f4e79d11fa' LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' +# In official 'readme', tstid is the flag of test data +# and trnid is the flag of train data. But test data is more than train data. +# So we exchange the train data and test data. +TRAIN_FLAG = 'tstid' +TEST_FLAG = 'trnid' +VALID_FLAG = 'valid' def default_mapper(sample): @@ -64,7 +70,7 @@ def reader_creator(data_file, dataset_name, mapper=default_mapper, buffered_size=1024, - useXmap=True): + use_xmap=True): ''' 1. read images from tar file and merge images into batch files in 102flowers.tgz_batch/ @@ -106,13 +112,13 @@ def reader_creator(data_file, for sample, label in itertools.izip(data, batch['label']): yield sample, int(label) - if useXmap: + if use_xmap: return xmap_readers(mapper, reader, cpu_count(), buffered_size) else: return map_readers(mapper, reader) -def train(mapper=default_mapper, buffered_size=1024, useXmap=True): +def train(mapper=default_mapper, buffered_size=1024, use_xmap=True): ''' Create flowers training set reader. It returns a reader, each sample in the reader is @@ -131,11 +137,11 @@ def train(mapper=default_mapper, buffered_size=1024, useXmap=True): return reader_creator( download(DATA_URL, 'flowers', DATA_MD5), download(LABEL_URL, 'flowers', LABEL_MD5), - download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper, - buffered_size, useXmap) + download(SETID_URL, 'flowers', SETID_MD5), TRAIN_FLAG, mapper, + buffered_size, use_xmap) -def test(mapper=default_mapper, buffered_size=1024, useXmap=True): +def test(mapper=default_mapper, buffered_size=1024, use_xmap=True): ''' Create flowers test set reader. It returns a reader, each sample in the reader is @@ -154,11 +160,11 @@ def test(mapper=default_mapper, buffered_size=1024, useXmap=True): return reader_creator( download(DATA_URL, 'flowers', DATA_MD5), download(LABEL_URL, 'flowers', LABEL_MD5), - download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper, - buffered_size, useXmap) + download(SETID_URL, 'flowers', SETID_MD5), TEST_FLAG, mapper, + buffered_size, use_xmap) -def valid(mapper=default_mapper, buffered_size=1024, useXmap=True): +def valid(mapper=default_mapper, buffered_size=1024, use_xmap=True): ''' Create flowers validation set reader. It returns a reader, each sample in the reader is @@ -177,8 +183,8 @@ def valid(mapper=default_mapper, buffered_size=1024, useXmap=True): return reader_creator( download(DATA_URL, 'flowers', DATA_MD5), download(LABEL_URL, 'flowers', LABEL_MD5), - download(SETID_URL, 'flowers', SETID_MD5), 'valid', mapper, - buffered_size, useXmap) + download(SETID_URL, 'flowers', SETID_MD5), VALID_FLAG, mapper, + buffered_size, use_xmap) def fetch(): -- GitLab