diff --git a/python/paddle/optimizer/momentum.py b/python/paddle/optimizer/momentum.py index b9d05eb8a72e781953603a7b5aafb4e06ec358ce..bfcd2bc038b6f44bfb7e969600d113c84739c73b 100644 --- a/python/paddle/optimizer/momentum.py +++ b/python/paddle/optimizer/momentum.py @@ -198,10 +198,6 @@ class Momentum(Optimizer): velocity_acc = self._get_accumulator(self._velocity_acc_str, param_and_grad[0]) - find_master = self._multi_precision and param_and_grad[ - 0].dtype == core.VarDesc.VarType.FP16 - master_weight = (self._master_weights[param_and_grad[0].name] - if find_master else None) lr = self._create_param_lr(param_and_grad) if framework.in_dygraph_mode(): @@ -213,6 +209,11 @@ class Momentum(Optimizer): self._regularization_coeff) return None + find_master = self._multi_precision and param_and_grad[ + 0].dtype == core.VarDesc.VarType.FP16 + master_weight = (self._master_weights[param_and_grad[0].name] + if find_master else None) + attrs = { "mu": self._momentum, "use_nesterov": self._use_nesterov,