yolo.py 1.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from ppdet.core.workspace import register
from .meta_arch import BaseArch

__all__ = ['YOLOv3']


@register
class YOLOv3(BaseArch):
    __category__ = 'architecture'
    __inject__ = [
        'anchor',
        'backbone',
        'yolo_head',
    ]

20 21
    def __init__(self, anchor, backbone, yolo_head, *args, **kwargs):
        super(YOLOv3, self).__init__(*args, **kwargs)
22 23 24 25
        self.anchor = anchor
        self.backbone = backbone
        self.yolo_head = yolo_head

26
    def model_arch(self, ):
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
        # Backbone
        bb_out = self.backbone(self.gbd)
        self.gbd.update(bb_out)

        # YOLO Head
        yolo_head_out = self.yolo_head(self.gbd)
        self.gbd.update(yolo_head_out)

        # Anchor
        anchor_out = self.anchor(self.gbd)
        self.gbd.update(anchor_out)

        if self.gbd['mode'] == 'infer':
            bbox_out = self.anchor.post_process(self.gbd)
            self.gbd.update(bbox_out)

43 44 45
    def loss(self, ):
        yolo_loss = self.yolo_head.loss(self.gbd)
        out = {'loss': yolo_loss}
46 47
        return out

48
    def infer(self, ):
49
        outs = {
50
            "bbox": self.gbd['predicted_bbox'].numpy(),
51 52
            "bbox_nums": self.gbd['predicted_bbox_nums'],
            'im_id': self.gbd['im_id'].numpy()
53 54
        }
        return outs