From 0dff1e4140987441e52e3994ad5351bc902776d2 Mon Sep 17 00:00:00 2001 From: JiayiFeng Date: Wed, 18 Apr 2018 09:37:06 +0000 Subject: [PATCH] Fix bugs --- python/paddle/fluid/layers/io.py | 7 ++++++- .../paddle/fluid/tests/unittests/test_multi_pass_reader.py | 6 ++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index 7c19144ea4..ffadda6595 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -21,7 +21,7 @@ from ..executor import global_scope __all__ = [ 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', - 'open_files', 'read_file', 'shuffle', 'double_buffer' + 'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer' ] @@ -469,6 +469,11 @@ def shuffle(reader, buffer_size): 'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)}) +def batch(reader, batch_size): + return __create_unshared_decorated_reader__( + 'create_batch_reader', reader, {'batch_size': int(batch_size)}) + + def double_buffer(reader, place=None): attrs = dict() if place is not None: diff --git a/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py b/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py index 1471843ded..52e7cc1ffb 100644 --- a/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py +++ b/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py @@ -43,9 +43,8 @@ class TestMultipleReader(unittest.TestCase): filename='./mnist.recordio', shapes=[(-1, 784), (-1, 1)], lod_levels=[0, 0], - dtypes=['float32', 'int64']) - data_file = fluid.layers.io.multi_pass( - reader=data_file, pass_num=self.pass_num) + dtypes=['float32', 'int64'], + pass_num=self.pass_num) img, label = fluid.layers.read_file(data_file) if fluid.core.is_compiled_with_cuda(): @@ -65,5 +64,4 @@ class TestMultipleReader(unittest.TestCase): break batch_count += 1 self.assertLessEqual(img_val.shape[0], self.batch_size) - data_file.reset() self.assertEqual(batch_count, self.num_batch * self.pass_num) -- GitLab