diff --git a/paddle/fluid/operators/pool_op_mlu.cc b/paddle/fluid/operators/pool_op_mlu.cc index 4f9d1343c8395ab0d37f9802551f72bdc644f1f7..5eaf8dbff88ab600223a980066c5c45d4264430b 100644 --- a/paddle/fluid/operators/pool_op_mlu.cc +++ b/paddle/fluid/operators/pool_op_mlu.cc @@ -100,6 +100,25 @@ class MLUPoolOpKernel : public framework::OpKernel { cnnlPoolingMode_t pool_mode = ToCnnlPoolingMode(pooling_type, exclusive, adaptive); + // transpose NCHW to NHWC since cnnl pool2d has worse performance in that + // layout. + framework::Tensor trans_in_x; + framework::Tensor trans_out; + if (channel_last) { + trans_in_x = *in_x; + trans_out = *out; + } else { + std::vector perm{0, 2, 3, 1}; + TransposeFromMLUTensor( + ctx, perm, in_x, &trans_in_x, true /*need_reshape_or_alloc*/); + trans_out = ctx.AllocateTmpTensor( + {out_dims[0], out_dims[2], out_dims[3], out_dims[1]}, dev_ctx); + } + MLUCnnlTensorDesc trans_in_x_desc( + trans_in_x, CNNL_LAYOUT_NHWC, ToCnnlDataType()); + MLUCnnlTensorDesc trans_out_desc( + trans_out, CNNL_LAYOUT_NHWC, ToCnnlDataType()); + if (!adaptive) { MLUCnnlPoolingDesc pool_desc(pool_mode, CNNL_NOT_PROPAGATE_NAN, @@ -128,8 +147,8 @@ class MLUPoolOpKernel : public framework::OpKernel { {static_cast(extra_input_size)}, cpu_ctx); cnnlInitPoolingExtraInput(handle, pool_desc.get(), - in_x_desc.get(), - out_desc.get(), + trans_in_x_desc.get(), + trans_out_desc.get(), GetBasePtr(&extra_host_tensor)); framework::Tensor extra_device_tensor = ctx.AllocateTmpTensor( @@ -151,12 +170,12 @@ class MLUPoolOpKernel : public framework::OpKernel { out_w, pool_desc.get(), nullptr /*alpha*/, - in_x_desc.get(), - GetBasePtr(in_x), + trans_in_x_desc.get(), + GetBasePtr(&trans_in_x), nullptr /*beta*/, GetBasePtr(&extra_device_tensor) /*params_shape_ptr*/, - out_desc.get(), - GetBasePtr(out)); + trans_out_desc.get(), + GetBasePtr(&trans_out)); } else { MLUCnnl::PoolingForward(ctx, pool_mode, @@ -164,31 +183,14 @@ class MLUPoolOpKernel : public framework::OpKernel { out_w, pool_desc.get(), nullptr /*alpha*/, - in_x_desc.get(), - GetBasePtr(in_x), + trans_in_x_desc.get(), + GetBasePtr(&trans_in_x), nullptr /*beta*/, nullptr /*params_shape_ptr*/, - out_desc.get(), - GetBasePtr(out)); + trans_out_desc.get(), + GetBasePtr(&trans_out)); } } else { - // cnnl Adaptive pooling only support NHWC layout - framework::Tensor trans_in_x; - framework::Tensor trans_out; - if (channel_last) { - trans_in_x = *in_x; - trans_out = *out; - } else { - std::vector perm{0, 2, 3, 1}; - TransposeFromMLUTensor( - ctx, perm, in_x, &trans_in_x, true /*need_reshape_or_alloc*/); - trans_out = ctx.AllocateTmpTensor( - {out_dims[0], out_dims[2], out_dims[3], out_dims[1]}, dev_ctx); - } - MLUCnnlTensorDesc trans_in_x_desc( - trans_in_x, CNNL_LAYOUT_NHWC, ToCnnlDataType()); - MLUCnnlTensorDesc trans_out_desc( - trans_out, CNNL_LAYOUT_NHWC, ToCnnlDataType()); MLUCnnl::AdaptivePoolingForward(ctx, pool_mode, trans_in_x_desc.get(), @@ -197,11 +199,11 @@ class MLUPoolOpKernel : public framework::OpKernel { GetBasePtr(&trans_out), nullptr, nullptr); - if (!channel_last) { - std::vector perm{0, 3, 1, 2}; - TransposeFromMLUTensor( - ctx, perm, &trans_out, out, false /*need_reshape_or_alloc*/); - } + } + if (!channel_last) { + std::vector perm{0, 3, 1, 2}; + TransposeFromMLUTensor( + ctx, perm, &trans_out, out, false /*need_reshape_or_alloc*/); } } };