spp_op.h 8.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
S
sweetsky0901 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16 17
#include <string>
#include <vector>
18

Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/op_registry.h"
F
From00 已提交
20
#include "paddle/fluid/framework/phi_utils.h"
21
#include "paddle/phi/kernels/funcs/math_function.h"
F
From00 已提交
22
#include "paddle/phi/kernels/funcs/pooling.h"
23
#include "paddle/phi/kernels/funcs/strided_memcpy.h"
S
sweetsky0901 已提交
24 25 26

namespace paddle {
namespace operators {
S
sweetsky0901 已提交
27
template <typename DeviceContext, typename T>
S
sweetsky0901 已提交
28 29 30
class SppKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
31 32
    const phi::DenseTensor* in_x = context.Input<phi::DenseTensor>("X");
    auto* out = context.Output<phi::DenseTensor>("Out");
S
sweetsky0901 已提交
33
    int pyramid_height = context.template Attr<int>("pyramid_height");
S
sweetsky0901 已提交
34 35
    std::string pooling_type =
        context.template Attr<std::string>("pooling_type");
S
sweetsky0901 已提交
36
    out->mutable_data<T>(context.GetPlace());
37
    auto out_stride = phi::stride(out->dims());
S
sweetsky0901 已提交
38 39 40 41 42
    int input_h = in_x->dims()[2];
    int input_w = in_x->dims()[3];
    size_t output_offset = 0;
    for (int p = 0; p < pyramid_height; ++p) {
      int bins = std::pow(2, p);
S
sweetsky0901 已提交
43 44 45 46 47 48
      int kernel_size_h = std::ceil(input_h / static_cast<double>(bins));
      int kernel_size_w = std::ceil(input_w / static_cast<double>(bins));
      int padding_h = (kernel_size_h * bins - input_h + 1) / 2;
      int padding_w = (kernel_size_w * bins - input_w + 1) / 2;
      std::vector<int> kernel_size({kernel_size_h, kernel_size_w});
      std::vector<int> strides({kernel_size_h, kernel_size_w});
S
sweetsky0901 已提交
49 50
      std::vector<int> paddings({padding_h, padding_w});
      // pooling output shape
51
      phi::DenseTensor out_level;
S
sweetsky0901 已提交
52 53
      std::vector<int64_t> output_shape_vec(
          {in_x->dims()[0], in_x->dims()[1], bins, bins});
54
      framework::DDim output_shape(phi::make_ddim(output_shape_vec));
S
sweetsky0901 已提交
55 56
      out_level.mutable_data<T>(output_shape, context.GetPlace());
      // pooling
S
sweetsky0901 已提交
57
      if (pooling_type == "max") {
F
From00 已提交
58 59
        phi::funcs::Pool2dFunctor<
            typename framework::ConvertToPhiContext<DeviceContext>::TYPE,
60 61
            phi::funcs::MaxPool<T>,
            T>
F
From00 已提交
62 63
            pool_forward;
        phi::funcs::MaxPool<T> max_process;
64 65 66 67 68 69 70 71
        pool_forward(context.template device_context<DeviceContext>(),
                     *in_x,
                     kernel_size,
                     strides,
                     paddings,
                     true,
                     false,
                     &out_level,
72
                     max_process);
S
sweetsky0901 已提交
73
      } else if (pooling_type == "avg") {
F
From00 已提交
74 75
        phi::funcs::Pool2dFunctor<
            typename framework::ConvertToPhiContext<DeviceContext>::TYPE,
76 77
            phi::funcs::AvgPool<T>,
            T>
F
From00 已提交
78 79
            pool_forward;
        phi::funcs::AvgPool<T> avg_process;
80 81 82 83 84 85 86 87
        pool_forward(context.template device_context<DeviceContext>(),
                     *in_x,
                     kernel_size,
                     strides,
                     paddings,
                     true,
                     false,
                     &out_level,
88
                     avg_process);
S
sweetsky0901 已提交
89
      }
S
sweetsky0901 已提交
90 91 92 93 94
      // flatten pooling output shape
      int output_flatten_w = in_x->dims()[1] * bins * bins;
      std::vector<int64_t> output_flatten_shape_vec(
          {in_x->dims()[0], output_flatten_w});
      framework::DDim output_flatten_shape(
95
          phi::make_ddim(output_flatten_shape_vec));
S
sweetsky0901 已提交
96
      out_level.Resize(output_flatten_shape);
S
sweetsky0901 已提交
97
      // concat
98
      auto out_level_stride = phi::stride(out_level.dims());
99 100 101 102 103 104 105
      phi::funcs::StridedMemcpy<T>(
          context.template device_context<DeviceContext>(),
          out_level.data<T>(),
          out_level_stride,
          out_level.dims(),
          out_stride,
          out->data<T>() + output_offset);
S
sweetsky0901 已提交
106
      output_offset += out_level.dims()[1] * out_level_stride[1];
S
sweetsky0901 已提交
107 108 109
    }
  }
};
S
sweetsky0901 已提交
110
template <typename DeviceContext, typename T>
S
sweetsky0901 已提交
111 112 113
class SppGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
114 115 116 117 118 119
    const phi::DenseTensor* in_x = context.Input<phi::DenseTensor>("X");
    const phi::DenseTensor* out = context.Input<phi::DenseTensor>("Out");
    const phi::DenseTensor* out_grad =
        context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
    phi::DenseTensor* in_x_grad =
        context.Output<phi::DenseTensor>(framework::GradVarName("X"));
S
sweetsky0901 已提交
120
    int pyramid_height = context.template Attr<int>("pyramid_height");
S
sweetsky0901 已提交
121 122
    std::string pooling_type =
        context.template Attr<std::string>("pooling_type");
S
sweetsky0901 已提交
123
    auto& device_ctx = context.template device_context<DeviceContext>();
F
From00 已提交
124
    phi::funcs::SetConstant<
125 126
        typename framework::ConvertToPhiContext<DeviceContext>::TYPE,
        T>
F
From00 已提交
127
        zero;
S
sweetsky0901 已提交
128 129
    in_x_grad->mutable_data<T>(context.GetPlace());
    zero(device_ctx, in_x_grad, static_cast<T>(0));
130
    auto out_stride = phi::stride(out->dims());
S
sweetsky0901 已提交
131 132 133 134 135
    int input_h = in_x->dims()[2];
    int input_w = in_x->dims()[3];
    size_t out_offset = 0;
    for (int p = 0; p < pyramid_height; ++p) {
      int bins = std::pow(2, p);
S
sweetsky0901 已提交
136 137 138 139 140 141
      int kernel_size_h = std::ceil(input_h / static_cast<double>(bins));
      int kernel_size_w = std::ceil(input_w / static_cast<double>(bins));
      int padding_h = (kernel_size_h * bins - input_h + 1) / 2;
      int padding_w = (kernel_size_w * bins - input_w + 1) / 2;
      std::vector<int> kernel_size({kernel_size_h, kernel_size_w});
      std::vector<int> strides({kernel_size_h, kernel_size_w});
S
sweetsky0901 已提交
142
      std::vector<int> paddings({padding_h, padding_w});
S
sweetsky0901 已提交
143
      // split out and outgrad  ...  to flatten
144 145
      phi::DenseTensor out_level;
      phi::DenseTensor outgrad_level;
S
sweetsky0901 已提交
146 147 148
      int out_flatten_w = in_x->dims()[1] * bins * bins;
      std::vector<int64_t> out_flatten_shape_vec(
          {in_x->dims()[0], out_flatten_w});
149
      framework::DDim out_flatten_shape(phi::make_ddim(out_flatten_shape_vec));
S
sweetsky0901 已提交
150 151
      out_level.mutable_data<T>(out_flatten_shape, context.GetPlace());
      outgrad_level.mutable_data<T>(out_flatten_shape, context.GetPlace());
152
      auto flatten_stride = phi::stride(out_level.dims());
S
sweetsky0901 已提交
153
      // memcpy
154 155 156 157 158 159 160
      phi::funcs::StridedMemcpy<T>(
          context.template device_context<DeviceContext>(),
          out->data<T>() + out_offset,
          out_stride,
          out_level.dims(),
          flatten_stride,
          out_level.data<T>());
S
sweetsky0901 已提交
161

162 163 164 165 166 167 168
      phi::funcs::StridedMemcpy<T>(
          context.template device_context<DeviceContext>(),
          out_grad->data<T>() + out_offset,
          out_stride,
          outgrad_level.dims(),
          flatten_stride,
          outgrad_level.data<T>());
S
sweetsky0901 已提交
169
      out_offset += out_level.dims()[1] * out_stride[1];
S
sweetsky0901 已提交
170
      // flatten backward to nchw
S
sweetsky0901 已提交
171

S
sweetsky0901 已提交
172
      std::vector<int64_t> out_shape_vec({in_x->dims()[0], in_x->dims()[1]});
S
sweetsky0901 已提交
173 174 175 176
      out_shape_vec.push_back(
          (input_h - kernel_size_h + 2 * padding_h) / kernel_size_h + 1);
      out_shape_vec.push_back(
          (input_w - kernel_size_w + 2 * padding_w) / kernel_size_w + 1);
177
      framework::DDim out_shape(phi::make_ddim(out_shape_vec));
S
sweetsky0901 已提交
178
      out_level.ShareDataWith(out_level);
S
sweetsky0901 已提交
179
      out_level.Resize(out_shape);
S
sweetsky0901 已提交
180
      outgrad_level.ShareDataWith(outgrad_level);
S
sweetsky0901 已提交
181
      outgrad_level.Resize(out_shape);
S
sweetsky0901 已提交
182
      // pooling backward
S
sweetsky0901 已提交
183
      if (pooling_type == "max") {
F
From00 已提交
184
        phi::funcs::MaxPool2dGradFunctor<
185 186
            typename framework::ConvertToPhiContext<DeviceContext>::TYPE,
            T>
F
From00 已提交
187
            pool2d_backward;
188 189 190 191 192 193 194 195
        pool2d_backward(context.template device_context<DeviceContext>(),
                        *in_x,
                        *&out_level,
                        *&outgrad_level,
                        kernel_size,
                        strides,
                        paddings,
                        in_x_grad);
S
sweetsky0901 已提交
196
      } else if (pooling_type == "avg") {
F
From00 已提交
197 198
        phi::funcs::Pool2dGradFunctor<
            typename framework::ConvertToPhiContext<DeviceContext>::TYPE,
199 200
            phi::funcs::AvgPoolGrad<T>,
            T>
S
sweetsky0901 已提交
201
            pool_backward;
F
From00 已提交
202
        phi::funcs::AvgPoolGrad<T> avg_process;
203 204 205 206 207 208 209 210 211 212 213
        pool_backward(context.template device_context<DeviceContext>(),
                      *in_x,
                      *&out_level,
                      *&outgrad_level,
                      kernel_size,
                      strides,
                      paddings,
                      true,
                      false,
                      in_x_grad,
                      avg_process);
S
sweetsky0901 已提交
214
      }
S
sweetsky0901 已提交
215 216 217 218 219
    }
  }
};
}  // namespace operators
}  // namespace paddle