From 8d40e02f8ec53b910c1a2bfdbe4ba392642c8966 Mon Sep 17 00:00:00 2001 From: ykkk2333 <77383312+ykkk2333@users.noreply.github.com> Date: Fri, 10 Mar 2023 13:45:53 +0800 Subject: [PATCH] xpu supports multi-thread dataloader, test=kunlun (#51351) --- paddle/phi/backends/xpu/xpu_context.cc | 108 ++++++++++++++++++++++++- paddle/phi/backends/xpu/xpu_context.h | 4 + 2 files changed, 110 insertions(+), 2 deletions(-) diff --git a/paddle/phi/backends/xpu/xpu_context.cc b/paddle/phi/backends/xpu/xpu_context.cc index b0261b6f6f6..eec65d0d6a3 100644 --- a/paddle/phi/backends/xpu/xpu_context.cc +++ b/paddle/phi/backends/xpu/xpu_context.cc @@ -19,6 +19,7 @@ #include "paddle/phi/api/ext/exception.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/common/place.h" +#include "paddle/phi/core/os_info.h" #include "xpu/runtime.h" #include "xpu/runtime_ex.h" #include "xpu/xdnn.h" @@ -54,6 +55,15 @@ struct XPUContext::Impl { } } + bool IsDataloader() const { + if (std::getenv("XPU_PADDLE_XDL_CONTEXTS") == nullptr) { + return false; + } + std::string cur_thread_name = phi::GetCurrentThreadName(); + bool is_dataloader_thread = (cur_thread_name.substr(0, 10) == "Dataloader"); + return is_dataloader_thread; + } + Impl() : place_(XPUPlace()) {} explicit Impl(const Place& place) : place_(place) {} @@ -71,11 +81,31 @@ struct XPUContext::Impl { xpu::destroy_context(context_); context_ = nullptr; } + if (std::getenv("XPU_PADDLE_XDL_CONTEXTS") != nullptr) { + // destroy all XPU Dataloader threads if exist + backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId()); + for (auto ctx : GetAllXdlCtxs()) { + xpu_wait(ctx->xpu_stream); + if (ctx->xpu_stream) { + xpu_stream_destroy(ctx->xpu_stream); + ctx->xpu_stream = nullptr; + } + xpu::destroy_context(ctx); + ctx = nullptr; + } + } } const Place& GetPlace() const { return place_; } - XPUStream stream() const { return context_->xpu_stream; } + XPUStream stream() const { + if (IsDataloader()) { + std::string cur_thread_name = phi::GetCurrentThreadName(); + xpu::Context* ctx_t = GetXdlCtx(cur_thread_name); + return ctx_t->xpu_stream; + } + return context_->xpu_stream; + } xpu::Context* GetXContext() const { PD_CHECK(context_ != nullptr, "the xpu context is nullptr."); @@ -87,18 +117,53 @@ struct XPUContext::Impl { return bkcl_context_; } + // Overload GetXContext function to set and get + // contexts of XPU Dataloader threads, and keep old GetXContext Method + xpu::Context* GetXContext() { + std::string cur_thread_name = phi::GetCurrentThreadName(); + VLOG(3) << "XPU Dataloader: current thread at Get Context = " + << phi::GetCurrentThreadName(); + if (IsDataloader()) { + SetXdlCtx(cur_thread_name); + xpu::Context* ctx_t = GetXdlCtx(cur_thread_name); + PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr."); + return ctx_t; + } + + PD_CHECK(context_ != nullptr, "the xpu context is nullptr."); + return context_; + } + void Wait() const { backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId()); PD_CHECK(context_ != nullptr, "the xpu context is nullptr."); xpu_wait(context_->xpu_stream); } + // Overload Wait for xpu wait on XPU Dataloader threads streams + void Wait() { + if (IsDataloader()) { + std::string cur_thread_name = phi::GetCurrentThreadName(); + SetXdlCtx(cur_thread_name); + xpu::Context* ctx_t = GetXdlCtx(cur_thread_name); + PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr."); + xpu_wait(GetXdlCtx(cur_thread_name)->xpu_stream); + } + backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId()); + PD_CHECK(context_ != nullptr, "the xpu context is nullptr."); + xpu_wait(context_->xpu_stream); + } + void Init() { owned_ = true; backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId()); LOG_FIRST_N(WARNING, 1) << "Please NOTE: xpu device: " << static_cast(place_.device); context_ = xpu::create_context(); + if (std::getenv("XPU_PADDLE_XDL_CONTEXTS") != nullptr) { + // Initialize XPU Dataloader threads contexts map + InitializeXdlContexts(); + } xpu_version_ = backends::xpu::get_xpu_version(place_.device); SetL3Cache(); } @@ -115,10 +180,48 @@ struct XPUContext::Impl { PADDLE_ENFORCE_XPU_SUCCESS(xpu_stream_create(&context_->xpu_stream)); } + // Methods of XPU Dataloader threads contexts map, + // currently, need set 'export XPU_PADDLE_XDL_CONTEXTS=1' + // to open XPU Dataloader context map + void InitializeXdlContexts() { + if (std::getenv("XPU_PADDLE_XDL_CONTEXTS") == nullptr) { + return; + } + auto thread_map = phi::GetAllThreadNames(); + for (const auto& tp : thread_map) { + std::string t_name = tp.second; + if (t_name.substr(0, 10) == "Dataloader") { + SetXdlCtx(t_name); + } + } + } + + void SetXdlCtx(std::string thread_name) { + if (xdl_context_map_.find(thread_name) == xdl_context_map_.end()) { + xpu::Context* ctx_t = xpu::create_context(); + xdl_context_map_[thread_name] = ctx_t; + } + } + + xpu::Context* GetXdlCtx(const std::string thread_name) const { + return (xdl_context_map_.find(thread_name) == xdl_context_map_.end()) + ? nullptr + : xdl_context_map_.find(thread_name)->second; + } + + std::vector GetAllXdlCtxs() { + std::vector ctxs; + for (const auto& it : xdl_context_map_) { + ctxs.emplace_back(it.second); + } + return ctxs; + } + bool owned_{false}; Place place_; backends::xpu::XPUVersion xpu_version_; xpu::Context* context_{nullptr}; + std::unordered_map xdl_context_map_; // NOTE: Distributed communicator, distributed framework manages its // resources, XPUContext only holds references. @@ -158,6 +261,8 @@ void XPUContext::SetXContext(xpu::Context* context) { void XPUContext::SetL3Cache(int l3_size) { impl_->SetL3Cache(l3_size); } +bool XPUContext::IsDataloader() const { return impl_->IsDataloader(); } + void XPUContext::SetBkclContext(xpu::BKCLContext_t context) { impl_->SetBkclContext(context); } @@ -165,5 +270,4 @@ void XPUContext::SetBkclContext(xpu::BKCLContext_t context) { void XPUContext::CreateStream() { impl_->CreateStream(); } void XPUContext::Init() { impl_->Init(); } - } // namespace phi diff --git a/paddle/phi/backends/xpu/xpu_context.h b/paddle/phi/backends/xpu/xpu_context.h index 79b8dc2f04b..349118f3336 100644 --- a/paddle/phi/backends/xpu/xpu_context.h +++ b/paddle/phi/backends/xpu/xpu_context.h @@ -47,6 +47,10 @@ class XPUContext : public DeviceContext, xpu::Context* x_context() const; + // For multi-thread dataloader, + // check if the current thread is Dataloader thread + bool IsDataloader() const; + // Return bkcl context. xpu::BKCLContext_t bkcl_context() const; void SetBkclContext(xpu::BKCLContext_t context); -- GitLab