未验证 提交 0c68ae0c 编写于 作者: F Fan Zhang 提交者: GitHub

[cherry-pick] Adapt BKCL comm for XPUPS (#42266)

* XPUPS Adaptation (#40991)

* Adapt XPUPS - 1st version - 3.24

* Adapt XPUPS - update XPU PushSparse -  2nd version - 3.24

* Adapt XPUPS - add XPU PullSparseOp - 3nd version - 3.25

* refactor heter comm kernel

* update. test=develop

* Adapt XPUPS - modify by compilation - 4th version - 3.27

* update calc_shard_offset. test=develop

* update xpu kernel. test=develop

* update args of calc_shard_offset

* update. test=develop

* remove customGradMerger

* update. test=develop

* heter_comm update

* heter_comm update

* update calc_shard_offset. test=develop

* heter_comm update

* update args of calc_shard_offset

* update. test=develop

* remove customGradMerger

* update. test=develop

* fix. test=develop

* update. test=develop

* update. test=develop

* update optimizer kernel

* Adapt XPUPS - use WITH_XPU_KP and modify wrapper kernel function - 5th version - 3.30

* update. test=develop

* update pslib.cmake

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* Adapt XPUPS - modify by kp compilation  - 6th version - 3.30

* update. test=develop

* update. test=develop

* update. test=develop

* update optimizer kernel

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* fix. test=develop

* fix. test=develop

* used by minxu

* update heter_comm_inl

* fix. test=develop

* Adapt XPUPS - modify by kp compilation  - 7th version - 3.30

* fix. test=develop

* add optimizer kernel. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* 3.31 update

* Adapt XPUPS - update kp compilation path  - 8th version - 3.31

* add optimizer kernel. test=develop

* fix kunlun not support size_t. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix kunlun not support size_t. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* update heter_comm_kernel.kps 3.31

* fix. test=develop

* fix. test=develop

* update heter_comm_kernel.kps 3.31

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* update heter_comm.h 3.31

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* update hashtable. test=develop

* update. test=develop

* Adapt XPUPS - update by kp compilation  - 9th version - 4.1

* update hashtable. test=develop

* fix. test=develop

* update hashtable 4.1

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* Adapt XPUPS - update by kp compilation  - 10th version - 4.1

* fix. test=develop

* fix. test=develop

* fix. test=develop

* update. test=develop

* modify by compilation 4.1

* update. test=develop

* update. test=develop

* fix. test=develop

* modify by compilation 4.1

* update. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* modify by compilation 4.1

* fix. test=develop

* fix. test=develop

* fix. test=develop

* modify by compilation 4.1 19:30

* fix. test=develop

* update ps_gpu_wrapper.kps 4.1

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* Adapt XPUPS - update by kp compilation  - 11th version - 4.1

* fix. test=develop

* Adapt XPUPS - update by kp compilation  - 12nd version - 4.2

* fix. test=develop

* fix. test=develop

* modify by compilation 4.2

* 4.2 update

* fix. test=develop

* template init. test=develop

* update 4.6

* fix. test=develop

* template init. test=develop

* 4.6 modify by compilation

* hashtable template init. test=develop

* hashtable template init. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=devlop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=devlop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* Adapt XPUPS - update by kp compilation  - 13nd version - 4.7

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* 4.11 update

* fix. test=develop

* fix. test=develop

* 4.11 update

* update by pre-commit

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* 4.12 update

* fix. test=develop

* Adapt XPUPS - update by kp compilation  - 14th version - 4.13

* 4.13 update

* 4.14 update

* 4.14 update

* 4.14 update

* 4.14 modify by merged latest compilation

* retry CI 4.14

* 4.15 pass static check

* 4.15 modify by gpups CI

* 3.16 update by gpups CI - modify ps_gpu_wrapper.h

* 4.16 update

* 4.16 pass xpu compile

* 4.16 retry CI

* 4.16 update
Co-authored-by: Nzmxdream <zhangminxu01@baidu.com>

* modify ps_gpu_wrapper.cc

* update

* Adapt BKCL comm for XPUPS (#42168)

* Adapt XPUPS - 1st version - 3.24

* Adapt XPUPS - update XPU PushSparse -  2nd version - 3.24

* Adapt XPUPS - add XPU PullSparseOp - 3nd version - 3.25

* refactor heter comm kernel

* update. test=develop

* Adapt XPUPS - modify by compilation - 4th version - 3.27

* update calc_shard_offset. test=develop

* update xpu kernel. test=develop

* update args of calc_shard_offset

* update. test=develop

* remove customGradMerger

* update. test=develop

* heter_comm update

* heter_comm update

* update calc_shard_offset. test=develop

* heter_comm update

* update args of calc_shard_offset

* update. test=develop

* remove customGradMerger

* update. test=develop

* fix. test=develop

* update. test=develop

* update. test=develop

* update optimizer kernel

* Adapt XPUPS - use WITH_XPU_KP and modify wrapper kernel function - 5th version - 3.30

* update. test=develop

* update pslib.cmake

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* Adapt XPUPS - modify by kp compilation  - 6th version - 3.30

* update. test=develop

* update. test=develop

* update. test=develop

* update optimizer kernel

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* update. test=develop

* fix. test=develop

* fix. test=develop

* used by minxu

* update heter_comm_inl

* fix. test=develop

* Adapt XPUPS - modify by kp compilation  - 7th version - 3.30

* fix. test=develop

* add optimizer kernel. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* 3.31 update

* Adapt XPUPS - update kp compilation path  - 8th version - 3.31

* add optimizer kernel. test=develop

* fix kunlun not support size_t. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix kunlun not support size_t. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* update heter_comm_kernel.kps 3.31

* fix. test=develop

* fix. test=develop

* update heter_comm_kernel.kps 3.31

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* update heter_comm.h 3.31

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* update hashtable. test=develop

* update. test=develop

* Adapt XPUPS - update by kp compilation  - 9th version - 4.1

* update hashtable. test=develop

* fix. test=develop

* update hashtable 4.1

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* Adapt XPUPS - update by kp compilation  - 10th version - 4.1

* fix. test=develop

* fix. test=develop

* fix. test=develop

* update. test=develop

* modify by compilation 4.1

* update. test=develop

* update. test=develop

* fix. test=develop

* modify by compilation 4.1

* update. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* modify by compilation 4.1

* fix. test=develop

* fix. test=develop

* fix. test=develop

* modify by compilation 4.1 19:30

* fix. test=develop

* update ps_gpu_wrapper.kps 4.1

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* Adapt XPUPS - update by kp compilation  - 11th version - 4.1

* fix. test=develop

* Adapt XPUPS - update by kp compilation  - 12nd version - 4.2

* fix. test=develop

* fix. test=develop

* modify by compilation 4.2

* 4.2 update

* fix. test=develop

* template init. test=develop

* update 4.6

* fix. test=develop

* template init. test=develop

* 4.6 modify by compilation

* hashtable template init. test=develop

* hashtable template init. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=devlop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=devlop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* Adapt XPUPS - update by kp compilation  - 13nd version - 4.7

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* 4.11 update

* fix. test=develop

* fix. test=develop

* 4.11 update

* update by pre-commit

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* 4.12 update

* fix. test=develop

* Adapt XPUPS - update by kp compilation  - 14th version - 4.13

* 4.13 update

* 4.14 update

* 4.14 update

* 4.14 update

* 4.14 modify by merged latest compilation

* retry CI 4.14

* 4.15 pass static check

* 4.15 modify by gpups CI

* 3.16 update by gpups CI - modify ps_gpu_wrapper.h

* 4.16 update

* 4.16 pass xpu compile

* 4.16 retry CI

* 4.16 update

* Adapt XPUPS - adapt BKCL comm for XPUPS - 4.24

* update by compilation

* Adapt XPUPS - register PSGPUTrainer for XPUPS - 4.25

* update device_worker_factory
Co-authored-by: Nzmxdream <zhangminxu01@baidu.com>

* update

* update CMakeLists
Co-authored-by: Nzmxdream <zhangminxu01@baidu.com>
上级 6366e0a9
...@@ -522,7 +522,8 @@ class HeterCpuWorker : public HogwildWorker { ...@@ -522,7 +522,8 @@ class HeterCpuWorker : public HogwildWorker {
}; };
#endif #endif
#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \ #if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL || \
defined PADDLE_WITH_XPU_BKCL) && \
(defined PADDLE_WITH_PSLIB) (defined PADDLE_WITH_PSLIB)
class PSGPUWorker : public HogwildWorker { class PSGPUWorker : public HogwildWorker {
public: public:
...@@ -537,8 +538,10 @@ class PSGPUWorker : public HogwildWorker { ...@@ -537,8 +538,10 @@ class PSGPUWorker : public HogwildWorker {
new (&program_) ProgramDesc(main_program); new (&program_) ProgramDesc(main_program);
} }
void ProduceTasks() override; void ProduceTasks() override;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
virtual void SetStream(const gpuStream_t stream) { copy_stream_ = stream; } virtual void SetStream(const gpuStream_t stream) { copy_stream_ = stream; }
virtual void SetEvent(const gpuEvent_t event) { event_ = event; } virtual void SetEvent(const gpuEvent_t event) { event_ = event; }
#endif
void ResetStat(); void ResetStat();
protected: protected:
...@@ -588,8 +591,10 @@ class PSGPUWorker : public HogwildWorker { ...@@ -588,8 +591,10 @@ class PSGPUWorker : public HogwildWorker {
std::unordered_map<uint64_t, std::unordered_set<uint64_t>> feasign_set_; std::unordered_map<uint64_t, std::unordered_set<uint64_t>> feasign_set_;
paddle::framework::Channel<std::shared_ptr<HeterTask>> pull_queue_; paddle::framework::Channel<std::shared_ptr<HeterTask>> pull_queue_;
paddle::framework::Channel<std::shared_ptr<HeterTask>> push_queue_; paddle::framework::Channel<std::shared_ptr<HeterTask>> push_queue_;
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
gpuEvent_t event_; gpuEvent_t event_;
gpuStream_t copy_stream_; gpuStream_t copy_stream_;
#endif
int batch_cnt_{0}; int batch_cnt_{0};
std::atomic<int> done_cnt_{0}; std::atomic<int> done_cnt_{0};
......
...@@ -75,7 +75,8 @@ REGISTER_DEVICE_WORKER_CLASS(HeterSectionWorker); ...@@ -75,7 +75,8 @@ REGISTER_DEVICE_WORKER_CLASS(HeterSectionWorker);
REGISTER_DEVICE_WORKER_CLASS(HeterCpuWorker); REGISTER_DEVICE_WORKER_CLASS(HeterCpuWorker);
#endif #endif
#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \ #if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL || \
defined PADDLE_WITH_XPU_BKCL) && \
(defined PADDLE_WITH_PSLIB) (defined PADDLE_WITH_PSLIB)
REGISTER_DEVICE_WORKER_CLASS(PSGPUWorker); REGISTER_DEVICE_WORKER_CLASS(PSGPUWorker);
#endif #endif
......
...@@ -23,7 +23,8 @@ limitations under the License. */ ...@@ -23,7 +23,8 @@ limitations under the License. */
#include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h" #include "paddle/fluid/framework/fleet/ps_gpu_wrapper.h"
#include "paddle/fluid/framework/trainer.h" #include "paddle/fluid/framework/trainer.h"
#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \ #if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL || \
defined PADDLE_WITH_XPU_BKCL) && \
(defined PADDLE_WITH_PSLIB) (defined PADDLE_WITH_PSLIB)
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
......
...@@ -18,7 +18,8 @@ limitations under the License. */ ...@@ -18,7 +18,8 @@ limitations under the License. */
#include "paddle/fluid/platform/lodtensor_printer.h" #include "paddle/fluid/platform/lodtensor_printer.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \ #if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL || \
defined PADDLE_WITH_XPU_BKCL) && \
(defined PADDLE_WITH_PSLIB) (defined PADDLE_WITH_PSLIB)
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
...@@ -131,6 +132,11 @@ void PSGPUWorker::TrainFiles() { ...@@ -131,6 +132,11 @@ void PSGPUWorker::TrainFiles() {
device_reader_->Start(); device_reader_->Start();
int cur_batch; int cur_batch;
int batch_cnt = 0; int batch_cnt = 0;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
platform::SetDeviceId(thread_id_);
#elif defined(PADDLE_WITH_XPU_BKCL)
platform::SetXPUDeviceId(thread_id_);
#endif
while ((cur_batch = device_reader_->Next()) > 0) { while ((cur_batch = device_reader_->Next()) > 0) {
total_ins_num += cur_batch; total_ins_num += cur_batch;
for (auto& op : ops_) { for (auto& op : ops_) {
...@@ -227,6 +233,11 @@ void PSGPUWorker::TrainFilesWithProfiler() { ...@@ -227,6 +233,11 @@ void PSGPUWorker::TrainFilesWithProfiler() {
int total_ins_num = 0; int total_ins_num = 0;
int cur_batch; int cur_batch;
timeline.Start(); timeline.Start();
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
platform::SetDeviceId(thread_id_);
#elif defined(PADDLE_WITH_XPU_BKCL)
platform::SetXPUDeviceId(thread_id_);
#endif
while ((cur_batch = device_reader_->Next()) > 0) { while ((cur_batch = device_reader_->Next()) > 0) {
total_ins_num += cur_batch; total_ins_num += cur_batch;
timeline.Pause(); timeline.Pause();
......
...@@ -248,7 +248,8 @@ class HeterXpuTrainer : public TrainerBase { ...@@ -248,7 +248,8 @@ class HeterXpuTrainer : public TrainerBase {
#endif #endif
#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \ #if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL || \
defined PADDLE_WITH_XPU_BKCL) && \
(defined PADDLE_WITH_PSLIB) (defined PADDLE_WITH_PSLIB)
class PSGPUTrainer : public TrainerBase { class PSGPUTrainer : public TrainerBase {
public: public:
......
...@@ -76,7 +76,8 @@ REGISTER_TRAINER_CLASS(HeterPipelineTrainer); ...@@ -76,7 +76,8 @@ REGISTER_TRAINER_CLASS(HeterPipelineTrainer);
(defined PADDLE_WITH_PSLIB) && (!defined(PADDLE_WITH_HETERPS)) (defined PADDLE_WITH_PSLIB) && (!defined(PADDLE_WITH_HETERPS))
REGISTER_TRAINER_CLASS(HeterXpuTrainer); REGISTER_TRAINER_CLASS(HeterXpuTrainer);
#endif #endif
#if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL) && \ #if (defined PADDLE_WITH_NCCL || defined PADDLE_WITH_RCCL || \
defined PADDLE_WITH_XPU_BKCL) && \
(defined PADDLE_WITH_PSLIB) (defined PADDLE_WITH_PSLIB)
REGISTER_TRAINER_CLASS(PSGPUTrainer); REGISTER_TRAINER_CLASS(PSGPUTrainer);
#endif #endif
......
...@@ -83,7 +83,6 @@ class CCommInitOp : public framework::OperatorBase { ...@@ -83,7 +83,6 @@ class CCommInitOp : public framework::OperatorBase {
UniqueId* comm_id = var->GetMutable<UniqueId>(); UniqueId* comm_id = var->GetMutable<UniqueId>();
int nranks = Attr<int>("nranks"); int nranks = Attr<int>("nranks");
int rank_id = Attr<int>("rank");
int rid = Attr<int>("ring_id"); int rid = Attr<int>("ring_id");
#if defined(PADDLE_WITH_XPU_BKCL) #if defined(PADDLE_WITH_XPU_BKCL)
...@@ -98,8 +97,18 @@ class CCommInitOp : public framework::OperatorBase { ...@@ -98,8 +97,18 @@ class CCommInitOp : public framework::OperatorBase {
if (Attr<int>("device_id") >= 0) { if (Attr<int>("device_id") >= 0) {
device_id = Attr<int>("device_id"); device_id = Attr<int>("device_id");
} }
#if defined(PADDLE_WITH_XPU_BKCL) && defined(PADDLE_WITH_HETERPS) && \
defined(PADDLE_WITH_PSLIB)
// XPUPS rank_id only equals 0, so replace rank_id with device_id
CommContext::Instance().CreateComm(comm_id, nranks, device_id, device_id,
rid);
#else
int rank_id = Attr<int>("rank");
CommContext::Instance().CreateComm(comm_id, nranks, rank_id, device_id, CommContext::Instance().CreateComm(comm_id, nranks, rank_id, device_id,
rid); rid);
#endif
#endif #endif
} }
}; };
......
...@@ -76,7 +76,15 @@ class CSyncCalcStreamKernel : public framework::OpKernel<T> { ...@@ -76,7 +76,15 @@ class CSyncCalcStreamKernel : public framework::OpKernel<T> {
auto dev_ctx = static_cast<platform::MLUDeviceContext*>( auto dev_ctx = static_cast<platform::MLUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(place)); platform::DeviceContextPool::Instance().Get(place));
platform::MLUStreamSync(dev_ctx->stream()); platform::MLUStreamSync(dev_ctx->stream());
#elif defined(PADDLE_WITH_XPU_BKCL)
auto place = ctx.GetPlace();
PADDLE_ENFORCE_EQ(platform::is_xpu_place(place), true,
platform::errors::PreconditionNotMet(
"Sync stream op can run on xpu place only for now."));
auto dev_ctx = static_cast<platform::XPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
dev_ctx->Wait();
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU.")); "PaddlePaddle should compile with GPU."));
...@@ -97,3 +105,5 @@ REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel<float>); ...@@ -97,3 +105,5 @@ REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel<float>);
REGISTER_OP_NPU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel<float>); REGISTER_OP_NPU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel<float>);
REGISTER_OP_MLU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel<float>); REGISTER_OP_MLU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel<float>);
REGISTER_OP_XPU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel<float>);
...@@ -20,7 +20,6 @@ limitations under the License. */ ...@@ -20,7 +20,6 @@ limitations under the License. */
#endif #endif
#if defined(PADDLE_WITH_ASCEND_CL) #if defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/npu/hccl_helper.h" #include "paddle/fluid/platform/device/npu/hccl_helper.h"
#endif #endif
...@@ -28,6 +27,10 @@ limitations under the License. */ ...@@ -28,6 +27,10 @@ limitations under the License. */
#include "paddle/fluid/platform/device/mlu/cncl_helper.h" #include "paddle/fluid/platform/device/mlu/cncl_helper.h"
#endif #endif
#if defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_ASCEND_CL)
#include "paddle/fluid/platform/collective_helper.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -94,7 +97,16 @@ class CSyncCommStreamKernel : public framework::OpKernel<T> { ...@@ -94,7 +97,16 @@ class CSyncCommStreamKernel : public framework::OpKernel<T> {
auto stream = auto stream =
platform::CNCLCommContext::Instance().Get(ring_id, place)->stream(); platform::CNCLCommContext::Instance().Get(ring_id, place)->stream();
platform::MLUStreamSync(stream); platform::MLUStreamSync(stream);
#elif defined(PADDLE_WITH_XPU_BKCL)
auto place = ctx.GetPlace();
PADDLE_ENFORCE_EQ(platform::is_xpu_place(place), true,
platform::errors::PreconditionNotMet(
"Sync stream op can run on xpu place only for now."));
int ring_id = ctx.Attr<int>("ring_id");
auto comm_dev_ctx = platform::BKCLCommContext::Instance()
.Get(ring_id, place)
->dev_context();
comm_dev_ctx->Wait();
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU.")); "PaddlePaddle should compile with GPU."));
...@@ -115,3 +127,5 @@ REGISTER_OP_CUDA_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel<float>); ...@@ -115,3 +127,5 @@ REGISTER_OP_CUDA_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel<float>);
REGISTER_OP_NPU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel<float>); REGISTER_OP_NPU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel<float>);
REGISTER_OP_MLU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel<float>); REGISTER_OP_MLU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel<float>);
REGISTER_OP_XPU_KERNEL(c_sync_comm_stream, ops::CSyncCommStreamKernel<float>);
...@@ -1139,10 +1139,11 @@ class DownpourOptimizer(DistributedOptimizer): ...@@ -1139,10 +1139,11 @@ class DownpourOptimizer(DistributedOptimizer):
from paddle.fluid.transpiler.collective import MultiThread from paddle.fluid.transpiler.collective import MultiThread
# check start program # check start program
if program_mode not in [ if program_mode not in [
"all_reduce", "fuse_all_reduce", "all_gather" "all_reduce", "fuse_all_reduce", "all_gather",
"all_reduce_xpu"
]: ]:
raise ValueError("You should set program_mode in [ all_reduce, \ raise ValueError("You should set program_mode in [ all_reduce, \
fuse_all_reduce, all_gather ]") fuse_all_reduce, all_gather, all_reduce_xpu ]")
env = self.get_dist_env() env = self.get_dist_env()
if not isinstance(losses, list): if not isinstance(losses, list):
startup_programs = [startup_programs] startup_programs = [startup_programs]
......
...@@ -42,6 +42,7 @@ class Collective(object): ...@@ -42,6 +42,7 @@ class Collective(object):
self.nrings = nrings self.nrings = nrings
self.endpoints = None self.endpoints = None
self.current_endpoint = None self.current_endpoint = None
self.other_endpoints = None
self.nranks = None self.nranks = None
self.rank = None self.rank = None
self.startup_program = None self.startup_program = None
...@@ -79,6 +80,12 @@ class Collective(object): ...@@ -79,6 +80,12 @@ class Collective(object):
self.endpoints = endpoints self.endpoints = endpoints
self.current_endpoint = current_endpoint self.current_endpoint = current_endpoint
if current_endpoint:
nranks = len(endpoints)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
self.other_endpoints = other_endpoints
self.wait_port = wait_port self.wait_port = wait_port
self.startup_program._origin_program = self.startup_program.clone() self.startup_program._origin_program = self.startup_program.clone()
...@@ -462,9 +469,41 @@ class MultiThread(GradAllReduce): ...@@ -462,9 +469,41 @@ class MultiThread(GradAllReduce):
self.rank, ring_id, self.wait_port, True) self.rank, ring_id, self.wait_port, True)
else: else:
print("begin to _transpile_startup_program for single-node") if "xpu" in self.trans_mode:
block = self.startup_program.global_block() print(
block.append_op(type='c_comm_init_all', attrs={'ring_id': 0}) "begin to _transpile_startup_program for single-node in XPU")
block = self.startup_program.global_block()
comm_id_var = block.create_var(
name=unique_name.generate('comm_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_bkcl_id',
inputs={},
outputs={'Out': comm_id_var},
attrs={
'rank': self.rank,
'endpoint': self.current_endpoint,
'other_endpoints': self.other_endpoints,
'ring_id': 0,
self.op_role_key: OpRole.Forward
})
block.append_op(
type='c_comm_init',
inputs={'X': comm_id_var},
outputs={},
attrs={
'nranks':
len(os.getenv("FLAGS_selected_gpus").split(",")),
'rank': self.rank,
'ring_id': 0,
self.op_role_key: OpRole.Forward
})
else:
print("begin to _transpile_startup_program for single-node")
block = self.startup_program.global_block()
block.append_op(type='c_comm_init_all', attrs={'ring_id': 0})
def _transpile_main_program(self): def _transpile_main_program(self):
self._insert_scale_loss_grad_ops() self._insert_scale_loss_grad_ops()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册