未验证 提交 17bee1d9 编写于 作者: Z zhupengyang 提交者: GitHub

Op(brelu) error message enhancement (#23606)

上级 df538439
...@@ -9179,6 +9179,8 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None): ...@@ -9179,6 +9179,8 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
#[[ 1. 6.] #[[ 1. 6.]
#[ 1. 10.]] #[ 1. 10.]]
""" """
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'brelu')
helper = LayerHelper('brelu', **locals()) helper = LayerHelper('brelu', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op( helper.append_op(
......
...@@ -520,6 +520,20 @@ class TestBRelu(TestActivation): ...@@ -520,6 +520,20 @@ class TestBRelu(TestActivation):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestBReluOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.brelu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.brelu, x_int32)
# support the input dtype is float16
x_fp16 = fluid.layers.data(
name='x_fp16', shape=[12, 10], dtype='float16')
fluid.layers.brelu(x_fp16)
class TestRelu6(TestActivation): class TestRelu6(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "relu6" self.op_type = "relu6"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册