未验证 提交 f6a8125f 编写于 作者: A akshatvishu 提交者: GitHub

make loss cond. on label while using Stirling approximation (#56992)

上级 c34853af
...@@ -1594,9 +1594,9 @@ def poisson_nll_loss( ...@@ -1594,9 +1594,9 @@ def poisson_nll_loss(
+ 0.5 * paddle.log(2 * math.pi * label) + 0.5 * paddle.log(2 * math.pi * label)
) )
loss_out += paddle.where( loss_out += paddle.where(
stirling_approx <= 1, label > 1,
paddle.zeros_like(stirling_approx),
stirling_approx, stirling_approx,
paddle.zeros_like(stirling_approx),
) )
if reduction == 'mean': if reduction == 'mean':
loss_out = paddle.mean(loss_out) loss_out = paddle.mean(loss_out)
......
...@@ -51,7 +51,9 @@ def ref_poisson_nll_loss( ...@@ -51,7 +51,9 @@ def ref_poisson_nll_loss(
stirling_approx = ( stirling_approx = (
label * np.log(label) - label + 0.5 * np.log(2 * np.pi * label) label * np.log(label) - label + 0.5 * np.log(2 * np.pi * label)
) )
loss_out += np.where(stirling_approx <= 1, 0, stirling_approx) loss_out += np.where(
label > 1, stirling_approx, np.zeros_like(stirling_approx)
)
if reduction == 'none': if reduction == 'none':
return loss_out return loss_out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册