diff --git a/paddle/phi/kernels/gpu/meshgrid_grad_kernel.cu.cc b/paddle/phi/kernels/gpu/meshgrid_grad_kernel.cu.cc index 80cf88b3ceb7f54ef1598941f856098cdf71f77c..17f74cd3743bd766a55384add1289fef06c0e7cc 100644 --- a/paddle/phi/kernels/gpu/meshgrid_grad_kernel.cu.cc +++ b/paddle/phi/kernels/gpu/meshgrid_grad_kernel.cu.cc @@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(meshgrid_grad, GPU, ALL_LAYOUT, phi::MeshgridGradKernel, + phi::dtype::float16, float, double, int, diff --git a/paddle/phi/kernels/gpu/meshgrid_kernel.cu.cc b/paddle/phi/kernels/gpu/meshgrid_kernel.cu.cc index c863550979444e26d3f882635e27c5875960a88e..73120c1391642fc24cbc79077e010298ead2c06c 100644 --- a/paddle/phi/kernels/gpu/meshgrid_kernel.cu.cc +++ b/paddle/phi/kernels/gpu/meshgrid_kernel.cu.cc @@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(meshgrid, GPU, ALL_LAYOUT, phi::MeshgridKernel, + phi::dtype::float16, float, double, int, diff --git a/python/paddle/fluid/tests/unittests/test_meshgrid_op.py b/python/paddle/fluid/tests/unittests/test_meshgrid_op.py index 60af417ebc5458e53e0d2bd2908c4fc2092f3b34..0039d4ee422e88d1e3fb6313693ef33dbebc781b 100644 --- a/python/paddle/fluid/tests/unittests/test_meshgrid_op.py +++ b/python/paddle/fluid/tests/unittests/test_meshgrid_op.py @@ -76,6 +76,14 @@ class TestMeshgridOp2(TestMeshgridOp): return [100, 300] +class TestMeshgridOp2Fp16(TestMeshgridOp): + def get_x_shape(self): + return [100, 300] + + def get_dtype(self): + return np.float16 + + class TestMeshgridOp3(unittest.TestCase): def test_api(self): x = paddle.static.data(shape=[100], dtype='int32', name='x') diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 186eda03e74d846d87389e4e1ea5993204e9eb8e..c57fceeeb85252f1691416bccf25aa4b99a5ccbd 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1512,7 +1512,7 @@ def meshgrid(*args, **kwargs): Args: *args(Tensor|list of Tensor) : tensors (tuple(list) of tensor): the shapes of input k tensors are (N1,), - (N2,),..., (Nk,). Support data types: ``float64``, ``float32``, ``int32``, ``int64``. + (N2,),..., (Nk,). Support data types: ``float64``, ``float16``, ``float32``, ``int32``, ``int64``. **kwargs (optional): Currently, only accept name in **kwargs The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.