pool_op.h 10.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <algorithm>
18 19
#include <string>
#include <vector>
Y
Yi Wang 已提交
20 21 22 23
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
24 25 26 27
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
28 29 30 31 32 33

class PoolOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
34 35 36 37

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
38 39 40 41 42 43 44
};

class PoolOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
45 46 47 48

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
49 50 51 52
};

class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
53
  void Make() override;
54 55 56 57
};

class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
58
  void Make() override;
59
};
60 61 62 63 64 65 66 67 68
inline void UpdatePadding(std::vector<int>* paddings, const bool global_pooling,
                          const bool adaptive,
                          const std::string padding_algorithm,
                          const framework::DDim data_dims,
                          const std::vector<int>& strides,
                          const std::vector<int>& ksize) {
  // set padding size == data_dims.size() * 2
  auto data_shape = framework::vectorize<int>(data_dims);
  if (paddings->size() == data_dims.size()) {
69
    for (size_t i = 0; i < data_dims.size(); ++i) {
70 71 72 73 74 75 76 77 78
      int copy_pad = *(paddings->begin() + 2 * i);
      paddings->insert(paddings->begin() + 2 * i + 1, copy_pad);
    }
  } else {
    PADDLE_ENFORCE_EQ(
        data_dims.size() * 2, paddings->size(),
        "Paddings size should be the same or twice as the pooling size.");
  }

79
  // when padding_algorithm is "VALID" or "SAME"
80
  if (padding_algorithm == "SAME") {
81
    for (int i = 0; i < data_dims.size(); ++i) {
82
      int out_size = (data_dims[i] + strides[i] - 1) / strides[i];
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
      int pad_sum =
          std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i], 0);
      int pad_0 = pad_sum / 2;
      int pad_1 = pad_sum - pad_0;
      *(paddings->begin() + i * 2) = pad_0;
      *(paddings->begin() + i * 2 + 1) = pad_1;
    }
  } else if (padding_algorithm == "VALID") {
    for (auto it = paddings->begin(); it != paddings->end(); it++) {
      *it = 0;
    }
  }

  // if global_pooling == true or adaptive == true, padding will be ignore
  if (global_pooling || adaptive) {
    for (auto it = paddings->begin(); it != paddings->end(); it++) {
      *it = 0;
    }
  }
}

inline void UpdateKsize(std::vector<int>* ksize,
                        const framework::DDim data_dims) {
  ksize->resize(static_cast<size_t>(data_dims.size()));
  for (size_t i = 0; i < ksize->size(); ++i) {
    *(ksize->begin() + i) = static_cast<int>(data_dims[i]);
  }
}
111

Q
QI JUN 已提交
112
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
113
class PoolKernel : public framework::OpKernel<T> {
114 115
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
116
    const Tensor* in_x = context.Input<Tensor>("X");
117
    Tensor* out = context.Output<Tensor>("Out");
118

C
chengduoZH 已提交
119
    std::string pooling_type = context.Attr<std::string>("pooling_type");
120 121 122
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
123
    std::string data_format = context.Attr<std::string>("data_format");
124
    bool exclusive = context.Attr<bool>("exclusive");
125
    bool adaptive = context.Attr<bool>("adaptive");
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
    bool global_pooling = context.Attr<bool>("global_pooling");
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");

    const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");

    // update paddings
    auto in_x_dims = in_x->dims();
    framework::DDim data_dims;
    if (channel_last) {
      data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
    } else {
      data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
    }

    UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm,
                  data_dims, strides, ksize);
    if (data_dims.size() * 2 == paddings.size()) {
      for (size_t i = 0; i < data_dims.size(); ++i) {
        paddings.erase(paddings.begin() + i + 1);
146 147
      }
    }
148 149 150 151 152

    if (global_pooling) {
      UpdateKsize(&ksize, data_dims);
    }

Q
QI JUN 已提交
153
    auto& dev_ctx = context.template device_context<DeviceContext>();
154 155 156
    switch (ksize.size()) {
      case 2: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
157
          paddle::operators::math::Pool2dFunctor<
Q
QI JUN 已提交
158
              DeviceContext, paddle::operators::math::MaxPool<T>, T>
159
              pool2d_forward;
160
          paddle::operators::math::MaxPool<T> pool_process;
161 162
          pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
                         pool_process, true, false, out);
163

C
chengduoZH 已提交
164
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
165
          paddle::operators::math::Pool2dFunctor<
Q
QI JUN 已提交
166
              DeviceContext, paddle::operators::math::AvgPool<T>, T>
167
              pool2d_forward;
168
          paddle::operators::math::AvgPool<T> pool_process;
169 170
          pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
                         pool_process, exclusive, adaptive, out);
171 172 173 174
        }
      } break;
      case 3: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
175
          paddle::operators::math::Pool3dFunctor<
Q
QI JUN 已提交
176
              DeviceContext, paddle::operators::math::MaxPool<T>, T>
177
              pool3d_forward;
178
          paddle::operators::math::MaxPool<T> pool_process;
179 180 181
          pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
                         pool_process, true, false, out);

C
chengduoZH 已提交
182
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
183
          paddle::operators::math::Pool3dFunctor<
Q
QI JUN 已提交
184
              DeviceContext, paddle::operators::math::AvgPool<T>, T>
185
              pool3d_forward;
186
          paddle::operators::math::AvgPool<T> pool_process;
187 188
          pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
                         pool_process, exclusive, adaptive, out);
189 190
        }
      } break;
C
fix bug  
chengduoZH 已提交
191
      default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
192 193 194 195
    }
  }
};

Q
QI JUN 已提交
196
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
197
class PoolGradKernel : public framework::OpKernel<T> {
198 199
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
200
    const Tensor* in_x = context.Input<Tensor>("X");
201 202 203
    const Tensor* out = context.Input<Tensor>("Out");
    const Tensor* out_grad =
        context.Input<Tensor>(framework::GradVarName("Out"));
C
chengduoZH 已提交
204
    Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
205

C
chengduoZH 已提交
206
    std::string pooling_type = context.Attr<std::string>("pooling_type");
207 208 209
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
210
    bool exclusive = context.Attr<bool>("exclusive");
211
    bool adaptive = context.Attr<bool>("adaptive");
212 213 214 215 216 217
    std::string data_format = context.Attr<std::string>("data_format");
    bool global_pooling = context.Attr<bool>("global_pooling");
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");

    const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
218

219 220 221 222 223 224 225 226 227 228 229 230 231
    // update paddings
    auto in_x_dims = in_x->dims();
    framework::DDim data_dims;
    if (channel_last) {
      data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
    } else {
      data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
    }
    UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm,
                  data_dims, strides, ksize);
    if (data_dims.size() * 2 == paddings.size()) {
      for (size_t i = 0; i < data_dims.size(); ++i) {
        paddings.erase(paddings.begin() + i + 1);
C
fix bug  
chengduoZH 已提交
232
      }
233
    }
234 235 236 237 238

    if (global_pooling) {
      UpdateKsize(&ksize, data_dims);
    }

Q
QI JUN 已提交
239
    auto& dev_ctx = context.template device_context<DeviceContext>();
C
chengduoZH 已提交
240 241
    if (in_x_grad) {
      in_x_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
242 243
      paddle::operators::math::SetConstant<DeviceContext, T> set_constant;
      set_constant(dev_ctx, in_x_grad, 0.0);
244 245 246 247

      switch (ksize.size()) {
        case 2: {
          if (pooling_type == "max") {
Q
QI JUN 已提交
248
            paddle::operators::math::MaxPool2dGradFunctor<DeviceContext, T>
249
                pool2d_backward;
Q
QI JUN 已提交
250
            pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
251
                            paddings, data_format, in_x_grad);
C
chengduoZH 已提交
252
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
253
            paddle::operators::math::Pool2dGradFunctor<
Q
QI JUN 已提交
254
                DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
255
                pool2d_backward;
256
            paddle::operators::math::AvgPoolGrad<T> pool_process;
Q
QI JUN 已提交
257
            pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
258 259
                            paddings, data_format, pool_process, exclusive,
                            adaptive, in_x_grad);
260 261 262 263
          }
        } break;
        case 3: {
          if (pooling_type == "max") {
Q
QI JUN 已提交
264
            paddle::operators::math::MaxPool3dGradFunctor<DeviceContext, T>
265
                pool3d_backward;
Q
QI JUN 已提交
266
            pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
267
                            paddings, data_format, in_x_grad);
C
chengduoZH 已提交
268
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
269
            paddle::operators::math::Pool3dGradFunctor<
Q
QI JUN 已提交
270
                DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
271
                pool3d_backward;
272
            paddle::operators::math::AvgPoolGrad<T> pool_process;
Q
QI JUN 已提交
273
            pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
274 275
                            paddings, data_format, pool_process, exclusive,
                            adaptive, in_x_grad);
276 277
          }
        } break;
C
fix bug  
chengduoZH 已提交
278
        default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
279 280 281 282 283 284 285
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle