未验证 提交 794a1958 编写于 作者: C chengduo 提交者: GitHub

fix fuse optimizer ops (#17102)

test=develop
上级 258e000b
......@@ -147,6 +147,18 @@ void FuseOptimizerOpPass::InitFusedGradsAndAllocSpaceForGrads(
vars.emplace(node->Var()->Name(), node);
}
}
// Set Gradients as Persistable to prevent this var becoming reusable.
for (auto &grad_var_name : grads) {
auto iter = vars.find(grad_var_name);
PADDLE_ENFORCE(iter != vars.end());
PADDLE_ENFORCE_NOT_NULL(iter->second->Var());
PADDLE_ENFORCE(iter->second->Var()->GetType() == proto::VarType::LOD_TENSOR,
"Currently the gradient type only should be LoDTensor when "
"fusing optimizer ops.");
iter->second->Var()->SetPersistable(true);
}
// Init Grads
for (auto it = local_scopes.rbegin(); it != local_scopes.rend(); ++it) {
auto &scope = *it;
......@@ -154,13 +166,10 @@ void FuseOptimizerOpPass::InitFusedGradsAndAllocSpaceForGrads(
PADDLE_ENFORCE(scope->FindVar(fused_grad_name) == nullptr,
"%s has existed in scope.", fused_grad_name);
scope->Var(fused_grad_name)->GetMutable<LoDTensor>();
for (auto &grad_var_name : grads) {
auto iter = vars.find(grad_var_name);
PADDLE_ENFORCE(iter != vars.end());
PADDLE_ENFORCE_NOT_NULL(iter->second->Var());
PADDLE_ENFORCE_EQ(iter->second->Var()->GetType(),
proto::VarType::LOD_TENSOR);
scope->Var(grad_var_name)->GetMutable<LoDTensor>();
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册