未验证 提交 4851c642 编写于 作者: Y ykkk2333 提交者: GitHub

[XPU] add pool3dgrad special dim support (#51727)

* add xpu tile and concat kernel int64, test=kunlun

* fix previous xpu dataoader bug, and add maxpool3dgrad special dim support, test=kunlun
上级 711f0c9b
...@@ -60,6 +60,8 @@ struct XPUContext::Impl { ...@@ -60,6 +60,8 @@ struct XPUContext::Impl {
return false; return false;
} }
std::string cur_thread_name = phi::GetCurrentThreadName(); std::string cur_thread_name = phi::GetCurrentThreadName();
VLOG(3) << "XPU Dataloader: current thread at Get Context = "
<< phi::GetCurrentThreadName();
bool is_dataloader_thread = (cur_thread_name.substr(0, 10) == "Dataloader"); bool is_dataloader_thread = (cur_thread_name.substr(0, 10) == "Dataloader");
return is_dataloader_thread; return is_dataloader_thread;
} }
...@@ -93,6 +95,7 @@ struct XPUContext::Impl { ...@@ -93,6 +95,7 @@ struct XPUContext::Impl {
xpu::destroy_context(ctx); xpu::destroy_context(ctx);
ctx = nullptr; ctx = nullptr;
} }
xdl_context_map_.clear();
} }
} }
...@@ -100,8 +103,7 @@ struct XPUContext::Impl { ...@@ -100,8 +103,7 @@ struct XPUContext::Impl {
XPUStream stream() const { XPUStream stream() const {
if (IsDataloader()) { if (IsDataloader()) {
std::string cur_thread_name = phi::GetCurrentThreadName(); xpu::Context* ctx_t = GetXdlCtx();
xpu::Context* ctx_t = GetXdlCtx(cur_thread_name);
return ctx_t->xpu_stream; return ctx_t->xpu_stream;
} }
return context_->xpu_stream; return context_->xpu_stream;
...@@ -120,12 +122,9 @@ struct XPUContext::Impl { ...@@ -120,12 +122,9 @@ struct XPUContext::Impl {
// Overload GetXContext function to set and get // Overload GetXContext function to set and get
// contexts of XPU Dataloader threads, and keep old GetXContext Method // contexts of XPU Dataloader threads, and keep old GetXContext Method
xpu::Context* GetXContext() { xpu::Context* GetXContext() {
std::string cur_thread_name = phi::GetCurrentThreadName();
VLOG(3) << "XPU Dataloader: current thread at Get Context = "
<< phi::GetCurrentThreadName();
if (IsDataloader()) { if (IsDataloader()) {
SetXdlCtx(cur_thread_name); SetXdlCtx();
xpu::Context* ctx_t = GetXdlCtx(cur_thread_name); xpu::Context* ctx_t = GetXdlCtx();
PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr."); PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr.");
return ctx_t; return ctx_t;
} }
...@@ -135,20 +134,15 @@ struct XPUContext::Impl { ...@@ -135,20 +134,15 @@ struct XPUContext::Impl {
} }
void Wait() const { void Wait() const {
backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId());
PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
xpu_wait(context_->xpu_stream);
}
// Overload Wait for xpu wait on XPU Dataloader threads streams
void Wait() {
if (IsDataloader()) { if (IsDataloader()) {
std::string cur_thread_name = phi::GetCurrentThreadName(); xpu::Context* ctx_t = GetXdlCtx();
SetXdlCtx(cur_thread_name); if (ctx_t) {
xpu::Context* ctx_t = GetXdlCtx(cur_thread_name); PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr.");
PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr."); xpu_wait(ctx_t->xpu_stream);
xpu_wait(GetXdlCtx(cur_thread_name)->xpu_stream); }
return;
} }
backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId()); backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId());
PD_CHECK(context_ != nullptr, "the xpu context is nullptr."); PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
xpu_wait(context_->xpu_stream); xpu_wait(context_->xpu_stream);
...@@ -191,22 +185,24 @@ struct XPUContext::Impl { ...@@ -191,22 +185,24 @@ struct XPUContext::Impl {
for (const auto& tp : thread_map) { for (const auto& tp : thread_map) {
std::string t_name = tp.second; std::string t_name = tp.second;
if (t_name.substr(0, 10) == "Dataloader") { if (t_name.substr(0, 10) == "Dataloader") {
SetXdlCtx(t_name); SetXdlCtx();
} }
} }
} }
void SetXdlCtx(std::string thread_name) { void SetXdlCtx() {
if (xdl_context_map_.find(thread_name) == xdl_context_map_.end()) { auto pid = phi::GetProcessId();
if (xdl_context_map_.find(pid) == xdl_context_map_.end()) {
xpu::Context* ctx_t = xpu::create_context(); xpu::Context* ctx_t = xpu::create_context();
xdl_context_map_[thread_name] = ctx_t; xdl_context_map_[pid] = ctx_t;
} }
} }
xpu::Context* GetXdlCtx(const std::string thread_name) const { xpu::Context* GetXdlCtx() const {
return (xdl_context_map_.find(thread_name) == xdl_context_map_.end()) auto pid = phi::GetProcessId();
return (xdl_context_map_.find(pid) == xdl_context_map_.end())
? nullptr ? nullptr
: xdl_context_map_.find(thread_name)->second; : xdl_context_map_.find(pid)->second;
} }
std::vector<xpu::Context*> GetAllXdlCtxs() { std::vector<xpu::Context*> GetAllXdlCtxs() {
...@@ -221,7 +217,7 @@ struct XPUContext::Impl { ...@@ -221,7 +217,7 @@ struct XPUContext::Impl {
Place place_; Place place_;
backends::xpu::XPUVersion xpu_version_; backends::xpu::XPUVersion xpu_version_;
xpu::Context* context_{nullptr}; xpu::Context* context_{nullptr};
std::unordered_map<std::string, xpu::Context*> xdl_context_map_; std::unordered_map<uint32_t, xpu::Context*> xdl_context_map_;
// NOTE: Distributed communicator, distributed framework manages its // NOTE: Distributed communicator, distributed framework manages its
// resources, XPUContext only holds references. // resources, XPUContext only holds references.
......
...@@ -340,22 +340,41 @@ void Pool3dGradKernel(const Context& ctx, ...@@ -340,22 +340,41 @@ void Pool3dGradKernel(const Context& ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adaptive_pool3d_grad"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "adaptive_pool3d_grad");
} else { } else {
if (pooling_type == "max") { if (pooling_type == "max") {
r = xpu::max_pool3d_grad<XPUType>( if (kernel_size[0] == 1 && kernel_size.size() == 3 &&
ctx.x_context(), strides.size() == 3 && paddings.size() == 6) {
reinterpret_cast<const XPUType*>(x.data<T>()), r = xpu::max_pool2d_grad<XPUType>(
reinterpret_cast<const XPUType*>(out.data<T>()), ctx.x_context(),
index_data, reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(dout.data<T>()), reinterpret_cast<const XPUType*>(out.data<T>()),
reinterpret_cast<XPUType*>(dx->data<T>()), index_data,
n, reinterpret_cast<const XPUType*>(dout.data<T>()),
c, reinterpret_cast<XPUType*>(dx->data<T>()),
in_d, n,
in_h, c * in_d,
in_w, in_h,
kernel_size, in_w,
strides, {kernel_size[1], kernel_size[2]},
paddings, {strides[1], strides[2]},
!channel_last); {paddings[2], paddings[3], paddings[4], paddings[5]},
!channel_last);
} else {
r = xpu::max_pool3d_grad<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(out.data<T>()),
index_data,
reinterpret_cast<const XPUType*>(dout.data<T>()),
reinterpret_cast<XPUType*>(dx->data<T>()),
n,
c,
in_d,
in_h,
in_w,
kernel_size,
strides,
paddings,
!channel_last);
}
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
r = xpu::avg_pool3d_grad<XPUType>( r = xpu::avg_pool3d_grad<XPUType>(
ctx.x_context(), ctx.x_context(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册