yolo_head.py 3.7 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from ppdet.core.workspace import register
7
from ..backbones.darknet import ConvBNLayer
Q
qingqing01 已提交
8 9


W
wangxinxin08 已提交
10 11 12 13 14 15 16
def _de_sigmoid(x, eps=1e-7):
    x = paddle.clip(x, eps, 1. / eps)
    x = paddle.clip(1. / x - 1., eps, 1. / eps)
    x = -paddle.log(x)
    return x


Q
qingqing01 已提交
17 18 19 20 21 22 23 24 25 26
@register
class YOLOv3Head(nn.Layer):
    __shared__ = ['num_classes']
    __inject__ = ['loss']

    def __init__(self,
                 anchors=[[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
                          [59, 119], [116, 90], [156, 198], [373, 326]],
                 anchor_masks=[[6, 7, 8], [3, 4, 5], [0, 1, 2]],
                 num_classes=80,
W
wangxinxin08 已提交
27 28 29
                 loss='YOLOv3Loss',
                 iou_aware=False,
                 iou_aware_factor=0.4):
Q
qingqing01 已提交
30 31 32 33
        super(YOLOv3Head, self).__init__()
        self.num_classes = num_classes
        self.loss = loss

W
wangxinxin08 已提交
34 35 36
        self.iou_aware = iou_aware
        self.iou_aware_factor = iou_aware_factor

Q
qingqing01 已提交
37 38 39 40 41
        self.parse_anchor(anchors, anchor_masks)
        self.num_outputs = len(self.anchors)

        self.yolo_outputs = []
        for i in range(len(self.anchors)):
W
wangxinxin08 已提交
42 43 44 45
            if self.iou_aware:
                num_filters = self.num_outputs * (self.num_classes + 6)
            else:
                num_filters = self.num_outputs * (self.num_classes + 5)
Q
qingqing01 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
            name = 'yolo_output.{}'.format(i)
            yolo_output = self.add_sublayer(
                name,
                nn.Conv2D(
                    in_channels=1024 // (2**i),
                    out_channels=num_filters,
                    kernel_size=1,
                    stride=1,
                    padding=0,
                    weight_attr=ParamAttr(name=name + '.conv.weights'),
                    bias_attr=ParamAttr(
                        name=name + '.conv.bias', regularizer=L2Decay(0.))))
            self.yolo_outputs.append(yolo_output)

    def parse_anchor(self, anchors, anchor_masks):
        self.anchors = [[anchors[i] for i in mask] for mask in anchor_masks]
        self.mask_anchors = []
        anchor_num = len(anchors)
        for masks in anchor_masks:
            self.mask_anchors.append([])
            for mask in masks:
                assert mask < anchor_num, "anchor mask index overflow"
                self.mask_anchors[-1].extend(anchors[mask])

    def forward(self, feats):
        assert len(feats) == len(self.anchors)
        yolo_outputs = []
        for i, feat in enumerate(feats):
            yolo_output = self.yolo_outputs[i](feat)
            yolo_outputs.append(yolo_output)
        return yolo_outputs

    def get_loss(self, inputs, targets):
        return self.loss(inputs, targets, self.anchors)
W
wangxinxin08 已提交
80 81 82 83 84 85 86 87 88

    def get_outputs(self, outputs):
        if self.iou_aware:
            y = []
            for i, out in enumerate(outputs):
                na = len(self.anchors[i])
                ioup, x = out[:, 0:na, :, :], out[:, na:, :, :]
                b, c, h, w = x.shape
                no = c // na
W
wangxinxin08 已提交
89 90 91
                x = x.reshape((b, na, no, h * w))
                ioup = ioup.reshape((b, na, 1, h * w))
                obj = x[:, :, 4:5, :]
W
wangxinxin08 已提交
92 93 94 95 96
                ioup = F.sigmoid(ioup)
                obj = F.sigmoid(obj)
                obj_t = (obj**(1 - self.iou_aware_factor)) * (
                    ioup**self.iou_aware_factor)
                obj_t = _de_sigmoid(obj_t)
W
wangxinxin08 已提交
97 98
                loc_t = x[:, :, :4, :]
                cls_t = x[:, :, 5:, :]
W
wangxinxin08 已提交
99
                y_t = paddle.concat([loc_t, obj_t, cls_t], axis=2)
W
wangxinxin08 已提交
100
                y_t = y_t.reshape((b, c, h, w))
W
wangxinxin08 已提交
101 102 103 104
                y.append(y_t)
            return y
        else:
            return outputs