未验证 提交 bffb5aaf 编写于 作者: W wawltor 提交者: GitHub

Fix the api input type and data dtype check, cherry-pick from develop(#20138) (#20429)

Fix the api input type and data dtype check in the op of sign
test=release/1.6
上级 f8ef714b
......@@ -15400,7 +15400,17 @@ def sign(x):
helper = LayerHelper("sign", **locals())
if not isinstance(x, Variable):
x = assign(x)
if isinstance(x, np.ndarray):
x = assign(x)
else:
raise TypeError(
"The type of 'x' in sign_op must be Variable or numpy.ndarray, but received %s."
% (type(x)))
if convert_dtype(x.dtype) not in ['float32', 'float64']:
raise TypeError(
"The data type of 'x' in sign_op must be float32 or float64, but received %s."
% (convert_dtype(x.dtype)))
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......
......@@ -17,6 +17,8 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestSignOp(OpTest):
......@@ -34,5 +36,17 @@ class TestSignOp(OpTest):
self.check_grad(['X'], 'Out')
class TestSignOpError(OpTest):
def test_errors(self):
with program_guard(Program(), Program()):
# The input type of sign_op must be Variable or numpy.ndarray.
input1 = 12
self.assertRaises(TypeError, fluid.layers.sign, input1)
# The input dtype of sign_op must be float32, float64.
input2 = fluid.layers.data(
name='input2', shape=[12, 10], dtype="int32")
self.assertRaises(TypeError, fluid.layers.sign, input2)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册