yolo.py 1.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from ppdet.core.workspace import register
from .meta_arch import BaseArch

__all__ = ['YOLOv3']


@register
class YOLOv3(BaseArch):
    __category__ = 'architecture'
    __inject__ = [
        'anchor',
        'backbone',
        'yolo_head',
18
        'post_process',
19 20
    ]

21
    def __init__(self, anchor, backbone, yolo_head, post_process):
W
wangguanzhong 已提交
22
        super(YOLOv3, self).__init__()
23 24 25
        self.anchor = anchor
        self.backbone = backbone
        self.yolo_head = yolo_head
26
        self.post_process = post_process
27

28
    def model_arch(self, ):
29
        # Backbone
W
wangguanzhong 已提交
30
        body_feats = self.backbone(self.inputs)
31 32

        # YOLO Head
W
wangguanzhong 已提交
33
        self.yolo_head_out = self.yolo_head(body_feats)
34 35

        # Anchor
W
wangguanzhong 已提交
36
        self.anchors, self.anchor_masks, self.mask_anchors = self.anchor()
37

38
    def loss(self, ):
W
wangguanzhong 已提交
39 40 41 42
        yolo_loss = self.yolo_head.loss(self.inputs, self.yolo_head_out,
                                        self.anchors, self.anchor_masks,
                                        self.mask_anchors)
        return yolo_loss
43

44
    def infer(self, ):
45 46
        bbox, bbox_num = self.post_process(
            self.yolo_head_out, self.mask_anchors, self.inputs['im_size'])
47
        outs = {
W
wangguanzhong 已提交
48
            "bbox": bbox.numpy(),
49
            "bbox_num": bbox_num.numpy(),
W
wangguanzhong 已提交
50
            'im_id': self.inputs['im_id'].numpy()
51 52
        }
        return outs