未验证 提交 2d9e103e 编写于 作者: C chenjian 提交者: GitHub

[Prim] add pow composite rule (#51070)

* add pow composite rule

* fix test

* fix unit test

* update test

* fix test

* update
上级 9045b882
......@@ -2914,6 +2914,7 @@ class TestSquareBF16(OpTest):
class TestPow(TestActivation):
def setUp(self):
self.op_type = "pow"
self.prim_op_type = "comp"
self.python_api = paddle.pow
self.init_dtype()
self.init_shape()
......@@ -2927,23 +2928,28 @@ class TestPow(TestActivation):
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
self.check_output(check_prim=True)
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_prim=True)
class TestPow_ZeroDim(TestPow):
def init_shape(self):
self.shape = []
def setUp(self):
super(TestPow_ZeroDim, self).setUp()
self.enable_cinn = False
class TestPow_factor_tensor(TestActivation):
def setUp(self):
self.op_type = "pow"
self.python_api = paddle.pow
self.enable_cinn = False
self.init_dtype()
np.random.seed(1024)
......@@ -3826,7 +3832,7 @@ else:
create_test_act_fp16_class(TestLog10, atol=5e-2)
create_test_act_fp16_class(TestLog1p, grad_atol=0.9)
create_test_act_fp16_class(TestSquare)
create_test_act_fp16_class(TestPow, atol=5e-2)
create_test_act_fp16_class(TestPow, check_prim=True, atol=5e-2)
create_test_act_fp16_class(TestPow_factor_tensor, atol=5e-2)
create_test_act_fp16_class(TestSTanh, grad_atol=0.9)
create_test_act_fp16_class(TestSoftplus)
......
......@@ -401,6 +401,18 @@ def fill_any_like(x, fill_value, dtype, place=None):
return val
@REGISTER_COMPOSITE('pow')
def pow_composite(x, y):
"""
define composite rule of op pow
res = x^y
"""
if isinstance(y, (int, float)):
y = full([1], y, x.dtype)
res = pow(x, y)
return res
@REGISTER_COMPOSITE('relu')
def relu_composite(x):
"""define composite rule of op relu."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册