roi_pool_op.cc 9.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaox 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/roi_pool_op.h"
S
sneaxiy 已提交
16
#include <memory>
17
#include "paddle/fluid/framework/op_version_registry.h"
W
wanghaox 已提交
18 19 20 21

namespace paddle {
namespace operators {

W
wanghaox 已提交
22
using Tensor = framework::Tensor;
23
using LoDTensor = framework::LoDTensor;
W
wanghaox 已提交
24

W
wanghaox 已提交
25
class ROIPoolOp : public framework::OperatorWithKernel {
W
wanghaox 已提交
26 27 28 29
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
30 31 32 33 34
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "roi_pool");
    OP_INOUT_CHECK(ctx->HasInput("ROIs"), "Input", "ROIs", "roi_pool");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "roi_pool");
    OP_INOUT_CHECK(ctx->HasOutput("Argmax"), "Output", "Argmax", "roi_pool");

W
wanghaox 已提交
35
    auto input_dims = ctx->GetInputDim("X");
W
wanghaox 已提交
36
    auto rois_dims = ctx->GetInputDim("ROIs");
37

38 39 40
    if (ctx->HasInput("RoisNum")) {
      auto rois_num_dims = ctx->GetInputDim("RoisNum");
      PADDLE_ENFORCE_EQ(rois_num_dims.size(), 1,
41
                        platform::errors::InvalidArgument(
42 43 44
                            "The second dimension of RoisNum should "
                            "be 1, but received dimension is %d",
                            rois_num_dims.size()));
F
FDInSky 已提交
45
    }
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
    PADDLE_ENFORCE_EQ(input_dims.size(), 4,
                      platform::errors::InvalidArgument(
                          "The input data should be a four-dimensional "
                          "tensor with [N,C,H,W], but received input data with "
                          " %d dimension",
                          input_dims.size()));
    PADDLE_ENFORCE_EQ(
        rois_dims.size(), 2,
        platform::errors::InvalidArgument(
            "ROIs should be a 2-D LoDTensor with shape (num_rois, 4)"
            "given as [[x1, y1, x2, y2], ...], but received ROIs is "
            "%d-dimensional LoDTensor",
            rois_dims.size()));
    PADDLE_ENFORCE_EQ(
        rois_dims[1], kROISize,
        platform::errors::InvalidArgument(
            "ROIs should be a 2-D LoDTensor with shape (num_rois, 4)"
            "given as [[x1, y1, x2, y2], ...]. But the second dimension of  "
            "the received data is %d",
            rois_dims[1]));
W
wanghaox 已提交
66 67 68 69 70 71

    int pooled_height = ctx->Attrs().Get<int>("pooled_height");
    int pooled_width = ctx->Attrs().Get<int>("pooled_width");
    float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");

    PADDLE_ENFORCE_GT(pooled_height, 0,
72 73 74 75
                      platform::errors::OutOfRange(
                          "The pooled output height must be greater than 0"
                          "but received height is %d",
                          pooled_height));
W
wanghaox 已提交
76
    PADDLE_ENFORCE_GT(pooled_width, 0,
77 78 79 80
                      platform::errors::OutOfRange(
                          "The pooled output width must be greater than 0"
                          "but received width is %d",
                          pooled_width));
W
wanghaox 已提交
81
    PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
82 83 84 85
                      platform::errors::OutOfRange(
                          "The spatial scale must be greater than 0, "
                          "but received spatial scale is %f",
                          spatial_scale));
W
wanghaox 已提交
86 87 88 89 90 91 92 93 94

    auto out_dims = input_dims;
    out_dims[0] = rois_dims[0];
    out_dims[1] = input_dims[1];
    out_dims[2] = pooled_height;
    out_dims[3] = pooled_width;

    ctx->SetOutputDim("Out", out_dims);
    ctx->SetOutputDim("Argmax", out_dims);
95
  }
W
wanghaox 已提交
96 97

 protected:
98
  framework::OpKernelType GetExpectedKernelType(
W
wanghaox 已提交
99
      const framework::ExecutionContext& ctx) const override {
100 101 102
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"),
        ctx.device_context());
W
wanghaox 已提交
103 104 105
  }
};

W
wanghaox 已提交
106
class ROIPoolGradOp : public framework::OperatorWithKernel {
W
wanghaox 已提交
107 108 109 110
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
111 112 113 114
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
                   framework::GradVarName("Out"), "roi_pool");
    OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
                   framework::GradVarName("X"), "roi_pool");
W
wanghaox 已提交
115 116 117 118
    ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
  }

 protected:
119
  framework::OpKernelType GetExpectedKernelType(
W
wanghaox 已提交
120
      const framework::ExecutionContext& ctx) const override {
121 122 123
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"),
        ctx.device_context());
W
wanghaox 已提交
124 125 126
  }
};

W
wanghaox 已提交
127
class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
W
wanghaox 已提交
128
 public:
Y
Yu Yang 已提交
129
  void Make() override {
W
wanghaox 已提交
130 131
    AddInput("X",
             "(Tensor), "
W
wanghaox 已提交
132 133 134 135 136 137
             "the input of ROIPoolOp. "
             "The format of input tensor is NCHW. Where N is batch size, "
             "C is the number of input channels, "
             "H is the height of the feature, and "
             "W is the width of the feature.");
    AddInput("ROIs",
138
             "(LoDTensor), "
W
wanghaox 已提交
139
             "ROIs (Regions of Interest) to pool over. "
140
             "should be a 2-D LoDTensor of shape (num_rois, 4)"
W
wopeizl 已提交
141
             "given as [[x1, y1, x2, y2], ...]. "
W
wanghaox 已提交
142 143 144
             "Where batch_id is the id of the data, "
             "(x1, y1) is the top left coordinates, and "
             "(x2, y2) is the bottom right coordinates.");
145 146
    AddInput("RoisNum", "(Tensor), The number of RoIs in each image.")
        .AsDispensable();
W
wanghaox 已提交
147 148
    AddOutput("Out",
              "(Tensor), "
W
wanghaox 已提交
149 150
              "The output of ROIPoolOp is a 4-D tensor with shape "
              "(num_rois, channels, pooled_h, pooled_w).");
W
wanghaox 已提交
151 152 153 154
    AddOutput("Argmax",
              "(Tensor), "
              "Argmaxes corresponding to indices in X used "
              "for gradient computation. Only output "
P
peizhilin 已提交
155
              "if arg \"is_test\" is false.")
156
        .AsIntermediate();
W
wanghaox 已提交
157
    AddAttr<float>("spatial_scale",
W
wanghaox 已提交
158 159 160 161
                   "(float, default 1.0), "
                   "Multiplicative spatial scale factor "
                   "to translate ROI coords from their input scale "
                   "to the scale used when pooling.")
162
        .SetDefault(1.0);
W
wanghaox 已提交
163
    AddAttr<int>("pooled_height",
W
wanghaox 已提交
164 165
                 "(int, default 1), "
                 "The pooled output height.")
166
        .SetDefault(1);
W
wanghaox 已提交
167
    AddAttr<int>("pooled_width",
W
wanghaox 已提交
168 169
                 "(int, default 1), "
                 "The pooled output width.")
170
        .SetDefault(1);
W
wanghaox 已提交
171
    AddComment(R"DOC(
Y
yi.wu 已提交
172
**ROIPool Operator**
W
wanghaox 已提交
173

Y
yi.wu 已提交
174 175 176 177 178
Region of interest pooling (also known as RoI pooling) is to perform
is to perform max pooling on inputs of nonuniform sizes to obtain
fixed-size feature maps (e.g. 7*7).

The operator has three steps:
Y
yi.wu 已提交
179

Y
yi.wu 已提交
180 181
1. Dividing each region proposal into equal-sized sections with
   the pooled_width and pooled_height
Y
update  
yi.wu 已提交
182

Y
yi.wu 已提交
183
2. Finding the largest value in each section
Y
update  
yi.wu 已提交
184

Y
yi.wu 已提交
185 186
3. Copying these max values to the output buffer

W
wanghaox 已提交
187 188 189 190 191 192
ROI Pooling for Faster-RCNN. The link below is a further introduction: 
https://stackoverflow.com/questions/43430056/what-is-roi-layer-in-fast-rcnn
    )DOC");
  }
};

H
hong 已提交
193 194
template <typename T>
class ROIPoolGradMaker : public framework::SingleGradOpMaker<T> {
S
sneaxiy 已提交
195
 public:
H
hong 已提交
196
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
S
sneaxiy 已提交
197 198

 protected:
199
  void Apply(GradOpPtr<T> op) const override {
S
sneaxiy 已提交
200
    op->SetType("roi_pool_grad");
H
hong 已提交
201 202
    op->SetInput("X", this->Input("X"));
    op->SetInput("ROIs", this->Input("ROIs"));
203
    op->SetInput("RoisNum", this->Input("RoisNum"));
H
hong 已提交
204 205 206 207
    op->SetInput("Argmax", this->Output("Argmax"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
S
sneaxiy 已提交
208 209 210
  }
};

W
wanghaox 已提交
211 212 213 214
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yang 已提交
215
REGISTER_OPERATOR(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker,
H
hong 已提交
216 217
                  ops::ROIPoolGradMaker<paddle::framework::OpDesc>,
                  ops::ROIPoolGradMaker<paddle::imperative::OpBase>);
218
REGISTER_OPERATOR(roi_pool_grad, ops::ROIPoolGradOp);
W
wanghaox 已提交
219
REGISTER_OP_CPU_KERNEL(
Q
QI JUN 已提交
220 221
    roi_pool,
    ops::CPUROIPoolOpKernel<paddle::platform::CPUDeviceContext, float>,
F
FDInSky 已提交
222 223
    ops::CPUROIPoolOpKernel<paddle::platform::CPUDeviceContext, double>,
    ops::CPUROIPoolOpKernel<paddle::platform::CPUDeviceContext, int>);
W
wanghaox 已提交
224 225
REGISTER_OP_CPU_KERNEL(
    roi_pool_grad,
Q
QI JUN 已提交
226
    ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, float>,
F
FDInSky 已提交
227 228
    ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, double>,
    ops::CPUROIPoolGradOpKernel<paddle::platform::CPUDeviceContext, int>);
229 230 231 232 233 234 235
REGISTER_OP_VERSION(roi_pool)
    .AddCheckpoint(
        R"ROC(
              Upgrade roi_pool add a new input [RoisNum])ROC",
        paddle::framework::compatible::OpVersionDesc().NewInput(
            "RoisNum",
            "The number of RoIs in each image. RoisNum is dispensable."));