pool_op.cc 8.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/pool_op.h"

namespace paddle {
namespace operators {

C
chengduoZH 已提交
20
int OutputSizePool(int input_size, int filter_size, int padding, int stride) {
21 22 23 24 25 26 27 28 29
  int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
  return output_size;
}

class PoolOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
30
  void InferShape(framework::InferShapeContext *ctx) const override {
31 32 33 34 35 36 37 38 39 40 41
    PADDLE_ENFORCE(ctx->HasInput("X"),
                   "X(Input) of Pooling should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
                   "Out(Output) of Pooling should not be null.");

    auto in_x_dims = ctx->GetInputDim("X");

    std::string pooling_type = ctx->Attrs().Get<std::string>("poolingType");
    std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
    std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
    std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
42

43
    PADDLE_ENFORCE(in_x_dims.size() == 4 || in_x_dims.size() == 5,
C
chengduoZH 已提交
44
                   "Pooling intput should be 4-D or 5-D");
45

46 47
    if (ctx->Attrs().Get<bool>("globalPooling")) {
      ksize.resize(static_cast<size_t>(in_x_dims.size()) - 2);
48
      for (size_t i = 0; i < ksize.size(); ++i)
49
        ksize[i] = static_cast<int>(in_x_dims[i + 2]);
50 51
    }

52
    PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U,
C
fix doc  
chengduoZH 已提交
53
                   "Input size and pooling size should be consistent.");
C
chengduoZH 已提交
54
    PADDLE_ENFORCE_EQ(ksize.size(), strides.size(),
C
fix doc  
chengduoZH 已提交
55
                      "Strides size and pooling size should be the same.");
C
chengduoZH 已提交
56
    PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(),
C
fix doc  
chengduoZH 已提交
57
                      "Paddings size and pooling size should be the same.");
58

59
    std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
60
    for (size_t i = 0; i < ksize.size(); ++i) {
61 62
      output_shape.push_back(
          OutputSizePool(in_x_dims[i + 2], ksize[i], paddings[i], strides[i]));
63
    }
64
    ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
65 66 67 68 69 70 71 72
  }
};

class PoolOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
73
  void InferShape(framework::InferShapeContext *ctx) const override {
C
fix doc  
chengduoZH 已提交
74
    PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
75
    PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
C
fix doc  
chengduoZH 已提交
76
                   "Input(X@GRAD) should not be null.");
77
    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
78 79 80
  }
};

C
chengduoZH 已提交
81
class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
82
 public:
C
chengduoZH 已提交
83
  Pool2dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
84 85
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput(
86
        "X",
87
        "The input tensor of pooling operator. "
C
chengduoZH 已提交
88
        "The format of input tensor is NCHW. Where N is batch size, C is the "
C
chengduoZH 已提交
89
        "number of channels, H and W is the height and width of feature.");
90
    AddOutput("Out",
91
              "The output tensor of pooling operator."
C
fix doc  
chengduoZH 已提交
92 93 94 95
              "The format of output tensor is also NCHW."
              "Where N is batch size, C is "
              "the number of channels, H and W is the height and "
              "width of feature.");
96

97
    AddAttr<std::string>("poolingType",
98 99 100
                         "PoolingType of pooling operator."
                         "Str constant equal to 'max' or 'avg'.")
        .InEnum({"max", "avg"});
C
fix doc  
chengduoZH 已提交
101

102
    AddAttr<std::vector<int>>(
103
        "ksize",
C
fix doc  
chengduoZH 已提交
104
        "The pooling size(height, width) of pooling operator."
105
        "If globalPooling = true, ksize is ignored and need not be "
C
fix doc  
chengduoZH 已提交
106 107
        "specified.");  // TODO(Chengduo): Add checker. (Currently,
                        // TypedAttrChecker don't support vector type.)
108
    AddAttr<bool>(
C
chengduoZH 已提交
109
        "globalPooling",
110 111 112
        "Whether to use the globalPooling."
        "Bool constant equal to false or true."
        "Default false."
113 114
        "If globalPooling = true, ksize is ignored and need not be specified.")
        .SetDefault(false);
C
chengduoZH 已提交
115
    AddAttr<std::vector<int>>("strides",
116
                              "Strides(height, width) of pooling operator."
C
fix doc  
chengduoZH 已提交
117 118 119
                              "Default {1,1}.")
        .SetDefault({1, 1});  // TODO(Chengduo): Add checker. (Currently,
                              // TypedAttrChecker don't support vector type.)
C
chengduoZH 已提交
120
    AddAttr<std::vector<int>>("paddings",
121 122
                              "Paddings(height, width) of pooling operator."
                              "Default {0,0}.")
C
fix doc  
chengduoZH 已提交
123 124 125
        .SetDefault({0, 0});  // TODO(Chengduo): Add checker. (Currently,
                              // TypedAttrChecker don't support vector type.)

126
    AddComment(R"DOC(
C
chengduoZH 已提交
127
The pooling2d operation calculates the output based on
128
the input, poolingType and ksize, strides, paddings parameters.
C
fix doc  
chengduoZH 已提交
129 130 131 132
Input(X) and output(Out) are in NCHW format. Where N is batch size, C is the
number of channels, H and W is the height and width of feature.
Parameters(ksize, strides, paddings) are two elements.
These two elements represent height and width, respectively.
133 134 135
)DOC");
  }
};
136

C
chengduoZH 已提交
137
class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
138
 public:
C
chengduoZH 已提交
139
  Pool3dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
140
      : OpProtoAndCheckerMaker(proto, op_checker) {
C
fix doc  
chengduoZH 已提交
141 142 143 144 145 146
    AddInput(
        "X",
        "The input tensor of pooling operator. "
        "The format of input tensor is NCDHW. Where N is batch size, C is "
        "the number of channels, D, H and W is the depth, height and width of "
        "feature.");
147
    AddOutput("Out",
148
              "The output tensor of pooling operator."
C
fix doc  
chengduoZH 已提交
149 150 151 152
              "The format of output tensor is also NCDHW."
              "Where N is batch size, C is "
              "the number of channels, D, H and W is the depth, height and "
              "width of feature.");
153

154
    AddAttr<std::string>("poolingType",
155
                         "PoolingType of pooling operator."
C
fix doc  
chengduoZH 已提交
156
                         "Str constant equal to 'max' or 'avg'.")
157
        .InEnum({"max", "avg"});
C
fix doc  
chengduoZH 已提交
158

159
    AddAttr<std::vector<int>>(
160
        "ksize",
C
fix doc  
chengduoZH 已提交
161
        "The pooling size(depth, height, width) of pooling operator."
162
        "If globalPooling = true, ksize is ignored and need not be "
C
fix doc  
chengduoZH 已提交
163 164
        "specified.");  // TODO(Chengduo): Add checker. (Currently,
                        // TypedAttrChecker don't support vector type.)
165
    AddAttr<bool>(
C
chengduoZH 已提交
166
        "globalPooling",
167 168 169
        "Whether to use the globalPooling."
        "Bool constant equal to false or true."
        "Default false."
170 171
        "If globalPooling = true, ksize is ignored and need not be specified.")
        .SetDefault(false);
C
chengduoZH 已提交
172 173
    AddAttr<std::vector<int>>(
        "strides",
174 175
        "Strides(depth, height, width) of pooling operator."
        "Default {1,1,1}.")
C
fix doc  
chengduoZH 已提交
176 177
        .SetDefault({1, 1, 1});  // TODO(Chengduo): Add checker. (Currently,
                                 // TypedAttrChecker don't support vector type.)
C
chengduoZH 已提交
178 179
    AddAttr<std::vector<int>>(
        "paddings",
180 181
        "Paddings(depth, height, width) of pooling operator."
        "Default {0,0,0}.")
C
fix doc  
chengduoZH 已提交
182 183 184
        .SetDefault({0, 0, 0});  // TODO(Chengduo): Add checker. (Currently,
                                 // TypedAttrChecker don't support vector type.)

185
    AddComment(R"DOC(
C
chengduoZH 已提交
186
The pooling3d operation calculates the output based on
187
the input, poolingType and ksize, strides, paddings parameters.
C
fix doc  
chengduoZH 已提交
188 189 190 191
Input(X) and output(Out) are in NCDHW format. Where N is batch
size, C is the number of channels, D, H and W is the depth, height and
width of feature. Parameters(ksize, strides, paddings) are three elements.
These three elements represent depth, height and width, respectively.
192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
)DOC");
  }
};
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP(pool2d, ops::PoolOp, ops::Pool2dOpMaker, pool2d_grad,
            ops::PoolOpGrad);

REGISTER_OP_CPU_KERNEL(pool2d,
                       ops::PoolKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(pool2d_grad,
                       ops::PoolGradKernel<paddle::platform::CPUPlace, float>)

REGISTER_OP(pool3d, ops::PoolOp, ops::Pool3dOpMaker, pool3d_grad,
            ops::PoolOpGrad);

REGISTER_OP_CPU_KERNEL(pool3d,
                       ops::PoolKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(pool3d_grad,
                       ops::PoolGradKernel<paddle::platform::CPUPlace, float>);