From 0c68ae0c774a710a83f8a730f3d8d482b2090c80 Mon Sep 17 00:00:00 2001 From: Fan Zhang Date: Tue, 26 Apr 2022 20:19:49 +0800 Subject: [PATCH] [cherry-pick] Adapt BKCL comm for XPUPS (#42266) * XPUPS Adaptation (#40991) * Adapt XPUPS - 1st version - 3.24 * Adapt XPUPS - update XPU PushSparse - 2nd version - 3.24 * Adapt XPUPS - add XPU PullSparseOp - 3nd version - 3.25 * refactor heter comm kernel * update. test=develop * Adapt XPUPS - modify by compilation - 4th version - 3.27 * update calc_shard_offset. test=develop * update xpu kernel. test=develop * update args of calc_shard_offset * update. test=develop * remove customGradMerger * update. test=develop * heter_comm update * heter_comm update * update calc_shard_offset. test=develop * heter_comm update * update args of calc_shard_offset * update. test=develop * remove customGradMerger * update. test=develop * fix. test=develop * update. test=develop * update. test=develop * update optimizer kernel * Adapt XPUPS - use WITH_XPU_KP and modify wrapper kernel function - 5th version - 3.30 * update. test=develop * update pslib.cmake * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * Adapt XPUPS - modify by kp compilation - 6th version - 3.30 * update. test=develop * update. test=develop * update. test=develop * update optimizer kernel * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * fix. test=develop * fix. test=develop * used by minxu * update heter_comm_inl * fix. test=develop * Adapt XPUPS - modify by kp compilation - 7th version - 3.30 * fix. test=develop * add optimizer kernel. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * 3.31 update * Adapt XPUPS - update kp compilation path - 8th version - 3.31 * add optimizer kernel. test=develop * fix kunlun not support size_t. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix kunlun not support size_t. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * update heter_comm_kernel.kps 3.31 * fix. test=develop * fix. test=develop * update heter_comm_kernel.kps 3.31 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * update heter_comm.h 3.31 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * update hashtable. test=develop * update. test=develop * Adapt XPUPS - update by kp compilation - 9th version - 4.1 * update hashtable. test=develop * fix. test=develop * update hashtable 4.1 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * Adapt XPUPS - update by kp compilation - 10th version - 4.1 * fix. test=develop * fix. test=develop * fix. test=develop * update. test=develop * modify by compilation 4.1 * update. test=develop * update. test=develop * fix. test=develop * modify by compilation 4.1 * update. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * modify by compilation 4.1 * fix. test=develop * fix. test=develop * fix. test=develop * modify by compilation 4.1 19:30 * fix. test=develop * update ps_gpu_wrapper.kps 4.1 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * Adapt XPUPS - update by kp compilation - 11th version - 4.1 * fix. test=develop * Adapt XPUPS - update by kp compilation - 12nd version - 4.2 * fix. test=develop * fix. test=develop * modify by compilation 4.2 * 4.2 update * fix. test=develop * template init. test=develop * update 4.6 * fix. test=develop * template init. test=develop * 4.6 modify by compilation * hashtable template init. test=develop * hashtable template init. test=develop * fix. test=develop * fix. test=develop * fix. test=devlop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=devlop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * Adapt XPUPS - update by kp compilation - 13nd version - 4.7 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * 4.11 update * fix. test=develop * fix. test=develop * 4.11 update * update by pre-commit * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * 4.12 update * fix. test=develop * Adapt XPUPS - update by kp compilation - 14th version - 4.13 * 4.13 update * 4.14 update * 4.14 update * 4.14 update * 4.14 modify by merged latest compilation * retry CI 4.14 * 4.15 pass static check * 4.15 modify by gpups CI * 3.16 update by gpups CI - modify ps_gpu_wrapper.h * 4.16 update * 4.16 pass xpu compile * 4.16 retry CI * 4.16 update Co-authored-by: zmxdream * modify ps_gpu_wrapper.cc * update * Adapt BKCL comm for XPUPS (#42168) * Adapt XPUPS - 1st version - 3.24 * Adapt XPUPS - update XPU PushSparse - 2nd version - 3.24 * Adapt XPUPS - add XPU PullSparseOp - 3nd version - 3.25 * refactor heter comm kernel * update. test=develop * Adapt XPUPS - modify by compilation - 4th version - 3.27 * update calc_shard_offset. test=develop * update xpu kernel. test=develop * update args of calc_shard_offset * update. test=develop * remove customGradMerger * update. test=develop * heter_comm update * heter_comm update * update calc_shard_offset. test=develop * heter_comm update * update args of calc_shard_offset * update. test=develop * remove customGradMerger * update. test=develop * fix. test=develop * update. test=develop * update. test=develop * update optimizer kernel * Adapt XPUPS - use WITH_XPU_KP and modify wrapper kernel function - 5th version - 3.30 * update. test=develop * update pslib.cmake * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * Adapt XPUPS - modify by kp compilation - 6th version - 3.30 * update. test=develop * update. test=develop * update. test=develop * update optimizer kernel * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * update. test=develop * fix. test=develop * fix. test=develop * used by minxu * update heter_comm_inl * fix. test=develop * Adapt XPUPS - modify by kp compilation - 7th version - 3.30 * fix. test=develop * add optimizer kernel. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * 3.31 update * Adapt XPUPS - update kp compilation path - 8th version - 3.31 * add optimizer kernel. test=develop * fix kunlun not support size_t. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix kunlun not support size_t. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * update heter_comm_kernel.kps 3.31 * fix. test=develop * fix. test=develop * update heter_comm_kernel.kps 3.31 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * update heter_comm.h 3.31 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * update hashtable. test=develop * update. test=develop * Adapt XPUPS - update by kp compilation - 9th version - 4.1 * update hashtable. test=develop * fix. test=develop * update hashtable 4.1 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * Adapt XPUPS - update by kp compilation - 10th version - 4.1 * fix. test=develop * fix. test=develop * fix. test=develop * update. test=develop * modify by compilation 4.1 * update. test=develop * update. test=develop * fix. test=develop * modify by compilation 4.1 * update. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * modify by compilation 4.1 * fix. test=develop * fix. test=develop * fix. test=develop * modify by compilation 4.1 19:30 * fix. test=develop * update ps_gpu_wrapper.kps 4.1 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * Adapt XPUPS - update by kp compilation - 11th version - 4.1 * fix. test=develop * Adapt XPUPS - update by kp compilation - 12nd version - 4.2 * fix. test=develop * fix. test=develop * modify by compilation 4.2 * 4.2 update * fix. test=develop * template init. test=develop * update 4.6 * fix. test=develop * template init. test=develop * 4.6 modify by compilation * hashtable template init. test=develop * hashtable template init. test=develop * fix. test=develop * fix. test=develop * fix. test=devlop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=devlop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * Adapt XPUPS - update by kp compilation - 13nd version - 4.7 * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * 4.11 update * fix. test=develop * fix. test=develop * 4.11 update * update by pre-commit * fix. test=develop * fix. test=develop * fix. test=develop * fix. test=develop * 4.12 update * fix. test=develop * Adapt XPUPS - update by kp compilation - 14th version - 4.13 * 4.13 update * 4.14 update * 4.14 update * 4.14 update * 4.14 modify by merged latest compilation * retry CI 4.14 * 4.15 pass static check * 4.15 modify by gpups CI * 3.16 update by gpups CI - modify ps_gpu_wrapper.h * 4.16 update * 4.16 pass xpu compile * 4.16 retry CI * 4.16 update * Adapt XPUPS - adapt BKCL comm for XPUPS - 4.24 * update by compilation * Adapt XPUPS - register PSGPUTrainer for XPUPS - 4.25 * update device_worker_factory Co-authored-by: zmxdream * update * update CMakeLists Co-authored-by: zmxdream --- paddle/fluid/framework/device_worker.h | 7 ++- .../fluid/framework/device_worker_factory.cc | 3 +- paddle/fluid/framework/ps_gpu_trainer.cc | 3 +- paddle/fluid/framework/ps_gpu_worker.cc | 13 +++++- paddle/fluid/framework/trainer.h | 3 +- paddle/fluid/framework/trainer_factory.cc | 3 +- .../operators/collective/c_comm_init_op.cc | 11 ++++- .../collective/c_sync_calc_stream_op.cc | 10 +++++ .../collective/c_sync_comm_stream_op.cc | 18 +++++++- .../fleet/parameter_server/pslib/__init__.py | 5 ++- python/paddle/fluid/transpiler/collective.py | 45 +++++++++++++++++-- 11 files changed, 107 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index e1a1c1fab5..895e459a37 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -522,7 +522,8 @@ class HeterCpuWorker : public HogwildWorker { }; #endif -#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \ +#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL || \ + defined PADDLE_WITH_XPU_BKCL) && \ (defined PADDLE_WITH_PSLIB) class PSGPUWorker : public HogwildWorker { public: @@ -537,8 +538,10 @@ class PSGPUWorker : public HogwildWorker { new (&program_) ProgramDesc(main_program); } void ProduceTasks() override; +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) virtual void SetStream(const gpuStream_t stream) { copy_stream_ = stream; } virtual void SetEvent(const gpuEvent_t event) { event_ = event; } +#endif void ResetStat(); protected: @@ -588,8 +591,10 @@ class PSGPUWorker : public HogwildWorker { std::unordered_map> feasign_set_; paddle::framework::Channel> pull_queue_; paddle::framework::Channel> push_queue_; +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) gpuEvent_t event_; gpuStream_t copy_stream_; +#endif int batch_cnt_{0}; std::atomic done_cnt_{0}; diff --git a/paddle/fluid/framework/device_worker_factory.cc b/paddle/fluid/framework/device_worker_factory.cc index 9c418b2f78..e6635a2f94 100644 --- a/paddle/fluid/framework/device_worker_factory.cc +++ b/paddle/fluid/framework/device_worker_factory.cc @@ -75,7 +75,8 @@ REGISTER_DEVICE_WORKER_CLASS(HeterSectionWorker); REGISTER_DEVICE_WORKER_CLASS(HeterCpuWorker); #endif -#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \ +#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL || \ + defined PADDLE_WITH_XPU_BKCL) && \ (defined PADDLE_WITH_PSLIB) REGISTER_DEVICE_WORKER_CLASS(PSGPUWorker); #endif diff --git a/paddle/fluid/framework/ps_gpu_trainer.cc b/paddle/fluid/framework/ps_gpu_trainer.cc index e4004c2fbf..9b12870a2b 100644 --- a/paddle/fluid/framework/ps_gpu_trainer.cc +++ b/paddle/fluid/framework/ps_gpu_trainer.cc @@ -23,7 +23,8 @@ limitations under the License. */ #include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" #include "paddle/fluid/framework/trainer.h" -#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \ +#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL || \ + defined PADDLE_WITH_XPU_BKCL) && \ (defined PADDLE_WITH_PSLIB) #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cuda_device_guard.h" diff --git a/paddle/fluid/framework/ps_gpu_worker.cc b/paddle/fluid/framework/ps_gpu_worker.cc index 323b2c4803..e46bdf6d87 100644 --- a/paddle/fluid/framework/ps_gpu_worker.cc +++ b/paddle/fluid/framework/ps_gpu_worker.cc @@ -18,7 +18,8 @@ limitations under the License. */ #include "paddle/fluid/platform/lodtensor_printer.h" #include "paddle/fluid/string/string_helper.h" -#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \ +#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL || \ + defined PADDLE_WITH_XPU_BKCL) && \ (defined PADDLE_WITH_PSLIB) #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/platform/cuda_device_guard.h" @@ -131,6 +132,11 @@ void PSGPUWorker::TrainFiles() { device_reader_->Start(); int cur_batch; int batch_cnt = 0; +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + platform::SetDeviceId(thread_id_); +#elif defined(PADDLE_WITH_XPU_BKCL) + platform::SetXPUDeviceId(thread_id_); +#endif while ((cur_batch = device_reader_->Next()) > 0) { total_ins_num += cur_batch; for (auto& op : ops_) { @@ -227,6 +233,11 @@ void PSGPUWorker::TrainFilesWithProfiler() { int total_ins_num = 0; int cur_batch; timeline.Start(); +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + platform::SetDeviceId(thread_id_); +#elif defined(PADDLE_WITH_XPU_BKCL) + platform::SetXPUDeviceId(thread_id_); +#endif while ((cur_batch = device_reader_->Next()) > 0) { total_ins_num += cur_batch; timeline.Pause(); diff --git a/paddle/fluid/framework/trainer.h b/paddle/fluid/framework/trainer.h index 8a11775702..2496d4d040 100644 --- a/paddle/fluid/framework/trainer.h +++ b/paddle/fluid/framework/trainer.h @@ -248,7 +248,8 @@ class HeterXpuTrainer : public TrainerBase { #endif -#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \ +#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL || \ + defined PADDLE_WITH_XPU_BKCL) && \ (defined PADDLE_WITH_PSLIB) class PSGPUTrainer : public TrainerBase { public: diff --git a/paddle/fluid/framework/trainer_factory.cc b/paddle/fluid/framework/trainer_factory.cc index f189d0213d..1f1122d32f 100644 --- a/paddle/fluid/framework/trainer_factory.cc +++ b/paddle/fluid/framework/trainer_factory.cc @@ -76,7 +76,8 @@ REGISTER_TRAINER_CLASS(HeterPipelineTrainer); (defined PADDLE_WITH_PSLIB) && (!defined(PADDLE_WITH_HETERPS)) REGISTER_TRAINER_CLASS(HeterXpuTrainer); #endif -#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \ +#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL || \ + defined PADDLE_WITH_XPU_BKCL) && \ (defined PADDLE_WITH_PSLIB) REGISTER_TRAINER_CLASS(PSGPUTrainer); #endif diff --git a/paddle/fluid/operators/collective/c_comm_init_op.cc b/paddle/fluid/operators/collective/c_comm_init_op.cc index 39acb50d4e..82d3b1b1db 100644 --- a/paddle/fluid/operators/collective/c_comm_init_op.cc +++ b/paddle/fluid/operators/collective/c_comm_init_op.cc @@ -83,7 +83,6 @@ class CCommInitOp : public framework::OperatorBase { UniqueId* comm_id = var->GetMutable(); int nranks = Attr("nranks"); - int rank_id = Attr("rank"); int rid = Attr("ring_id"); #if defined(PADDLE_WITH_XPU_BKCL) @@ -98,8 +97,18 @@ class CCommInitOp : public framework::OperatorBase { if (Attr("device_id") >= 0) { device_id = Attr("device_id"); } + +#if defined(PADDLE_WITH_XPU_BKCL) && defined(PADDLE_WITH_HETERPS) && \ + defined(PADDLE_WITH_PSLIB) + // XPUPS rank_id only equals 0, so replace rank_id with device_id + CommContext::Instance().CreateComm(comm_id, nranks, device_id, device_id, + rid); +#else + int rank_id = Attr("rank"); CommContext::Instance().CreateComm(comm_id, nranks, rank_id, device_id, rid); +#endif + #endif } }; diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc index 42584948e0..088366dbc8 100644 --- a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc @@ -76,7 +76,15 @@ class CSyncCalcStreamKernel : public framework::OpKernel { auto dev_ctx = static_cast( platform::DeviceContextPool::Instance().Get(place)); platform::MLUStreamSync(dev_ctx->stream()); +#elif defined(PADDLE_WITH_XPU_BKCL) + auto place = ctx.GetPlace(); + PADDLE_ENFORCE_EQ(platform::is_xpu_place(place), true, + platform::errors::PreconditionNotMet( + "Sync stream op can run on xpu place only for now.")); + auto dev_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(place)); + dev_ctx->Wait(); #else PADDLE_THROW(platform::errors::PreconditionNotMet( "PaddlePaddle should compile with GPU.")); @@ -97,3 +105,5 @@ REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel); REGISTER_OP_NPU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel); REGISTER_OP_MLU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel); + +REGISTER_OP_XPU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel); diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc index 37ce4ef7ee..5a9a00aa8e 100644 --- a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc @@ -20,7 +20,6 @@ limitations under the License. */ #endif #if defined(PADDLE_WITH_ASCEND_CL) -#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/npu/hccl_helper.h" #endif @@ -28,6 +27,10 @@ limitations under the License. */ #include "paddle/fluid/platform/device/mlu/cncl_helper.h" #endif +#if defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL) +#include "paddle/fluid/platform/collective_helper.h" +#endif + namespace paddle { namespace operators { @@ -94,7 +97,16 @@ class CSyncCommStreamKernel : public framework::OpKernel { auto stream = platform::CNCLCommContext::Instance().Get(ring_id, place)->stream(); platform::MLUStreamSync(stream); - +#elif defined(PADDLE_WITH_XPU_BKCL) + auto place = ctx.GetPlace(); + PADDLE_ENFORCE_EQ(platform::is_xpu_place(place), true, + platform::errors::PreconditionNotMet( + "Sync stream op can run on xpu place only for now.")); + int ring_id = ctx.Attr("ring_id"); + auto comm_dev_ctx = platform::BKCLCommContext::Instance() + .Get(ring_id, place) + ->dev_context(); + comm_dev_ctx->Wait(); #else PADDLE_THROW(platform::errors::PreconditionNotMet( "PaddlePaddle should compile with GPU.")); @@ -115,3 +127,5 @@ REGISTER_OP_CUDA_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel); REGISTER_OP_NPU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel); REGISTER_OP_MLU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel); + +REGISTER_OP_XPU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel); diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py index 8d803c0d5b..40ff41fe89 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/pslib/__init__.py @@ -1139,10 +1139,11 @@ class DownpourOptimizer(DistributedOptimizer): from paddle.fluid.transpiler.collective import MultiThread # check start program if program_mode not in [ - "all_reduce", "fuse_all_reduce", "all_gather" + "all_reduce", "fuse_all_reduce", "all_gather", + "all_reduce_xpu" ]: raise ValueError("You should set program_mode in [ all_reduce, \ - fuse_all_reduce, all_gather ]") + fuse_all_reduce, all_gather, all_reduce_xpu ]") env = self.get_dist_env() if not isinstance(losses, list): startup_programs = [startup_programs] diff --git a/python/paddle/fluid/transpiler/collective.py b/python/paddle/fluid/transpiler/collective.py index ea88a89e68..95ab446e1d 100644 --- a/python/paddle/fluid/transpiler/collective.py +++ b/python/paddle/fluid/transpiler/collective.py @@ -42,6 +42,7 @@ class Collective(object): self.nrings = nrings self.endpoints = None self.current_endpoint = None + self.other_endpoints = None self.nranks = None self.rank = None self.startup_program = None @@ -79,6 +80,12 @@ class Collective(object): self.endpoints = endpoints self.current_endpoint = current_endpoint + if current_endpoint: + nranks = len(endpoints) + other_endpoints = endpoints[:] + other_endpoints.remove(current_endpoint) + self.other_endpoints = other_endpoints + self.wait_port = wait_port self.startup_program._origin_program = self.startup_program.clone() @@ -462,9 +469,41 @@ class MultiThread(GradAllReduce): self.rank, ring_id, self.wait_port, True) else: - print("begin to _transpile_startup_program for single-node") - block = self.startup_program.global_block() - block.append_op(type='c_comm_init_all', attrs={'ring_id': 0}) + if "xpu" in self.trans_mode: + print( + "begin to _transpile_startup_program for single-node in XPU") + block = self.startup_program.global_block() + comm_id_var = block.create_var( + name=unique_name.generate('comm_id'), + persistable=True, + type=core.VarDesc.VarType.RAW) + block.append_op( + type='c_gen_bkcl_id', + inputs={}, + outputs={'Out': comm_id_var}, + attrs={ + 'rank': self.rank, + 'endpoint': self.current_endpoint, + 'other_endpoints': self.other_endpoints, + 'ring_id': 0, + self.op_role_key: OpRole.Forward + }) + block.append_op( + type='c_comm_init', + inputs={'X': comm_id_var}, + outputs={}, + attrs={ + 'nranks': + len(os.getenv("FLAGS_selected_gpus").split(",")), + 'rank': self.rank, + 'ring_id': 0, + self.op_role_key: OpRole.Forward + }) + + else: + print("begin to _transpile_startup_program for single-node") + block = self.startup_program.global_block() + block.append_op(type='c_comm_init_all', attrs={'ring_id': 0}) def _transpile_main_program(self): self._insert_scale_loss_grad_ops() -- GitLab