未验证 提交 b3f6eedb 编写于 作者: Z Zeng Jinle 提交者: GitHub

refine LarsOptimizer (#36351)

上级 f77083bb
......@@ -2047,11 +2047,15 @@ class LarsMomentumOptimizer(Optimizer):
def _append_optimize_op(self, block, param_and_grad):
assert isinstance(block, framework.Block)
_lars_weight_decay = self._lars_weight_decay
_lars_coeff = self._lars_coeff
param_name = param_and_grad[0].name
is_excluded = False
if len(self._exclude_from_weight_decay) > 0:
for name in self._exclude_from_weight_decay:
if name in param_name:
_lars_weight_decay = 0.0
_lars_coeff = 0.0
is_excluded = True
break
velocity_acc = self._get_accumulator(self._velocity_acc_str,
......@@ -2065,7 +2069,7 @@ class LarsMomentumOptimizer(Optimizer):
attrs = {
"mu": self._momentum,
"lars_coeff": self._lars_coeff,
"lars_coeff": _lars_coeff,
"lars_weight_decay": _lars_weight_decay,
"multi_precision": find_master,
"rescale_grad": self._rescale_grad
......@@ -2086,7 +2090,7 @@ class LarsMomentumOptimizer(Optimizer):
# create the momentum optimize op
momentum_op = block.append_op(
type=self.type,
type='momentum' if is_excluded else self.type,
inputs=inputs,
outputs=outputs,
attrs=attrs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册