From 794a1958815afaf64ea1642be8a376d78c93bfcf Mon Sep 17 00:00:00 2001 From: chengduo Date: Fri, 26 Apr 2019 10:44:41 +0800 Subject: [PATCH] fix fuse optimizer ops (#17102) test=develop --- .../framework/details/fuse_optimizer_op_pass.cc | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/details/fuse_optimizer_op_pass.cc b/paddle/fluid/framework/details/fuse_optimizer_op_pass.cc index 262d968c3b..312fc89470 100644 --- a/paddle/fluid/framework/details/fuse_optimizer_op_pass.cc +++ b/paddle/fluid/framework/details/fuse_optimizer_op_pass.cc @@ -147,6 +147,18 @@ void FuseOptimizerOpPass::InitFusedGradsAndAllocSpaceForGrads( vars.emplace(node->Var()->Name(), node); } } + + // Set Gradients as Persistable to prevent this var becoming reusable. + for (auto &grad_var_name : grads) { + auto iter = vars.find(grad_var_name); + PADDLE_ENFORCE(iter != vars.end()); + PADDLE_ENFORCE_NOT_NULL(iter->second->Var()); + PADDLE_ENFORCE(iter->second->Var()->GetType() == proto::VarType::LOD_TENSOR, + "Currently the gradient type only should be LoDTensor when " + "fusing optimizer ops."); + iter->second->Var()->SetPersistable(true); + } + // Init Grads for (auto it = local_scopes.rbegin(); it != local_scopes.rend(); ++it) { auto &scope = *it; @@ -154,13 +166,10 @@ void FuseOptimizerOpPass::InitFusedGradsAndAllocSpaceForGrads( PADDLE_ENFORCE(scope->FindVar(fused_grad_name) == nullptr, "%s has existed in scope.", fused_grad_name); scope->Var(fused_grad_name)->GetMutable(); - for (auto &grad_var_name : grads) { auto iter = vars.find(grad_var_name); PADDLE_ENFORCE(iter != vars.end()); PADDLE_ENFORCE_NOT_NULL(iter->second->Var()); - PADDLE_ENFORCE_EQ(iter->second->Var()->GetType(), - proto::VarType::LOD_TENSOR); scope->Var(grad_var_name)->GetMutable(); } } -- GitLab