diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index dbcc4f0c05fdade4eb1fcbb26fd6c7752a4474f0..f0411d096dee47c5ee977b36d7e7e51f9ada6d04 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1594,9 +1594,9 @@ def poisson_nll_loss( + 0.5 * paddle.log(2 * math.pi * label) ) loss_out += paddle.where( - stirling_approx <= 1, - paddle.zeros_like(stirling_approx), + label > 1, stirling_approx, + paddle.zeros_like(stirling_approx), ) if reduction == 'mean': loss_out = paddle.mean(loss_out) diff --git a/test/legacy_test/test_poisson_nll_loss.py b/test/legacy_test/test_poisson_nll_loss.py index 096018a6e2bf0b5e395a0ec37ea7b035ca57f846..14ad3755199140d4bc2b092b32eb82b7faf6c055 100644 --- a/test/legacy_test/test_poisson_nll_loss.py +++ b/test/legacy_test/test_poisson_nll_loss.py @@ -51,7 +51,9 @@ def ref_poisson_nll_loss( stirling_approx = ( label * np.log(label) - label + 0.5 * np.log(2 * np.pi * label) ) - loss_out += np.where(stirling_approx <= 1, 0, stirling_approx) + loss_out += np.where( + label > 1, stirling_approx, np.zeros_like(stirling_approx) + ) if reduction == 'none': return loss_out