未验证 提交 a6e9ff85 编写于 作者: L LielinJiang 提交者: GitHub

Fix the slow running speed of kl_div when option 'reduction' is set (#37283)

* Fix the slow running speed of kl_div when option reduction is set

* fix unittest coverage
上级 99909520
...@@ -112,7 +112,9 @@ class TestKLDivLossDygraph(unittest.TestCase): ...@@ -112,7 +112,9 @@ class TestKLDivLossDygraph(unittest.TestCase):
input = paddle.fluid.data(name='input', shape=[5, 20]) input = paddle.fluid.data(name='input', shape=[5, 20])
label = paddle.fluid.data(name='label', shape=[5, 20]) label = paddle.fluid.data(name='label', shape=[5, 20])
pred_loss = paddle.nn.functional.kl_div(input, label) paddle.nn.functional.kl_div(input, label)
paddle.nn.functional.kl_div(input, label, 'sum')
paddle.nn.functional.kl_div(input, label, 'batchmean')
class TestKLDivLossTypePromotion(unittest.TestCase): class TestKLDivLossTypePromotion(unittest.TestCase):
......
...@@ -903,7 +903,15 @@ def kl_div(input, label, reduction='mean', name=None): ...@@ -903,7 +903,15 @@ def kl_div(input, label, reduction='mean', name=None):
label = paddle.cast(label, 'float64') label = paddle.cast(label, 'float64')
if paddle.in_dynamic_mode(): if paddle.in_dynamic_mode():
out = _C_ops.kldiv_loss(input, label, 'reduction', reduction) out = _C_ops.kldiv_loss(input, label, 'reduction', 'none')
if reduction == 'mean':
out = paddle.mean(out)
elif reduction == 'sum':
out = paddle.sum(out)
elif reduction == 'batchmean':
if len(input.shape) > 0:
batch_size = input.shape[0]
out = paddle.sum(out) / batch_size
return out return out
helper = LayerHelper('kl_div', **locals()) helper = LayerHelper('kl_div', **locals())
...@@ -920,7 +928,15 @@ def kl_div(input, label, reduction='mean', name=None): ...@@ -920,7 +928,15 @@ def kl_div(input, label, reduction='mean', name=None):
inputs={'X': input, inputs={'X': input,
'Target': label}, 'Target': label},
outputs={'Loss': loss}, outputs={'Loss': loss},
attrs={'reduction': reduction}) attrs={'reduction': 'none'})
if reduction == 'mean':
loss = paddle.mean(loss)
elif reduction == 'sum':
loss = paddle.sum(loss)
elif reduction == 'batchmean':
batch_size = paddle.shape(input)[0]
loss = paddle.sum(loss) / batch_size
return loss return loss
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册