未验证 提交 468c17ff 编写于 作者: C ccrrong 提交者: GitHub

Add stack composite rule (#50861)

* add stack composite rule

* add float16 datatype test
上级 39a9abaa
...@@ -1207,6 +1207,7 @@ set(TEST_CINN_OPS ...@@ -1207,6 +1207,7 @@ set(TEST_CINN_OPS
test_expand_v2_op test_expand_v2_op
test_reduce_op test_reduce_op
test_slice_op test_slice_op
test_stack_op
test_activation_op test_activation_op
test_full_like_op test_full_like_op
test_fill_any_like_op test_fill_any_like_op
......
...@@ -44,6 +44,7 @@ class TestStackOpBase(OpTest): ...@@ -44,6 +44,7 @@ class TestStackOpBase(OpTest):
self.initDefaultParameters() self.initDefaultParameters()
self.initParameters() self.initParameters()
self.op_type = 'stack' self.op_type = 'stack'
self.prim_op_type = "comp"
self.python_api = paddle.stack self.python_api = paddle.stack
self.x = [] self.x = []
for i in range(self.num_inputs): for i in range(self.num_inputs):
...@@ -61,10 +62,12 @@ class TestStackOpBase(OpTest): ...@@ -61,10 +62,12 @@ class TestStackOpBase(OpTest):
self.attrs = {'axis': self.axis} self.attrs = {'axis': self.axis}
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=True) self.check_output(check_eager=True, check_prim=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(self.get_x_names(), 'Y', check_eager=True) self.check_grad(
self.get_x_names(), 'Y', check_eager=True, check_prim=True
)
class TestStackOp1(TestStackOpBase): class TestStackOp1(TestStackOpBase):
...@@ -100,6 +103,7 @@ class TestStackOp6(TestStackOpBase): ...@@ -100,6 +103,7 @@ class TestStackOp6(TestStackOpBase):
class TestStackOp_ZeroDim(TestStackOpBase): class TestStackOp_ZeroDim(TestStackOpBase):
def initParameters(self): def initParameters(self):
self.input_dim = () self.input_dim = ()
self.enable_cinn = False
class TestStackBF16Op(OpTest): class TestStackBF16Op(OpTest):
...@@ -122,6 +126,8 @@ class TestStackBF16Op(OpTest): ...@@ -122,6 +126,8 @@ class TestStackBF16Op(OpTest):
self.initDefaultParameters() self.initDefaultParameters()
self.initParameters() self.initParameters()
self.op_type = 'stack' self.op_type = 'stack'
self.prim_op_type = "comp"
self.enable_cinn = False
self.python_api = paddle.stack self.python_api = paddle.stack
self.x = [] self.x = []
for i in range(self.num_inputs): for i in range(self.num_inputs):
...@@ -141,9 +147,10 @@ class TestStackBF16Op(OpTest): ...@@ -141,9 +147,10 @@ class TestStackBF16Op(OpTest):
self.attrs = {'axis': self.axis} self.attrs = {'axis': self.axis}
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=True) self.check_output(check_eager=True, check_prim=True)
def test_check_grad(self): def test_check_grad(self):
# concat_grad unspport bfloat16 dtype, skip check_prim
self.check_grad(self.get_x_names(), 'Y', check_eager=True) self.check_grad(self.get_x_names(), 'Y', check_eager=True)
......
...@@ -192,6 +192,20 @@ def mean_composite(x, axis, keepdim): ...@@ -192,6 +192,20 @@ def mean_composite(x, axis, keepdim):
return divide(sum_x, norm) return divide(sum_x, norm)
@REGISTER_COMPOSITE('stack')
def stack_composite(x, axis):
"""
define composite rule of op stack
unsqueeze each dimension of the input (use reshape), and then concat
"""
x_shape = x[0].shape
if axis < 0:
axis += len(x_shape) + 1
out_shape = x_shape[:axis] + (1,) + x_shape[axis:]
out = concat([reshape(item, out_shape) for item in x], axis)
return out
@REGISTER_COMPOSITE('flatten_contiguous_range') @REGISTER_COMPOSITE('flatten_contiguous_range')
def flatten_contiguous_range_composite(x, start_axis, stop_axis): def flatten_contiguous_range_composite(x, start_axis, stop_axis):
""" """
......
...@@ -22,6 +22,7 @@ from paddle.tensor import atan # noqa: F401 ...@@ -22,6 +22,7 @@ from paddle.tensor import atan # noqa: F401
from paddle.tensor import atanh # noqa: F401 from paddle.tensor import atanh # noqa: F401
from paddle.tensor import broadcast_shape # noqa: F401 from paddle.tensor import broadcast_shape # noqa: F401
from paddle.tensor import broadcast_to # noqa: F401 from paddle.tensor import broadcast_to # noqa: F401
from paddle.tensor import concat # noqa: F401
from paddle.tensor import cos # noqa: F401 from paddle.tensor import cos # noqa: F401
from paddle.tensor import cosh # noqa: F401 from paddle.tensor import cosh # noqa: F401
from paddle.tensor import cumprod # noqa: F401 from paddle.tensor import cumprod # noqa: F401
...@@ -122,6 +123,7 @@ others = [ ...@@ -122,6 +123,7 @@ others = [
'fill_constant', 'fill_constant',
'reshape', 'reshape',
'full', 'full',
'concat',
'uniform', 'uniform',
'greater_equal', 'greater_equal',
] ]
......
...@@ -1855,7 +1855,14 @@ def stack(x, axis=0, name=None): ...@@ -1855,7 +1855,14 @@ def stack(x, axis=0, name=None):
check_variable_and_dtype( check_variable_and_dtype(
i, i,
'x', 'x',
['float16', 'float32', 'float64', 'int32', 'int64'], [
'float16',
'float32',
'float64',
'int32',
'int64',
'uint16',
],
'stack', 'stack',
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册