faster_rcnn.py 3.3 KB
Newer Older
F
FDInSky 已提交
1 2 3 4
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

5
import paddle
F
FDInSky 已提交
6 7 8 9 10 11 12 13 14 15
from ppdet.core.workspace import register
from .meta_arch import BaseArch

__all__ = ['FasterRCNN']


@register
class FasterRCNN(BaseArch):
    __category__ = 'architecture'
    __inject__ = [
16 17
        'anchor', 'proposal', 'backbone', 'neck', 'rpn_head', 'bbox_head',
        'bbox_post_process'
F
FDInSky 已提交
18 19
    ]

20 21 22 23 24 25 26 27 28
    def __init__(self,
                 anchor,
                 proposal,
                 backbone,
                 rpn_head,
                 bbox_head,
                 bbox_post_process,
                 neck=None):
        super(FasterRCNN, self).__init__()
F
FDInSky 已提交
29 30 31 32 33
        self.anchor = anchor
        self.proposal = proposal
        self.backbone = backbone
        self.rpn_head = rpn_head
        self.bbox_head = bbox_head
34 35
        self.bbox_post_process = bbox_post_process
        self.neck = neck
F
FDInSky 已提交
36

37
    def model_arch(self):
F
FDInSky 已提交
38
        # Backbone
39 40 41 42 43 44
        body_feats = self.backbone(self.inputs)
        spatial_scale = 0.0625

        # Neck
        if self.neck is not None:
            body_feats, spatial_scale = self.neck(body_feats)
F
FDInSky 已提交
45 46

        # RPN
47 48 49 50 51
        # rpn_head returns two list: rpn_feat, rpn_head_out
        # each element in rpn_feats contains rpn feature on each level,
        # and the length is 1 when the neck is not applied.
        # each element in rpn_head_out contains (rpn_rois_score, rpn_rois_delta)
        rpn_feat, self.rpn_head_out = self.rpn_head(self.inputs, body_feats)
F
FDInSky 已提交
52 53

        # Anchor
54 55 56
        # anchor_out returns a list,
        # each element contains (anchor, anchor_var)
        self.anchor_out = self.anchor(rpn_feat)
F
FDInSky 已提交
57

58 59 60
        # Proposal RoI
        # compute targets here when training
        rois = self.proposal(self.inputs, self.rpn_head_out, self.anchor_out)
F
FDInSky 已提交
61
        # BBox Head
62 63
        bbox_feat, self.bbox_head_out, self.bbox_head_feat_func = self.bbox_head(
            body_feats, rois, spatial_scale)
F
FDInSky 已提交
64

65 66 67 68 69 70 71 72 73 74 75 76
        if self.inputs['mode'] == 'infer':
            bbox_pred, bboxes = self.bbox_head.get_prediction(
                self.bbox_head_out, rois)
            # Refine bbox by the output from bbox_head at test stage
            self.bboxes = self.bbox_post_process(bbox_pred, bboxes,
                                                 self.inputs['im_shape'],
                                                 self.inputs['scale_factor'])

        else:
            # Proposal RoI for Mask branch
            # bboxes update at training stage only
            bbox_targets = self.proposal.get_targets()[0]
F
FDInSky 已提交
77

K
Kaipeng Deng 已提交
78
    def get_loss(self, ):
79
        loss = {}
F
FDInSky 已提交
80

81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
        # RPN loss
        rpn_loss_inputs = self.anchor.generate_loss_inputs(
            self.inputs, self.rpn_head_out, self.anchor_out)
        loss_rpn = self.rpn_head.get_loss(rpn_loss_inputs)
        loss.update(loss_rpn)

        # BBox loss
        bbox_targets = self.proposal.get_targets()
        loss_bbox = self.bbox_head.get_loss(self.bbox_head_out, bbox_targets)
        loss.update(loss_bbox)
        total_loss = paddle.add_n(list(loss.values()))
        loss.update({'loss': total_loss})
        return loss

    def get_pred(self, return_numpy=True):
        bbox, bbox_num = self.bboxes
        output = {
            'bbox': bbox.numpy(),
            'bbox_num': bbox_num.numpy(),
            'im_id': self.inputs['im_id'].numpy()
F
FDInSky 已提交
101
        }
102 103

        return output