未验证 提交 41bdf41d 编写于 作者: Z zhangxiaoci 提交者: GitHub

change to async mode for xpu multi-card training in static graph mode, test=kunlun (#45024)

* change to async mode for xpu multi-card training in static graph mode

* minor bugfix

* irrelevant. move to another pr

* move change to other pr

* fix stream issue

* fix 'stream not meet with current context' error

* fix branch diverge, test=kunlun
上级 041ef22c
......@@ -276,19 +276,26 @@ void AllReduceOpHandle::BKCLAllReduceFunc(
if (all_reduce_calls.size() == 1UL) {
all_reduce_calls[0]();
} else {
PADDLE_ENFORCE_EQ(
bkcl_group_start(),
BKCL_SUCCESS,
platform::errors::PreconditionNotMet("bkcl_group_start failed"));
platform::BKCLGroupGuard guard;
for (auto &call : all_reduce_calls) {
call();
}
PADDLE_ENFORCE_EQ(
bkcl_group_end(),
BKCL_SUCCESS,
platform::errors::PreconditionNotMet("bkcl_group_end failed"));
}
});
SyncBKCLAllReduce();
}
void AllReduceOpHandle::SyncBKCLAllReduce() {
// bkcl always use async kernel
for (auto &p : places_) {
int dev_id = p.device;
platform::SetXPUDeviceId(dev_id);
auto *bkcl_ctxs =
bkcl_ctxs_->GetRunEnvBKCLCtx(run_order_, use_hierarchical_allreduce_);
auto &bkcl_ctx = bkcl_ctxs->at(dev_id);
platform::XPUStreamSync(bkcl_ctx.stream());
}
}
#endif
......
......@@ -94,6 +94,8 @@ class AllReduceOpHandle : public OpHandleBase {
#if defined(PADDLE_WITH_XPU_BKCL)
void BKCLAllReduceFunc(
const std::vector<std::function<void()>> &all_reduce_calls);
void SyncBKCLAllReduce();
#endif
void AllReduceImpl(const std::vector<VarHandle *> &in_var_handles,
......
......@@ -92,7 +92,9 @@ class BKCLOpHandleBase : public OpHandleBase {
"The argument run_order_ must be >= 0, but got %d.", run_order_));
auto flat_bkcl_ctxs = bkcl_ctxs_->GetFlatCtx(run_order_);
int dev_id = place.device;
platform::SetXPUDeviceId(dev_id);
auto& bkcl_ctx = flat_bkcl_ctxs->at(dev_id);
auto stream = bkcl_ctx.stream();
auto comm = bkcl_ctx.comm_;
VLOG(10) << "before all reduce buffer:" << sendbuff << ", numel:" << count
......@@ -100,7 +102,7 @@ class BKCLOpHandleBase : public OpHandleBase {
<< ", place:" << place;
PADDLE_ENFORCE_EQ(
bkcl_all_reduce(comm, sendbuff, recvbuff, count, datatype, op, NULL),
bkcl_all_reduce(comm, sendbuff, recvbuff, count, datatype, op, stream),
BKCL_SUCCESS,
platform::errors::PreconditionNotMet("bckl all reduce failed"));
}
......
......@@ -27,6 +27,7 @@
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/xpu/enforce_xpu.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
......@@ -57,6 +58,24 @@ inline BKCLDataType ToBKCLDataType(framework::proto::VarType::Type type) {
}
}
class BKCLGroupGuard {
public:
static std::mutex &BKCLMutex() {
static std::mutex mtx;
return mtx;
}
inline BKCLGroupGuard() {
BKCLMutex().lock();
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_start());
}
inline ~BKCLGroupGuard() PADDLE_MAY_THROW {
PADDLE_ENFORCE_XPU_SUCCESS(bkcl_group_end());
BKCLMutex().unlock();
}
};
struct BKCLContext {
std::unique_ptr<platform::XPUDeviceContext> ctx_;
BKCLContext_t comm_;
......@@ -65,6 +84,7 @@ struct BKCLContext {
: ctx_(new platform::XPUDeviceContext(XPUPlace(dev_id))),
comm_{nullptr} {}
XPUStream stream() const { return ctx_->stream(); }
BKCLContext_t comm() const { return comm_; }
int device_id() const { return ctx_->GetPlace().device; }
......@@ -258,19 +278,33 @@ class BKCLCommunicator {
ptr->init();
VLOG(1) << "init local trainer";
flat_ctxs_.emplace_back(ptr);
return;
} else {
PADDLE_ENFORCE_EQ(bkcl_ids.size(),
1,
platform::errors::Unimplemented(
"Multi-all-reduce-ring is not support for XPU"));
for (size_t i = 0; i < bkcl_ids.size(); i++) {
auto ptr = new platform::BKCLContextMap(
places, bkcl_ids[i], trainers_num, trainer_id);
ptr->init();
VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i;
flat_ctxs_.emplace_back(ptr);
}
}
PADDLE_ENFORCE_EQ(bkcl_ids.size(),
1,
platform::errors::Unimplemented(
"Multi-all-reduce-ring is not support for XPU"));
for (size_t i = 0; i < bkcl_ids.size(); i++) {
auto ptr = new platform::BKCLContextMap(
places, bkcl_ids[i], trainers_num, trainer_id);
ptr->init();
VLOG(1) << "init trainer_id:" << trainer_id << ", comm no:" << i;
flat_ctxs_.emplace_back(ptr);
// as Executor have no way to use BKCLComm created by ParallelExecutor,
// we assign all flatten contexts to BKCLCommContext to fix.
int nranks = static_cast<int>(trainers_num * places.size());
int nrings = static_cast<int>(flat_ctxs_.size());
for (int ring_id = 0; ring_id < nrings; ++ring_id) {
for (size_t p = 0; p < places.size(); ++p) {
int rank = trainer_id * places.size() + p;
int dev_id = places[p].device;
platform::SetXPUDeviceId(dev_id);
auto &ctx = flat_ctxs_[ring_id]->contexts_.at(dev_id);
BKCLCommContext::Instance().AssignBKCLComm(
ctx.comm_, nranks, rank, dev_id, ring_id);
}
}
}
......
......@@ -68,6 +68,12 @@ struct XPUContext::Impl {
void SetStream(XPUStream stream) { context_->xpu_stream = stream; }
XPUStream stream() const {
auto s = context_->xpu_stream;
PD_CHECK(s != nullptr, "the xpu stream is nullptr.");
return s;
}
xpu::Context* GetXContext() const {
PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
return context_;
......@@ -119,6 +125,8 @@ const Place& XPUContext::GetPlace() const { return impl_->GetPlace(); }
void XPUContext::SetXPUStream(XPUStream stream) { impl_->SetStream(stream); }
XPUStream XPUContext::stream() const { return impl_->stream(); }
backends::xpu::XPUVersion XPUContext::xpu_version() const {
return impl_->xpu_version_;
}
......
......@@ -63,6 +63,8 @@ class XPUContext : public DeviceContext {
void SetXPUStream(XPUStream stream);
XPUStream stream() const;
private:
struct Impl;
std::unique_ptr<Impl> impl_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册