diff --git a/python/paddle/v2/dataset/flowers.py b/python/paddle/v2/dataset/flowers.py index a181f3881a61a7caebec668da0ca192cde2d3a3e..158cfe158c4f1c8d82d157301adcfbe0351c55df 100644 --- a/python/paddle/v2/dataset/flowers.py +++ b/python/paddle/v2/dataset/flowers.py @@ -46,6 +46,12 @@ SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat' DATA_MD5 = '52808999861908f626f3c1f4e79d11fa' LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' +# In official 'readme', tstid is the flag of test data +# and trnid is the flag of train data. But test data is more than train data. +# So we exchange the train data and test data. +TRAIN_FLAG = 'tstid' +TEST_FLAG = 'trnid' +VALID_FLAG = 'valid' def default_mapper(sample): @@ -64,7 +70,7 @@ def reader_creator(data_file, dataset_name, mapper=default_mapper, buffered_size=1024, - useXmap=True): + use_xmap=True): ''' 1. read images from tar file and merge images into batch files in 102flowers.tgz_batch/ @@ -106,13 +112,13 @@ def reader_creator(data_file, for sample, label in itertools.izip(data, batch['label']): yield sample, int(label) - if useXmap: + if use_xmap: return xmap_readers(mapper, reader, cpu_count(), buffered_size) else: return map_readers(mapper, reader) -def train(mapper=default_mapper, buffered_size=1024, useXmap=True): +def train(mapper=default_mapper, buffered_size=1024, use_xmap=True): ''' Create flowers training set reader. It returns a reader, each sample in the reader is @@ -131,11 +137,11 @@ def train(mapper=default_mapper, buffered_size=1024, useXmap=True): return reader_creator( download(DATA_URL, 'flowers', DATA_MD5), download(LABEL_URL, 'flowers', LABEL_MD5), - download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper, - buffered_size, useXmap) + download(SETID_URL, 'flowers', SETID_MD5), TRAIN_FLAG, mapper, + buffered_size, use_xmap) -def test(mapper=default_mapper, buffered_size=1024, useXmap=True): +def test(mapper=default_mapper, buffered_size=1024, use_xmap=True): ''' Create flowers test set reader. It returns a reader, each sample in the reader is @@ -154,11 +160,11 @@ def test(mapper=default_mapper, buffered_size=1024, useXmap=True): return reader_creator( download(DATA_URL, 'flowers', DATA_MD5), download(LABEL_URL, 'flowers', LABEL_MD5), - download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper, - buffered_size, useXmap) + download(SETID_URL, 'flowers', SETID_MD5), TEST_FLAG, mapper, + buffered_size, use_xmap) -def valid(mapper=default_mapper, buffered_size=1024, useXmap=True): +def valid(mapper=default_mapper, buffered_size=1024, use_xmap=True): ''' Create flowers validation set reader. It returns a reader, each sample in the reader is @@ -177,8 +183,8 @@ def valid(mapper=default_mapper, buffered_size=1024, useXmap=True): return reader_creator( download(DATA_URL, 'flowers', DATA_MD5), download(LABEL_URL, 'flowers', LABEL_MD5), - download(SETID_URL, 'flowers', SETID_MD5), 'valid', mapper, - buffered_size, useXmap) + download(SETID_URL, 'flowers', SETID_MD5), VALID_FLAG, mapper, + buffered_size, use_xmap) def fetch():