pool_op.h 14.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include <algorithm>
18 19
#include <string>
#include <vector>
20

Y
Yi Wang 已提交
21 22 23 24
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
25
#if defined(__HIPCC__) || defined(__NVCC__)
26 27 28
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
#endif

29 30
namespace paddle {
namespace operators {
31 32 33
template <typename T>
struct DivideFunctor {
  HOSTDEVICE explicit inline DivideFunctor(int n) : n_inv((T)(1.0 / n)) {}
34 35 36 37 38

  template <typename U>
  HOSTDEVICE inline U operator()(const U& x) const {
    return x * static_cast<U>(n_inv);
  }
39 40 41 42

 private:
  T n_inv;
};
43 44

using Tensor = framework::Tensor;
45 46 47 48 49 50

class PoolOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
51 52 53 54

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
55 56 57

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
58
      const framework::OpKernelType& expected_kernel_type) const override;
59 60 61 62 63 64 65
};

class PoolOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
66 67 68 69

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
70 71 72 73

  framework::OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const framework::OpKernelType& expected_kernel_type) const override;
74 75 76 77
};

class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
78
  void Make() override;
79 80 81 82
};

class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
83
  void Make() override;
84
};
85 86 87

template <typename T = int>
inline void UpdatePadding(std::vector<T>* paddings, const bool global_pooling,
88 89 90
                          const bool adaptive,
                          const std::string padding_algorithm,
                          const framework::DDim data_dims,
91 92
                          const std::vector<T>& strides,
                          const std::vector<T>& ksize) {
93
  // set padding size == data_dims.size() * 2
94
  auto data_shape = framework::vectorize<T>(data_dims);
95 96
  if (static_cast<int>(paddings->size()) == data_dims.size()) {
    for (int i = 0; i < data_dims.size(); ++i) {
97
      T copy_pad = *(paddings->begin() + 2 * i);
98 99 100
      paddings->insert(paddings->begin() + 2 * i + 1, copy_pad);
    }
  } else {
101 102 103 104 105
    PADDLE_ENFORCE_EQ(data_dims.size() * 2, paddings->size(),
                      platform::errors::InvalidArgument(
                          "Paddings size %d should be the same or twice as the "
                          "pooling size %d.",
                          paddings->size(), data_dims.size() * 2));
106 107
  }

108
  // when padding_algorithm is "VALID" or "SAME"
109
  if (padding_algorithm == "SAME") {
110
    for (int i = 0; i < data_dims.size(); ++i) {
111 112
      T out_size = (data_dims[i] + strides[i] - 1) / strides[i];
      T pad_sum =
113 114
          std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i],
                   static_cast<T>(0));
115 116
      T pad_0 = pad_sum / 2;
      T pad_1 = pad_sum - pad_0;
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
      *(paddings->begin() + i * 2) = pad_0;
      *(paddings->begin() + i * 2 + 1) = pad_1;
    }
  } else if (padding_algorithm == "VALID") {
    for (auto it = paddings->begin(); it != paddings->end(); it++) {
      *it = 0;
    }
  }

  // if global_pooling == true or adaptive == true, padding will be ignore
  if (global_pooling || adaptive) {
    for (auto it = paddings->begin(); it != paddings->end(); it++) {
      *it = 0;
    }
  }
}

134 135
template <typename T = int>
inline void UpdateKsize(std::vector<T>* ksize,
136 137 138
                        const framework::DDim data_dims) {
  ksize->resize(static_cast<size_t>(data_dims.size()));
  for (size_t i = 0; i < ksize->size(); ++i) {
139
    *(ksize->begin() + i) = static_cast<T>(data_dims[i]);
140 141
  }
}
142

143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
inline int getReduceNum(const framework::Tensor& input,
                        const framework::Tensor* output,
                        const std::string data_format,
                        std::vector<int>* reduce_dim) {
  // data_format only can be NCHW
  bool channel_last = (data_format == "NHWC");
  if (channel_last) {
    return 0;
  }
  int reduce_num = 0;
  const int output_height = output->dims()[2];
  const int output_width = output->dims()[3];
  if ((output_height == 1) && (output_width == 1)) {
    reduce_dim->push_back(2);
    reduce_dim->push_back(3);
    reduce_num = input.dims()[2] * input.dims()[3];
  }
  return reduce_num;
}

Q
QI JUN 已提交
163
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
164
class PoolKernel : public framework::OpKernel<T> {
165 166
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
167
    const Tensor* in_x = context.Input<Tensor>("X");
168
    Tensor* out = context.Output<Tensor>("Out");
169

C
chengduoZH 已提交
170
    std::string pooling_type = context.Attr<std::string>("pooling_type");
171 172 173
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
174
    std::string data_format = context.Attr<std::string>("data_format");
175
    bool exclusive = context.Attr<bool>("exclusive");
176
    bool adaptive = context.Attr<bool>("adaptive");
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
    bool global_pooling = context.Attr<bool>("global_pooling");
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");

    const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");

    // update paddings
    auto in_x_dims = in_x->dims();
    framework::DDim data_dims;
    if (channel_last) {
      data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
    } else {
      data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
    }

    UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm,
                  data_dims, strides, ksize);
194 195
    if (data_dims.size() * 2 == static_cast<int>(paddings.size())) {
      for (int i = 0; i < data_dims.size(); ++i) {
196
        paddings.erase(paddings.begin() + i + 1);
197 198
      }
    }
199 200 201 202

    if (global_pooling) {
      UpdateKsize(&ksize, data_dims);
    }
Q
QI JUN 已提交
203
    auto& dev_ctx = context.template device_context<DeviceContext>();
204 205 206
    switch (ksize.size()) {
      case 2: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
207
          paddle::operators::math::Pool2dFunctor<
Q
QI JUN 已提交
208
              DeviceContext, paddle::operators::math::MaxPool<T>, T>
209
              pool2d_forward;
210
          paddle::operators::math::MaxPool<T> pool_process;
211
          pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
212
                         true, false, out, pool_process);
213

C
chengduoZH 已提交
214
        } else if (pooling_type == "avg") {
215 216 217 218 219
          std::vector<int> reduce_dim;
          int reduce_num = getReduceNum(*in_x, out, data_format, &reduce_dim);

          if (reduce_num > 0 &&
              adaptive) {  // for adaptive_avg_pool2d && output_size == 1
220
#if defined(__HIPCC__) || defined(__NVCC__)
221 222 223 224 225 226 227 228 229 230
            auto stream = dev_ctx.stream();
            TensorReduce<T, T, cub::Sum, DivideFunctor<T>>(
                *in_x, out, reduce_dim, static_cast<T>(0), cub::Sum(),
                DivideFunctor<T>(reduce_num), stream);
#else  // for cpu
            paddle::operators::math::Pool2dFunctor<
                DeviceContext, paddle::operators::math::AvgPool<T>, T>
                pool2d_forward;
            paddle::operators::math::AvgPool<T> pool_process;
            pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings,
231
                           data_format, exclusive, adaptive, out, pool_process);
232 233 234 235 236 237 238
#endif
          } else {  // avgpool_2d or  adaptive_avg_pool2d && output_size != 1
            paddle::operators::math::Pool2dFunctor<
                DeviceContext, paddle::operators::math::AvgPool<T>, T>
                pool2d_forward;
            paddle::operators::math::AvgPool<T> pool_process;
            pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings,
239
                           data_format, exclusive, adaptive, out, pool_process);
240
          }
241 242 243 244
        }
      } break;
      case 3: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
245
          paddle::operators::math::Pool3dFunctor<
Q
QI JUN 已提交
246
              DeviceContext, paddle::operators::math::MaxPool<T>, T>
247
              pool3d_forward;
248
          paddle::operators::math::MaxPool<T> pool_process;
249
          pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
250
                         true, false, out, pool_process);
251

C
chengduoZH 已提交
252
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
253
          paddle::operators::math::Pool3dFunctor<
Q
QI JUN 已提交
254
              DeviceContext, paddle::operators::math::AvgPool<T>, T>
255
              pool3d_forward;
256
          paddle::operators::math::AvgPool<T> pool_process;
257
          pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, data_format,
258
                         exclusive, adaptive, out, pool_process);
259 260
        }
      } break;
261 262 263 264
      default: {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Pool op only supports 2D and 3D input."));
      }
265 266 267 268
    }
  }
};

Q
QI JUN 已提交
269
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
270
class PoolGradKernel : public framework::OpKernel<T> {
271 272
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
273
    const Tensor* in_x = context.Input<Tensor>("X");
274 275 276
    const Tensor* out = context.Input<Tensor>("Out");
    const Tensor* out_grad =
        context.Input<Tensor>(framework::GradVarName("Out"));
C
chengduoZH 已提交
277
    Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
278

C
chengduoZH 已提交
279
    std::string pooling_type = context.Attr<std::string>("pooling_type");
280 281 282
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
283
    bool exclusive = context.Attr<bool>("exclusive");
284
    bool adaptive = context.Attr<bool>("adaptive");
285 286 287 288 289 290
    std::string data_format = context.Attr<std::string>("data_format");
    bool global_pooling = context.Attr<bool>("global_pooling");
    std::string padding_algorithm =
        context.Attr<std::string>("padding_algorithm");

    const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
291

292 293 294 295 296 297 298 299 300 301
    // update paddings
    auto in_x_dims = in_x->dims();
    framework::DDim data_dims;
    if (channel_last) {
      data_dims = framework::slice_ddim(in_x_dims, 1, in_x_dims.size() - 1);
    } else {
      data_dims = framework::slice_ddim(in_x_dims, 2, in_x_dims.size());
    }
    UpdatePadding(&paddings, global_pooling, adaptive, padding_algorithm,
                  data_dims, strides, ksize);
302 303
    if (data_dims.size() * 2 == static_cast<int>(paddings.size())) {
      for (int i = 0; i < data_dims.size(); ++i) {
304
        paddings.erase(paddings.begin() + i + 1);
C
fix bug  
chengduoZH 已提交
305
      }
306
    }
307 308 309 310 311

    if (global_pooling) {
      UpdateKsize(&ksize, data_dims);
    }

Q
QI JUN 已提交
312
    auto& dev_ctx = context.template device_context<DeviceContext>();
C
chengduoZH 已提交
313 314
    if (in_x_grad) {
      in_x_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
315
      paddle::operators::math::SetConstant<DeviceContext, T> set_constant;
316
      set_constant(dev_ctx, in_x_grad, static_cast<T>(0.0));
317 318 319 320

      switch (ksize.size()) {
        case 2: {
          if (pooling_type == "max") {
Q
QI JUN 已提交
321
            paddle::operators::math::MaxPool2dGradFunctor<DeviceContext, T>
322
                pool2d_backward;
Q
QI JUN 已提交
323
            pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
324
                            paddings, data_format, in_x_grad);
C
chengduoZH 已提交
325
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
326
            paddle::operators::math::Pool2dGradFunctor<
Q
QI JUN 已提交
327
                DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
328
                pool2d_backward;
329
            paddle::operators::math::AvgPoolGrad<T> pool_process;
Q
QI JUN 已提交
330
            pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
331 332
                            paddings, data_format, exclusive, adaptive,
                            in_x_grad, pool_process);
333 334 335 336
          }
        } break;
        case 3: {
          if (pooling_type == "max") {
Q
QI JUN 已提交
337
            paddle::operators::math::MaxPool3dGradFunctor<DeviceContext, T>
338
                pool3d_backward;
Q
QI JUN 已提交
339
            pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
340
                            paddings, data_format, in_x_grad);
C
chengduoZH 已提交
341
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
342
            paddle::operators::math::Pool3dGradFunctor<
Q
QI JUN 已提交
343
                DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
344
                pool3d_backward;
345
            paddle::operators::math::AvgPoolGrad<T> pool_process;
Q
QI JUN 已提交
346
            pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
347 348
                            paddings, data_format, exclusive, adaptive,
                            in_x_grad, pool_process);
349 350
          }
        } break;
351 352 353 354
        default: {
          PADDLE_THROW(platform::errors::InvalidArgument(
              "Pool op only supports 2D and 3D input."));
        }
355 356 357 358 359
      }
    }
  }
};

360 361 362 363 364 365 366 367 368 369 370 371 372 373
template <typename DeviceContext, typename T>
class PoolGradGradKernel : public PoolKernel<DeviceContext, T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    std::string pooling_type = context.Attr<std::string>("pooling_type");
    if (pooling_type == "max") {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Pool op grad grad only supports avgpool."));
    } else {
      PoolKernel<DeviceContext, T>::Compute(context);
    }
  }
};

374 375
}  // namespace operators
}  // namespace paddle