From 8fd724a5026e9b5da3a68225566ea4861338d9e2 Mon Sep 17 00:00:00 2001 From: Baibaifan <39549453+Baibaifan@users.noreply.github.com> Date: Fri, 30 Apr 2021 11:24:43 +0800 Subject: [PATCH] add_c_sync_npu_kernel (#32687) --- paddle/fluid/operators/collective/c_sync_calc_stream_op.cc | 7 ++++--- .../operators/collective/c_sync_calc_stream_op_npu_test.cc | 2 +- paddle/fluid/operators/collective/c_sync_comm_stream_op.cc | 7 ++++--- .../operators/collective/c_sync_comm_stream_op_npu_test.cc | 2 +- 4 files changed, 10 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc index 83da712bee9..71ab25a7b0f 100644 --- a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc @@ -46,7 +46,7 @@ Call calculation stream synchronization. }; template -class CSyncCalcStreamCudaKernel : public framework::OpKernel { +class CSyncCalcStreamKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { #if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32) @@ -86,5 +86,6 @@ namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(c_sync_calc_stream, ops::CSyncCalcStreamOp, ops::CSyncCalcStreamOpMaker); -REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream, - ops::CSyncCalcStreamCudaKernel); +REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel); + +REGISTER_OP_NPU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel); diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op_npu_test.cc b/paddle/fluid/operators/collective/c_sync_calc_stream_op_npu_test.cc index 4b1f7bb3401..45613715b82 100644 --- a/paddle/fluid/operators/collective/c_sync_calc_stream_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op_npu_test.cc @@ -35,7 +35,7 @@ namespace m = paddle::operators::math; USE_OP(elementwise_add); USE_OP_DEVICE_KERNEL(elementwise_add, NPU); -USE_NO_KERNEL_OP(c_sync_calc_stream); +USE_OP_DEVICE_KERNEL(c_sync_calc_stream, NPU); template void Compare(f::Scope* scope, const p::DeviceContext& ctx) { diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc index 772122bb58d..71fda2cd01c 100644 --- a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc @@ -58,7 +58,7 @@ Call communication stream synchronization. }; template -class CSyncCommStreamCudaKernel : public framework::OpKernel { +class CSyncCommStreamKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto place = ctx.GetPlace(); @@ -97,5 +97,6 @@ namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(c_sync_comm_stream, ops::CSyncCommStreamOp, ops::CSyncCommStreamOpMaker); -REGISTER_OP_CUDA_KERNEL(c_sync_comm_stream, - ops::CSyncCommStreamCudaKernel); +REGISTER_OP_CUDA_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel); + +REGISTER_OP_NPU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel); diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op_npu_test.cc b/paddle/fluid/operators/collective/c_sync_comm_stream_op_npu_test.cc index 3915ec4fa35..6c5a6db6148 100644 --- a/paddle/fluid/operators/collective/c_sync_comm_stream_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op_npu_test.cc @@ -43,7 +43,7 @@ namespace p = paddle::platform; namespace m = paddle::operators::math; USE_OP(c_broadcast); -USE_NO_KERNEL_OP(c_sync_comm_stream); +USE_OP_DEVICE_KERNEL(c_sync_comm_stream, NPU); USE_NO_KERNEL_OP(c_gen_hccl_id); USE_NO_KERNEL_OP(c_comm_init_hccl); USE_OP_DEVICE_KERNEL(c_broadcast, NPU); -- GitLab