yolo.py 1.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from ppdet.core.workspace import register
from .meta_arch import BaseArch

__all__ = ['YOLOv3']


@register
class YOLOv3(BaseArch):
    __category__ = 'architecture'
    __inject__ = [
        'backbone',
16
        'neck',
17
        'yolo_head',
18
        'post_process',
19 20
    ]

21 22 23 24 25
    def __init__(self,
                 backbone='DarkNet',
                 neck='YOLOv3FPN',
                 yolo_head='YOLOv3Head',
                 post_process='BBoxPostProcess'):
W
wangguanzhong 已提交
26
        super(YOLOv3, self).__init__()
27
        self.backbone = backbone
28
        self.neck = neck
29
        self.yolo_head = yolo_head
30
        self.post_process = post_process
31

32
    def model_arch(self, ):
33
        # Backbone
W
wangguanzhong 已提交
34
        body_feats = self.backbone(self.inputs)
35

36 37
        # neck
        body_feats = self.neck(body_feats)
38

39 40
        # YOLO Head
        self.yolo_head_outs = self.yolo_head(body_feats)
41

42
    def loss(self, ):
43
        yolo_loss = self.yolo_head.loss(self.inputs, self.yolo_head_outs)
W
wangguanzhong 已提交
44
        return yolo_loss
45

46
    def infer(self, ):
47 48 49
        bbox, bbox_num = self.post_process(self.yolo_head_outs,
                                           self.yolo_head.mask_anchors,
                                           self.inputs['im_size'])
50
        outs = {
W
wangguanzhong 已提交
51
            "bbox": bbox.numpy(),
52
            "bbox_num": bbox_num.numpy(),
W
wangguanzhong 已提交
53
            'im_id': self.inputs['im_id'].numpy()
54 55
        }
        return outs