diff --git a/configs/_base_/datasets/voc.yml b/configs/_base_/datasets/voc.yml new file mode 100644 index 0000000000000000000000000000000000000000..de4d78eda57792f857eac95141cd28b5c34a6175 --- /dev/null +++ b/configs/_base_/datasets/voc.yml @@ -0,0 +1,18 @@ +metric: VOC +num_classes: 20 + +TrainDataset: + !VOCDataSet + dataset_dir: dataset/voc + anno_path: trainval.txt + label_list: label_list.txt + +EvalDataset: + !VOCDataSet + dataset_dir: dataset/voc + anno_path: test.txt + label_list: label_list.txt + +TestDataset: + !ImageFolder + anno_path: dataset/voc/label_list.txt diff --git a/configs/_base_/models/ssd_vgg16_300.yml b/configs/_base_/models/ssd_vgg16_300.yml new file mode 100644 index 0000000000000000000000000000000000000000..497292b5ab37b06e7d9331327447e77f26f87080 --- /dev/null +++ b/configs/_base_/models/ssd_vgg16_300.yml @@ -0,0 +1,41 @@ +architecture: SSD +pretrain_weights: https://paddlemodels.bj.bcebos.com/object_detection/dygraph/VGG16_caffe_pretrained.pdparams +weights: output/ssd_vgg16/model_final + +# Model Achitecture +SSD: + # model feat info flow + backbone: VGG + ssd_head: SSDHead + # post process + post_process: BBoxPostProcess + +VGG: + depth: 16 + normalizations: [20., -1, -1, -1, -1, -1] + +SSDHead: + in_channels: [512, 1024, 512, 256, 256, 256] + anchor_generator: AnchorGeneratorSSD + +AnchorGeneratorSSD: + steps: [8, 16, 32, 64, 100, 300] + aspect_ratios: [[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]] + min_ratio: 20 + max_ratio: 90 + min_sizes: [30.0, 60.0, 111.0, 162.0, 213.0, 264.0] + max_sizes: [60.0, 111.0, 162.0, 213.0, 264.0, 315.0] + offset: 0.5 + flip: true + min_max_aspect_ratios_order: true + +BBoxPostProcess: + decode: + name: SSDBox + nms: + name: MultiClassNMS + keep_top_k: 200 + score_threshold: 0.01 + nms_threshold: 0.45 + nms_top_k: 400 + nms_eta: 1.0 diff --git a/configs/_base_/readers/ssd_reader.yml b/configs/_base_/readers/ssd_reader.yml new file mode 100644 index 0000000000000000000000000000000000000000..f493c402214de60302b438628db2af5cd2ba83e0 --- /dev/null +++ b/configs/_base_/readers/ssd_reader.yml @@ -0,0 +1,46 @@ +worker_num: 2 +TrainReader: + inputs_def: + fields: ['image', 'gt_bbox', 'gt_class'] + num_max_boxes: 90 + + sample_transforms: + - DecodeOp: {} + - RandomDistortOp: {brightness: [0.5, 1.125, 0.875], random_apply: False} + - RandomExpandOp: {fill_value: [104., 117., 123.]} + - RandomCropOp: {allow_no_crop: true} + - RandomFlipOp: {} + - NormalizeBoxOp: {} + - ResizeOp: {target_size: [300, 300], keep_ratio: False, interp: 1} + - PadBoxOp: {num_max_boxes: 90} + + batch_transforms: + - NormalizeImageOp: {mean: [104., 117., 123.], std: [1., 1., 1.], is_scale: false} + - PermuteOp: {} + + batch_size: 8 + shuffle: true + drop_last: true + + +EvalReader: + inputs_def: + fields: ['image', 'im_shape', 'scale_factor', 'im_id', 'gt_bbox', 'gt_class', 'difficult'] + sample_transforms: + - DecodeOp: {} + - ResizeOp: {target_size: [300, 300], keep_ratio: False, interp: 1} + - NormalizeImageOp: {mean: [104., 117., 123.], std: [1., 1., 1.], is_scale: false} + - PermuteOp: {} + batch_size: 1 + drop_empty: false + +TestReader: + inputs_def: + image_shape: [3, 300, 300] + fields: ['image', 'im_shape', 'scale_factor', 'im_id'] + sample_transforms: + - DecodeOp: {} + - ResizeOp: {target_size: [300, 300], keep_ratio: False, interp: 1} + - NormalizeImageOp: {mean: [104., 117., 123.], std: [1., 1., 1.], is_scale: false} + - PermuteOp: {} + batch_size: 1 diff --git a/configs/ssd_vgg16_300_120e_coco.yml b/configs/ssd_vgg16_300_120e_coco.yml new file mode 100644 index 0000000000000000000000000000000000000000..a94e3cf3e70b91033678e8d874eaeebaf0c0a907 --- /dev/null +++ b/configs/ssd_vgg16_300_120e_coco.yml @@ -0,0 +1,7 @@ +_BASE_: [ + './_base_/models/ssd_vgg16_300.yml', + './_base_/optimizers/ssd_120e.yml', + './_base_/datasets/coco.yml', + './_base_/readers/ssd_reader.yml', + './_base_/runtime.yml', +] diff --git a/configs/ssd_vgg16_300_240e_voc.yml b/configs/ssd_vgg16_300_240e_voc.yml new file mode 100644 index 0000000000000000000000000000000000000000..0052a3e03fbfaae87a232e1489bacfc6f529b542 --- /dev/null +++ b/configs/ssd_vgg16_300_240e_voc.yml @@ -0,0 +1,7 @@ +_BASE_: [ + './_base_/models/ssd_vgg16_300.yml', + './_base_/optimizers/ssd_240e.yml', + './_base_/datasets/voc.yml', + './_base_/readers/ssd_reader.yml', + './_base_/runtime.yml', +] diff --git a/deploy/python/infer.py b/deploy/python/infer.py index bf2b201ad4d38982934b04f42f6ccd0773c9bb51..933966d6219c1bc0ea603e24da5a862e20bf7ad9 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -33,6 +33,7 @@ from paddle.inference import create_predictor SUPPORT_MODELS = { 'YOLO', 'RCNN', + 'SSD', } @@ -73,7 +74,7 @@ class Detector(object): def postprocess(self, np_boxes, np_masks, inputs, threshold=0.5): # postprocess output of predictor results = {} - if self.pred_config.arch in ['SSD', 'Face']: + if self.pred_config.arch in ['Face']: h, w = inputs['im_shape'] scale_y, scale_x = inputs['scale_factor'] w, h = float(h) / scale_y, float(w) / scale_x diff --git a/docs/MODEL_ZOO_cn.md b/docs/MODEL_ZOO_cn.md index 88976c300472b50cb593f6a91d9c65b41330959e..b281e117fbfc13de7b74adfa36d80ea777a19584 100644 --- a/docs/MODEL_ZOO_cn.md +++ b/docs/MODEL_ZOO_cn.md @@ -41,3 +41,11 @@ Paddle提供基于ImageNet的骨架网络预训练模型。所有预训练模型 | ResNet50-FPN | Cascade Faster | 1 | 1x | ---- | 41.1 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/cascade_rcnn_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/cascade_faster_rcnn_r50_fpn_1x_coco.yml) | | ResNet50-FPN | Cascade Mask | 1 | 1x | ---- | 41.6 | 35.3 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/cascade_mask_rcnn_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/cascade_mask_rcnn_r50_fpn_1x_coco.yml) | | DarkNet53 | YOLOv3 | 1 | 270e | ---- | 39.0 | - | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/yolov3_darknet53_270e_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/yolov3_darknet53_270e_coco.yml) | + +### SSD on Pascal VOC + +| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 | +| :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: | +| VGG | SSD | 8 | 240e | ---- | 78.2 | [下载链接](https://paddlemodels.bj.bcebos.com/object_detection/dygraph/ssd_vgg16_300_240e_voc.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/dygraph/configs/ssd_vgg16_300_240e_voc.yml) | + +**注意:** SSD使用4GPU训练,训练240个epoch diff --git a/ppdet/data/reader.py b/ppdet/data/reader.py index b53dca461ba25cd6f7c13bab523f840a94774f4b..94819a4f229ebfabafdb6dd8158d0dc53467097e 100644 --- a/ppdet/data/reader.py +++ b/ppdet/data/reader.py @@ -2,6 +2,7 @@ import copy import traceback import logging import threading +import six import sys if sys.version_info >= (3, 0): import queue as Queue @@ -118,25 +119,25 @@ class BaseDataLoader(object): batch_sampler=None, return_list=False, use_prefetch=True): - self._dataset = dataset - self._dataset.parse_dataset(self.with_background) + self.dataset = dataset + self.dataset.parse_dataset(self.with_background) # get data - self._dataset.set_out(self._sample_transforms, - copy.deepcopy(self._fields)) + self.dataset.set_out(self._sample_transforms, + copy.deepcopy(self._fields)) # set kwargs - self._dataset.set_kwargs(**self.kwargs) + self.dataset.set_kwargs(**self.kwargs) # batch sampler if batch_sampler is None: self._batch_sampler = DistributedBatchSampler( - self._dataset, + self.dataset, batch_size=self.batch_size, shuffle=self.shuffle, drop_last=self.drop_last) else: self._batch_sampler = batch_sampler - loader = DataLoader( - dataset=self._dataset, + self.loader = DataLoader( + dataset=self.dataset, batch_sampler=self._batch_sampler, collate_fn=self._batch_transforms, num_workers=worker_num, @@ -144,8 +145,29 @@ class BaseDataLoader(object): return_list=return_list, use_buffer_reader=use_prefetch, use_shared_memory=False) + self.loader = iter(self.loader) - return loader, len(self._batch_sampler) + return self + + def __len__(self): + return len(self._batch_sampler) + + def __iter__(self): + return self + + def __next__(self): + # pack {filed_name: field_data} here + # looking forward to support dictionary + # data structure in paddle.io.DataLoader + try: + data = next(self.loader) + return {k: v for k, v in zip(self._fields, data)} + except StopIteration: + six.reraise(*sys.exc_info()) + + def next(self): + # python2 compatibility + return self.__next__() @register diff --git a/ppdet/data/source/__init__.py b/ppdet/data/source/__init__.py index 573fe527ed581e5cb7d302bdd42b5d9a3f1a7f07..60c205d140cf8ac6a631be473ab816009c82ac6d 100644 --- a/ppdet/data/source/__init__.py +++ b/ppdet/data/source/__init__.py @@ -14,9 +14,9 @@ from . import coco # TODO add voc and widerface dataset -#from . import voc +from . import voc #from . import widerface from .coco import * -#from .voc import * +from .voc import * #from .widerface import * diff --git a/ppdet/data/source/voc.py b/ppdet/data/source/voc.py index 498ca5d16426cc0e67c8119aeff69c63e89b26b3..41c5b4c33c7eadcce56a918bd07c4ea53a25c7f5 100644 --- a/ppdet/data/source/voc.py +++ b/ppdet/data/source/voc.py @@ -19,14 +19,14 @@ import xml.etree.ElementTree as ET from ppdet.core.workspace import register, serializable -from .dataset import DataSet +from .dataset import DetDataset import logging logger = logging.getLogger(__name__) @register @serializable -class VOCDataSet(DataSet): +class VOCDataSet(DetDataset): """ Load dataset with PascalVOC format. @@ -38,8 +38,6 @@ class VOCDataSet(DataSet): image_dir (str): directory for images. anno_path (str): voc annotation file path. sample_num (int): number of samples to load, -1 means all. - use_default_label (bool): whether use the default mapping of - label to integer index. Default True. label_list (str): if use_default_label is False, will load mapping between category and class index. """ @@ -49,32 +47,15 @@ class VOCDataSet(DataSet): image_dir=None, anno_path=None, sample_num=-1, - use_default_label=True, - label_list='label_list.txt'): + label_list=None): super(VOCDataSet, self).__init__( + dataset_dir=dataset_dir, image_dir=image_dir, anno_path=anno_path, - sample_num=sample_num, - dataset_dir=dataset_dir) - # roidbs is list of dict whose structure is: - # { - # 'im_file': im_fname, # image file name - # 'im_id': im_id, # image id - # 'h': im_h, # height of image - # 'w': im_w, # width - # 'is_crowd': is_crowd, - # 'gt_class': gt_class, - # 'gt_score': gt_score, - # 'gt_bbox': gt_bbox, - # 'difficult': difficult - # } - self.roidbs = None - # 'cname2id' is a dict to map category name to class id - self.cname2cid = None - self.use_default_label = use_default_label + sample_num=sample_num) self.label_list = label_list - def load_roidb_and_cname2cid(self, with_background=True): + def parse_dataset(self, with_background=True): anno_path = os.path.join(self.dataset_dir, self.anno_path) image_dir = os.path.join(self.dataset_dir, self.image_dir) @@ -86,7 +67,7 @@ class VOCDataSet(DataSet): records = [] ct = 0 cname2cid = {} - if not self.use_default_label: + if self.label_list: label_path = os.path.join(self.dataset_dir, self.label_list) if not os.path.exists(label_path): raise ValueError("label_list {} does not exists".format( @@ -183,6 +164,9 @@ class VOCDataSet(DataSet): logger.debug('{} samples in file {}'.format(ct, anno_path)) self.roidbs, self.cname2cid = records, cname2cid + def get_label_list(self): + return os.path.join(self.dataset_dir, self.label_list) + def pascalvoc_label(with_background=True): labels_map = { diff --git a/ppdet/data/transform/operator.py b/ppdet/data/transform/operator.py index 7f7a7e5d52595a1367a39bef94de26669eb029c5..f98e95f06b7846d1760d2e184ce1659560d79902 100644 --- a/ppdet/data/transform/operator.py +++ b/ppdet/data/transform/operator.py @@ -1668,7 +1668,7 @@ class PadBoxOp(BaseOperator): # in training, for example in op ExpandImage, # the bbox and gt_class is expandded, but the difficult is not, # so, judging by it's length - if 'is_difficult' in sample: + if 'difficult' in sample: pad_diff = np.zeros((num_max, ), dtype=np.int32) if gt_num > 0: pad_diff[:gt_num] = sample['difficult'][:gt_num, 0] diff --git a/ppdet/modeling/architecture/__init__.py b/ppdet/modeling/architecture/__init__.py index e83f20b745301e4ea1d959b492a08ab63d9dcf86..c2203ab98b991486755decc4ef283bd6c16c2893 100644 --- a/ppdet/modeling/architecture/__init__.py +++ b/ppdet/modeling/architecture/__init__.py @@ -10,9 +10,11 @@ from . import faster_rcnn from . import mask_rcnn from . import yolo from . import cascade_rcnn +from . import ssd from .meta_arch import * from .faster_rcnn import * from .mask_rcnn import * from .yolo import * from .cascade_rcnn import * +from .ssd import * diff --git a/ppdet/modeling/architecture/meta_arch.py b/ppdet/modeling/architecture/meta_arch.py index 2731b3d456ff0fc08fb71bb57538f3715b9ca26f..b1f01f42c2dc1691c61e65f144653afecd2d7b4a 100644 --- a/ppdet/modeling/architecture/meta_arch.py +++ b/ppdet/modeling/architecture/meta_arch.py @@ -15,17 +15,8 @@ class BaseArch(nn.Layer): def __init__(self): super(BaseArch, self).__init__() - def forward(self, - input_tensor=None, - data=None, - input_def=None, - mode='infer'): - if input_tensor is None: - assert data is not None and input_def is not None - self.inputs = self.build_inputs(data, input_def) - else: - self.inputs = input_tensor - + def forward(self, inputs, mode='infer'): + self.inputs = inputs self.inputs['mode'] = mode self.model_arch() diff --git a/ppdet/modeling/architecture/ssd.py b/ppdet/modeling/architecture/ssd.py new file mode 100644 index 0000000000000000000000000000000000000000..a1b4aec20d9130d544df08a4993f03ab1e906100 --- /dev/null +++ b/ppdet/modeling/architecture/ssd.py @@ -0,0 +1,50 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import fluid +from ppdet.core.workspace import register +from .meta_arch import BaseArch + +__all__ = ['SSD'] + + +@register +class SSD(BaseArch): + __category__ = 'architecture' + __inject__ = ['backbone', 'neck', 'ssd_head', 'post_process'] + + def __init__(self, backbone, ssd_head, post_process, neck=None): + super(SSD, self).__init__() + self.backbone = backbone + self.neck = neck + self.ssd_head = ssd_head + self.post_process = post_process + + def model_arch(self): + # Backbone + body_feats = self.backbone(self.inputs) + + # Neck + if self.neck is not None: + body_feats, spatial_scale = self.neck(body_feats) + + # SSD Head + self.ssd_head_outs, self.anchors = self.ssd_head(body_feats, + self.inputs['image']) + + def get_loss(self, ): + loss = self.ssd_head.get_loss(self.ssd_head_outs, self.inputs, + self.anchors) + return {"loss": loss} + + def get_pred(self, return_numpy=True): + output = {} + bbox, bbox_num = self.post_process(self.ssd_head_outs, self.anchors, + self.inputs['im_shape'], + self.inputs['scale_factor']) + outs = { + "bbox": bbox, + "bbox_num": bbox_num, + } + return outs diff --git a/ppdet/modeling/backbone/__init__.py b/ppdet/modeling/backbone/__init__.py index a4a38be4acc28243b86f7c110e18416de482b3e6..304c5ded5cb997ad365ee49cf6e1a702c27767b8 100644 --- a/ppdet/modeling/backbone/__init__.py +++ b/ppdet/modeling/backbone/__init__.py @@ -1,5 +1,7 @@ +from . import vgg from . import resnet from . import darknet +from .vgg import * from .resnet import * from .darknet import * diff --git a/ppdet/modeling/backbone/vgg.py b/ppdet/modeling/backbone/vgg.py new file mode 100755 index 0000000000000000000000000000000000000000..7f770b59a2b7c46d7270673e97a227a11128f631 --- /dev/null +++ b/ppdet/modeling/backbone/vgg.py @@ -0,0 +1,206 @@ +from __future__ import division + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +from paddle import ParamAttr +from paddle.fluid.regularizer import L2Decay +from paddle.nn import Conv2D, MaxPool2D +from ppdet.core.workspace import register, serializable + +__all__ = ['VGG'] + +VGG_cfg = {16: [2, 2, 3, 3, 3], 19: [2, 2, 4, 4, 4]} + + +class ConvBlock(nn.Layer): + def __init__(self, + in_channels, + out_channels, + groups, + pool_size=2, + pool_stride=2, + pool_padding=0, + name=None): + super(ConvBlock, self).__init__() + + self.groups = groups + self.conv0 = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding=1, + weight_attr=ParamAttr(name=name + "1_weights"), + bias_attr=ParamAttr(name=name + "1_bias")) + self.conv_out_list = [] + for i in range(1, groups): + conv_out = self.add_sublayer( + 'conv{}'.format(i), + Conv2D( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + padding=1, + weight_attr=ParamAttr( + name=name + "{}_weights".format(i + 1)), + bias_attr=ParamAttr(name=name + "{}_bias".format(i + 1)))) + self.conv_out_list.append(conv_out) + + self.pool = MaxPool2D( + kernel_size=pool_size, + stride=pool_stride, + padding=pool_padding, + ceil_mode=True) + + def forward(self, inputs): + out = self.conv0(inputs) + out = F.relu(out) + for conv_i in self.conv_out_list: + out = conv_i(out) + out = F.relu(out) + pool = self.pool(out) + return out, pool + + +class ExtraBlock(nn.Layer): + def __init__(self, + in_channels, + mid_channels, + out_channels, + padding, + stride, + kernel_size, + name=None): + super(ExtraBlock, self).__init__() + + self.conv0 = Conv2D( + in_channels=in_channels, + out_channels=mid_channels, + kernel_size=1, + stride=1, + padding=0) + self.conv1 = Conv2D( + in_channels=mid_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding) + + def forward(self, inputs): + out = self.conv0(inputs) + out = F.relu(out) + out = self.conv1(out) + out = F.relu(out) + return out + + +class L2NormScale(nn.Layer): + def __init__(self, num_channels, scale=1.0): + super(L2NormScale, self).__init__() + self.scale = self.create_parameter( + attr=ParamAttr(initializer=paddle.nn.initializer.Constant(scale)), + shape=[num_channels]) + + def forward(self, inputs): + out = F.normalize(inputs, axis=1, epsilon=1e-10) + # out = self.scale.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as( + # out) * out + out = self.scale.unsqueeze(0).unsqueeze(2).unsqueeze(3) * out + return out + + +@register +@serializable +class VGG(nn.Layer): + def __init__(self, + depth=16, + normalizations=[20., -1, -1, -1, -1, -1], + extra_block_filters=[[256, 512, 1, 2, 3], [128, 256, 1, 2, 3], + [128, 256, 0, 1, 3], + [128, 256, 0, 1, 3]]): + super(VGG, self).__init__() + + assert depth in [16, 19], \ + "depth as 16/19 supported currently, but got {}".format(depth) + self.depth = depth + self.groups = VGG_cfg[depth] + self.normalizations = normalizations + self.extra_block_filters = extra_block_filters + + self.conv_block_0 = ConvBlock( + 3, 64, self.groups[0], 2, 2, 0, name="conv1_") + self.conv_block_1 = ConvBlock( + 64, 128, self.groups[1], 2, 2, 0, name="conv2_") + self.conv_block_2 = ConvBlock( + 128, 256, self.groups[2], 2, 2, 0, name="conv3_") + self.conv_block_3 = ConvBlock( + 256, 512, self.groups[3], 2, 2, 0, name="conv4_") + self.conv_block_4 = ConvBlock( + 512, 512, self.groups[4], 3, 1, 1, name="conv5_") + + self.fc6 = Conv2D( + in_channels=512, + out_channels=1024, + kernel_size=3, + stride=1, + padding=6, + dilation=6) + self.fc7 = Conv2D( + in_channels=1024, + out_channels=1024, + kernel_size=1, + stride=1, + padding=0) + + # extra block + self.extra_convs = [] + last_channels = 1024 + for i, v in enumerate(self.extra_block_filters): + assert len(v) == 5, "extra_block_filters size not fix" + extra_conv = self.add_sublayer("conv{}".format(6 + i), + ExtraBlock(last_channels, v[0], v[1], + v[2], v[3], v[4])) + last_channels = v[1] + self.extra_convs.append(extra_conv) + + self.norms = [] + for i, n in enumerate(self.normalizations): + if n != -1: + norm = self.add_sublayer("norm{}".format(i), + L2NormScale( + self.extra_block_filters[i][1], n)) + else: + norm = None + self.norms.append(norm) + + def forward(self, inputs): + outputs = [] + + conv, pool = self.conv_block_0(inputs['image']) + conv, pool = self.conv_block_1(pool) + conv, pool = self.conv_block_2(pool) + conv, pool = self.conv_block_3(pool) + outputs.append(conv) + + conv, pool = self.conv_block_4(pool) + out = self.fc6(pool) + out = F.relu(out) + out = self.fc7(out) + out = F.relu(out) + outputs.append(out) + + if not self.extra_block_filters: + return out + + # extra block + for extra_conv in self.extra_convs: + out = extra_conv(out) + outputs.append(out) + + for i, n in enumerate(self.normalizations): + if n != -1: + outputs[i] = self.norms[i](outputs[i]) + + return outputs diff --git a/ppdet/modeling/head/__init__.py b/ppdet/modeling/head/__init__.py index 619b3ccf226b88a40961a7fea43701298b8855d1..a0fa75a5ab49e2480db92fe416f50a104298baf5 100644 --- a/ppdet/modeling/head/__init__.py +++ b/ppdet/modeling/head/__init__.py @@ -17,9 +17,11 @@ from . import bbox_head from . import mask_head from . import yolo_head from . import roi_extractor +from . import ssd_head from .rpn_head import * from .bbox_head import * from .mask_head import * from .yolo_head import * from .roi_extractor import * +from .ssd_head import * diff --git a/ppdet/modeling/head/ssd_head.py b/ppdet/modeling/head/ssd_head.py new file mode 100644 index 0000000000000000000000000000000000000000..abccd2c4475a225b572ba2aee713aada6c43fc88 --- /dev/null +++ b/ppdet/modeling/head/ssd_head.py @@ -0,0 +1,69 @@ +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np +from ppdet.core.workspace import register + + +@register +class SSDHead(nn.Layer): + __shared__ = ['num_classes'] + __inject__ = ['anchor_generator', 'loss'] + + def __init__(self, + num_classes=81, + in_channels=(512, 1024, 512, 256, 256, 256), + anchor_generator='AnchorGeneratorSSD', + loss='SSDLoss'): + super(SSDHead, self).__init__() + self.num_classes = num_classes + self.in_channels = in_channels + self.anchor_generator = anchor_generator + self.loss = loss + self.num_priors = self.anchor_generator.num_priors + + self.box_convs = [] + self.score_convs = [] + for i, num_prior in enumerate(self.num_priors): + self.box_convs.append( + self.add_sublayer( + "boxes{}".format(i), + nn.Conv2D( + in_channels=in_channels[i], + out_channels=num_prior * 4, + kernel_size=3, + padding=1))) + self.score_convs.append( + self.add_sublayer( + "scores{}".format(i), + nn.Conv2D( + in_channels=in_channels[i], + out_channels=num_prior * num_classes, + kernel_size=3, + padding=1))) + + def forward(self, feats, image): + box_preds = [] + cls_scores = [] + prior_boxes = [] + for feat, box_conv, score_conv in zip(feats, self.box_convs, + self.score_convs): + box_pred = box_conv(feat) + box_pred = paddle.transpose(box_pred, [0, 2, 3, 1]) + box_pred = paddle.reshape(box_pred, [0, -1, 4]) + box_preds.append(box_pred) + + cls_score = score_conv(feat) + cls_score = paddle.transpose(cls_score, [0, 2, 3, 1]) + cls_score = paddle.reshape(cls_score, [0, -1, self.num_classes]) + cls_scores.append(cls_score) + + prior_boxes = self.anchor_generator(feats, image) + + outputs = {} + outputs['boxes'] = box_preds + outputs['scores'] = cls_scores + return outputs, prior_boxes + + def get_loss(self, inputs, targets, prior_boxes): + return self.loss(inputs, targets, prior_boxes) diff --git a/ppdet/modeling/layers.py b/ppdet/modeling/layers.py index 67b6a32f13c3f5352d2137ebbbcea03691f4f7ad..547c62b13e8c51ae2a6caaeb210dab119320b2c2 100644 --- a/ppdet/modeling/layers.py +++ b/ppdet/modeling/layers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math import numpy as np from numbers import Integral @@ -24,6 +25,12 @@ from . import ops import paddle.nn.functional as F +def _to_list(l): + if isinstance(l, (list, tuple)): + return list(l) + return [l] + + @register @serializable class AnchorGeneratorRPN(object): @@ -103,6 +110,57 @@ class AnchorTargetGeneratorRPN(object): return pred_cls_logits, pred_bbox_pred, tgt_labels, tgt_bboxes, bbox_inside_weights +@register +@serializable +class AnchorGeneratorSSD(object): + def __init__(self, + steps=[8, 16, 32, 64, 100, 300], + aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]], + min_ratio=15, + max_ratio=90, + min_sizes=[30.0, 60.0, 111.0, 162.0, 213.0, 264.0], + max_sizes=[60.0, 111.0, 162.0, 213.0, 264.0, 315.0], + offset=0.5, + flip=True, + clip=False, + min_max_aspect_ratios_order=False): + self.steps = steps + self.aspect_ratios = aspect_ratios + self.min_ratio = min_ratio + self.max_ratio = max_ratio + self.min_sizes = min_sizes + self.max_sizes = max_sizes + self.offset = offset + self.flip = flip + self.clip = clip + self.min_max_aspect_ratios_order = min_max_aspect_ratios_order + + self.num_priors = [] + for aspect_ratio, min_size, max_size in zip(aspect_ratios, min_sizes, + max_sizes): + self.num_priors.append((len(aspect_ratio) * 2 + 1) * len( + _to_list(min_size)) + len(_to_list(max_size))) + + def __call__(self, inputs, image): + boxes = [] + for input, min_size, max_size, aspect_ratio, step in zip( + inputs, self.min_sizes, self.max_sizes, self.aspect_ratios, + self.steps): + box, _ = ops.prior_box( + input=input, + image=image, + min_sizes=_to_list(min_size), + max_sizes=_to_list(max_size), + aspect_ratios=aspect_ratio, + flip=self.flip, + clip=self.clip, + steps=[step, step], + offset=self.offset, + min_max_aspect_ratios_order=self.min_max_aspect_ratios_order) + boxes.append(paddle.reshape(box, [-1, 4])) + return boxes + + @register @serializable class ProposalGenerator(object): @@ -420,7 +478,12 @@ class YOLOBox(object): self.clip_bbox = clip_bbox self.scale_x_y = scale_x_y - def __call__(self, yolo_head_out, anchors, im_shape, scale_factor): + def __call__(self, + yolo_head_out, + anchors, + im_shape, + scale_factor, + var_weight=None): boxes_list = [] scores_list = [] origin_shape = im_shape / scale_factor @@ -437,6 +500,54 @@ class YOLOBox(object): return yolo_boxes, yolo_scores +@register +@serializable +class SSDBox(object): + def __init__(self, is_normalized=True): + self.is_normalized = is_normalized + self.norm_delta = float(not self.is_normalized) + + def __call__(self, + preds, + prior_boxes, + im_shape, + scale_factor, + var_weight=None): + boxes, scores = preds['boxes'], preds['scores'] + outputs = [] + for box, score, prior_box in zip(boxes, scores, prior_boxes): + pb_w = prior_box[:, 2] - prior_box[:, 0] + self.norm_delta + pb_h = prior_box[:, 3] - prior_box[:, 1] + self.norm_delta + pb_x = prior_box[:, 0] + pb_w * 0.5 + pb_y = prior_box[:, 1] + pb_h * 0.5 + out_x = pb_x + box[:, :, 0] * pb_w * 0.1 + out_y = pb_y + box[:, :, 1] * pb_h * 0.1 + out_w = paddle.exp(box[:, :, 2] * 0.2) * pb_w + out_h = paddle.exp(box[:, :, 3] * 0.2) * pb_h + + if self.is_normalized: + h = im_shape[:, 0] / scale_factor[:, 0] + w = im_shape[:, 1] / scale_factor[:, 1] + output = paddle.stack( + [(out_x - out_w / 2.) * w, (out_y - out_h / 2.) * h, + (out_x + out_w / 2.) * w, (out_y + out_h / 2.) * h], + axis=-1) + else: + output = paddle.stack( + [ + out_x - out_w / 2., out_y - out_h / 2., + out_x + out_w / 2. - 1., out_y + out_h / 2. - 1. + ], + axis=-1) + outputs.append(output) + boxes = paddle.concat(outputs, axis=1) + + scores = F.softmax(paddle.concat(scores, axis=1)) + scores = paddle.transpose(scores, [0, 2, 1]) + + return boxes, scores + + @register @serializable class AnchorGrid(object): diff --git a/ppdet/modeling/loss/__init__.py b/ppdet/modeling/loss/__init__.py index 6d87bdc02f94ff6eb1e72f81fab717f3c001d427..b47ab3499c8c9ecf47eb85ceedd9d9a3afebe0fb 100644 --- a/ppdet/modeling/loss/__init__.py +++ b/ppdet/modeling/loss/__init__.py @@ -15,7 +15,9 @@ from . import yolo_loss from . import iou_aware_loss from . import iou_loss +from . import ssd_loss from .yolo_loss import * from .iou_aware_loss import * from .iou_loss import * +from .ssd_loss import * diff --git a/ppdet/modeling/loss/ssd_loss.py b/ppdet/modeling/loss/ssd_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..aa1d2c944a5defdefd6e94c45035af971c2a39d9 --- /dev/null +++ b/ppdet/modeling/loss/ssd_loss.py @@ -0,0 +1,203 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import numpy as np +from ppdet.core.workspace import register +from ..ops import bipartite_match, box_coder, iou_similarity + +__all__ = ['SSDLoss'] + + +@register +class SSDLoss(nn.Layer): + def __init__(self, + match_type='per_prediction', + overlap_threshold=0.5, + neg_pos_ratio=3.0, + neg_overlap=0.5, + loc_loss_weight=1.0, + conf_loss_weight=1.0): + super(SSDLoss, self).__init__() + self.match_type = match_type + self.overlap_threshold = overlap_threshold + self.neg_pos_ratio = neg_pos_ratio + self.neg_overlap = neg_overlap + self.loc_loss_weight = loc_loss_weight + self.conf_loss_weight = conf_loss_weight + + def _label_target_assign(self, + gt_label, + matched_indices, + neg_mask=None, + mismatch_value=0): + gt_label = gt_label.numpy() + matched_indices = matched_indices.numpy() + if neg_mask is not None: + neg_mask = neg_mask.numpy() + + batch_size, num_priors = matched_indices.shape + trg_lbl = np.ones((batch_size, num_priors, 1)).astype('int32') + trg_lbl *= mismatch_value + trg_lbl_wt = np.zeros((batch_size, num_priors, 1)).astype('float32') + + for i in range(batch_size): + col_ids = np.where(matched_indices[i] > -1) + col_val = matched_indices[i][col_ids] + trg_lbl[i][col_ids] = gt_label[i][col_val] + trg_lbl_wt[i][col_ids] = 1.0 + + if neg_mask is not None: + trg_lbl_wt += neg_mask[:, :, np.newaxis] + + return paddle.to_tensor(trg_lbl), paddle.to_tensor(trg_lbl_wt) + + def _bbox_target_assign(self, encoded_box, matched_indices): + encoded_box = encoded_box.numpy() + matched_indices = matched_indices.numpy() + + batch_size, num_priors = matched_indices.shape + trg_bbox = np.zeros((batch_size, num_priors, 4)).astype('float32') + trg_bbox_wt = np.zeros((batch_size, num_priors, 1)).astype('float32') + + for i in range(batch_size): + col_ids = np.where(matched_indices[i] > -1) + col_val = matched_indices[i][col_ids] + for v, c in zip(col_val.tolist(), col_ids[0]): + trg_bbox[i][c] = encoded_box[i][v][c] + trg_bbox_wt[i][col_ids] = 1.0 + + return paddle.to_tensor(trg_bbox), paddle.to_tensor(trg_bbox_wt) + + def _mine_hard_example(self, + conf_loss, + matched_indices, + matched_dist, + neg_pos_ratio=3.0, + neg_overlap=0.5): + pos = (matched_indices > -1).astype(conf_loss.dtype) + num_pos = pos.sum(axis=1, keepdim=True) + neg = (matched_dist < neg_overlap).astype(conf_loss.dtype) + + conf_loss = conf_loss * (1.0 - pos) * neg + loss_idx = conf_loss.argsort(axis=1, descending=True) + idx_rank = loss_idx.argsort(axis=1) + num_negs = [] + for i in range(matched_indices.shape[0]): + cur_idx = loss_idx[i] + cur_num_pos = num_pos[i] + num_neg = paddle.clip(cur_num_pos * neg_pos_ratio, max=pos.shape[1]) + num_negs.append(num_neg) + num_neg = paddle.stack(num_negs, axis=0).expand_as(idx_rank) + neg_mask = (idx_rank < num_neg).astype(conf_loss.dtype) + return neg_mask + + def forward(self, inputs, targets, anchors): + boxes = paddle.concat(inputs['boxes'], axis=1) + scores = paddle.concat(inputs['scores'], axis=1) + prior_boxes = paddle.concat(anchors, axis=0) + gt_box = targets['gt_bbox'] + gt_label = targets['gt_class'].unsqueeze(-1) + batch_size, num_priors, num_classes = scores.shape + + def _reshape_to_2d(x): + return paddle.flatten(x, start_axis=2) + + # 1. Find matched bounding box by prior box. + # 1.1 Compute IOU similarity between ground-truth boxes and prior boxes. + # 1.2 Compute matched bounding box by bipartite matching algorithm. + matched_indices = [] + matched_dist = [] + for i in range(gt_box.shape[0]): + iou = iou_similarity(gt_box[i], prior_boxes) + matched_indice, matched_d = bipartite_match(iou, self.match_type, + self.overlap_threshold) + matched_indices.append(matched_indice) + matched_dist.append(matched_d) + matched_indices = paddle.concat(matched_indices, axis=0) + matched_indices.stop_gradient = True + matched_dist = paddle.concat(matched_dist, axis=0) + matched_dist.stop_gradient = True + + # 2. Compute confidence for mining hard examples + # 2.1. Get the target label based on matched indices + target_label, _ = self._label_target_assign(gt_label, matched_indices) + confidence = _reshape_to_2d(scores) + # 2.2. Compute confidence loss. + # Reshape confidence to 2D tensor. + target_label = _reshape_to_2d(target_label).astype('int64') + conf_loss = F.softmax_with_cross_entropy(confidence, target_label) + conf_loss = paddle.reshape(conf_loss, [batch_size, num_priors]) + + # 3. Mining hard examples + neg_mask = self._mine_hard_example( + conf_loss, + matched_indices, + matched_dist, + neg_pos_ratio=self.neg_pos_ratio, + neg_overlap=self.neg_overlap) + + # 4. Assign classification and regression targets + # 4.1. Encoded bbox according to the prior boxes. + prior_box_var = paddle.to_tensor( + np.array( + [0.1, 0.1, 0.2, 0.2], dtype='float32')).reshape( + [1, 4]).expand_as(prior_boxes) + encoded_bbox = [] + for i in range(gt_box.shape[0]): + encoded_bbox.append( + box_coder( + prior_box=prior_boxes, + prior_box_var=prior_box_var, + target_box=gt_box[i], + code_type='encode_center_size')) + encoded_bbox = paddle.stack(encoded_bbox, axis=0) + # 4.2. Assign regression targets + target_bbox, target_loc_weight = self._bbox_target_assign( + encoded_bbox, matched_indices) + # 4.3. Assign classification targets + target_label, target_conf_weight = self._label_target_assign( + gt_label, matched_indices, neg_mask=neg_mask) + + # 5. Compute loss. + # 5.1 Compute confidence loss. + target_label = _reshape_to_2d(target_label).astype('int64') + conf_loss = F.softmax_with_cross_entropy(confidence, target_label) + + target_conf_weight = _reshape_to_2d(target_conf_weight) + conf_loss = conf_loss * target_conf_weight * self.conf_loss_weight + + # 5.2 Compute regression loss. + location = _reshape_to_2d(boxes) + target_bbox = _reshape_to_2d(target_bbox) + + loc_loss = F.smooth_l1_loss(location, target_bbox, reduction='none') + loc_loss = paddle.sum(loc_loss, axis=-1, keepdim=True) + target_loc_weight = _reshape_to_2d(target_loc_weight) + loc_loss = loc_loss * target_loc_weight * self.loc_loss_weight + + # 5.3 Compute overall weighted loss. + loss = conf_loss + loc_loss + loss = paddle.reshape(loss, [batch_size, num_priors]) + loss = paddle.sum(loss, axis=1, keepdim=True) + normalizer = paddle.sum(target_loc_weight) + loss = paddle.sum(loss / normalizer) + + return loss diff --git a/ppdet/modeling/ops.py b/ppdet/modeling/ops.py index 3a6346cde347dffc64049053c11da474ffbade80..5b65a9a1fa002705403a7e346d7c72cb815b80e5 100644 --- a/ppdet/modeling/ops.py +++ b/ppdet/modeling/ops.py @@ -806,15 +806,12 @@ def prior_box(input, cur_max_sizes = max_sizes if in_dygraph_mode(): - attrs = [ - 'min_sizes', min_sizes, 'aspect_ratios', aspect_ratios, 'variances', - variance, 'flip', flip, 'clip', clip, 'step_w', steps[0], 'step_h', - steps[1], 'offset', offset, 'min_max_aspect_ratios_order', - min_max_aspect_ratios_order - ] - if cur_max_sizes is not None: - attrs.extend('max_sizes', max_sizes) - attrs = tuple(attrs) + assert cur_max_sizes is not None + attrs = ('min_sizes', min_sizes, 'max_sizes', cur_max_sizes, + 'aspect_ratios', aspect_ratios, 'variances', variance, 'flip', + flip, 'clip', clip, 'step_w', steps[0], 'step_h', steps[1], + 'offset', offset, 'min_max_aspect_ratios_order', + min_max_aspect_ratios_order) box, var = core.ops.prior_box(input, image, *attrs) return box, var else: @@ -1254,6 +1251,111 @@ def matrix_nms(bboxes, return output +def bipartite_match(dist_matrix, + match_type=None, + dist_threshold=None, + name=None): + """ + + This operator implements a greedy bipartite matching algorithm, which is + used to obtain the matching with the maximum distance based on the input + distance matrix. For input 2D matrix, the bipartite matching algorithm can + find the matched column for each row (matched means the largest distance), + also can find the matched row for each column. And this operator only + calculate matched indices from column to row. For each instance, + the number of matched indices is the column number of the input distance + matrix. **The OP only supports CPU**. + + There are two outputs, matched indices and distance. + A simple description, this algorithm matched the best (maximum distance) + row entity to the column entity and the matched indices are not duplicated + in each row of ColToRowMatchIndices. If the column entity is not matched + any row entity, set -1 in ColToRowMatchIndices. + + NOTE: the input DistMat can be LoDTensor (with LoD) or Tensor. + If LoDTensor with LoD, the height of ColToRowMatchIndices is batch size. + If Tensor, the height of ColToRowMatchIndices is 1. + + NOTE: This API is a very low level API. It is used by :code:`ssd_loss` + layer. Please consider to use :code:`ssd_loss` instead. + + Args: + dist_matrix(Tensor): This input is a 2-D LoDTensor with shape + [K, M]. The data type is float32 or float64. It is pair-wise + distance matrix between the entities represented by each row and + each column. For example, assumed one entity is A with shape [K], + another entity is B with shape [M]. The dist_matrix[i][j] is the + distance between A[i] and B[j]. The bigger the distance is, the + better matching the pairs are. NOTE: This tensor can contain LoD + information to represent a batch of inputs. One instance of this + batch can contain different numbers of entities. + match_type(str, optional): The type of matching method, should be + 'bipartite' or 'per_prediction'. None ('bipartite') by default. + dist_threshold(float32, optional): If `match_type` is 'per_prediction', + this threshold is to determine the extra matching bboxes based + on the maximum distance, 0.5 by default. + name(str, optional): For detailed information, please refer + to :ref:`api_guide_Name`. Usually name is no need to set and + None by default. + + Returns: + Tuple: + + matched_indices(Tensor): A 2-D Tensor with shape [N, M]. The data + type is int32. N is the batch size. If match_indices[i][j] is -1, it + means B[j] does not match any entity in i-th instance. + Otherwise, it means B[j] is matched to row + match_indices[i][j] in i-th instance. The row number of + i-th instance is saved in match_indices[i][j]. + + matched_distance(Tensor): A 2-D Tensor with shape [N, M]. The data + type is float32. N is batch size. If match_indices[i][j] is -1, + match_distance[i][j] is also -1.0. Otherwise, assumed + match_distance[i][j] = d, and the row offsets of each instance + are called LoD. Then match_distance[i][j] = + dist_matrix[d+LoD[i]][j]. + + Examples: + + .. code-block:: python + import paddle + from ppdet.modeling import ops + from ppdet.modeling.utils import iou_similarity + + paddle.enable_static() + + x = paddle.static.data(name='x', shape=[None, 4], dtype='float32') + y = paddle.static.data(name='y', shape=[None, 4], dtype='float32') + iou = iou_similarity(x=x, y=y) + matched_indices, matched_dist = ops.bipartite_match(iou) + """ + check_variable_and_dtype(dist_matrix, 'dist_matrix', + ['float32', 'float64'], 'bipartite_match') + + if in_dygraph_mode(): + match_indices, match_distance = core.ops.bipartite_match( + dist_matrix, "match_type", match_type, "dist_threshold", + dist_threshold) + return match_indices, match_distance + + helper = LayerHelper('bipartite_match', **locals()) + match_indices = helper.create_variable_for_type_inference(dtype='int32') + match_distance = helper.create_variable_for_type_inference( + dtype=dist_matrix.dtype) + helper.append_op( + type='bipartite_match', + inputs={'DistMat': dist_matrix}, + attrs={ + 'match_type': match_type, + 'dist_threshold': dist_threshold, + }, + outputs={ + 'ColToRowMatchIndices': match_indices, + 'ColToRowMatchDist': match_distance + }) + return match_indices, match_distance + + @paddle.jit.not_to_static def box_coder(prior_box, prior_box_var, diff --git a/ppdet/modeling/tests/test_ops.py b/ppdet/modeling/tests/test_ops.py index 6c4e3a32af70000ba7d08aa579d141c324d45a4d..bccebc2587436235d2f61937f770a16b2501113e 100644 --- a/ppdet/modeling/tests/test_ops.py +++ b/ppdet/modeling/tests/test_ops.py @@ -376,6 +376,31 @@ class TestIoUSimilarity(LayerTest): self.assertTrue(np.array_equal(iou_np, iou_dy_np)) +class TestBipartiteMatch(LayerTest): + def test_bipartite_match(self): + distance = np.random.random((20, 10)).astype('float32') + with self.static_graph(): + x = paddle.static.data(name='x', shape=[20, 10], dtype='float32') + + match_indices, match_dist = ops.bipartite_match( + x, match_type='per_prediction', dist_threshold=0.5) + match_indices_np, match_dist_np = self.get_static_graph_result( + feed={'x': distance, }, + fetch_list=[match_indices, match_dist], + with_lod=False) + + with self.dynamic_graph(): + x_dy = base.to_variable(distance) + + match_indices_dy, match_dist_dy = ops.bipartite_match( + x_dy, match_type='per_prediction', dist_threshold=0.5) + match_indices_dy_np = match_indices_dy.numpy() + match_dist_dy_np = match_dist_dy.numpy() + + self.assertTrue(np.array_equal(match_indices_np, match_indices_dy_np)) + self.assertTrue(np.array_equal(match_dist_np, match_dist_dy_np)) + + class TestYoloBox(LayerTest): def test_yolo_box(self): diff --git a/ppdet/utils/eval_utils.py b/ppdet/utils/eval_utils.py index ac047139f43dd449c4953442e85e4ecf7a747616..6049cb6356677df23c560c020bbb7e7975587bf5 100644 --- a/ppdet/utils/eval_utils.py +++ b/ppdet/utils/eval_utils.py @@ -33,7 +33,7 @@ def json_eval_results(metric, json_directory=None, dataset=None): logger.info("{} not exists!".format(v_json)) -def get_infer_results(outs_res, eval_type, catid, im_info): +def get_infer_results(outs_res, eval_type, catid): """ Get result at the stage of inference. The output format is dictionary containing bbox or mask result. @@ -45,31 +45,27 @@ def get_infer_results(outs_res, eval_type, catid, im_info): raise ValueError( 'The number of valid detection result if zero. Please use reasonable model and check input data.' ) - infer_res = {} - - if 'bbox' in eval_type: - box_res = [] - for i, outs in enumerate(outs_res): - im_ids = im_info[i][2] - box_res += get_det_res(outs['bbox'], outs['bbox_num'], im_ids, - catid) - infer_res['bbox'] = box_res - - if 'mask' in eval_type: - seg_res = [] - # mask post process - for i, outs in enumerate(outs_res): - im_shape = im_info[i][0] - scale_factor = im_info[i][1] - im_ids = im_info[i][2] - mask = outs['mask'] - seg_res += get_seg_res(mask, outs['bbox_num'], im_ids, catid) - infer_res['mask'] = seg_res + + infer_res = {k: [] for k in eval_type} + + for i, outs in enumerate(outs_res): + im_id = outs['im_id'] + im_shape = outs['im_shape'] + scale_factor = outs['scale_factor'] + + if 'bbox' in eval_type: + infer_res['bbox'] += get_det_res(outs['bbox'], outs['bbox_num'], + im_id, catid) + + if 'mask' in eval_type: + # mask post process + infer_res['mask'] += get_seg_res(outs['mask'], outs['bbox_num'], + im_id, catid) return infer_res -def eval_results(res, metric, anno_file): +def eval_results(res, metric, dataset): """ Evalute the inference result """ @@ -82,7 +78,8 @@ def eval_results(res, metric, anno_file): json.dump(res['bbox'], f) logger.info('The bbox result is saved to bbox.json.') - bbox_stats = cocoapi_eval('bbox.json', 'bbox', anno_file=anno_file) + bbox_stats = cocoapi_eval( + 'bbox.json', 'bbox', anno_file=dataset.get_anno()) eval_res.append(bbox_stats) sys.stdout.flush() if 'mask' in res: @@ -90,9 +87,14 @@ def eval_results(res, metric, anno_file): json.dump(res['mask'], f) logger.info('The mask result is saved to mask.json.') - seg_stats = cocoapi_eval('mask.json', 'segm', anno_file=anno_file) + seg_stats = cocoapi_eval( + 'mask.json', 'segm', anno_file=dataset.get_anno()) eval_res.append(seg_stats) sys.stdout.flush() + elif metric == 'VOC': + from ppdet.utils.voc_eval import bbox_eval + + bbox_stats = bbox_eval(res, 21) else: raise NotImplemented("Only COCO metric is supported now.") diff --git a/ppdet/utils/voc_eval.py b/ppdet/utils/voc_eval.py index 1b82928ff4ec7379fcd3b59d9735c0e9d71b11a8..2bf14489bfd1bf868757171e864e57185a69311e 100644 --- a/ppdet/utils/voc_eval.py +++ b/ppdet/utils/voc_eval.py @@ -63,46 +63,33 @@ def bbox_eval(results, evaluate_difficult=evaluate_difficult) for t in results: - bboxes = t['bbox'][0] - bbox_lengths = t['bbox'][1][0] + bboxes = t['bbox'] + bbox_lengths = t['bbox_num'] if bboxes.shape == (1, 1) or bboxes is None: continue - gt_boxes = t['gt_bbox'][0] - gt_labels = t['gt_class'][0] - difficults = t['is_difficult'][0] if not evaluate_difficult \ + gt_boxes = t['gt_bbox'] + gt_labels = t['gt_class'] + difficults = t['difficult'] if not evaluate_difficult \ else None - if len(t['gt_bbox'][1]) == 0: - # gt_bbox, gt_class, difficult read as zero padded Tensor - bbox_idx = 0 - for i in range(len(gt_boxes)): - gt_box = gt_boxes[i] - gt_label = gt_labels[i] - difficult = None if difficults is None \ - else difficults[i] - bbox_num = bbox_lengths[i] - bbox = bboxes[bbox_idx:bbox_idx + bbox_num] - gt_box, gt_label, difficult = prune_zero_padding( - gt_box, gt_label, difficult) - detection_map.update(bbox, gt_box, gt_label, difficult) - bbox_idx += bbox_num - else: - # gt_box, gt_label, difficult read as LoDTensor - gt_box_lengths = t['gt_bbox'][1][0] - bbox_idx = 0 - gt_box_idx = 0 - for i in range(len(bbox_lengths)): - bbox_num = bbox_lengths[i] - gt_box_num = gt_box_lengths[i] - bbox = bboxes[bbox_idx:bbox_idx + bbox_num] - gt_box = gt_boxes[gt_box_idx:gt_box_idx + gt_box_num] - gt_label = gt_labels[gt_box_idx:gt_box_idx + gt_box_num] - difficult = None if difficults is None else \ - difficults[gt_box_idx: gt_box_idx + gt_box_num] - detection_map.update(bbox, gt_box, gt_label, difficult) - bbox_idx += bbox_num - gt_box_idx += gt_box_num + scale_factor = t['scale_factor'] if 'scale_factor' in t else np.ones( + (gt_boxes.shape[0], 2)).astype('float32') + + bbox_idx = 0 + for i in range(gt_boxes.shape[0]): + gt_box = gt_boxes[i] + h, w = scale_factor[i] + gt_box = gt_box / np.array([w, h, w, h]) + gt_label = gt_labels[i] + difficult = None if difficults is None \ + else difficults[i] + bbox_num = bbox_lengths[i] + bbox = bboxes[bbox_idx:bbox_idx + bbox_num] + gt_box, gt_label, difficult = prune_zero_padding(gt_box, gt_label, + difficult) + detection_map.update(bbox, gt_box, gt_label, difficult) + bbox_idx += bbox_num logger.info("Accumulating evaluatation results...") detection_map.accumulate() diff --git a/tools/eval.py b/tools/eval.py index e40f80a2f74508989106e6775194ac889569d1a6..fe91dbbe644a3a8b30937d972fa130d2e3838772 100755 --- a/tools/eval.py +++ b/tools/eval.py @@ -69,34 +69,34 @@ def run(FLAGS, cfg, place): # Data Reader dataset = cfg.EvalDataset - eval_loader, _ = create('EvalReader')(dataset, cfg['worker_num'], place) + eval_loader = create('EvalReader')(dataset, cfg['worker_num'], place) + + extra_key = ['im_shape', 'scale_factor', 'im_id'] + if cfg.metric == 'VOC': + extra_key += ['gt_bbox', 'gt_class', 'difficult'] # Run Eval outs_res = [] start_time = time.time() sample_num = 0 - im_info = [] for iter_id, data in enumerate(eval_loader): # forward - fields = cfg['EvalReader']['inputs_def']['fields'] model.eval() - outs = model(data=data, input_def=fields, mode='infer') + outs = model(data, mode='infer') + for key in extra_key: + outs[key] = data[key] for key, value in outs.items(): outs[key] = value.numpy() - im_shape = data[fields.index('im_shape')].numpy() - scale_factor = data[fields.index('scale_factor')].numpy() - im_id = data[fields.index('im_id')].numpy() - im_info.append([im_shape, scale_factor, im_id]) if 'mask' in outs and 'bbox' in outs: mask_resolution = model.mask_post_process.mask_resolution from ppdet.py_op.post_process import mask_post_process - outs['mask'] = mask_post_process(outs, im_shape, scale_factor, - mask_resolution) + outs['mask'] = mask_post_process( + outs, outs['im_shape'], outs['scale_factor'], mask_resolution) outs_res.append(outs) # log - sample_num += im_shape.shape[0] + sample_num += outs['im_id'].shape[0] if iter_id % 100 == 0: logger.info("Eval iter: {}".format(iter_id)) @@ -111,15 +111,22 @@ def run(FLAGS, cfg, place): eval_type.append('mask') # Metric # TODO: support other metric - from ppdet.utils.coco_eval import get_category_info - anno_file = dataset.get_anno() with_background = cfg.with_background use_default_label = dataset.use_default_label - clsid2catid, catid2name = get_category_info(anno_file, with_background, - use_default_label) + if cfg.metric == 'COCO': + from ppdet.utils.coco_eval import get_category_info + clsid2catid, catid2name = get_category_info( + dataset.get_anno(), with_background, use_default_label) + + infer_res = get_infer_results(outs_res, eval_type, clsid2catid) + + elif cfg.metric == 'VOC': + from ppdet.utils.voc_eval import get_category_info + clsid2catid, catid2name = get_category_info( + dataset.get_label_list(), with_background, use_default_label) + infer_res = outs_res - infer_res = get_infer_results(outs_res, eval_type, clsid2catid, im_info) - eval_results(infer_res, cfg.metric, anno_file) + eval_results(infer_res, cfg.metric, dataset) def main(): diff --git a/tools/export_utils.py b/tools/export_utils.py index 61a0c0dcce807f8ec9a6030b202fced2c201960a..a16561007cf7890a192ebc371494509b5f8095d7 100644 --- a/tools/export_utils.py +++ b/tools/export_utils.py @@ -49,6 +49,8 @@ def parse_reader(reader_cfg, dataset_cfg, metric, arch, image_shape): if metric == 'COCO': from ppdet.utils.coco_eval import get_category_info + elif metric == 'VOC': + from ppdet.utils.voc_eval import get_category_info else: raise ValueError("metric only supports COCO, but received {}".format( metric)) diff --git a/tools/infer.py b/tools/infer.py index 397dbd7fe771549e93aefccd0d0ee553d8909b72..20133f659815ef68e4fde8c9596e66ba4c2d1740 100755 --- a/tools/infer.py +++ b/tools/infer.py @@ -129,7 +129,11 @@ def run(FLAGS, cfg, place): dataset = cfg.TestDataset test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img) dataset.set_images(test_images) - test_loader, _ = create('TestReader')(dataset, cfg['worker_num'], place) + test_loader = create('TestReader')(dataset, cfg['worker_num'], place) + + extra_key = ['im_shape', 'scale_factor', 'im_id'] + if cfg.metric == 'VOC': + extra_key += ['gt_bbox', 'gt_class', 'difficult'] # TODO: support other metrics imid2path = dataset.get_imid2path() @@ -147,24 +151,18 @@ def run(FLAGS, cfg, place): # Run Infer for iter_id, data in enumerate(test_loader): # forward - fields = cfg.TestReader['inputs_def']['fields'] model.eval() - outs = model( - data=data, - input_def=cfg.TestReader['inputs_def']['fields'], - mode='infer') + outs = model(data, mode='infer') + for key in extra_key: + outs[key] = data[key] for key, value in outs.items(): outs[key] = value.numpy() - im_shape = data[fields.index('im_shape')].numpy() - scale_factor = data[fields.index('scale_factor')].numpy() - im_ids = data[fields.index('im_id')].numpy() - im_info = [im_shape, scale_factor, im_ids] if 'mask' in outs and 'bbox' in outs: mask_resolution = model.mask_post_process.mask_resolution from ppdet.py_op.post_process import mask_post_process - outs['mask'] = mask_post_process(outs, im_shape, scale_factor, - mask_resolution) + outs['mask'] = mask_post_process( + outs, outs['im_shape'], outs['scale_factor'], mask_resolution) eval_type = [] if 'bbox' in outs: @@ -172,14 +170,14 @@ def run(FLAGS, cfg, place): if 'mask' in outs: eval_type.append('mask') - batch_res = get_infer_results([outs], eval_type, clsid2catid, [im_info]) + batch_res = get_infer_results([outs], eval_type, clsid2catid) logger.info('Infer iter {}'.format(iter_id)) bbox_res = None mask_res = None bbox_num = outs['bbox_num'] start = 0 - for i, im_id in enumerate(im_ids): + for i, im_id in enumerate(outs['im_id']): image_path = imid2path[int(im_id)] image = Image.open(image_path).convert('RGB') end = start + bbox_num[i] @@ -197,7 +195,7 @@ def run(FLAGS, cfg, place): mask_res = batch_res['mask'][start:end] image = visualize_results(image, bbox_res, mask_res, - int(im_id), catid2name, + int(outs['im_id']), catid2name, FLAGS.draw_threshold) # use VisualDL to log image with bbox diff --git a/tools/train.py b/tools/train.py index 115c38ef5bd81d852d3de5ede628bb25b3dafc70..85018c1f8fc1bb1532f835a7dab164151048048b 100755 --- a/tools/train.py +++ b/tools/train.py @@ -103,8 +103,8 @@ def run(FLAGS, cfg, place): # Data dataset = cfg.TrainDataset - train_loader, step_per_epoch = create('TrainReader')( - dataset, cfg['worker_num'], place) + train_loader = create('TrainReader')(dataset, cfg['worker_num'], place) + step_per_epoch = len(train_loader) # Model model = create(cfg.architecture) @@ -134,7 +134,6 @@ def run(FLAGS, cfg, place): if ParallelEnv().nranks > 1: model = paddle.DataParallel(model) - fields = train_loader.collate_fn.output_fields cfg_name = os.path.basename(FLAGS.config).split('.')[0] save_dir = os.path.join(cfg.save_dir, cfg_name) # Run Train @@ -155,7 +154,7 @@ def run(FLAGS, cfg, place): # Model Forward model.train() - outputs = model(data=data, input_def=fields, mode='train') + outputs = model(data, mode='train') # Model Backward loss = outputs['loss']