pool_op.h 5.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/pooling.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename Place, typename T>
C
chengduoZH 已提交
28
class PoolKernel : public framework::OpKernel<T> {
29 30
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
31
    const Tensor* in_x = context.Input<Tensor>("X");
32
    Tensor* out = context.Output<Tensor>("Out");
33

34
    std::string pooling_type = context.Attr<std::string>("poolingType");
35 36 37
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
38
    if (context.Attr<bool>("globalPooling")) {
39
      for (size_t i = 0; i < ksize.size(); ++i) {
C
chengduoZH 已提交
40
        ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
41 42 43 44 45 46
      }
    }

    switch (ksize.size()) {
      case 2: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
47
          paddle::operators::math::Pool2dFunctor<
48
              Place, paddle::operators::math::MaxPool<T>, T>
49
              pool2d_forward;
50
          paddle::operators::math::MaxPool<T> pool_process;
C
chengduoZH 已提交
51
          pool2d_forward(context.device_context(), *in_x, *out, ksize, strides,
52
                         paddings, pool_process);
53

C
chengduoZH 已提交
54
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
55
          paddle::operators::math::Pool2dFunctor<
56
              Place, paddle::operators::math::AvgPool<T>, T>
57
              pool2d_forward;
58
          paddle::operators::math::AvgPool<T> pool_process;
C
chengduoZH 已提交
59
          pool2d_forward(context.device_context(), *in_x, *out, ksize, strides,
60
                         paddings, pool_process);
61 62 63 64
        }
      } break;
      case 3: {
        if (pooling_type == "max") {
C
chengduoZH 已提交
65
          paddle::operators::math::Pool3dFunctor<
66
              Place, paddle::operators::math::MaxPool<T>, T>
67
              pool3d_forward;
68
          paddle::operators::math::MaxPool<T> pool_process;
C
chengduoZH 已提交
69
          pool3d_forward(context.device_context(), *in_x, *out, ksize, strides,
70
                         paddings, pool_process);
C
chengduoZH 已提交
71
        } else if (pooling_type == "avg") {
C
chengduoZH 已提交
72
          paddle::operators::math::Pool3dFunctor<
73
              Place, paddle::operators::math::AvgPool<T>, T>
74
              pool3d_forward;
75
          paddle::operators::math::AvgPool<T> pool_process;
C
chengduoZH 已提交
76
          pool3d_forward(context.device_context(), *in_x, *out, ksize, strides,
77
                         paddings, pool_process);
78 79 80 81 82 83 84
        }
      } break;
    }
  }
};

template <typename Place, typename T>
C
chengduoZH 已提交
85
class PoolGradKernel : public framework::OpKernel<T> {
86 87
 public:
  void Compute(const framework::ExecutionContext& context) const override {
C
chengduoZH 已提交
88
    const Tensor* in_x = context.Input<Tensor>("X");
89 90 91
    const Tensor* out = context.Input<Tensor>("Out");
    const Tensor* out_grad =
        context.Input<Tensor>(framework::GradVarName("Out"));
C
chengduoZH 已提交
92
    Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
93 94

    std::string pooling_type = context.Attr<std::string>("poolingType");
95 96 97 98
    std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
    std::vector<int> strides = context.Attr<std::vector<int>>("strides");
    std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");

99
    if (context.Attr<bool>("globalPooling")) {
C
chengduoZH 已提交
100 101
      for (size_t i = 0; i < ksize.size(); ++i)
        ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
102 103
    }

C
chengduoZH 已提交
104 105 106
    if (in_x_grad) {
      in_x_grad->mutable_data<T>(context.GetPlace());
      auto temp = framework::EigenVector<T>::Flatten(*in_x_grad);
107 108 109 110 111 112
      temp.device(context.GetEigenDevice<Place>()) =
          temp.constant(static_cast<T>(0));

      switch (ksize.size()) {
        case 2: {
          if (pooling_type == "max") {
C
chengduoZH 已提交
113
            paddle::operators::math::MaxPool2dGradFunctor<Place, T>
114
                pool2d_backward;
C
chengduoZH 已提交
115
            pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out,
C
chengduoZH 已提交
116
                            *out_grad, ksize, strides, paddings);
C
chengduoZH 已提交
117
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
118
            paddle::operators::math::Pool2dGradFunctor<
119
                Place, paddle::operators::math::AvgPoolGrad<T>, T>
120
                pool2d_backward;
121
            paddle::operators::math::AvgPoolGrad<T> pool_process;
C
chengduoZH 已提交
122
            pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out,
123
                            *out_grad, ksize, strides, paddings, pool_process);
124 125 126 127
          }
        } break;
        case 3: {
          if (pooling_type == "max") {
C
chengduoZH 已提交
128
            paddle::operators::math::MaxPool3dGradFunctor<Place, T>
129
                pool3d_backward;
C
chengduoZH 已提交
130
            pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out,
C
chengduoZH 已提交
131
                            *out_grad, ksize, strides, paddings);
C
chengduoZH 已提交
132
          } else if (pooling_type == "avg") {
C
chengduoZH 已提交
133
            paddle::operators::math::Pool3dGradFunctor<
134
                Place, paddle::operators::math::AvgPoolGrad<T>, T>
135
                pool3d_backward;
136
            paddle::operators::math::AvgPoolGrad<T> pool_process;
C
chengduoZH 已提交
137
            pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out,
138
                            *out_grad, ksize, strides, paddings, pool_process);
139 140 141 142 143 144 145 146 147
          }
        } break;
      }
    }
  }
};

}  // namespace operators
}  // namespace paddle