From 798925453eefc25dff1e81b68194b12045bfe65b Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Thu, 28 Feb 2019 13:25:05 +0800 Subject: [PATCH] Revert "Optimize while_op when is_test is true. (#15811)" (#15968) test=develop --- paddle/fluid/framework/lod_rank_table.cc | 4 --- .../fluid/operators/controlflow/while_op.cc | 31 +++---------------- 2 files changed, 5 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/framework/lod_rank_table.cc b/paddle/fluid/framework/lod_rank_table.cc index 12536ec60b7..6bc795b642b 100644 --- a/paddle/fluid/framework/lod_rank_table.cc +++ b/paddle/fluid/framework/lod_rank_table.cc @@ -19,10 +19,6 @@ namespace framework { void LoDRankTable::Reset(const LoD& lod, size_t level) { this->coarse_lod_.clear(); this->items_.clear(); - if (lod.size() == 0) { - // Reset to a empty rank table. - return; - } PADDLE_ENFORCE(level < lod.size(), "Cannot rank lod since the level %d is less than lod size %d", level, lod.size()); diff --git a/paddle/fluid/operators/controlflow/while_op.cc b/paddle/fluid/operators/controlflow/while_op.cc index 77fdcf41a7e..0360cf52735 100644 --- a/paddle/fluid/operators/controlflow/while_op.cc +++ b/paddle/fluid/operators/controlflow/while_op.cc @@ -58,7 +58,6 @@ class WhileOp : public framework::OperatorBase { void RunImpl(const framework::Scope &scope, const platform::Place &dev_place) const override { PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition))); - auto &cond = scope.FindVar(Input(kCondition))->Get(); PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1})); @@ -78,33 +77,13 @@ class WhileOp : public framework::OperatorBase { VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); auto ctx = executor.Prepare(*program, block->ID(), skip_vars); - if (!is_test) { - while (cond.data()[0]) { - auto ¤t_scope = scope.NewScope(); - step_scopes->push_back(¤t_scope); - executor.RunPreparedContext(ctx.get(), ¤t_scope, false, true, - true); - } - } else { + while (cond.data()[0]) { auto ¤t_scope = scope.NewScope(); - executor.CreateVariables(*program, ¤t_scope, block->ID()); - while (cond.data()[0]) { - for (auto &name : current_scope.LocalVarNames()) { - auto *var = current_scope.Var(name); - framework::LoD empty_lod; - if (var->IsType()) { - // Clear all lod information for all lod_tensors. - auto *t = var->GetMutable(); - t->set_lod(empty_lod); - } else if (var->IsType()) { - auto *t = var->GetMutable(); - t->Reset(empty_lod, 0); - } - } - executor.RunPreparedContext(ctx.get(), ¤t_scope, false, false, - false); + step_scopes->push_back(¤t_scope); + executor.RunPreparedContext(ctx.get(), ¤t_scope, false, true, true); + if (is_test) { + scope.DeleteScope(¤t_scope); } - scope.DeleteScope(¤t_scope); } } }; -- GitLab