未验证 提交 66fbfba8 编写于 作者: warrentdrew's avatar warrentdrew 提交者: GitHub

add leaky relu composite rule (#52909)

* add leaky relu composite rule

* add public python api

* unset default negative slope

* fix unittest case
上级 23e96bde
...@@ -1933,6 +1933,8 @@ class TestLeakyRelu(TestActivation): ...@@ -1933,6 +1933,8 @@ class TestLeakyRelu(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "leaky_relu" self.op_type = "leaky_relu"
self.python_api = paddle.nn.functional.leaky_relu self.python_api = paddle.nn.functional.leaky_relu
self.public_python_api = paddle.nn.functional.leaky_relu
self.prim_op_type = "comp"
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
alpha = self.get_alpha() alpha = self.get_alpha()
...@@ -1948,10 +1950,13 @@ class TestLeakyRelu(TestActivation): ...@@ -1948,10 +1950,13 @@ class TestLeakyRelu(TestActivation):
self.attrs = {'alpha': alpha} self.attrs = {'alpha': alpha}
self.convert_input_output() self.convert_input_output()
def test_check_output(self):
self.check_output(check_prim=True)
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
class TestLeakyReluAlpha1(TestLeakyRelu): class TestLeakyReluAlpha1(TestLeakyRelu):
...@@ -1973,6 +1978,26 @@ class TestLeakyRelu_ZeroDim(TestLeakyRelu): ...@@ -1973,6 +1978,26 @@ class TestLeakyRelu_ZeroDim(TestLeakyRelu):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def setUp(self):
self.op_type = "leaky_relu"
self.prim_op_type = "comp"
self.enable_cinn = False
self.python_api = paddle.nn.functional.leaky_relu
self.public_python_api = paddle.nn.functional.relu
self.init_dtype()
self.init_shape()
alpha = self.get_alpha()
np.random.seed(1024)
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.05
out = ref_leaky_relu(x, alpha)
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {'alpha': alpha}
class TestLeakyReluAPI(unittest.TestCase): class TestLeakyReluAPI(unittest.TestCase):
# test paddle.nn.LeakyReLU, paddle.nn.functional.leaky_relu, # test paddle.nn.LeakyReLU, paddle.nn.functional.leaky_relu,
...@@ -4031,11 +4056,13 @@ create_test_act_fp16_class(TestHardSigmoid) ...@@ -4031,11 +4056,13 @@ create_test_act_fp16_class(TestHardSigmoid)
create_test_act_fp16_class(TestSwish) create_test_act_fp16_class(TestSwish)
create_test_act_fp16_class(TestHardSwish, check_prim=True) create_test_act_fp16_class(TestHardSwish, check_prim=True)
create_test_act_fp16_class(TestMish) create_test_act_fp16_class(TestMish)
create_test_act_fp16_class(TestLeakyRelu) create_test_act_fp16_class(TestLeakyRelu, check_prim=True)
create_test_act_fp16_class(TestLeakyReluAlpha1) create_test_act_fp16_class(TestLeakyReluAlpha1, check_prim=True)
create_test_act_fp16_class(TestLeakyReluAlpha2) create_test_act_fp16_class(TestLeakyReluAlpha2, check_prim=True)
create_test_act_fp16_class(TestLeakyReluAlpha3) create_test_act_fp16_class(TestLeakyReluAlpha3, check_prim=True)
create_test_act_fp16_class(TestLeakyRelu_ZeroDim) create_test_act_fp16_class(
TestLeakyRelu_ZeroDim, check_prim=True, enable_cinn=False
)
create_test_act_fp16_class(TestRsqrt) create_test_act_fp16_class(TestRsqrt)
...@@ -4142,11 +4169,19 @@ create_test_act_bf16_class(TestHardSigmoid) ...@@ -4142,11 +4169,19 @@ create_test_act_bf16_class(TestHardSigmoid)
create_test_act_bf16_class(TestSwish) create_test_act_bf16_class(TestSwish)
create_test_act_bf16_class(TestHardSwish, check_prim=True) create_test_act_bf16_class(TestHardSwish, check_prim=True)
create_test_act_bf16_class(TestMish) create_test_act_bf16_class(TestMish)
create_test_act_bf16_class(TestLeakyRelu) create_test_act_bf16_class(TestLeakyRelu, check_prim=True, enable_cinn=False)
create_test_act_bf16_class(TestLeakyReluAlpha1) create_test_act_bf16_class(
create_test_act_bf16_class(TestLeakyReluAlpha2) TestLeakyReluAlpha1, check_prim=True, enable_cinn=False
create_test_act_bf16_class(TestLeakyReluAlpha3) )
create_test_act_bf16_class(TestLeakyRelu_ZeroDim) create_test_act_bf16_class(
TestLeakyReluAlpha2, check_prim=True, enable_cinn=False
)
create_test_act_bf16_class(
TestLeakyReluAlpha3, check_prim=True, enable_cinn=False
)
create_test_act_bf16_class(
TestLeakyRelu_ZeroDim, check_prim=True, enable_cinn=False
)
create_test_act_bf16_class(TestRsqrt) create_test_act_bf16_class(TestRsqrt)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -677,3 +677,12 @@ def group_norm_composite(x, scale, bias, epsilon, groups, data_layout): ...@@ -677,3 +677,12 @@ def group_norm_composite(x, scale, bias, epsilon, groups, data_layout):
if is_amp: if is_amp:
out = cast(out, "float16") out = cast(out, "float16")
return out, ret_mean_, ret_var_ return out, ret_mean_, ret_var_
@REGISTER_COMPOSITE('leaky_relu')
def leaky_relu_composite(x, negative_slope):
"""define composite rule of op leaky_relu."""
if negative_slope < 1.0:
return maximum(x, negative_slope * x)
else:
return minimum(x, negative_slope * x)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册