未验证 提交 aba9c4d4 编写于 作者: M mhy-666 提交者: GitHub

Add sqrt composite rule (#51080)

* add sqrt composite rule/test

* add sqrt composite rule/test

* fix ops/sqrt, add cinn test

* fix sqrt_comp

* fix sqrt_comp

* fix sqrt_comp

* fix

* fix codestyle

* fix codestyle

* add fp16 test

* add ops/sqrt

* fix

* fix

* fix unitest

* fix

* fix

* fix
上级 2f2b1f23
......@@ -1235,6 +1235,58 @@ class TestSqrtBF16(OpTest):
self.check_grad_with_place(place, ['X'], 'Out', check_prim=True)
class TestSqrtComp(TestActivation, TestParameter):
def setUp(self):
self.op_type = "sqrt"
self.prim_op_type = "comp"
self.python_api = paddle.sqrt
self.init_dtype()
self.init_shape()
np.random.seed(1023)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
out = np.sqrt(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.enable_cinn = True
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', check_dygraph=True, check_prim=True)
def test_check_output(self):
self.check_output(check_dygraph=True, check_prim=True)
class TestSqrtCompFp32(TestActivation):
def setUp(self):
self.op_type = "sqrt"
self.prim_op_type = "comp"
self.python_api = paddle.sqrt
self.init_dtype()
self.init_shape()
np.random.seed(1023)
x = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
out = np.sqrt(x)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
self.enable_cinn = True
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out', check_dygraph=True, check_prim=True)
def test_check_output(self):
self.check_output(check_dygraph=True, check_prim=True)
def init_dtype(self):
self.dtype = np.float32
class TestRsqrt(TestActivation):
def setUp(self):
self.op_type = "rsqrt"
......@@ -3813,6 +3865,7 @@ create_test_act_fp16_class(TestTanhshrink)
create_test_act_fp16_class(TestHardShrink)
create_test_act_fp16_class(TestSoftshrink)
create_test_act_fp16_class(TestSqrt)
create_test_act_fp16_class(TestSqrtComp, check_prim=True)
create_test_act_fp16_class(TestAbs, check_prim=True)
create_test_act_fp16_class(TestCeil, grad_check=False)
create_test_act_fp16_class(TestFloor, check_prim=True, grad_check=False)
......
......@@ -433,6 +433,17 @@ def fill_any_like(x, fill_value, dtype, place=None):
return val
@REGISTER_COMPOSITE('sqrt')
def sqrt_composite(x):
"""
define composite rule of op sqrt
res = pow(x, 0.5)
"""
y = full(x.shape, 0.5, x.dtype)
res = pow(x, y)
return res
@REGISTER_COMPOSITE('pow')
def pow_composite(x, y):
"""
......
......@@ -920,7 +920,10 @@ def sqrt(x, name=None):
return _C_ops.sqrt(x)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'uint16', 'float32', 'float64'], 'sqrt'
x,
'x',
['float16', 'uint16', 'float32', 'float64'],
'sqrt',
)
helper = LayerHelper('sqrt', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册