yolo_head.py 4.0 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register
7
from ..backbones.darknet import ConvBNLayer
Q
qingqing01 已提交
8 9


W
wangxinxin08 已提交
10 11 12 13 14 15 16
def _de_sigmoid(x, eps=1e-7):
    x = paddle.clip(x, eps, 1. / eps)
    x = paddle.clip(1. / x - 1., eps, 1. / eps)
    x = -paddle.log(x)
    return x


Q
qingqing01 已提交
17 18
@register
class YOLOv3Head(nn.Layer):
19
    __shared__ = ['num_classes', 'data_format']
Q
qingqing01 已提交
20 21 22 23 24 25 26
    __inject__ = ['loss']

    def __init__(self,
                 anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
                          [59, 119], [116, 90], [156, 198], [373, 326]],
                 anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
                 num_classes=80,
W
wangxinxin08 已提交
27 28
                 loss='YOLOv3Loss',
                 iou_aware=False,
29 30
                 iou_aware_factor=0.4,
                 data_format='NCHW'):
Q
qingqing01 已提交
31 32 33 34
        super(YOLOv3Head, self).__init__()
        self.num_classes = num_classes
        self.loss = loss

W
wangxinxin08 已提交
35 36 37
        self.iou_aware = iou_aware
        self.iou_aware_factor = iou_aware_factor

Q
qingqing01 已提交
38 39
        self.parse_anchor(anchors, anchor_masks)
        self.num_outputs = len(self.anchors)
40
        self.data_format = data_format
Q
qingqing01 已提交
41 42 43

        self.yolo_outputs = []
        for i in range(len(self.anchors)):
W
wangxinxin08 已提交
44

W
wangxinxin08 已提交
45
            if self.iou_aware:
W
wangxinxin08 已提交
46
                num_filters = len(self.anchors[i]) * (self.num_classes + 6)
W
wangxinxin08 已提交
47
            else:
W
wangxinxin08 已提交
48
                num_filters = len(self.anchors[i]) * (self.num_classes + 5)
Q
qingqing01 已提交
49 50 51 52
            name = 'yolo_output.{}'.format(i)
            yolo_output = self.add_sublayer(
                name,
                nn.Conv2D(
W
wangxinxin08 已提交
53
                    in_channels=128 * (2**self.num_outputs) // (2**i),
Q
qingqing01 已提交
54 55 56 57
                    out_channels=num_filters,
                    kernel_size=1,
                    stride=1,
                    padding=0,
58
                    data_format=data_format,
Q
qingqing01 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
                    weight_attr=ParamAttr(name=name + '.conv.weights'),
                    bias_attr=ParamAttr(
                        name=name + '.conv.bias', regularizer=L2Decay(0.))))
            self.yolo_outputs.append(yolo_output)

    def parse_anchor(self, anchors, anchor_masks):
        self.anchors = [[anchors[i] for i in mask] for mask in anchor_masks]
        self.mask_anchors = []
        anchor_num = len(anchors)
        for masks in anchor_masks:
            self.mask_anchors.append([])
            for mask in masks:
                assert mask < anchor_num, "anchor mask index overflow"
                self.mask_anchors[-1].extend(anchors[mask])

74
    def forward(self, feats, targets=None):
Q
qingqing01 已提交
75 76 77 78
        assert len(feats) == len(self.anchors)
        yolo_outputs = []
        for i, feat in enumerate(feats):
            yolo_output = self.yolo_outputs[i](feat)
79 80
            if self.data_format == 'NHWC':
                yolo_output = paddle.transpose(yolo_output, [0, 3, 1, 2])
Q
qingqing01 已提交
81 82
            yolo_outputs.append(yolo_output)

83 84
        if self.training:
            return self.loss(yolo_outputs, targets, self.anchors)
W
wangxinxin08 已提交
85
        else:
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
            if self.iou_aware:
                y = []
                for i, out in enumerate(yolo_outputs):
                    na = len(self.anchors[i])
                    ioup, x = out[:, 0:na, :, :], out[:, na:, :, :]
                    b, c, h, w = x.shape
                    no = c // na
                    x = x.reshape((b, na, no, h * w))
                    ioup = ioup.reshape((b, na, 1, h * w))
                    obj = x[:, :, 4:5, :]
                    ioup = F.sigmoid(ioup)
                    obj = F.sigmoid(obj)
                    obj_t = (obj**(1 - self.iou_aware_factor)) * (
                        ioup**self.iou_aware_factor)
                    obj_t = _de_sigmoid(obj_t)
                    loc_t = x[:, :, :4, :]
                    cls_t = x[:, :, 5:, :]
                    y_t = paddle.concat([loc_t, obj_t, cls_t], axis=2)
                    y_t = y_t.reshape((b, c, h, w))
                    y.append(y_t)
                return y
            else:
                return yolo_outputs