From 0fad5ef2bfe96f8227ad0135da3f33d763e19499 Mon Sep 17 00:00:00 2001 From: wawltor Date: Sat, 12 Oct 2019 15:03:46 +0800 Subject: [PATCH] fix sign op input error check (#20473) fix sign op input error check test=release/1.6 --- python/paddle/fluid/layers/nn.py | 7 +++++-- python/paddle/fluid/tests/unittests/test_sign_op.py | 5 ++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 1dc3afe1f6d..1fdccdb285b 100755 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -16335,9 +16335,12 @@ def sign(x): "The type of 'x' in sign_op must be Variable or numpy.ndarray, but received %s." % (type(x))) - if convert_dtype(x.dtype) not in ['float32', 'float64']: + if convert_dtype(x.dtype) in ['float16']: + warnings.warn( + "The data type of 'x' in sign_op only support float16 in GPU now.") + if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']: raise TypeError( - "The data type of 'x' in sign_op must be float32 or float64, but received %s." + "The data type of 'x' in sign_op must be float16, float32 or float64, but received %s." % (convert_dtype(x.dtype))) out = helper.create_variable_for_type_inference(dtype=x.dtype) diff --git a/python/paddle/fluid/tests/unittests/test_sign_op.py b/python/paddle/fluid/tests/unittests/test_sign_op.py index 00ac43b9ba5..a5412439fac 100644 --- a/python/paddle/fluid/tests/unittests/test_sign_op.py +++ b/python/paddle/fluid/tests/unittests/test_sign_op.py @@ -42,10 +42,13 @@ class TestSignOpError(OpTest): # The input type of sign_op must be Variable or numpy.ndarray. input1 = 12 self.assertRaises(TypeError, fluid.layers.sign, input1) - # The input dtype of sign_op must be float32, float64. + # The input dtype of sign_op must be float16, float32, float64. input2 = fluid.layers.data( name='input2', shape=[12, 10], dtype="int32") self.assertRaises(TypeError, fluid.layers.sign, input2) + input3 = fluid.layers.data( + name='input3', shape=[4], dtype="float16") + fluid.layers.sign(input3) if __name__ == "__main__": -- GitLab