roi_pool_op.cc 9.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
W
wanghaox 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

S
sneaxiy 已提交
15
#include <memory>
16
#include "paddle/fluid/framework/op_registry.h"
17
#include "paddle/fluid/framework/op_version_registry.h"
18
#include "paddle/phi/kernels/roi_pool_kernel.h"
W
wanghaox 已提交
19 20 21 22

namespace paddle {
namespace operators {

W
wanghaox 已提交
23
using Tensor = framework::Tensor;
24
using LoDTensor = framework::LoDTensor;
W
wanghaox 已提交
25

W
wanghaox 已提交
26
class ROIPoolOp : public framework::OperatorWithKernel {
W
wanghaox 已提交
27 28 29 30
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
31 32 33 34 35
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "roi_pool");
    OP_INOUT_CHECK(ctx->HasInput("ROIs"), "Input", "ROIs", "roi_pool");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "roi_pool");
    OP_INOUT_CHECK(ctx->HasOutput("Argmax"), "Output", "Argmax", "roi_pool");

W
wanghaox 已提交
36
    auto input_dims = ctx->GetInputDim("X");
W
wanghaox 已提交
37
    auto rois_dims = ctx->GetInputDim("ROIs");
38

39 40 41
    if (ctx->HasInput("RoisNum")) {
      auto rois_num_dims = ctx->GetInputDim("RoisNum");
      PADDLE_ENFORCE_EQ(rois_num_dims.size(), 1,
42
                        platform::errors::InvalidArgument(
43 44 45
                            "The second dimension of RoisNum should "
                            "be 1, but received dimension is %d",
                            rois_num_dims.size()));
F
FDInSky 已提交
46
    }
47 48 49 50 51 52 53 54 55 56 57 58 59 60
    PADDLE_ENFORCE_EQ(input_dims.size(), 4,
                      platform::errors::InvalidArgument(
                          "The input data should be a four-dimensional "
                          "tensor with [N,C,H,W], but received input data with "
                          " %d dimension",
                          input_dims.size()));
    PADDLE_ENFORCE_EQ(
        rois_dims.size(), 2,
        platform::errors::InvalidArgument(
            "ROIs should be a 2-D LoDTensor with shape (num_rois, 4)"
            "given as [[x1, y1, x2, y2], ...], but received ROIs is "
            "%d-dimensional LoDTensor",
            rois_dims.size()));
    PADDLE_ENFORCE_EQ(
61
        rois_dims[1], phi::kROISize,
62 63 64 65 66
        platform::errors::InvalidArgument(
            "ROIs should be a 2-D LoDTensor with shape (num_rois, 4)"
            "given as [[x1, y1, x2, y2], ...]. But the second dimension of  "
            "the received data is %d",
            rois_dims[1]));
W
wanghaox 已提交
67 68 69 70 71 72

    int pooled_height = ctx->Attrs().Get<int>("pooled_height");
    int pooled_width = ctx->Attrs().Get<int>("pooled_width");
    float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");

    PADDLE_ENFORCE_GT(pooled_height, 0,
73 74 75 76
                      platform::errors::OutOfRange(
                          "The pooled output height must be greater than 0"
                          "but received height is %d",
                          pooled_height));
W
wanghaox 已提交
77
    PADDLE_ENFORCE_GT(pooled_width, 0,
78 79 80 81
                      platform::errors::OutOfRange(
                          "The pooled output width must be greater than 0"
                          "but received width is %d",
                          pooled_width));
W
wanghaox 已提交
82
    PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
83 84 85 86
                      platform::errors::OutOfRange(
                          "The spatial scale must be greater than 0, "
                          "but received spatial scale is %f",
                          spatial_scale));
W
wanghaox 已提交
87 88 89 90 91 92 93 94 95

    auto out_dims = input_dims;
    out_dims[0] = rois_dims[0];
    out_dims[1] = input_dims[1];
    out_dims[2] = pooled_height;
    out_dims[3] = pooled_width;

    ctx->SetOutputDim("Out", out_dims);
    ctx->SetOutputDim("Argmax", out_dims);
96
  }
W
wanghaox 已提交
97 98

 protected:
99
  framework::OpKernelType GetExpectedKernelType(
W
wanghaox 已提交
100
      const framework::ExecutionContext& ctx) const override {
101 102 103
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"),
        ctx.device_context());
W
wanghaox 已提交
104 105 106
  }
};

W
wanghaox 已提交
107
class ROIPoolGradOp : public framework::OperatorWithKernel {
W
wanghaox 已提交
108 109 110 111
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
112 113 114 115
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
                   framework::GradVarName("Out"), "roi_pool");
    OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
                   framework::GradVarName("X"), "roi_pool");
W
wanghaox 已提交
116 117 118 119
    ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
  }

 protected:
120
  framework::OpKernelType GetExpectedKernelType(
W
wanghaox 已提交
121
      const framework::ExecutionContext& ctx) const override {
122 123 124
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"),
        ctx.device_context());
W
wanghaox 已提交
125 126 127
  }
};

W
wanghaox 已提交
128
class ROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
W
wanghaox 已提交
129
 public:
Y
Yu Yang 已提交
130
  void Make() override {
W
wanghaox 已提交
131 132
    AddInput("X",
             "(Tensor), "
W
wanghaox 已提交
133 134 135 136 137 138
             "the input of ROIPoolOp. "
             "The format of input tensor is NCHW. Where N is batch size, "
             "C is the number of input channels, "
             "H is the height of the feature, and "
             "W is the width of the feature.");
    AddInput("ROIs",
139
             "(LoDTensor), "
W
wanghaox 已提交
140
             "ROIs (Regions of Interest) to pool over. "
141
             "should be a 2-D LoDTensor of shape (num_rois, 4)"
W
wopeizl 已提交
142
             "given as [[x1, y1, x2, y2], ...]. "
W
wanghaox 已提交
143 144 145
             "Where batch_id is the id of the data, "
             "(x1, y1) is the top left coordinates, and "
             "(x2, y2) is the bottom right coordinates.");
146 147
    AddInput("RoisNum", "(Tensor), The number of RoIs in each image.")
        .AsDispensable();
W
wanghaox 已提交
148 149
    AddOutput("Out",
              "(Tensor), "
W
wanghaox 已提交
150 151
              "The output of ROIPoolOp is a 4-D tensor with shape "
              "(num_rois, channels, pooled_h, pooled_w).");
W
wanghaox 已提交
152 153 154 155
    AddOutput("Argmax",
              "(Tensor), "
              "Argmaxes corresponding to indices in X used "
              "for gradient computation. Only output "
P
peizhilin 已提交
156
              "if arg \"is_test\" is false.")
157
        .AsIntermediate();
W
wanghaox 已提交
158
    AddAttr<float>("spatial_scale",
W
wanghaox 已提交
159 160 161 162
                   "(float, default 1.0), "
                   "Multiplicative spatial scale factor "
                   "to translate ROI coords from their input scale "
                   "to the scale used when pooling.")
163
        .SetDefault(1.0);
W
wanghaox 已提交
164
    AddAttr<int>("pooled_height",
W
wanghaox 已提交
165 166
                 "(int, default 1), "
                 "The pooled output height.")
167
        .SetDefault(1);
W
wanghaox 已提交
168
    AddAttr<int>("pooled_width",
W
wanghaox 已提交
169 170
                 "(int, default 1), "
                 "The pooled output width.")
171
        .SetDefault(1);
W
wanghaox 已提交
172
    AddComment(R"DOC(
Y
yi.wu 已提交
173
**ROIPool Operator**
W
wanghaox 已提交
174

Y
yi.wu 已提交
175 176 177 178 179
Region of interest pooling (also known as RoI pooling) is to perform
is to perform max pooling on inputs of nonuniform sizes to obtain
fixed-size feature maps (e.g. 7*7).

The operator has three steps:
Y
yi.wu 已提交
180

Y
yi.wu 已提交
181 182
1. Dividing each region proposal into equal-sized sections with
   the pooled_width and pooled_height
Y
update  
yi.wu 已提交
183

Y
yi.wu 已提交
184
2. Finding the largest value in each section
Y
update  
yi.wu 已提交
185

Y
yi.wu 已提交
186 187
3. Copying these max values to the output buffer

W
wanghaox 已提交
188 189 190 191 192 193
ROI Pooling for Faster-RCNN. The link below is a further introduction: 
https://stackoverflow.com/questions/43430056/what-is-roi-layer-in-fast-rcnn
    )DOC");
  }
};

H
hong 已提交
194 195
template <typename T>
class ROIPoolGradMaker : public framework::SingleGradOpMaker<T> {
S
sneaxiy 已提交
196
 public:
H
hong 已提交
197
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
S
sneaxiy 已提交
198 199

 protected:
200
  void Apply(GradOpPtr<T> op) const override {
S
sneaxiy 已提交
201
    op->SetType("roi_pool_grad");
H
hong 已提交
202 203
    op->SetInput("X", this->Input("X"));
    op->SetInput("ROIs", this->Input("ROIs"));
204
    op->SetInput("RoisNum", this->Input("RoisNum"));
H
hong 已提交
205 206 207 208
    op->SetInput("Argmax", this->Output("Argmax"));
    op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    op->SetAttrMap(this->Attrs());
S
sneaxiy 已提交
209 210 211
  }
};

W
wanghaox 已提交
212 213 214 215
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
Y
Yang Yang 已提交
216
REGISTER_OPERATOR(roi_pool, ops::ROIPoolOp, ops::ROIPoolOpMaker,
H
hong 已提交
217 218
                  ops::ROIPoolGradMaker<paddle::framework::OpDesc>,
                  ops::ROIPoolGradMaker<paddle::imperative::OpBase>);
219
REGISTER_OPERATOR(roi_pool_grad, ops::ROIPoolGradOp);
220

221
REGISTER_OP_VERSION(roi_pool)
222 223 224 225 226 227 228
    .AddCheckpoint(
        R"ROC(
              Incompatible upgrade of input [RpnRoisLod])ROC",
        paddle::framework::compatible::OpVersionDesc().DeleteInput(
            "RpnRoisLod",
            "Delete RpnRoisLod due to incorrect input name and "
            "it is not used in object detection models yet."))
229 230 231 232 233 234
    .AddCheckpoint(
        R"ROC(
              Upgrade roi_pool add a new input [RoisNum])ROC",
        paddle::framework::compatible::OpVersionDesc().NewInput(
            "RoisNum",
            "The number of RoIs in each image. RoisNum is dispensable."));