pool_op.h 11.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <algorithm>
18 19
#include <string>
#include <vector>
20

Y
Yi Wang 已提交
21 22 23 24
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
25 26 27 28
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
29 30 31 32 33 34

class PoolOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
35 36 37 38

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
39 40 41

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
42
      const framework::OpKernelType& expected_kernel_type) const override;
43 44 45 46 47 48 49
};

class PoolOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
50 51 52 53

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
54 55 56 57

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override;
58 59 60 61
};

class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
62
  void Make() override;
63 64 65 66
};

class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
67
  void Make() override;
68
};
69 70 71

template <typename T = int>
inline void UpdatePadding(std::vector<T>* paddings, const bool global_pooling,
72 73 74
                          const bool adaptive,
                          const std::string padding_algorithm,
                          const framework::DDim data_dims,
75 76
                          const std::vector<T>& strides,
                          const std::vector<T>& ksize) {
77
  // set padding size == data_dims.size() * 2
78
  auto data_shape = framework::vectorize<T>(data_dims);
79 80
  if (static_cast<int>(paddings->size()) == data_dims.size()) {
    for (int i = 0; i < data_dims.size(); ++i) {
81
      T copy_pad = *(paddings->begin() + 2 * i);
82 83 84
      paddings->insert(paddings->begin() + 2 * i + 1, copy_pad);
    }
  } else {
85 86 87 88 89
    PADDLE_ENFORCE_EQ(data_dims.size() * 2, paddings->size(),
                      platform::errors::InvalidArgument(
                          "Paddings size %d should be the same or twice as the "
                          "pooling size %d.",
                          paddings->size(), data_dims.size() * 2));
90 91
  }

92
  // when padding_algorithm is "VALID" or "SAME"
93
  if (padding_algorithm == "SAME") {
94
    for (int i = 0; i < data_dims.size(); ++i) {
95 96
      T out_size = (data_dims[i] + strides[i] - 1) / strides[i];
      T pad_sum =
97 98
          std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i],
                   static_cast<T>(0));
99 100
      T pad_0 = pad_sum / 2;
      T pad_1 = pad_sum - pad_0;
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
      *(paddings->begin() + i * 2) = pad_0;
      *(paddings->begin() + i * 2 + 1) = pad_1;
    }
  } else if (padding_algorithm == "VALID") {
    for (auto it = paddings->begin(); it != paddings->end(); it++) {
      *it = 0;
    }
  }

  // if global_pooling == true or adaptive == true, padding will be ignore
  if (global_pooling || adaptive) {
    for (auto it = paddings->begin(); it != paddings->end(); it++) {
      *it = 0;
    }
  }
}

118 119
template <typename T = int>
inline void UpdateKsize(std::vector<T>* ksize,
120 121 122
                        const framework::DDim data_dims) {
  ksize->resize(static_cast<size_t>(data_dims.size()));
  for (size_t i = 0; i < ksize->size(); ++i) {
123
    *(ksize->begin() + i) = static_cast<T>(data_dims[i]);
124 125
  }
}
126

Q
QI JUN 已提交
127
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
128
class PoolKernel : public framework::OpKernel<T> {
129 130
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
131
    const Tensor* in_x = context.Input<Tensor>("X");
132
    Tensor* out = context.Output<Tensor>("Out");
133

C
chengduoZH 已提交
134
    std::string pooling_type = context.Attr<std::string>("pooling_type");
135 136 137
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
138
    std::string data_format = context.Attr<std::string>("data_format");
139
    bool exclusive = context.Attr<bool>("exclusive");
140
    bool adaptive = context.Attr<bool>("adaptive");
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
    bool global_pooling = context.Attr<bool>("global_pooling");
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");

    const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");

    // update paddings
    auto in_x_dims = in_x->dims();
    framework::DDim data_dims;
    if (channel_last) {
      data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
    } else {
      data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
    }

    UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm,
                  data_dims, strides, ksize);
158 159
    if (data_dims.size() * 2 == static_cast<int>(paddings.size())) {
      for (int i = 0; i < data_dims.size(); ++i) {
160
        paddings.erase(paddings.begin() + i + 1);
161 162
      }
    }
163 164 165 166 167

    if (global_pooling) {
      UpdateKsize(&ksize, data_dims);
    }

Q
QI JUN 已提交
168
    auto& dev_ctx = context.template device_context<DeviceContext>();
169 170 171
    switch (ksize.size()) {
      case 2: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
172
          paddle::operators::math::Pool2dFunctor<
Q
QI JUN 已提交
173
              DeviceContext, paddle::operators::math::MaxPool<T>, T>
174
              pool2d_forward;
175
          paddle::operators::math::MaxPool<T> pool_process;
176 177
          pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
                         pool_process, true, false, out);
178

C
chengduoZH 已提交
179
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
180
          paddle::operators::math::Pool2dFunctor<
Q
QI JUN 已提交
181
              DeviceContext, paddle::operators::math::AvgPool<T>, T>
182
              pool2d_forward;
183
          paddle::operators::math::AvgPool<T> pool_process;
184 185
          pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
                         pool_process, exclusive, adaptive, out);
186 187 188 189
        }
      } break;
      case 3: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
190
          paddle::operators::math::Pool3dFunctor<
Q
QI JUN 已提交
191
              DeviceContext, paddle::operators::math::MaxPool<T>, T>
192
              pool3d_forward;
193
          paddle::operators::math::MaxPool<T> pool_process;
194 195 196
          pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
                         pool_process, true, false, out);

C
chengduoZH 已提交
197
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
198
          paddle::operators::math::Pool3dFunctor<
Q
QI JUN 已提交
199
              DeviceContext, paddle::operators::math::AvgPool<T>, T>
200
              pool3d_forward;
201
          paddle::operators::math::AvgPool<T> pool_process;
202 203
          pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
                         pool_process, exclusive, adaptive, out);
204 205
        }
      } break;
206 207 208 209
      default: {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Pool op only supports 2D and 3D input."));
      }
210 211 212 213
    }
  }
};

Q
QI JUN 已提交
214
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
215
class PoolGradKernel : public framework::OpKernel<T> {
216 217
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
218
    const Tensor* in_x = context.Input<Tensor>("X");
219 220 221
    const Tensor* out = context.Input<Tensor>("Out");
    const Tensor* out_grad =
        context.Input<Tensor>(framework::GradVarName("Out"));
C
chengduoZH 已提交
222
    Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
223

C
chengduoZH 已提交
224
    std::string pooling_type = context.Attr<std::string>("pooling_type");
225 226 227
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
228
    bool exclusive = context.Attr<bool>("exclusive");
229
    bool adaptive = context.Attr<bool>("adaptive");
230 231 232 233 234 235
    std::string data_format = context.Attr<std::string>("data_format");
    bool global_pooling = context.Attr<bool>("global_pooling");
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");

    const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
236

237 238 239 240 241 242 243 244 245 246
    // update paddings
    auto in_x_dims = in_x->dims();
    framework::DDim data_dims;
    if (channel_last) {
      data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
    } else {
      data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
    }
    UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm,
                  data_dims, strides, ksize);
247 248
    if (data_dims.size() * 2 == static_cast<int>(paddings.size())) {
      for (int i = 0; i < data_dims.size(); ++i) {
249
        paddings.erase(paddings.begin() + i + 1);
C
fix bug  
chengduoZH 已提交
250
      }
251
    }
252 253 254 255 256

    if (global_pooling) {
      UpdateKsize(&ksize, data_dims);
    }

Q
QI JUN 已提交
257
    auto& dev_ctx = context.template device_context<DeviceContext>();
C
chengduoZH 已提交
258 259
    if (in_x_grad) {
      in_x_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
260
      paddle::operators::math::SetConstant<DeviceContext, T> set_constant;
261
      set_constant(dev_ctx, in_x_grad, static_cast<T>(0.0));
262 263 264 265

      switch (ksize.size()) {
        case 2: {
          if (pooling_type == "max") {
Q
QI JUN 已提交
266
            paddle::operators::math::MaxPool2dGradFunctor<DeviceContext, T>
267
                pool2d_backward;
Q
QI JUN 已提交
268
            pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
269
                            paddings, data_format, in_x_grad);
C
chengduoZH 已提交
270
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
271
            paddle::operators::math::Pool2dGradFunctor<
Q
QI JUN 已提交
272
                DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
273
                pool2d_backward;
274
            paddle::operators::math::AvgPoolGrad<T> pool_process;
Q
QI JUN 已提交
275
            pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
276 277
                            paddings, data_format, pool_process, exclusive,
                            adaptive, in_x_grad);
278 279 280 281
          }
        } break;
        case 3: {
          if (pooling_type == "max") {
Q
QI JUN 已提交
282
            paddle::operators::math::MaxPool3dGradFunctor<DeviceContext, T>
283
                pool3d_backward;
Q
QI JUN 已提交
284
            pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
285
                            paddings, data_format, in_x_grad);
C
chengduoZH 已提交
286
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
287
            paddle::operators::math::Pool3dGradFunctor<
Q
QI JUN 已提交
288
                DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
289
                pool3d_backward;
290
            paddle::operators::math::AvgPoolGrad<T> pool_process;
Q
QI JUN 已提交
291
            pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
292 293
                            paddings, data_format, pool_process, exclusive,
                            adaptive, in_x_grad);
294 295
          }
        } break;
296 297 298 299
        default: {
          PADDLE_THROW(platform::errors::InvalidArgument(
              "Pool op only supports 2D and 3D input."));
        }
300 301 302 303 304 305 306
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle