提交 1b10a784 编写于 作者: Y Yiqun Liu 提交者: ceci3

Optimize while_op when is_test is true. (#15811)

test=develop
上级 91838c32
...@@ -19,6 +19,10 @@ namespace framework { ...@@ -19,6 +19,10 @@ namespace framework {
void LoDRankTable::Reset(const LoD& lod, size_t level) { void LoDRankTable::Reset(const LoD& lod, size_t level) {
this->coarse_lod_.clear(); this->coarse_lod_.clear();
this->items_.clear(); this->items_.clear();
if (lod.size() == 0) {
// Reset to a empty rank table.
return;
}
PADDLE_ENFORCE(level < lod.size(), PADDLE_ENFORCE(level < lod.size(),
"Cannot rank lod since the level %d is less than lod size %d", "Cannot rank lod since the level %d is less than lod size %d",
level, lod.size()); level, lod.size());
......
...@@ -58,6 +58,7 @@ class WhileOp : public framework::OperatorBase { ...@@ -58,6 +58,7 @@ class WhileOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition))); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition)));
auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>(); auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1})); PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1}));
...@@ -77,13 +78,33 @@ class WhileOp : public framework::OperatorBase { ...@@ -77,13 +78,33 @@ class WhileOp : public framework::OperatorBase {
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
auto ctx = executor.Prepare(*program, block->ID(), skip_vars); auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
if (!is_test) {
while (cond.data<bool>()[0]) { while (cond.data<bool>()[0]) {
auto &current_scope = scope.NewScope(); auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope); step_scopes->push_back(&current_scope);
executor.RunPreparedContext(ctx.get(), &current_scope, false, true, true); executor.RunPreparedContext(ctx.get(), &current_scope, false, true,
if (is_test) { true);
scope.DeleteScope(&current_scope); }
} else {
auto &current_scope = scope.NewScope();
executor.CreateVariables(*program, &current_scope, block->ID());
while (cond.data<bool>()[0]) {
for (auto &name : current_scope.LocalVarNames()) {
auto *var = current_scope.Var(name);
framework::LoD empty_lod;
if (var->IsType<framework::LoDTensor>()) {
// Clear all lod information for all lod_tensors.
auto *t = var->GetMutable<framework::LoDTensor>();
t->set_lod(empty_lod);
} else if (var->IsType<framework::LoDRankTable>()) {
auto *t = var->GetMutable<framework::LoDRankTable>();
t->Reset(empty_lod, 0);
}
} }
executor.RunPreparedContext(ctx.get(), &current_scope, false, false,
false);
}
scope.DeleteScope(&current_scope);
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册