From 77395619769b734fbb001a6340ef16ac0e8beeea Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Wed, 2 Nov 2022 10:34:33 +0800 Subject: [PATCH] [XPU] add int64 support for slice and subtract. (#47409) * [XPU] add int64 support for slice and subtract. test=kunlun * try to fix xpu compile. test=kunlun * try to fix xpu compile. test=kunlun * try to fix xpu compile. test=kunlun * remove unnecessary modification. test=kunlun --- paddle/fluid/platform/device/xpu/xpu2_op_list.h | 6 ++++-- paddle/phi/kernels/elementwise_kernel.cc | 3 ++- paddle/phi/kernels/xpu/elementwise_subtract_kernel.cc | 3 ++- paddle/phi/kernels/xpu/slice_kernel.cc | 3 ++- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 8f487cf6cd7..73898354dc1 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -178,7 +178,8 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::FP16, XPUPlace())})}, {"elementwise_sub", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), - pOpKernelType(vartype::FP16, XPUPlace())})}, + pOpKernelType(vartype::FP16, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace())})}, {"elementwise_mod", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()), @@ -497,7 +498,8 @@ XPUOpMap& get_kl2_ops() { {"slice", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace()), - pOpKernelType(vartype::INT32, XPUPlace())})}, + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace())})}, {"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP16, XPUPlace())})}, diff --git a/paddle/phi/kernels/elementwise_kernel.cc b/paddle/phi/kernels/elementwise_kernel.cc index ba58bae0035..88551b34109 100644 --- a/paddle/phi/kernels/elementwise_kernel.cc +++ b/paddle/phi/kernels/elementwise_kernel.cc @@ -342,7 +342,8 @@ PD_REGISTER_KERNEL(subtract, ALL_LAYOUT, phi::SubtractKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + int64_t) {} #endif #if defined PADDLE_WITH_XPU diff --git a/paddle/phi/kernels/xpu/elementwise_subtract_kernel.cc b/paddle/phi/kernels/xpu/elementwise_subtract_kernel.cc index 299b5f80d7d..4e18264d713 100644 --- a/paddle/phi/kernels/xpu/elementwise_subtract_kernel.cc +++ b/paddle/phi/kernels/xpu/elementwise_subtract_kernel.cc @@ -36,4 +36,5 @@ PD_REGISTER_KERNEL(subtract_raw, ALL_LAYOUT, phi::SubtractRawKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + int64_t) {} diff --git a/paddle/phi/kernels/xpu/slice_kernel.cc b/paddle/phi/kernels/xpu/slice_kernel.cc index 3d01fae33e1..b30c6908357 100644 --- a/paddle/phi/kernels/xpu/slice_kernel.cc +++ b/paddle/phi/kernels/xpu/slice_kernel.cc @@ -113,4 +113,5 @@ PD_REGISTER_KERNEL(slice, phi::SliceRawKernel, float, int, - phi::dtype::float16) {} + phi::dtype::float16, + int64_t) {} -- GitLab