未验证 提交 c006a609 编写于 作者: D duanboqiang 提交者: GitHub

fix lars optitmizer bug (#40892)

* fix lars optitmizer bug

* Update optimizer.py
上级 9ab3c76b
...@@ -204,7 +204,7 @@ __forceinline__ __device__ void MomentumUpdate( ...@@ -204,7 +204,7 @@ __forceinline__ __device__ void MomentumUpdate(
const bool is_amp) { const bool is_amp) {
const MT lr = learning_rate[0]; const MT lr = learning_rate[0];
MT local_lr = lr; MT local_lr = lr;
if (lars_weight_decay > static_cast<MT>(0)) { if (param_norm > static_cast<MT>(0) && grad_norm > static_cast<MT>(0)) {
local_lr = lr * lars_coeff * param_norm / local_lr = lr * lars_coeff * param_norm /
(fma(lars_weight_decay, param_norm, grad_norm) + epsilon); (fma(lars_weight_decay, param_norm, grad_norm) + epsilon);
} }
......
...@@ -2156,27 +2156,16 @@ class LarsMomentumOptimizer(Optimizer): ...@@ -2156,27 +2156,16 @@ class LarsMomentumOptimizer(Optimizer):
outputs["MasterParamOut"] = master_weight outputs["MasterParamOut"] = master_weight
if framework._non_static_mode(): if framework._non_static_mode():
if _lars_weight_decay != 0.0:
tmp, tmp2 = _C_ops.lars_momentum( tmp, tmp2 = _C_ops.lars_momentum(
[param_and_grad[0]], [param_and_grad[1]], [velocity_acc], [param_and_grad[0]], [param_and_grad[1]], [velocity_acc], [lr],
[lr], [param_and_grad[0]], [velocity_acc], "mu", [param_and_grad[0]], [velocity_acc], "mu", self._momentum,
self._momentum, "lars_coeff", self._lars_coeff, "lars_coeff", self._lars_coeff, "lars_weight_decay",
"lars_weight_decay", [_lars_weight_decay], [_lars_weight_decay], "multi_precision", find_master, "epsilon",
"multi_precision", find_master, "epsilon", self._epsilon, self._epsilon, "rescale_grad", self._rescale_grad)
"rescale_grad", self._rescale_grad)
else:
_C_ops.momentum(param_and_grad[0], param_and_grad[1],
velocity_acc, lr, master_weight,
param_and_grad[0], velocity_acc, master_weight,
"mu", self._momentum, "lars_coeff",
self._lars_coeff, "lars_weight_decay",
[_lars_weight_decay], "multi_precision",
find_master, "epsilon", self._epsilon,
"rescale_grad", self._rescale_grad)
else: else:
# create the momentum optimize op # create the momentum optimize op
momentum_op = block.append_op( momentum_op = block.append_op(
type=self.type if _lars_weight_decay != 0.0 else 'momentum', type=self.type,
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
attrs=attrs, attrs=attrs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册