From 74bf3bed36c438191901801b61bdb278134c2162 Mon Sep 17 00:00:00 2001 From: QingshuChen Date: Fri, 4 Dec 2020 11:22:36 +0800 Subject: [PATCH] support global pooling for kunlun (#29293) * test=kunlun --- paddle/fluid/operators/pool_op_xpu.cc | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/pool_op_xpu.cc b/paddle/fluid/operators/pool_op_xpu.cc index 325b7359389..096a81db9bd 100644 --- a/paddle/fluid/operators/pool_op_xpu.cc +++ b/paddle/fluid/operators/pool_op_xpu.cc @@ -43,16 +43,18 @@ class PoolXPUKernel : public framework::OpKernel { bool exclusive = context.Attr("exclusive"); bool is_test = context.Attr("is_test"); bool adaptive = context.Attr("adaptive"); - PADDLE_ENFORCE_EQ( - !adaptive, true, - platform::errors::InvalidArgument( - "The Pool2d XPU OP does not support adaptive == true!")); PADDLE_ENFORCE_EQ( ksize.size(), 2, platform::errors::InvalidArgument( "The Pool2d XPU OP only support 2 dimension pooling!")); + PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1), true, + platform::errors::InvalidArgument( + "The Pool2d XPU OP does not support (adaptive == " + "true && output_size != 1)")); int* index_data = nullptr; - if (context.Attr("global_pooling")) { + bool global_pooling = context.Attr("global_pooling") || + (adaptive && (ksize[0] * ksize[1] == 1)); + if (global_pooling) { for (size_t i = 0; i < ksize.size(); ++i) { paddings[i] = 0; ksize[i] = static_cast(in_x->dims()[i + 2]); @@ -104,16 +106,18 @@ class PoolGradXPUKernel : public framework::OpKernel { bool exclusive = context.Attr("exclusive"); bool adaptive = context.Attr("adaptive"); const int* index_data = nullptr; - PADDLE_ENFORCE_EQ( - !adaptive, true, - platform::errors::InvalidArgument( - "The Pool2d XPU OP does not support adaptive == true!")); PADDLE_ENFORCE_EQ(ksize.size(), 2, platform::errors::InvalidArgument( "The Pool2d XPU OP only support 2 " "dimension pooling!, but received " "%d-dimension pool kernel size", ksize.size())); - if (context.Attr("global_pooling")) { + PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1), true, + platform::errors::InvalidArgument( + "The Pool2d XPU OP does not support (adaptive == " + "true && output_size != 1)")); + bool global_pooling = context.Attr("global_pooling") || + (adaptive && (ksize[0] * ksize[1] == 1)); + if (global_pooling) { for (size_t i = 0; i < ksize.size(); ++i) { paddings[i] = 0; ksize[i] = static_cast(in_x->dims()[i + 2]); -- GitLab