detectionoutput.py 2.3 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
from .register import register
from x2paddle.core.util import *


def detectionoutput_shape(input_shape):
    return [[-1, 6]]


def detectionoutput_layer(inputs,
                          nms_param=None,
                          background_label_id=0,
                          share_location=True,
                          keep_top_k=100,
                          confidence_threshold=0.1,
                          input_shape=None,
                          name=None):
17 18 19 20 21 22 23 24 25 26 27 28
    nms_param_str = nms_param
    nms_param = {}
    part = nms_param_str.split(',')
    for s in part:
        if s == '':
            break
        else:
            name, obj = s.split(': ')
            if name == 'top_k':
                nms_param[name] = int(obj)
            else:
                nms_param[name] = float(obj)
S
SunAhong1993 已提交
29 30 31 32 33 34 35 36 37 38
    if nms_param is None:
        nms_param = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
    mbox_conf_flatten = inputs[1]
    mbox_priorbox = inputs[2]
    mbox_priorbox_list = fluid.layers.split(mbox_priorbox, 2, dim=1)
    pb = mbox_priorbox_list[0]
    pbv = mbox_priorbox_list[1]
    pb = fluid.layers.reshape(x=pb, shape=[-1, 4])
    pbv = fluid.layers.reshape(x=pbv, shape=[-1, 4])
    mbox_loc = inputs[0]
S
SunAhong1993 已提交
39
    mbox_loc = fluid.layers.reshape(x=mbox_loc, shape=[-1, pb.shape[0], 4])
40
    mbox_conf_flatten = fluid.layers.reshape(x=mbox_conf_flatten,
S
SunAhong1993 已提交
41
                                             shape=[0, pb.shape[0], -1])
S
SunAhong1993 已提交
42 43 44 45

    default = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
    fields = ['eta', 'top_k', 'nms_threshold']
    for f in default.keys():
46
        if f not in nms_param:
S
SunAhong1993 已提交
47 48 49 50 51 52
            nms_param[f] = default[f]
    out = fluid.layers.detection_output(
        scores=mbox_conf_flatten,
        loc=mbox_loc,
        prior_box=pb,
        prior_box_var=pbv,
53
        background_label=background_label_id,
S
SunAhong1993 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
        nms_threshold=nms_param["nms_threshold"],
        nms_top_k=nms_param["top_k"],
        keep_top_k=keep_top_k,
        score_threshold=confidence_threshold,
        nms_eta=nms_param["eta"])
    return out


def detectionoutput_weights(name, data=None):
    weights_name = []
    return weights_name


register(kind='DetectionOutput',
         shape=detectionoutput_shape,
         layer=detectionoutput_layer,
         weights=detectionoutput_weights)