From 134d809e23d1cf08b2d57f1cf3899d7db06a0ed1 Mon Sep 17 00:00:00 2001 From: Tao Luo Date: Wed, 9 Oct 2019 23:15:28 +0800 Subject: [PATCH] fix softmax input error check on float16 (#20273) test=develop --- python/paddle/fluid/layers/nn.py | 8 ++++++-- python/paddle/fluid/tests/unittests/test_softmax_op.py | 4 +++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 07d4cd123d..7271bc6f90 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -2380,9 +2380,13 @@ def softmax(input, use_cudnn=False, name=None, axis=-1): raise TypeError( "The type of 'input' in softmax must be Variable, but received %s" % (type(input))) - if convert_dtype(input.dtype) not in ['float32', 'float64']: + if convert_dtype(input.dtype) in ['float16']: + warnings.warn( + "The data type of 'input' in softmax only support float16 in GPU now." + ) + if convert_dtype(input.dtype) not in ['float16', 'float32', 'float64']: raise TypeError( - "The data type of 'input' in softmax must be float32 or float64, but received %s." + "The data type of 'input' in softmax must be float16, float32 or float64, but received %s." % (convert_dtype(input.dtype))) dtype = helper.input_dtype() diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py index f6770bdd1d..50b29ba1e5 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py @@ -83,9 +83,11 @@ class TestSoftmaxOpError(OpTest): x1 = fluid.create_lod_tensor( np.array([[-1]]), [[1]], fluid.CPUPlace()) self.assertRaises(TypeError, fluid.layers.softmax, x1) - # The input dtype of softmax_op must be float32 or float64. + # The input dtype of softmax_op must be float16, float32 or float64. x2 = fluid.layers.data(name='x2', shape=[4], dtype="int32") self.assertRaises(TypeError, fluid.layers.softmax, x2) + x3 = fluid.layers.data(name='x3', shape=[4], dtype="float16") + fluid.layers.softmax(x3) class TestSoftmaxOp2(TestSoftmaxOp): -- GitLab