generate_proposals_op.cc 20.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

15 16
#include <cmath>
#include <cstring>
17 18 19
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
20
#include "paddle/fluid/framework/op_version_registry.h"
21 22 23 24 25 26 27 28 29
#include "paddle/fluid/operators/gather.h"
#include "paddle/fluid/operators/math/math_function.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;

30
static const double kBBoxClipDefault = std::log(1000.0 / 16.0);
31

32 33 34 35 36 37 38 39 40
static void AppendProposals(Tensor *dst, int64_t offset, const Tensor &src) {
  auto *out_data = dst->data<void>();
  auto *to_add_data = src.data<void>();
  size_t size_of_t = framework::SizeOfType(src.type());
  offset *= size_of_t;
  std::memcpy(
      reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(out_data) + offset),
      to_add_data, src.numel() * size_of_t);
}
41 42 43 44 45 46

class GenerateProposalsOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext *ctx) const override {
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("Scores"), true,
        platform::errors::NotFound("Input(Scores) shouldn't be null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("BboxDeltas"), true,
        platform::errors::NotFound("Input(BboxDeltas) shouldn't be null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("ImInfo"), true,
        platform::errors::NotFound("Input(ImInfo) shouldn't be null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("Anchors"), true,
        platform::errors::NotFound("Input(Anchors) shouldn't be null."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("Variances"), true,
        platform::errors::NotFound("Input(Variances) shouldn't be null."));
62 63 64

    ctx->SetOutputDim("RpnRois", {-1, 4});
    ctx->SetOutputDim("RpnRoiProbs", {-1, 1});
65 66 67 68
    if (!ctx->IsRuntime()) {
      ctx->SetLoDLevel("RpnRois", std::max(ctx->GetLoDLevel("Scores"), 1));
      ctx->SetLoDLevel("RpnRoiProbs", std::max(ctx->GetLoDLevel("Scores"), 1));
    }
69 70 71 72 73
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
74 75 76
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "Anchors"),
        ctx.device_context());
77 78 79 80
  }
};

template <class T>
81 82 83
static inline void BoxCoder(const platform::DeviceContext &ctx,
                            Tensor *all_anchors, Tensor *bbox_deltas,
                            Tensor *variances, Tensor *proposals) {
84 85 86 87 88 89 90 91 92 93 94 95 96
  T *proposals_data = proposals->mutable_data<T>(ctx.GetPlace());

  int64_t row = all_anchors->dims()[0];
  int64_t len = all_anchors->dims()[1];

  auto *bbox_deltas_data = bbox_deltas->data<T>();
  auto *anchor_data = all_anchors->data<T>();
  const T *variances_data = nullptr;
  if (variances) {
    variances_data = variances->data<T>();
  }

  for (int64_t i = 0; i < row; ++i) {
97 98
    T anchor_width = anchor_data[i * len + 2] - anchor_data[i * len] + 1.0;
    T anchor_height = anchor_data[i * len + 3] - anchor_data[i * len + 1] + 1.0;
99

100 101
    T anchor_center_x = anchor_data[i * len] + 0.5 * anchor_width;
    T anchor_center_y = anchor_data[i * len + 1] + 0.5 * anchor_height;
102 103 104 105 106 107 108 109 110 111 112

    T bbox_center_x = 0, bbox_center_y = 0;
    T bbox_width = 0, bbox_height = 0;

    if (variances) {
      bbox_center_x =
          variances_data[i * len] * bbox_deltas_data[i * len] * anchor_width +
          anchor_center_x;
      bbox_center_y = variances_data[i * len + 1] *
                          bbox_deltas_data[i * len + 1] * anchor_height +
                      anchor_center_y;
113 114
      bbox_width = std::exp(std::min<T>(variances_data[i * len + 2] *
                                            bbox_deltas_data[i * len + 2],
115
                                        kBBoxClipDefault)) *
116
                   anchor_width;
117 118
      bbox_height = std::exp(std::min<T>(variances_data[i * len + 3] *
                                             bbox_deltas_data[i * len + 3],
119
                                         kBBoxClipDefault)) *
120 121 122 123 124 125
                    anchor_height;
    } else {
      bbox_center_x =
          bbox_deltas_data[i * len] * anchor_width + anchor_center_x;
      bbox_center_y =
          bbox_deltas_data[i * len + 1] * anchor_height + anchor_center_y;
126
      bbox_width = std::exp(std::min<T>(bbox_deltas_data[i * len + 2],
127
                                        kBBoxClipDefault)) *
128 129
                   anchor_width;
      bbox_height = std::exp(std::min<T>(bbox_deltas_data[i * len + 3],
130
                                         kBBoxClipDefault)) *
131
                    anchor_height;
132 133 134 135
    }

    proposals_data[i * len] = bbox_center_x - bbox_width / 2;
    proposals_data[i * len + 1] = bbox_center_y - bbox_height / 2;
136 137
    proposals_data[i * len + 2] = bbox_center_x + bbox_width / 2 - 1;
    proposals_data[i * len + 3] = bbox_center_y + bbox_height / 2 - 1;
138 139 140 141 142
  }
  // return proposals;
}

template <class T>
143 144
static inline void ClipTiledBoxes(const platform::DeviceContext &ctx,
                                  const Tensor &im_info, Tensor *boxes) {
145 146
  T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
  const T *im_info_data = im_info.data<T>();
147
  T zero(0);
148 149 150
  for (int64_t i = 0; i < boxes->numel(); ++i) {
    if (i % 4 == 0) {
      boxes_data[i] =
151
          std::max(std::min(boxes_data[i], im_info_data[1] - 1), zero);
152 153
    } else if (i % 4 == 1) {
      boxes_data[i] =
154
          std::max(std::min(boxes_data[i], im_info_data[0] - 1), zero);
155 156
    } else if (i % 4 == 2) {
      boxes_data[i] =
157
          std::max(std::min(boxes_data[i], im_info_data[1] - 1), zero);
158 159
    } else {
      boxes_data[i] =
160
          std::max(std::min(boxes_data[i], im_info_data[0] - 1), zero);
161 162 163 164 165
    }
  }
}

template <class T>
166 167 168
static inline void FilterBoxes(const platform::DeviceContext &ctx,
                               Tensor *boxes, float min_size,
                               const Tensor &im_info, Tensor *keep) {
169 170
  const T *im_info_data = im_info.data<T>();
  T *boxes_data = boxes->mutable_data<T>(ctx.GetPlace());
171
  T im_scale = im_info_data[2];
172
  keep->Resize({boxes->dims()[0]});
173
  min_size = std::max(min_size, 1.0f);
174 175 176 177 178 179
  int *keep_data = keep->mutable_data<int>(ctx.GetPlace());

  int keep_len = 0;
  for (int i = 0; i < boxes->dims()[0]; ++i) {
    T ws = boxes_data[4 * i + 2] - boxes_data[4 * i] + 1;
    T hs = boxes_data[4 * i + 3] - boxes_data[4 * i + 1] + 1;
180 181 182 183
    T ws_origin_scale =
        (boxes_data[4 * i + 2] - boxes_data[4 * i]) / im_scale + 1;
    T hs_origin_scale =
        (boxes_data[4 * i + 3] - boxes_data[4 * i + 1]) / im_scale + 1;
184 185
    T x_ctr = boxes_data[4 * i] + ws / 2;
    T y_ctr = boxes_data[4 * i + 1] + hs / 2;
186 187
    if (ws_origin_scale >= min_size && hs_origin_scale >= min_size &&
        x_ctr <= im_info_data[1] && y_ctr <= im_info_data[0]) {
188 189 190 191 192 193 194
      keep_data[keep_len++] = i;
    }
  }
  keep->Resize({keep_len});
}

template <class T>
195 196 197 198
static inline std::vector<std::pair<T, int>> GetSortedScoreIndex(
    const std::vector<T> &scores) {
  std::vector<std::pair<T, int>> sorted_indices;
  sorted_indices.reserve(scores.size());
199
  for (size_t i = 0; i < scores.size(); ++i) {
200
    sorted_indices.emplace_back(scores[i], i);
201 202
  }
  // Sort the score pair according to the scores in descending order
203 204 205 206 207
  std::stable_sort(sorted_indices.begin(), sorted_indices.end(),
                   [](const std::pair<T, int> &a, const std::pair<T, int> &b) {
                     return a.first < b.first;
                   });
  return sorted_indices;
208 209 210
}

template <class T>
211
static inline T BBoxArea(const T *box, bool normalized) {
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
  if (box[2] < box[0] || box[3] < box[1]) {
    // If coordinate values are is invalid
    // (e.g. xmax < xmin or ymax < ymin), return 0.
    return static_cast<T>(0.);
  } else {
    const T w = box[2] - box[0];
    const T h = box[3] - box[1];
    if (normalized) {
      return w * h;
    } else {
      // If coordinate values are not within range [0, 1].
      return (w + 1) * (h + 1);
    }
  }
}

template <class T>
229
static inline T JaccardOverlap(const T *box1, const T *box2, bool normalized) {
230 231 232 233 234 235 236 237
  if (box2[0] > box1[2] || box2[2] < box1[0] || box2[1] > box1[3] ||
      box2[3] < box1[1]) {
    return static_cast<T>(0.);
  } else {
    const T inter_xmin = std::max(box1[0], box2[0]);
    const T inter_ymin = std::max(box1[1], box2[1]);
    const T inter_xmax = std::min(box1[2], box2[2]);
    const T inter_ymax = std::min(box1[3], box2[3]);
238 239
    const T inter_w = std::max(T(0), inter_xmax - inter_xmin + 1);
    const T inter_h = std::max(T(0), inter_ymax - inter_ymin + 1);
240 241 242 243 244 245 246
    const T inter_area = inter_w * inter_h;
    const T bbox1_area = BBoxArea<T>(box1, normalized);
    const T bbox2_area = BBoxArea<T>(box2, normalized);
    return inter_area / (bbox1_area + bbox2_area - inter_area);
  }
}

247 248 249 250 251 252 253 254 255 256 257 258
template <typename T>
static inline Tensor VectorToTensor(const std::vector<T> &selected_indices,
                                    int selected_num) {
  Tensor keep_nms;
  keep_nms.Resize({selected_num});
  auto *keep_data = keep_nms.mutable_data<T>(platform::CPUPlace());
  for (int i = 0; i < selected_num; ++i) {
    keep_data[i] = selected_indices[i];
  }
  return keep_nms;
}

259
template <class T>
260 261
static inline Tensor NMS(const platform::DeviceContext &ctx, Tensor *bbox,
                         Tensor *scores, T nms_threshold, float eta) {
262 263 264 265 266 267
  int64_t num_boxes = bbox->dims()[0];
  // 4: [xmin ymin xmax ymax]
  int64_t box_size = bbox->dims()[1];

  std::vector<T> scores_data(num_boxes);
  std::copy_n(scores->data<T>(), num_boxes, scores_data.begin());
268 269
  std::vector<std::pair<T, int>> sorted_indices =
      GetSortedScoreIndex<T>(scores_data);
270 271 272 273 274 275

  std::vector<int> selected_indices;
  int selected_num = 0;
  T adaptive_threshold = nms_threshold;
  const T *bbox_data = bbox->data<T>();
  while (sorted_indices.size() != 0) {
276 277 278
    int idx = sorted_indices.back().second;
    bool flag = true;
    for (int kept_idx : selected_indices) {
279 280 281 282 283 284 285 286 287 288
      if (flag) {
        T overlap = JaccardOverlap<T>(bbox_data + idx * box_size,
                                      bbox_data + kept_idx * box_size, false);
        flag = (overlap <= adaptive_threshold);
      } else {
        break;
      }
    }
    if (flag) {
      selected_indices.push_back(idx);
289
      ++selected_num;
290
    }
J
jerrywgz 已提交
291
    sorted_indices.erase(sorted_indices.end() - 1);
292 293 294 295
    if (flag && eta < 1 && adaptive_threshold > 0.5) {
      adaptive_threshold *= eta;
    }
  }
296
  return VectorToTensor(selected_indices, selected_num);
297 298
}

299
template <typename T>
300 301 302 303 304 305
class GenerateProposalsKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext &context) const override {
    auto *scores = context.Input<Tensor>("Scores");
    auto *bbox_deltas = context.Input<Tensor>("BboxDeltas");
    auto *im_info = context.Input<Tensor>("ImInfo");
306 307 308 309
    auto anchors = GET_DATA_SAFELY(context.Input<Tensor>("Anchors"), "Input",
                                   "Anchors", "GenerateProposals");
    auto variances = GET_DATA_SAFELY(context.Input<Tensor>("Variances"),
                                     "Input", "Variances", "GenerateProposals");
310 311 312 313 314 315 316 317 318 319

    auto *rpn_rois = context.Output<LoDTensor>("RpnRois");
    auto *rpn_roi_probs = context.Output<LoDTensor>("RpnRoiProbs");

    int pre_nms_top_n = context.Attr<int>("pre_nms_topN");
    int post_nms_top_n = context.Attr<int>("post_nms_topN");
    float nms_thresh = context.Attr<float>("nms_thresh");
    float min_size = context.Attr<float>("min_size");
    float eta = context.Attr<float>("eta");

320 321
    auto &dev_ctx =
        context.template device_context<platform::CPUDeviceContext>();
322

323
    auto &scores_dim = scores->dims();
324 325 326 327 328
    int64_t num = scores_dim[0];
    int64_t c_score = scores_dim[1];
    int64_t h_score = scores_dim[2];
    int64_t w_score = scores_dim[3];

329
    auto &bbox_dim = bbox_deltas->dims();
330 331 332 333 334 335
    int64_t c_bbox = bbox_dim[1];
    int64_t h_bbox = bbox_dim[2];
    int64_t w_bbox = bbox_dim[3];

    rpn_rois->mutable_data<T>({bbox_deltas->numel() / 4, 4},
                              context.GetPlace());
336
    rpn_roi_probs->mutable_data<T>({scores->numel(), 1}, context.GetPlace());
337 338 339 340 341 342 343

    Tensor bbox_deltas_swap, scores_swap;
    bbox_deltas_swap.mutable_data<T>({num, h_bbox, w_bbox, c_bbox},
                                     dev_ctx.GetPlace());
    scores_swap.mutable_data<T>({num, h_score, w_score, c_score},
                                dev_ctx.GetPlace());

344
    math::Transpose<platform::CPUDeviceContext, T, 4> trans;
345 346 347 348 349
    std::vector<int> axis = {0, 2, 3, 1};
    trans(dev_ctx, *bbox_deltas, &bbox_deltas_swap, axis);
    trans(dev_ctx, *scores, &scores_swap, axis);

    framework::LoD lod;
350 351 352 353 354
    lod.resize(1);
    auto &lod0 = lod[0];
    lod0.push_back(0);
    anchors.Resize({anchors.numel() / 4, 4});
    variances.Resize({variances.numel() / 4, 4});
355
    std::vector<int> tmp_num;
356 357 358 359 360 361 362 363 364 365 366

    int64_t num_proposals = 0;
    for (int64_t i = 0; i < num; ++i) {
      Tensor im_info_slice = im_info->Slice(i, i + 1);
      Tensor bbox_deltas_slice = bbox_deltas_swap.Slice(i, i + 1);
      Tensor scores_slice = scores_swap.Slice(i, i + 1);

      bbox_deltas_slice.Resize({h_bbox * w_bbox * c_bbox / 4, 4});
      scores_slice.Resize({h_score * w_score * c_score, 1});

      std::pair<Tensor, Tensor> tensor_pair =
367
          ProposalForOneImage(dev_ctx, im_info_slice, anchors, variances,
368 369
                              bbox_deltas_slice, scores_slice, pre_nms_top_n,
                              post_nms_top_n, nms_thresh, min_size, eta);
370 371
      Tensor &proposals = tensor_pair.first;
      Tensor &scores = tensor_pair.second;
372

373 374
      AppendProposals(rpn_rois, 4 * num_proposals, proposals);
      AppendProposals(rpn_roi_probs, num_proposals, scores);
375
      num_proposals += proposals.dims()[0];
376
      lod0.push_back(num_proposals);
377
      tmp_num.push_back(proposals.dims()[0]);
F
FDInSky 已提交
378
    }
379 380 381 382
    if (context.HasOutput("RpnRoisNum")) {
      auto *rpn_rois_num = context.Output<Tensor>("RpnRoisNum");
      rpn_rois_num->mutable_data<int>({num}, context.GetPlace());
      int *num_data = rpn_rois_num->data<int>();
F
FDInSky 已提交
383
      for (int i = 0; i < num; i++) {
384
        num_data[i] = tmp_num[i];
F
FDInSky 已提交
385
      }
386
      rpn_rois_num->Resize({num});
387 388 389 390 391 392 393 394
    }
    rpn_rois->set_lod(lod);
    rpn_roi_probs->set_lod(lod);
    rpn_rois->Resize({num_proposals, 4});
    rpn_roi_probs->Resize({num_proposals, 1});
  }

  std::pair<Tensor, Tensor> ProposalForOneImage(
395
      const platform::CPUDeviceContext &ctx, const Tensor &im_info_slice,
396 397 398 399 400 401 402 403 404 405 406 407 408 409
      const Tensor &anchors, const Tensor &variances,
      const Tensor &bbox_deltas_slice,  // [M, 4]
      const Tensor &scores_slice,       // [N, 1]
      int pre_nms_top_n, int post_nms_top_n, float nms_thresh, float min_size,
      float eta) const {
    auto *scores_data = scores_slice.data<T>();

    // Sort index
    Tensor index_t;
    index_t.Resize({scores_slice.numel()});
    int *index = index_t.mutable_data<int>(ctx.GetPlace());
    for (int i = 0; i < scores_slice.numel(); ++i) {
      index[i] = i;
    }
410 411 412
    auto compare = [scores_data](const int64_t &i, const int64_t &j) {
      return scores_data[i] > scores_data[j];
    };
413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440

    if (pre_nms_top_n <= 0 || pre_nms_top_n >= scores_slice.numel()) {
      std::sort(index, index + scores_slice.numel(), compare);
    } else {
      std::nth_element(index, index + pre_nms_top_n,
                       index + scores_slice.numel(), compare);
      index_t.Resize({pre_nms_top_n});
    }

    Tensor scores_sel, bbox_sel, anchor_sel, var_sel;
    scores_sel.mutable_data<T>({index_t.numel(), 1}, ctx.GetPlace());
    bbox_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
    anchor_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
    var_sel.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());

    CPUGather<T>(ctx, scores_slice, index_t, &scores_sel);
    CPUGather<T>(ctx, bbox_deltas_slice, index_t, &bbox_sel);
    CPUGather<T>(ctx, anchors, index_t, &anchor_sel);
    CPUGather<T>(ctx, variances, index_t, &var_sel);

    Tensor proposals;
    proposals.mutable_data<T>({index_t.numel(), 4}, ctx.GetPlace());
    BoxCoder<T>(ctx, &anchor_sel, &bbox_sel, &var_sel, &proposals);

    ClipTiledBoxes<T>(ctx, im_info_slice, &proposals);

    Tensor keep;
    FilterBoxes<T>(ctx, &proposals, min_size, im_info_slice, &keep);
441 442 443 444 445 446 447 448 449 450
    // Handle the case when there is no keep index left
    if (keep.numel() == 0) {
      math::SetConstant<platform::CPUDeviceContext, T> set_zero;
      bbox_sel.mutable_data<T>({1, 4}, ctx.GetPlace());
      set_zero(ctx, &bbox_sel, static_cast<T>(0));
      Tensor scores_filter;
      scores_filter.mutable_data<T>({1, 1}, ctx.GetPlace());
      set_zero(ctx, &scores_filter, static_cast<T>(0));
      return std::make_pair(bbox_sel, scores_filter);
    }
451 452 453 454 455 456 457

    Tensor scores_filter;
    bbox_sel.mutable_data<T>({keep.numel(), 4}, ctx.GetPlace());
    scores_filter.mutable_data<T>({keep.numel(), 1}, ctx.GetPlace());
    CPUGather<T>(ctx, proposals, keep, &bbox_sel);
    CPUGather<T>(ctx, scores_sel, keep, &scores_filter);
    if (nms_thresh <= 0) {
458
      return std::make_pair(bbox_sel, scores_filter);
459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478
    }

    Tensor keep_nms = NMS<T>(ctx, &bbox_sel, &scores_filter, nms_thresh, eta);

    if (post_nms_top_n > 0 && post_nms_top_n < keep_nms.numel()) {
      keep_nms.Resize({post_nms_top_n});
    }

    proposals.mutable_data<T>({keep_nms.numel(), 4}, ctx.GetPlace());
    scores_sel.mutable_data<T>({keep_nms.numel(), 1}, ctx.GetPlace());
    CPUGather<T>(ctx, bbox_sel, keep_nms, &proposals);
    CPUGather<T>(ctx, scores_filter, keep_nms, &scores_sel);

    return std::make_pair(proposals, scores_sel);
  }
};

class GenerateProposalsOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498
    AddInput("Scores",
             "(Tensor) The scores from conv is in shape (N, A, H, W), "
             "N is batch size, A is number of anchors, "
             "H and W are height and width of the feature map");
    AddInput("BboxDeltas",
             "(Tensor) Bounding box deltas from conv is in "
             "shape (N, 4*A, H, W).");
    AddInput("ImInfo",
             "(Tensor) Information for image reshape is in shape (N, 3), "
             "in format (height, width, scale)");
    AddInput("Anchors",
             "(Tensor) Bounding box anchors from anchor_generator_op "
             "is in shape (A, H, W, 4).");
    AddInput("Variances",
             "(Tensor) Bounding box variances with same shape as `Anchors`.");

    AddOutput("RpnRois",
              "(LoDTensor), Output proposals with shape (rois_num, 4).");
    AddOutput("RpnRoiProbs",
              "(LoDTensor) Scores of proposals with shape (rois_num, 1).");
499 500
    AddOutput("RpnRoisNum", "(Tensor), The number of Rpn RoIs in each image")
        .AsDispensable();
501 502 503 504 505 506 507 508 509 510
    AddAttr<int>("pre_nms_topN",
                 "Number of top scoring RPN proposals to keep before "
                 "applying NMS.");
    AddAttr<int>("post_nms_topN",
                 "Number of top scoring RPN proposals to keep after "
                 "applying NMS");
    AddAttr<float>("nms_thresh", "NMS threshold used on RPN proposals.");
    AddAttr<float>("min_size",
                   "Proposal height and width both need to be greater "
                   "than this min_size.");
511
    AddAttr<float>("eta", "The parameter for adaptive NMS.");
512
    AddComment(R"DOC(
513 514 515 516 517 518
This operator Generate bounding box proposals for Faster RCNN.
The propoasls are generated for a list of images based on image
score 'Scores', bounding box regression result 'BboxDeltas' as
well as predefined bounding box shapes 'anchors'. Greedy
non-maximum suppression is applied to generate the final bounding
boxes.
519 520 521 522 523 524 525 526 527

)DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
H
hong 已提交
528 529 530 531
REGISTER_OPERATOR(
    generate_proposals, ops::GenerateProposalsOp, ops::GenerateProposalsOpMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
532 533
REGISTER_OP_CPU_KERNEL(generate_proposals, ops::GenerateProposalsKernel<float>,
                       ops::GenerateProposalsKernel<double>);
534 535 536 537 538 539 540 541
REGISTER_OP_VERSION(generate_proposals)
    .AddCheckpoint(
        R"ROC(
              Upgrade generate_proposals add a new output [RpnRoisNum])ROC",
        paddle::framework::compatible::OpVersionDesc().NewOutput(
            "RpnRoisNum",
            "The number of Rpn RoIs in each image. RpnRoisNum is "
            "dispensable."));