diff --git a/paddle/fluid/framework/tensor_util.h b/paddle/fluid/framework/tensor_util.h index d1dc5e45c2d8c77322b4764334a36d99e1419888..3466b33c828d71df1e350afa22d295c327c41b96 100644 --- a/paddle/fluid/framework/tensor_util.h +++ b/paddle/fluid/framework/tensor_util.h @@ -189,6 +189,11 @@ void TensorFromArray(const T* src, size, reinterpret_cast(ctx).stream()); } +#endif +#ifdef PADDLE_WITH_XPU + else if (platform::is_xpu_place(dst_place)) { // NOLINT + memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); + } #endif else { // NOLINT PADDLE_THROW(platform::errors::Unimplemented( diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 01fd563f8e3c59f129b5d1e935bd9ec51fe5d898..8f487cf6cd72eedd41f5b27dc112361171092151 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -471,6 +471,7 @@ XPUOpMap& get_kl2_ops() { pOpKernelType(vartype::INT64, XPUPlace())})}, {"scatter", XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, {"sampling_id", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), diff --git a/paddle/phi/kernels/assign_kernel.cc b/paddle/phi/kernels/assign_kernel.cc index 77b9fbc0e1628ac031441dda67246e7e92948ef0..e9f2c7547d4cff85fe280fa970cf86fb617cc62d 100644 --- a/paddle/phi/kernels/assign_kernel.cc +++ b/paddle/phi/kernels/assign_kernel.cc @@ -174,4 +174,12 @@ PD_REGISTER_GENERAL_KERNEL(assign_array, ALL_LAYOUT, phi::AssignArrayKernel, ALL_DTYPE) {} +PD_REGISTER_KERNEL(assign_value, + XPU, + ALL_LAYOUT, + phi::AssignValueKernel, + bool, + int, + float, + int64_t) {} #endif diff --git a/paddle/phi/kernels/batch_norm_kernel.cc b/paddle/phi/kernels/batch_norm_kernel.cc index a0de7842b9e0d417296ab7e965397a691041a679..623b4c1cc745b3b2b214675913a6ff9a59ac1df0 100644 --- a/paddle/phi/kernels/batch_norm_kernel.cc +++ b/paddle/phi/kernels/batch_norm_kernel.cc @@ -88,3 +88,7 @@ PD_REGISTER_KERNEL(batch_norm_infer, float, phi::dtype::float16) {} #endif +#ifdef PADDLE_WITH_XPU +PD_REGISTER_KERNEL( + batch_norm_infer, XPU, ALL_LAYOUT, phi::BatchNormInferKernel, float) {} +#endif diff --git a/paddle/phi/kernels/empty_kernel.cc b/paddle/phi/kernels/empty_kernel.cc index 01b07c438a5270d0290a7cccd3a71401e8311388..6f7500b41f5668ad02c187dcb084e6c7677ea9a2 100644 --- a/paddle/phi/kernels/empty_kernel.cc +++ b/paddle/phi/kernels/empty_kernel.cc @@ -126,4 +126,19 @@ PD_REGISTER_KERNEL(empty, int64_t, bool, phi::dtype::float16) {} +PD_REGISTER_KERNEL(empty_like, + XPU, + ALL_LAYOUT, + phi::EmptyLikeKernel, + float, + double, + int8_t, + uint8_t, + int16_t, + int, + int64_t, + bool, + phi::dtype::float16) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); +} #endif diff --git a/paddle/phi/kernels/xpu/stack_grad_kernel.cc b/paddle/phi/kernels/xpu/stack_grad_kernel.cc index 59319d3e7624ce3ca4fc29531e8f21e3414a6162..719aabae373968120ff1f415ab5d5487be9c38b3 100644 --- a/paddle/phi/kernels/xpu/stack_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/stack_grad_kernel.cc @@ -28,7 +28,7 @@ void StackGradKernel(const Context& dev_ctx, auto outs = x_grad; auto dy_dims = out.dims(); - if (axis < 0) axis += dy_dims.size() + 1; + if (axis < 0) axis += dy_dims.size(); auto dy_shape = phi::vectorize(dy_dims); std::vector dx_dims_list(x_grad.size(), 1); diff --git a/paddle/phi/kernels/xpu/top_k_kernel.cc b/paddle/phi/kernels/xpu/top_k_kernel.cc index d68ff8df8c02e622795335c988d715f7de34ea07..411b74928d0a2a3fd2fffe4a45ddf324eaac7e0e 100644 --- a/paddle/phi/kernels/xpu/top_k_kernel.cc +++ b/paddle/phi/kernels/xpu/top_k_kernel.cc @@ -135,11 +135,11 @@ void TopkKernel(const Context& dev_ctx, // Transpose back to original dims std::vector trans_back_axes; for (int i = 0; i < axis; i++) { - trans_axes.emplace_back(i); + trans_back_axes.emplace_back(i); } - trans_axes.emplace_back(trans_out_dims.size() - 1); + trans_back_axes.emplace_back(trans_out_dims.size() - 1); for (int i = axis; i < trans_out_dims.size() - 1; i++) { - trans_axes.emplace_back(i); + trans_back_axes.emplace_back(i); } std::vector trans_out_shape_host(trans_back_axes.size(), 0);