提交 c0a9aebe 编写于 作者: L Liu Yiqun

Remove the clone of program in C++ Executor.Run().

上级 cbfd15f9
......@@ -106,10 +106,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
// and feed_holder_name. Raise exception when any mismatch is found.
// Return true if the block has feed operators and holder of matching info.
static bool has_feed_operators(
BlockDesc* block, std::map<std::string, const LoDTensor*>& feed_targets,
const BlockDesc& block,
std::map<std::string, const LoDTensor*>& feed_targets,
const std::string& feed_holder_name) {
size_t feed_count = 0;
for (auto* op : block->AllOps()) {
for (auto* op : block.AllOps()) {
if (op->Type() == kFeedOpType) {
feed_count++;
PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name,
......@@ -128,7 +129,7 @@ static bool has_feed_operators(
"The number of feed operators should match 'feed_targets'");
// When feed operator are present, so should be feed_holder
auto var = block->FindVar(feed_holder_name);
auto var = block.FindVar(feed_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
feed_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
......@@ -146,10 +147,10 @@ static bool has_feed_operators(
// and fetch_holder_name. Raise exception when any mismatch is found.
// Return true if the block has fetch operators and holder of matching info.
static bool has_fetch_operators(
BlockDesc* block, std::map<std::string, LoDTensor*>& fetch_targets,
const BlockDesc& block, std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& fetch_holder_name) {
size_t fetch_count = 0;
for (auto* op : block->AllOps()) {
for (auto* op : block.AllOps()) {
if (op->Type() == kFetchOpType) {
fetch_count++;
PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name,
......@@ -168,7 +169,7 @@ static bool has_fetch_operators(
"The number of fetch operators should match 'fetch_targets'");
// When fetch operator are present, so should be fetch_holder
auto var = block->FindVar(fetch_holder_name);
auto var = block.FindVar(fetch_holder_name);
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
fetch_holder_name);
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
......@@ -184,10 +185,19 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
std::map<std::string, LoDTensor*>& fetch_targets,
const std::string& feed_holder_name,
const std::string& fetch_holder_name) {
auto* copy_program = new ProgramDesc(program);
bool has_feed_ops =
has_feed_operators(program.Block(0), feed_targets, feed_holder_name);
bool has_fetch_ops =
has_fetch_operators(program.Block(0), fetch_targets, fetch_holder_name);
ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program);
if (!has_feed_ops || !has_fetch_ops) {
copy_program = std::unique_ptr<ProgramDesc>(new ProgramDesc(program)).get();
}
auto* global_block = copy_program->MutableBlock(0);
if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) {
if (!has_feed_ops) {
// create feed_holder variable
auto* feed_holder = global_block->Var(feed_holder_name);
feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
......@@ -220,7 +230,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
}
if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) {
if (!has_fetch_ops) {
// create fetch_holder variable
auto* fetch_holder = global_block->Var(fetch_holder_name);
fetch_holder->SetType(proto::VarType::FETCH_LIST);
......@@ -254,8 +264,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
GetFetchVariable(*scope, fetch_holder_name, idx);
}
}
delete copy_program;
}
ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program,
......@@ -305,9 +313,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
} // if (create_vars)
for (auto& op : ctx->ops_) {
VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
op->Run(*local_scope, place_);
if (FLAGS_benchmark) {
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册