未验证 提交 0fad5ef2 编写于 作者: W wawltor 提交者: GitHub

fix sign op input error check (#20473)

fix sign op input error check
test=release/1.6
上级 ac56b467
......@@ -16335,9 +16335,12 @@ def sign(x):
"The type of 'x' in sign_op must be Variable or numpy.ndarray, but received %s."
% (type(x)))
if convert_dtype(x.dtype) not in ['float32', 'float64']:
if convert_dtype(x.dtype) in ['float16']:
warnings.warn(
"The data type of 'x' in sign_op only support float16 in GPU now.")
if convert_dtype(x.dtype) not in ['float16', 'float32', 'float64']:
raise TypeError(
"The data type of 'x' in sign_op must be float32 or float64, but received %s."
"The data type of 'x' in sign_op must be float16, float32 or float64, but received %s."
% (convert_dtype(x.dtype)))
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......
......@@ -42,10 +42,13 @@ class TestSignOpError(OpTest):
# The input type of sign_op must be Variable or numpy.ndarray.
input1 = 12
self.assertRaises(TypeError, fluid.layers.sign, input1)
# The input dtype of sign_op must be float32, float64.
# The input dtype of sign_op must be float16, float32, float64.
input2 = fluid.layers.data(
name='input2', shape=[12, 10], dtype="int32")
self.assertRaises(TypeError, fluid.layers.sign, input2)
input3 = fluid.layers.data(
name='input3', shape=[4], dtype="float16")
fluid.layers.sign(input3)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册