未验证 提交 468c17ff 编写于 作者: C ccrrong 提交者: GitHub

Add stack composite rule (#50861)

* add stack composite rule

* add float16 datatype test
上级 39a9abaa
......@@ -1207,6 +1207,7 @@ set(TEST_CINN_OPS
test_expand_v2_op
test_reduce_op
test_slice_op
test_stack_op
test_activation_op
test_full_like_op
test_fill_any_like_op
......
......@@ -44,6 +44,7 @@ class TestStackOpBase(OpTest):
self.initDefaultParameters()
self.initParameters()
self.op_type = 'stack'
self.prim_op_type = "comp"
self.python_api = paddle.stack
self.x = []
for i in range(self.num_inputs):
......@@ -61,10 +62,12 @@ class TestStackOpBase(OpTest):
self.attrs = {'axis': self.axis}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=True, check_prim=True)
def test_check_grad(self):
self.check_grad(self.get_x_names(), 'Y', check_eager=True)
self.check_grad(
self.get_x_names(), 'Y', check_eager=True, check_prim=True
)
class TestStackOp1(TestStackOpBase):
......@@ -100,6 +103,7 @@ class TestStackOp6(TestStackOpBase):
class TestStackOp_ZeroDim(TestStackOpBase):
def initParameters(self):
self.input_dim = ()
self.enable_cinn = False
class TestStackBF16Op(OpTest):
......@@ -122,6 +126,8 @@ class TestStackBF16Op(OpTest):
self.initDefaultParameters()
self.initParameters()
self.op_type = 'stack'
self.prim_op_type = "comp"
self.enable_cinn = False
self.python_api = paddle.stack
self.x = []
for i in range(self.num_inputs):
......@@ -141,9 +147,10 @@ class TestStackBF16Op(OpTest):
self.attrs = {'axis': self.axis}
def test_check_output(self):
self.check_output(check_eager=True)
self.check_output(check_eager=True, check_prim=True)
def test_check_grad(self):
# concat_grad unspport bfloat16 dtype, skip check_prim
self.check_grad(self.get_x_names(), 'Y', check_eager=True)
......
......@@ -192,6 +192,20 @@ def mean_composite(x, axis, keepdim):
return divide(sum_x, norm)
@REGISTER_COMPOSITE('stack')
def stack_composite(x, axis):
"""
define composite rule of op stack
unsqueeze each dimension of the input (use reshape), and then concat
"""
x_shape = x[0].shape
if axis < 0:
axis += len(x_shape) + 1
out_shape = x_shape[:axis] + (1,) + x_shape[axis:]
out = concat([reshape(item, out_shape) for item in x], axis)
return out
@REGISTER_COMPOSITE('flatten_contiguous_range')
def flatten_contiguous_range_composite(x, start_axis, stop_axis):
"""
......
......@@ -22,6 +22,7 @@ from paddle.tensor import atan # noqa: F401
from paddle.tensor import atanh # noqa: F401
from paddle.tensor import broadcast_shape # noqa: F401
from paddle.tensor import broadcast_to # noqa: F401
from paddle.tensor import concat # noqa: F401
from paddle.tensor import cos # noqa: F401
from paddle.tensor import cosh # noqa: F401
from paddle.tensor import cumprod # noqa: F401
......@@ -122,6 +123,7 @@ others = [
'fill_constant',
'reshape',
'full',
'concat',
'uniform',
'greater_equal',
]
......
......@@ -1855,7 +1855,14 @@ def stack(x, axis=0, name=None):
check_variable_and_dtype(
i,
'x',
['float16', 'float32', 'float64', 'int32', 'int64'],
[
'float16',
'float32',
'float64',
'int32',
'int64',
'uint16',
],
'stack',
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册