未验证 提交 77395619 编写于 作者: H houj04 提交者: GitHub

[XPU] add int64 support for slice and subtract. (#47409)

* [XPU] add int64 support for slice and subtract. test=kunlun

* try to fix xpu compile. test=kunlun

* try to fix xpu compile. test=kunlun

* try to fix xpu compile. test=kunlun

* remove unnecessary modification. test=kunlun
上级 75b73400
...@@ -178,7 +178,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -178,7 +178,8 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_sub", {"elementwise_sub",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"elementwise_mod", {"elementwise_mod",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()),
...@@ -497,7 +498,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -497,7 +498,8 @@ XPUOpMap& get_kl2_ops() {
{"slice", {"slice",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})}, pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"softmax", {"softmax",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
......
...@@ -342,7 +342,8 @@ PD_REGISTER_KERNEL(subtract, ...@@ -342,7 +342,8 @@ PD_REGISTER_KERNEL(subtract,
ALL_LAYOUT, ALL_LAYOUT,
phi::SubtractKernel, phi::SubtractKernel,
float, float,
phi::dtype::float16) {} phi::dtype::float16,
int64_t) {}
#endif #endif
#if defined PADDLE_WITH_XPU #if defined PADDLE_WITH_XPU
......
...@@ -36,4 +36,5 @@ PD_REGISTER_KERNEL(subtract_raw, ...@@ -36,4 +36,5 @@ PD_REGISTER_KERNEL(subtract_raw,
ALL_LAYOUT, ALL_LAYOUT,
phi::SubtractRawKernel, phi::SubtractRawKernel,
float, float,
phi::dtype::float16) {} phi::dtype::float16,
int64_t) {}
...@@ -113,4 +113,5 @@ PD_REGISTER_KERNEL(slice, ...@@ -113,4 +113,5 @@ PD_REGISTER_KERNEL(slice,
phi::SliceRawKernel, phi::SliceRawKernel,
float, float,
int, int,
phi::dtype::float16) {} phi::dtype::float16,
int64_t) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册