diff --git a/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py b/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py index aaba571e1a6b9e1abc1974b5f4fef81279843820..a301748ed7bbb8b814e5ddabe195a5a86ddaaff0 100644 --- a/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py @@ -112,7 +112,9 @@ class TestKLDivLossDygraph(unittest.TestCase): input = paddle.fluid.data(name='input', shape=[5, 20]) label = paddle.fluid.data(name='label', shape=[5, 20]) - pred_loss = paddle.nn.functional.kl_div(input, label) + paddle.nn.functional.kl_div(input, label) + paddle.nn.functional.kl_div(input, label, 'sum') + paddle.nn.functional.kl_div(input, label, 'batchmean') class TestKLDivLossTypePromotion(unittest.TestCase): diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 87eb564f60ea6cde466ba5acc21233299bc6173f..2332c14b2d97a3dd44e3541ac1e6b416adb017d3 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -903,7 +903,15 @@ def kl_div(input, label, reduction='mean', name=None): label = paddle.cast(label, 'float64') if paddle.in_dynamic_mode(): - out = _C_ops.kldiv_loss(input, label, 'reduction', reduction) + out = _C_ops.kldiv_loss(input, label, 'reduction', 'none') + if reduction == 'mean': + out = paddle.mean(out) + elif reduction == 'sum': + out = paddle.sum(out) + elif reduction == 'batchmean': + if len(input.shape) > 0: + batch_size = input.shape[0] + out = paddle.sum(out) / batch_size return out helper = LayerHelper('kl_div', **locals()) @@ -920,7 +928,15 @@ def kl_div(input, label, reduction='mean', name=None): inputs={'X': input, 'Target': label}, outputs={'Loss': loss}, - attrs={'reduction': reduction}) + attrs={'reduction': 'none'}) + + if reduction == 'mean': + loss = paddle.mean(loss) + elif reduction == 'sum': + loss = paddle.sum(loss) + elif reduction == 'batchmean': + batch_size = paddle.shape(input)[0] + loss = paddle.sum(loss) / batch_size return loss