diff --git a/configs/face_detection/README.md b/configs/face_detection/README.md index c1dc2e74ca806f5e90535754c0231335971c5ecb..5eb320f87bf18bd59df25e94b022a8e411f17a83 100644 --- a/configs/face_detection/README.md +++ b/configs/face_detection/README.md @@ -12,6 +12,7 @@ | 网络结构 | 输入尺寸 | 图片个数/GPU | 学习率策略 | Easy/Medium/Hard Set | 预测时延(SD855)| 模型大小(MB) | 下载 | 配置文件 | |:------------:|:--------:|:----:|:-------:|:-------:|:---------:|:----------:|:---------:|:--------:| | BlazeFace | 640 | 8 | 1000e | 0.885 / 0.855 / 0.731 | - | 0.472 |[下载链接](https://paddledet.bj.bcebos.com/models/blazeface_1000e.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.1/configs/face_detection/blazeface_1000e.yml) | +| BlazeFace-FPN-SSH | 640 | 8 | 1000e | 0.907 / 0.883 / 0.793 | - | 0.479 |[下载链接](https://paddledet.bj.bcebos.com/models/blazeface_fpn_ssh_1000e.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/release/2.1/configs/face_detection/blazeface_fpn_ssh_1000e.yml) | **注意:** - 我们使用多尺度评估策略得到`Easy/Medium/Hard Set`里的mAP。具体细节请参考[在WIDER-FACE数据集上评估](#在WIDER-FACE数据集上评估)。 @@ -52,6 +53,23 @@ cd dataset/wider_face && ./download_wider_face.sh ``` +### 参数配置 +基础模型的配置可以参考`configs/face_detection/_base_/blazeface.yml`; +改进模型增加FPN和SSH的neck结构,配置文件可以参考`configs/face_detection/_base_/blazeface_fpn.yml`,可以根据需求配置FPN和SSH,具体如下: +```yaml +BlazeNet: + blaze_filters: [[24, 24], [24, 24], [24, 48, 2], [48, 48], [48, 48]] + double_blaze_filters: [[48, 24, 96, 2], [96, 24, 96], [96, 24, 96], + [96, 24, 96, 2], [96, 24, 96], [96, 24, 96]] + act: hard_swish #配置backbone中BlazeBlock的激活函数,基础模型为relu,增加FPN和SSH时需使用hard_swish + +BlazeNeck: + neck_type : fpn_ssh #可选only_fpn、only_ssh和fpn_ssh + in_channel: [96,96] +``` + + + ### 训练与评估 训练流程与评估流程方法与其他算法一致,请参考[GETTING_STARTED_cn.md](../../docs/tutorials/GETTING_STARTED_cn.md)。 **注意:** 人脸检测模型目前不支持边训练边评估。 diff --git a/configs/face_detection/_base_/blazeface.yml b/configs/face_detection/_base_/blazeface.yml index 469aa9c4ca067e3dc38a6bba1832d8050e30b19e..de54100fe63c1d0dd004c5c1797b6a6587106993 100644 --- a/configs/face_detection/_base_/blazeface.yml +++ b/configs/face_detection/_base_/blazeface.yml @@ -1,17 +1,23 @@ -architecture: SSD +architecture: BlazeFace -SSD: +BlazeFace: backbone: BlazeNet - ssd_head: FaceHead + neck: BlazeNeck + blaze_head: FaceHead post_process: BBoxPostProcess BlazeNet: blaze_filters: [[24, 24], [24, 24], [24, 48, 2], [48, 48], [48, 48]] double_blaze_filters: [[48, 24, 96, 2], [96, 24, 96], [96, 24, 96], [96, 24, 96, 2], [96, 24, 96], [96, 24, 96]] + act: relu + +BlazeNeck: + neck_type : None + in_channel: [96,96] FaceHead: - in_channels: [96, 96] + in_channels: [96,96] anchor_generator: AnchorGeneratorSSD loss: SSDLoss diff --git a/configs/face_detection/_base_/blazeface_fpn.yml b/configs/face_detection/_base_/blazeface_fpn.yml new file mode 100644 index 0000000000000000000000000000000000000000..6572a99d301eda65a65c485e133cc00497a2eee2 --- /dev/null +++ b/configs/face_detection/_base_/blazeface_fpn.yml @@ -0,0 +1,45 @@ +architecture: BlazeFace + +BlazeFace: + backbone: BlazeNet + neck: BlazeNeck + blaze_head: FaceHead + post_process: BBoxPostProcess + +BlazeNet: + blaze_filters: [[24, 24], [24, 24], [24, 48, 2], [48, 48], [48, 48]] + double_blaze_filters: [[48, 24, 96, 2], [96, 24, 96], [96, 24, 96], + [96, 24, 96, 2], [96, 24, 96], [96, 24, 96]] + act: hard_swish + +BlazeNeck: + neck_type : fpn_ssh + in_channel: [96,96] + +FaceHead: + in_channels: [48, 48] + anchor_generator: AnchorGeneratorSSD + loss: SSDLoss + +SSDLoss: + overlap_threshold: 0.35 + +AnchorGeneratorSSD: + steps: [8., 16.] + aspect_ratios: [[1.], [1.]] + min_sizes: [[16.,24.], [32., 48., 64., 80., 96., 128.]] + max_sizes: [[], []] + offset: 0.5 + flip: False + min_max_aspect_ratios_order: false + +BBoxPostProcess: + decode: + name: SSDBox + nms: + name: MultiClassNMS + keep_top_k: 750 + score_threshold: 0.01 + nms_threshold: 0.3 + nms_top_k: 5000 + nms_eta: 1.0 diff --git a/configs/face_detection/blazeface_fpn_ssh_1000e.yml b/configs/face_detection/blazeface_fpn_ssh_1000e.yml new file mode 100644 index 0000000000000000000000000000000000000000..21dbd26443856710a5674f8e93e1cc0075836a38 --- /dev/null +++ b/configs/face_detection/blazeface_fpn_ssh_1000e.yml @@ -0,0 +1,9 @@ +_BASE_: [ + '../datasets/wider_face.yml', + '../runtime.yml', + '_base_/optimizer_1000e.yml', + '_base_/blazeface_fpn.yml', + '_base_/face_reader.yml', +] +weights: output/blazeface_fpn_ssh_1000e/model_final +multi_scale_eval: True diff --git a/deploy/python/infer.py b/deploy/python/infer.py index 99d5424c084cf56b12e93fbc14bcfad55180d499..94ca6c903b75fed7204cb02a9bddcfa9b83062fb 100644 --- a/deploy/python/infer.py +++ b/deploy/python/infer.py @@ -36,6 +36,7 @@ SUPPORT_MODELS = { 'YOLO', 'RCNN', 'SSD', + 'Face', 'FCOS', 'SOLOv2', 'TTFNet', @@ -113,14 +114,6 @@ class Detector(object): threshold=0.5): # postprocess output of predictor results = {} - if self.pred_config.arch in ['Face']: - h, w = inputs['im_shape'] - scale_y, scale_x = inputs['scale_factor'] - w, h = float(h) / scale_y, float(w) / scale_x - np_boxes[:, 2] *= h - np_boxes[:, 3] *= w - np_boxes[:, 4] *= h - np_boxes[:, 5] *= w results['boxes'] = np_boxes results['boxes_num'] = np_boxes_num if np_masks is not None: diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 20109678ecef375abdebaccd398a8a43f3a34ac7..cef560d9d3cadf923fd0683ba5996a541be8fd65 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -433,7 +433,6 @@ class Trainer(object): if 'segm' in batch_res else None keypoint_res = batch_res['keypoint'][start:end] \ if 'keypoint' in batch_res else None - image = visualize_results( image, bbox_res, mask_res, segm_res, keypoint_res, int(im_id), catid2name, draw_threshold) diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index 6b2a33ae57801caa618b4e4ec22d9872d0fe8d7a..2efcd1d0eaae7824a3e81e2eca2d5601533e821d 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -38,3 +38,4 @@ from .jde import * from .deepsort import * from .fairmot import * from .centernet import * +from .blazeface import * diff --git a/ppdet/modeling/architectures/blazeface.py b/ppdet/modeling/architectures/blazeface.py new file mode 100644 index 0000000000000000000000000000000000000000..af6aa269d1680a0990bcfa78875ef07ac62c5af0 --- /dev/null +++ b/ppdet/modeling/architectures/blazeface.py @@ -0,0 +1,91 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from ppdet.core.workspace import register, create +from .meta_arch import BaseArch + +__all__ = ['BlazeFace'] + + +@register +class BlazeFace(BaseArch): + """ + BlazeFace: Sub-millisecond Neural Face Detection on Mobile GPUs, + see https://arxiv.org/abs/1907.05047 + + Args: + backbone (nn.Layer): backbone instance + neck (nn.Layer): neck instance + blaze_head (nn.Layer): `blazeHead` instance + post_process (object): `BBoxPostProcess` instance + """ + + __category__ = 'architecture' + __inject__ = ['post_process'] + + def __init__(self, backbone, blaze_head, neck, post_process): + super(BlazeFace, self).__init__() + self.backbone = backbone + self.neck = neck + self.blaze_head = blaze_head + self.post_process = post_process + + @classmethod + def from_config(cls, cfg, *args, **kwargs): + # backbone + backbone = create(cfg['backbone']) + # fpn + kwargs = {'input_shape': backbone.out_shape} + neck = create(cfg['neck'], **kwargs) + # head + kwargs = {'input_shape': neck.out_shape} + blaze_head = create(cfg['blaze_head'], **kwargs) + + return { + 'backbone': backbone, + 'neck': neck, + 'blaze_head': blaze_head, + } + + def _forward(self): + # Backbone + body_feats = self.backbone(self.inputs) + # neck + neck_feats = self.neck(body_feats) + # blaze Head + if self.training: + return self.blaze_head(neck_feats, self.inputs['image'], + self.inputs['gt_bbox'], + self.inputs['gt_class']) + else: + preds, anchors = self.blaze_head(neck_feats, self.inputs['image']) + bbox, bbox_num = self.post_process(preds, anchors, + self.inputs['im_shape'], + self.inputs['scale_factor']) + return bbox, bbox_num + + def get_loss(self, ): + return {"loss": self._forward()} + + def get_pred(self): + bbox_pred, bbox_num = self._forward() + output = { + "bbox": bbox_pred, + "bbox_num": bbox_num, + } + return output diff --git a/ppdet/modeling/backbones/blazenet.py b/ppdet/modeling/backbones/blazenet.py index 97134c28e718b7f3a4beebe11109dfe2c54a50df..d5f345e600ce6addfefd306aa6426ef64a4cde42 100644 --- a/ppdet/modeling/backbones/blazenet.py +++ b/ppdet/modeling/backbones/blazenet.py @@ -29,6 +29,10 @@ from ..shape_spec import ShapeSpec __all__ = ['BlazeNet'] +def hard_swish(x): + return x * F.relu6(x + 3) / 6. + + class ConvBNLayer(nn.Layer): def __init__(self, in_channels, @@ -80,6 +84,10 @@ class ConvBNLayer(nn.Layer): x = F.relu(x) elif self.act == "relu6": x = F.relu6(x) + elif self.act == 'leaky': + x = F.leaky_relu(x) + elif self.act == 'hard_swish': + x = hard_swish(x) return x @@ -91,6 +99,7 @@ class BlazeBlock(nn.Layer): double_channels=None, stride=1, use_5x5kernel=True, + act='relu', name=None): super(BlazeBlock, self).__init__() assert stride in [1, 2] @@ -132,14 +141,14 @@ class BlazeBlock(nn.Layer): padding=1, num_groups=out_channels1, name=name + "1_dw_2"))) - act = 'relu' if self.use_double_block else None + self.act = act if self.use_double_block else None self.conv_pw = ConvBNLayer( in_channels=out_channels1, out_channels=out_channels2, kernel_size=1, stride=1, padding=0, - act=act, + act=self.act, name=name + "1_sep") if self.use_double_block: self.conv_dw2 = [] @@ -237,7 +246,8 @@ class BlazeNet(nn.Layer): blaze_filters=[[24, 24], [24, 24], [24, 48, 2], [48, 48], [48, 48]], double_blaze_filters=[[48, 24, 96, 2], [96, 24, 96], [96, 24, 96], [96, 24, 96, 2], [96, 24, 96], [96, 24, 96]], - use_5x5kernel=True): + use_5x5kernel=True, + act=None): super(BlazeNet, self).__init__() conv1_num_filters = blaze_filters[0][0] self.conv1 = ConvBNLayer( @@ -262,6 +272,7 @@ class BlazeNet(nn.Layer): v[0], v[1], use_5x5kernel=use_5x5kernel, + act=act, name='blaze_{}'.format(k)))) elif len(v) == 3: self.blaze_block.append( @@ -273,6 +284,7 @@ class BlazeNet(nn.Layer): v[1], stride=v[2], use_5x5kernel=use_5x5kernel, + act=act, name='blaze_{}'.format(k)))) in_channels = v[1] @@ -289,6 +301,7 @@ class BlazeNet(nn.Layer): v[1], double_channels=v[2], use_5x5kernel=use_5x5kernel, + act=act, name='double_blaze_{}'.format(k)))) elif len(v) == 4: self.blaze_block.append( @@ -301,6 +314,7 @@ class BlazeNet(nn.Layer): double_channels=v[2], stride=v[3], use_5x5kernel=use_5x5kernel, + act=act, name='double_blaze_{}'.format(k)))) in_channels = v[2] self._out_channels.append(in_channels) diff --git a/ppdet/modeling/heads/face_head.py b/ppdet/modeling/heads/face_head.py index 937f30db3e41ee4ffba8eaf15abacbc835b390c9..83f34c26c1819fae3d4c18cf4c9952005788924d 100644 --- a/ppdet/modeling/heads/face_head.py +++ b/ppdet/modeling/heads/face_head.py @@ -41,7 +41,7 @@ class FaceHead(nn.Layer): def __init__(self, num_classes=80, - in_channels=(96, 96), + in_channels=[96, 96], anchor_generator=AnchorGeneratorSSD().__dict__, kernel_size=3, padding=1, @@ -65,7 +65,7 @@ class FaceHead(nn.Layer): box_conv = self.add_sublayer( box_conv_name, nn.Conv2D( - in_channels=in_channels[i], + in_channels=self.in_channels[i], out_channels=num_prior * 4, kernel_size=kernel_size, padding=padding)) @@ -75,7 +75,7 @@ class FaceHead(nn.Layer): score_conv = self.add_sublayer( score_conv_name, nn.Conv2D( - in_channels=in_channels[i], + in_channels=self.in_channels[i], out_channels=num_prior * self.num_classes, kernel_size=kernel_size, padding=padding)) diff --git a/ppdet/modeling/necks/__init__.py b/ppdet/modeling/necks/__init__.py index 6de12cffb3f0beb10de74597968cf157255377ce..7a7e3af40fc5914f0e90e0409c7dad05731d0426 100644 --- a/ppdet/modeling/necks/__init__.py +++ b/ppdet/modeling/necks/__init__.py @@ -23,3 +23,4 @@ from .yolo_fpn import * from .hrfpn import * from .ttf_fpn import * from .centernet_fpn import * +from .blazeface_fpn import * diff --git a/ppdet/modeling/necks/blazeface_fpn.py b/ppdet/modeling/necks/blazeface_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..f712372a12372a94b88914abf3de5f76202ee7fc --- /dev/null +++ b/ppdet/modeling/necks/blazeface_fpn.py @@ -0,0 +1,230 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import math +import paddle +import paddle.nn.functional as F +from paddle import ParamAttr +import paddle.nn as nn +from paddle.nn.initializer import KaimingNormal +from ppdet.core.workspace import register, serializable +from ppdet.modeling.layers import ConvNormLayer +from ..shape_spec import ShapeSpec + +__all__ = ['BlazeNeck'] + + +def hard_swish(x): + return x * F.relu6(x + 3) / 6. + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding, + num_groups=1, + act='relu', + conv_lr=0.1, + conv_decay=0., + norm_decay=0., + norm_type='bn', + name=None): + super(ConvBNLayer, self).__init__() + self.act = act + self._conv = nn.Conv2D( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + groups=num_groups, + weight_attr=ParamAttr( + learning_rate=conv_lr, + initializer=KaimingNormal(), + name=name + "_weights"), + bias_attr=False) + + param_attr = ParamAttr(name=name + "_bn_scale") + bias_attr = ParamAttr(name=name + "_bn_offset") + if norm_type == 'sync_bn': + self._batch_norm = nn.SyncBatchNorm( + out_channels, weight_attr=param_attr, bias_attr=bias_attr) + else: + self._batch_norm = nn.BatchNorm( + out_channels, + act=None, + param_attr=param_attr, + bias_attr=bias_attr, + use_global_stats=False, + moving_mean_name=name + '_bn_mean', + moving_variance_name=name + '_bn_variance') + + def forward(self, x): + x = self._conv(x) + x = self._batch_norm(x) + if self.act == "relu": + x = F.relu(x) + elif self.act == "relu6": + x = F.relu6(x) + elif self.act == 'leaky': + x = F.leaky_relu(x) + elif self.act == 'hard_swish': + x = hard_swish(x) + return x + + +class FPN(nn.Layer): + def __init__(self, in_channels, out_channels, name=None): + super(FPN, self).__init__() + self.conv1_fpn = ConvBNLayer( + in_channels, + out_channels // 2, + kernel_size=1, + padding=0, + stride=1, + act='leaky', + name=name + '_output1') + self.conv2_fpn = ConvBNLayer( + in_channels, + out_channels // 2, + kernel_size=1, + padding=0, + stride=1, + act='leaky', + name=name + '_output2') + self.conv3_fpn = ConvBNLayer( + out_channels // 2, + out_channels // 2, + kernel_size=3, + padding=1, + stride=1, + act='leaky', + name=name + '_merge') + + def forward(self, input): + output1 = self.conv1_fpn(input[0]) + output2 = self.conv2_fpn(input[1]) + up2 = F.upsample( + output2, size=paddle.shape(output1)[-2:], mode='nearest') + output1 = paddle.add(output1, up2) + output1 = self.conv3_fpn(output1) + return output1, output2 + + +class SSH(nn.Layer): + def __init__(self, in_channels, out_channels, name=None): + super(SSH, self).__init__() + assert out_channels % 4 == 0 + self.conv0_ssh = ConvBNLayer( + in_channels, + out_channels // 2, + kernel_size=3, + padding=1, + stride=1, + act=None, + name=name + 'ssh_conv3') + self.conv1_ssh = ConvBNLayer( + out_channels // 2, + out_channels // 4, + kernel_size=3, + padding=1, + stride=1, + act='leaky', + name=name + 'ssh_conv5_1') + self.conv2_ssh = ConvBNLayer( + out_channels // 4, + out_channels // 4, + kernel_size=3, + padding=1, + stride=1, + act=None, + name=name + 'ssh_conv5_2') + self.conv3_ssh = ConvBNLayer( + out_channels // 4, + out_channels // 4, + kernel_size=3, + padding=1, + stride=1, + act='leaky', + name=name + 'ssh_conv7_1') + self.conv4_ssh = ConvBNLayer( + out_channels // 4, + out_channels // 4, + kernel_size=3, + padding=1, + stride=1, + act=None, + name=name + 'ssh_conv7_2') + + def forward(self, x): + conv0 = self.conv0_ssh(x) + conv1 = self.conv1_ssh(conv0) + conv2 = self.conv2_ssh(conv1) + conv3 = self.conv3_ssh(conv2) + conv4 = self.conv4_ssh(conv3) + concat = paddle.concat([conv0, conv2, conv4], axis=1) + return F.relu(concat) + + +@register +@serializable +class BlazeNeck(nn.Layer): + def __init__(self, in_channel, neck_type="None", data_format='NCHW'): + super(BlazeNeck, self).__init__() + self.neck_type = neck_type + self.reture_input = False + self._out_channels = in_channel + if self.neck_type == 'None': + self.reture_input = True + if "fpn" in self.neck_type: + self.fpn = FPN(self._out_channels[0], + self._out_channels[1], + name='fpn') + self._out_channels = [ + self._out_channels[0] // 2, self._out_channels[1] // 2 + ] + if "ssh" in self.neck_type: + self.ssh1 = SSH(self._out_channels[0], + self._out_channels[0], + name='ssh1') + self.ssh2 = SSH(self._out_channels[1], + self._out_channels[1], + name='ssh2') + self._out_channels = [self._out_channels[0], self._out_channels[1]] + + def forward(self, inputs): + if self.reture_input: + return inputs + output1, output2 = None, None + if "fpn" in self.neck_type: + backout_4, backout_1 = inputs + output1, output2 = self.fpn([backout_4, backout_1]) + if self.neck_type == "only_fpn": + return [output1, output2] + if self.neck_type == "only_ssh": + output1, output2 = inputs + feature1 = self.ssh1(output1) + feature2 = self.ssh2(output2) + return [feature1, feature2] + + @property + def out_shape(self): + return [ + ShapeSpec(channels=c) + for c in [self._out_channels[0], self._out_channels[1]] + ] diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index 22bffff7e9130db2421d9ff9a68e8b893cc3f6e3..c74a053a03171e0479903b22197e356bb2ff49ba 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -162,7 +162,7 @@ def load_pretrain_weight(model, pretrain_weight): # hack: fit for faster rcnn. Pretrain weights contain prefix of 'backbone' # while res5 module is located in bbox_head.head. Replace the prefix of # res5 with 'bbox_head.head' to load pretrain weights correctly. - for k in param_state_dict.keys(): + for k in list(param_state_dict.keys()): if 'backbone.res5' in k: new_k = k.replace('backbone', 'bbox_head.head') if new_k in model_dict.keys():