From 88e6dc4ac5a5f0a4ed0c54365e4210528da6f3ab Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Tue, 5 Jan 2021 15:11:07 +0800 Subject: [PATCH] optimize momentum to speedup dygraph, a little, test=develop (#30099) --- python/paddle/optimizer/momentum.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python/paddle/optimizer/momentum.py b/python/paddle/optimizer/momentum.py index b9d05eb8a72..bfcd2bc038b 100644 --- a/python/paddle/optimizer/momentum.py +++ b/python/paddle/optimizer/momentum.py @@ -198,10 +198,6 @@ class Momentum(Optimizer): velocity_acc = self._get_accumulator(self._velocity_acc_str, param_and_grad[0]) - find_master = self._multi_precision and param_and_grad[ - 0].dtype == core.VarDesc.VarType.FP16 - master_weight = (self._master_weights[param_and_grad[0].name] - if find_master else None) lr = self._create_param_lr(param_and_grad) if framework.in_dygraph_mode(): @@ -213,6 +209,11 @@ class Momentum(Optimizer): self._regularization_coeff) return None + find_master = self._multi_precision and param_and_grad[ + 0].dtype == core.VarDesc.VarType.FP16 + master_weight = (self._master_weights[param_and_grad[0].name] + if find_master else None) + attrs = { "mu": self._momentum, "use_nesterov": self._use_nesterov, -- GitLab