nms.py 4.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2020  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
W
wjj19950828 已提交
16 17 18
from paddle import _C_ops
from paddle import in_dynamic_mode
from paddle.common_ops_import import Variable, LayerHelper
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35


def multiclass_nms(bboxes,
                   scores,
                   score_threshold,
                   nms_top_k,
                   keep_top_k,
                   nms_threshold=0.3,
                   normalized=True,
                   nms_eta=1.,
                   background_label=-1,
                   return_index=False,
                   return_rois_num=True,
                   rois_num=None,
                   name=None):
    helper = LayerHelper('multiclass_nms3', **locals())

W
wjj19950828 已提交
36
    if in_dynamic_mode():
37 38 39 40
        attrs = ('background_label', background_label, 'score_threshold',
                 score_threshold, 'nms_top_k', nms_top_k, 'nms_threshold',
                 nms_threshold, 'keep_top_k', keep_top_k, 'nms_eta', nms_eta,
                 'normalized', normalized)
W
wjj19950828 已提交
41 42
        output, index, nms_rois_num = _C_ops.multiclass_nms3(bboxes, scores,
                                                             rois_num, *attrs)
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
        if not return_index:
            index = None
        return output, nms_rois_num, index

    else:
        output = helper.create_variable_for_type_inference(dtype=bboxes.dtype)
        index = helper.create_variable_for_type_inference(dtype='int')

        inputs = {'BBoxes': bboxes, 'Scores': scores}
        outputs = {'Out': output, 'Index': index}

        if rois_num is not None:
            inputs['RoisNum'] = rois_num

        if return_rois_num:
            nms_rois_num = helper.create_variable_for_type_inference(
                dtype='int32')
            outputs['NmsRoisNum'] = nms_rois_num

        helper.append_op(
            type="multiclass_nms3",
            inputs=inputs,
            attrs={
                'background_label': background_label,
                'score_threshold': score_threshold,
                'nms_top_k': nms_top_k,
                'nms_threshold': nms_threshold,
                'keep_top_k': keep_top_k,
                'nms_eta': nms_eta,
                'normalized': normalized
            },
            outputs=outputs)
        output.stop_gradient = True
        index.stop_gradient = True
        if not return_index:
            index = None
        if not return_rois_num:
            nms_rois_num = None

        return output, nms_rois_num, index


class NMS(object):
86
    def __init__(self, score_threshold, keep_top_k, nms_threshold):
87
        self.score_threshold = score_threshold
88
        self.keep_top_k = keep_top_k
89 90 91 92 93 94
        self.nms_threshold = nms_threshold

    def __call__(self, bboxes, scores):
        attrs = {
            'background_label': -1,
            'score_threshold': self.score_threshold,
95
            'nms_top_k': -1,
96
            'nms_threshold': self.nms_threshold,
97
            'keep_top_k': self.keep_top_k,
98 99 100 101 102 103 104 105 106 107 108
            'nms_eta': 1.0,
            'normalized': False,
            'return_index': True
        }
        output, nms_rois_num, index = multiclass_nms(bboxes, scores, **attrs)
        clas = paddle.slice(output, axes=[1], starts=[0], ends=[1])
        clas = paddle.cast(clas, dtype="int64")
        index = paddle.cast(index, dtype="int64")
        if bboxes.shape[0] == 1:
            batch = paddle.zeros_like(clas, dtype="int64")
        else:
W
wjj19950828 已提交
109 110
            bboxes_count = paddle.shape(bboxes)[1]
            bboxes_count = paddle.cast(bboxes_count, dtype="int64")
111 112 113 114
            batch = paddle.divide(index, bboxes_count)
            index = paddle.mod(index, bboxes_count)
        res = paddle.concat([batch, clas, index], axis=1)
        return res