diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc index 83da712bee90881120ee09fc6fad56f7a6a2615a..71ab25a7b0ff8a490d7de0022f810009a58482d4 100644 --- a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc @@ -46,7 +46,7 @@ Call calculation stream synchronization. }; template -class CSyncCalcStreamCudaKernel : public framework::OpKernel { +class CSyncCalcStreamKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { #if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32) @@ -86,5 +86,6 @@ namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(c_sync_calc_stream, ops::CSyncCalcStreamOp, ops::CSyncCalcStreamOpMaker); -REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream, - ops::CSyncCalcStreamCudaKernel); +REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel); + +REGISTER_OP_NPU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel); diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op_npu_test.cc b/paddle/fluid/operators/collective/c_sync_calc_stream_op_npu_test.cc index 4b1f7bb340178748d302f9ec5a5c987a25dae2e3..45613715b8260c3f38968e5cd91f245cd9f524d5 100644 --- a/paddle/fluid/operators/collective/c_sync_calc_stream_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op_npu_test.cc @@ -35,7 +35,7 @@ namespace m = paddle::operators::math; USE_OP(elementwise_add); USE_OP_DEVICE_KERNEL(elementwise_add, NPU); -USE_NO_KERNEL_OP(c_sync_calc_stream); +USE_OP_DEVICE_KERNEL(c_sync_calc_stream, NPU); template void Compare(f::Scope* scope, const p::DeviceContext& ctx) { diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc index 772122bb58d60f14f2b28d0b2483f75ec4a0dd8d..71fda2cd01c8d6007cab19ebeea365467e8e7a99 100644 --- a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc @@ -58,7 +58,7 @@ Call communication stream synchronization. }; template -class CSyncCommStreamCudaKernel : public framework::OpKernel { +class CSyncCommStreamKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto place = ctx.GetPlace(); @@ -97,5 +97,6 @@ namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(c_sync_comm_stream, ops::CSyncCommStreamOp, ops::CSyncCommStreamOpMaker); -REGISTER_OP_CUDA_KERNEL(c_sync_comm_stream, - ops::CSyncCommStreamCudaKernel); +REGISTER_OP_CUDA_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel); + +REGISTER_OP_NPU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel); diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op_npu_test.cc b/paddle/fluid/operators/collective/c_sync_comm_stream_op_npu_test.cc index 3915ec4fa35e8bfbf77095e5afff102d2d924d4d..6c5a6db61483dcd7e3578ded6a12a8a421ca1933 100644 --- a/paddle/fluid/operators/collective/c_sync_comm_stream_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op_npu_test.cc @@ -43,7 +43,7 @@ namespace p = paddle::platform; namespace m = paddle::operators::math; USE_OP(c_broadcast); -USE_NO_KERNEL_OP(c_sync_comm_stream); +USE_OP_DEVICE_KERNEL(c_sync_comm_stream, NPU); USE_NO_KERNEL_OP(c_gen_hccl_id); USE_NO_KERNEL_OP(c_comm_init_hccl); USE_OP_DEVICE_KERNEL(c_broadcast, NPU);