未验证 提交 7b648ad1 编写于 作者: Z zhupengyang 提交者: GitHub

Op(relu) error message enhancement (#23510)

上级 5d970b58
......@@ -8201,6 +8201,8 @@ def relu(x, name=None):
if in_dygraph_mode():
return core.ops.relu(x)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'relu')
inputs = {'X': [x]}
helper = LayerHelper('relu', **locals())
dtype = helper.input_dtype(input_param_name='x')
......
......@@ -431,6 +431,20 @@ class TestRelu(TestActivation):
self.check_grad(['X'], 'Out')
class TestReluOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.sqrt, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.relu, x_int32)
# support the input dtype is float16
x_fp16 = fluid.layers.data(
name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.relu(x_fp16)
class TestLeakyRelu(TestActivation):
def setUp(self):
self.op_type = "leaky_relu"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册