From 44d5f42a7e2400074790171f77920473aabb7eba Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Tue, 3 Apr 2018 14:13:04 +0800 Subject: [PATCH] update reader --- .../reader/create_batch_reader_op.cc | 7 ++-- .../reader/create_double_buffer_reader_op.cc | 8 +++-- .../reader/create_multi_pass_reader_op.cc | 9 +++-- .../reader/create_shuffle_reader_op.cc | 8 +++-- python/paddle/fluid/framework.py | 15 ++++++++ python/paddle/fluid/layers/io.py | 35 ++++++++++++++----- 6 files changed, 64 insertions(+), 18 deletions(-) diff --git a/paddle/fluid/operators/reader/create_batch_reader_op.cc b/paddle/fluid/operators/reader/create_batch_reader_op.cc index 277f2856c0..04c5872bef 100644 --- a/paddle/fluid/operators/reader/create_batch_reader_op.cc +++ b/paddle/fluid/operators/reader/create_batch_reader_op.cc @@ -39,10 +39,13 @@ class CreateBatchReaderOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override { - const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) - ->Get(); auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); + if (out->Get() != nullptr) { + return; + } + const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) + ->Get(); out->Reset( new BatchReader(underlying_reader.Get(), Attr("batch_size"))); } diff --git a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc index f9a8058f2a..82bb668e9f 100644 --- a/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc +++ b/paddle/fluid/operators/reader/create_double_buffer_reader_op.cc @@ -12,7 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include "paddle/fluid/framework/channel.h" #include "paddle/fluid/operators/reader/reader_op_registry.h" @@ -98,10 +97,13 @@ class CreateDoubleBufferReaderOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override { - const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) - ->Get(); auto* out = scope.FindVar(Output("Out")) ->template GetMutable(); + if (out->Get() != nullptr) { + return; + } + const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) + ->Get(); auto place_str = Attr("place"); platform::Place place; diff --git a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc index 47d9989bc8..b72ccc77a3 100644 --- a/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc +++ b/paddle/fluid/operators/reader/create_multi_pass_reader_op.cc @@ -62,12 +62,15 @@ class CreateMultiPassReaderOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override { + auto* out = detail::Ref(scope.FindVar(Output("Out"))) + .GetMutable(); + if (out->Get() != nullptr) { + return; + } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - auto& out = detail::Ref(scope.FindVar(Output("Out"))); int pass_num = Attr("pass_num"); - out.GetMutable()->Reset( - new MultiPassReader(underlying_reader.Get(), pass_num)); + out->Reset(new MultiPassReader(underlying_reader.Get(), pass_num)); } }; diff --git a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc index 3a1f3805a0..b164ce232d 100644 --- a/paddle/fluid/operators/reader/create_shuffle_reader_op.cc +++ b/paddle/fluid/operators/reader/create_shuffle_reader_op.cc @@ -80,10 +80,14 @@ class CreateShuffleReaderOp : public framework::OperatorBase { private: void RunImpl(const framework::Scope& scope, const platform::Place& dev_place) const override { + auto* out = detail::Ref(scope.FindVar(Output("Out"))) + .GetMutable(); + if (out->Get() != nullptr) { + return; + } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); - auto& var = detail::Ref(scope.FindVar(Output("Out"))); - var.GetMutable()->Reset( + out->Reset( new ShuffleReader(underlying_reader.Get(), static_cast(Attr("buffer_size")))); } diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 3e78788f47..2f943457f5 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -640,6 +640,21 @@ class Operator(object): """ return self.desc.block_attr(name) + @property + def attrs(self): + """ + Get the attribute dict + Returns(dict): The Operator's attribute dict + """ + attr_names = self.attr_names + attr_map = {} + for n in attr_names: + if n == 'sub_block': + attr_map[n] = self.block_attr(n) + else: + attr_map[n] = self.attr(n) + return attr_map + class Block(object): def __init__(self, program, idx): diff --git a/python/paddle/fluid/layers/io.py b/python/paddle/fluid/layers/io.py index bd7e9c30fe..f6bd3c7d0a 100644 --- a/python/paddle/fluid/layers/io.py +++ b/python/paddle/fluid/layers/io.py @@ -255,7 +255,22 @@ def _copy_reader_var_(block, var): new_var.desc.set_shapes(var.desc.shapes()) new_var.desc.set_dtypes(var.desc.dtypes()) new_var.persistable = True - return monkey_patch_reader_methods(new_var) + return new_var + + +def _copy_reader_create_op_(block, op): + def _find_vars_(block, name_list): + res = {} + for n in name_list: + var = block.var(n) + res[n] = var + return res + + input_map = _find_vars_(block, op.input_names) + output_map = _find_vars_(block, op.output_names) + new_op = block.append_op( + type=op.type, inputs=input_map, outputs=output_map, attrs=op.attrs) + return new_op def open_recordio_file(filename, shapes, lod_levels, dtypes): @@ -283,8 +298,9 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes): startup_var.desc.set_dtypes(dtypes) startup_var.persistable = True - return _copy_reader_var_(default_main_program().current_block(), - startup_var) + main_prog_var = _copy_reader_var_(default_main_program().current_block(), + startup_var) + return monkey_patch_reader_methods(main_prog_var) def open_files(filenames, thread_num, shapes, lod_levels, dtypes): @@ -313,22 +329,25 @@ def open_files(filenames, thread_num, shapes, lod_levels, dtypes): startup_var.desc.set_dtypes(dtypes) startup_var.persistable = True - return _copy_reader_var_(default_main_program().current_block(), - startup_var) + main_prog_var = _copy_reader_var_(default_main_program().current_block(), + startup_var) + return monkey_patch_reader_methods(main_prog_var) def __create_decorated_reader__(op_type, reader, attrs): var_name = unique_name(op_type) startup_blk = default_startup_program().current_block() startup_var = startup_blk.create_var(name=var_name) - startup_blk.append_op( + startop_op = startup_blk.append_op( type=op_type, inputs={'UnderlyingReader': reader}, outputs={'Out': [startup_var]}, attrs=attrs) startup_var.persistable = True - return _copy_reader_var_(default_main_program().current_block(), - startup_var) + main_prog_block = default_main_program().current_block() + main_prog_var = _copy_reader_var_(main_prog_block, startup_var) + _copy_reader_create_op_(main_prog_block, startop_op) + return monkey_patch_reader_methods(main_prog_var) def create_shuffle_reader(reader, buffer_size): -- GitLab