maxout_op.cc 4.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaox 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *     Unless required by applicable law or agreed to in writing, software
 *     distributed under the License is distributed on an "AS IS" BASIS,
 *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *     See the License for the specific language governing permissions and
 *     limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/maxout_op.h"
16 17
#include <vector>

W
wanghaox 已提交
18 19 20 21 22 23 24
namespace paddle {
namespace operators {

using framework::Tensor;

class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
25
  void Make() override {
26 27
    AddInput(
        "X",
W
wangguanzhong 已提交
28 29 30 31
        "(Tensor) The input tensor of maxout operator with data type of "
        "float32. The format of input tensor is NCHW. Where N is batch size,"
        " C is the number of channels, H and W is the height and width of "
        "feature.");
W
wanghaox 已提交
32
    AddOutput("Out",
33
              "(Tensor) The output tensor of maxout operator."
W
wangguanzhong 已提交
34
              "The data type is float32."
35 36 37 38
              "The format of output tensor is also NCHW."
              "Where N is batch size, C is "
              "the number of channels, H and W is the height and "
              "width of feature.");
W
wanghaox 已提交
39 40
    AddAttr<int>(
        "groups",
J
jerrywgz 已提交
41 42
        "(int),"
        "Specifies how many groups the input tensor will be split"
W
wanghaox 已提交
43
        "in the channel dimension. And the number of output channel is "
J
jerrywgz 已提交
44
        "the number of channels divided by groups.");
W
wanghaox 已提交
45
    AddComment(R"DOC(
K
kexinzhao 已提交
46
MaxOut Operator.
W
wanghaox 已提交
47

K
kexinzhao 已提交
48 49 50
Assumed the input shape is (N, Ci, H, W).
The output shape is (N, Co, H, W).
Then $Co = Ci / groups$ and the operator formula is as follows:
W
wanghaox 已提交
51

J
jerrywgz 已提交
52 53 54 55 56 57
$$ y_{si+j} = \max_{k} x_{gsi + sk + j} $$
$$ g = groups $$
$$ s = \\frac{input.size}{num\\_channels} $$
$$ 0 \\le i < \\frac{num\\_channels}{groups} $$
$$ 0 \\le j < s $$
$$ 0 \\le k < groups $$
K
kexinzhao 已提交
58 59 60 61 62 63 64 65

Please refer to Paper:
  - Maxout Networks: http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf
  - Multi-digit Number Recognition from Street View \
    Imagery using Deep Convolutional Neural Networks: \
    https://arxiv.org/pdf/1312.6082v4.pdf

)DOC");
W
wanghaox 已提交
66 67 68 69 70 71 72
  }
};

class MaxOutOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override {
73
    PADDLE_ENFORCE(ctx->HasInput("X"),
J
jerrywgz 已提交
74
                   "Input(X) of MaxoutOpshould not be null.");
W
wanghaox 已提交
75
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
J
jerrywgz 已提交
76
                   "Output(Out) of MaxoutOp should not be null.");
W
wanghaox 已提交
77 78 79
    auto in_x_dims = ctx->GetInputDim("X");
    int groups = ctx->Attrs().Get<int>("groups");
    // check groups > 1
80
    PADDLE_ENFORCE_GT(groups, 1, "groups should be larger than 1 in maxoutop");
W
wanghaox 已提交
81
    std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1] / groups});
W
wanghaox 已提交
82 83 84 85 86 87 88 89 90 91
    output_shape.push_back(in_x_dims[2]);
    output_shape.push_back(in_x_dims[3]);
    ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
  }
};

class MaxOutOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override {
J
jerrywgz 已提交
92 93
    PADDLE_ENFORCE(ctx->HasInput("X"),
                   "Input(X) of MaxOutOpGrad must not be null.");
W
wanghaox 已提交
94
    PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
J
jerrywgz 已提交
95
                   "Output(Grad@X) of MaxOutOpGrad should not be null.");
W
wanghaox 已提交
96 97 98
    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
  }
};
99 100
}  // namespace operators
}  // namespace paddle
W
wanghaox 已提交
101 102

namespace ops = paddle::operators;
Y
Yang Yang 已提交
103
REGISTER_OPERATOR(maxout, ops::MaxOutOp, ops::MaxOutOpMaker,
104 105
                  paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(maxout_grad, ops::MaxOutOpGrad);
106
REGISTER_OP_CPU_KERNEL(
Q
QI JUN 已提交
107 108 109 110
    maxout, ops::MaxOutKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
    maxout_grad,
    ops::MaxOutGradKernel<paddle::platform::CPUDeviceContext, float>);