未验证 提交 2d9e103e 编写于 作者: C chenjian 提交者: GitHub

[Prim] add pow composite rule (#51070)

* add pow composite rule

* fix test

* fix unit test

* update test

* fix test

* update
上级 9045b882
...@@ -2914,6 +2914,7 @@ class TestSquareBF16(OpTest): ...@@ -2914,6 +2914,7 @@ class TestSquareBF16(OpTest):
class TestPow(TestActivation): class TestPow(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "pow" self.op_type = "pow"
self.prim_op_type = "comp"
self.python_api = paddle.pow self.python_api = paddle.pow
self.init_dtype() self.init_dtype()
self.init_shape() self.init_shape()
...@@ -2927,23 +2928,28 @@ class TestPow(TestActivation): ...@@ -2927,23 +2928,28 @@ class TestPow(TestActivation):
self.outputs = {'Out': out} self.outputs = {'Out': out}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_prim=True)
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out', check_prim=True)
class TestPow_ZeroDim(TestPow): class TestPow_ZeroDim(TestPow):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
def setUp(self):
super(TestPow_ZeroDim, self).setUp()
self.enable_cinn = False
class TestPow_factor_tensor(TestActivation): class TestPow_factor_tensor(TestActivation):
def setUp(self): def setUp(self):
self.op_type = "pow" self.op_type = "pow"
self.python_api = paddle.pow self.python_api = paddle.pow
self.enable_cinn = False
self.init_dtype() self.init_dtype()
np.random.seed(1024) np.random.seed(1024)
...@@ -3826,7 +3832,7 @@ else: ...@@ -3826,7 +3832,7 @@ else:
create_test_act_fp16_class(TestLog10, atol=5e-2) create_test_act_fp16_class(TestLog10, atol=5e-2)
create_test_act_fp16_class(TestLog1p, grad_atol=0.9) create_test_act_fp16_class(TestLog1p, grad_atol=0.9)
create_test_act_fp16_class(TestSquare) create_test_act_fp16_class(TestSquare)
create_test_act_fp16_class(TestPow, atol=5e-2) create_test_act_fp16_class(TestPow, check_prim=True, atol=5e-2)
create_test_act_fp16_class(TestPow_factor_tensor, atol=5e-2) create_test_act_fp16_class(TestPow_factor_tensor, atol=5e-2)
create_test_act_fp16_class(TestSTanh, grad_atol=0.9) create_test_act_fp16_class(TestSTanh, grad_atol=0.9)
create_test_act_fp16_class(TestSoftplus) create_test_act_fp16_class(TestSoftplus)
......
...@@ -401,6 +401,18 @@ def fill_any_like(x, fill_value, dtype, place=None): ...@@ -401,6 +401,18 @@ def fill_any_like(x, fill_value, dtype, place=None):
return val return val
@REGISTER_COMPOSITE('pow')
def pow_composite(x, y):
"""
define composite rule of op pow
res = x^y
"""
if isinstance(y, (int, float)):
y = full([1], y, x.dtype)
res = pow(x, y)
return res
@REGISTER_COMPOSITE('relu') @REGISTER_COMPOSITE('relu')
def relu_composite(x): def relu_composite(x):
"""define composite rule of op relu.""" """define composite rule of op relu."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册