diff --git a/paddle/fluid/framework/details/fuse_optimizer_op_pass.cc b/paddle/fluid/framework/details/fuse_optimizer_op_pass.cc index 262d968c3b79df6b8eea5f74bfe3bbb41bdcba1c..312fc89470f0ee212f07536c6d9eb55fb70e64ec 100644 --- a/paddle/fluid/framework/details/fuse_optimizer_op_pass.cc +++ b/paddle/fluid/framework/details/fuse_optimizer_op_pass.cc @@ -147,6 +147,18 @@ void FuseOptimizerOpPass::InitFusedGradsAndAllocSpaceForGrads( vars.emplace(node->Var()->Name(), node); } } + + // Set Gradients as Persistable to prevent this var becoming reusable. + for (auto &grad_var_name : grads) { + auto iter = vars.find(grad_var_name); + PADDLE_ENFORCE(iter != vars.end()); + PADDLE_ENFORCE_NOT_NULL(iter->second->Var()); + PADDLE_ENFORCE(iter->second->Var()->GetType() == proto::VarType::LOD_TENSOR, + "Currently the gradient type only should be LoDTensor when " + "fusing optimizer ops."); + iter->second->Var()->SetPersistable(true); + } + // Init Grads for (auto it = local_scopes.rbegin(); it != local_scopes.rend(); ++it) { auto &scope = *it; @@ -154,13 +166,10 @@ void FuseOptimizerOpPass::InitFusedGradsAndAllocSpaceForGrads( PADDLE_ENFORCE(scope->FindVar(fused_grad_name) == nullptr, "%s has existed in scope.", fused_grad_name); scope->Var(fused_grad_name)->GetMutable(); - for (auto &grad_var_name : grads) { auto iter = vars.find(grad_var_name); PADDLE_ENFORCE(iter != vars.end()); PADDLE_ENFORCE_NOT_NULL(iter->second->Var()); - PADDLE_ENFORCE_EQ(iter->second->Var()->GetType(), - proto::VarType::LOD_TENSOR); scope->Var(grad_var_name)->GetMutable(); } }