未验证 提交 19902a12 编写于 作者: H houj04 提交者: GitHub

unsqueeze2 support fp16. test=kunlun (#44142)

上级 7be637a7
......@@ -482,7 +482,8 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"unsqueeze2",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
......@@ -490,7 +491,8 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::UINT8, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"where_index",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
......
......@@ -69,7 +69,7 @@ class XPUTestUnsqueeze2Op(XPUOpTestWrapper):
def test_check_grad(self):
place = paddle.XPUPlace(0)
if self.dtype in [np.float32, np.float64]:
if self.dtype in [np.float32, np.float64, np.float16]:
self.check_grad_with_place(place, ['X'], 'Out')
elif self.dtype == np.bool_:
return
......@@ -147,7 +147,7 @@ class XPUTestUnsqueeze2Op(XPUOpTestWrapper):
def test_check_grad(self):
place = paddle.XPUPlace(0)
if self.dtype in [np.float32, np.float64]:
if self.dtype in [np.float32, np.float64, np.float16]:
self.check_grad_with_place(place, ['X'], 'Out')
else:
return
......@@ -217,7 +217,7 @@ class XPUTestUnsqueeze2Op(XPUOpTestWrapper):
def test_check_grad(self):
place = paddle.XPUPlace(0)
if self.dtype in [np.float32, np.float64]:
if self.dtype in [np.float32, np.float64, np.float16]:
self.check_grad_with_place(place, ['X'], 'Out')
else:
return
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册