未验证 提交 300b687a 编写于 作者: Q qizhaoaoe 提交者: GitHub

implement expand as using tile (#51577)

上级 7a3d05d9
...@@ -24,6 +24,7 @@ import paddle.fluid as fluid ...@@ -24,6 +24,7 @@ import paddle.fluid as fluid
class TestExpandAsBasic(OpTest): class TestExpandAsBasic(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand_as_v2" self.op_type = "expand_as_v2"
self.prim_op_type = "comp"
self.python_api = paddle.expand_as self.python_api = paddle.expand_as
x = np.random.rand(100).astype("float64") x = np.random.rand(100).astype("float64")
target_tensor = np.random.rand(2, 100).astype("float64") target_tensor = np.random.rand(2, 100).astype("float64")
...@@ -34,15 +35,16 @@ class TestExpandAsBasic(OpTest): ...@@ -34,15 +35,16 @@ class TestExpandAsBasic(OpTest):
self.outputs = {'Out': output} self.outputs = {'Out': output}
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=True) self.check_output(check_eager=True, check_prim=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Out', check_eager=True) self.check_grad(['X'], 'Out', check_eager=True, check_prim=True)
class TestExpandAsOpRank2(TestExpandAsBasic): class TestExpandAsOpRank2(TestExpandAsBasic):
def setUp(self): def setUp(self):
self.op_type = "expand_as_v2" self.op_type = "expand_as_v2"
self.prim_op_type = "comp"
self.python_api = paddle.expand_as self.python_api = paddle.expand_as
x = np.random.rand(10, 12).astype("float64") x = np.random.rand(10, 12).astype("float64")
target_tensor = np.random.rand(10, 12).astype("float64") target_tensor = np.random.rand(10, 12).astype("float64")
...@@ -56,6 +58,7 @@ class TestExpandAsOpRank2(TestExpandAsBasic): ...@@ -56,6 +58,7 @@ class TestExpandAsOpRank2(TestExpandAsBasic):
class TestExpandAsOpRank3(TestExpandAsBasic): class TestExpandAsOpRank3(TestExpandAsBasic):
def setUp(self): def setUp(self):
self.op_type = "expand_as_v2" self.op_type = "expand_as_v2"
self.prim_op_type = "comp"
self.python_api = paddle.expand_as self.python_api = paddle.expand_as
x = np.random.rand(2, 3, 20).astype("float64") x = np.random.rand(2, 3, 20).astype("float64")
target_tensor = np.random.rand(2, 3, 20).astype("float64") target_tensor = np.random.rand(2, 3, 20).astype("float64")
...@@ -69,6 +72,7 @@ class TestExpandAsOpRank3(TestExpandAsBasic): ...@@ -69,6 +72,7 @@ class TestExpandAsOpRank3(TestExpandAsBasic):
class TestExpandAsOpRank4(TestExpandAsBasic): class TestExpandAsOpRank4(TestExpandAsBasic):
def setUp(self): def setUp(self):
self.op_type = "expand_as_v2" self.op_type = "expand_as_v2"
self.prim_op_type = "comp"
self.python_api = paddle.expand_as self.python_api = paddle.expand_as
x = np.random.rand(1, 1, 7, 16).astype("float64") x = np.random.rand(1, 1, 7, 16).astype("float64")
target_tensor = np.random.rand(4, 6, 7, 16).astype("float64") target_tensor = np.random.rand(4, 6, 7, 16).astype("float64")
...@@ -84,6 +88,7 @@ class TestExpandAsOpRank5(TestExpandAsBasic): ...@@ -84,6 +88,7 @@ class TestExpandAsOpRank5(TestExpandAsBasic):
def setUp(self): def setUp(self):
self.op_type = "expand_as_v2" self.op_type = "expand_as_v2"
self.prim_op_type = "comp"
self.python_api = paddle.expand_as self.python_api = paddle.expand_as
x = np.random.rand(1, 1, 7, 16).astype("int64") x = np.random.rand(1, 1, 7, 16).astype("int64")
target_tensor = np.random.rand(4, 6, 7, 16).astype("float64") target_tensor = np.random.rand(4, 6, 7, 16).astype("float64")
......
...@@ -225,6 +225,43 @@ def expand_v2_composite(x, shape): ...@@ -225,6 +225,43 @@ def expand_v2_composite(x, shape):
return tile(x, repeat_times=repeat_times) return tile(x, repeat_times=repeat_times)
@REGISTER_COMPOSITE('expand_as_v2')
def expand_as_v2_composite(x, y, target_shape):
"""
define composite rule of op expnad_as_v2, expand_as_v2->expand_as
repeat_times = target_shape / x.shape
out = tile(x, repeat_times = repeat_times)
"""
shape_in = x.shape
if y is not None:
target_shape = y.shape
assert target_shape is not None
dim_out = len(target_shape)
dim_in = len(shape_in)
assert dim_in <= dim_out and dim_out >= 0
repeat_times = []
for i in range(dim_out):
offset = dim_out - i
dim = dim_in - offset
size_in = shape_in[dim] if dim >= 0 else 1
size_out = target_shape[i]
if size_out == -1:
assert dim >= 0
repeat = 1
else:
assert size_out % size_in == 0
repeat = int(size_out / size_in)
repeat_times.append(repeat)
if dim_in < dim_out:
shape_in_expand = []
for i in range(dim_out - dim_in):
shape_in_expand.append(1)
shape_in_expand.extend(shape_in)
x_reshape = reshape(x, shape_in_expand)
return tile(x_reshape, repeat_times=repeat_times)
return tile(x, repeat_times=repeat_times)
@REGISTER_COMPOSITE('stack') @REGISTER_COMPOSITE('stack')
def stack_composite(x, axis): def stack_composite(x, axis):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册