pool_op.h 6.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/pooling.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51

class PoolOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
};

class PoolOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override;
};

class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  Pool2dOpMaker(framework::OpProto* proto,
                framework::OpAttrChecker* op_checker);
};

class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  Pool3dOpMaker(framework::OpProto* proto,
                framework::OpAttrChecker* op_checker);
};
52 53

template <typename Place, typename T>
C
chengduoZH 已提交
54
class PoolKernel : public framework::OpKernel<T> {
55 56
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
57
    const Tensor* in_x = context.Input<Tensor>("X");
58
    Tensor* out = context.Output<Tensor>("Out");
59

60
    std::string pooling_type = context.Attr<std::string>("pooling_type");
61 62 63
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
64
    if (context.Attr<bool>("global_pooling")) {
65
      for (size_t i = 0; i < ksize.size(); ++i) {
C
chengduoZH 已提交
66
        ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
67 68 69 70 71 72
      }
    }

    switch (ksize.size()) {
      case 2: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
73
          paddle::operators::math::Pool2dFunctor<
74
              Place, paddle::operators::math::MaxPool<T>, T>
75
              pool2d_forward;
76
          paddle::operators::math::MaxPool<T> pool_process;
C
chengduoZH 已提交
77
          pool2d_forward(context.device_context(), *in_x, *out, ksize, strides,
78
                         paddings, pool_process);
79

C
chengduoZH 已提交
80
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
81
          paddle::operators::math::Pool2dFunctor<
82
              Place, paddle::operators::math::AvgPool<T>, T>
83
              pool2d_forward;
84
          paddle::operators::math::AvgPool<T> pool_process;
C
chengduoZH 已提交
85
          pool2d_forward(context.device_context(), *in_x, *out, ksize, strides,
86
                         paddings, pool_process);
87 88 89 90
        }
      } break;
      case 3: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
91
          paddle::operators::math::Pool3dFunctor<
92
              Place, paddle::operators::math::MaxPool<T>, T>
93
              pool3d_forward;
94
          paddle::operators::math::MaxPool<T> pool_process;
C
chengduoZH 已提交
95
          pool3d_forward(context.device_context(), *in_x, *out, ksize, strides,
96
                         paddings, pool_process);
C
chengduoZH 已提交
97
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
98
          paddle::operators::math::Pool3dFunctor<
99
              Place, paddle::operators::math::AvgPool<T>, T>
100
              pool3d_forward;
101
          paddle::operators::math::AvgPool<T> pool_process;
C
chengduoZH 已提交
102
          pool3d_forward(context.device_context(), *in_x, *out, ksize, strides,
103
                         paddings, pool_process);
104 105 106 107 108 109 110
        }
      } break;
    }
  }
};

template <typename Place, typename T>
C
chengduoZH 已提交
111
class PoolGradKernel : public framework::OpKernel<T> {
112 113
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
114
    const Tensor* in_x = context.Input<Tensor>("X");
115 116 117
    const Tensor* out = context.Input<Tensor>("Out");
    const Tensor* out_grad =
        context.Input<Tensor>(framework::GradVarName("Out"));
C
chengduoZH 已提交
118
    Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
119

120
    std::string pooling_type = context.Attr<std::string>("pooling_type");
121 122 123 124
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");

125
    if (context.Attr<bool>("global_pooling")) {
C
chengduoZH 已提交
126 127
      for (size_t i = 0; i < ksize.size(); ++i)
        ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
128 129
    }

C
chengduoZH 已提交
130 131 132
    if (in_x_grad) {
      in_x_grad->mutable_data<T>(context.GetPlace());
      auto temp = framework::EigenVector<T>::Flatten(*in_x_grad);
133 134 135 136 137 138
      temp.device(context.GetEigenDevice<Place>()) =
          temp.constant(static_cast<T>(0));

      switch (ksize.size()) {
        case 2: {
          if (pooling_type == "max") {
C
chengduoZH 已提交
139
            paddle::operators::math::MaxPool2dGradFunctor<Place, T>
140
                pool2d_backward;
C
chengduoZH 已提交
141
            pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out,
C
chengduoZH 已提交
142
                            *out_grad, ksize, strides, paddings);
C
chengduoZH 已提交
143
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
144
            paddle::operators::math::Pool2dGradFunctor<
145
                Place, paddle::operators::math::AvgPoolGrad<T>, T>
146
                pool2d_backward;
147
            paddle::operators::math::AvgPoolGrad<T> pool_process;
C
chengduoZH 已提交
148
            pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out,
149
                            *out_grad, ksize, strides, paddings, pool_process);
150 151 152 153
          }
        } break;
        case 3: {
          if (pooling_type == "max") {
C
chengduoZH 已提交
154
            paddle::operators::math::MaxPool3dGradFunctor<Place, T>
155
                pool3d_backward;
C
chengduoZH 已提交
156
            pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out,
C
chengduoZH 已提交
157
                            *out_grad, ksize, strides, paddings);
C
chengduoZH 已提交
158
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
159
            paddle::operators::math::Pool3dGradFunctor<
160
                Place, paddle::operators::math::AvgPoolGrad<T>, T>
161
                pool3d_backward;
162
            paddle::operators::math::AvgPoolGrad<T> pool_process;
C
chengduoZH 已提交
163
            pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out,
164
                            *out_grad, ksize, strides, paddings, pool_process);
165 166 167 168 169 170 171 172 173
          }
        } break;
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle