# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function from collections import OrderedDict import paddle.fluid as fluid from ppdet.experimental import mixed_precision_global_state from ppdet.core.workspace import register from ppdet.modeling.ops import SSDOutputDecoder __all__ = ['SSD'] @register class SSD(object): """ Single Shot MultiBox Detector, see https://arxiv.org/abs/1512.02325 Args: backbone (object): backbone instance multi_box_head (object): `MultiBoxHead` instance output_decoder (object): `SSDOutputDecoder` instance num_classes (int): number of output classes """ __category__ = 'architecture' __inject__ = ['backbone', 'multi_box_head', 'output_decoder'] __shared__ = ['num_classes'] def __init__(self, backbone, multi_box_head='MultiBoxHead', output_decoder=SSDOutputDecoder().__dict__, num_classes=21): super(SSD, self).__init__() self.backbone = backbone self.multi_box_head = multi_box_head self.num_classes = num_classes self.output_decoder = output_decoder if isinstance(output_decoder, dict): self.output_decoder = SSDOutputDecoder(**output_decoder) def build(self, feed_vars, mode='train'): im = feed_vars['image'] if mode == 'train' or mode == 'eval': gt_bbox = feed_vars['gt_bbox'] gt_class = feed_vars['gt_class'] mixed_precision_enabled = mixed_precision_global_state() is not None # cast inputs to FP16 if mixed_precision_enabled: im = fluid.layers.cast(im, 'float16') # backbone body_feats = self.backbone(im) if isinstance(body_feats, OrderedDict): body_feat_names = list(body_feats.keys()) body_feats = [body_feats[name] for name in body_feat_names] # cast features back to FP32 if mixed_precision_enabled: body_feats = [fluid.layers.cast(v, 'float32') for v in body_feats] locs, confs, box, box_var = self.multi_box_head( inputs=body_feats, image=im, num_classes=self.num_classes) if mode == 'train': loss = fluid.layers.ssd_loss(locs, confs, gt_bbox, gt_class, box, box_var) loss = fluid.layers.reduce_sum(loss) return {'loss': loss} else: pred = self.output_decoder(locs, confs, box, box_var) return {'bbox': pred} def _inputs_def(self, image_shape): im_shape = [None] + image_shape # yapf: disable inputs_def = { 'image': {'shape': im_shape, 'dtype': 'float32', 'lod_level': 0}, 'im_id': {'shape': [None, 1], 'dtype': 'int64', 'lod_level': 0}, 'gt_bbox': {'shape': [None, 4], 'dtype': 'float32', 'lod_level': 1}, 'gt_class': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, 'im_shape': {'shape': [None, 3], 'dtype': 'int32', 'lod_level': 0}, 'is_difficult': {'shape': [None, 1], 'dtype': 'int32', 'lod_level': 1}, } # yapf: enable return inputs_def def build_inputs( self, image_shape=[3, None, None], fields=['image', 'im_id', 'gt_bbox', 'gt_class'], # for train use_dataloader=True, iterable=False): inputs_def = self._inputs_def(image_shape) feed_vars = OrderedDict([(key, fluid.data( name=key, shape=inputs_def[key]['shape'], dtype=inputs_def[key]['dtype'], lod_level=inputs_def[key]['lod_level'])) for key in fields]) loader = fluid.io.DataLoader.from_generator( feed_list=list(feed_vars.values()), capacity=16, use_double_buffer=True, iterable=iterable) if use_dataloader else None return feed_vars, loader def train(self, feed_vars): return self.build(feed_vars, 'train') def eval(self, feed_vars): return self.build(feed_vars, 'eval') def test(self, feed_vars): return self.build(feed_vars, 'test') def is_bbox_normalized(self): # SSD use output_decoder in output layers, bbox is normalized # to range [0, 1], is_bbox_normalized is used in eval.py and infer.py return True