From 9127cc3c969cd3c94c91cac90a0fb6c164c29d9a Mon Sep 17 00:00:00 2001 From: denglianbin <112610123+denglianbin@users.noreply.github.com> Date: Wed, 26 Apr 2023 11:15:24 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=20No.48=E3=80=91=E4=B8=BA=20?= =?UTF-8?q?Paddle=20meshgrid=20=E7=AE=97=E5=AD=90=E5=AE=9E=E7=8E=B0=20floa?= =?UTF-8?q?t16=20=E6=95=B0=E6=8D=AE=E7=B1=BB=E5=9E=8B=E6=94=AF=E6=8C=81=20?= =?UTF-8?q?(#53284)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- paddle/phi/kernels/gpu/meshgrid_grad_kernel.cu.cc | 1 + paddle/phi/kernels/gpu/meshgrid_kernel.cu.cc | 1 + python/paddle/fluid/tests/unittests/test_meshgrid_op.py | 8 ++++++++ python/paddle/tensor/creation.py | 2 +- 4 files changed, 11 insertions(+), 1 deletion(-) diff --git a/paddle/phi/kernels/gpu/meshgrid_grad_kernel.cu.cc b/paddle/phi/kernels/gpu/meshgrid_grad_kernel.cu.cc index 80cf88b3ceb..17f74cd3743 100644 --- a/paddle/phi/kernels/gpu/meshgrid_grad_kernel.cu.cc +++ b/paddle/phi/kernels/gpu/meshgrid_grad_kernel.cu.cc @@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(meshgrid_grad, GPU, ALL_LAYOUT, phi::MeshgridGradKernel, + phi::dtype::float16, float, double, int, diff --git a/paddle/phi/kernels/gpu/meshgrid_kernel.cu.cc b/paddle/phi/kernels/gpu/meshgrid_kernel.cu.cc index c8635509794..73120c13916 100644 --- a/paddle/phi/kernels/gpu/meshgrid_kernel.cu.cc +++ b/paddle/phi/kernels/gpu/meshgrid_kernel.cu.cc @@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(meshgrid, GPU, ALL_LAYOUT, phi::MeshgridKernel, + phi::dtype::float16, float, double, int, diff --git a/python/paddle/fluid/tests/unittests/test_meshgrid_op.py b/python/paddle/fluid/tests/unittests/test_meshgrid_op.py index 60af417ebc5..0039d4ee422 100644 --- a/python/paddle/fluid/tests/unittests/test_meshgrid_op.py +++ b/python/paddle/fluid/tests/unittests/test_meshgrid_op.py @@ -76,6 +76,14 @@ class TestMeshgridOp2(TestMeshgridOp): return [100, 300] +class TestMeshgridOp2Fp16(TestMeshgridOp): + def get_x_shape(self): + return [100, 300] + + def get_dtype(self): + return np.float16 + + class TestMeshgridOp3(unittest.TestCase): def test_api(self): x = paddle.static.data(shape=[100], dtype='int32', name='x') diff --git a/python/paddle/tensor/creation.py b/python/paddle/tensor/creation.py index 186eda03e74..c57fceeeb85 100644 --- a/python/paddle/tensor/creation.py +++ b/python/paddle/tensor/creation.py @@ -1512,7 +1512,7 @@ def meshgrid(*args, **kwargs): Args: *args(Tensor|list of Tensor) : tensors (tuple(list) of tensor): the shapes of input k tensors are (N1,), - (N2,),..., (Nk,). Support data types: ``float64``, ``float32``, ``int32``, ``int64``. + (N2,),..., (Nk,). Support data types: ``float64``, ``float16``, ``float32``, ``int32``, ``int64``. **kwargs (optional): Currently, only accept name in **kwargs The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`. -- GitLab