未验证 提交 4851c642 编写于 作者: Y ykkk2333 提交者: GitHub

[XPU] add pool3dgrad special dim support (#51727)

* add xpu tile and concat kernel int64, test=kunlun

* fix previous xpu dataoader bug, and add maxpool3dgrad special dim support, test=kunlun
上级 711f0c9b
......@@ -60,6 +60,8 @@ struct XPUContext::Impl {
return false;
}
std::string cur_thread_name = phi::GetCurrentThreadName();
VLOG(3) << "XPU Dataloader: current thread at Get Context = "
<< phi::GetCurrentThreadName();
bool is_dataloader_thread = (cur_thread_name.substr(0, 10) == "Dataloader");
return is_dataloader_thread;
}
......@@ -93,6 +95,7 @@ struct XPUContext::Impl {
xpu::destroy_context(ctx);
ctx = nullptr;
}
xdl_context_map_.clear();
}
}
......@@ -100,8 +103,7 @@ struct XPUContext::Impl {
XPUStream stream() const {
if (IsDataloader()) {
std::string cur_thread_name = phi::GetCurrentThreadName();
xpu::Context* ctx_t = GetXdlCtx(cur_thread_name);
xpu::Context* ctx_t = GetXdlCtx();
return ctx_t->xpu_stream;
}
return context_->xpu_stream;
......@@ -120,12 +122,9 @@ struct XPUContext::Impl {
// Overload GetXContext function to set and get
// contexts of XPU Dataloader threads, and keep old GetXContext Method
xpu::Context* GetXContext() {
std::string cur_thread_name = phi::GetCurrentThreadName();
VLOG(3) << "XPU Dataloader: current thread at Get Context = "
<< phi::GetCurrentThreadName();
if (IsDataloader()) {
SetXdlCtx(cur_thread_name);
xpu::Context* ctx_t = GetXdlCtx(cur_thread_name);
SetXdlCtx();
xpu::Context* ctx_t = GetXdlCtx();
PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr.");
return ctx_t;
}
......@@ -135,20 +134,15 @@ struct XPUContext::Impl {
}
void Wait() const {
backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId());
PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
xpu_wait(context_->xpu_stream);
}
// Overload Wait for xpu wait on XPU Dataloader threads streams
void Wait() {
if (IsDataloader()) {
std::string cur_thread_name = phi::GetCurrentThreadName();
SetXdlCtx(cur_thread_name);
xpu::Context* ctx_t = GetXdlCtx(cur_thread_name);
PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr.");
xpu_wait(GetXdlCtx(cur_thread_name)->xpu_stream);
xpu::Context* ctx_t = GetXdlCtx();
if (ctx_t) {
PD_CHECK(ctx_t != nullptr, "the xpu dataloader context is nullptr.");
xpu_wait(ctx_t->xpu_stream);
}
return;
}
backends::xpu::XPUDeviceGuard guard(place_.GetDeviceId());
PD_CHECK(context_ != nullptr, "the xpu context is nullptr.");
xpu_wait(context_->xpu_stream);
......@@ -191,22 +185,24 @@ struct XPUContext::Impl {
for (const auto& tp : thread_map) {
std::string t_name = tp.second;
if (t_name.substr(0, 10) == "Dataloader") {
SetXdlCtx(t_name);
SetXdlCtx();
}
}
}
void SetXdlCtx(std::string thread_name) {
if (xdl_context_map_.find(thread_name) == xdl_context_map_.end()) {
void SetXdlCtx() {
auto pid = phi::GetProcessId();
if (xdl_context_map_.find(pid) == xdl_context_map_.end()) {
xpu::Context* ctx_t = xpu::create_context();
xdl_context_map_[thread_name] = ctx_t;
xdl_context_map_[pid] = ctx_t;
}
}
xpu::Context* GetXdlCtx(const std::string thread_name) const {
return (xdl_context_map_.find(thread_name) == xdl_context_map_.end())
xpu::Context* GetXdlCtx() const {
auto pid = phi::GetProcessId();
return (xdl_context_map_.find(pid) == xdl_context_map_.end())
? nullptr
: xdl_context_map_.find(thread_name)->second;
: xdl_context_map_.find(pid)->second;
}
std::vector<xpu::Context*> GetAllXdlCtxs() {
......@@ -221,7 +217,7 @@ struct XPUContext::Impl {
Place place_;
backends::xpu::XPUVersion xpu_version_;
xpu::Context* context_{nullptr};
std::unordered_map<std::string, xpu::Context*> xdl_context_map_;
std::unordered_map<uint32_t, xpu::Context*> xdl_context_map_;
// NOTE: Distributed communicator, distributed framework manages its
// resources, XPUContext only holds references.
......
......@@ -340,22 +340,41 @@ void Pool3dGradKernel(const Context& ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adaptive_pool3d_grad");
} else {
if (pooling_type == "max") {
r = xpu::max_pool3d_grad<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(out.data<T>()),
index_data,
reinterpret_cast<const XPUType*>(dout.data<T>()),
reinterpret_cast<XPUType*>(dx->data<T>()),
n,
c,
in_d,
in_h,
in_w,
kernel_size,
strides,
paddings,
!channel_last);
if (kernel_size[0] == 1 && kernel_size.size() == 3 &&
strides.size() == 3 && paddings.size() == 6) {
r = xpu::max_pool2d_grad<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(out.data<T>()),
index_data,
reinterpret_cast<const XPUType*>(dout.data<T>()),
reinterpret_cast<XPUType*>(dx->data<T>()),
n,
c * in_d,
in_h,
in_w,
{kernel_size[1], kernel_size[2]},
{strides[1], strides[2]},
{paddings[2], paddings[3], paddings[4], paddings[5]},
!channel_last);
} else {
r = xpu::max_pool3d_grad<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(out.data<T>()),
index_data,
reinterpret_cast<const XPUType*>(dout.data<T>()),
reinterpret_cast<XPUType*>(dx->data<T>()),
n,
c,
in_d,
in_h,
in_w,
kernel_size,
strides,
paddings,
!channel_last);
}
} else if (pooling_type == "avg") {
r = xpu::avg_pool3d_grad<XPUType>(
ctx.x_context(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册