ssd.py 5.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

19
from collections import OrderedDict
20

21 22 23
import paddle.fluid as fluid

from ppdet.experimental import mixed_precision_global_state
24
from ppdet.core.workspace import register
25
from ppdet.modeling.ops import SSDOutputDecoder
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42

__all__ = ['SSD']


@register
class SSD(object):
    """
    Single Shot MultiBox Detector, see https://arxiv.org/abs/1512.02325

    Args:
        backbone (object): backbone instance
        multi_box_head (object): `MultiBoxHead` instance
        output_decoder (object): `SSDOutputDecoder` instance
        num_classes (int): number of output classes
    """

    __category__ = 'architecture'
43
    __inject__ = ['backbone', 'multi_box_head', 'output_decoder']
44
    __shared__ = ['num_classes']
45 46 47 48 49 50 51 52 53 54 55 56 57

    def __init__(self,
                 backbone,
                 multi_box_head='MultiBoxHead',
                 output_decoder=SSDOutputDecoder().__dict__,
                 num_classes=21):
        super(SSD, self).__init__()
        self.backbone = backbone
        self.multi_box_head = multi_box_head
        self.num_classes = num_classes
        self.output_decoder = output_decoder
        if isinstance(output_decoder, dict):
            self.output_decoder = SSDOutputDecoder(**output_decoder)
58

59
    def build(self, feed_vars, mode='train'):
60 61
        im = feed_vars['image']
        if mode == 'train' or mode == 'eval':
62 63
            gt_bbox = feed_vars['gt_bbox']
            gt_class = feed_vars['gt_class']
64

65 66 67 68 69 70
        mixed_precision_enabled = mixed_precision_global_state() is not None
        # cast inputs to FP16
        if mixed_precision_enabled:
            im = fluid.layers.cast(im, 'float16')

        # backbone
71
        body_feats = self.backbone(im)
72 73 74 75 76 77 78 79 80

        if isinstance(body_feats, OrderedDict):
            body_feat_names = list(body_feats.keys())
            body_feats = [body_feats[name] for name in body_feat_names]

        # cast features back to FP32
        if mixed_precision_enabled:
            body_feats = [fluid.layers.cast(v, 'float32') for v in body_feats]

81 82 83 84
        locs, confs, box, box_var = self.multi_box_head(
            inputs=body_feats, image=im, num_classes=self.num_classes)

        if mode == 'train':
85
            loss = fluid.layers.ssd_loss(locs, confs, gt_bbox, gt_class, box,
86 87 88 89 90
                                         box_var)
            loss = fluid.layers.reduce_sum(loss)
            return {'loss': loss}
        else:
            pred = self.output_decoder(locs, confs, box, box_var)
91
            return {'bbox': pred}
92

93 94 95 96 97
    def _inputs_def(self, image_shape):
        im_shape = [None] + image_shape
        # yapf: disable
        inputs_def = {
            'image':        {'shape': im_shape,  'dtype': 'float32', 'lod_level': 0},
Q
qingqing01 已提交
98
            'im_id':        {'shape': [None, 1], 'dtype': 'int64',   'lod_level': 0},
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
            'gt_bbox':      {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1},
            'gt_class':     {'shape': [None, 1], 'dtype': 'int32',   'lod_level': 1},
            'im_shape':     {'shape': [None, 3], 'dtype': 'int32',   'lod_level': 0},
            'is_difficult': {'shape': [None, 1], 'dtype': 'int32',   'lod_level': 1},
        }
        # yapf: enable
        return inputs_def

    def build_inputs(
            self,
            image_shape=[3, None, None],
            fields=['image', 'im_id', 'gt_bbox', 'gt_class'],  # for train
            use_dataloader=True,
            iterable=False):
        inputs_def = self._inputs_def(image_shape)
        feed_vars = OrderedDict([(key, fluid.data(
            name=key,
            shape=inputs_def[key]['shape'],
            dtype=inputs_def[key]['dtype'],
            lod_level=inputs_def[key]['lod_level'])) for key in fields])
        loader = fluid.io.DataLoader.from_generator(
            feed_list=list(feed_vars.values()),
            capacity=64,
            use_double_buffer=True,
            iterable=iterable) if use_dataloader else None
        return feed_vars, loader

126
    def train(self, feed_vars):
127
        return self.build(feed_vars, 'train')
128 129

    def eval(self, feed_vars):
130
        return self.build(feed_vars, 'eval')
131 132

    def test(self, feed_vars):
133 134 135 136
        return self.build(feed_vars, 'test')

    def is_bbox_normalized(self):
        # SSD use output_decoder in output layers, bbox is normalized
137
        # to range [0, 1], is_bbox_normalized is used in eval.py and infer.py
138
        return True