未验证 提交 53bb883d 编写于 作者: C chenjian 提交者: GitHub

[Prim] add meshgrid composite rule (#51061)

* add meshgrid composite rule

* add meshgrid composite rule

* update

* add into CMakeLists

* fix

* update

* update

* optimize code

* fix meshgrid op

* update test
上级 33897a95
...@@ -2119,7 +2119,11 @@ void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs, ...@@ -2119,7 +2119,11 @@ void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
auto out_shape = std::vector<int>(inputs_num); auto out_shape = std::vector<int>(inputs_num);
for (size_t i = 0; i < inputs.size(); i++) { for (size_t i = 0; i < inputs.size(); i++) {
out_shape[i] = inputs[i]->dims()[0]; if (inputs[i]->dims().size() == 0) {
out_shape[i] = 1;
} else {
out_shape[i] = inputs[i]->dims()[0];
}
} }
auto out_dims = phi::make_ddim(std::vector<int>(out_shape)); auto out_dims = phi::make_ddim(std::vector<int>(out_shape));
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
......
...@@ -1214,6 +1214,7 @@ set(TEST_CINN_OPS ...@@ -1214,6 +1214,7 @@ set(TEST_CINN_OPS
test_reshape_op test_reshape_op
test_mean_op test_mean_op
test_unsqueeze2_op test_unsqueeze2_op
test_meshgrid_op
test_gather_op) test_gather_op)
foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) foreach(TEST_CINN_OPS ${TEST_CINN_OPS})
......
...@@ -28,7 +28,9 @@ def meshgrid_wrapper(x): ...@@ -28,7 +28,9 @@ def meshgrid_wrapper(x):
class TestMeshgridOp(OpTest): class TestMeshgridOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "meshgrid" self.op_type = "meshgrid"
self.prim_op_type = "comp"
self.python_api = meshgrid_wrapper self.python_api = meshgrid_wrapper
self.public_python_api = meshgrid_wrapper
self.dtype = self.get_dtype() self.dtype = self.get_dtype()
ins, outs = self.init_test_data() ins, outs = self.init_test_data()
self.inputs = {'X': [('x%d' % i, ins[i]) for i in range(len(ins))]} self.inputs = {'X': [('x%d' % i, ins[i]) for i in range(len(ins))]}
...@@ -36,15 +38,16 @@ class TestMeshgridOp(OpTest): ...@@ -36,15 +38,16 @@ class TestMeshgridOp(OpTest):
'Out': [('out%d' % i, outs[i]) for i in range(len(outs))] 'Out': [('out%d' % i, outs[i]) for i in range(len(outs))]
} }
self.python_out_sig = ['out0', 'out1'] self.python_out_sig = ['out0', 'out1']
self.if_enable_cinn()
def get_dtype(self): def get_dtype(self):
return "float64" return "float64"
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output(check_prim=True)
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['x0'], ['out0', 'out1']) self.check_grad(['x0'], ['out0', 'out1'], check_prim=True)
def init_test_data(self): def init_test_data(self):
self.shape = self.get_x_shape() self.shape = self.get_x_shape()
...@@ -63,6 +66,9 @@ class TestMeshgridOp(OpTest): ...@@ -63,6 +66,9 @@ class TestMeshgridOp(OpTest):
def get_x_shape(self): def get_x_shape(self):
return [100, 200] return [100, 200]
def if_enable_cinn(self):
self.enable_cinn = True
class TestMeshgridOp2(TestMeshgridOp): class TestMeshgridOp2(TestMeshgridOp):
def get_x_shape(self): def get_x_shape(self):
...@@ -257,6 +263,28 @@ class TestMeshgridOp8(unittest.TestCase): ...@@ -257,6 +263,28 @@ class TestMeshgridOp8(unittest.TestCase):
assert np.array_equal(res_4.shape, [100, 200]) assert np.array_equal(res_4.shape, [100, 200])
class TestMeshGrid_ZeroDim(TestMeshgridOp):
def init_test_data(self):
self.shape = self.get_x_shape()
ins = []
outs = []
ins.append(np.random.random(([])).astype(self.dtype))
ins.append(np.random.random([2]).astype(self.dtype))
ins.append(np.random.random([3]).astype(self.dtype))
for i in range(len(self.shape)):
out_reshape = [1] * len(self.shape)
out_reshape[i] = self.shape[i]
out_temp = np.reshape(ins[i], out_reshape)
outs.append(np.broadcast_to(out_temp, self.shape))
return ins, outs
def get_x_shape(self):
return [1, 2, 3]
def if_enable_cinn(self):
self.enable_cinn = False
class TestMeshgridEager(unittest.TestCase): class TestMeshgridEager(unittest.TestCase):
def test_dygraph_api(self): def test_dygraph_api(self):
input_1 = np.random.randint( input_1 = np.random.randint(
......
...@@ -439,6 +439,33 @@ def silu_composite(x): ...@@ -439,6 +439,33 @@ def silu_composite(x):
return res return res
@REGISTER_COMPOSITE('meshgrid')
def meshgrid_composite(inputs):
"""
define composite rule of op meshgrid
If the input has N tensors of size S_0, ... S_n-1, then the output will also have N tensors, where
each tensor is of shape (S_0, ..., S_n-1).
E.g. a1 is Tensor [1,2,3]
b1 is Tensor [4,5]
r1, r2 = paddle.meshgrid([a1, b1])
r1 is Tensor [[1,1], [2,2], [3,3]]
r2 is Tensor [[4,5], [4,5], [4,5]]
"""
size = len(inputs)
shape = [1] * size
for i in range(size):
dim = inputs[i].dim()
assert dim == 0 or dim == 1
if dim == 1:
shape[i] = inputs[i].shape[0]
outputs = []
for i in range(size):
view_shape = [1] * size
view_shape[i] = shape[i]
outputs.append(inputs[i].reshape(view_shape).broadcast_to(shape))
return outputs
@REGISTER_COMPOSITE('fill_any_like') @REGISTER_COMPOSITE('fill_any_like')
def fill_any_like(x, fill_value, dtype, place=None): def fill_any_like(x, fill_value, dtype, place=None):
"""define composite rule of op full_like.""" """define composite rule of op full_like."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册