未验证 提交 77395619 编写于 作者: H houj04 提交者: GitHub

[XPU] add int64 support for slice and subtract. (#47409)

* [XPU] add int64 support for slice and subtract. test=kunlun

* try to fix xpu compile. test=kunlun

* try to fix xpu compile. test=kunlun

* try to fix xpu compile. test=kunlun

* remove unnecessary modification. test=kunlun
上级 75b73400
......@@ -178,7 +178,8 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_sub",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"elementwise_mod",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
......@@ -497,7 +498,8 @@ XPUOpMap& get_kl2_ops() {
{"slice",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"softmax",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
......
......@@ -342,7 +342,8 @@ PD_REGISTER_KERNEL(subtract,
ALL_LAYOUT,
phi::SubtractKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16,
int64_t) {}
#endif
#if defined PADDLE_WITH_XPU
......
......@@ -36,4 +36,5 @@ PD_REGISTER_KERNEL(subtract_raw,
ALL_LAYOUT,
phi::SubtractRawKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16,
int64_t) {}
......@@ -113,4 +113,5 @@ PD_REGISTER_KERNEL(slice,
phi::SliceRawKernel,
float,
int,
phi::dtype::float16) {}
phi::dtype::float16,
int64_t) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册