未验证 提交 98802e1f 编写于 作者: Y Yiqun Liu 提交者: GitHub

Optimize the implementation of while_op again, for cases when is_test is true. (#16359)

test=develop
上级 c34b24ed
......@@ -51,6 +51,7 @@ class WhileOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition)));
auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1}));
......@@ -70,13 +71,34 @@ class WhileOp : public framework::OperatorBase {
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
while (cond.data<bool>()[0]) {
if (!is_test) {
while (cond.data<bool>()[0]) {
auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope);
executor.RunPreparedContext(ctx.get(), &current_scope, false, true,
true);
}
} else {
auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope);
executor.RunPreparedContext(ctx.get(), &current_scope, false, true, true);
if (is_test) {
scope.DeleteScope(&current_scope);
executor.CreateVariables(*program, &current_scope, block->ID());
while (cond.data<bool>()[0]) {
for (auto &name : current_scope.LocalVarNames()) {
auto *var = current_scope.Var(name);
if (var->IsType<framework::LoDTensor>()) {
// Clear all lod information for all lod_tensors.
auto *t = var->GetMutable<framework::LoDTensor>();
framework::LoD empty_lod;
t->set_lod(empty_lod);
} else if (var->IsType<framework::LoDTensorArray>()) {
// Clear elements of all tensor arrays.
auto *t = var->GetMutable<framework::LoDTensorArray>();
t->clear();
}
}
executor.RunPreparedContext(ctx.get(), &current_scope, false, false,
false);
}
scope.DeleteScope(&current_scope);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册