diff --git a/paddle/phi/backends/xpu/xpu_context.cc b/paddle/phi/backends/xpu/xpu_context.cc index eec65d0d6a3eb5efc6424257e4127c16c22e16d3..acb8ae8db3b3a3d1f28e9a43b74e3557d08479ca 100644 --- a/paddle/phi/backends/xpu/xpu_context.cc +++ b/paddle/phi/backends/xpu/xpu_context.cc @@ -60,6 +60,8 @@ struct XPUContext::Impl { return false; } std::string cur_thread_name = phi::GetCurrentThreadName(); + VLOG(3) << "XPU Dataloader: current thread at Get Context = " + << phi::GetCurrentThreadName(); bool is_dataloader_thread = (cur_thread_name.substr(0, 10) == "Dataloader"); return is_dataloader_thread; } @@ -93,6 +95,7 @@ struct XPUContext::Impl { xpu::destroy_context(ctx); ctx = nullptr; } + xdl_context_map_.clear(); } } @@ -100,8 +103,7 @@ struct XPUContext::Impl { XPUStream stream() const { if (IsDataloader()) { - std::string cur_thread_name = phi::GetCurrentThreadName(); - xpu::Context* ctx_t = GetXdlCtx(cur_thread_name); + xpu::Context* ctx_t = GetXdlCtx(); return ctx_t->xpu_stream; } return context_->xpu_stream; @@ -120,12 +122,9 @@ struct XPUContext::Impl { // Overload GetXContext function to set and get // contexts of XPU Dataloader threads, and keep old GetXContext Method xpu::Context* GetXContext() { - std::string cur_thread_name = phi::GetCurrentThreadName(); - VLOG(3) << "XPU Dataloader: current thread at Get Context = " - << phi::GetCurrentThreadName(); if (IsDataloader()) { - SetXdlCtx(cur_thread_name); - xpu::Context* ctx_t = GetXdlCtx(cur_thread_name); + SetXdlCtx(); + xpu::Context* ctx_t = GetXdlCtx(); PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr."); return ctx_t; } @@ -135,20 +134,15 @@ struct XPUContext::Impl { } void Wait() const { - backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId()); - PD_CHECK(context_ != nullptr, "the xpu context is nullptr."); - xpu_wait(context_->xpu_stream); - } - - // Overload Wait for xpu wait on XPU Dataloader threads streams - void Wait() { if (IsDataloader()) { - std::string cur_thread_name = phi::GetCurrentThreadName(); - SetXdlCtx(cur_thread_name); - xpu::Context* ctx_t = GetXdlCtx(cur_thread_name); - PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr."); - xpu_wait(GetXdlCtx(cur_thread_name)->xpu_stream); + xpu::Context* ctx_t = GetXdlCtx(); + if (ctx_t) { + PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr."); + xpu_wait(ctx_t->xpu_stream); + } + return; } + backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId()); PD_CHECK(context_ != nullptr, "the xpu context is nullptr."); xpu_wait(context_->xpu_stream); @@ -191,22 +185,24 @@ struct XPUContext::Impl { for (const auto& tp : thread_map) { std::string t_name = tp.second; if (t_name.substr(0, 10) == "Dataloader") { - SetXdlCtx(t_name); + SetXdlCtx(); } } } - void SetXdlCtx(std::string thread_name) { - if (xdl_context_map_.find(thread_name) == xdl_context_map_.end()) { + void SetXdlCtx() { + auto pid = phi::GetProcessId(); + if (xdl_context_map_.find(pid) == xdl_context_map_.end()) { xpu::Context* ctx_t = xpu::create_context(); - xdl_context_map_[thread_name] = ctx_t; + xdl_context_map_[pid] = ctx_t; } } - xpu::Context* GetXdlCtx(const std::string thread_name) const { - return (xdl_context_map_.find(thread_name) == xdl_context_map_.end()) + xpu::Context* GetXdlCtx() const { + auto pid = phi::GetProcessId(); + return (xdl_context_map_.find(pid) == xdl_context_map_.end()) ? nullptr - : xdl_context_map_.find(thread_name)->second; + : xdl_context_map_.find(pid)->second; } std::vector GetAllXdlCtxs() { @@ -221,7 +217,7 @@ struct XPUContext::Impl { Place place_; backends::xpu::XPUVersion xpu_version_; xpu::Context* context_{nullptr}; - std::unordered_map xdl_context_map_; + std::unordered_map xdl_context_map_; // NOTE: Distributed communicator, distributed framework manages its // resources, XPUContext only holds references. diff --git a/paddle/phi/kernels/xpu/pool_grad_kernel.cc b/paddle/phi/kernels/xpu/pool_grad_kernel.cc index 6f937b93e1976248288ff0da9bb967cf1596adee..dfea57231560ad370b3c90e1198cfe698bbfc118 100644 --- a/paddle/phi/kernels/xpu/pool_grad_kernel.cc +++ b/paddle/phi/kernels/xpu/pool_grad_kernel.cc @@ -340,22 +340,41 @@ void Pool3dGradKernel(const Context& ctx, PADDLE_ENFORCE_XDNN_SUCCESS(r, "adaptive_pool3d_grad"); } else { if (pooling_type == "max") { - r = xpu::max_pool3d_grad( - ctx.x_context(), - reinterpret_cast(x.data()), - reinterpret_cast(out.data()), - index_data, - reinterpret_cast(dout.data()), - reinterpret_cast(dx->data()), - n, - c, - in_d, - in_h, - in_w, - kernel_size, - strides, - paddings, - !channel_last); + if (kernel_size[0] == 1 && kernel_size.size() == 3 && + strides.size() == 3 && paddings.size() == 6) { + r = xpu::max_pool2d_grad( + ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out.data()), + index_data, + reinterpret_cast(dout.data()), + reinterpret_cast(dx->data()), + n, + c * in_d, + in_h, + in_w, + {kernel_size[1], kernel_size[2]}, + {strides[1], strides[2]}, + {paddings[2], paddings[3], paddings[4], paddings[5]}, + !channel_last); + } else { + r = xpu::max_pool3d_grad( + ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out.data()), + index_data, + reinterpret_cast(dout.data()), + reinterpret_cast(dx->data()), + n, + c, + in_d, + in_h, + in_w, + kernel_size, + strides, + paddings, + !channel_last); + } } else if (pooling_type == "avg") { r = xpu::avg_pool3d_grad( ctx.x_context(),