pool_op.h 7.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17 18
#include <string>
#include <vector>
Y
Yi Wang 已提交
19 20 21 22
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/pooling.h"
23 24 25 26 27

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
28 29 30 31 32 33

class PoolOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
34 35 36 37

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
38 39 40 41 42 43 44
};

class PoolOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
45 46 47 48

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
49 50 51 52
};

class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
53
  Pool2dOpMaker(OpProto* proto, OpAttrChecker* op_checker);
54 55 56 57
};

class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
58
  Pool3dOpMaker(OpProto* proto, OpAttrChecker* op_checker);
59
};
60

Q
QI JUN 已提交
61
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
62
class PoolKernel : public framework::OpKernel<T> {
63 64
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
65
    const Tensor* in_x = context.Input<Tensor>("X");
66
    Tensor* out = context.Output<Tensor>("Out");
67

C
chengduoZH 已提交
68
    std::string pooling_type = context.Attr<std::string>("pooling_type");
69 70 71
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
72
    if (context.Attr<bool>("global_pooling")) {
73
      for (size_t i = 0; i < ksize.size(); ++i) {
C
fix bug  
chengduoZH 已提交
74
        paddings[i] = 0;
C
chengduoZH 已提交
75
        ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
76 77
      }
    }
Q
QI JUN 已提交
78
    auto& dev_ctx = context.template device_context<DeviceContext>();
79 80 81
    switch (ksize.size()) {
      case 2: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
82
          paddle::operators::math::Pool2dFunctor<
Q
QI JUN 已提交
83
              DeviceContext, paddle::operators::math::MaxPool<T>, T>
84
              pool2d_forward;
85
          paddle::operators::math::MaxPool<T> pool_process;
Q
QI JUN 已提交
86 87
          pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
                         out);
88

C
chengduoZH 已提交
89
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
90
          paddle::operators::math::Pool2dFunctor<
Q
QI JUN 已提交
91
              DeviceContext, paddle::operators::math::AvgPool<T>, T>
92
              pool2d_forward;
93
          paddle::operators::math::AvgPool<T> pool_process;
Q
QI JUN 已提交
94 95
          pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
                         out);
96 97 98 99
        }
      } break;
      case 3: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
100
          paddle::operators::math::Pool3dFunctor<
Q
QI JUN 已提交
101
              DeviceContext, paddle::operators::math::MaxPool<T>, T>
102
              pool3d_forward;
103
          paddle::operators::math::MaxPool<T> pool_process;
Q
QI JUN 已提交
104 105
          pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
                         out);
C
chengduoZH 已提交
106
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
107
          paddle::operators::math::Pool3dFunctor<
Q
QI JUN 已提交
108
              DeviceContext, paddle::operators::math::AvgPool<T>, T>
109
              pool3d_forward;
110
          paddle::operators::math::AvgPool<T> pool_process;
Q
QI JUN 已提交
111 112
          pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
                         out);
113 114
        }
      } break;
C
fix bug  
chengduoZH 已提交
115
      default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
116 117 118 119
    }
  }
};

Q
QI JUN 已提交
120
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
121
class PoolGradKernel : public framework::OpKernel<T> {
122 123
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
124
    const Tensor* in_x = context.Input<Tensor>("X");
125 126 127
    const Tensor* out = context.Input<Tensor>("Out");
    const Tensor* out_grad =
        context.Input<Tensor>(framework::GradVarName("Out"));
C
chengduoZH 已提交
128
    Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
129

C
chengduoZH 已提交
130
    std::string pooling_type = context.Attr<std::string>("pooling_type");
131 132 133 134
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");

C
chengduoZH 已提交
135
    if (context.Attr<bool>("global_pooling")) {
C
fix bug  
chengduoZH 已提交
136 137
      for (size_t i = 0; i < ksize.size(); ++i) {
        paddings[i] = 0;
C
chengduoZH 已提交
138
        ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
C
fix bug  
chengduoZH 已提交
139
      }
140
    }
Q
QI JUN 已提交
141
    auto& dev_ctx = context.template device_context<DeviceContext>();
C
chengduoZH 已提交
142 143
    if (in_x_grad) {
      in_x_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
144 145
      paddle::operators::math::SetConstant<DeviceContext, T> set_constant;
      set_constant(dev_ctx, in_x_grad, 0.0);
146 147 148 149

      switch (ksize.size()) {
        case 2: {
          if (pooling_type == "max") {
Q
QI JUN 已提交
150
            paddle::operators::math::MaxPool2dGradFunctor<DeviceContext, T>
151
                pool2d_backward;
Q
QI JUN 已提交
152 153
            pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
                            paddings, in_x_grad);
C
chengduoZH 已提交
154
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
155
            paddle::operators::math::Pool2dGradFunctor<
Q
QI JUN 已提交
156
                DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
157
                pool2d_backward;
158
            paddle::operators::math::AvgPoolGrad<T> pool_process;
Q
QI JUN 已提交
159 160
            pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
                            paddings, pool_process, in_x_grad);
161 162 163 164
          }
        } break;
        case 3: {
          if (pooling_type == "max") {
Q
QI JUN 已提交
165
            paddle::operators::math::MaxPool3dGradFunctor<DeviceContext, T>
166
                pool3d_backward;
Q
QI JUN 已提交
167 168
            pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
                            paddings, in_x_grad);
C
chengduoZH 已提交
169
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
170
            paddle::operators::math::Pool3dGradFunctor<
Q
QI JUN 已提交
171
                DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
172
                pool3d_backward;
173
            paddle::operators::math::AvgPoolGrad<T> pool_process;
Q
QI JUN 已提交
174 175
            pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
                            paddings, pool_process, in_x_grad);
176 177
          }
        } break;
C
fix bug  
chengduoZH 已提交
178
        default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
179 180 181 182 183 184 185
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle