cascade_rcnn.py 3.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from paddle import fluid
from ppdet.core.workspace import register
from .meta_arch import BaseArch

__all__ = ['CascadeRCNN']


@register
class CascadeRCNN(BaseArch):
    __category__ = 'architecture'
    __shared__ = ['num_stages']
    __inject__ = [
        'anchor',
        'proposal',
        'mask',
        'backbone',
        'rpn_head',
        'bbox_head',
        'mask_head',
    ]

    def __init__(self,
                 anchor,
                 proposal,
                 mask,
                 backbone,
                 rpn_head,
                 bbox_head,
                 mask_head,
                 num_stages=3,
                 *args,
                 **kwargs):
        super(CascadeRCNN, self).__init__(*args, **kwargs)
        self.anchor = anchor
        self.proposal = proposal
        self.mask = mask
        self.backbone = backbone
        self.rpn_head = rpn_head
        self.bbox_head = bbox_head
        self.mask_head = mask_head
        self.num_stages = num_stages

    def model_arch(self, ):
        # Backbone
        bb_out = self.backbone(self.gbd)
        self.gbd.update(bb_out)

        # RPN
        rpn_head_out = self.rpn_head(self.gbd)
        self.gbd.update(rpn_head_out)

        # Anchor
        anchor_out = self.anchor(self.gbd)
        self.gbd.update(anchor_out)

        self.gbd['stage'] = 0
        for i in range(self.num_stages):
            self.gbd.update_v('stage', i)
            # Proposal BBox
            proposal_out = self.proposal(self.gbd)
            self.gbd.update({"proposal_" + str(i): proposal_out})

            # BBox Head
            bbox_head_out = self.bbox_head(self.gbd)
            self.gbd.update({'bbox_head_' + str(i): bbox_head_out})

            refine_bbox_out = self.proposal.refine_bbox(self.gbd)
            self.gbd['proposal_' + str(i)].update(refine_bbox_out)

        if self.gbd['mode'] == 'infer':
            bbox_out = self.proposal.post_process(self.gbd)
            self.gbd.update(bbox_out)

        # Mask 
        mask_out = self.mask(self.gbd)
        self.gbd.update(mask_out)

        # Mask Head 
        mask_head_out = self.mask_head(self.gbd)
        self.gbd.update(mask_head_out)

        if self.gbd['mode'] == 'infer':
            mask_out = self.mask.post_process(self.gbd)
            self.gbd.update(mask_out)

K
Kaipeng Deng 已提交
90
    def get_loss(self, ):
91 92 93
        outs = {}
        losses = []

K
Kaipeng Deng 已提交
94
        rpn_cls_loss, rpn_reg_loss = self.rpn_head.get_loss(self.gbd)
95 96 97 98 99 100 101 102
        outs['loss_rpn_cls'] = rpn_cls_loss
        outs['loss_rpn_reg'] = rpn_reg_loss
        losses.extend([rpn_cls_loss, rpn_reg_loss])

        bbox_cls_loss_list = []
        bbox_reg_loss_list = []
        for i in range(self.num_stages):
            self.gbd.update_v('stage', i)
K
Kaipeng Deng 已提交
103
            bbox_cls_loss, bbox_reg_loss = self.bbox_head.get_loss(self.gbd)
104 105 106 107 108 109 110
            bbox_cls_loss_list.append(bbox_cls_loss)
            bbox_reg_loss_list.append(bbox_reg_loss)
            outs['loss_bbox_cls_' + str(i)] = bbox_cls_loss
            outs['loss_bbox_reg_' + str(i)] = bbox_reg_loss
        losses.extend(bbox_cls_loss_list)
        losses.extend(bbox_reg_loss_list)

K
Kaipeng Deng 已提交
111
        mask_loss = self.mask_head.get_loss(self.gbd)
112 113 114 115 116 117 118
        outs['mask_loss'] = mask_loss
        losses.append(mask_loss)

        loss = fluid.layers.sum(losses)
        outs['loss'] = loss
        return outs

K
Kaipeng Deng 已提交
119
    def get_pred(self, ):
120 121 122 123 124 125 126
        outs = {
            'bbox': self.gbd['predicted_bbox'].numpy(),
            'bbox_nums': self.gbd['predicted_bbox_nums'].numpy(),
            'mask': self.gbd['predicted_mask'].numpy(),
            'im_id': self.gbd['im_id'].numpy(),
        }
        return inputs