From 1b10a7843c416d499ddaf2fd76df57b360b880ce Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Wed, 27 Feb 2019 19:31:00 +0800 Subject: [PATCH] Optimize while_op when is_test is true. (#15811) test=develop --- paddle/fluid/framework/lod_rank_table.cc | 4 +++ .../fluid/operators/controlflow/while_op.cc | 31 ++++++++++++++++--- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/framework/lod_rank_table.cc b/paddle/fluid/framework/lod_rank_table.cc index 6bc795b642b..12536ec60b7 100644 --- a/paddle/fluid/framework/lod_rank_table.cc +++ b/paddle/fluid/framework/lod_rank_table.cc @@ -19,6 +19,10 @@ namespace framework { void LoDRankTable::Reset(const LoD& lod, size_t level) { this->coarse_lod_.clear(); this->items_.clear(); + if (lod.size() == 0) { + // Reset to a empty rank table. + return; + } PADDLE_ENFORCE(level < lod.size(), "Cannot rank lod since the level %d is less than lod size %d", level, lod.size()); diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index 0360cf52735..77fdcf41a7e 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -58,6 +58,7 @@ class WhileOp : public framework::OperatorBase { void RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const override { PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition))); + auto &cond = scope.FindVar(Input(kCondition))->Get(); PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1})); @@ -77,13 +78,33 @@ class WhileOp : public framework::OperatorBase { VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); auto ctx = executor.Prepare(*program, block->ID(), skip_vars); - while (cond.data()[0]) { + if (!is_test) { + while (cond.data()[0]) { + auto ¤t_scope = scope.NewScope(); + step_scopes->push_back(¤t_scope); + executor.RunPreparedContext(ctx.get(), ¤t_scope, false, true, + true); + } + } else { auto ¤t_scope = scope.NewScope(); - step_scopes->push_back(¤t_scope); - executor.RunPreparedContext(ctx.get(), ¤t_scope, false, true, true); - if (is_test) { - scope.DeleteScope(¤t_scope); + executor.CreateVariables(*program, ¤t_scope, block->ID()); + while (cond.data()[0]) { + for (auto &name : current_scope.LocalVarNames()) { + auto *var = current_scope.Var(name); + framework::LoD empty_lod; + if (var->IsType()) { + // Clear all lod information for all lod_tensors. + auto *t = var->GetMutable(); + t->set_lod(empty_lod); + } else if (var->IsType()) { + auto *t = var->GetMutable(); + t->Reset(empty_lod, 0); + } + } + executor.RunPreparedContext(ctx.get(), ¤t_scope, false, false, + false); } + scope.DeleteScope(¤t_scope); } } }; -- GitLab