未验证 提交 b9c7c66e 编写于 作者: L LielinJiang 提交者: GitHub

add type promotion (#27756)

上级 9089841b
...@@ -115,5 +115,20 @@ class TestKLDivLossDygraph(unittest.TestCase): ...@@ -115,5 +115,20 @@ class TestKLDivLossDygraph(unittest.TestCase):
pred_loss = paddle.nn.functional.kl_div(input, label) pred_loss = paddle.nn.functional.kl_div(input, label)
class TestKLDivLossTypePromotion(unittest.TestCase):
def test_kl_div_promotion(self):
with paddle.fluid.dygraph.guard():
x1 = paddle.rand([5, 20], dtype='float32')
target1 = paddle.rand([5, 20], dtype='float64')
kldiv_criterion = paddle.nn.KLDivLoss()
pred_loss1 = kldiv_criterion(x1, target1)
x2 = paddle.rand([5, 20], dtype='float64')
target2 = paddle.rand([5, 20], dtype='float32')
pred_loss2 = paddle.nn.functional.kl_div(x2, target2)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -800,6 +800,16 @@ def kl_div(input, label, reduction='mean', name=None): ...@@ -800,6 +800,16 @@ def kl_div(input, label, reduction='mean', name=None):
# shape=[5, 20] # shape=[5, 20]
""" """
# ugly type promotion
if fluid.data_feeder.convert_dtype(
input.dtype) == 'float32' and fluid.data_feeder.convert_dtype(
label.dtype) == 'float64':
input = fluid.layers.cast(input, 'float64')
elif fluid.data_feeder.convert_dtype(
input.dtype) == 'float64' and fluid.data_feeder.convert_dtype(
label.dtype) == 'float32':
label = fluid.layers.cast(label, 'float64')
if paddle.in_dynamic_mode(): if paddle.in_dynamic_mode():
out = core.ops.kldiv_loss(input, label, 'reduction', reduction) out = core.ops.kldiv_loss(input, label, 'reduction', reduction)
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册