From b3f6eedb77925c28a193eaedb858220b9417c5ca Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 12 Oct 2021 12:55:02 +0800 Subject: [PATCH] refine LarsOptimizer (#36351) --- python/paddle/fluid/optimizer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 24076e82b03..4625d7ea89b 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -2047,11 +2047,15 @@ class LarsMomentumOptimizer(Optimizer): def _append_optimize_op(self, block, param_and_grad): assert isinstance(block, framework.Block) _lars_weight_decay = self._lars_weight_decay + _lars_coeff = self._lars_coeff param_name = param_and_grad[0].name + is_excluded = False if len(self._exclude_from_weight_decay) > 0: for name in self._exclude_from_weight_decay: if name in param_name: _lars_weight_decay = 0.0 + _lars_coeff = 0.0 + is_excluded = True break velocity_acc = self._get_accumulator(self._velocity_acc_str, @@ -2065,7 +2069,7 @@ class LarsMomentumOptimizer(Optimizer): attrs = { "mu": self._momentum, - "lars_coeff": self._lars_coeff, + "lars_coeff": _lars_coeff, "lars_weight_decay": _lars_weight_decay, "multi_precision": find_master, "rescale_grad": self._rescale_grad @@ -2086,7 +2090,7 @@ class LarsMomentumOptimizer(Optimizer): # create the momentum optimize op momentum_op = block.append_op( - type=self.type, + type='momentum' if is_excluded else self.type, inputs=inputs, outputs=outputs, attrs=attrs, -- GitLab