From 53bb883d81c17c959604cd8e2bb0ebda452b8f5b Mon Sep 17 00:00:00 2001 From: chenjian Date: Thu, 23 Mar 2023 14:52:12 +0800 Subject: [PATCH] [Prim] add meshgrid composite rule (#51061) * add meshgrid composite rule * add meshgrid composite rule * update * add into CMakeLists * fix * update * update * optimize code * fix meshgrid op * update test --- paddle/phi/infermeta/multiary.cc | 6 +++- .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/test_meshgrid_op.py | 32 +++++++++++++++++-- .../incubate/autograd/composite_rules.py | 27 ++++++++++++++++ 4 files changed, 63 insertions(+), 3 deletions(-) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 309362ebf30..3f15fdb4424 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -2119,7 +2119,11 @@ void MeshgridInferMeta(const std::vector& inputs, auto out_shape = std::vector(inputs_num); for (size_t i = 0; i < inputs.size(); i++) { - out_shape[i] = inputs[i]->dims()[0]; + if (inputs[i]->dims().size() == 0) { + out_shape[i] = 1; + } else { + out_shape[i] = inputs[i]->dims()[0]; + } } auto out_dims = phi::make_ddim(std::vector(out_shape)); for (size_t i = 0; i < outputs.size(); ++i) { diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index b5b1e74a942..4103c0d0227 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -1214,6 +1214,7 @@ set(TEST_CINN_OPS test_reshape_op test_mean_op test_unsqueeze2_op + test_meshgrid_op test_gather_op) foreach(TEST_CINN_OPS ${TEST_CINN_OPS}) diff --git a/python/paddle/fluid/tests/unittests/test_meshgrid_op.py b/python/paddle/fluid/tests/unittests/test_meshgrid_op.py index 1c08d0bc83d..e9479539408 100644 --- a/python/paddle/fluid/tests/unittests/test_meshgrid_op.py +++ b/python/paddle/fluid/tests/unittests/test_meshgrid_op.py @@ -28,7 +28,9 @@ def meshgrid_wrapper(x): class TestMeshgridOp(OpTest): def setUp(self): self.op_type = "meshgrid" + self.prim_op_type = "comp" self.python_api = meshgrid_wrapper + self.public_python_api = meshgrid_wrapper self.dtype = self.get_dtype() ins, outs = self.init_test_data() self.inputs = {'X': [('x%d' % i, ins[i]) for i in range(len(ins))]} @@ -36,15 +38,16 @@ class TestMeshgridOp(OpTest): 'Out': [('out%d' % i, outs[i]) for i in range(len(outs))] } self.python_out_sig = ['out0', 'out1'] + self.if_enable_cinn() def get_dtype(self): return "float64" def test_check_output(self): - self.check_output() + self.check_output(check_prim=True) def test_check_grad(self): - self.check_grad(['x0'], ['out0', 'out1']) + self.check_grad(['x0'], ['out0', 'out1'], check_prim=True) def init_test_data(self): self.shape = self.get_x_shape() @@ -63,6 +66,9 @@ class TestMeshgridOp(OpTest): def get_x_shape(self): return [100, 200] + def if_enable_cinn(self): + self.enable_cinn = True + class TestMeshgridOp2(TestMeshgridOp): def get_x_shape(self): @@ -257,6 +263,28 @@ class TestMeshgridOp8(unittest.TestCase): assert np.array_equal(res_4.shape, [100, 200]) +class TestMeshGrid_ZeroDim(TestMeshgridOp): + def init_test_data(self): + self.shape = self.get_x_shape() + ins = [] + outs = [] + ins.append(np.random.random(([])).astype(self.dtype)) + ins.append(np.random.random([2]).astype(self.dtype)) + ins.append(np.random.random([3]).astype(self.dtype)) + for i in range(len(self.shape)): + out_reshape = [1] * len(self.shape) + out_reshape[i] = self.shape[i] + out_temp = np.reshape(ins[i], out_reshape) + outs.append(np.broadcast_to(out_temp, self.shape)) + return ins, outs + + def get_x_shape(self): + return [1, 2, 3] + + def if_enable_cinn(self): + self.enable_cinn = False + + class TestMeshgridEager(unittest.TestCase): def test_dygraph_api(self): input_1 = np.random.randint( diff --git a/python/paddle/incubate/autograd/composite_rules.py b/python/paddle/incubate/autograd/composite_rules.py index a6974303060..8e84ab74d0d 100644 --- a/python/paddle/incubate/autograd/composite_rules.py +++ b/python/paddle/incubate/autograd/composite_rules.py @@ -439,6 +439,33 @@ def silu_composite(x): return res +@REGISTER_COMPOSITE('meshgrid') +def meshgrid_composite(inputs): + """ + define composite rule of op meshgrid + If the input has N tensors of size S_0, ... S_n-1, then the output will also have N tensors, where + each tensor is of shape (S_0, ..., S_n-1). + E.g. a1 is Tensor [1,2,3] + b1 is Tensor [4,5] + r1, r2 = paddle.meshgrid([a1, b1]) + r1 is Tensor [[1,1], [2,2], [3,3]] + r2 is Tensor [[4,5], [4,5], [4,5]] + """ + size = len(inputs) + shape = [1] * size + for i in range(size): + dim = inputs[i].dim() + assert dim == 0 or dim == 1 + if dim == 1: + shape[i] = inputs[i].shape[0] + outputs = [] + for i in range(size): + view_shape = [1] * size + view_shape[i] = shape[i] + outputs.append(inputs[i].reshape(view_shape).broadcast_to(shape)) + return outputs + + @REGISTER_COMPOSITE('fill_any_like') def fill_any_like(x, fill_value, dtype, place=None): """define composite rule of op full_like.""" -- GitLab