/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. Indicesou may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include "paddle/framework/op_registry.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/pooling.h" #include "paddle/operators/strided_memcpy.h" namespace paddle { namespace operators { template class SppKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const framework::Tensor* in_x = context.Input("X"); auto* out = context.Output("Out"); int pyramid_height = context.template Attr("pyramid_height"); out->mutable_data(context.GetPlace()); auto out_stride = framework::stride(out->dims()); int input_h = in_x->dims()[2]; int input_w = in_x->dims()[3]; size_t output_offset = 0; for (int p = 0; p < pyramid_height; ++p) { int bins = std::pow(2, p); int ksize_h = std::ceil(input_h / static_cast(bins)); int ksize_w = std::ceil(input_w / static_cast(bins)); int padding_h = (ksize_h * bins - input_h + 1) / 2; int padding_w = (ksize_w * bins - input_w + 1) / 2; std::vector ksize({ksize_h, ksize_w}); std::vector strides({ksize_h, ksize_w}); std::vector paddings({padding_h, padding_w}); // pooling output shape std::vector output_shape_vec({in_x->dims()[0], in_x->dims()[1]}); output_shape_vec.push_back((input_h - ksize_h + 2 * padding_h) / ksize_h + 1); output_shape_vec.push_back((input_w - ksize_w + 2 * padding_w) / ksize_w + 1); framework::DDim output_shape(framework::make_ddim(output_shape_vec)); // flatten pooling output shape int output_flatten_w = in_x->dims()[1] * bins * bins; std::vector output_flatten_shape_vec( {in_x->dims()[0], output_flatten_w}); framework::DDim output_flatten_shape( framework::make_ddim(output_flatten_shape_vec)); framework::Tensor out_level; framework::Tensor out_flatten_level; out_level.mutable_data(output_shape, context.GetPlace()); // pooling math::Pool2dFunctor, T> pool_forward; math::MaxPool max_process; pool_forward(context.device_context(), *in_x, ksize, strides, paddings, max_process, &out_level); out_flatten_level.ShareDataWith(out_level); out_flatten_level.Resize(output_flatten_shape); auto in_stride = framework::stride(out_flatten_level.dims()); const T* src_data = out_flatten_level.data(); StridedMemcpy(context.device_context(), src_data, in_stride, out_flatten_level.dims(), out_stride, out->data() + output_offset); output_offset += out_flatten_level.dims()[1] * in_stride[1]; } } }; template class SppGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const framework::Tensor* in_x = context.Input("X"); const framework::Tensor* out = context.Input("Out"); const framework::Tensor* out_grad = context.Input(framework::GradVarName("Out")); framework::Tensor* in_x_grad = context.Output(framework::GradVarName("X")); auto& device_ctx = context.device_context(); math::SetConstant zero; in_x_grad->mutable_data(context.GetPlace()); zero(device_ctx, in_x_grad, static_cast(0)); int pyramid_height = context.template Attr("pyramid_height"); auto outgrad_stride = framework::stride(out_grad->dims()); auto out_stride = framework::stride(out->dims()); int input_h = in_x->dims()[2]; int input_w = in_x->dims()[3]; size_t out_offset = 0; for (int p = 0; p < pyramid_height; ++p) { int bins = std::pow(2, p); int ksize_h = std::ceil(input_h / static_cast(bins)); int ksize_w = std::ceil(input_w / static_cast(bins)); int padding_h = (ksize_h * bins - input_h + 1) / 2; int padding_w = (ksize_w * bins - input_w + 1) / 2; std::vector ksize({ksize_h, ksize_w}); std::vector strides({ksize_h, ksize_w}); std::vector paddings({padding_h, padding_w}); // split outgrad and get flatten std::vector out_shape_vec({in_x->dims()[0], in_x->dims()[1]}); out_shape_vec.push_back((input_h - ksize_h + 2 * padding_h) / ksize_h + 1); out_shape_vec.push_back((input_w - ksize_w + 2 * padding_w) / ksize_w + 1); framework::DDim out_shape(framework::make_ddim(out_shape_vec)); int out_flatten_w = in_x->dims()[1] * bins * bins; std::vector out_flatten_shape_vec( {in_x->dims()[0], out_flatten_w}); framework::DDim out_flatten_shape( framework::make_ddim(out_flatten_shape_vec)); framework::Tensor out_level; framework::Tensor outgrad_level; framework::Tensor out_flatten_level; framework::Tensor outgrad_flatten_level; out_flatten_level.mutable_data(out_flatten_shape, context.GetPlace()); outgrad_flatten_level.mutable_data(out_flatten_shape, context.GetPlace()); auto flatten_stride = framework::stride(out_flatten_level.dims()); // memcpy StridedMemcpy(context.device_context(), out->data() + out_offset, out_stride, out_flatten_level.dims(), flatten_stride, out_flatten_level.data()); StridedMemcpy(context.device_context(), out_grad->data() + out_offset, outgrad_stride, outgrad_flatten_level.dims(), flatten_stride, outgrad_flatten_level.data()); out_offset += out_flatten_level.dims()[1] * out_stride[1]; // flatten backward out_level.ShareDataWith(out_flatten_level); out_level.Resize(out_shape); outgrad_level.ShareDataWith(outgrad_flatten_level); outgrad_level.Resize(out_shape); math::MaxPool2dGradFunctor pool2d_backward; pool2d_backward(context.device_context(), *in_x, *&out_level, *&outgrad_level, ksize, strides, paddings, in_x_grad); } } }; } // namespace operators } // namespace paddle