diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index 733157ea05ed39434b9a750e3a94ea548f512ce6..48e37796e1b4190e50602421106a105e4d4f6d74 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -57,12 +57,12 @@ class WhileOp : public framework::OperatorBase { PADDLE_ENFORCE(platform::is_cpu_place(cond.place()), "Condition of while op must in CPU memory."); + + auto ctx = executor.Prepare(*program, block->ID()); while (cond.data()[0]) { auto ¤t_scope = scope.NewScope(); step_scopes->push_back(¤t_scope); - - executor.Run(*program, ¤t_scope, block->ID(), - false /*create_local_scope*/); + executor.RunPreparedContext(ctx.get(), ¤t_scope, false); } } }; @@ -109,6 +109,7 @@ class WhileGradOp : public framework::OperatorBase { framework::Executor executor(dev_place); auto *block = Attr(kStepBlock); auto *program = block->Program(); + auto ctx = executor.Prepare(*program, block->ID()); auto *step_scopes = scope.FindVar(Input(kStepScopes))->GetMutable(); @@ -161,8 +162,7 @@ class WhileGradOp : public framework::OperatorBase { } } } - - executor.Run(*program, *cur_scope_iter, block->ID(), false); + executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false); auto &pg_names = Outputs(kXGRAD); auto &p_names = Inputs(kX);