提交 4594f186 编写于 作者: W Wojciech Uss

added cycling the dataset for cifar and flowers

上级 26ae6111
...@@ -43,7 +43,7 @@ CIFAR100_URL = URL_PREFIX + 'cifar-100-python.tar.gz' ...@@ -43,7 +43,7 @@ CIFAR100_URL = URL_PREFIX + 'cifar-100-python.tar.gz'
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85' CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
def reader_creator(filename, sub_name): def reader_creator(filename, sub_name, cycle=False):
def read_batch(batch): def read_batch(batch):
data = batch['data'] data = batch['data']
labels = batch.get('labels', batch.get('fine_labels', None)) labels = batch.get('labels', batch.get('fine_labels', None))
...@@ -56,10 +56,13 @@ def reader_creator(filename, sub_name): ...@@ -56,10 +56,13 @@ def reader_creator(filename, sub_name):
names = (each_item.name for each_item in f names = (each_item.name for each_item in f
if sub_name in each_item.name) if sub_name in each_item.name)
for name in names: while True:
batch = cPickle.load(f.extractfile(name)) for name in names:
for item in read_batch(batch): batch = cPickle.load(f.extractfile(name))
yield item for item in read_batch(batch):
yield item
if not cycle:
break
return reader return reader
...@@ -94,34 +97,40 @@ def test100(): ...@@ -94,34 +97,40 @@ def test100():
'test') 'test')
def train10(): def train10(cycle=False):
""" """
CIFAR-10 training set creator. CIFAR-10 training set creator.
It returns a reader creator, each sample in the reader is image pixels in It returns a reader creator, each sample in the reader is image pixels in
[0, 1] and label in [0, 9]. [0, 1] and label in [0, 9].
:param cycle: whether to cycle through the dataset
:type cycle: bool
:return: Training reader creator :return: Training reader creator
:rtype: callable :rtype: callable
""" """
return reader_creator( return reader_creator(
paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'data_batch') 'data_batch',
cycle=cycle)
def test10(): def test10(cycle=False):
""" """
CIFAR-10 test set creator. CIFAR-10 test set creator.
It returns a reader creator, each sample in the reader is image pixels in It returns a reader creator, each sample in the reader is image pixels in
[0, 1] and label in [0, 9]. [0, 1] and label in [0, 9].
:param cycle: whether to cycle through the dataset
:type cycle: bool
:return: Test reader creator. :return: Test reader creator.
:rtype: callable :rtype: callable
""" """
return reader_creator( return reader_creator(
paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5),
'test_batch') 'test_batch',
cycle=cycle)
def fetch(): def fetch():
......
...@@ -76,7 +76,8 @@ def reader_creator(data_file, ...@@ -76,7 +76,8 @@ def reader_creator(data_file,
dataset_name, dataset_name,
mapper, mapper,
buffered_size=1024, buffered_size=1024,
use_xmap=True): use_xmap=True,
cycle=False):
''' '''
1. read images from tar file and 1. read images from tar file and
merge images into batch files in 102flowers.tgz_batch/ merge images into batch files in 102flowers.tgz_batch/
...@@ -96,6 +97,8 @@ def reader_creator(data_file, ...@@ -96,6 +97,8 @@ def reader_creator(data_file,
:type mapper: callable :type mapper: callable
:param buffered_size: the size of buffer used to process images :param buffered_size: the size of buffer used to process images
:type buffered_size: int :type buffered_size: int
:param cycle: whether to cycle through the dataset
:type cycle: bool
:return: data reader :return: data reader
:rtype: callable :rtype: callable
''' '''
...@@ -108,15 +111,18 @@ def reader_creator(data_file, ...@@ -108,15 +111,18 @@ def reader_creator(data_file,
file_list = batch_images_from_tar(data_file, dataset_name, img2label) file_list = batch_images_from_tar(data_file, dataset_name, img2label)
def reader(): def reader():
for file in open(file_list): while True:
file = file.strip() for file in open(file_list):
batch = None file = file.strip()
with open(file, 'r') as f: batch = None
batch = cPickle.load(f) with open(file, 'r') as f:
data = batch['data'] batch = cPickle.load(f)
labels = batch['label'] data = batch['data']
for sample, label in itertools.izip(data, batch['label']): labels = batch['label']
yield sample, int(label) - 1 for sample, label in itertools.izip(data, batch['label']):
yield sample, int(label) - 1
if not cycle:
break
if use_xmap: if use_xmap:
cpu_num = int(os.environ.get('CPU_NUM', cpu_count())) cpu_num = int(os.environ.get('CPU_NUM', cpu_count()))
...@@ -125,7 +131,7 @@ def reader_creator(data_file, ...@@ -125,7 +131,7 @@ def reader_creator(data_file,
return map_readers(mapper, reader) return map_readers(mapper, reader)
def train(mapper=train_mapper, buffered_size=1024, use_xmap=True): def train(mapper=train_mapper, buffered_size=1024, use_xmap=True, cycle=False):
''' '''
Create flowers training set reader. Create flowers training set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
...@@ -138,17 +144,23 @@ def train(mapper=train_mapper, buffered_size=1024, use_xmap=True): ...@@ -138,17 +144,23 @@ def train(mapper=train_mapper, buffered_size=1024, use_xmap=True):
:type mapper: callable :type mapper: callable
:param buffered_size: the size of buffer used to process images :param buffered_size: the size of buffer used to process images
:type buffered_size: int :type buffered_size: int
:param cycle: whether to cycle through the dataset
:type cycle: bool
:return: train data reader :return: train data reader
:rtype: callable :rtype: callable
''' '''
return reader_creator( return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5), download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), TRAIN_FLAG, mapper, download(SETID_URL, 'flowers', SETID_MD5),
buffered_size, use_xmap) TRAIN_FLAG,
mapper,
buffered_size,
use_xmap,
cycle=cycle)
def test(mapper=test_mapper, buffered_size=1024, use_xmap=True): def test(mapper=test_mapper, buffered_size=1024, use_xmap=True, cycle=False):
''' '''
Create flowers test set reader. Create flowers test set reader.
It returns a reader, each sample in the reader is It returns a reader, each sample in the reader is
...@@ -161,14 +173,20 @@ def test(mapper=test_mapper, buffered_size=1024, use_xmap=True): ...@@ -161,14 +173,20 @@ def test(mapper=test_mapper, buffered_size=1024, use_xmap=True):
:type mapper: callable :type mapper: callable
:param buffered_size: the size of buffer used to process images :param buffered_size: the size of buffer used to process images
:type buffered_size: int :type buffered_size: int
:param cycle: whether to cycle through the dataset
:type cycle: bool
:return: test data reader :return: test data reader
:rtype: callable :rtype: callable
''' '''
return reader_creator( return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5), download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), TEST_FLAG, mapper, download(SETID_URL, 'flowers', SETID_MD5),
buffered_size, use_xmap) TEST_FLAG,
mapper,
buffered_size,
use_xmap,
cycle=cycle)
def valid(mapper=test_mapper, buffered_size=1024, use_xmap=True): def valid(mapper=test_mapper, buffered_size=1024, use_xmap=True):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册