From a6e9ff851c9b780f63f17ab4cd9c9863db3cd939 Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Thu, 18 Nov 2021 10:56:39 +0800 Subject: [PATCH] Fix the slow running speed of kl_div when option 'reduction' is set (#37283) * Fix the slow running speed of kl_div when option reduction is set * fix unittest coverage --- .../tests/unittests/test_kldiv_loss_op.py | 4 +++- python/paddle/nn/functional/loss.py | 20 +++++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py b/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py index aaba571e1a6..a301748ed7b 100644 --- a/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py @@ -112,7 +112,9 @@ class TestKLDivLossDygraph(unittest.TestCase): input = paddle.fluid.data(name='input', shape=[5, 20]) label = paddle.fluid.data(name='label', shape=[5, 20]) - pred_loss = paddle.nn.functional.kl_div(input, label) + paddle.nn.functional.kl_div(input, label) + paddle.nn.functional.kl_div(input, label, 'sum') + paddle.nn.functional.kl_div(input, label, 'batchmean') class TestKLDivLossTypePromotion(unittest.TestCase): diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 87eb564f60e..2332c14b2d9 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -903,7 +903,15 @@ def kl_div(input, label, reduction='mean', name=None): label = paddle.cast(label, 'float64') if paddle.in_dynamic_mode(): - out = _C_ops.kldiv_loss(input, label, 'reduction', reduction) + out = _C_ops.kldiv_loss(input, label, 'reduction', 'none') + if reduction == 'mean': + out = paddle.mean(out) + elif reduction == 'sum': + out = paddle.sum(out) + elif reduction == 'batchmean': + if len(input.shape) > 0: + batch_size = input.shape[0] + out = paddle.sum(out) / batch_size return out helper = LayerHelper('kl_div', **locals()) @@ -920,7 +928,15 @@ def kl_div(input, label, reduction='mean', name=None): inputs={'X': input, 'Target': label}, outputs={'Loss': loss}, - attrs={'reduction': reduction}) + attrs={'reduction': 'none'}) + + if reduction == 'mean': + loss = paddle.mean(loss) + elif reduction == 'sum': + loss = paddle.sum(loss) + elif reduction == 'batchmean': + batch_size = paddle.shape(input)[0] + loss = paddle.sum(loss) / batch_size return loss -- GitLab