spp_op.h 9.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
S
sweetsky0901 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
16 17
#include <string>
#include <vector>
18

Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/op_registry.h"
F
From00 已提交
20
#include "paddle/fluid/framework/phi_utils.h"
Y
Yi Wang 已提交
21
#include "paddle/fluid/operators/strided_memcpy.h"
22
#include "paddle/phi/kernels/funcs/math_function.h"
F
From00 已提交
23
#include "paddle/phi/kernels/funcs/pooling.h"
S
sweetsky0901 已提交
24 25 26

namespace paddle {
namespace operators {
S
sweetsky0901 已提交
27
template <typename DeviceContext, typename T>
S
sweetsky0901 已提交
28 29 30
class SppKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
31 32
    const phi::DenseTensor* in_x = context.Input<phi::DenseTensor>("X");
    auto* out = context.Output<phi::DenseTensor>("Out");
S
sweetsky0901 已提交
33
    int pyramid_height = context.template Attr<int>("pyramid_height");
S
sweetsky0901 已提交
34 35
    std::string pooling_type =
        context.template Attr<std::string>("pooling_type");
S
sweetsky0901 已提交
36
    out->mutable_data<T>(context.GetPlace());
37
    auto out_stride = phi::stride(out->dims());
S
sweetsky0901 已提交
38 39 40 41 42
    int input_h = in_x->dims()[2];
    int input_w = in_x->dims()[3];
    size_t output_offset = 0;
    for (int p = 0; p < pyramid_height; ++p) {
      int bins = std::pow(2, p);
S
sweetsky0901 已提交
43 44 45 46 47 48
      int kernel_size_h = std::ceil(input_h / static_cast<double>(bins));
      int kernel_size_w = std::ceil(input_w / static_cast<double>(bins));
      int padding_h = (kernel_size_h * bins - input_h + 1) / 2;
      int padding_w = (kernel_size_w * bins - input_w + 1) / 2;
      std::vector<int> kernel_size({kernel_size_h, kernel_size_w});
      std::vector<int> strides({kernel_size_h, kernel_size_w});
S
sweetsky0901 已提交
49 50
      std::vector<int> paddings({padding_h, padding_w});
      // pooling output shape
51
      phi::DenseTensor out_level;
S
sweetsky0901 已提交
52 53
      std::vector<int64_t> output_shape_vec(
          {in_x->dims()[0], in_x->dims()[1], bins, bins});
54
      framework::DDim output_shape(phi::make_ddim(output_shape_vec));
S
sweetsky0901 已提交
55 56
      out_level.mutable_data<T>(output_shape, context.GetPlace());
      // pooling
S
sweetsky0901 已提交
57
      if (pooling_type == "max") {
F
From00 已提交
58 59
        phi::funcs::Pool2dFunctor<
            typename framework::ConvertToPhiContext<DeviceContext>::TYPE,
60 61
            phi::funcs::MaxPool<T>,
            T>
F
From00 已提交
62 63
            pool_forward;
        phi::funcs::MaxPool<T> max_process;
64 65 66 67 68 69 70 71
        pool_forward(context.template device_context<DeviceContext>(),
                     *in_x,
                     kernel_size,
                     strides,
                     paddings,
                     true,
                     false,
                     &out_level,
72
                     max_process);
S
sweetsky0901 已提交
73
      } else if (pooling_type == "avg") {
F
From00 已提交
74 75
        phi::funcs::Pool2dFunctor<
            typename framework::ConvertToPhiContext<DeviceContext>::TYPE,
76 77
            phi::funcs::AvgPool<T>,
            T>
F
From00 已提交
78 79
            pool_forward;
        phi::funcs::AvgPool<T> avg_process;
80 81 82 83 84 85 86 87
        pool_forward(context.template device_context<DeviceContext>(),
                     *in_x,
                     kernel_size,
                     strides,
                     paddings,
                     true,
                     false,
                     &out_level,
88
                     avg_process);
S
sweetsky0901 已提交
89
      }
S
sweetsky0901 已提交
90 91 92 93 94
      // flatten pooling output shape
      int output_flatten_w = in_x->dims()[1] * bins * bins;
      std::vector<int64_t> output_flatten_shape_vec(
          {in_x->dims()[0], output_flatten_w});
      framework::DDim output_flatten_shape(
95
          phi::make_ddim(output_flatten_shape_vec));
S
sweetsky0901 已提交
96
      out_level.Resize(output_flatten_shape);
S
sweetsky0901 已提交
97
      // concat
98
      auto out_level_stride = phi::stride(out_level.dims());
S
sweetsky0901 已提交
99
      StridedMemcpy<T>(context.template device_context<DeviceContext>(),
100 101 102 103 104
                       out_level.data<T>(),
                       out_level_stride,
                       out_level.dims(),
                       out_stride,
                       out->data<T>() + output_offset);
S
sweetsky0901 已提交
105
      output_offset += out_level.dims()[1] * out_level_stride[1];
S
sweetsky0901 已提交
106 107 108
    }
  }
};
S
sweetsky0901 已提交
109
template <typename DeviceContext, typename T>
S
sweetsky0901 已提交
110 111 112
class SppGradKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
113 114 115 116 117 118
    const phi::DenseTensor* in_x = context.Input<phi::DenseTensor>("X");
    const phi::DenseTensor* out = context.Input<phi::DenseTensor>("Out");
    const phi::DenseTensor* out_grad =
        context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
    phi::DenseTensor* in_x_grad =
        context.Output<phi::DenseTensor>(framework::GradVarName("X"));
S
sweetsky0901 已提交
119
    int pyramid_height = context.template Attr<int>("pyramid_height");
S
sweetsky0901 已提交
120 121
    std::string pooling_type =
        context.template Attr<std::string>("pooling_type");
S
sweetsky0901 已提交
122
    auto& device_ctx = context.template device_context<DeviceContext>();
F
From00 已提交
123
    phi::funcs::SetConstant<
124 125
        typename framework::ConvertToPhiContext<DeviceContext>::TYPE,
        T>
F
From00 已提交
126
        zero;
S
sweetsky0901 已提交
127 128
    in_x_grad->mutable_data<T>(context.GetPlace());
    zero(device_ctx, in_x_grad, static_cast<T>(0));
129
    auto out_stride = phi::stride(out->dims());
S
sweetsky0901 已提交
130 131 132 133 134
    int input_h = in_x->dims()[2];
    int input_w = in_x->dims()[3];
    size_t out_offset = 0;
    for (int p = 0; p < pyramid_height; ++p) {
      int bins = std::pow(2, p);
S
sweetsky0901 已提交
135 136 137 138 139 140
      int kernel_size_h = std::ceil(input_h / static_cast<double>(bins));
      int kernel_size_w = std::ceil(input_w / static_cast<double>(bins));
      int padding_h = (kernel_size_h * bins - input_h + 1) / 2;
      int padding_w = (kernel_size_w * bins - input_w + 1) / 2;
      std::vector<int> kernel_size({kernel_size_h, kernel_size_w});
      std::vector<int> strides({kernel_size_h, kernel_size_w});
S
sweetsky0901 已提交
141
      std::vector<int> paddings({padding_h, padding_w});
S
sweetsky0901 已提交
142
      // split out and outgrad  ...  to flatten
143 144
      phi::DenseTensor out_level;
      phi::DenseTensor outgrad_level;
S
sweetsky0901 已提交
145 146 147
      int out_flatten_w = in_x->dims()[1] * bins * bins;
      std::vector<int64_t> out_flatten_shape_vec(
          {in_x->dims()[0], out_flatten_w});
148
      framework::DDim out_flatten_shape(phi::make_ddim(out_flatten_shape_vec));
S
sweetsky0901 已提交
149 150
      out_level.mutable_data<T>(out_flatten_shape, context.GetPlace());
      outgrad_level.mutable_data<T>(out_flatten_shape, context.GetPlace());
151
      auto flatten_stride = phi::stride(out_level.dims());
S
sweetsky0901 已提交
152
      // memcpy
S
sweetsky0901 已提交
153
      StridedMemcpy<T>(context.template device_context<DeviceContext>(),
154 155 156 157 158
                       out->data<T>() + out_offset,
                       out_stride,
                       out_level.dims(),
                       flatten_stride,
                       out_level.data<T>());
S
sweetsky0901 已提交
159

S
sweetsky0901 已提交
160
      StridedMemcpy<T>(context.template device_context<DeviceContext>(),
161 162 163 164
                       out_grad->data<T>() + out_offset,
                       out_stride,
                       outgrad_level.dims(),
                       flatten_stride,
S
sweetsky0901 已提交
165 166
                       outgrad_level.data<T>());
      out_offset += out_level.dims()[1] * out_stride[1];
S
sweetsky0901 已提交
167
      // flatten backward to nchw
S
sweetsky0901 已提交
168

S
sweetsky0901 已提交
169
      std::vector<int64_t> out_shape_vec({in_x->dims()[0], in_x->dims()[1]});
S
sweetsky0901 已提交
170 171 172 173
      out_shape_vec.push_back(
          (input_h - kernel_size_h + 2 * padding_h) / kernel_size_h + 1);
      out_shape_vec.push_back(
          (input_w - kernel_size_w + 2 * padding_w) / kernel_size_w + 1);
174
      framework::DDim out_shape(phi::make_ddim(out_shape_vec));
S
sweetsky0901 已提交
175
      out_level.ShareDataWith(out_level);
S
sweetsky0901 已提交
176
      out_level.Resize(out_shape);
S
sweetsky0901 已提交
177
      outgrad_level.ShareDataWith(outgrad_level);
S
sweetsky0901 已提交
178
      outgrad_level.Resize(out_shape);
S
sweetsky0901 已提交
179
      // pooling backward
S
sweetsky0901 已提交
180
      if (pooling_type == "max") {
F
From00 已提交
181
        phi::funcs::MaxPool2dGradFunctor<
182 183
            typename framework::ConvertToPhiContext<DeviceContext>::TYPE,
            T>
F
From00 已提交
184
            pool2d_backward;
185 186 187 188 189 190 191 192
        pool2d_backward(context.template device_context<DeviceContext>(),
                        *in_x,
                        *&out_level,
                        *&outgrad_level,
                        kernel_size,
                        strides,
                        paddings,
                        in_x_grad);
S
sweetsky0901 已提交
193
      } else if (pooling_type == "avg") {
F
From00 已提交
194 195
        phi::funcs::Pool2dGradFunctor<
            typename framework::ConvertToPhiContext<DeviceContext>::TYPE,
196 197
            phi::funcs::AvgPoolGrad<T>,
            T>
S
sweetsky0901 已提交
198
            pool_backward;
F
From00 已提交
199
        phi::funcs::AvgPoolGrad<T> avg_process;
200 201 202 203 204 205 206 207 208 209 210
        pool_backward(context.template device_context<DeviceContext>(),
                      *in_x,
                      *&out_level,
                      *&outgrad_level,
                      kernel_size,
                      strides,
                      paddings,
                      true,
                      false,
                      in_x_grad,
                      avg_process);
S
sweetsky0901 已提交
211
      }
S
sweetsky0901 已提交
212 213 214 215 216
    }
  }
};
}  // namespace operators
}  // namespace paddle