From 19902a1291c3f07f9671324ee9e6b42f56f64fbc Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Fri, 8 Jul 2022 14:16:03 +0800 Subject: [PATCH] unsqueeze2 support fp16. test=kunlun (#44142) --- paddle/fluid/platform/device/xpu/xpu2_op_list.h | 6 ++++-- .../fluid/tests/unittests/xpu/test_unsqueeze2_op_xpu.py | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 204cb001504..2fa287b80f4 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -482,7 +482,8 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::INT8, XPUPlace()), pOpKernelType(vartype::UINT8, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})}, + pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"unsqueeze2", XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()), @@ -490,7 +491,8 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::BOOL, XPUPlace()), pOpKernelType(vartype::INT8, XPUPlace()), pOpKernelType(vartype::UINT8, XPUPlace()), - pOpKernelType(vartype::FP32, XPUPlace())})}, + pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP16, XPUPlace())})}, {"where_index", XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::BOOL, XPUPlace()), diff --git a/python/paddle/fluid/tests/unittests/xpu/test_unsqueeze2_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_unsqueeze2_op_xpu.py index e9fc66ca4fc..8ba7f681888 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_unsqueeze2_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_unsqueeze2_op_xpu.py @@ -69,7 +69,7 @@ class XPUTestUnsqueeze2Op(XPUOpTestWrapper): def test_check_grad(self): place = paddle.XPUPlace(0) - if self.dtype in [np.float32, np.float64]: + if self.dtype in [np.float32, np.float64, np.float16]: self.check_grad_with_place(place, ['X'], 'Out') elif self.dtype == np.bool_: return @@ -147,7 +147,7 @@ class XPUTestUnsqueeze2Op(XPUOpTestWrapper): def test_check_grad(self): place = paddle.XPUPlace(0) - if self.dtype in [np.float32, np.float64]: + if self.dtype in [np.float32, np.float64, np.float16]: self.check_grad_with_place(place, ['X'], 'Out') else: return @@ -217,7 +217,7 @@ class XPUTestUnsqueeze2Op(XPUOpTestWrapper): def test_check_grad(self): place = paddle.XPUPlace(0) - if self.dtype in [np.float32, np.float64]: + if self.dtype in [np.float32, np.float64, np.float16]: self.check_grad_with_place(place, ['X'], 'Out') else: return -- GitLab