未验证 提交 f6a8125f 编写于 作者: A akshatvishu 提交者: GitHub

make loss cond. on label while using Stirling approximation (#56992)

上级 c34853af
......@@ -1594,9 +1594,9 @@ def poisson_nll_loss(
+ 0.5 * paddle.log(2 * math.pi * label)
)
loss_out += paddle.where(
stirling_approx <= 1,
paddle.zeros_like(stirling_approx),
label > 1,
stirling_approx,
paddle.zeros_like(stirling_approx),
)
if reduction == 'mean':
loss_out = paddle.mean(loss_out)
......
......@@ -51,7 +51,9 @@ def ref_poisson_nll_loss(
stirling_approx = (
label * np.log(label) - label + 0.5 * np.log(2 * np.pi * label)
)
loss_out += np.where(stirling_approx <= 1, 0, stirling_approx)
loss_out += np.where(
label > 1, stirling_approx, np.zeros_like(stirling_approx)
)
if reduction == 'none':
return loss_out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册