未验证 提交 bf6d65fc 编写于 作者: Z zhangbo9674 提交者: GitHub

Fix multi tensor momentum regular bug (#38344)

* fix merged_momentum regular bug

* fix bug
上级 8fe1cb72
......@@ -192,7 +192,7 @@ class Momentum(Optimizer):
def _update_regularization(self, weight_decay):
reg_method = ""
reg_coeff = 0
reg_coeff = 0.0
if (isinstance(weight_decay, L2DecayRegularizer)):
reg_method = "l2_decay"
......@@ -306,7 +306,7 @@ class Momentum(Optimizer):
# the param's regularization has been done before, we avoid do l2decay in momentum.
elif param.regularizer is not None:
regularization_method = ""
regularization_coeff = 0
regularization_coeff = 0.0
find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
......@@ -380,7 +380,7 @@ class Momentum(Optimizer):
if isinstance(param.regularizer, L2DecayRegularizer):
regularization_method = "l2_decay"
regularization_coeff = param.regularizer._regularization_coeff
else:
elif param.regularizer is not None:
regularization_method = ""
regularization_coeff = 0.0
if param.dtype == paddle.float32:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册