提交 c0a9aebe 编写于 作者: L Liu Yiqun

Remove the clone of program in C++ Executor.Run().

上级 cbfd15f9
...@@ -106,10 +106,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -106,10 +106,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
// and feed_holder_name. Raise exception when any mismatch is found. // and feed_holder_name. Raise exception when any mismatch is found.
// Return true if the block has feed operators and holder of matching info. // Return true if the block has feed operators and holder of matching info.
static bool has_feed_operators( static bool has_feed_operators(
BlockDesc* block, std::map<std::string, const LoDTensor*>& feed_targets, const BlockDesc& block,
std::map<std::string, const LoDTensor*>& feed_targets,
const std::string& feed_holder_name) { const std::string& feed_holder_name) {
size_t feed_count = 0; size_t feed_count = 0;
for (auto* op : block->AllOps()) { for (auto* op : block.AllOps()) {
if (op->Type() == kFeedOpType) { if (op->Type() == kFeedOpType) {
feed_count++; feed_count++;
PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name, PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name,
...@@ -128,7 +129,7 @@ static bool has_feed_operators( ...@@ -128,7 +129,7 @@ static bool has_feed_operators(
"The number of feed operators should match 'feed_targets'"); "The number of feed operators should match 'feed_targets'");
// When feed operator are present, so should be feed_holder // When feed operator are present, so should be feed_holder
auto var = block->FindVar(feed_holder_name); auto var = block.FindVar(feed_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
feed_holder_name); feed_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH, PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
...@@ -146,10 +147,10 @@ static bool has_feed_operators( ...@@ -146,10 +147,10 @@ static bool has_feed_operators(
// and fetch_holder_name. Raise exception when any mismatch is found. // and fetch_holder_name. Raise exception when any mismatch is found.
// Return true if the block has fetch operators and holder of matching info. // Return true if the block has fetch operators and holder of matching info.
static bool has_fetch_operators( static bool has_fetch_operators(
BlockDesc* block, std::map<std::string, LoDTensor*>& fetch_targets, const BlockDesc& block, std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& fetch_holder_name) { const std::string& fetch_holder_name) {
size_t fetch_count = 0; size_t fetch_count = 0;
for (auto* op : block->AllOps()) { for (auto* op : block.AllOps()) {
if (op->Type() == kFetchOpType) { if (op->Type() == kFetchOpType) {
fetch_count++; fetch_count++;
PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name, PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name,
...@@ -168,7 +169,7 @@ static bool has_fetch_operators( ...@@ -168,7 +169,7 @@ static bool has_fetch_operators(
"The number of fetch operators should match 'fetch_targets'"); "The number of fetch operators should match 'fetch_targets'");
// When fetch operator are present, so should be fetch_holder // When fetch operator are present, so should be fetch_holder
auto var = block->FindVar(fetch_holder_name); auto var = block.FindVar(fetch_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable", PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
fetch_holder_name); fetch_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST, PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
...@@ -184,10 +185,19 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -184,10 +185,19 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, LoDTensor*>& fetch_targets, std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name, const std::string& feed_holder_name,
const std::string& fetch_holder_name) { const std::string& fetch_holder_name) {
auto* copy_program = new ProgramDesc(program); bool has_feed_ops =
has_feed_operators(program.Block(0), feed_targets, feed_holder_name);
bool has_fetch_ops =
has_fetch_operators(program.Block(0), fetch_targets, fetch_holder_name);
ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program);
if (!has_feed_ops || !has_fetch_ops) {
copy_program = std::unique_ptr<ProgramDesc>(new ProgramDesc(program)).get();
}
auto* global_block = copy_program->MutableBlock(0); auto* global_block = copy_program->MutableBlock(0);
if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) { if (!has_feed_ops) {
// create feed_holder variable // create feed_holder variable
auto* feed_holder = global_block->Var(feed_holder_name); auto* feed_holder = global_block->Var(feed_holder_name);
feed_holder->SetType(proto::VarType::FEED_MINIBATCH); feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
...@@ -220,7 +230,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -220,7 +230,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
} }
} }
if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) { if (!has_fetch_ops) {
// create fetch_holder variable // create fetch_holder variable
auto* fetch_holder = global_block->Var(fetch_holder_name); auto* fetch_holder = global_block->Var(fetch_holder_name);
fetch_holder->SetType(proto::VarType::FETCH_LIST); fetch_holder->SetType(proto::VarType::FETCH_LIST);
...@@ -254,8 +264,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -254,8 +264,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
GetFetchVariable(*scope, fetch_holder_name, idx); GetFetchVariable(*scope, fetch_holder_name, idx);
} }
} }
delete copy_program;
} }
ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program, ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program,
...@@ -305,9 +313,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, ...@@ -305,9 +313,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} // if (create_vars) } // if (create_vars)
for (auto& op : ctx->ops_) { for (auto& op : ctx->ops_) {
VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope); VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
if (FLAGS_benchmark) { if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: " VLOG(2) << "Memory used after operator " + op->Type() + " running: "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册