diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op_xpu.cc b/paddle/fluid/operators/collective/c_sync_calc_stream_op_xpu.cc index 0b432cab281fcf07423e60c264668b1719dfa72b..24157f1c64a6cd9b094896ceb0c918a98a4405f2 100644 --- a/paddle/fluid/operators/collective/c_sync_calc_stream_op_xpu.cc +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op_xpu.cc @@ -17,5 +17,12 @@ limitations under the License. */ namespace ops = paddle::operators; namespace plat = paddle::platform; -PD_REGISTER_STRUCT_KERNEL( - c_sync_calc_stream, XPU, ALL_LAYOUT, ops::CSyncCalcStreamKernel, float) {} +PD_REGISTER_STRUCT_KERNEL(c_sync_calc_stream, + XPU, + ALL_LAYOUT, + ops::CSyncCalcStreamKernel, + float, + double, + int, + int64_t, + plat::float16) {} diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 43461e696da81695dc7f4e8d80f1da29ba182b90..303963450b19797ee40416a7008dcf80b73bac3a 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -115,7 +115,12 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32, phi::DataType::INT32})}, - {"c_sync_calc_stream", XPUKernelSet({phi::DataType::FLOAT32})}, + {"c_sync_calc_stream", + XPUKernelSet({phi::DataType::FLOAT16, + phi::DataType::FLOAT32, + phi::DataType::FLOAT64, + phi::DataType::INT32, + phi::DataType::INT64})}, {"c_sync_comm_stream", XPUKernelSet({phi::DataType::FLOAT32})}, {"cast", XPUKernelSet({phi::DataType::FLOAT32,