pool_op.h 7.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17 18
#include <string>
#include <vector>
Y
Yi Wang 已提交
19 20 21 22
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
23 24 25 26 27

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
28 29 30 31 32 33

class PoolOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
34 35 36 37

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
38 39 40 41 42 43 44
};

class PoolOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
45 46 47 48

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
49 50 51 52
};

class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
53
  void Make() override;
54 55 56 57
};

class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
58
  void Make() override;
59
};
60

Q
QI JUN 已提交
61
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
62
class PoolKernel : public framework::OpKernel<T> {
63 64
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
65
    const Tensor* in_x = context.Input<Tensor>("X");
66
    Tensor* out = context.Output<Tensor>("Out");
67

C
chengduoZH 已提交
68
    std::string pooling_type = context.Attr<std::string>("pooling_type");
69 70 71
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
72
    bool exclusive = context.Attr<bool>("exclusive");
73
    bool adaptive = context.Attr<bool>("adaptive");
C
chengduoZH 已提交
74
    if (context.Attr<bool>("global_pooling")) {
75
      for (size_t i = 0; i < ksize.size(); ++i) {
C
fix bug  
chengduoZH 已提交
76
        paddings[i] = 0;
C
chengduoZH 已提交
77
        ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
78 79
      }
    }
Q
QI JUN 已提交
80
    auto& dev_ctx = context.template device_context<DeviceContext>();
81 82 83
    switch (ksize.size()) {
      case 2: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
84
          paddle::operators::math::Pool2dFunctor<
Q
QI JUN 已提交
85
              DeviceContext, paddle::operators::math::MaxPool<T>, T>
86
              pool2d_forward;
87
          paddle::operators::math::MaxPool<T> pool_process;
Q
QI JUN 已提交
88
          pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
89
                         true, false, out);
90

C
chengduoZH 已提交
91
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
92
          paddle::operators::math::Pool2dFunctor<
Q
QI JUN 已提交
93
              DeviceContext, paddle::operators::math::AvgPool<T>, T>
94
              pool2d_forward;
95
          paddle::operators::math::AvgPool<T> pool_process;
Q
QI JUN 已提交
96
          pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
97
                         exclusive, adaptive, out);
98 99 100 101
        }
      } break;
      case 3: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
102
          paddle::operators::math::Pool3dFunctor<
Q
QI JUN 已提交
103
              DeviceContext, paddle::operators::math::MaxPool<T>, T>
104
              pool3d_forward;
105
          paddle::operators::math::MaxPool<T> pool_process;
Q
QI JUN 已提交
106
          pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
107
                         true, false, out);
C
chengduoZH 已提交
108
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
109
          paddle::operators::math::Pool3dFunctor<
Q
QI JUN 已提交
110
              DeviceContext, paddle::operators::math::AvgPool<T>, T>
111
              pool3d_forward;
112
          paddle::operators::math::AvgPool<T> pool_process;
Q
QI JUN 已提交
113
          pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
114
                         exclusive, adaptive, out);
115 116
        }
      } break;
C
fix bug  
chengduoZH 已提交
117
      default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
118 119 120 121
    }
  }
};

Q
QI JUN 已提交
122
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
123
class PoolGradKernel : public framework::OpKernel<T> {
124 125
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
126
    const Tensor* in_x = context.Input<Tensor>("X");
127 128 129
    const Tensor* out = context.Input<Tensor>("Out");
    const Tensor* out_grad =
        context.Input<Tensor>(framework::GradVarName("Out"));
C
chengduoZH 已提交
130
    Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
131

C
chengduoZH 已提交
132
    std::string pooling_type = context.Attr<std::string>("pooling_type");
133 134 135
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
136
    bool exclusive = context.Attr<bool>("exclusive");
137
    bool adaptive = context.Attr<bool>("adaptive");
138

C
chengduoZH 已提交
139
    if (context.Attr<bool>("global_pooling")) {
C
fix bug  
chengduoZH 已提交
140 141
      for (size_t i = 0; i < ksize.size(); ++i) {
        paddings[i] = 0;
C
chengduoZH 已提交
142
        ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
C
fix bug  
chengduoZH 已提交
143
      }
144
    }
Q
QI JUN 已提交
145
    auto& dev_ctx = context.template device_context<DeviceContext>();
C
chengduoZH 已提交
146 147
    if (in_x_grad) {
      in_x_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
148 149
      paddle::operators::math::SetConstant<DeviceContext, T> set_constant;
      set_constant(dev_ctx, in_x_grad, 0.0);
150 151 152 153

      switch (ksize.size()) {
        case 2: {
          if (pooling_type == "max") {
Q
QI JUN 已提交
154
            paddle::operators::math::MaxPool2dGradFunctor<DeviceContext, T>
155
                pool2d_backward;
Q
QI JUN 已提交
156 157
            pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
                            paddings, in_x_grad);
C
chengduoZH 已提交
158
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
159
            paddle::operators::math::Pool2dGradFunctor<
Q
QI JUN 已提交
160
                DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
161
                pool2d_backward;
162
            paddle::operators::math::AvgPoolGrad<T> pool_process;
Q
QI JUN 已提交
163
            pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
164 165
                            paddings, pool_process, exclusive, adaptive,
                            in_x_grad);
166 167 168 169
          }
        } break;
        case 3: {
          if (pooling_type == "max") {
Q
QI JUN 已提交
170
            paddle::operators::math::MaxPool3dGradFunctor<DeviceContext, T>
171
                pool3d_backward;
Q
QI JUN 已提交
172 173
            pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
                            paddings, in_x_grad);
C
chengduoZH 已提交
174
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
175
            paddle::operators::math::Pool3dGradFunctor<
Q
QI JUN 已提交
176
                DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
177
                pool3d_backward;
178
            paddle::operators::math::AvgPoolGrad<T> pool_process;
Q
QI JUN 已提交
179
            pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
180 181
                            paddings, pool_process, exclusive, adaptive,
                            in_x_grad);
182 183
          }
        } break;
C
fix bug  
chengduoZH 已提交
184
        default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
185 186 187 188 189 190 191
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle