From 4b85e5db98644fcf7638a00cfa506b8adbe98f50 Mon Sep 17 00:00:00 2001 From: wz1qqx <55830058+wz1qqx@users.noreply.github.com> Date: Fri, 19 May 2023 16:08:08 +0800 Subject: [PATCH] [XPU] fix fallback (#53801) --- paddle/phi/backends/xpu/xpu2_op_list.cc | 6 ++++-- paddle/phi/kernels/xpu/conv_transpose_kernel.cc | 3 ++- paddle/phi/kernels/xpu/split_kernel.cc | 11 ++++++++--- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index ba9742a98d0..87fe9401baf 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -187,7 +187,8 @@ XPUOpMap& get_kl2_ops() { {"depthwise_conv2d", XPUKernelSet({phi::DataType::FLOAT32})}, {"depthwise_conv2d_transpose_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"depthwise_conv2d_transpose", XPUKernelSet({phi::DataType::FLOAT32})}, + {"depthwise_conv2d_transpose", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"diag_v2", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, @@ -709,7 +710,8 @@ XPUOpMap& get_kl2_ops() { {"split", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, - phi::DataType::INT32})}, + phi::DataType::INT32, + phi::DataType::INT64})}, {"split_with_num", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, diff --git a/paddle/phi/kernels/xpu/conv_transpose_kernel.cc b/paddle/phi/kernels/xpu/conv_transpose_kernel.cc index f658f06a990..f6166ff61f7 100644 --- a/paddle/phi/kernels/xpu/conv_transpose_kernel.cc +++ b/paddle/phi/kernels/xpu/conv_transpose_kernel.cc @@ -177,7 +177,8 @@ PD_REGISTER_KERNEL(depthwise_conv2d_transpose, XPU, ALL_LAYOUT, phi::DepthwiseConv2dTransposeKernel, - float) {} + float, + phi::dtype::float16) {} PD_REGISTER_KERNEL(conv2d_transpose, XPU, diff --git a/paddle/phi/kernels/xpu/split_kernel.cc b/paddle/phi/kernels/xpu/split_kernel.cc index 6a9bf54876c..f94dbaac037 100644 --- a/paddle/phi/kernels/xpu/split_kernel.cc +++ b/paddle/phi/kernels/xpu/split_kernel.cc @@ -66,9 +66,14 @@ void SplitWithNumKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - split, XPU, ALL_LAYOUT, phi::SplitKernel, float, int, phi::dtype::float16) { -} +PD_REGISTER_KERNEL(split, + XPU, + ALL_LAYOUT, + phi::SplitKernel, + float, + int64_t, + int, + phi::dtype::float16) {} PD_REGISTER_KERNEL(split_with_num, XPU, ALL_LAYOUT, -- GitLab