提交 8ecad985 编写于 作者: L Liu Yiqun

Add the bool variable to decide whether to have a copy of the program in ExecutorPrepareContext.

上级 c0a9aebe
......@@ -35,11 +35,26 @@ namespace paddle {
namespace framework {
struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id)
: prog_(prog), block_id_(block_id) {}
ExecutorPrepareContext(const framework::ProgramDesc* prog, size_t block_id,
bool own_program = true)
: block_id_(block_id), own_program_(own_program) {
if (own_program_) {
prog_ = new ProgramDesc(*prog);
} else {
// If own_program_ is false, we can avoid a clone of the program.
prog_ = prog;
}
}
~ExecutorPrepareContext() {
if (own_program_) {
delete prog_;
}
}
framework::ProgramDesc prog_;
const framework::ProgramDesc* prog_;
size_t block_id_;
bool own_program_;
std::vector<std::unique_ptr<OperatorBase>> ops_;
};
......@@ -94,7 +109,7 @@ static void CheckTensorNANOrInf(const std::string& name,
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) {
auto* ctx = Prepare(pdesc, block_id);
auto* ctx = Prepare(pdesc, block_id, false);
RunPreparedContext(ctx, scope, create_local_scope, create_vars);
delete ctx;
}
......@@ -267,8 +282,8 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
}
ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program,
int block_id) {
auto* ctx = new ExecutorPrepareContext(program, block_id);
int block_id, bool own_program) {
auto* ctx = new ExecutorPrepareContext(&program, block_id, own_program);
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) {
......@@ -279,7 +294,7 @@ ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program,
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope, bool create_vars) {
auto& block = ctx->prog_.Block(ctx->block_id_);
auto& block = ctx->prog_->Block(ctx->block_id_);
Scope* local_scope = scope;
if (create_vars) {
......
......@@ -48,7 +48,7 @@ class Executor {
const std::string& fetch_holder_name = "fetch");
static ExecutorPrepareContext* Prepare(const ProgramDesc& program,
int block_id);
int block_id, bool own_program = true);
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope = true,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册