detectionoutput.py 1.9 KB
Newer Older
S
SunAhong1993 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
from .register import register
from x2paddle.core.util import *


def detectionoutput_shape(input_shape):
    return [[-1, 6]]


def detectionoutput_layer(inputs,
                          nms_param=None,
                          background_label_id=0,
                          share_location=True,
                          keep_top_k=100,
                          confidence_threshold=0.1,
                          input_shape=None,
                          name=None):
    if nms_param is None:
        nms_param = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
    mbox_conf_flatten = inputs[1]
    mbox_priorbox = inputs[2]
S
SunAhong1993 已提交
21
    mbox_priorbox_list = paddle.split(mbox_priorbox, 2, dim=1)
S
SunAhong1993 已提交
22 23
    pb = mbox_priorbox_list[0]
    pbv = mbox_priorbox_list[1]
S
SunAhong1993 已提交
24 25
    pb = paddle.reshape(x=pb, shape=[-1, 4])
    pbv = paddle.reshape(x=pbv, shape=[-1, 4])
S
SunAhong1993 已提交
26
    mbox_loc = inputs[0]
S
SunAhong1993 已提交
27 28
    mbox_loc = paddle.reshape(x=mbox_loc, shape=[-1, pb.shape[0], 4])
    mbox_conf_flatten = paddle.reshape(
J
jiangjiajun 已提交
29
        x=mbox_conf_flatten, shape=[0, pb.shape[0], -1])
S
SunAhong1993 已提交
30 31 32 33

    default = {"nms_threshold": 0.3, "top_k": 10, "eta": 1.0}
    fields = ['eta', 'top_k', 'nms_threshold']
    for f in default.keys():
34
        if f not in nms_param:
S
SunAhong1993 已提交
35 36 37 38 39 40
            nms_param[f] = default[f]
    out = fluid.layers.detection_output(
        scores=mbox_conf_flatten,
        loc=mbox_loc,
        prior_box=pb,
        prior_box_var=pbv,
41
        background_label=background_label_id,
S
SunAhong1993 已提交
42 43 44 45 46 47 48 49 50 51 52 53 54
        nms_threshold=nms_param["nms_threshold"],
        nms_top_k=nms_param["top_k"],
        keep_top_k=keep_top_k,
        score_threshold=confidence_threshold,
        nms_eta=nms_param["eta"])
    return out


def detectionoutput_weights(name, data=None):
    weights_name = []
    return weights_name


J
jiangjiajun 已提交
55 56 57 58 59
register(
    kind='DetectionOutput',
    shape=detectionoutput_shape,
    layer=detectionoutput_layer,
    weights=detectionoutput_weights)