diff --git a/paddle/fluid/operators/optimizers/lamb_op.h b/paddle/fluid/operators/optimizers/lamb_op.h index 4fbe486cb1d68f9e569db5774e917e6b5ec9f8bb..72a9093859cdf5d678681e0f24648142c04d4be0 100644 --- a/paddle/fluid/operators/optimizers/lamb_op.h +++ b/paddle/fluid/operators/optimizers/lamb_op.h @@ -174,10 +174,11 @@ struct LambParamUpateFunctor { inline HOSTDEVICE void operator()(size_t i) const { T lr = *lr_; - T p_norm = *param_norm_; - T tr_div_norm = *trust_ratio_div_norm_; + T p = *param_norm_; + T t = *trust_ratio_div_norm_; - lr *= p_norm / tr_div_norm; + T r = (p > 0 && t > 0) ? p / t : 1.0; + lr *= r; param_out_[i] = param_[i] - lr * trust_ratio_div_[i]; } };