yolo_head.py 4.5 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register


W
wangxinxin08 已提交
9 10 11 12 13 14 15
def _de_sigmoid(x, eps=1e-7):
    x = paddle.clip(x, eps, 1. / eps)
    x = paddle.clip(1. / x - 1., eps, 1. / eps)
    x = -paddle.log(x)
    return x


Q
qingqing01 已提交
16 17
@register
class YOLOv3Head(nn.Layer):
18
    __shared__ = ['num_classes', 'data_format']
Q
qingqing01 已提交
19 20 21
    __inject__ = ['loss']

    def __init__(self,
K
Kaipeng Deng 已提交
22
                 in_channels=[1024, 512, 256],
Q
qingqing01 已提交
23 24 25 26
                 anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
                          [59, 119], [116, 90], [156, 198], [373, 326]],
                 anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
                 num_classes=80,
W
wangxinxin08 已提交
27 28
                 loss='YOLOv3Loss',
                 iou_aware=False,
29 30
                 iou_aware_factor=0.4,
                 data_format='NCHW'):
W
wangxinxin08 已提交
31 32 33 34 35 36 37 38 39 40 41 42
        """
        Head for YOLOv3 network

        Args:
            num_classes (int): number of foreground classes
            anchors (list): anchors
            anchor_masks (list): anchor masks
            loss (object): YOLOv3Loss instance
            iou_aware (bool): whether to use iou_aware
            iou_aware_factor (float): iou aware factor
            data_format (str): data format, NCHW or NHWC
        """
Q
qingqing01 已提交
43
        super(YOLOv3Head, self).__init__()
K
Kaipeng Deng 已提交
44 45
        assert len(in_channels) > 0, "in_channels length should > 0"
        self.in_channels = in_channels
Q
qingqing01 已提交
46 47 48
        self.num_classes = num_classes
        self.loss = loss

W
wangxinxin08 已提交
49 50 51
        self.iou_aware = iou_aware
        self.iou_aware_factor = iou_aware_factor

Q
qingqing01 已提交
52 53
        self.parse_anchor(anchors, anchor_masks)
        self.num_outputs = len(self.anchors)
54
        self.data_format = data_format
Q
qingqing01 已提交
55 56 57

        self.yolo_outputs = []
        for i in range(len(self.anchors)):
W
wangxinxin08 已提交
58

W
wangxinxin08 已提交
59
            if self.iou_aware:
W
wangxinxin08 已提交
60
                num_filters = len(self.anchors[i]) * (self.num_classes + 6)
W
wangxinxin08 已提交
61
            else:
W
wangxinxin08 已提交
62
                num_filters = len(self.anchors[i]) * (self.num_classes + 5)
Q
qingqing01 已提交
63
            name = 'yolo_output.{}'.format(i)
G
Guanghua Yu 已提交
64
            conv = nn.Conv2D(
K
Kaipeng Deng 已提交
65
                in_channels=self.in_channels[i],
G
Guanghua Yu 已提交
66 67 68 69 70 71 72 73
                out_channels=num_filters,
                kernel_size=1,
                stride=1,
                padding=0,
                data_format=data_format,
                bias_attr=ParamAttr(regularizer=L2Decay(0.)))
            conv.skip_quant = True
            yolo_output = self.add_sublayer(name, conv)
Q
qingqing01 已提交
74 75 76 77 78 79 80 81 82 83 84 85
            self.yolo_outputs.append(yolo_output)

    def parse_anchor(self, anchors, anchor_masks):
        self.anchors = [[anchors[i] for i in mask] for mask in anchor_masks]
        self.mask_anchors = []
        anchor_num = len(anchors)
        for masks in anchor_masks:
            self.mask_anchors.append([])
            for mask in masks:
                assert mask < anchor_num, "anchor mask index overflow"
                self.mask_anchors[-1].extend(anchors[mask])

86
    def forward(self, feats, targets=None):
Q
qingqing01 已提交
87 88 89 90
        assert len(feats) == len(self.anchors)
        yolo_outputs = []
        for i, feat in enumerate(feats):
            yolo_output = self.yolo_outputs[i](feat)
91 92
            if self.data_format == 'NHWC':
                yolo_output = paddle.transpose(yolo_output, [0, 3, 1, 2])
Q
qingqing01 已提交
93 94
            yolo_outputs.append(yolo_output)

95 96
        if self.training:
            return self.loss(yolo_outputs, targets, self.anchors)
W
wangxinxin08 已提交
97
        else:
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
            if self.iou_aware:
                y = []
                for i, out in enumerate(yolo_outputs):
                    na = len(self.anchors[i])
                    ioup, x = out[:, 0:na, :, :], out[:, na:, :, :]
                    b, c, h, w = x.shape
                    no = c // na
                    x = x.reshape((b, na, no, h * w))
                    ioup = ioup.reshape((b, na, 1, h * w))
                    obj = x[:, :, 4:5, :]
                    ioup = F.sigmoid(ioup)
                    obj = F.sigmoid(obj)
                    obj_t = (obj**(1 - self.iou_aware_factor)) * (
                        ioup**self.iou_aware_factor)
                    obj_t = _de_sigmoid(obj_t)
                    loc_t = x[:, :, :4, :]
                    cls_t = x[:, :, 5:, :]
                    y_t = paddle.concat([loc_t, obj_t, cls_t], axis=2)
                    y_t = y_t.reshape((b, c, h, w))
                    y.append(y_t)
                return y
            else:
                return yolo_outputs
K
Kaipeng Deng 已提交
121 122 123 124

    @classmethod
    def from_config(cls, cfg, input_shape):
        return {'in_channels': [i.channels for i in input_shape], }