yolo_head.py 4.3 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register
7
from ..backbones.darknet import ConvBNLayer
Q
qingqing01 已提交
8 9


W
wangxinxin08 已提交
10 11 12 13 14 15 16
def _de_sigmoid(x, eps=1e-7):
    x = paddle.clip(x, eps, 1. / eps)
    x = paddle.clip(1. / x - 1., eps, 1. / eps)
    x = -paddle.log(x)
    return x


Q
qingqing01 已提交
17 18
@register
class YOLOv3Head(nn.Layer):
19
    __shared__ = ['num_classes', 'data_format']
Q
qingqing01 已提交
20 21 22 23 24 25 26
    __inject__ = ['loss']

    def __init__(self,
                 anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
                          [59, 119], [116, 90], [156, 198], [373, 326]],
                 anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
                 num_classes=80,
W
wangxinxin08 已提交
27 28
                 loss='YOLOv3Loss',
                 iou_aware=False,
29 30
                 iou_aware_factor=0.4,
                 data_format='NCHW'):
W
wangxinxin08 已提交
31 32 33 34 35 36 37 38 39 40 41 42
        """
        Head for YOLOv3 network

        Args:
            num_classes (int): number of foreground classes
            anchors (list): anchors
            anchor_masks (list): anchor masks
            loss (object): YOLOv3Loss instance
            iou_aware (bool): whether to use iou_aware
            iou_aware_factor (float): iou aware factor
            data_format (str): data format, NCHW or NHWC
        """
Q
qingqing01 已提交
43 44 45 46
        super(YOLOv3Head, self).__init__()
        self.num_classes = num_classes
        self.loss = loss

W
wangxinxin08 已提交
47 48 49
        self.iou_aware = iou_aware
        self.iou_aware_factor = iou_aware_factor

Q
qingqing01 已提交
50 51
        self.parse_anchor(anchors, anchor_masks)
        self.num_outputs = len(self.anchors)
52
        self.data_format = data_format
Q
qingqing01 已提交
53 54 55

        self.yolo_outputs = []
        for i in range(len(self.anchors)):
W
wangxinxin08 已提交
56

W
wangxinxin08 已提交
57
            if self.iou_aware:
W
wangxinxin08 已提交
58
                num_filters = len(self.anchors[i]) * (self.num_classes + 6)
W
wangxinxin08 已提交
59
            else:
W
wangxinxin08 已提交
60
                num_filters = len(self.anchors[i]) * (self.num_classes + 5)
Q
qingqing01 已提交
61 62 63 64
            name = 'yolo_output.{}'.format(i)
            yolo_output = self.add_sublayer(
                name,
                nn.Conv2D(
W
wangxinxin08 已提交
65
                    in_channels=128 * (2**self.num_outputs) // (2**i),
Q
qingqing01 已提交
66 67 68 69
                    out_channels=num_filters,
                    kernel_size=1,
                    stride=1,
                    padding=0,
70
                    data_format=data_format,
71
                    bias_attr=ParamAttr(regularizer=L2Decay(0.))))
Q
qingqing01 已提交
72 73 74 75 76 77 78 79 80 81 82 83
            self.yolo_outputs.append(yolo_output)

    def parse_anchor(self, anchors, anchor_masks):
        self.anchors = [[anchors[i] for i in mask] for mask in anchor_masks]
        self.mask_anchors = []
        anchor_num = len(anchors)
        for masks in anchor_masks:
            self.mask_anchors.append([])
            for mask in masks:
                assert mask < anchor_num, "anchor mask index overflow"
                self.mask_anchors[-1].extend(anchors[mask])

84
    def forward(self, feats, targets=None):
Q
qingqing01 已提交
85 86 87 88
        assert len(feats) == len(self.anchors)
        yolo_outputs = []
        for i, feat in enumerate(feats):
            yolo_output = self.yolo_outputs[i](feat)
89 90
            if self.data_format == 'NHWC':
                yolo_output = paddle.transpose(yolo_output, [0, 3, 1, 2])
Q
qingqing01 已提交
91 92
            yolo_outputs.append(yolo_output)

93 94
        if self.training:
            return self.loss(yolo_outputs, targets, self.anchors)
W
wangxinxin08 已提交
95
        else:
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
            if self.iou_aware:
                y = []
                for i, out in enumerate(yolo_outputs):
                    na = len(self.anchors[i])
                    ioup, x = out[:, 0:na, :, :], out[:, na:, :, :]
                    b, c, h, w = x.shape
                    no = c // na
                    x = x.reshape((b, na, no, h * w))
                    ioup = ioup.reshape((b, na, 1, h * w))
                    obj = x[:, :, 4:5, :]
                    ioup = F.sigmoid(ioup)
                    obj = F.sigmoid(obj)
                    obj_t = (obj**(1 - self.iou_aware_factor)) * (
                        ioup**self.iou_aware_factor)
                    obj_t = _de_sigmoid(obj_t)
                    loc_t = x[:, :, :4, :]
                    cls_t = x[:, :, 5:, :]
                    y_t = paddle.concat([loc_t, obj_t, cls_t], axis=2)
                    y_t = y_t.reshape((b, c, h, w))
                    y.append(y_t)
                return y
            else:
                return yolo_outputs