未验证 提交 a7db9acc 编写于 作者: H hong19860320 提交者: GitHub

Add the constraint for the scale of SELU/selu (#26686)

上级 5f524efe
......@@ -130,6 +130,11 @@ class TestSeluAPI(unittest.TestCase):
# The input dtype must be float16, float32, float64.
x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.selu, x_int32)
# The scale must be greater than 1.0
x_fp32 = paddle.data(name='x_fp32', shape=[12, 10], dtype='float32')
self.assertRaises(ValueError, F.selu, x_fp32, -1.0)
# The alpha must be no less than 0
self.assertRaises(ValueError, F.selu, x_fp32, 1.6, -1.0)
# support the input dtype is float16
x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16')
F.selu(x_fp16)
......
......@@ -652,8 +652,8 @@ def selu(x,
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
scale (float, optional): The value of scale for selu. Default is 1.0507009873554804934193349852946
alpha (float, optional): The value of alpha for selu. Default is 1.6732632423543772848170429916717
scale (float, optional): The value of scale(must be greater than 1.0) for selu. Default is 1.0507009873554804934193349852946
alpha (float, optional): The value of alpha(must be no less than zero) for selu. Default is 1.6732632423543772848170429916717
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
......@@ -672,6 +672,14 @@ def selu(x,
x = paddle.to_tensor(np.array([[0.0, 1.0],[2.0, 3.0]]))
out = F.selu(x) # [[0, 1.050701],[2.101402, 3.152103]]
"""
if scale <= 1.0:
raise ValueError(
"The scale must be greater than 1.0. Received: {}.".format(scale))
if alpha < 0:
raise ValueError(
"The alpha must be no less than zero. Received: {}.".format(alpha))
if in_dygraph_mode():
return core.ops.selu(x, 'scale', scale, 'alpha', alpha)
......
......@@ -559,8 +559,8 @@ class SELU(layers.Layer):
\\end{cases}
Parameters:
scale (float, optional): The value of scale for SELU. Default is 1.0507009873554804934193349852946
alpha (float, optional): The value of alpha for SELU. Default is 1.6732632423543772848170429916717
scale (float, optional): The value of scale(must be greater than 1.0) for SELU. Default is 1.0507009873554804934193349852946
alpha (float, optional): The value of alpha(must be no less than zero) for SELU. Default is 1.6732632423543772848170429916717
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册