未验证 提交 74bf3bed 编写于 作者: Q QingshuChen 提交者: GitHub

support global pooling for kunlun (#29293)

* test=kunlun
上级 b10ecd9d
......@@ -43,16 +43,18 @@ class PoolXPUKernel : public framework::OpKernel<T> {
bool exclusive = context.Attr<bool>("exclusive");
bool is_test = context.Attr<bool>("is_test");
bool adaptive = context.Attr<bool>("adaptive");
PADDLE_ENFORCE_EQ(
!adaptive, true,
platform::errors::InvalidArgument(
"The Pool2d XPU OP does not support adaptive == true!"));
PADDLE_ENFORCE_EQ(
ksize.size(), 2,
platform::errors::InvalidArgument(
"The Pool2d XPU OP only support 2 dimension pooling!"));
PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1), true,
platform::errors::InvalidArgument(
"The Pool2d XPU OP does not support (adaptive == "
"true && output_size != 1)"));
int* index_data = nullptr;
if (context.Attr<bool>("global_pooling")) {
bool global_pooling = context.Attr<bool>("global_pooling") ||
(adaptive && (ksize[0] * ksize[1] == 1));
if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
......@@ -104,16 +106,18 @@ class PoolGradXPUKernel : public framework::OpKernel<T> {
bool exclusive = context.Attr<bool>("exclusive");
bool adaptive = context.Attr<bool>("adaptive");
const int* index_data = nullptr;
PADDLE_ENFORCE_EQ(
!adaptive, true,
platform::errors::InvalidArgument(
"The Pool2d XPU OP does not support adaptive == true!"));
PADDLE_ENFORCE_EQ(ksize.size(), 2, platform::errors::InvalidArgument(
"The Pool2d XPU OP only support 2 "
"dimension pooling!, but received "
"%d-dimension pool kernel size",
ksize.size()));
if (context.Attr<bool>("global_pooling")) {
PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1), true,
platform::errors::InvalidArgument(
"The Pool2d XPU OP does not support (adaptive == "
"true && output_size != 1)"));
bool global_pooling = context.Attr<bool>("global_pooling") ||
(adaptive && (ksize[0] * ksize[1] == 1));
if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) {
paddings[i] = 0;
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册