未验证 提交 8d40e02f 编写于 作者: Y ykkk2333 提交者: GitHub

xpu supports multi-thread dataloader, test=kunlun (#51351)

上级 1cffb1ff
......@@ -19,6 +19,7 @@
#include "paddle/phi/api/ext/exception.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/os_info.h"
#include "xpu/runtime.h"
#include "xpu/runtime_ex.h"
#include "xpu/xdnn.h"
......@@ -54,6 +55,15 @@ struct XPUContext::Impl {
}
}
bool IsDataloader() const {
if (std::getenv("XPU_PADDLE_XDL_CONTEXTS") == nullptr) {
return false;
}
std::string cur_thread_name = phi::GetCurrentThreadName();
bool is_dataloader_thread = (cur_thread_name.substr(0, 10) == "Dataloader");
return is_dataloader_thread;
}
Impl() : place_(XPUPlace()) {}
explicit Impl(const Place& place) : place_(place) {}
......@@ -71,11 +81,31 @@ struct XPUContext::Impl {
xpu::destroy_context(context_);
context_ = nullptr;
}
if (std::getenv("XPU_PADDLE_XDL_CONTEXTS") != nullptr) {
// destroy all XPU Dataloader threads if exist
backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId());
for (auto ctx : GetAllXdlCtxs()) {
xpu_wait(ctx->xpu_stream);
if (ctx->xpu_stream) {
xpu_stream_destroy(ctx->xpu_stream);
ctx->xpu_stream = nullptr;
}
xpu::destroy_context(ctx);
ctx = nullptr;
}
}
}
const Place& GetPlace() const { return place_; }
XPUStream stream() const { return context_->xpu_stream; }
XPUStream stream() const {
if (IsDataloader()) {
std::string cur_thread_name = phi::GetCurrentThreadName();
xpu::Context* ctx_t = GetXdlCtx(cur_thread_name);
return ctx_t->xpu_stream;
}
return context_->xpu_stream;
}
xpu::Context* GetXContext() const {
PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
......@@ -87,18 +117,53 @@ struct XPUContext::Impl {
return bkcl_context_;
}
// Overload GetXContext function to set and get
// contexts of XPU Dataloader threads, and keep old GetXContext Method
xpu::Context* GetXContext() {
std::string cur_thread_name = phi::GetCurrentThreadName();
VLOG(3) << "XPU Dataloader: current thread at Get Context = "
<< phi::GetCurrentThreadName();
if (IsDataloader()) {
SetXdlCtx(cur_thread_name);
xpu::Context* ctx_t = GetXdlCtx(cur_thread_name);
PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr.");
return ctx_t;
}
PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
return context_;
}
void Wait() const {
backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId());
PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
xpu_wait(context_->xpu_stream);
}
// Overload Wait for xpu wait on XPU Dataloader threads streams
void Wait() {
if (IsDataloader()) {
std::string cur_thread_name = phi::GetCurrentThreadName();
SetXdlCtx(cur_thread_name);
xpu::Context* ctx_t = GetXdlCtx(cur_thread_name);
PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr.");
xpu_wait(GetXdlCtx(cur_thread_name)->xpu_stream);
}
backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId());
PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
xpu_wait(context_->xpu_stream);
}
void Init() {
owned_ = true;
backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId());
LOG_FIRST_N(WARNING, 1)
<< "Please NOTE: xpu device: " << static_cast<int>(place_.device);
context_ = xpu::create_context();
if (std::getenv("XPU_PADDLE_XDL_CONTEXTS") != nullptr) {
// Initialize XPU Dataloader threads contexts map
InitializeXdlContexts();
}
xpu_version_ = backends::xpu::get_xpu_version(place_.device);
SetL3Cache();
}
......@@ -115,10 +180,48 @@ struct XPUContext::Impl {
PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&context_->xpu_stream));
}
// Methods of XPU Dataloader threads contexts map,
// currently, need set 'export XPU_PADDLE_XDL_CONTEXTS=1'
// to open XPU Dataloader context map
void InitializeXdlContexts() {
if (std::getenv("XPU_PADDLE_XDL_CONTEXTS") == nullptr) {
return;
}
auto thread_map = phi::GetAllThreadNames();
for (const auto& tp : thread_map) {
std::string t_name = tp.second;
if (t_name.substr(0, 10) == "Dataloader") {
SetXdlCtx(t_name);
}
}
}
void SetXdlCtx(std::string thread_name) {
if (xdl_context_map_.find(thread_name) == xdl_context_map_.end()) {
xpu::Context* ctx_t = xpu::create_context();
xdl_context_map_[thread_name] = ctx_t;
}
}
xpu::Context* GetXdlCtx(const std::string thread_name) const {
return (xdl_context_map_.find(thread_name) == xdl_context_map_.end())
? nullptr
: xdl_context_map_.find(thread_name)->second;
}
std::vector<xpu::Context*> GetAllXdlCtxs() {
std::vector<xpu::Context*> ctxs;
for (const auto& it : xdl_context_map_) {
ctxs.emplace_back(it.second);
}
return ctxs;
}
bool owned_{false};
Place place_;
backends::xpu::XPUVersion xpu_version_;
xpu::Context* context_{nullptr};
std::unordered_map<std::string, xpu::Context*> xdl_context_map_;
// NOTE: Distributed communicator, distributed framework manages its
// resources, XPUContext only holds references.
......@@ -158,6 +261,8 @@ void XPUContext::SetXContext(xpu::Context* context) {
void XPUContext::SetL3Cache(int l3_size) { impl_->SetL3Cache(l3_size); }
bool XPUContext::IsDataloader() const { return impl_->IsDataloader(); }
void XPUContext::SetBkclContext(xpu::BKCLContext_t context) {
impl_->SetBkclContext(context);
}
......@@ -165,5 +270,4 @@ void XPUContext::SetBkclContext(xpu::BKCLContext_t context) {
void XPUContext::CreateStream() { impl_->CreateStream(); }
void XPUContext::Init() { impl_->Init(); }
} // namespace phi
......@@ -47,6 +47,10 @@ class XPUContext : public DeviceContext,
xpu::Context* x_context() const;
// For multi-thread dataloader,
// check if the current thread is Dataloader thread
bool IsDataloader() const;
// Return bkcl context.
xpu::BKCLContext_t bkcl_context() const;
void SetBkclContext(xpu::BKCLContext_t context);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册