提交 b94307a9 编写于 作者: Y Yiqun Liu 提交者: ceci3

Revert "Optimize while_op when is_test is true. (#15811)" (#15968)

test=develop
上级 eeb70edd
...@@ -19,10 +19,6 @@ namespace framework { ...@@ -19,10 +19,6 @@ namespace framework {
void LoDRankTable::Reset(const LoD& lod, size_t level) { void LoDRankTable::Reset(const LoD& lod, size_t level) {
this->coarse_lod_.clear(); this->coarse_lod_.clear();
this->items_.clear(); this->items_.clear();
if (lod.size() == 0) {
// Reset to a empty rank table.
return;
}
PADDLE_ENFORCE(level < lod.size(), PADDLE_ENFORCE(level < lod.size(),
"Cannot rank lod since the level %d is less than lod size %d", "Cannot rank lod since the level %d is less than lod size %d",
level, lod.size()); level, lod.size());
......
...@@ -58,7 +58,6 @@ class WhileOp : public framework::OperatorBase { ...@@ -58,7 +58,6 @@ class WhileOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override { const platform::Place &dev_place) const override {
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition))); PADDLE_ENFORCE_NOT_NULL(scope.FindVar(Input(kCondition)));
auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>(); auto &cond = scope.FindVar(Input(kCondition))->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1})); PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1}));
...@@ -78,35 +77,15 @@ class WhileOp : public framework::OperatorBase { ...@@ -78,35 +77,15 @@ class WhileOp : public framework::OperatorBase {
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
auto ctx = executor.Prepare(*program, block->ID(), skip_vars); auto ctx = executor.Prepare(*program, block->ID(), skip_vars);
if (!is_test) {
while (cond.data<bool>()[0]) { while (cond.data<bool>()[0]) {
auto &current_scope = scope.NewScope(); auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope); step_scopes->push_back(&current_scope);
executor.RunPreparedContext(ctx.get(), &current_scope, false, true, executor.RunPreparedContext(ctx.get(), &current_scope, false, true, true);
true); if (is_test) {
}
} else {
auto &current_scope = scope.NewScope();
executor.CreateVariables(*program, &current_scope, block->ID());
while (cond.data<bool>()[0]) {
for (auto &name : current_scope.LocalVarNames()) {
auto *var = current_scope.Var(name);
framework::LoD empty_lod;
if (var->IsType<framework::LoDTensor>()) {
// Clear all lod information for all lod_tensors.
auto *t = var->GetMutable<framework::LoDTensor>();
t->set_lod(empty_lod);
} else if (var->IsType<framework::LoDRankTable>()) {
auto *t = var->GetMutable<framework::LoDRankTable>();
t->Reset(empty_lod, 0);
}
}
executor.RunPreparedContext(ctx.get(), &current_scope, false, false,
false);
}
scope.DeleteScope(&current_scope); scope.DeleteScope(&current_scope);
} }
} }
}
}; };
class WhileOpMaker : public framework::OpProtoAndCheckerMaker { class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册