提交 8ecad985 编写于 作者: L Liu Yiqun

Add the bool variable to decide whether to have a copy of the program in ExecutorPrepareContext.

上级 c0a9aebe
...@@ -35,11 +35,26 @@ namespace paddle { ...@@ -35,11 +35,26 @@ namespace paddle {
namespace framework { namespace framework {
struct ExecutorPrepareContext { struct ExecutorPrepareContext {
ExecutorPrepareContext(const framework::ProgramDesc& prog, size_t block_id) ExecutorPrepareContext(const framework::ProgramDesc* prog, size_t block_id,
: prog_(prog), block_id_(block_id) {} bool own_program = true)
: block_id_(block_id), own_program_(own_program) {
if (own_program_) {
prog_ = new ProgramDesc(*prog);
} else {
// If own_program_ is false, we can avoid a clone of the program.
prog_ = prog;
}
}
~ExecutorPrepareContext() {
if (own_program_) {
delete prog_;
}
}
framework::ProgramDesc prog_; const framework::ProgramDesc* prog_;
size_t block_id_; size_t block_id_;
bool own_program_;
std::vector<std::unique_ptr<OperatorBase>> ops_; std::vector<std::unique_ptr<OperatorBase>> ops_;
}; };
...@@ -94,7 +109,7 @@ static void CheckTensorNANOrInf(const std::string& name, ...@@ -94,7 +109,7 @@ static void CheckTensorNANOrInf(const std::string& name,
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) { bool create_local_scope, bool create_vars) {
auto* ctx = Prepare(pdesc, block_id); auto* ctx = Prepare(pdesc, block_id, false);
RunPreparedContext(ctx, scope, create_local_scope, create_vars); RunPreparedContext(ctx, scope, create_local_scope, create_vars);
delete ctx; delete ctx;
} }
...@@ -267,8 +282,8 @@ void Executor::Run(const ProgramDesc& program, Scope* scope, ...@@ -267,8 +282,8 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
} }
ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program, ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program,
int block_id) { int block_id, bool own_program) {
auto* ctx = new ExecutorPrepareContext(program, block_id); auto* ctx = new ExecutorPrepareContext(&program, block_id, own_program);
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size()); PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
auto& block = program.Block(block_id); auto& block = program.Block(block_id);
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
...@@ -279,7 +294,7 @@ ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program, ...@@ -279,7 +294,7 @@ ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program,
void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope, bool create_vars) { bool create_local_scope, bool create_vars) {
auto& block = ctx->prog_.Block(ctx->block_id_); auto& block = ctx->prog_->Block(ctx->block_id_);
Scope* local_scope = scope; Scope* local_scope = scope;
if (create_vars) { if (create_vars) {
......
...@@ -48,7 +48,7 @@ class Executor { ...@@ -48,7 +48,7 @@ class Executor {
const std::string& fetch_holder_name = "fetch"); const std::string& fetch_holder_name = "fetch");
static ExecutorPrepareContext* Prepare(const ProgramDesc& program, static ExecutorPrepareContext* Prepare(const ProgramDesc& program,
int block_id); int block_id, bool own_program = true);
void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope, void RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
bool create_local_scope = true, bool create_local_scope = true,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册