未验证 提交 559de39a 编写于 作者: X xysheng-baidu 提交者: GitHub

Add expand composite rule (#50810)

* Add expand composite rule

* reshape x when dim_in less than dim_out

* add tile op for expand

* remove rensor shape case when comp prim

* enable cinn case

* dim_out can't be 0

* update test case for prim type
上级 fcab331d
......@@ -179,7 +179,7 @@ class TestExpandV2OpBoolean(OpTest):
self.check_output()
# Situation 56: input x is Integer
# Situation 6: input x is Integer
class TestExpandV2OpInt64_t(OpTest):
def setUp(self):
self.op_type = "expand_v2"
......@@ -332,6 +332,108 @@ class TestExpandTripleGradCheck(unittest.TestCase):
self.func(p)
# Situation 7: comp case, shape is a list(without tensor)
class TestExpandV2CompOpRank1(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.prim_op_type = "comp"
self.init_data()
self.python_api = paddle.expand
self.inputs = {'X': np.random.random(self.ori_shape).astype("float64")}
self.attrs = {'shape': self.shape}
output = np.tile(self.inputs['X'], self.expand_times)
self.outputs = {'Out': output}
self.enable_cinn = True
def init_data(self):
self.ori_shape = [100]
self.shape = [100]
self.expand_times = [1]
def test_check_output(self):
self.check_output(check_prim=True)
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)
class TestExpandV2OpCompRank2_DimExpanding(TestExpandV2CompOpRank1):
def init_data(self):
self.ori_shape = [120]
self.shape = [2, 120]
self.expand_times = [2, 1]
class TestExpandV2CompOpRank2(TestExpandV2CompOpRank1):
def init_data(self):
self.ori_shape = [1, 140]
self.shape = [12, 140]
self.expand_times = [12, 1]
class TestExpandV2CompOpRank3_Corner(TestExpandV2CompOpRank1):
def init_data(self):
self.ori_shape = (2, 10, 5)
self.shape = (2, 10, 5)
self.expand_times = (1, 1, 1)
class TestExpandV2CompOpRank4(TestExpandV2CompOpRank1):
def init_data(self):
self.ori_shape = (2, 4, 5, 7)
self.shape = (-1, -1, -1, -1)
self.expand_times = (1, 1, 1, 1)
# Situation 8: comp case, input x is Integer
class TestExpandV2CompOpInteger(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.prim_op_type = "comp"
self.python_api = paddle.expand
self.inputs = {
'X': np.random.randint(10, size=(2, 4, 5)).astype("int32")
}
self.attrs = {'shape': [2, 4, 5]}
output = np.tile(self.inputs['X'], (1, 1, 1))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output(check_prim=True)
# Situation 9: comp case, input x is Bool
class TestExpandV2CompOpBoolean(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.prim_op_type = "comp"
self.python_api = paddle.expand
self.inputs = {'X': np.random.randint(2, size=(2, 4, 5)).astype("bool")}
self.attrs = {'shape': [2, 4, 5]}
output = np.tile(self.inputs['X'], (1, 1, 1))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output(check_prim=True)
# Situation 10: comp case, input x is Integer
class TestExpandV2CompOpInt64_t(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.prim_op_type = "comp"
self.python_api = paddle.expand
self.inputs = {
'X': np.random.randint(10, size=(2, 4, 5)).astype("int64")
}
self.attrs = {'shape': [2, 4, 5]}
output = np.tile(self.inputs['X'], (1, 1, 1))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output(check_prim=True)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -192,6 +192,40 @@ def mean_composite(x, axis, keepdim):
return divide(sum_x, norm)
@REGISTER_COMPOSITE('expand_v2')
def expand_v2_composite(x, shape):
"""
define composite rule of op expnad_v2, expand_v2->expand
repeat_times = shape / x.shape
out = tile(x, repeat_times = repeat_times)
"""
shape_in = x.shape
dim_out = len(shape)
dim_in = len(shape_in)
assert dim_in <= dim_out and dim_out >= 0
repeat_times = []
for i in range(dim_out):
offset = dim_out - i
dim = dim_in - offset
size_in = shape_in[dim] if dim >= 0 else 1
size_out = shape[i]
if size_out == -1:
assert dim >= 0
repeat = 1
else:
assert size_out % size_in == 0
repeat = int(size_out / size_in)
repeat_times.append(repeat)
if dim_in < dim_out:
shape_in_expand = []
for i in range(dim_out - dim_in):
shape_in_expand.append(1)
shape_in_expand.extend(shape_in)
x_reshape = reshape(x, shape_in_expand)
return tile(x_reshape, repeat_times=repeat_times)
return tile(x, repeat_times=repeat_times)
@REGISTER_COMPOSITE('stack')
def stack_composite(x, axis):
"""
......
......@@ -57,6 +57,7 @@ from paddle.tensor import subtract # noqa: F401
from paddle.tensor import sum # noqa: F401
from paddle.tensor import tan # noqa: F401
from paddle.tensor import tanh # noqa: F401
from paddle.tensor import tile # noqa: F401
from paddle.tensor import uniform # noqa: F401
from paddle.tensor import zeros # noqa: F401
from paddle.tensor.creation import assign # noqa: F401
......@@ -124,6 +125,7 @@ others = [
'fill_constant',
'reshape',
'full',
'tile',
'concat',
'uniform',
'greater_equal',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册