diff --git a/paddle/fluid/operators/optimizers/lars_momentum_op.cu b/paddle/fluid/operators/optimizers/lars_momentum_op.cu index fe5cd066864b82c734614e33869dff1734bee6d0..5b883a11e57335114cb90b34b9e77ca0e07e209d 100644 --- a/paddle/fluid/operators/optimizers/lars_momentum_op.cu +++ b/paddle/fluid/operators/optimizers/lars_momentum_op.cu @@ -204,7 +204,7 @@ __forceinline__ __device__ void MomentumUpdate( const bool is_amp) { const MT lr = learning_rate[0]; MT local_lr = lr; - if (lars_weight_decay > static_cast(0)) { + if (param_norm > static_cast(0) && grad_norm > static_cast(0)) { local_lr = lr * lars_coeff * param_norm / (fma(lars_weight_decay, param_norm, grad_norm) + epsilon); } diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 24b41a91341ee85b058e4f6226c7c41f06c90bc6..7bf4608de89c9cb57e7f2309e6165dbb636ce41a 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -2156,27 +2156,16 @@ class LarsMomentumOptimizer(Optimizer): outputs["MasterParamOut"] = master_weight if framework._non_static_mode(): - if _lars_weight_decay != 0.0: - tmp, tmp2 = _C_ops.lars_momentum( - [param_and_grad[0]], [param_and_grad[1]], [velocity_acc], - [lr], [param_and_grad[0]], [velocity_acc], "mu", - self._momentum, "lars_coeff", self._lars_coeff, - "lars_weight_decay", [_lars_weight_decay], - "multi_precision", find_master, "epsilon", self._epsilon, - "rescale_grad", self._rescale_grad) - else: - _C_ops.momentum(param_and_grad[0], param_and_grad[1], - velocity_acc, lr, master_weight, - param_and_grad[0], velocity_acc, master_weight, - "mu", self._momentum, "lars_coeff", - self._lars_coeff, "lars_weight_decay", - [_lars_weight_decay], "multi_precision", - find_master, "epsilon", self._epsilon, - "rescale_grad", self._rescale_grad) + tmp, tmp2 = _C_ops.lars_momentum( + [param_and_grad[0]], [param_and_grad[1]], [velocity_acc], [lr], + [param_and_grad[0]], [velocity_acc], "mu", self._momentum, + "lars_coeff", self._lars_coeff, "lars_weight_decay", + [_lars_weight_decay], "multi_precision", find_master, "epsilon", + self._epsilon, "rescale_grad", self._rescale_grad) else: # create the momentum optimize op momentum_op = block.append_op( - type=self.type if _lars_weight_decay != 0.0 else 'momentum', + type=self.type, inputs=inputs, outputs=outputs, attrs=attrs,