From f6a8125fd4f766c0be562b5f1ef5088d41ed5ce6 Mon Sep 17 00:00:00 2001 From: akshatvishu <33392262+akshatvishu@users.noreply.github.com> Date: Wed, 6 Sep 2023 09:10:09 +0530 Subject: [PATCH] make loss cond. on label while using Stirling approximation (#56992) --- python/paddle/nn/functional/loss.py | 4 ++-- test/legacy_test/test_poisson_nll_loss.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index dbcc4f0c05f..f0411d096de 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1594,9 +1594,9 @@ def poisson_nll_loss( + 0.5 * paddle.log(2 * math.pi * label) ) loss_out += paddle.where( - stirling_approx <= 1, - paddle.zeros_like(stirling_approx), + label > 1, stirling_approx, + paddle.zeros_like(stirling_approx), ) if reduction == 'mean': loss_out = paddle.mean(loss_out) diff --git a/test/legacy_test/test_poisson_nll_loss.py b/test/legacy_test/test_poisson_nll_loss.py index 096018a6e2b..14ad3755199 100644 --- a/test/legacy_test/test_poisson_nll_loss.py +++ b/test/legacy_test/test_poisson_nll_loss.py @@ -51,7 +51,9 @@ def ref_poisson_nll_loss( stirling_approx = ( label * np.log(label) - label + 0.5 * np.log(2 * np.pi * label) ) - loss_out += np.where(stirling_approx <= 1, 0, stirling_approx) + loss_out += np.where( + label > 1, stirling_approx, np.zeros_like(stirling_approx) + ) if reduction == 'none': return loss_out -- GitLab