diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index ba9742a98d01816d8f140fe9212d51ff06718224..87fe9401bafe5f1e7446feada76ae8d44278af12 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -187,7 +187,8 @@ XPUOpMap& get_kl2_ops() { {"depthwise_conv2d", XPUKernelSet({phi::DataType::FLOAT32})}, {"depthwise_conv2d_transpose_grad", XPUKernelSet({phi::DataType::FLOAT32})}, - {"depthwise_conv2d_transpose", XPUKernelSet({phi::DataType::FLOAT32})}, + {"depthwise_conv2d_transpose", + XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"diag_v2", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, @@ -709,7 +710,8 @@ XPUOpMap& get_kl2_ops() { {"split", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, - phi::DataType::INT32})}, + phi::DataType::INT32, + phi::DataType::INT64})}, {"split_with_num", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, diff --git a/paddle/phi/kernels/xpu/conv_transpose_kernel.cc b/paddle/phi/kernels/xpu/conv_transpose_kernel.cc index f658f06a9908d91bec35d491fbadb377e6c45701..f6166ff61f7233ef3d9bf0b077e2024cf20f6c27 100644 --- a/paddle/phi/kernels/xpu/conv_transpose_kernel.cc +++ b/paddle/phi/kernels/xpu/conv_transpose_kernel.cc @@ -177,7 +177,8 @@ PD_REGISTER_KERNEL(depthwise_conv2d_transpose, XPU, ALL_LAYOUT, phi::DepthwiseConv2dTransposeKernel, - float) {} + float, + phi::dtype::float16) {} PD_REGISTER_KERNEL(conv2d_transpose, XPU, diff --git a/paddle/phi/kernels/xpu/split_kernel.cc b/paddle/phi/kernels/xpu/split_kernel.cc index 6a9bf54876c6a3b9ac22469b9d829ae32dc89831..f94dbaac037c70671291c24c0a6ab9d986c2fcaf 100644 --- a/paddle/phi/kernels/xpu/split_kernel.cc +++ b/paddle/phi/kernels/xpu/split_kernel.cc @@ -66,9 +66,14 @@ void SplitWithNumKernel(const Context& dev_ctx, } // namespace phi -PD_REGISTER_KERNEL( - split, XPU, ALL_LAYOUT, phi::SplitKernel, float, int, phi::dtype::float16) { -} +PD_REGISTER_KERNEL(split, + XPU, + ALL_LAYOUT, + phi::SplitKernel, + float, + int64_t, + int, + phi::dtype::float16) {} PD_REGISTER_KERNEL(split_with_num, XPU, ALL_LAYOUT,