diff --git a/ppdet/modeling/heads/roi_extractor.py b/ppdet/modeling/heads/roi_extractor.py index eb9b75b10e1a6fca88b12f27b48a5e51969a52c2..b96bb4e91a0ca720d099562c3cf51d95080051fa 100644 --- a/ppdet/modeling/heads/roi_extractor.py +++ b/ppdet/modeling/heads/roi_extractor.py @@ -88,21 +88,17 @@ class RoIAlign(object): k_min = self.start_level + offset k_max = self.end_level + offset if hasattr(paddle.vision.ops, "distribute_fpn_proposals"): - rois_dist, restore_index, rois_num_dist = paddle.vision.ops.distribute_fpn_proposals( - roi, - k_min, - k_max, - self.canconical_level, - self.canonical_size, - rois_num=rois_num) + distribute_fpn_proposals = getattr(paddle.vision.ops, + "distribute_fpn_proposals") else: - rois_dist, restore_index, rois_num_dist = ops.distribute_fpn_proposals( - roi, - k_min, - k_max, - self.canconical_level, - self.canonical_size, - rois_num=rois_num) + distribute_fpn_proposals = ops.distribute_fpn_proposals + rois_dist, restore_index, rois_num_dist = distribute_fpn_proposals( + roi, + k_min, + k_max, + self.canconical_level, + self.canonical_size, + rois_num=rois_num) rois_feat_list = [] for lvl in range(self.start_level, self.end_level + 1): diff --git a/ppdet/modeling/proposal_generator/proposal_generator.py b/ppdet/modeling/proposal_generator/proposal_generator.py index 6c722c8cf0872140b77acb2ff9bb1af352cb66e7..b87a72ced5f0ddcd9515332a17b52d5210e9398a 100644 --- a/ppdet/modeling/proposal_generator/proposal_generator.py +++ b/ppdet/modeling/proposal_generator/proposal_generator.py @@ -63,30 +63,21 @@ class ProposalGenerator(object): top_n = self.pre_nms_top_n if self.topk_after_collect else self.post_nms_top_n variances = paddle.ones_like(anchors) if hasattr(paddle.vision.ops, "generate_proposals"): - rpn_rois, rpn_rois_prob, rpn_rois_num = paddle.vision.ops.generate_proposals( - scores, - bbox_deltas, - im_shape, - anchors, - variances, - pre_nms_top_n=self.pre_nms_top_n, - post_nms_top_n=top_n, - nms_thresh=self.nms_thresh, - min_size=self.min_size, - eta=self.eta, - return_rois_num=True) + generate_proposals = getattr(paddle.vision.ops, + "generate_proposals") else: - rpn_rois, rpn_rois_prob, rpn_rois_num = ops.generate_proposals( - scores, - bbox_deltas, - im_shape, - anchors, - variances, - pre_nms_top_n=self.pre_nms_top_n, - post_nms_top_n=top_n, - nms_thresh=self.nms_thresh, - min_size=self.min_size, - eta=self.eta, - return_rois_num=True) + generate_proposals = ops.generate_proposals + rpn_rois, rpn_rois_prob, rpn_rois_num = generate_proposals( + scores, + bbox_deltas, + im_shape, + anchors, + variances, + pre_nms_top_n=self.pre_nms_top_n, + post_nms_top_n=top_n, + nms_thresh=self.nms_thresh, + min_size=self.min_size, + eta=self.eta, + return_rois_num=True) return rpn_rois, rpn_rois_prob, rpn_rois_num, self.post_nms_top_n