From 9c1eb98ad6c979ec9f9b1d6e170843b858c60272 Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Thu, 27 Apr 2023 16:07:36 +0800 Subject: [PATCH] [XPU] c_sync_calc_stream support more types (#53389) --- .../operators/collective/c_sync_calc_stream_op_xpu.cc | 11 +++++++++-- paddle/phi/backends/xpu/xpu2_op_list.cc | 7 ++++++- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op_xpu.cc b/paddle/fluid/operators/collective/c_sync_calc_stream_op_xpu.cc index 0b432cab281..24157f1c64a 100644 --- a/paddle/fluid/operators/collective/c_sync_calc_stream_op_xpu.cc +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op_xpu.cc @@ -17,5 +17,12 @@ limitations under the License. */ namespace ops = paddle::operators; namespace plat = paddle::platform; -PD_REGISTER_STRUCT_KERNEL( - c_sync_calc_stream, XPU, ALL_LAYOUT, ops::CSyncCalcStreamKernel, float) {} +PD_REGISTER_STRUCT_KERNEL(c_sync_calc_stream, + XPU, + ALL_LAYOUT, + ops::CSyncCalcStreamKernel, + float, + double, + int, + int64_t, + plat::float16) {} diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 43461e696da..303963450b1 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -115,7 +115,12 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32, phi::DataType::INT32})}, - {"c_sync_calc_stream", XPUKernelSet({phi::DataType::FLOAT32})}, + {"c_sync_calc_stream", + XPUKernelSet({phi::DataType::FLOAT16, + phi::DataType::FLOAT32, + phi::DataType::FLOAT64, + phi::DataType::INT32, + phi::DataType::INT64})}, {"c_sync_comm_stream", XPUKernelSet({phi::DataType::FLOAT32})}, {"cast", XPUKernelSet({phi::DataType::FLOAT32, -- GitLab