未验证 提交 9127cc3c 编写于 作者: D denglianbin 提交者: GitHub

【Hackathon No.48】为 Paddle meshgrid 算子实现 float16 数据类型支持 (#53284)

上级 1d549400
......@@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(meshgrid_grad,
GPU,
ALL_LAYOUT,
phi::MeshgridGradKernel,
phi::dtype::float16,
float,
double,
int,
......
......@@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(meshgrid,
GPU,
ALL_LAYOUT,
phi::MeshgridKernel,
phi::dtype::float16,
float,
double,
int,
......
......@@ -76,6 +76,14 @@ class TestMeshgridOp2(TestMeshgridOp):
return [100, 300]
class TestMeshgridOp2Fp16(TestMeshgridOp):
def get_x_shape(self):
return [100, 300]
def get_dtype(self):
return np.float16
class TestMeshgridOp3(unittest.TestCase):
def test_api(self):
x = paddle.static.data(shape=[100], dtype='int32', name='x')
......
......@@ -1512,7 +1512,7 @@ def meshgrid(*args, **kwargs):
Args:
*args(Tensor|list of Tensor) : tensors (tuple(list) of tensor): the shapes of input k tensors are (N1,),
(N2,),..., (Nk,). Support data types: ``float64``, ``float32``, ``int32``, ``int64``.
(N2,),..., (Nk,). Support data types: ``float64``, ``float16``, ``float32``, ``int32``, ``int64``.
**kwargs (optional): Currently, only accept name in **kwargs
The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册