From 41bdf41da74bc88db0d3a64fc41249bf38ac8a66 Mon Sep 17 00:00:00 2001 From: zhangxiaoci Date: Thu, 18 Aug 2022 14:20:41 +0800 Subject: [PATCH] change to async mode for xpu multi-card training in static graph mode, test=kunlun (#45024) * change to async mode for xpu multi-card training in static graph mode * minor bugfix * irrelevant. move to another pr * move change to other pr * fix stream issue * fix 'stream not meet with current context' error * fix branch diverge, test=kunlun --- .../framework/details/all_reduce_op_handle.cc | 23 +++++--- .../framework/details/all_reduce_op_handle.h | 2 + .../fluid/framework/details/bkcl_op_handle.h | 4 +- .../fluid/platform/device/xpu/bkcl_helper.h | 56 +++++++++++++++---- paddle/phi/backends/xpu/xpu_context.cc | 8 +++ paddle/phi/backends/xpu/xpu_context.h | 2 + 6 files changed, 75 insertions(+), 20 deletions(-) diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index 380ae841cd0..01f707eb9ba 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -276,19 +276,26 @@ void AllReduceOpHandle::BKCLAllReduceFunc( if (all_reduce_calls.size() == 1UL) { all_reduce_calls[0](); } else { - PADDLE_ENFORCE_EQ( - bkcl_group_start(), - BKCL_SUCCESS, - platform::errors::PreconditionNotMet("bkcl_group_start failed")); + platform::BKCLGroupGuard guard; for (auto &call : all_reduce_calls) { call(); } - PADDLE_ENFORCE_EQ( - bkcl_group_end(), - BKCL_SUCCESS, - platform::errors::PreconditionNotMet("bkcl_group_end failed")); } }); + + SyncBKCLAllReduce(); +} + +void AllReduceOpHandle::SyncBKCLAllReduce() { + // bkcl always use async kernel + for (auto &p : places_) { + int dev_id = p.device; + platform::SetXPUDeviceId(dev_id); + auto *bkcl_ctxs = + bkcl_ctxs_->GetRunEnvBKCLCtx(run_order_, use_hierarchical_allreduce_); + auto &bkcl_ctx = bkcl_ctxs->at(dev_id); + platform::XPUStreamSync(bkcl_ctx.stream()); + } } #endif diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.h b/paddle/fluid/framework/details/all_reduce_op_handle.h index d628a3b1ee1..685ab0b957a 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.h +++ b/paddle/fluid/framework/details/all_reduce_op_handle.h @@ -94,6 +94,8 @@ class AllReduceOpHandle : public OpHandleBase { #if defined(PADDLE_WITH_XPU_BKCL) void BKCLAllReduceFunc( const std::vector> &all_reduce_calls); + + void SyncBKCLAllReduce(); #endif void AllReduceImpl(const std::vector &in_var_handles, diff --git a/paddle/fluid/framework/details/bkcl_op_handle.h b/paddle/fluid/framework/details/bkcl_op_handle.h index 650e69d0cbf..4ca8bf4cb58 100644 --- a/paddle/fluid/framework/details/bkcl_op_handle.h +++ b/paddle/fluid/framework/details/bkcl_op_handle.h @@ -92,7 +92,9 @@ class BKCLOpHandleBase : public OpHandleBase { "The argument run_order_ must be >= 0, but got %d.", run_order_)); auto flat_bkcl_ctxs = bkcl_ctxs_->GetFlatCtx(run_order_); int dev_id = place.device; + platform::SetXPUDeviceId(dev_id); auto& bkcl_ctx = flat_bkcl_ctxs->at(dev_id); + auto stream = bkcl_ctx.stream(); auto comm = bkcl_ctx.comm_; VLOG(10) << "before all reduce buffer:" << sendbuff << ", numel:" << count @@ -100,7 +102,7 @@ class BKCLOpHandleBase : public OpHandleBase { << ", place:" << place; PADDLE_ENFORCE_EQ( - bkcl_all_reduce(comm, sendbuff, recvbuff, count, datatype, op, NULL), + bkcl_all_reduce(comm, sendbuff, recvbuff, count, datatype, op, stream), BKCL_SUCCESS, platform::errors::PreconditionNotMet("bckl all reduce failed")); } diff --git a/paddle/fluid/platform/device/xpu/bkcl_helper.h b/paddle/fluid/platform/device/xpu/bkcl_helper.h index c1882eb7520..7bd1b67efd6 100644 --- a/paddle/fluid/platform/device/xpu/bkcl_helper.h +++ b/paddle/fluid/platform/device/xpu/bkcl_helper.h @@ -27,6 +27,7 @@ #include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/device/xpu/enforce_xpu.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/float16.h" @@ -57,6 +58,24 @@ inline BKCLDataType ToBKCLDataType(framework::proto::VarType::Type type) { } } +class BKCLGroupGuard { + public: + static std::mutex &BKCLMutex() { + static std::mutex mtx; + return mtx; + } + + inline BKCLGroupGuard() { + BKCLMutex().lock(); + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start()); + } + + inline ~BKCLGroupGuard() PADDLE_MAY_THROW { + PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_end()); + BKCLMutex().unlock(); + } +}; + struct BKCLContext { std::unique_ptr ctx_; BKCLContext_t comm_; @@ -65,6 +84,7 @@ struct BKCLContext { : ctx_(new platform::XPUDeviceContext(XPUPlace(dev_id))), comm_{nullptr} {} + XPUStream stream() const { return ctx_->stream(); } BKCLContext_t comm() const { return comm_; } int device_id() const { return ctx_->GetPlace().device; } @@ -258,19 +278,33 @@ class BKCLCommunicator { ptr->init(); VLOG(1) << "init local trainer"; flat_ctxs_.emplace_back(ptr); - return; + } else { + PADDLE_ENFORCE_EQ(bkcl_ids.size(), + 1, + platform::errors::Unimplemented( + "Multi-all-reduce-ring is not support for XPU")); + for (size_t i = 0; i < bkcl_ids.size(); i++) { + auto ptr = new platform::BKCLContextMap( + places, bkcl_ids[i], trainers_num, trainer_id); + ptr->init(); + VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i; + flat_ctxs_.emplace_back(ptr); + } } - PADDLE_ENFORCE_EQ(bkcl_ids.size(), - 1, - platform::errors::Unimplemented( - "Multi-all-reduce-ring is not support for XPU")); - for (size_t i = 0; i < bkcl_ids.size(); i++) { - auto ptr = new platform::BKCLContextMap( - places, bkcl_ids[i], trainers_num, trainer_id); - ptr->init(); - VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i; - flat_ctxs_.emplace_back(ptr); + // as Executor have no way to use BKCLComm created by ParallelExecutor, + // we assign all flatten contexts to BKCLCommContext to fix. + int nranks = static_cast(trainers_num * places.size()); + int nrings = static_cast(flat_ctxs_.size()); + for (int ring_id = 0; ring_id < nrings; ++ring_id) { + for (size_t p = 0; p < places.size(); ++p) { + int rank = trainer_id * places.size() + p; + int dev_id = places[p].device; + platform::SetXPUDeviceId(dev_id); + auto &ctx = flat_ctxs_[ring_id]->contexts_.at(dev_id); + BKCLCommContext::Instance().AssignBKCLComm( + ctx.comm_, nranks, rank, dev_id, ring_id); + } } } diff --git a/paddle/phi/backends/xpu/xpu_context.cc b/paddle/phi/backends/xpu/xpu_context.cc index fe0dda2d3db..2735e2a4208 100644 --- a/paddle/phi/backends/xpu/xpu_context.cc +++ b/paddle/phi/backends/xpu/xpu_context.cc @@ -68,6 +68,12 @@ struct XPUContext::Impl { void SetStream(XPUStream stream) { context_->xpu_stream = stream; } + XPUStream stream() const { + auto s = context_->xpu_stream; + PD_CHECK(s != nullptr, "the xpu stream is nullptr."); + return s; + } + xpu::Context* GetXContext() const { PD_CHECK(context_ != nullptr, "the xpu context is nullptr."); return context_; @@ -119,6 +125,8 @@ const Place& XPUContext::GetPlace() const { return impl_->GetPlace(); } void XPUContext::SetXPUStream(XPUStream stream) { impl_->SetStream(stream); } +XPUStream XPUContext::stream() const { return impl_->stream(); } + backends::xpu::XPUVersion XPUContext::xpu_version() const { return impl_->xpu_version_; } diff --git a/paddle/phi/backends/xpu/xpu_context.h b/paddle/phi/backends/xpu/xpu_context.h index d20a1ad4e1e..90fc7c97b78 100644 --- a/paddle/phi/backends/xpu/xpu_context.h +++ b/paddle/phi/backends/xpu/xpu_context.h @@ -63,6 +63,8 @@ class XPUContext : public DeviceContext { void SetXPUStream(XPUStream stream); + XPUStream stream() const; + private: struct Impl; std::unique_ptr impl_; -- GitLab