diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index 380ae841cd054416f267b23079515dcec58303b6..01f707eb9baaf4d9cc659b4df67e71e45742f3dc 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -276,19 +276,26 @@ void AllReduceOpHandle::BKCLAllReduceFunc( if (all_reduce_calls.size() == 1UL) { all_reduce_calls[0](); } else { - PADDLE_ENFORCE_EQ( - bkcl_group_start(), - BKCL_SUCCESS, - platform::errors::PreconditionNotMet("bkcl_group_start failed")); + platform::BKCLGroupGuard guard; for (auto &call : all_reduce_calls) { call(); } - PADDLE_ENFORCE_EQ( - bkcl_group_end(), - BKCL_SUCCESS, - platform::errors::PreconditionNotMet("bkcl_group_end failed")); } }); + + SyncBKCLAllReduce(); +} + +void AllReduceOpHandle::SyncBKCLAllReduce() { + // bkcl always use async kernel + for (auto &p : places_) { + int dev_id = p.device; + platform::SetXPUDeviceId(dev_id); + auto *bkcl_ctxs = + bkcl_ctxs_->GetRunEnvBKCLCtx(run_order_, use_hierarchical_allreduce_); + auto &bkcl_ctx = bkcl_ctxs->at(dev_id); + platform::XPUStreamSync(bkcl_ctx.stream()); + } } #endif diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.h b/paddle/fluid/framework/details/all_reduce_op_handle.h index d628a3b1ee18165c3894c6253208c26c2a0718e8..685ab0b957a448de3a3be8fe109fc531736a4740 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/all_reduce_op_handle.h @@ -94,6 +94,8 @@ class AllReduceOpHandle : public OpHandleBase { #if defined(PADDLE_WITH_XPU_BKCL) void BKCLAllReduceFunc( const std::vector> &all_reduce_calls); + + void SyncBKCLAllReduce(); #endif void AllReduceImpl(const std::vector &in_var_handles, diff --git a/paddle/fluid/framework/details/bkcl_op_handle.h b/paddle/fluid/framework/details/bkcl_op_handle.h index 650e69d0cbfd264bba47f6d89f00e19622e84e2c..4ca8bf4cb587491bacb1fdb47616f9a9020d60f4 100644 --- a/paddle/fluid/framework/details/bkcl_op_handle.h +++ b/paddle/fluid/framework/details/bkcl_op_handle.h @@ -92,7 +92,9 @@ class BKCLOpHandleBase : public OpHandleBase { "The argument run_order_ must be >= 0, but got %d.", run_order_)); auto flat_bkcl_ctxs = bkcl_ctxs_->GetFlatCtx(run_order_); int dev_id = place.device; + platform::SetXPUDeviceId(dev_id); auto& bkcl_ctx = flat_bkcl_ctxs->at(dev_id); + auto stream = bkcl_ctx.stream(); auto comm = bkcl_ctx.comm_; VLOG(10) << "before all reduce buffer:" << sendbuff << ", numel:" << count @@ -100,7 +102,7 @@ class BKCLOpHandleBase : public OpHandleBase { << ", place:" << place; PADDLE_ENFORCE_EQ( - bkcl_all_reduce(comm, sendbuff, recvbuff, count, datatype, op, NULL), + bkcl_all_reduce(comm, sendbuff, recvbuff, count, datatype, op, stream), BKCL_SUCCESS, platform::errors::PreconditionNotMet("bckl all reduce failed")); } diff --git a/paddle/fluid/platform/device/xpu/bkcl_helper.h b/paddle/fluid/platform/device/xpu/bkcl_helper.h index c1882eb752080dbf59bff7a09752abcb64ff7891..7bd1b67efd6144df3cb925ec517a123a1d0f7caa 100644 --- a/paddle/fluid/platform/device/xpu/bkcl_helper.h +++ b/paddle/fluid/platform/device/xpu/bkcl_helper.h @@ -27,6 +27,7 @@ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/xpu/enforce_xpu.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/float16.h" @@ -57,6 +58,24 @@ inline BKCLDataType ToBKCLDataType(framework::proto::VarType::Type type) { } } +class BKCLGroupGuard { + public: + static std::mutex &BKCLMutex() { + static std::mutex mtx; + return mtx; + } + + inline BKCLGroupGuard() { + BKCLMutex().lock(); + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start()); + } + + inline ~BKCLGroupGuard() PADDLE_MAY_THROW { + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_end()); + BKCLMutex().unlock(); + } +}; + struct BKCLContext { std::unique_ptr ctx_; BKCLContext_t comm_; @@ -65,6 +84,7 @@ struct BKCLContext { : ctx_(new platform::XPUDeviceContext(XPUPlace(dev_id))), comm_{nullptr} {} + XPUStream stream() const { return ctx_->stream(); } BKCLContext_t comm() const { return comm_; } int device_id() const { return ctx_->GetPlace().device; } @@ -258,19 +278,33 @@ class BKCLCommunicator { ptr->init(); VLOG(1) << "init local trainer"; flat_ctxs_.emplace_back(ptr); - return; + } else { + PADDLE_ENFORCE_EQ(bkcl_ids.size(), + 1, + platform::errors::Unimplemented( + "Multi-all-reduce-ring is not support for XPU")); + for (size_t i = 0; i < bkcl_ids.size(); i++) { + auto ptr = new platform::BKCLContextMap( + places, bkcl_ids[i], trainers_num, trainer_id); + ptr->init(); + VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i; + flat_ctxs_.emplace_back(ptr); + } } - PADDLE_ENFORCE_EQ(bkcl_ids.size(), - 1, - platform::errors::Unimplemented( - "Multi-all-reduce-ring is not support for XPU")); - for (size_t i = 0; i < bkcl_ids.size(); i++) { - auto ptr = new platform::BKCLContextMap( - places, bkcl_ids[i], trainers_num, trainer_id); - ptr->init(); - VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i; - flat_ctxs_.emplace_back(ptr); + // as Executor have no way to use BKCLComm created by ParallelExecutor, + // we assign all flatten contexts to BKCLCommContext to fix. + int nranks = static_cast(trainers_num * places.size()); + int nrings = static_cast(flat_ctxs_.size()); + for (int ring_id = 0; ring_id < nrings; ++ring_id) { + for (size_t p = 0; p < places.size(); ++p) { + int rank = trainer_id * places.size() + p; + int dev_id = places[p].device; + platform::SetXPUDeviceId(dev_id); + auto &ctx = flat_ctxs_[ring_id]->contexts_.at(dev_id); + BKCLCommContext::Instance().AssignBKCLComm( + ctx.comm_, nranks, rank, dev_id, ring_id); + } } } diff --git a/paddle/phi/backends/xpu/xpu_context.cc b/paddle/phi/backends/xpu/xpu_context.cc index fe0dda2d3dbeb7ed22d1eb069a3ae5b5f5e80fc5..2735e2a4208bafe7b05c439003f26e9d9c8f7f91 100644 --- a/paddle/phi/backends/xpu/xpu_context.cc +++ b/paddle/phi/backends/xpu/xpu_context.cc @@ -68,6 +68,12 @@ struct XPUContext::Impl { void SetStream(XPUStream stream) { context_->xpu_stream = stream; } + XPUStream stream() const { + auto s = context_->xpu_stream; + PD_CHECK(s != nullptr, "the xpu stream is nullptr."); + return s; + } + xpu::Context* GetXContext() const { PD_CHECK(context_ != nullptr, "the xpu context is nullptr."); return context_; @@ -119,6 +125,8 @@ const Place& XPUContext::GetPlace() const { return impl_->GetPlace(); } void XPUContext::SetXPUStream(XPUStream stream) { impl_->SetStream(stream); } +XPUStream XPUContext::stream() const { return impl_->stream(); } + backends::xpu::XPUVersion XPUContext::xpu_version() const { return impl_->xpu_version_; } diff --git a/paddle/phi/backends/xpu/xpu_context.h b/paddle/phi/backends/xpu/xpu_context.h index d20a1ad4e1e4867309eeeefcc78f7ee88bb4f4e6..90fc7c97b785cdb36c8c8cbe57dcf57f8a7ee13b 100644 --- a/paddle/phi/backends/xpu/xpu_context.h +++ b/paddle/phi/backends/xpu/xpu_context.h @@ -63,6 +63,8 @@ class XPUContext : public DeviceContext { void SetXPUStream(XPUStream stream); + XPUStream stream() const; + private: struct Impl; std::unique_ptr impl_;