From a4f397fb681fd6b0bc28216c3d9bae3b660b50a8 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 20 Mar 2018 20:50:03 +0800 Subject: [PATCH] add an unittest --- python/paddle/fluid/layers/io.py | 3 +- .../tests/unittests/test_multi_pass_reader.py | 66 +++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 python/paddle/fluid/tests/unittests/test_multi_pass_reader.py diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index e0c4cffa2d..4ff5cd9bf5 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -21,7 +21,8 @@ from ..executor import global_scope __all__ = [ 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', - 'read_file', 'create_shuffle_reader', 'create_double_buffer_reader' + 'read_file', 'create_shuffle_reader', 'create_double_buffer_reader', + 'create_multi_pass_reader' ] diff --git a/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py b/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py new file mode 100644 index 0000000000..17374aec1b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py @@ -0,0 +1,66 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle.fluid as fluid +import paddle.v2 as paddle +import paddle.v2.dataset.mnist as mnist + + +class TestMultipleReader(unittest.TestCase): + def setUp(self): + self.batch_size = 64 + self.pass_num = 3 + # Convert mnist to recordio file + with fluid.program_guard(fluid.Program(), fluid.Program()): + data_file = paddle.batch(mnist.train(), batch_size=self.batch_size) + feeder = fluid.DataFeeder( + feed_list=[ + fluid.layers.data( + name='image', shape=[784]), + fluid.layers.data( + name='label', shape=[1], dtype='int64'), + ], + place=fluid.CPUPlace()) + self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file( + './mnist.recordio', data_file, feeder) + + def test_main(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data_file = fluid.layers.open_recordio_file( + filename='./mnist.recordio', + shapes=[(-1, 784), (-1, 1)], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + data_file = fluid.layers.create_multi_pass_reader( + reader=data_file, pass_num=self.pass_num) + img, label = fluid.layers.read_file(data_file) + + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + batch_count = 0 + while not data_file.eof(): + img_val, = exe.run(fetch_list=[img]) + batch_count += 1 + self.assertLessEqual(img_val.shape[0], self.batch_size) + print(batch_count) + data_file.reset() + self.assertEqual(batch_count, self.num_batch * self.pass_num) -- GitLab