未验证 提交 5d970b58 编写于 作者: Z zhupengyang 提交者: GitHub

Op(leaky_relu) error message enhancement (#23627)

上级 06d4aa4e
......@@ -9238,6 +9238,9 @@ def leaky_relu(x, alpha=0.02, name=None):
if in_dygraph_mode():
return core.ops.leaky_relu(x, 'alpha', alpha)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'leaky_relu')
inputs = {'X': [x]}
attrs = {'alpha': alpha}
helper = LayerHelper('leaky_relu', **locals())
......
......@@ -450,6 +450,20 @@ class TestLeakyRelu(TestActivation):
self.check_grad(['X'], 'Out')
class TestLeakyReluOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.leaky_relu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.leaky_relu, x_int32)
# support the input dtype is float32
x_fp16 = fluid.layers.data(
name='x_fp16', shape=[12, 10], dtype='float32')
fluid.layers.leaky_relu(x_fp16)
def gelu(x, approximate):
if approximate:
y_ref = 0.5 * x * (1.0 + np.tanh(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册