From b9c7c66ea52952553a797ac86ff9045f0db9f3fd Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Fri, 9 Oct 2020 11:39:16 +0800 Subject: [PATCH] add type promotion (#27756) --- .../fluid/tests/unittests/test_kldiv_loss_op.py | 15 +++++++++++++++ python/paddle/nn/functional/loss.py | 10 ++++++++++ 2 files changed, 25 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py b/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py index 3a3b7071e04..aaba571e1a6 100644 --- a/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py @@ -115,5 +115,20 @@ class TestKLDivLossDygraph(unittest.TestCase): pred_loss = paddle.nn.functional.kl_div(input, label) +class TestKLDivLossTypePromotion(unittest.TestCase): + def test_kl_div_promotion(self): + + with paddle.fluid.dygraph.guard(): + x1 = paddle.rand([5, 20], dtype='float32') + target1 = paddle.rand([5, 20], dtype='float64') + + kldiv_criterion = paddle.nn.KLDivLoss() + pred_loss1 = kldiv_criterion(x1, target1) + + x2 = paddle.rand([5, 20], dtype='float64') + target2 = paddle.rand([5, 20], dtype='float32') + pred_loss2 = paddle.nn.functional.kl_div(x2, target2) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 76722f26007..05daf24ca24 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -800,6 +800,16 @@ def kl_div(input, label, reduction='mean', name=None): # shape=[5, 20] """ + # ugly type promotion + if fluid.data_feeder.convert_dtype( + input.dtype) == 'float32' and fluid.data_feeder.convert_dtype( + label.dtype) == 'float64': + input = fluid.layers.cast(input, 'float64') + elif fluid.data_feeder.convert_dtype( + input.dtype) == 'float64' and fluid.data_feeder.convert_dtype( + label.dtype) == 'float32': + label = fluid.layers.cast(label, 'float64') + if paddle.in_dynamic_mode(): out = core.ops.kldiv_loss(input, label, 'reduction', reduction) return out -- GitLab