未验证 提交 c006a609 编写于 作者: D duanboqiang 提交者: GitHub

fix lars optitmizer bug (#40892)

* fix lars optitmizer bug

* Update optimizer.py
上级 9ab3c76b
......@@ -204,7 +204,7 @@ __forceinline__ __device__ void MomentumUpdate(
const bool is_amp) {
const MT lr = learning_rate[0];
MT local_lr = lr;
if (lars_weight_decay > static_cast<MT>(0)) {
if (param_norm > static_cast<MT>(0) && grad_norm > static_cast<MT>(0)) {
local_lr = lr * lars_coeff * param_norm /
(fma(lars_weight_decay, param_norm, grad_norm) + epsilon);
}
......
......@@ -2156,27 +2156,16 @@ class LarsMomentumOptimizer(Optimizer):
outputs["MasterParamOut"] = master_weight
if framework._non_static_mode():
if _lars_weight_decay != 0.0:
tmp, tmp2 = _C_ops.lars_momentum(
[param_and_grad[0]], [param_and_grad[1]], [velocity_acc],
[lr], [param_and_grad[0]], [velocity_acc], "mu",
self._momentum, "lars_coeff", self._lars_coeff,
"lars_weight_decay", [_lars_weight_decay],
"multi_precision", find_master, "epsilon", self._epsilon,
"rescale_grad", self._rescale_grad)
else:
_C_ops.momentum(param_and_grad[0], param_and_grad[1],
velocity_acc, lr, master_weight,
param_and_grad[0], velocity_acc, master_weight,
"mu", self._momentum, "lars_coeff",
self._lars_coeff, "lars_weight_decay",
[_lars_weight_decay], "multi_precision",
find_master, "epsilon", self._epsilon,
"rescale_grad", self._rescale_grad)
tmp, tmp2 = _C_ops.lars_momentum(
[param_and_grad[0]], [param_and_grad[1]], [velocity_acc], [lr],
[param_and_grad[0]], [velocity_acc], "mu", self._momentum,
"lars_coeff", self._lars_coeff, "lars_weight_decay",
[_lars_weight_decay], "multi_precision", find_master, "epsilon",
self._epsilon, "rescale_grad", self._rescale_grad)
else:
# create the momentum optimize op
momentum_op = block.append_op(
type=self.type if _lars_weight_decay != 0.0 else 'momentum',
type=self.type,
inputs=inputs,
outputs=outputs,
attrs=attrs,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册