From e3ec5d0f0262d396e5eda2768cf67df3face18a1 Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Wed, 8 Feb 2023 19:29:54 +0800 Subject: [PATCH] update rcnn fit for paddle 2.2 (#7706) --- ppdet/modeling/heads/roi_extractor.py | 24 +++++------- .../proposal_generator/proposal_generator.py | 39 +++++++------------ 2 files changed, 25 insertions(+), 38 deletions(-) diff --git a/ppdet/modeling/heads/roi_extractor.py b/ppdet/modeling/heads/roi_extractor.py index eb9b75b10..b96bb4e91 100644 --- a/ppdet/modeling/heads/roi_extractor.py +++ b/ppdet/modeling/heads/roi_extractor.py @@ -88,21 +88,17 @@ class RoIAlign(object): k_min = self.start_level + offset k_max = self.end_level + offset if hasattr(paddle.vision.ops, "distribute_fpn_proposals"): - rois_dist, restore_index, rois_num_dist = paddle.vision.ops.distribute_fpn_proposals( - roi, - k_min, - k_max, - self.canconical_level, - self.canonical_size, - rois_num=rois_num) + distribute_fpn_proposals = getattr(paddle.vision.ops, + "distribute_fpn_proposals") else: - rois_dist, restore_index, rois_num_dist = ops.distribute_fpn_proposals( - roi, - k_min, - k_max, - self.canconical_level, - self.canonical_size, - rois_num=rois_num) + distribute_fpn_proposals = ops.distribute_fpn_proposals + rois_dist, restore_index, rois_num_dist = distribute_fpn_proposals( + roi, + k_min, + k_max, + self.canconical_level, + self.canonical_size, + rois_num=rois_num) rois_feat_list = [] for lvl in range(self.start_level, self.end_level + 1): diff --git a/ppdet/modeling/proposal_generator/proposal_generator.py b/ppdet/modeling/proposal_generator/proposal_generator.py index 6c722c8cf..b87a72ced 100644 --- a/ppdet/modeling/proposal_generator/proposal_generator.py +++ b/ppdet/modeling/proposal_generator/proposal_generator.py @@ -63,30 +63,21 @@ class ProposalGenerator(object): top_n = self.pre_nms_top_n if self.topk_after_collect else self.post_nms_top_n variances = paddle.ones_like(anchors) if hasattr(paddle.vision.ops, "generate_proposals"): - rpn_rois, rpn_rois_prob, rpn_rois_num = paddle.vision.ops.generate_proposals( - scores, - bbox_deltas, - im_shape, - anchors, - variances, - pre_nms_top_n=self.pre_nms_top_n, - post_nms_top_n=top_n, - nms_thresh=self.nms_thresh, - min_size=self.min_size, - eta=self.eta, - return_rois_num=True) + generate_proposals = getattr(paddle.vision.ops, + "generate_proposals") else: - rpn_rois, rpn_rois_prob, rpn_rois_num = ops.generate_proposals( - scores, - bbox_deltas, - im_shape, - anchors, - variances, - pre_nms_top_n=self.pre_nms_top_n, - post_nms_top_n=top_n, - nms_thresh=self.nms_thresh, - min_size=self.min_size, - eta=self.eta, - return_rois_num=True) + generate_proposals = ops.generate_proposals + rpn_rois, rpn_rois_prob, rpn_rois_num = generate_proposals( + scores, + bbox_deltas, + im_shape, + anchors, + variances, + pre_nms_top_n=self.pre_nms_top_n, + post_nms_top_n=top_n, + nms_thresh=self.nms_thresh, + min_size=self.min_size, + eta=self.eta, + return_rois_num=True) return rpn_rois, rpn_rois_prob, rpn_rois_num, self.post_nms_top_n -- GitLab