diff --git a/paddle/fluid/operators/reader/CMakeLists.txt b/paddle/fluid/operators/reader/CMakeLists.txt index 4a43fc02d2189ec2eb4bff12769c11cb8dae8193..6fa0195b9ae103418beb56cc4b0fa9ab59e93108 100644 --- a/paddle/fluid/operators/reader/CMakeLists.txt +++ b/paddle/fluid/operators/reader/CMakeLists.txt @@ -21,5 +21,6 @@ reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc) reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc) reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc) reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc) +reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc) # Export local libraries to parent set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE) diff --git a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4d4e9fb909eafea5328491a4097276577f28a5ba --- /dev/null +++ b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc @@ -0,0 +1,101 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/operators/reader/reader_op_registry.h" + +namespace paddle { +namespace operators { +namespace reader { + +class MultiPassReader : public framework::DecoratedReader { + public: + MultiPassReader(ReaderBase* reader, int pass_num) + : DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {} + + void ReadNext(std::vector* out) override { + if (!HasNext()) { + PADDLE_THROW("There is no next data!"); + } + reader_->ReadNext(out); + } + + bool HasNext() const override { + if (reader_->HasNext()) { + return true; + } else { + ++pass_count_; + if (pass_count_ >= pass_num_) { + return false; + } else { + reader_->ReInit(); + return true; + } + } + } + + void ReInit() override { + pass_count_ = 0; + reader_->ReInit(); + } + + private: + int pass_num_; + mutable int pass_count_; +}; + +class CreateMultiPassReaderOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope& scope, + const platform::Place& dev_place) const override { + const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) + ->Get(); + auto& out = detail::Ref(scope.FindVar(Output("Out"))); + int pass_num = Attr("pass_num"); + out.GetMutable()->Reset( + new MultiPassReader(underlying_reader.Get(), pass_num)); + } +}; + +class CreateMultiPassReaderOpMaker : public DecoratedReaderMakerBase { + public: + CreateMultiPassReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) + : DecoratedReaderMakerBase(op_proto, op_checker) { + AddAttr("pass_num", "The number of pass to run.").GreaterThan(0); + AddComment(R"DOC( + CreateMultiPassReader Operator + + This operator creates a multi-pass reader. A multi-pass reader + is used to yield data for several pass training continuously. + It takes the the number of pass to run as one of its attributes + ('pass_num'), and maintains a pass counter to record how many + passes it has completed. When the underlying reader reach the EOF, + the multi-pass reader checks whether it has completed training + of the given number of pass. If not, the underlying reader will + be re-initialized and starts a new pass automatically. + )DOC"); + } +}; + +} // namespace reader +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators::reader; +REGISTER_DECORATED_READER_OPERATOR(create_multi_pass_reader, + ops::CreateMultiPassReaderOp, + ops::CreateMultiPassReaderOpMaker); diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index f169642eaa44eadcef8ff0bc6745183a9fec20e8..bc5e291ad811315ddc9d101853d69c7f5ab5082d 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -22,7 +22,7 @@ from ..executor import global_scope __all__ = [ 'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file', 'open_files', 'read_file', 'create_shuffle_reader', - 'create_double_buffer_reader' + 'create_double_buffer_reader', 'create_multi_pass_reader' ] @@ -345,6 +345,11 @@ def create_double_buffer_reader(reader, place=None): attrs) +def create_multi_pass_reader(reader, pass_num): + return __create_decorated_reader__('create_multi_pass_reader', reader, + {'pass_num': int(pass_num)}) + + def read_file(file_obj): helper = LayerHelper('read_file') out = [ diff --git a/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py b/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..8add353303e3626bbce68199a100306d4858766a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_multi_pass_reader.py @@ -0,0 +1,65 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import paddle.fluid as fluid +import paddle.v2 as paddle +import paddle.v2.dataset.mnist as mnist + + +class TestMultipleReader(unittest.TestCase): + def setUp(self): + self.batch_size = 64 + self.pass_num = 3 + # Convert mnist to recordio file + with fluid.program_guard(fluid.Program(), fluid.Program()): + data_file = paddle.batch(mnist.train(), batch_size=self.batch_size) + feeder = fluid.DataFeeder( + feed_list=[ + fluid.layers.data( + name='image', shape=[784]), + fluid.layers.data( + name='label', shape=[1], dtype='int64'), + ], + place=fluid.CPUPlace()) + self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file( + './mnist.recordio', data_file, feeder) + + def test_main(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data_file = fluid.layers.open_recordio_file( + filename='./mnist.recordio', + shapes=[(-1, 784), (-1, 1)], + lod_levels=[0, 0], + dtypes=['float32', 'int64']) + data_file = fluid.layers.create_multi_pass_reader( + reader=data_file, pass_num=self.pass_num) + img, label = fluid.layers.read_file(data_file) + + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + else: + place = fluid.CPUPlace() + + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + + batch_count = 0 + while not data_file.eof(): + img_val, = exe.run(fetch_list=[img]) + batch_count += 1 + self.assertLessEqual(img_val.shape[0], self.batch_size) + data_file.reset() + self.assertEqual(batch_count, self.num_batch * self.pass_num)