From 1bfbcfaf5ad9e737b19556cdbfa2b068ec9439b2 Mon Sep 17 00:00:00 2001 From: Fan Zhang Date: Thu, 2 Jun 2022 15:54:10 +0800 Subject: [PATCH] [XPUPS] modify BKCL comm op register (#43028) * Adapt XPUPS - 1st version - 3.24 * Adapt XPUPS - update XPU PushSparse - 2nd version - 3.24 * Adapt XPUPS - add XPU PullSparseOp - 3nd version - 3.25 * refactor heter comm kernel * update. test=develop * Adapt XPUPS - modify by compilation - 4th version - 3.27 * update calc_shard_offset. test=develop * update xpu kernel. test=develop * update args of calc_shard_offset * update. test=develop * remove customGradMerger * update. test=develop * heter_comm update * heter_comm update * update calc_shard_offset. test=develop * heter_comm update * update args of calc_shard_offset * update. test=develop * remove customGradMerger * update. test=develop * fix. test=develop * update. test=develop * update. test=develop * update optimizer kernel * Adapt XPUPS - use WITH_XPU_KP and modify wrapper kernel function - 5th version - 3.30 * update. test=develop * update pslib.cmake * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * Adapt XPUPS - modify by kp compilation - 6th version - 3.30 * update. test=develop * update. test=develop * update. test=develop * update optimizer kernel * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * fix. test=develop * fix. test=develop * used by minxu * update heter_comm_inl * fix. test=develop * Adapt XPUPS - modify by kp compilation - 7th version - 3.30 * fix. test=develop * add optimizer kernel. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * 3.31 update * Adapt XPUPS - update kp compilation path - 8th version - 3.31 * add optimizer kernel. test=develop * fix kunlun not support size_t. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix kunlun not support size_t. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * update heter_comm_kernel.kps 3.31 * fix. test=develop * fix. test=develop * update heter_comm_kernel.kps 3.31 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * update heter_comm.h 3.31 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * update hashtable. test=develop * update. test=develop * Adapt XPUPS - update by kp compilation - 9th version - 4.1 * update hashtable. test=develop * fix. test=develop * update hashtable 4.1 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * Adapt XPUPS - update by kp compilation - 10th version - 4.1 * fix. test=develop * fix. test=develop * fix. test=develop * update. test=develop * modify by compilation 4.1 * update. test=develop * update. test=develop * fix. test=develop * modify by compilation 4.1 * update. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * modify by compilation 4.1 * fix. test=develop * fix. test=develop * fix. test=develop * modify by compilation 4.1 19:30 * fix. test=develop * update ps_gpu_wrapper.kps 4.1 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * Adapt XPUPS - update by kp compilation - 11th version - 4.1 * fix. test=develop * Adapt XPUPS - update by kp compilation - 12nd version - 4.2 * fix. test=develop * fix. test=develop * modify by compilation 4.2 * 4.2 update * fix. test=develop * template init. test=develop * update 4.6 * fix. test=develop * template init. test=develop * 4.6 modify by compilation * hashtable template init. test=develop * hashtable template init. test=develop * fix. test=develop * fix. test=develop * fix. test=devlop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=devlop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * Adapt XPUPS - update by kp compilation - 13nd version - 4.7 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * 4.11 update * fix. test=develop * fix. test=develop * 4.11 update * update by pre-commit * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * 4.12 update * fix. test=develop * Adapt XPUPS - update by kp compilation - 14th version - 4.13 * 4.13 update * 4.14 update * 4.14 update * 4.14 update * 4.14 modify by merged latest compilation * retry CI 4.14 * 4.15 pass static check * 4.15 modify by gpups CI * 3.16 update by gpups CI - modify ps_gpu_wrapper.h * 4.16 update * 4.16 pass xpu compile * 4.16 retry CI * 4.16 update * Adapt XPUPS - adapt BKCL comm for XPUPS - 4.24 * update by compilation * Adapt XPUPS - register PSGPUTrainer for XPUPS - 4.25 * update device_worker_factory * Adapt XPUPS - split heter_ps into .cu and .cc - 4.27 * Adapt XPUPS - register pull_box_sparse op under XPU_KP - 4.28 * update * 5.7 modify ps_gpu_wrapper pull_sparse * 5.11 update ps_gpu_wrapper CopyKeysKernel * 5.13 modify calc_shard_offset_kernel & fill_shard_key_kernel * modify fill_dvals_kernel & PullCopy & c_sync_calc_stream - 5.18 * modify PushCopy & fill_shard_grads_kernel & register push_box_sparse - 5.19 * Adapt XPUPS - modify BKCL comm op register - 5.26 * Adapt XPUPS - modify BKCL comm op register - 5.27 * Adapt XPUPS - modify BKCL comm op register - 5.27v2 * Adapt XPUPS - modify BKCL comm op register - 5.27v3 * Adapt XPUPS - modify c_comm_init_all_op to adapt BKCL init - 5.30 * Adapt XPUPS - modify c_comm_init_all_op to adapt BKCL init v2 - 5.30 * Adapt XPUPS - modify c_comm_init_all_op to adapt BKCL init v3 - 5.30 * Adapt XPUPS - modify c_comm_init_all_op to adapt BKCL init v4 - 5.31 Co-authored-by: zmxdream --- ...e_sum_op_xpu.cc => c_allreduce_sum_op.kps} | 27 +++++- .../collective/c_comm_init_all_op.cc | 58 ++++++++++-- .../operators/collective/c_comm_init_op.cc | 9 -- .../collective/c_sync_calc_stream_op.cc | 1 - .../collective/c_sync_calc_stream_op.kps | 42 +++++++++ .../collective/c_sync_comm_stream_op.cc | 73 +-------------- .../collective/c_sync_comm_stream_op.h | 88 +++++++++++++++++++ .../collective/c_sync_comm_stream_op.kps | 42 +++++++++ .../platform/device/xpu/xpu_op_kpfirst_list.h | 6 ++ python/paddle/fluid/transpiler/collective.py | 34 +++---- 10 files changed, 266 insertions(+), 114 deletions(-) rename paddle/fluid/operators/collective/{c_allreduce_sum_op_xpu.cc => c_allreduce_sum_op.kps} (58%) create mode 100644 paddle/fluid/operators/collective/c_sync_calc_stream_op.kps create mode 100644 paddle/fluid/operators/collective/c_sync_comm_stream_op.h create mode 100644 paddle/fluid/operators/collective/c_sync_comm_stream_op.kps diff --git a/paddle/fluid/operators/collective/c_allreduce_sum_op_xpu.cc b/paddle/fluid/operators/collective/c_allreduce_sum_op.kps similarity index 58% rename from paddle/fluid/operators/collective/c_allreduce_sum_op_xpu.cc rename to paddle/fluid/operators/collective/c_allreduce_sum_op.kps index d23572e6d6..3230d2c9ec 100644 --- a/paddle/fluid/operators/collective/c_allreduce_sum_op_xpu.cc +++ b/paddle/fluid/operators/collective/c_allreduce_sum_op.kps @@ -1,4 +1,4 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,10 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef PADDLE_WITH_XPU_KP + +// Please do not modify the following code +#if defined(__CUDA_ARCH__) +#undef __CUDA_ARCH__ +#endif + +#if defined(__CUDACC__) +#undef __CUDACC__ +#endif + +#if defined(__CUDA__) +#undef __CUDA__ +#endif + +#if defined(__NVCC__) +#undef __NVCC__ +#endif + #include "paddle/fluid/operators/collective/c_allreduce_op.h" namespace ops = paddle::operators; namespace plat = paddle::platform; -REGISTER_OP_XPU_KERNEL(c_allreduce_sum, - ops::CAllReduceOpXPUKernel) +REGISTER_OP_KERNEL(c_allreduce_sum, KP, plat::XPUPlace, + ops::CAllReduceOpXPUKernel); + +#endif diff --git a/paddle/fluid/operators/collective/c_comm_init_all_op.cc b/paddle/fluid/operators/collective/c_comm_init_all_op.cc index 5820bd318d..ce2da1f22f 100644 --- a/paddle/fluid/operators/collective/c_comm_init_all_op.cc +++ b/paddle/fluid/operators/collective/c_comm_init_all_op.cc @@ -17,11 +17,16 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/threadpool.h" -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/device/gpu/nccl_helper.h" #endif +#if defined(PADDLE_WITH_XPU_BKCL) +#include "paddle/fluid/platform/device/xpu/bkcl_helper.h" +#endif + namespace paddle { namespace framework { class InferShapeContext; @@ -48,9 +53,9 @@ class CCommInitAllOp : public framework::OperatorBase { void RunImpl(const framework::Scope& scope, const platform::Place& place) const override { - PADDLE_ENFORCE_EQ(platform::is_gpu_place(place), true, - platform::errors::PreconditionNotMet( - "CCommInitAllOp can run on gpu place only")); +// PADDLE_ENFORCE_EQ(platform::is_gpu_place(place), true, +// platform::errors::PreconditionNotMet( +// "CCommInitAllOp can run on gpu place only")); #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) std::vector devices = Attr>("devices"); @@ -61,9 +66,52 @@ class CCommInitAllOp : public framework::OperatorBase { int rid = Attr("ring_id"); platform::NCCLCommContext::Instance().CreateAllNCCLComms(devices, rid); + +#elif defined(PADDLE_WITH_XPU_BKCL) + std::vector devices = Attr>("devices"); + int ring_id = Attr("ring_id"); + + if (devices.empty()) { + int count = platform::GetXPUDeviceCount(); + for (int i = 0; i < count; ++i) { + devices.push_back(i); + } + } + + if (devices.size() > 1) { + std::vector place_list_; + for (size_t i = 0; i < devices.size(); ++i) { + auto p = platform::XPUPlace(devices[i]); + place_list_.push_back(p); + } + + // create pthread to bkcl_init_rank on all devices + auto ptr = new platform::BKCLContextMap(place_list_); + ptr->init(); + + for (size_t i = 0; i < devices.size(); ++i) { + platform::BKCLCommContext::Instance().AssignBKCLComm( + ptr->contexts_.at(devices[i]).comm_, devices.size(), devices[i], + devices[i], ring_id); + + VLOG(0) << "bkcl communicator of rank " << devices[i] << " in ring " + << ring_id << " has been created on device " << devices[i]; + + // TODO(WorgenZhang): need release comm_map_ when quit + // std::call_once(once_flag_, []() { + // std::atexit([]() { + // platform::BKCLCommContext::Instance().ReleaseBKCLComms(); }); + // }); + } + + VLOG(0) << "done bkcl_init_rank on all devices"; + } else { + VLOG(0) + << "bkcl_init_rank doesn't support on one device, skip init process"; + } #else PADDLE_THROW(platform::errors::PreconditionNotMet( - "PaddlePaddle should compile with GPU.")); + "PaddlePaddle should compile with GPU or XPU.")); #endif } }; diff --git a/paddle/fluid/operators/collective/c_comm_init_op.cc b/paddle/fluid/operators/collective/c_comm_init_op.cc index 82d3b1b1db..490747520d 100644 --- a/paddle/fluid/operators/collective/c_comm_init_op.cc +++ b/paddle/fluid/operators/collective/c_comm_init_op.cc @@ -97,18 +97,9 @@ class CCommInitOp : public framework::OperatorBase { if (Attr("device_id") >= 0) { device_id = Attr("device_id"); } - -#if defined(PADDLE_WITH_XPU_BKCL) && defined(PADDLE_WITH_HETERPS) && \ - defined(PADDLE_WITH_PSLIB) - // XPUPS rank_id only equals 0, so replace rank_id with device_id - CommContext::Instance().CreateComm(comm_id, nranks, device_id, device_id, - rid); -#else int rank_id = Attr("rank"); CommContext::Instance().CreateComm(comm_id, nranks, rank_id, device_id, rid); -#endif - #endif } }; diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc index 6ad22ff8b1..bf7434686b 100644 --- a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc @@ -23,7 +23,6 @@ class CSyncCalcStreamOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("Out", "(Tensor) Dependency of the variable need to sync"); AddComment(R"DOC( CSyncCalcStream Operator - Call calculation stream synchronization. )DOC"); } diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op.kps b/paddle/fluid/operators/collective/c_sync_calc_stream_op.kps new file mode 100644 index 0000000000..65126f416c --- /dev/null +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op.kps @@ -0,0 +1,42 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU_KP + +// Please do not modify the following code +#if defined(__CUDA_ARCH__) +#undef __CUDA_ARCH__ +#endif + +#if defined(__CUDACC__) +#undef __CUDACC__ +#endif + +#if defined(__CUDA__) +#undef __CUDA__ +#endif + +#if defined(__NVCC__) +#undef __NVCC__ +#endif + +#include "paddle/fluid/operators/collective/c_sync_calc_stream_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_KERNEL(c_sync_calc_stream, KP, plat::XPUPlace, + ops::CSyncCalcStreamKernel); + +#endif diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc index 5a9a00aa8e..a3717459a2 100644 --- a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc @@ -11,25 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include - -#include "paddle/fluid/framework/op_registry.h" - -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -#include "paddle/fluid/platform/device/gpu/nccl_helper.h" -#endif - -#if defined(PADDLE_WITH_ASCEND_CL) -#include "paddle/fluid/platform/device/npu/hccl_helper.h" -#endif - -#if defined(PADDLE_WITH_CNCL) -#include "paddle/fluid/platform/device/mlu/cncl_helper.h" -#endif - -#if defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL) -#include "paddle/fluid/platform/collective_helper.h" -#endif +#include "paddle/fluid/operators/collective/c_sync_comm_stream_op.h" namespace paddle { namespace operators { @@ -58,62 +40,11 @@ class CSyncCommStreamOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("ring_id", "(int default 0) ring id.").SetDefault(0); AddComment(R"DOC( CSyncCommStream Operator - Call communication stream synchronization. )DOC"); } }; -template -class CSyncCommStreamKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - auto place = ctx.GetPlace(); - int ring_id = ctx.Attr("ring_id"); - auto stream = - platform::NCCLCommContext::Instance().Get(ring_id, place)->stream(); - - platform::GpuStreamSync(stream); - -#elif defined(PADDLE_WITH_ASCEND_CL) - auto place = ctx.GetPlace(); - PADDLE_ENFORCE_EQ(platform::is_npu_place(place), true, - platform::errors::PreconditionNotMet( - "Sync comm stream op can run on npu place only for " - "now, but we got %s, please check the environment.", - place.DebugString())); - int ring_id = ctx.Attr("ring_id"); - auto stream = - platform::HCCLCommContext::Instance().Get(ring_id, place)->stream(); - platform::NPUStreamSync(stream); - -#elif defined(PADDLE_WITH_CNCL) - auto place = ctx.GetPlace(); - PADDLE_ENFORCE_EQ(platform::is_mlu_place(place), true, - platform::errors::PreconditionNotMet( - "Sync stream op can run on mlu place only for now.")); - int ring_id = ctx.Attr("ring_id"); - auto stream = - platform::CNCLCommContext::Instance().Get(ring_id, place)->stream(); - platform::MLUStreamSync(stream); -#elif defined(PADDLE_WITH_XPU_BKCL) - auto place = ctx.GetPlace(); - PADDLE_ENFORCE_EQ(platform::is_xpu_place(place), true, - platform::errors::PreconditionNotMet( - "Sync stream op can run on xpu place only for now.")); - int ring_id = ctx.Attr("ring_id"); - auto comm_dev_ctx = platform::BKCLCommContext::Instance() - .Get(ring_id, place) - ->dev_context(); - comm_dev_ctx->Wait(); -#else - PADDLE_THROW(platform::errors::PreconditionNotMet( - "PaddlePaddle should compile with GPU.")); -#endif - } -}; - } // namespace operators } // namespace paddle @@ -127,5 +58,3 @@ REGISTER_OP_CUDA_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel); REGISTER_OP_NPU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel); REGISTER_OP_MLU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel); - -REGISTER_OP_XPU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel); diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op.h b/paddle/fluid/operators/collective/c_sync_comm_stream_op.h new file mode 100644 index 0000000000..f9dec93037 --- /dev/null +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op.h @@ -0,0 +1,88 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include + +#include "paddle/fluid/framework/op_registry.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/platform/device/gpu/nccl_helper.h" +#endif + +#if defined(PADDLE_WITH_ASCEND_CL) +#include "paddle/fluid/platform/device/npu/hccl_helper.h" +#endif + +#if defined(PADDLE_WITH_CNCL) +#include "paddle/fluid/platform/device/mlu/cncl_helper.h" +#endif + +#if defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL) +#include "paddle/fluid/platform/collective_helper.h" +#endif + +namespace paddle { +namespace operators { + +template +class CSyncCommStreamKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto place = ctx.GetPlace(); + int ring_id = ctx.Attr("ring_id"); + auto stream = + platform::NCCLCommContext::Instance().Get(ring_id, place)->stream(); + + platform::GpuStreamSync(stream); + +#elif defined(PADDLE_WITH_ASCEND_CL) + auto place = ctx.GetPlace(); + PADDLE_ENFORCE_EQ(platform::is_npu_place(place), true, + platform::errors::PreconditionNotMet( + "Sync comm stream op can run on npu place only for " + "now, but we got %s, please check the environment.", + place.DebugString())); + int ring_id = ctx.Attr("ring_id"); + auto stream = + platform::HCCLCommContext::Instance().Get(ring_id, place)->stream(); + platform::NPUStreamSync(stream); + +#elif defined(PADDLE_WITH_CNCL) + auto place = ctx.GetPlace(); + PADDLE_ENFORCE_EQ(platform::is_mlu_place(place), true, + platform::errors::PreconditionNotMet( + "Sync stream op can run on mlu place only for now.")); + int ring_id = ctx.Attr("ring_id"); + auto stream = + platform::CNCLCommContext::Instance().Get(ring_id, place)->stream(); + platform::MLUStreamSync(stream); +#elif defined(PADDLE_WITH_XPU_BKCL) + auto place = ctx.GetPlace(); + PADDLE_ENFORCE_EQ(platform::is_xpu_place(place), true, + platform::errors::PreconditionNotMet( + "Sync stream op can run on xpu place only for now.")); + int ring_id = ctx.Attr("ring_id"); + auto comm_dev_ctx = platform::BKCLCommContext::Instance() + .Get(ring_id, place) + ->dev_context(); + comm_dev_ctx->Wait(); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with GPU.")); +#endif + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op.kps b/paddle/fluid/operators/collective/c_sync_comm_stream_op.kps new file mode 100644 index 0000000000..bfac7bf5c5 --- /dev/null +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op.kps @@ -0,0 +1,42 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU_KP + +// Please do not modify the following code +#if defined(__CUDA_ARCH__) +#undef __CUDA_ARCH__ +#endif + +#if defined(__CUDACC__) +#undef __CUDACC__ +#endif + +#if defined(__CUDA__) +#undef __CUDA__ +#endif + +#if defined(__NVCC__) +#undef __NVCC__ +#endif + +#include "paddle/fluid/operators/collective/c_sync_comm_stream_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_KERNEL(c_sync_comm_stream, KP, plat::XPUPlace, + ops::CSyncCommStreamKernel); + +#endif diff --git a/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h b/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h index 778c18146d..452f388f03 100644 --- a/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h +++ b/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h @@ -113,6 +113,12 @@ XPUOpMap& get_kp_ops() { XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_amax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_amin", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"c_sync_calc_stream", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"c_sync_comm_stream", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + {"c_allreduce_sum", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, }; return s_xpu_kp_kernels; diff --git a/python/paddle/fluid/transpiler/collective.py b/python/paddle/fluid/transpiler/collective.py index 95ab446e1d..1ddebad286 100644 --- a/python/paddle/fluid/transpiler/collective.py +++ b/python/paddle/fluid/transpiler/collective.py @@ -473,33 +473,14 @@ class MultiThread(GradAllReduce): print( "begin to _transpile_startup_program for single-node in XPU") block = self.startup_program.global_block() - comm_id_var = block.create_var( - name=unique_name.generate('comm_id'), - persistable=True, - type=core.VarDesc.VarType.RAW) block.append_op( - type='c_gen_bkcl_id', - inputs={}, - outputs={'Out': comm_id_var}, + type='c_comm_init_all', attrs={ - 'rank': self.rank, - 'endpoint': self.current_endpoint, - 'other_endpoints': self.other_endpoints, - 'ring_id': 0, - self.op_role_key: OpRole.Forward + 'devices': list( + map(int, + os.getenv("FLAGS_selected_gpus").split(","))), + 'ring_id': 0 }) - block.append_op( - type='c_comm_init', - inputs={'X': comm_id_var}, - outputs={}, - attrs={ - 'nranks': - len(os.getenv("FLAGS_selected_gpus").split(",")), - 'rank': self.rank, - 'ring_id': 0, - self.op_role_key: OpRole.Forward - }) - else: print("begin to _transpile_startup_program for single-node") block = self.startup_program.global_block() @@ -515,6 +496,11 @@ class MultiThread(GradAllReduce): elif self.trans_mode == "fuse_all_reduce": print("begin to transpile in fuse all-reduce mode") self._insert_fuse_allreduce_ops() + elif self.trans_mode == "all_reduce_xpu" and len( + os.getenv("FLAGS_selected_gpus").split(",")) == 1: + print( + "skip transpile in all-reduce-xpu mode when number of devices is only one" + ) else: print("begin to transpile in all-reduce mode") self._insert_allreduce_ops() -- GitLab