pool_op.h 11.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <algorithm>
18 19
#include <string>
#include <vector>
Y
Yi Wang 已提交
20 21 22 23
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
24 25 26 27
namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
28 29 30 31 32 33

class PoolOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
34 35 36 37

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
38 39 40 41

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const;
42 43 44 45 46 47 48
};

class PoolOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
49 50 51 52

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
53 54 55 56
};

class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
57
  void Make() override;
58 59 60 61
};

class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
62
  void Make() override;
63
};
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
inline void UpdatePadding(std::vector<int>* paddings, const bool global_pooling,
                          const bool adaptive,
                          const std::string padding_algorithm,
                          const framework::DDim data_dims,
                          const std::vector<int>& strides,
                          const std::vector<int>& ksize) {
  // set padding size == data_dims.size() * 2
  auto data_shape = framework::vectorize<int>(data_dims);
  if (paddings->size() == data_dims.size()) {
    for (size_t i = 0; i < data_dims.size(); ++i) {
      int copy_pad = *(paddings->begin() + 2 * i);
      paddings->insert(paddings->begin() + 2 * i + 1, copy_pad);
    }
  } else {
    PADDLE_ENFORCE_EQ(
        data_dims.size() * 2, paddings->size(),
        "Paddings size should be the same or twice as the pooling size.");
  }

83
  // when padding_algorithm is "VALID" or "SAME"
84
  if (padding_algorithm == "SAME") {
85 86
    for (int i = 0; i < data_dims.size(); ++i) {
      int out_size = (data_dims[i] + strides[i] - 1) / strides[i];
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
      int pad_sum =
          std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i], 0);
      int pad_0 = pad_sum / 2;
      int pad_1 = pad_sum - pad_0;
      *(paddings->begin() + i * 2) = pad_0;
      *(paddings->begin() + i * 2 + 1) = pad_1;
    }
  } else if (padding_algorithm == "VALID") {
    for (auto it = paddings->begin(); it != paddings->end(); it++) {
      *it = 0;
    }
  }

  // if global_pooling == true or adaptive == true, padding will be ignore
  if (global_pooling || adaptive) {
    for (auto it = paddings->begin(); it != paddings->end(); it++) {
      *it = 0;
    }
  }
}

inline void UpdateKsize(std::vector<int>* ksize,
                        const framework::DDim data_dims) {
  ksize->resize(static_cast<size_t>(data_dims.size()));
  for (size_t i = 0; i < ksize->size(); ++i) {
    *(ksize->begin() + i) = static_cast<int>(data_dims[i]);
  }
}
115

Q
QI JUN 已提交
116
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
117
class PoolKernel : public framework::OpKernel<T> {
118 119
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
120
    const Tensor* in_x = context.Input<Tensor>("X");
121
    Tensor* out = context.Output<Tensor>("Out");
122

C
chengduoZH 已提交
123
    std::string pooling_type = context.Attr<std::string>("pooling_type");
124 125 126
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
127
    std::string data_format = context.Attr<std::string>("data_format");
128
    bool exclusive = context.Attr<bool>("exclusive");
129
    bool adaptive = context.Attr<bool>("adaptive");
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
    bool global_pooling = context.Attr<bool>("global_pooling");
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");

    const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");

    // update paddings
    auto in_x_dims = in_x->dims();
    framework::DDim data_dims;
    if (channel_last) {
      data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
    } else {
      data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
    }

    UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm,
                  data_dims, strides, ksize);
    if (data_dims.size() * 2 == paddings.size()) {
      for (size_t i = 0; i < data_dims.size(); ++i) {
        paddings.erase(paddings.begin() + i + 1);
150 151
      }
    }
152 153 154 155 156

    if (global_pooling) {
      UpdateKsize(&ksize, data_dims);
    }

Q
QI JUN 已提交
157
    auto& dev_ctx = context.template device_context<DeviceContext>();
158 159 160
    switch (ksize.size()) {
      case 2: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
161
          paddle::operators::math::Pool2dFunctor<
Q
QI JUN 已提交
162
              DeviceContext, paddle::operators::math::MaxPool<T>, T>
163
              pool2d_forward;
164
          paddle::operators::math::MaxPool<T> pool_process;
165 166
          pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
                         pool_process, true, false, out);
167

C
chengduoZH 已提交
168
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
169
          paddle::operators::math::Pool2dFunctor<
Q
QI JUN 已提交
170
              DeviceContext, paddle::operators::math::AvgPool<T>, T>
171
              pool2d_forward;
172
          paddle::operators::math::AvgPool<T> pool_process;
173 174
          pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
                         pool_process, exclusive, adaptive, out);
175 176 177 178
        }
      } break;
      case 3: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
179
          paddle::operators::math::Pool3dFunctor<
Q
QI JUN 已提交
180
              DeviceContext, paddle::operators::math::MaxPool<T>, T>
181
              pool3d_forward;
182
          paddle::operators::math::MaxPool<T> pool_process;
183 184 185
          pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
                         pool_process, true, false, out);

C
chengduoZH 已提交
186
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
187
          paddle::operators::math::Pool3dFunctor<
Q
QI JUN 已提交
188
              DeviceContext, paddle::operators::math::AvgPool<T>, T>
189
              pool3d_forward;
190
          paddle::operators::math::AvgPool<T> pool_process;
191 192
          pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
                         pool_process, exclusive, adaptive, out);
193 194
        }
      } break;
C
fix bug  
chengduoZH 已提交
195
      default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
196 197 198 199
    }
  }
};

Q
QI JUN 已提交
200
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
201
class PoolGradKernel : public framework::OpKernel<T> {
202 203
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
204
    const Tensor* in_x = context.Input<Tensor>("X");
205 206 207
    const Tensor* out = context.Input<Tensor>("Out");
    const Tensor* out_grad =
        context.Input<Tensor>(framework::GradVarName("Out"));
C
chengduoZH 已提交
208
    Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
209

C
chengduoZH 已提交
210
    std::string pooling_type = context.Attr<std::string>("pooling_type");
211 212 213
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
214
    bool exclusive = context.Attr<bool>("exclusive");
215
    bool adaptive = context.Attr<bool>("adaptive");
216 217 218 219 220 221
    std::string data_format = context.Attr<std::string>("data_format");
    bool global_pooling = context.Attr<bool>("global_pooling");
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");

    const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
222

223 224 225 226 227 228 229 230 231 232 233 234 235
    // update paddings
    auto in_x_dims = in_x->dims();
    framework::DDim data_dims;
    if (channel_last) {
      data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
    } else {
      data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
    }
    UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm,
                  data_dims, strides, ksize);
    if (data_dims.size() * 2 == paddings.size()) {
      for (size_t i = 0; i < data_dims.size(); ++i) {
        paddings.erase(paddings.begin() + i + 1);
C
fix bug  
chengduoZH 已提交
236
      }
237
    }
238 239 240 241 242

    if (global_pooling) {
      UpdateKsize(&ksize, data_dims);
    }

Q
QI JUN 已提交
243
    auto& dev_ctx = context.template device_context<DeviceContext>();
C
chengduoZH 已提交
244 245
    if (in_x_grad) {
      in_x_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
246 247
      paddle::operators::math::SetConstant<DeviceContext, T> set_constant;
      set_constant(dev_ctx, in_x_grad, 0.0);
248 249 250 251

      switch (ksize.size()) {
        case 2: {
          if (pooling_type == "max") {
Q
QI JUN 已提交
252
            paddle::operators::math::MaxPool2dGradFunctor<DeviceContext, T>
253
                pool2d_backward;
Q
QI JUN 已提交
254
            pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
255
                            paddings, data_format, in_x_grad);
C
chengduoZH 已提交
256
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
257
            paddle::operators::math::Pool2dGradFunctor<
Q
QI JUN 已提交
258
                DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
259
                pool2d_backward;
260
            paddle::operators::math::AvgPoolGrad<T> pool_process;
Q
QI JUN 已提交
261
            pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
262 263
                            paddings, data_format, pool_process, exclusive,
                            adaptive, in_x_grad);
264 265 266 267
          }
        } break;
        case 3: {
          if (pooling_type == "max") {
Q
QI JUN 已提交
268
            paddle::operators::math::MaxPool3dGradFunctor<DeviceContext, T>
269
                pool3d_backward;
Q
QI JUN 已提交
270
            pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
271
                            paddings, data_format, in_x_grad);
C
chengduoZH 已提交
272
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
273
            paddle::operators::math::Pool3dGradFunctor<
Q
QI JUN 已提交
274
                DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
275
                pool3d_backward;
276
            paddle::operators::math::AvgPoolGrad<T> pool_process;
Q
QI JUN 已提交
277
            pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
278 279
                            paddings, data_format, pool_process, exclusive,
                            adaptive, in_x_grad);
280 281
          }
        } break;
C
fix bug  
chengduoZH 已提交
282
        default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
283 284 285 286 287 288 289
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle