yolo.py 1.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from ppdet.core.workspace import register
from .meta_arch import BaseArch

__all__ = ['YOLOv3']


@register
class YOLOv3(BaseArch):
    __category__ = 'architecture'
    __inject__ = [
        'anchor',
        'backbone',
        'yolo_head',
    ]

W
wangguanzhong 已提交
20 21
    def __init__(self, anchor, backbone, yolo_head):
        super(YOLOv3, self).__init__()
22 23 24 25
        self.anchor = anchor
        self.backbone = backbone
        self.yolo_head = yolo_head

26
    def model_arch(self, ):
27
        # Backbone
W
wangguanzhong 已提交
28
        body_feats = self.backbone(self.inputs)
29 30

        # YOLO Head
W
wangguanzhong 已提交
31
        self.yolo_head_out = self.yolo_head(body_feats)
32 33

        # Anchor
W
wangguanzhong 已提交
34
        self.anchors, self.anchor_masks, self.mask_anchors = self.anchor()
35

36
    def loss(self, ):
W
wangguanzhong 已提交
37 38 39 40
        yolo_loss = self.yolo_head.loss(self.inputs, self.yolo_head_out,
                                        self.anchors, self.anchor_masks,
                                        self.mask_anchors)
        return yolo_loss
41

42
    def infer(self, ):
W
wangguanzhong 已提交
43 44
        bbox, bbox_num = self.anchor.post_process(
            self.inputs['im_size'], self.yolo_head_out, self.mask_anchors)
45
        outs = {
W
wangguanzhong 已提交
46 47 48
            "bbox": bbox.numpy(),
            "bbox_num": bbox_num,
            'im_id': self.inputs['im_id'].numpy()
49 50
        }
        return outs