提交 a2c0e52f 编写于 作者: X Xin Pan

speed up while_op

上级 509b6835
...@@ -57,12 +57,12 @@ class WhileOp : public framework::OperatorBase { ...@@ -57,12 +57,12 @@ class WhileOp : public framework::OperatorBase {
PADDLE_ENFORCE(platform::is_cpu_place(cond.place()), PADDLE_ENFORCE(platform::is_cpu_place(cond.place()),
"Condition of while op must in CPU memory."); "Condition of while op must in CPU memory.");
auto ctx = executor.Prepare(*program, block->ID());
while (cond.data<bool>()[0]) { while (cond.data<bool>()[0]) {
auto &current_scope = scope.NewScope(); auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope); step_scopes->push_back(&current_scope);
executor.RunPreparedContext(ctx.get(), &current_scope, false);
executor.Run(*program, &current_scope, block->ID(),
false /*create_local_scope*/);
} }
} }
}; };
...@@ -109,6 +109,7 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -109,6 +109,7 @@ class WhileGradOp : public framework::OperatorBase {
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
auto ctx = executor.Prepare(*program, block->ID());
auto *step_scopes = auto *step_scopes =
scope.FindVar(Input(kStepScopes))->GetMutable<StepScopeVar>(); scope.FindVar(Input(kStepScopes))->GetMutable<StepScopeVar>();
...@@ -161,8 +162,7 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -161,8 +162,7 @@ class WhileGradOp : public framework::OperatorBase {
} }
} }
} }
executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false);
executor.Run(*program, *cur_scope_iter, block->ID(), false);
auto &pg_names = Outputs(kXGRAD); auto &pg_names = Outputs(kXGRAD);
auto &p_names = Inputs(kX); auto &p_names = Inputs(kX);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册