/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/pool_op.h" #include "paddle/fluid/operators/mlu/mlu_baseop.h" namespace paddle { namespace operators { namespace { cnnlPoolingMode_t ToCnnlPoolingMode(const std::string &pooling_type, bool exclusive) { cnnlPoolingMode_t pooling_mode; if (pooling_type == "max") { pooling_mode = CNNL_POOLING_MAX; } else if (pooling_type == "avg") { if (exclusive) { pooling_mode = CNNL_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING; } else { pooling_mode = CNNL_POOLING_AVERAGE_COUNT_INCLUDE_PADDING; } } else { PADDLE_THROW(platform::errors::InvalidArgument("Unknown pooling_type: %s", pooling_type)); } return pooling_mode; } } // namespace template class MLUPoolOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto &dev_ctx = ctx.template device_context(); const Tensor *in_x = ctx.Input("X"); Tensor *out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); std::string pooling_type = ctx.Attr("pooling_type"); std::vector ksize = ctx.Attr>("ksize"); std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); std::string data_format = ctx.Attr("data_format"); bool global_pooling = ctx.Attr("global_pooling"); bool ceil_mode = ctx.Attr("ceil_mode"); bool exclusive = ctx.Attr("exclusive"); bool adaptive = ctx.Attr("adaptive"); std::string padding_algorithm = ctx.Attr("padding_algorithm"); PADDLE_ENFORCE_EQ(in_x->dims().size(), 4, platform::errors::InvalidArgument( "Only support 4-dims for mlu pool2d kernel.")); PADDLE_ENFORCE_EQ(adaptive, false, platform::errors::InvalidArgument( "Not support adaptive for mlu pool2d kernel.")); // default cnnlTensorLayout_t cnnl_layout = CNNL_LAYOUT_NCHW; auto out_dims = out->dims(); int64_t out_h = out_dims[2]; int64_t out_w = out_dims[3]; auto in_x_dims = in_x->dims(); framework::DDim data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size()); const bool channel_last = data_format == "NHWC"; if (channel_last) { cnnl_layout = CNNL_LAYOUT_NHWC; out_h = out_dims[1]; out_w = out_dims[2]; data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1); } UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm, data_dims, strides, ksize); if (global_pooling) { UpdateKsize(&ksize, data_dims); } MLUCnnlTensorDesc in_x_desc(*in_x, cnnl_layout, ToCnnlDataType()); MLUCnnlTensorDesc out_desc(*out, cnnl_layout, ToCnnlDataType()); cnnlPoolingMode_t pool_mode = ToCnnlPoolingMode(pooling_type, exclusive); MLUCnnlPoolingDesc pool_desc( pool_mode, CNNL_NOT_PROPAGATE_NAN, ksize[0], ksize[1], paddings[0], paddings[1], paddings[2], paddings[3], strides[0], strides[1], 1 /*row_dilation*/, 1 /*col_dilation*/, ceil_mode); size_t extra_input_size = 0; cnnlHandle_t handle = ctx.template device_context().cnnl_handle(); cnnlGetPoolingExtraInputSize(handle, pool_mode, out_w, out_h, &extra_input_size); if (extra_input_size > 0) { paddle::platform::CPUDeviceContext cpu_ctx; framework::Tensor extra_host_tensor = ctx.AllocateTmpTensor( {static_cast(extra_input_size)}, cpu_ctx); cnnlInitPoolingExtraInput(handle, pool_desc.get(), in_x_desc.get(), out_desc.get(), GetBasePtr(&extra_host_tensor)); framework::Tensor extra_device_tensor = ctx.AllocateTmpTensor( {static_cast(extra_input_size)}, dev_ctx); // TODO(fwg): use Async copy, and add a callback to stream that free host // memory. framework::TensorCopySync(extra_host_tensor, ctx.GetPlace(), &extra_device_tensor); MLUCnnl::PoolingForward( ctx, pool_mode, out_h, out_w, pool_desc.get(), nullptr /*alpha*/, in_x_desc.get(), GetBasePtr(in_x), nullptr /*beta*/, GetBasePtr(&extra_device_tensor) /*params_shape_ptr*/, out_desc.get(), GetBasePtr(out)); } else { MLUCnnl::PoolingForward( ctx, pool_mode, out_h, out_w, pool_desc.get(), nullptr /*alpha*/, in_x_desc.get(), GetBasePtr(in_x), nullptr /*beta*/, nullptr /*params_shape_ptr*/, out_desc.get(), GetBasePtr(out)); } } }; template class MLUPoolGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { auto &dev_ctx = ctx.template device_context(); const Tensor *in_x = ctx.Input("X"); const Tensor *out = ctx.Input("Out"); const Tensor *out_grad = ctx.Input(framework::GradVarName("Out")); Tensor *in_x_grad = ctx.Output(framework::GradVarName("X")); in_x_grad->mutable_data(ctx.GetPlace()); std::string pooling_type = ctx.Attr("pooling_type"); std::vector ksize = ctx.Attr>("ksize"); std::vector strides = ctx.Attr>("strides"); std::vector paddings = ctx.Attr>("paddings"); bool ceil_mode = ctx.Attr("ceil_mode"); bool exclusive = ctx.Attr("exclusive"); bool adaptive = ctx.Attr("adaptive"); std::string data_format = ctx.Attr("data_format"); bool global_pooling = ctx.Attr("global_pooling"); std::string padding_algorithm = ctx.Attr("padding_algorithm"); const bool channel_last = data_format == "NHWC"; auto in_x_dims = in_x->dims(); framework::DDim data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size()); if (channel_last) { data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1); } UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm, data_dims, strides, ksize); if (global_pooling) { UpdateKsize(&ksize, data_dims); } // inputs need with NHWC layout framework::Tensor trans_in_x; framework::Tensor trans_out; framework::Tensor trans_out_grad; framework::Tensor trans_in_x_grad; if (channel_last) { trans_in_x = *in_x; trans_out = *out; trans_out_grad = *out_grad; trans_in_x_grad = *in_x_grad; } else { std::vector perm{0, 2, 3, 1}; TransposeFromMLUTensor(ctx, perm, in_x, &trans_in_x, true /*need_reshape_or_alloc*/); TransposeFromMLUTensor(ctx, perm, out, &trans_out, true /*need_reshape_or_alloc*/); TransposeFromMLUTensor(ctx, perm, out_grad, &trans_out_grad, true /*need_reshape_or_alloc*/); auto in_x_grad_dims = in_x_grad->dims(); trans_in_x_grad = ctx.AllocateTmpTensor( {in_x_grad_dims[0], in_x_grad_dims[2], in_x_grad_dims[3], in_x_grad_dims[1]}, dev_ctx); } MLUCnnlTensorDesc trans_in_x_desc(trans_in_x, CNNL_LAYOUT_NHWC, ToCnnlDataType()); MLUCnnlTensorDesc trans_out_desc(trans_out, CNNL_LAYOUT_NHWC, ToCnnlDataType()); MLUCnnlTensorDesc trans_out_grad_desc(trans_out_grad, CNNL_LAYOUT_NHWC, ToCnnlDataType()); MLUCnnlTensorDesc trans_in_x_grad_desc(trans_in_x_grad, CNNL_LAYOUT_NHWC, ToCnnlDataType()); cnnlPoolingMode_t pool_mode = ToCnnlPoolingMode(pooling_type, exclusive); MLUCnnlPoolingDesc pool_desc( pool_mode, CNNL_NOT_PROPAGATE_NAN, ksize[0], ksize[1], paddings[0], paddings[1], paddings[2], paddings[3], strides[0], strides[1], 1 /*row_dilation*/, 1 /*col_dilation*/, ceil_mode); if (pooling_type == "max") { framework::Tensor index_tensor = ctx.AllocateTmpTensor(trans_out_grad.dims(), dev_ctx); MLUCnnlTensorDesc index_tensor_desc(index_tensor, CNNL_LAYOUT_NHWC, ToCnnlDataType()); MLUCnnl::PoolingIndex(ctx, pool_desc.get(), trans_in_x_desc.get(), GetBasePtr(&trans_in_x), index_tensor_desc.get(), GetBasePtr(&index_tensor)); MLUCnnl::PoolingBackward( ctx, pool_desc.get(), nullptr /*alpha*/, index_tensor_desc.get(), GetBasePtr(&index_tensor), trans_out_grad_desc.get(), GetBasePtr(&trans_out_grad), trans_in_x_desc.get(), GetBasePtr(&trans_in_x), nullptr /*beta*/, trans_in_x_grad_desc.get(), GetBasePtr(&trans_in_x_grad)); } else { MLUCnnl::PoolingBackward(ctx, pool_desc.get(), nullptr /*alpha*/, nullptr, nullptr, trans_out_grad_desc.get(), GetBasePtr(&trans_out_grad), nullptr, nullptr, nullptr /*beta*/, trans_in_x_grad_desc.get(), GetBasePtr(&trans_in_x_grad)); } if (!channel_last) { std::vector perm{0, 3, 1, 2}; TransposeFromMLUTensor(ctx, perm, &trans_in_x_grad, in_x_grad, false /*need_reshape_or_alloc*/); } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_MLU_KERNEL(pool2d, ops::MLUPoolOpKernel, ops::MLUPoolOpKernel); REGISTER_OP_MLU_KERNEL(pool2d_grad, ops::MLUPoolGradOpKernel, ops::MLUPoolGradOpKernel);