diff --git a/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py b/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py index 3a3b7071e04dced7fc33c4ac56a9018bb82afbf4..aaba571e1a6b9e1abc1974b5f4fef81279843820 100644 --- a/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py +++ b/python/paddle/fluid/tests/unittests/test_kldiv_loss_op.py @@ -115,5 +115,20 @@ class TestKLDivLossDygraph(unittest.TestCase): pred_loss = paddle.nn.functional.kl_div(input, label) +class TestKLDivLossTypePromotion(unittest.TestCase): + def test_kl_div_promotion(self): + + with paddle.fluid.dygraph.guard(): + x1 = paddle.rand([5, 20], dtype='float32') + target1 = paddle.rand([5, 20], dtype='float64') + + kldiv_criterion = paddle.nn.KLDivLoss() + pred_loss1 = kldiv_criterion(x1, target1) + + x2 = paddle.rand([5, 20], dtype='float64') + target2 = paddle.rand([5, 20], dtype='float32') + pred_loss2 = paddle.nn.functional.kl_div(x2, target2) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 76722f26007c4aa9d673929e2065ddbfed644256..05daf24ca24ab886da834296f4e355715313ea5c 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -800,6 +800,16 @@ def kl_div(input, label, reduction='mean', name=None): # shape=[5, 20] """ + # ugly type promotion + if fluid.data_feeder.convert_dtype( + input.dtype) == 'float32' and fluid.data_feeder.convert_dtype( + label.dtype) == 'float64': + input = fluid.layers.cast(input, 'float64') + elif fluid.data_feeder.convert_dtype( + input.dtype) == 'float64' and fluid.data_feeder.convert_dtype( + label.dtype) == 'float32': + label = fluid.layers.cast(label, 'float64') + if paddle.in_dynamic_mode(): out = core.ops.kldiv_loss(input, label, 'reduction', reduction) return out