提交 01f44bff 编写于 作者: W wanghaoshuang

rename args and add comments

1. rename 'useXmap' to 'use_xmap'
2. add comments about exchanging train data and test data
上级 fc5972ba
...@@ -46,6 +46,12 @@ SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat' ...@@ -46,6 +46,12 @@ SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa' DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
# In official 'readme', tstid is the flag of test data
# and trnid is the flag of train data. But test data is more than train data.
# So we exchange the train data and test data.
TRAIN_FLAG = 'tstid'
TEST_FLAG = 'trnid'
VALID_FLAG = 'valid'
def default_mapper(sample): def default_mapper(sample):
...@@ -64,7 +70,7 @@ def reader_creator(data_file, ...@@ -64,7 +70,7 @@ def reader_creator(data_file,
dataset_name, dataset_name,
mapper=default_mapper, mapper=default_mapper,
buffered_size=1024, buffered_size=1024,
useXmap=True): use_xmap=True):
''' '''
1. read images from tar file and 1. read images from tar file and
merge images into batch files in 102flowers.tgz_batch/ merge images into batch files in 102flowers.tgz_batch/
...@@ -106,13 +112,13 @@ def reader_creator(data_file, ...@@ -106,13 +112,13 @@ def reader_creator(data_file,
for sample, label in itertools.izip(data, batch['label']): for sample, label in itertools.izip(data, batch['label']):
yield sample, int(label) yield sample, int(label)
if useXmap: if use_xmap:
return xmap_readers(mapper, reader, cpu_count(), buffered_size) return xmap_readers(mapper, reader, cpu_count(), buffered_size)
else: else:
return map_readers(mapper, reader) return map_readers(mapper, reader)
def train(mapper=default_mapper, buffered_size=1024, useXmap=True): def train(mapper=default_mapper, buffered_size=1024, use_xmap=True):
''' '''
Create flowers training set reader. Create flowers training set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
...@@ -131,11 +137,11 @@ def train(mapper=default_mapper, buffered_size=1024, useXmap=True): ...@@ -131,11 +137,11 @@ def train(mapper=default_mapper, buffered_size=1024, useXmap=True):
return reader_creator( return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5), download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper, download(SETID_URL, 'flowers', SETID_MD5), TRAIN_FLAG, mapper,
buffered_size, useXmap) buffered_size, use_xmap)
def test(mapper=default_mapper, buffered_size=1024, useXmap=True): def test(mapper=default_mapper, buffered_size=1024, use_xmap=True):
''' '''
Create flowers test set reader. Create flowers test set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
...@@ -154,11 +160,11 @@ def test(mapper=default_mapper, buffered_size=1024, useXmap=True): ...@@ -154,11 +160,11 @@ def test(mapper=default_mapper, buffered_size=1024, useXmap=True):
return reader_creator( return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5), download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper, download(SETID_URL, 'flowers', SETID_MD5), TEST_FLAG, mapper,
buffered_size, useXmap) buffered_size, use_xmap)
def valid(mapper=default_mapper, buffered_size=1024, useXmap=True): def valid(mapper=default_mapper, buffered_size=1024, use_xmap=True):
''' '''
Create flowers validation set reader. Create flowers validation set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
...@@ -177,8 +183,8 @@ def valid(mapper=default_mapper, buffered_size=1024, useXmap=True): ...@@ -177,8 +183,8 @@ def valid(mapper=default_mapper, buffered_size=1024, useXmap=True):
return reader_creator( return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5), download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'valid', mapper, download(SETID_URL, 'flowers', SETID_MD5), VALID_FLAG, mapper,
buffered_size, useXmap) buffered_size, use_xmap)
def fetch(): def fetch():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册