pool_op.h 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/pooling.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
26 27 28 29 30 31

class PoolOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
32 33 34 35

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
36 37 38 39 40 41 42
};

class PoolOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
43 44 45 46

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
47 48 49 50
};

class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
51
  Pool2dOpMaker(OpProto* proto, OpAttrChecker* op_checker);
52 53 54 55
};

class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
56
  Pool3dOpMaker(OpProto* proto, OpAttrChecker* op_checker);
57
};
58

Q
QI JUN 已提交
59
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
60
class PoolKernel : public framework::OpKernel<T> {
61 62
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
63
    const Tensor* in_x = context.Input<Tensor>("X");
64
    Tensor* out = context.Output<Tensor>("Out");
65

C
chengduoZH 已提交
66
    std::string pooling_type = context.Attr<std::string>("pooling_type");
67 68 69
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
70
    if (context.Attr<bool>("global_pooling")) {
71
      for (size_t i = 0; i < ksize.size(); ++i) {
C
fix bug  
chengduoZH 已提交
72
        paddings[i] = 0;
C
chengduoZH 已提交
73
        ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
74 75
      }
    }
Q
QI JUN 已提交
76
    auto& dev_ctx = context.template device_context<DeviceContext>();
77 78 79
    switch (ksize.size()) {
      case 2: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
80
          paddle::operators::math::Pool2dFunctor<
Q
QI JUN 已提交
81
              DeviceContext, paddle::operators::math::MaxPool<T>, T>
82
              pool2d_forward;
83
          paddle::operators::math::MaxPool<T> pool_process;
Q
QI JUN 已提交
84 85
          pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
                         out);
86

C
chengduoZH 已提交
87
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
88
          paddle::operators::math::Pool2dFunctor<
Q
QI JUN 已提交
89
              DeviceContext, paddle::operators::math::AvgPool<T>, T>
90
              pool2d_forward;
91
          paddle::operators::math::AvgPool<T> pool_process;
Q
QI JUN 已提交
92 93
          pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
                         out);
94 95 96 97
        }
      } break;
      case 3: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
98
          paddle::operators::math::Pool3dFunctor<
Q
QI JUN 已提交
99
              DeviceContext, paddle::operators::math::MaxPool<T>, T>
100
              pool3d_forward;
101
          paddle::operators::math::MaxPool<T> pool_process;
Q
QI JUN 已提交
102 103
          pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
                         out);
C
chengduoZH 已提交
104
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
105
          paddle::operators::math::Pool3dFunctor<
Q
QI JUN 已提交
106
              DeviceContext, paddle::operators::math::AvgPool<T>, T>
107
              pool3d_forward;
108
          paddle::operators::math::AvgPool<T> pool_process;
Q
QI JUN 已提交
109 110
          pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
                         out);
111 112
        }
      } break;
C
fix bug  
chengduoZH 已提交
113
      default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
114 115 116 117
    }
  }
};

Q
QI JUN 已提交
118
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
119
class PoolGradKernel : public framework::OpKernel<T> {
120 121
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
122
    const Tensor* in_x = context.Input<Tensor>("X");
123 124 125
    const Tensor* out = context.Input<Tensor>("Out");
    const Tensor* out_grad =
        context.Input<Tensor>(framework::GradVarName("Out"));
C
chengduoZH 已提交
126
    Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
127

C
chengduoZH 已提交
128
    std::string pooling_type = context.Attr<std::string>("pooling_type");
129 130 131 132
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");

C
chengduoZH 已提交
133
    if (context.Attr<bool>("global_pooling")) {
C
fix bug  
chengduoZH 已提交
134 135
      for (size_t i = 0; i < ksize.size(); ++i) {
        paddings[i] = 0;
C
chengduoZH 已提交
136
        ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
C
fix bug  
chengduoZH 已提交
137
      }
138
    }
Q
QI JUN 已提交
139
    auto& dev_ctx = context.template device_context<DeviceContext>();
C
chengduoZH 已提交
140 141
    if (in_x_grad) {
      in_x_grad->mutable_data<T>(context.GetPlace());
Q
QI JUN 已提交
142 143
      paddle::operators::math::SetConstant<DeviceContext, T> set_constant;
      set_constant(dev_ctx, in_x_grad, 0.0);
144 145 146 147

      switch (ksize.size()) {
        case 2: {
          if (pooling_type == "max") {
Q
QI JUN 已提交
148
            paddle::operators::math::MaxPool2dGradFunctor<DeviceContext, T>
149
                pool2d_backward;
Q
QI JUN 已提交
150 151
            pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
                            paddings, in_x_grad);
C
chengduoZH 已提交
152
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
153
            paddle::operators::math::Pool2dGradFunctor<
Q
QI JUN 已提交
154
                DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
155
                pool2d_backward;
156
            paddle::operators::math::AvgPoolGrad<T> pool_process;
Q
QI JUN 已提交
157 158
            pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
                            paddings, pool_process, in_x_grad);
159 160 161 162
          }
        } break;
        case 3: {
          if (pooling_type == "max") {
Q
QI JUN 已提交
163
            paddle::operators::math::MaxPool3dGradFunctor<DeviceContext, T>
164
                pool3d_backward;
Q
QI JUN 已提交
165 166
            pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
                            paddings, in_x_grad);
C
chengduoZH 已提交
167
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
168
            paddle::operators::math::Pool3dGradFunctor<
Q
QI JUN 已提交
169
                DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
170
                pool3d_backward;
171
            paddle::operators::math::AvgPoolGrad<T> pool_process;
Q
QI JUN 已提交
172 173
            pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
                            paddings, pool_process, in_x_grad);
174 175
          }
        } break;
C
fix bug  
chengduoZH 已提交
176
        default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
177 178 179 180 181 182 183
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle