未验证 提交 01d05bc0 编写于 作者: z8hanghuan's avatar z8hanghuan 提交者: GitHub

fix bug of adaptive pool2d_grad, *test=kunlun (#45031)

* fix bug of adaptive pool2d_grad, *test=kunlun

* fix bug of adaptive pool2d_grad, *test=kunlun

* fix bug of adaptive pool2d_grad, *test=kunlun
上级 84bf5c31
...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so") ...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so")
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
set(XPU_BASE_URL_WITHOUT_DATE set(XPU_BASE_URL_WITHOUT_DATE
"https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220802") set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220810")
else() else()
set(XPU_BASE_URL "${XPU_BASE_URL}") set(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
...@@ -19,7 +19,7 @@ endif() ...@@ -19,7 +19,7 @@ endif()
if(NOT DEFINED XPU_XDNN_BASE_URL) if(NOT DEFINED XPU_XDNN_BASE_URL)
set(XPU_XDNN_BASE_URL_WITHOUT_DATE set(XPU_XDNN_BASE_URL_WITHOUT_DATE
"https://klx-sdk-release-public.su.bcebos.com/xdnn/dev") "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220802") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220810")
else() else()
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif() endif()
......
...@@ -44,6 +44,15 @@ class PoolXPUKernel : public framework::OpKernel<T> { ...@@ -44,6 +44,15 @@ class PoolXPUKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The Pool2d XPU OP only support 2 dimension pooling!")); "The Pool2d XPU OP only support 2 dimension pooling!"));
std::string data_format = context.Attr<std::string>("data_format");
PADDLE_ENFORCE_EQ(
data_format,
"NCHW",
platform::errors::InvalidArgument("The Pool2d XPU OP only support"
"data_format is 'NCHW', but received "
"%s",
data_format));
int* index_data = nullptr; int* index_data = nullptr;
bool global_pooling = context.Attr<bool>("global_pooling") || bool global_pooling = context.Attr<bool>("global_pooling") ||
(adaptive && (ksize[0] * ksize[1] == 1)); (adaptive && (ksize[0] * ksize[1] == 1));
...@@ -173,6 +182,16 @@ class PoolGradXPUKernel : public framework::OpKernel<T> { ...@@ -173,6 +182,16 @@ class PoolGradXPUKernel : public framework::OpKernel<T> {
bool exclusive = context.Attr<bool>("exclusive"); bool exclusive = context.Attr<bool>("exclusive");
bool adaptive = context.Attr<bool>("adaptive"); bool adaptive = context.Attr<bool>("adaptive");
bool ceil_mode = context.Attr<bool>("ceil_mode"); bool ceil_mode = context.Attr<bool>("ceil_mode");
std::string data_format = context.Attr<std::string>("data_format");
PADDLE_ENFORCE_EQ(
data_format,
"NCHW",
platform::errors::InvalidArgument("The Pool2d_grad XPU OP only support"
"data_format is 'NCHW', but received "
"%s",
data_format));
std::string padding_algorithm = std::string padding_algorithm =
context.Attr<std::string>("padding_algorithm"); context.Attr<std::string>("padding_algorithm");
const int* index_data = nullptr; const int* index_data = nullptr;
...@@ -202,13 +221,6 @@ class PoolGradXPUKernel : public framework::OpKernel<T> { ...@@ -202,13 +221,6 @@ class PoolGradXPUKernel : public framework::OpKernel<T> {
const int out_h = out->dims()[2]; const int out_h = out->dims()[2];
const int out_w = out->dims()[3]; const int out_w = out->dims()[3];
PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1) ||
(in_h % out_h == 0 && in_w % out_w == 0),
true,
platform::errors::InvalidArgument(
"The Pool2d XPU OP does not support (adaptive == "
"true && output_size != 1)"));
framework::DDim data_dims; framework::DDim data_dims;
data_dims = phi::slice_ddim(in_x->dims(), 2, in_x->dims().size()); data_dims = phi::slice_ddim(in_x->dims(), 2, in_x->dims().size());
...@@ -234,7 +246,8 @@ class PoolGradXPUKernel : public framework::OpKernel<T> { ...@@ -234,7 +246,8 @@ class PoolGradXPUKernel : public framework::OpKernel<T> {
auto input_grad = reinterpret_cast<XPUType*>(in_x_grad->data<T>()); auto input_grad = reinterpret_cast<XPUType*>(in_x_grad->data<T>());
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::Error_t::SUCCESS; int r = xpu::Error_t::SUCCESS;
if (adaptive && in_h % out_h == 0 && in_w % out_w == 0) { if (adaptive) {
// floor for stride
strides = {in_h / out_h, in_w / out_w}; strides = {in_h / out_h, in_w / out_w};
int kh = in_h - (out_h - 1) * strides[0]; int kh = in_h - (out_h - 1) * strides[0];
int kw = in_w - (out_w - 1) * strides[1]; int kw = in_w - (out_w - 1) * strides[1];
...@@ -243,6 +256,7 @@ class PoolGradXPUKernel : public framework::OpKernel<T> { ...@@ -243,6 +256,7 @@ class PoolGradXPUKernel : public framework::OpKernel<T> {
} }
if (pooling_type == "max") { if (pooling_type == "max") {
// TODO(zhanghuan05) to bind max_pool2d_grad_indices xpu api
r = xpu::max_pool2d_grad<XPUType>(dev_ctx.x_context(), r = xpu::max_pool2d_grad<XPUType>(dev_ctx.x_context(),
input, input,
output, output,
......
...@@ -341,6 +341,24 @@ class XPUTestPool2D_Op(XPUOpTestWrapper): ...@@ -341,6 +341,24 @@ class XPUTestPool2D_Op(XPUOpTestWrapper):
def init_adaptive(self): def init_adaptive(self):
self.adaptive = False self.adaptive = False
class TestAvgPoolAdaptive(TestPool2D_Op):
def init_adaptive(self):
self.adaptive = True
class TestAvgPoolAdaptiveAsyOutSize(TestPool2D_Op):
def init_adaptive(self):
self.adaptive = True
def init_shape(self):
self.shape = [8, 3, 6, 6]
def init_test_case(self):
self.ksize = [2, 3]
self.strides = [1, 1]
self.paddings = [0, 0, 0, 0]
class TestCase1(TestPool2D_Op): class TestCase1(TestPool2D_Op):
def init_test_case(self): def init_test_case(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册