diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 5cae38b2a857b2037f0e5ae4da50d1591da0c11a..82f75ab741a4689bbbdc7f5d8c33f27d1da26181 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -106,10 +106,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, // and feed_holder_name. Raise exception when any mismatch is found. // Return true if the block has feed operators and holder of matching info. static bool has_feed_operators( - BlockDesc* block, std::map& feed_targets, + const BlockDesc& block, + std::map& feed_targets, const std::string& feed_holder_name) { size_t feed_count = 0; - for (auto* op : block->AllOps()) { + for (auto* op : block.AllOps()) { if (op->Type() == kFeedOpType) { feed_count++; PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name, @@ -128,7 +129,7 @@ static bool has_feed_operators( "The number of feed operators should match 'feed_targets'"); // When feed operator are present, so should be feed_holder - auto var = block->FindVar(feed_holder_name); + auto var = block.FindVar(feed_holder_name); PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", feed_holder_name); PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH, @@ -146,10 +147,10 @@ static bool has_feed_operators( // and fetch_holder_name. Raise exception when any mismatch is found. // Return true if the block has fetch operators and holder of matching info. static bool has_fetch_operators( - BlockDesc* block, std::map& fetch_targets, + const BlockDesc& block, std::map& fetch_targets, const std::string& fetch_holder_name) { size_t fetch_count = 0; - for (auto* op : block->AllOps()) { + for (auto* op : block.AllOps()) { if (op->Type() == kFetchOpType) { fetch_count++; PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name, @@ -168,7 +169,7 @@ static bool has_fetch_operators( "The number of fetch operators should match 'fetch_targets'"); // When fetch operator are present, so should be fetch_holder - auto var = block->FindVar(fetch_holder_name); + auto var = block.FindVar(fetch_holder_name); PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", fetch_holder_name); PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST, @@ -184,10 +185,19 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, std::map& fetch_targets, const std::string& feed_holder_name, const std::string& fetch_holder_name) { - auto* copy_program = new ProgramDesc(program); + bool has_feed_ops = + has_feed_operators(program.Block(0), feed_targets, feed_holder_name); + bool has_fetch_ops = + has_fetch_operators(program.Block(0), fetch_targets, fetch_holder_name); + + ProgramDesc* copy_program = const_cast(&program); + if (!has_feed_ops || !has_fetch_ops) { + copy_program = std::unique_ptr(new ProgramDesc(program)).get(); + } + auto* global_block = copy_program->MutableBlock(0); - if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) { + if (!has_feed_ops) { // create feed_holder variable auto* feed_holder = global_block->Var(feed_holder_name); feed_holder->SetType(proto::VarType::FEED_MINIBATCH); @@ -220,7 +230,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, } } - if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) { + if (!has_fetch_ops) { // create fetch_holder variable auto* fetch_holder = global_block->Var(fetch_holder_name); fetch_holder->SetType(proto::VarType::FETCH_LIST); @@ -254,8 +264,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, GetFetchVariable(*scope, fetch_holder_name, idx); } } - - delete copy_program; } ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program, @@ -305,9 +313,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, } // if (create_vars) for (auto& op : ctx->ops_) { - VLOG(4) << place_ << " " << op->DebugStringEx(local_scope); - op->Run(*local_scope, place_); VLOG(3) << place_ << " " << op->DebugStringEx(local_scope); + op->Run(*local_scope, place_); if (FLAGS_benchmark) { VLOG(2) << "Memory used after operator " + op->Type() + " running: "