pool_op.h 7.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/pooling.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
26 27 28 29 30 31

class PoolOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
32 33 34 35

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
36 37 38 39 40 41 42
};

class PoolOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
43 44 45 46

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override;
47 48 49 50
};

class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
51
  Pool2dOpMaker(OpProto* proto, OpAttrChecker* op_checker);
52 53 54 55
};

class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
56
  Pool3dOpMaker(OpProto* proto, OpAttrChecker* op_checker);
57
};
58

Q
QI JUN 已提交
59
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
60
class PoolKernel : public framework::OpKernel<T> {
61 62
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
63
    const Tensor* in_x = context.Input<Tensor>("X");
64
    Tensor* out = context.Output<Tensor>("Out");
65

C
chengduoZH 已提交
66
    std::string pooling_type = context.Attr<std::string>("pooling_type");
67 68 69
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
C
chengduoZH 已提交
70
    if (context.Attr<bool>("global_pooling")) {
71
      for (size_t i = 0; i < ksize.size(); ++i) {
C
fix bug  
chengduoZH 已提交
72
        paddings[i] = 0;
C
chengduoZH 已提交
73
        ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
74 75
      }
    }
Q
QI JUN 已提交
76
    auto& dev_ctx = context.template device_context<DeviceContext>();
77 78 79
    switch (ksize.size()) {
      case 2: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
80
          paddle::operators::math::Pool2dFunctor<
Q
QI JUN 已提交
81
              DeviceContext, paddle::operators::math::MaxPool<T>, T>
82
              pool2d_forward;
83
          paddle::operators::math::MaxPool<T> pool_process;
Q
QI JUN 已提交
84 85
          pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
                         out);
86

C
chengduoZH 已提交
87
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
88
          paddle::operators::math::Pool2dFunctor<
Q
QI JUN 已提交
89
              DeviceContext, paddle::operators::math::AvgPool<T>, T>
90
              pool2d_forward;
91
          paddle::operators::math::AvgPool<T> pool_process;
Q
QI JUN 已提交
92 93
          pool2d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
                         out);
94 95 96 97
        }
      } break;
      case 3: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
98
          paddle::operators::math::Pool3dFunctor<
Q
QI JUN 已提交
99
              DeviceContext, paddle::operators::math::MaxPool<T>, T>
100
              pool3d_forward;
101
          paddle::operators::math::MaxPool<T> pool_process;
Q
QI JUN 已提交
102 103
          pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
                         out);
C
chengduoZH 已提交
104
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
105
          paddle::operators::math::Pool3dFunctor<
Q
QI JUN 已提交
106
              DeviceContext, paddle::operators::math::AvgPool<T>, T>
107
              pool3d_forward;
108
          paddle::operators::math::AvgPool<T> pool_process;
Q
QI JUN 已提交
109 110
          pool3d_forward(dev_ctx, *in_x, ksize, strides, paddings, pool_process,
                         out);
111 112
        }
      } break;
C
fix bug  
chengduoZH 已提交
113
      default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
114 115 116 117
    }
  }
};

Q
QI JUN 已提交
118
template <typename DeviceContext, typename T>
C
chengduoZH 已提交
119
class PoolGradKernel : public framework::OpKernel<T> {
120 121
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
122
    const Tensor* in_x = context.Input<Tensor>("X");
123 124 125
    const Tensor* out = context.Input<Tensor>("Out");
    const Tensor* out_grad =
        context.Input<Tensor>(framework::GradVarName("Out"));
C
chengduoZH 已提交
126
    Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
127

C
chengduoZH 已提交
128
    std::string pooling_type = context.Attr<std::string>("pooling_type");
129 130 131 132
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");

C
chengduoZH 已提交
133
    if (context.Attr<bool>("global_pooling")) {
C
fix bug  
chengduoZH 已提交
134 135
      for (size_t i = 0; i < ksize.size(); ++i) {
        paddings[i] = 0;
C
chengduoZH 已提交
136
        ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
C
fix bug  
chengduoZH 已提交
137
      }
138
    }
Q
QI JUN 已提交
139
    auto& dev_ctx = context.template device_context<DeviceContext>();
C
chengduoZH 已提交
140 141 142
    if (in_x_grad) {
      in_x_grad->mutable_data<T>(context.GetPlace());
      auto temp = framework::EigenVector<T>::Flatten(*in_x_grad);
Q
QI JUN 已提交
143 144
      temp.device(
          *context.template device_context<DeviceContext>().eigen_device()) =
145 146 147 148 149
          temp.constant(static_cast<T>(0));

      switch (ksize.size()) {
        case 2: {
          if (pooling_type == "max") {
Q
QI JUN 已提交
150
            paddle::operators::math::MaxPool2dGradFunctor<DeviceContext, T>
151
                pool2d_backward;
Q
QI JUN 已提交
152 153
            pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
                            paddings, in_x_grad);
C
chengduoZH 已提交
154
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
155
            paddle::operators::math::Pool2dGradFunctor<
Q
QI JUN 已提交
156
                DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
157
                pool2d_backward;
158
            paddle::operators::math::AvgPoolGrad<T> pool_process;
Q
QI JUN 已提交
159 160
            pool2d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
                            paddings, pool_process, in_x_grad);
161 162 163 164
          }
        } break;
        case 3: {
          if (pooling_type == "max") {
Q
QI JUN 已提交
165
            paddle::operators::math::MaxPool3dGradFunctor<DeviceContext, T>
166
                pool3d_backward;
Q
QI JUN 已提交
167 168
            pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
                            paddings, in_x_grad);
C
chengduoZH 已提交
169
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
170
            paddle::operators::math::Pool3dGradFunctor<
Q
QI JUN 已提交
171
                DeviceContext, paddle::operators::math::AvgPoolGrad<T>, T>
172
                pool3d_backward;
173
            paddle::operators::math::AvgPoolGrad<T> pool_process;
Q
QI JUN 已提交
174 175
            pool3d_backward(dev_ctx, *in_x, *out, *out_grad, ksize, strides,
                            paddings, pool_process, in_x_grad);
176 177
          }
        } break;
C
fix bug  
chengduoZH 已提交
178
        default: { PADDLE_THROW("Pool op only supports 2D and 3D input."); }
179 180 181 182 183 184 185
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle