pool_op.h 7.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17 18
#include <string>
#include <vector>
Y
Yi Wang 已提交
19 20 21 22
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
23 24 25 26 27

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
28 29 30 31 32 33

class PoolOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
34 35 36 37

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
38 39 40 41 42 43 44
};

class PoolOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
45 46 47 48

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
49 50 51 52
};

class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
53
  void Make() override;
54 55 56 57
};

class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
58
  void Make() override;
59
};
60

Q
QI JUN 已提交
61
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
62
class PoolKernel : public framework::OpKernel<T> {
63 64
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
65
    const Tensor* in_x = context.Input<Tensor>("X");
66
    Tensor* out = context.Output<Tensor>("Out");
67

C
chengduoZH 已提交
68
    std::string pooling_type = context.Attr<std::string>("pooling_type");
69 70 71
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
72
    bool exclusive = context.Attr<bool>("exclusive");
C
chengduoZH 已提交
73
    if (context.Attr<bool>("global_pooling")) {
74
      for (size_t i = 0; i < ksize.size(); ++i) {
C
fix bug  
chengduoZH 已提交
75
        paddings[i] = 0;
C
chengduoZH 已提交
76
        ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
77 78
      }
    }
Q
QI JUN 已提交
79
    auto& dev_ctx = context.template device_context<DeviceContext>();
80 81 82
    switch (ksize.size()) {
      case 2: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
83
          paddle::operators::math::Pool2dFunctor<
Q
QI JUN 已提交
84
              DeviceContext, paddle::operators::math::MaxPool<T>, T>
85
              pool2d_forward;
86
          paddle::operators::math::MaxPool<T> pool_process;
Q
QI JUN 已提交
87
          pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
88
                         true, out);
89

C
chengduoZH 已提交
90
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
91
          paddle::operators::math::Pool2dFunctor<
Q
QI JUN 已提交
92
              DeviceContext, paddle::operators::math::AvgPool<T>, T>
93
              pool2d_forward;
94
          paddle::operators::math::AvgPool<T> pool_process;
Q
QI JUN 已提交
95
          pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
96
                         exclusive, out);
97 98 99 100
        }
      } break;
      case 3: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
101
          paddle::operators::math::Pool3dFunctor<
Q
QI JUN 已提交
102
              DeviceContext, paddle::operators::math::MaxPool<T>, T>
103
              pool3d_forward;
104
          paddle::operators::math::MaxPool<T> pool_process;
Q
QI JUN 已提交
105
          pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
106
                         true, out);
C
chengduoZH 已提交
107
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
108
          paddle::operators::math::Pool3dFunctor<
Q
QI JUN 已提交
109
              DeviceContext, paddle::operators::math::AvgPool<T>, T>
110
              pool3d_forward;
111
          paddle::operators::math::AvgPool<T> pool_process;
Q
QI JUN 已提交
112
          pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
113
                         exclusive, out);
114 115
        }
      } break;
C
fix bug  
chengduoZH 已提交
116
      default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
117 118 119 120
    }
  }
};

Q
QI JUN 已提交
121
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
122
class PoolGradKernel : public framework::OpKernel<T> {
123 124
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
125
    const Tensor* in_x = context.Input<Tensor>("X");
126 127 128
    const Tensor* out = context.Input<Tensor>("Out");
    const Tensor* out_grad =
        context.Input<Tensor>(framework::GradVarName("Out"));
C
chengduoZH 已提交
129
    Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
130

C
chengduoZH 已提交
131
    std::string pooling_type = context.Attr<std::string>("pooling_type");
132 133 134
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
135
    bool exclusive = context.Attr<bool>("exclusive");
136

C
chengduoZH 已提交
137
    if (context.Attr<bool>("global_pooling")) {
C
fix bug  
chengduoZH 已提交
138 139
      for (size_t i = 0; i < ksize.size(); ++i) {
        paddings[i] = 0;
C
chengduoZH 已提交
140
        ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
C
fix bug  
chengduoZH 已提交
141
      }
142
    }
Q
QI JUN 已提交
143
    auto& dev_ctx = context.template device_context<DeviceContext>();
C
chengduoZH 已提交
144 145
    if (in_x_grad) {
      in_x_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
146 147
      paddle::operators::math::SetConstant<DeviceContext, T> set_constant;
      set_constant(dev_ctx, in_x_grad, 0.0);
148 149 150 151

      switch (ksize.size()) {
        case 2: {
          if (pooling_type == "max") {
Q
QI JUN 已提交
152
            paddle::operators::math::MaxPool2dGradFunctor<DeviceContext, T>
153
                pool2d_backward;
Q
QI JUN 已提交
154 155
            pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
                            paddings, in_x_grad);
C
chengduoZH 已提交
156
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
157
            paddle::operators::math::Pool2dGradFunctor<
Q
QI JUN 已提交
158
                DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
159
                pool2d_backward;
160
            paddle::operators::math::AvgPoolGrad<T> pool_process;
Q
QI JUN 已提交
161
            pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
162
                            paddings, pool_process, exclusive, in_x_grad);
163 164 165 166
          }
        } break;
        case 3: {
          if (pooling_type == "max") {
Q
QI JUN 已提交
167
            paddle::operators::math::MaxPool3dGradFunctor<DeviceContext, T>
168
                pool3d_backward;
Q
QI JUN 已提交
169 170
            pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
                            paddings, in_x_grad);
C
chengduoZH 已提交
171
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
172
            paddle::operators::math::Pool3dGradFunctor<
Q
QI JUN 已提交
173
                DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
174
                pool3d_backward;
175
            paddle::operators::math::AvgPoolGrad<T> pool_process;
Q
QI JUN 已提交
176
            pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
177
                            paddings, pool_process, exclusive, in_x_grad);
178 179
          }
        } break;
C
fix bug  
chengduoZH 已提交
180
        default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
181 182 183 184 185 186 187
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle