From e4e9d3624f25dfaae2516b5e57708ddb9f90ccd3 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 24 May 2018 12:55:03 +0800 Subject: [PATCH] fix a potential bug --- .../reader/create_custom_reader_op.cc | 20 +++++++------------ 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/reader/create_custom_reader_op.cc b/paddle/fluid/operators/reader/create_custom_reader_op.cc index 74e6b79a2a3..f03b3473ad3 100644 --- a/paddle/fluid/operators/reader/create_custom_reader_op.cc +++ b/paddle/fluid/operators/reader/create_custom_reader_op.cc @@ -21,25 +21,22 @@ namespace reader { class CustomReader : public framework::DecoratedReader { public: - CustomReader(ReaderBase* reader, const framework::BlockDesc* sub_block, + CustomReader(ReaderBase* reader, const framework::BlockDesc& sub_block, const platform::Place& dev_place, const std::vector& source_var_names, const std::vector& sink_var_names) : DecoratedReader(reader), - sub_block_(sub_block), + program_(*sub_block.Program()), + sub_block_id_(sub_block.ID()), exe_(framework::Executor(dev_place)), source_var_names_(source_var_names), sink_var_names_(sink_var_names) {} void ReadNext(std::vector* out) override; - void UpdateBlockAndScope(const framework::BlockDesc* sub_block, - const framework::Scope* scope) { - sub_block_ = sub_block; - } - private: - const framework::BlockDesc* sub_block_; + const framework::ProgramDesc program_; + int sub_block_id_; framework::Executor exe_; std::vector source_var_names_; @@ -57,14 +54,12 @@ class CreateCustomReaderOp : public framework::OperatorBase { ->template GetMutable(); auto* sub_block = Attr("sub_block"); if (out->Get() != nullptr) { - auto* custom_reader = reinterpret_cast(out->Get()); - custom_reader->UpdateBlockAndScope(sub_block, &scope); return; } const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) ->Get(); out->Reset( - new CustomReader(underlying_reader.Get(), sub_block, dev_place, + new CustomReader(underlying_reader.Get(), *sub_block, dev_place, Attr>("source_var_names"), Attr>("sink_var_names"))); } @@ -159,8 +154,7 @@ void CustomReader::ReadNext(std::vector* out) { tensor->set_lod(underlying_outs[i].lod()); } // 2. Run the sub-block. - framework::ProgramDesc* program = sub_block_->Program(); - exe_.Run(*program, scope, sub_block_->ID(), false, true); + exe_.Run(program_, scope, sub_block_id_, false, true); // 3. Copy LoDTensors from sink variables to out. out->resize(sink_var_names_.size()); for (size_t i = 0; i < sink_var_names_.size(); ++i) { -- GitLab