maxout_op.cc 5.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaox 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *     Unless required by applicable law or agreed to in writing, software
 *     distributed under the License is distributed on an "AS IS" BASIS,
 *     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *     See the License for the specific language governing permissions and
 *     limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/maxout_op.h"
16 17
#include <vector>

W
wanghaox 已提交
18 19 20 21 22 23 24
namespace paddle {
namespace operators {

using framework::Tensor;

class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
25
  void Make() override {
26 27 28 29 30 31
    AddInput("X",
             "A 4-D Tensor with data type of float32 or float64. "
             "The data format is NCHW or NHWC. Where N is "
             "batch size, C is the number of channels, "
             "H and W is the height and width of "
             "feature. ");
W
wanghaox 已提交
32
    AddOutput("Out",
33 34
              "A 4-D Tensor with same data type and data format "
              "with input Tensor. ");
W
wanghaox 已提交
35 36
    AddAttr<int>(
        "groups",
37 38 39 40 41 42 43 44 45 46
        "Specifies how many groups the input tensor will be split into "
        "at the channel dimension. And the number of output channel is "
        "the number of channels divided by groups. ");
    AddAttr<int>(
        "axis",
        "Specifies the index of channel dimension where maxout will "
        "be performed. It should be 1 when data format is NCHW, -1 or 3 "
        "when data format is NHWC. "
        "Default: 1. ")
        .SetDefault(1);
W
wanghaox 已提交
47
    AddComment(R"DOC(
K
kexinzhao 已提交
48
MaxOut Operator.
W
wanghaox 已提交
49

K
kexinzhao 已提交
50 51 52
Assumed the input shape is (N, Ci, H, W).
The output shape is (N, Co, H, W).
Then $Co = Ci / groups$ and the operator formula is as follows:
W
wanghaox 已提交
53

J
jerrywgz 已提交
54 55 56 57 58 59
$$ y_{si+j} = \max_{k} x_{gsi + sk + j} $$
$$ g = groups $$
$$ s = \\frac{input.size}{num\\_channels} $$
$$ 0 \\le i < \\frac{num\\_channels}{groups} $$
$$ 0 \\le j < s $$
$$ 0 \\le k < groups $$
K
kexinzhao 已提交
60 61 62 63 64 65 66 67

Please refer to Paper:
  - Maxout Networks: http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf
  - Multi-digit Number Recognition from Street View \
    Imagery using Deep Convolutional Neural Networks: \
    https://arxiv.org/pdf/1312.6082v4.pdf

)DOC");
W
wanghaox 已提交
68 69 70 71 72 73 74
  }
};

class MaxOutOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override {
75 76 77
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "maxout");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "maxout");

W
wanghaox 已提交
78 79
    auto in_x_dims = ctx->GetInputDim("X");
    int groups = ctx->Attrs().Get<int>("groups");
80
    int axis = ctx->Attrs().Get<int>("axis");
W
wanghaox 已提交
81
    // check groups > 1
82 83 84 85
    PADDLE_ENFORCE_GT(groups, 1, platform::errors::InvalidArgument(
                                     "Attr(groups) of Op(maxout) should be "
                                     "larger than 1. But received %d.",
                                     groups));
86 87 88 89 90 91 92 93 94 95 96 97
    PADDLE_ENFORCE_EQ(
        axis == 1 || axis == -1 || axis == 3, true,
        platform::errors::InvalidArgument(
            "axis only supported 1, -1 or 3, but recevied axis is: %d", axis));
    PADDLE_ENFORCE_EQ(in_x_dims.size(), 4,
                      platform::errors::InvalidArgument(
                          "x's dims should be 4, but received x's dims is: %d",
                          in_x_dims.size()));

    if (axis < 0) {
      axis += in_x_dims.size();
    }
98 99
    PADDLE_ENFORCE_EQ(
        in_x_dims[axis] % groups, 0,
100 101 102 103 104 105 106
        platform::errors::InvalidArgument(
            "The number of input channels for Op(maxout) "
            "should be divisible by Attr(groups). But received: the "
            "input's channels is [%d], the shape of input is [%s], "
            "the Attr(groups) is [%d], the Attr(axis) is [%d]. The "
            "error may come from wrong Attr(groups) or Attr(axis) setting.",
            in_x_dims[axis], in_x_dims, groups, axis));
107 108 109
    std::vector<int64_t> output_shape(
        {in_x_dims[0], in_x_dims[1], in_x_dims[2], in_x_dims[3]});
    output_shape[axis] = in_x_dims[axis] / groups;
110
    ctx->SetOutputDim("Out", phi::make_ddim(output_shape));
W
wanghaox 已提交
111 112 113 114 115 116 117
  }
};

class MaxOutOpGrad : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
  void InferShape(framework::InferShapeContext* ctx) const override {
118 119 120
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "maxout_grad");
    OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
                   "X@Grad", "maxout_grad");
W
wanghaox 已提交
121 122 123
    ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
  }
};
124 125
}  // namespace operators
}  // namespace paddle
W
wanghaox 已提交
126 127

namespace ops = paddle::operators;
H
hong 已提交
128 129 130 131
REGISTER_OPERATOR(
    maxout, ops::MaxOutOp, ops::MaxOutOpMaker,
    paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
    paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
132
REGISTER_OPERATOR(maxout_grad, ops::MaxOutOpGrad);
133
REGISTER_OP_CPU_KERNEL(
134 135
    maxout, ops::MaxOutKernel<paddle::platform::CPUDeviceContext, float>,
    ops::MaxOutKernel<paddle::platform::CPUDeviceContext, double>);
Q
QI JUN 已提交
136 137
REGISTER_OP_CPU_KERNEL(
    maxout_grad,
138 139
    ops::MaxOutGradKernel<paddle::platform::CPUDeviceContext, float>,
    ops::MaxOutGradKernel<paddle::platform::CPUDeviceContext, double>);