未验证 提交 01d905e3 编写于 作者: G Guanghua Yu 提交者: GitHub

add blazeface & update rcnn modelzoo (#2303)

* add blazeface & update rcnn modelzoo

* remove condition num_priors
上级 1d1c5826
metric: WiderFace
num_classes: 1
TrainDataset:
!WIDERFaceDataSet
dataset_dir: dataset/wider_face
anno_path: wider_face_split/wider_face_train_bbx_gt.txt
image_dir: WIDER_train/images
data_fields: ['image', 'gt_bbox', 'gt_class']
EvalDataset:
!WIDERFaceDataSet
dataset_dir: dataset/wider_face
anno_path: wider_face_split/wider_face_val_bbx_gt.txt
image_dir: WIDER_val/images
data_fields: ['image']
TestDataset:
!ImageFolder
use_default_label: true
...@@ -7,10 +7,10 @@ ...@@ -7,10 +7,10 @@
| ResNet50-vd-FPN | Faster | c3-c5 | 1 | 2x | - | 43.7 | - | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_dcn_r50_vd_fpn_2x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/dcn/faster_rcnn_dcn_r50_vd_fpn_2x_coco.yml) | | ResNet50-vd-FPN | Faster | c3-c5 | 1 | 2x | - | 43.7 | - | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_dcn_r50_vd_fpn_2x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/dcn/faster_rcnn_dcn_r50_vd_fpn_2x_coco.yml) |
| ResNet101-vd-FPN | Faster | c3-c5 | 1 | 1x | - | 45.1 | - | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_dcn_r101_vd_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/dcn/faster_rcnn_dcn_r101_vd_fpn_1x_coco.yml) | | ResNet101-vd-FPN | Faster | c3-c5 | 1 | 1x | - | 45.1 | - | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_dcn_r101_vd_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/dcn/faster_rcnn_dcn_r101_vd_fpn_1x_coco.yml) |
| ResNeXt101-vd-FPN | Faster | c3-c5 | 1 | 1x | - | 46.5 | - | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_dcn_x101_vd_64x4d_fpn_1x_coco.pdparams) |[配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/dcn/faster_rcnn_dcn_x101_vd_64x4d_fpn_1x_coco.yml) | | ResNeXt101-vd-FPN | Faster | c3-c5 | 1 | 1x | - | 46.5 | - | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_dcn_x101_vd_64x4d_fpn_1x_coco.pdparams) |[配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/dcn/faster_rcnn_dcn_x101_vd_64x4d_fpn_1x_coco.yml) |
| ResNet50-FPN | Mask | c3-c5 | 1 | 1x | - | - | - | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_dcn_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/dcn/mask_rcnn_dcn_r50_fpn_1x_coco.yml) | | ResNet50-FPN | Mask | c3-c5 | 1 | 1x | - | 42.7 | 38.4 | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_dcn_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/dcn/mask_rcnn_dcn_r50_fpn_1x_coco.yml) |
| ResNet50-vd-FPN | Mask | c3-c5 | 1 | 2x | - | - | - | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_dcn_r50_vd_fpn_2x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/dcn/mask_rcnn_dcn_r50_vd_fpn_2x_coco.yml) | | ResNet50-vd-FPN | Mask | c3-c5 | 1 | 2x | - | 44.6 | 39.8 | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_dcn_r50_vd_fpn_2x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/dcn/mask_rcnn_dcn_r50_vd_fpn_2x_coco.yml) |
| ResNet101-vd-FPN | Mask | c3-c5 | 1 | 1x | - | 45.6 | 40.6 | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_dcn_r101_vd_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/dcn/mask_rcnn_dcn_r101_vd_fpn_1x_coco.yml) | | ResNet101-vd-FPN | Mask | c3-c5 | 1 | 1x | - | 45.6 | 40.6 | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_dcn_r101_vd_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/dcn/mask_rcnn_dcn_r101_vd_fpn_1x_coco.yml) |
| ResNeXt101-vd-FPN | Mask | c3-c5 | 1 | 1x | - | - | - | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_dcn_x101_vd_64x4d_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/dcn/mask_rcnn_dcn_x101_vd_64x4d_fpn_1x_coco.yml) | | ResNeXt101-vd-FPN | Mask | c3-c5 | 1 | 1x | - | 47.3 | 42.0 | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_dcn_x101_vd_64x4d_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/dcn/mask_rcnn_dcn_x101_vd_64x4d_fpn_1x_coco.yml) |
| ResNet50-FPN | Cascade Faster | c3-c5 | 1 | 1x | - | 42.1 | - | [下载链接](https://paddledet.bj.bcebos.com/models/cascade_rcnn_dcn_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/dcn/cascade_rcnn_dcn_r50_fpn_1x_coco.yml) | | ResNet50-FPN | Cascade Faster | c3-c5 | 1 | 1x | - | 42.1 | - | [下载链接](https://paddledet.bj.bcebos.com/models/cascade_rcnn_dcn_r50_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/dcn/cascade_rcnn_dcn_r50_fpn_1x_coco.yml) |
| ResNeXt101-vd-FPN | Cascade Faster | c3-c5 | 1 | 1x | - | 48.8 | - | [下载链接](https://paddledet.bj.bcebos.com/models/cascade_rcnn_dcn_x101_vd_64x4d_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/dcn/cascade_rcnn_dcn_x101_vd_64x4d_fpn_1x_coco.yml) | | ResNeXt101-vd-FPN | Cascade Faster | c3-c5 | 1 | 1x | - | 48.8 | - | [下载链接](https://paddledet.bj.bcebos.com/models/cascade_rcnn_dcn_x101_vd_64x4d_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/dcn/cascade_rcnn_dcn_x101_vd_64x4d_fpn_1x_coco.yml) |
......
# 人脸检测模型
## 简介
`face_detection`中提供高效、高速的人脸检测解决方案,包括最先进的模型和经典模型。
![](../../docs/images/12_Group_Group_12_Group_Group_12_935.jpg)
## 模型库
#### WIDER-FACE数据集上的mAP
| 网络结构 | 输入尺寸 | 图片个数/GPU | 学习率策略 | Easy/Medium/Hard Set | 预测时延(SD855)| 模型大小(MB) | 下载 | 配置文件 |
|:------------:|:--------:|:----:|:-------:|:-------:|:---------:|:----------:|:---------:|:--------:|
| BlazeFace | 640 | 8 | 1000e | 0.889 / 0.859 / 0.740 | - | 0.472 |[下载链接](https://paddledet.bj.bcebos.com/models/blazeface_1000e.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/face_detection/blazeface_1000e.yml) |
**注意:**
- 我们使用多尺度评估策略得到`Easy/Medium/Hard Set`里的mAP。具体细节请参考[在WIDER-FACE数据集上评估](#在WIDER-FACE数据集上评估)
## 快速开始
### 数据准备
我们使用[WIDER-FACE数据集](http://shuoyang1213.me/WIDERFACE/)进行训练和模型测试,官方网站提供了详细的数据介绍。
- WIDER-Face数据源:
使用如下目录结构加载`wider_face`类型的数据集:
```
dataset/wider_face/
├── wider_face_split
│ ├── wider_face_train_bbx_gt.txt
│ ├── wider_face_val_bbx_gt.txt
├── WIDER_train
│ ├── images
│ │ ├── 0--Parade
│ │ │ ├── 0_Parade_marchingband_1_100.jpg
│ │ │ ├── 0_Parade_marchingband_1_381.jpg
│ │ │ │ ...
│ │ ├── 10--People_Marching
│ │ │ ...
├── WIDER_val
│ ├── images
│ │ ├── 0--Parade
│ │ │ ├── 0_Parade_marchingband_1_1004.jpg
│ │ │ ├── 0_Parade_marchingband_1_1045.jpg
│ │ │ │ ...
│ │ ├── 10--People_Marching
│ │ │ ...
```
- 手动下载数据集:
要下载WIDER-FACE数据集,请运行以下命令:
```
cd dataset/wider_face && ./download_wider_face.sh
```
### 训练与评估
训练流程与评估流程方法与其他算法一致,请参考[GETTING_STARTED_cn.md](../../docs/tutorials/GETTING_STARTED_cn.md)
**注意:**
- 人脸检测模型目前不支持边训练边评估。
#### 在WIDER-FACE数据集上评估
评估并生成结果文件:
```shell
python -u tools/eval.py -c configs/face_detection/blazeface_1000e.yml \
-o weights=output/blazeface_1000e/model_final \
multi_scale=True
```
设置`multi_scale=True`进行多尺度评估,评估完成后,将在`output/pred`中生成txt格式的测试结果。
- 下载官方评估脚本来评估AP指标:
```
wget http://mmlab.ie.cuhk.edu.hk/projects/WIDERFace/support/eval_script/eval_tools.zip
unzip eval_tools.zip && rm -f eval_tools.zip
```
-`eval_tools/wider_eval.m`中修改保存结果路径和绘制曲线的名称:
```
# Modify the folder name where the result is stored.
pred_dir = './pred';
# Modify the name of the curve to be drawn
legend_name = 'Fluid-BlazeFace';
```
- `wider_eval.m` 是评估模块的主要执行程序。运行命令如下:
```
matlab -nodesktop -nosplash -nojvm -r "run wider_eval.m;quit;"
```
## Citations
```
@article{bazarevsky2019blazeface,
title={BlazeFace: Sub-millisecond Neural Face Detection on Mobile GPUs},
author={Valentin Bazarevsky and Yury Kartynnik and Andrey Vakunov and Karthik Raveendran and Matthias Grundmann},
year={2019},
eprint={1907.05047},
archivePrefix={arXiv},
```
architecture: SSD
SSD:
backbone: BlazeNet
ssd_head: FaceHead
post_process: BBoxPostProcess
BlazeNet:
blaze_filters: [[24, 24], [24, 24], [24, 48, 2], [48, 48], [48, 48]]
double_blaze_filters: [[48, 24, 96, 2], [96, 24, 96], [96, 24, 96],
[96, 24, 96, 2], [96, 24, 96], [96, 24, 96]]
FaceHead:
in_channels: [96, 96]
anchor_generator: AnchorGeneratorSSD
loss: SSDLoss
SSDLoss:
overlap_threshold: 0.35
neg_overlap: 0.35
AnchorGeneratorSSD:
steps: [8., 16.]
aspect_ratios: [[1.], [1.]]
min_sizes: [[16.,24.], [32., 48., 64., 80., 96., 128.]]
max_sizes: [[], []]
offset: 0.5
flip: False
min_max_aspect_ratios_order: false
BBoxPostProcess:
decode:
name: SSDBox
nms:
name: MultiClassNMS
keep_top_k: 750
score_threshold: 0.01
nms_threshold: 0.3
nms_top_k: 5000
nms_eta: 1.0
worker_num: 2
TrainReader:
inputs_def:
num_max_boxes: 90
sample_transforms:
- Decode: {}
- RandomDistort: {brightness: [0.5, 1.125, 0.875], random_apply: False}
- RandomExpand: {fill_value: [123.675, 116.28, 103.53]}
- RandomFlip: {}
- CropWithDataAchorSampling: {
anchor_sampler: [[1, 10, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.2, 0.0]],
batch_sampler: [
[1, 50, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
[1, 50, 0.3, 1.0, 1.0, 1.0, 0.0, 0.0, 1.0, 0.0],
],
target_size: 640}
- Resize: {target_size: [640, 640], keep_ratio: False, interp: 1}
- NormalizeBox: {}
- PadBox: {num_max_boxes: 90}
batch_transforms:
- NormalizeImage: {mean: [123, 117, 104], std: [127.502231, 127.502231, 127.502231], is_scale: false}
- Permute: {}
batch_size: 8
shuffle: true
drop_last: true
EvalReader:
sample_transforms:
- Decode: {}
- NormalizeImage: {mean: [123, 117, 104], std: [127.502231, 127.502231, 127.502231], is_scale: false}
- Permute: {}
batch_size: 1
drop_empty: false
TestReader:
sample_transforms:
- Decode: {}
- NormalizeImage: {mean: [123, 117, 104], std: [127.502231, 127.502231, 127.502231], is_scale: false}
- Permute: {}
batch_size: 1
epoch: 1000
LearningRate:
base_lr: 0.001
schedulers:
- !PiecewiseDecay
gamma: 0.1
milestones:
- 333
- 800
- !LinearWarmup
start_factor: 0.3333333333333333
steps: 500
OptimizerBuilder:
optimizer:
momentum: 0.0
type: RMSProp
regularizer:
factor: 0.0005
type: L2
_BASE_: [
'../datasets/wider_face.yml',
'../runtime.yml',
'_base_/optimizer_1000e.yml',
'_base_/blazeface.yml',
'_base_/face_reader.yml',
]
weights: output/blazeface_1000e/model_final
multi_scale_eval: True
...@@ -6,6 +6,9 @@ ...@@ -6,6 +6,9 @@
| :------------- | :------------- | :-----------: | :------: | :--------: |:-----: | :-----: | :----: | :----: | | :------------- | :------------- | :-----------: | :------: | :--------: |:-----: | :-----: | :----: | :----: |
| ResNet50-FPN | Faster | 1 | 2x | - | 41.9 | - | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_fpn_gn_2x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/gn/faster_rcnn_r50_fpn_gn_2x_coco.yml) | | ResNet50-FPN | Faster | 1 | 2x | - | 41.9 | - | [下载链接](https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_fpn_gn_2x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/gn/faster_rcnn_r50_fpn_gn_2x_coco.yml) |
| ResNet50-FPN | Mask | 1 | 2x | - | 42.3 | 38.4 | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_r50_fpn_gn_2x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/gn/mask_rcnn_r50_fpn_gn_2x_coco.yml) | | ResNet50-FPN | Mask | 1 | 2x | - | 42.3 | 38.4 | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_r50_fpn_gn_2x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/gn/mask_rcnn_r50_fpn_gn_2x_coco.yml) |
| ResNet50-FPN | Cascade Faster | 1 | 2x | - | - | - | [下载链接](https://paddledet.bj.bcebos.com/models/cascade_rcnn_r50_fpn_gn_2x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/gn/cascade_rcnn_r50_fpn_gn_2x_coco.yml) |
| ResNet50-FPN | Cacade Mask | 1 | 2x | - | 45.0 | 39.3 | [下载链接](https://paddledet.bj.bcebos.com/models/cascade_mask_rcnn_r50_fpn_gn_2x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/gn/cascade_mask_rcnn_r50_fpn_gn_2x_coco.yml) |
**注意:** Faster R-CNN baseline仅使用 `2fc` head,而此处使用[`4conv1fc` head](https://arxiv.org/abs/1803.08494)(4层conv之间使用GN),并且FPN也使用GN,而对于Mask R-CNN是在mask head的4层conv之间也使用GN。 **注意:** Faster R-CNN baseline仅使用 `2fc` head,而此处使用[`4conv1fc` head](https://arxiv.org/abs/1803.08494)(4层conv之间使用GN),并且FPN也使用GN,而对于Mask R-CNN是在mask head的4层conv之间也使用GN。
......
...@@ -5,7 +5,7 @@ _BASE_: [ ...@@ -5,7 +5,7 @@ _BASE_: [
'../cascade_rcnn/_base_/cascade_mask_rcnn_r50_fpn.yml', '../cascade_rcnn/_base_/cascade_mask_rcnn_r50_fpn.yml',
'../cascade_rcnn/_base_/cascade_mask_fpn_reader.yml', '../cascade_rcnn/_base_/cascade_mask_fpn_reader.yml',
] ]
weights: output/cascade_mask_rcnn_r50_fpn_gn_2x/model_final weights: output/cascade_mask_rcnn_r50_fpn_gn_2x_coco/model_final
CascadeRCNN: CascadeRCNN:
backbone: ResNet backbone: ResNet
......
...@@ -5,7 +5,7 @@ _BASE_: [ ...@@ -5,7 +5,7 @@ _BASE_: [
'../cascade_rcnn/_base_/cascade_rcnn_r50_fpn.yml', '../cascade_rcnn/_base_/cascade_rcnn_r50_fpn.yml',
'../cascade_rcnn/_base_/cascade_fpn_reader.yml', '../cascade_rcnn/_base_/cascade_fpn_reader.yml',
] ]
weights: output/cascade_rcnn_r50_fpn_gn_2x/model_final weights: output/cascade_rcnn_r50_fpn_gn_2x_coco/model_final
FPN: FPN:
out_channel: 256 out_channel: 256
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
| ResNet101-FPN | Mask | 1 | 1x | ---- | 40.6 | 36.6 | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_r101_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/mask_rcnn/mask_rcnn_r101_fpn_1x_coco.yml) | | ResNet101-FPN | Mask | 1 | 1x | ---- | 40.6 | 36.6 | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_r101_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/mask_rcnn/mask_rcnn_r101_fpn_1x_coco.yml) |
| ResNet101-vd-FPN | Mask | 1 | 1x | ---- | 42.4 | 38.1 | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_r101_vd_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/mask_rcnn/mask_rcnn_r101_vd_fpn_1x_coco.yml) | | ResNet101-vd-FPN | Mask | 1 | 1x | ---- | 42.4 | 38.1 | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_r101_vd_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/mask_rcnn/mask_rcnn_r101_vd_fpn_1x_coco.yml) |
| ResNeXt101-vd-FPN | Mask | 1 | 1x | ---- | 44.0 | 39.5 | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_x101_vd_64x4d_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/mask_rcnn/mask_rcnn_x101_vd_64x4d_fpn_1x_coco.yml) | | ResNeXt101-vd-FPN | Mask | 1 | 1x | ---- | 44.0 | 39.5 | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_x101_vd_64x4d_fpn_1x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/mask_rcnn/mask_rcnn_x101_vd_64x4d_fpn_1x_coco.yml) |
| ResNeXt101-vd-FPN | Mask | 1 | 2x | ---- | - | - | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_x101_vd_64x4d_fpn_2x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/mask_rcnn/mask_rcnn_x101_vd_64x4d_fpn_2x_coco.yml) | | ResNeXt101-vd-FPN | Mask | 1 | 2x | ---- | 44.6 | 39.8 | [下载链接](https://paddledet.bj.bcebos.com/models/mask_rcnn_x101_vd_64x4d_fpn_2x_coco.pdparams) | [配置文件](https://github.com/PaddlePaddle/PaddleDetection/tree/master/dygraph/configs/mask_rcnn/mask_rcnn_x101_vd_64x4d_fpn_2x_coco.yml) |
**注意:** Mask R-CNN模型精度依赖Paddle develop分支修改,精度复现须使用[每日版本](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/Tables.html#whl-dev)或2.0.1版本(将于2021.03发布),使用Paddle 2.0.0版本会有少量精度损失。 **注意:** Mask R-CNN模型精度依赖Paddle develop分支修改,精度复现须使用[每日版本](https://www.paddlepaddle.org.cn/documentation/docs/zh/install/Tables.html#whl-dev)或2.0.1版本(将于2021.03发布),使用Paddle 2.0.0版本会有少量精度损失。
......
# All rights `PaddleDetection` reserved
# References:
# @inproceedings{yang2016wider,
# Author = {Yang, Shuo and Luo, Ping and Loy, Chen Change and Tang, Xiaoou},
# Booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
# Title = {WIDER FACE: A Face Detection Benchmark},
# Year = {2016}}
DIR="$( cd "$(dirname "$0")" ; pwd -P )"
cd "$DIR"
# Download the data.
echo "Downloading..."
wget https://dataset.bj.bcebos.com/wider_face/WIDER_train.zip
wget https://dataset.bj.bcebos.com/wider_face/WIDER_val.zip
wget https://dataset.bj.bcebos.com/wider_face/wider_face_split.zip
# Extract the data.
echo "Extracting..."
unzip -q WIDER_train.zip
unzip -q WIDER_val.zip
unzip -q wider_face_split.zip
...@@ -16,7 +16,7 @@ import os ...@@ -16,7 +16,7 @@ import os
import numpy as np import numpy as np
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from .dataset import DataSet from .dataset import DetDataset
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__) logger = setup_logger(__name__)
...@@ -24,7 +24,7 @@ logger = setup_logger(__name__) ...@@ -24,7 +24,7 @@ logger = setup_logger(__name__)
@register @register
@serializable @serializable
class WIDERFaceDataSet(DataSet): class WIDERFaceDataSet(DetDataset):
""" """
Load WiderFace records with 'anno_path' Load WiderFace records with 'anno_path'
...@@ -39,20 +39,23 @@ class WIDERFaceDataSet(DataSet): ...@@ -39,20 +39,23 @@ class WIDERFaceDataSet(DataSet):
dataset_dir=None, dataset_dir=None,
image_dir=None, image_dir=None,
anno_path=None, anno_path=None,
data_fields=['image'],
sample_num=-1, sample_num=-1,
with_lmk=False): with_lmk=False):
super(WIDERFaceDataSet, self).__init__( super(WIDERFaceDataSet, self).__init__(
dataset_dir=dataset_dir,
image_dir=image_dir, image_dir=image_dir,
anno_path=anno_path, anno_path=anno_path,
data_fields=data_fields,
sample_num=sample_num, sample_num=sample_num,
dataset_dir=dataset_dir) with_lmk=with_lmk)
self.anno_path = anno_path self.anno_path = anno_path
self.sample_num = sample_num self.sample_num = sample_num
self.roidbs = None self.roidbs = None
self.cname2cid = None self.cname2cid = None
self.with_lmk = with_lmk self.with_lmk = with_lmk
def load_roidb_and_cname2cid(self, ): def parse_dataset(self):
anno_path = os.path.join(self.dataset_dir, self.anno_path) anno_path = os.path.join(self.dataset_dir, self.anno_path)
image_dir = os.path.join(self.dataset_dir, self.image_dir) image_dir = os.path.join(self.dataset_dir, self.image_dir)
...@@ -67,7 +70,7 @@ class WIDERFaceDataSet(DataSet): ...@@ -67,7 +70,7 @@ class WIDERFaceDataSet(DataSet):
im_fname = item[0] im_fname = item[0]
im_id = np.array([ct]) im_id = np.array([ct])
gt_bbox = np.zeros((len(item) - 1, 4), dtype=np.float32) gt_bbox = np.zeros((len(item) - 1, 4), dtype=np.float32)
gt_class = np.ones((len(item) - 1, 1), dtype=np.int32) gt_class = np.zeros((len(item) - 1, 1), dtype=np.int32)
gt_lmk_labels = np.zeros((len(item) - 1, 10), dtype=np.float32) gt_lmk_labels = np.zeros((len(item) - 1, 10), dtype=np.float32)
lmk_ignore_flag = np.zeros((len(item) - 1, 1), dtype=np.int32) lmk_ignore_flag = np.zeros((len(item) - 1, 1), dtype=np.int32)
for index_box in range(len(item)): for index_box in range(len(item)):
...@@ -82,9 +85,14 @@ class WIDERFaceDataSet(DataSet): ...@@ -82,9 +85,14 @@ class WIDERFaceDataSet(DataSet):
widerface_rec = { widerface_rec = {
'im_file': im_fname, 'im_file': im_fname,
'im_id': im_id, 'im_id': im_id,
} if 'image' in self.data_fields else {}
gt_rec = {
'gt_bbox': gt_bbox, 'gt_bbox': gt_bbox,
'gt_class': gt_class, 'gt_class': gt_class,
} }
for k, v in gt_rec.items():
if k in self.data_fields:
widerface_rec[k] = v
if self.with_lmk: if self.with_lmk:
widerface_rec['gt_keypoint'] = gt_lmk_labels widerface_rec['gt_keypoint'] = gt_lmk_labels
widerface_rec['keypoint_ignore'] = lmk_ignore_flag widerface_rec['keypoint_ignore'] = lmk_ignore_flag
...@@ -105,18 +113,24 @@ class WIDERFaceDataSet(DataSet): ...@@ -105,18 +113,24 @@ class WIDERFaceDataSet(DataSet):
file_dict = {} file_dict = {}
num_class = 0 num_class = 0
exts = ['jpg', 'jpeg', 'png', 'bmp']
exts += [ext.upper() for ext in exts]
for i in range(len(lines_input_txt)): for i in range(len(lines_input_txt)):
line_txt = lines_input_txt[i].strip('\n\t\r') line_txt = lines_input_txt[i].strip('\n\t\r')
if '.jpg' in line_txt: split_str = line_txt.split(' ')
if i != 0: if len(split_str) == 1:
num_class += 1 img_file_name = os.path.split(split_str[0])[1]
file_dict[num_class] = [] split_txt = img_file_name.split('.')
file_dict[num_class].append(line_txt) if len(split_txt) < 2:
if '.jpg' not in line_txt: continue
elif split_txt[-1] in exts:
if i != 0:
num_class += 1
file_dict[num_class] = [line_txt]
else:
if len(line_txt) <= 6: if len(line_txt) <= 6:
continue continue
result_boxs = [] result_boxs = []
split_str = line_txt.split(' ')
xmin = float(split_str[0]) xmin = float(split_str[0])
ymin = float(split_str[1]) ymin = float(split_str[1])
w = float(split_str[2]) w = float(split_str[2])
......
...@@ -994,7 +994,7 @@ class CropWithDataAchorSampling(BaseOperator): ...@@ -994,7 +994,7 @@ class CropWithDataAchorSampling(BaseOperator):
[max sample, max trial, min scale, max scale, [max sample, max trial, min scale, max scale,
min aspect ratio, max aspect ratio, min aspect ratio, max aspect ratio,
min overlap, max overlap, min coverage, max coverage] min overlap, max overlap, min coverage, max coverage]
target_size (bool): target image size. target_size (int): target image size.
das_anchor_scales (list[float]): a list of anchor scales in data das_anchor_scales (list[float]): a list of anchor scales in data
anchor smapling. anchor smapling.
min_size (float): minimum size of sampled bbox. min_size (float): minimum size of sampled bbox.
...@@ -1026,6 +1026,10 @@ class CropWithDataAchorSampling(BaseOperator): ...@@ -1026,6 +1026,10 @@ class CropWithDataAchorSampling(BaseOperator):
gt_bbox = sample['gt_bbox'] gt_bbox = sample['gt_bbox']
gt_class = sample['gt_class'] gt_class = sample['gt_class']
image_height, image_width = im.shape[:2] image_height, image_width = im.shape[:2]
gt_bbox[:, 0] /= image_width
gt_bbox[:, 1] /= image_height
gt_bbox[:, 2] /= image_width
gt_bbox[:, 3] /= image_height
gt_score = None gt_score = None
if 'gt_score' in sample: if 'gt_score' in sample:
gt_score = sample['gt_score'] gt_score = sample['gt_score']
...@@ -1073,10 +1077,16 @@ class CropWithDataAchorSampling(BaseOperator): ...@@ -1073,10 +1077,16 @@ class CropWithDataAchorSampling(BaseOperator):
continue continue
im = crop_image_sampling(im, sample_bbox, image_width, im = crop_image_sampling(im, sample_bbox, image_width,
image_height, self.target_size) image_height, self.target_size)
height, width = im.shape[:2]
crop_bbox[:, 0] *= width
crop_bbox[:, 1] *= height
crop_bbox[:, 2] *= width
crop_bbox[:, 3] *= height
sample['image'] = im sample['image'] = im
sample['gt_bbox'] = crop_bbox sample['gt_bbox'] = crop_bbox
sample['gt_class'] = crop_class sample['gt_class'] = crop_class
sample['gt_score'] = crop_score if 'gt_score' in sample:
sample['gt_score'] = crop_score
if 'gt_keypoint' in sample.keys(): if 'gt_keypoint' in sample.keys():
sample['gt_keypoint'] = gt_keypoints[0] sample['gt_keypoint'] = gt_keypoints[0]
sample['keypoint_ignore'] = gt_keypoints[1] sample['keypoint_ignore'] = gt_keypoints[1]
...@@ -1124,10 +1134,16 @@ class CropWithDataAchorSampling(BaseOperator): ...@@ -1124,10 +1134,16 @@ class CropWithDataAchorSampling(BaseOperator):
ymin = int(sample_bbox[1] * image_height) ymin = int(sample_bbox[1] * image_height)
ymax = int(sample_bbox[3] * image_height) ymax = int(sample_bbox[3] * image_height)
im = im[ymin:ymax, xmin:xmax] im = im[ymin:ymax, xmin:xmax]
height, width = im.shape[:2]
crop_bbox[:, 0] *= width
crop_bbox[:, 1] *= height
crop_bbox[:, 2] *= width
crop_bbox[:, 3] *= height
sample['image'] = im sample['image'] = im
sample['gt_bbox'] = crop_bbox sample['gt_bbox'] = crop_bbox
sample['gt_class'] = crop_class sample['gt_class'] = crop_class
sample['gt_score'] = crop_score if 'gt_score' in sample:
sample['gt_score'] = crop_score
if 'gt_keypoint' in sample.keys(): if 'gt_keypoint' in sample.keys():
sample['gt_keypoint'] = gt_keypoints[0] sample['gt_keypoint'] = gt_keypoints[0]
sample['keypoint_ignore'] = gt_keypoints[1] sample['keypoint_ignore'] = gt_keypoints[1]
......
...@@ -17,6 +17,7 @@ from __future__ import division ...@@ -17,6 +17,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import sys
import datetime import datetime
import paddle import paddle
...@@ -169,3 +170,15 @@ class Checkpointer(Callback): ...@@ -169,3 +170,15 @@ class Checkpointer(Callback):
else: else:
save_model(self.model.model, self.model.optimizer, save_dir, save_model(self.model.model, self.model.optimizer, save_dir,
save_name, epoch_id + 1) save_name, epoch_id + 1)
class WiferFaceEval(Callback):
def __init__(self, model):
super(WiferFaceEval, self).__init__(model)
def on_epoch_begin(self, status):
assert self.model.mode == 'eval', \
"WiferFaceEval can only be set during evaluation"
for metric in self.model._metrics:
metric.update(self.model.model)
sys.exit()
...@@ -31,10 +31,10 @@ from paddle.static import InputSpec ...@@ -31,10 +31,10 @@ from paddle.static import InputSpec
from ppdet.core.workspace import create from ppdet.core.workspace import create
from ppdet.utils.checkpoint import load_weight, load_pretrain_weight from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
from ppdet.utils.visualizer import visualize_results from ppdet.utils.visualizer import visualize_results
from ppdet.metrics import Metric, COCOMetric, VOCMetric, get_categories, get_infer_results from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_categories, get_infer_results
import ppdet.utils.stats as stats import ppdet.utils.stats as stats
from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval
from .export_utils import _dump_infer_config from .export_utils import _dump_infer_config
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
...@@ -90,8 +90,6 @@ class Trainer(object): ...@@ -90,8 +90,6 @@ class Trainer(object):
self.start_epoch = 0 self.start_epoch = 0
self.end_epoch = cfg.epoch self.end_epoch = cfg.epoch
self._weights_loaded = False
# initial default callbacks # initial default callbacks
self._init_callbacks() self._init_callbacks()
...@@ -105,6 +103,8 @@ class Trainer(object): ...@@ -105,6 +103,8 @@ class Trainer(object):
self._compose_callback = ComposeCallback(self._callbacks) self._compose_callback = ComposeCallback(self._callbacks)
elif self.mode == 'eval': elif self.mode == 'eval':
self._callbacks = [LogPrinter(self)] self._callbacks = [LogPrinter(self)]
if self.cfg.metric == 'WiderFace':
self._callbacks.append(WiferFaceEval(self))
self._compose_callback = ComposeCallback(self._callbacks) self._compose_callback = ComposeCallback(self._callbacks)
else: else:
self._callbacks = [] self._callbacks = []
...@@ -128,6 +128,15 @@ class Trainer(object): ...@@ -128,6 +128,15 @@ class Trainer(object):
class_num=self.cfg.num_classes, class_num=self.cfg.num_classes,
map_type=self.cfg.map_type) map_type=self.cfg.map_type)
] ]
elif self.cfg.metric == 'WiderFace':
multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True
self._metrics = [
WiderFaceMetric(
image_dir=os.path.join(self.dataset.dataset_dir,
self.dataset.image_dir),
anno_file=self.dataset.get_anno(),
multi_scale=multi_scale)
]
else: else:
logger.warn("Metric not support for metric type {}".format( logger.warn("Metric not support for metric type {}".format(
self.cfg.metric)) self.cfg.metric))
...@@ -165,15 +174,10 @@ class Trainer(object): ...@@ -165,15 +174,10 @@ class Trainer(object):
weight_type) weight_type)
logger.debug("Load {} weights {} to start training".format( logger.debug("Load {} weights {} to start training".format(
weight_type, weights)) weight_type, weights))
self._weights_loaded = True
def train(self, validate=False): def train(self, validate=False):
assert self.mode == 'train', "Model not in 'train' mode" assert self.mode == 'train', "Model not in 'train' mode"
# if no given weights loaded, load backbone pretrain weights as default
if not self._weights_loaded:
self.load_weights(self.cfg.pretrain_weights)
model = self.model model = self.model
if self.cfg.fleet: if self.cfg.fleet:
model = fleet.distributed_model(model) model = fleet.distributed_model(model)
......
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
import os import os
from ppdet.data.source.voc import pascalvoc_label from ppdet.data.source.voc import pascalvoc_label
from ppdet.data.source.widerface import widerface_label
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__) logger = setup_logger(__name__)
...@@ -75,6 +76,9 @@ def get_categories(metric_type, anno_file=None): ...@@ -75,6 +76,9 @@ def get_categories(metric_type, anno_file=None):
logger.warn("only default categories support for OID19") logger.warn("only default categories support for OID19")
return _oid19_category() return _oid19_category()
elif metric_type.lower() == 'widerface':
return _widerface_category()
else: else:
raise ValueError("unknown metric type {}".format(metric_type)) raise ValueError("unknown metric type {}".format(metric_type))
...@@ -274,6 +278,16 @@ def _vocall_category(): ...@@ -274,6 +278,16 @@ def _vocall_category():
return clsid2catid, catid2name return clsid2catid, catid2name
def _widerface_category():
label_map = widerface_label()
label_map = sorted(label_map.items(), key=lambda x: x[1])
cats = [l[0] for l in label_map]
clsid2catid = {i: i for i in range(len(cats))}
catid2name = {i: name for i, name in enumerate(cats)}
return clsid2catid, catid2name
def _oid19_category(): def _oid19_category():
clsid2catid = {k: k + 1 for k in range(500)} clsid2catid = {k: k + 1 for k in range(500)}
......
...@@ -25,17 +25,26 @@ import numpy as np ...@@ -25,17 +25,26 @@ import numpy as np
from .category import get_categories from .category import get_categories
from .map_utils import prune_zero_padding, DetectionMAP from .map_utils import prune_zero_padding, DetectionMAP
from .coco_utils import get_infer_results, cocoapi_eval from .coco_utils import get_infer_results, cocoapi_eval
from .widerface_utils import face_eval_run
from ppdet.utils.logger import setup_logger from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__) logger = setup_logger(__name__)
__all__ = ['Metric', 'COCOMetric', 'VOCMetric', 'get_infer_results'] __all__ = [
'Metric', 'COCOMetric', 'VOCMetric', 'WiderFaceMetric', 'get_infer_results'
]
class Metric(paddle.metric.Metric): class Metric(paddle.metric.Metric):
def name(self): def name(self):
return self.__class__.__name__ return self.__class__.__name__
def reset(self):
pass
def accumulate(self):
pass
# paddle.metric.Metric defined :metch:`update`, :meth:`accumulate` # paddle.metric.Metric defined :metch:`update`, :meth:`accumulate`
# :metch:`reset`, in ppdet, we also need following 2 methods: # :metch:`reset`, in ppdet, we also need following 2 methods:
...@@ -194,3 +203,21 @@ class VOCMetric(Metric): ...@@ -194,3 +203,21 @@ class VOCMetric(Metric):
def get_results(self): def get_results(self):
self.detection_map.get_map() self.detection_map.get_map()
class WiderFaceMetric(Metric):
def __init__(self, image_dir, anno_file, multi_scale=True):
self.image_dir = image_dir
self.anno_file = anno_file
self.multi_scale = multi_scale
self.clsid2catid, self.catid2name = get_categories('widerface')
def update(self, model):
face_eval_run(
model,
self.image_dir,
self.anno_file,
pred_dir='output/pred',
eval_mode='widerface',
multi_scale=self.multi_scale)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import cv2
import numpy as np
from collections import OrderedDict
import paddle
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = ['face_eval_run', 'lmk2out']
def face_eval_run(model,
image_dir,
gt_file,
pred_dir='output/pred',
eval_mode='widerface',
multi_scale=False):
# load ground truth files
with open(gt_file, 'r') as f:
gt_lines = f.readlines()
imid2path = []
pos_gt = 0
while pos_gt < len(gt_lines):
name_gt = gt_lines[pos_gt].strip('\n\t').split()[0]
imid2path.append(name_gt)
pos_gt += 1
n_gt = int(gt_lines[pos_gt].strip('\n\t').split()[0])
pos_gt += 1 + n_gt
logger.info('The ground truth file load {} images'.format(len(imid2path)))
dets_dist = OrderedDict()
for iter_id, im_path in enumerate(imid2path):
image_path = os.path.join(image_dir, im_path)
if eval_mode == 'fddb':
image_path += '.jpg'
assert os.path.exists(image_path)
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if multi_scale:
shrink, max_shrink = get_shrink(image.shape[0], image.shape[1])
det0 = detect_face(model, image, shrink)
det1 = flip_test(model, image, shrink)
[det2, det3] = multi_scale_test(model, image, max_shrink)
det4 = multi_scale_test_pyramid(model, image, max_shrink)
det = np.row_stack((det0, det1, det2, det3, det4))
dets = bbox_vote(det)
else:
dets = detect_face(model, image, 1)
if eval_mode == 'widerface':
save_widerface_bboxes(image_path, dets, pred_dir)
else:
dets_dist[im_path] = dets
if iter_id % 100 == 0:
logger.info('Test iter {}'.format(iter_id))
if eval_mode == 'fddb':
save_fddb_bboxes(dets_dist, pred_dir)
logger.info("Finish evaluation.")
def detect_face(model, image, shrink):
image_shape = [image.shape[0], image.shape[1]]
if shrink != 1:
h, w = int(image_shape[0] * shrink), int(image_shape[1] * shrink)
image = cv2.resize(image, (w, h))
image_shape = [h, w]
img = face_img_process(image)
image_shape = np.asarray([image_shape])
scale_factor = np.asarray([[shrink, shrink]])
data = {
"image": paddle.to_tensor(
img, dtype='float32'),
"im_shape": paddle.to_tensor(
image_shape, dtype='float32'),
"scale_factor": paddle.to_tensor(
scale_factor, dtype='float32')
}
model.eval()
detection = model(data)
detection = detection['bbox'].numpy()
# layout: xmin, ymin, xmax. ymax, score
if np.prod(detection.shape) == 1:
logger.info("No face detected")
return np.array([[0, 0, 0, 0, 0]])
det_conf = detection[:, 1]
det_xmin = detection[:, 2]
det_ymin = detection[:, 3]
det_xmax = detection[:, 4]
det_ymax = detection[:, 5]
det = np.column_stack((det_xmin, det_ymin, det_xmax, det_ymax, det_conf))
return det
def flip_test(model, image, shrink):
img = cv2.flip(image, 1)
det_f = detect_face(model, img, shrink)
det_t = np.zeros(det_f.shape)
img_width = image.shape[1]
det_t[:, 0] = img_width - det_f[:, 2]
det_t[:, 1] = det_f[:, 1]
det_t[:, 2] = img_width - det_f[:, 0]
det_t[:, 3] = det_f[:, 3]
det_t[:, 4] = det_f[:, 4]
return det_t
def multi_scale_test(model, image, max_shrink):
# Shrink detecting is only used to detect big faces
st = 0.5 if max_shrink >= 0.75 else 0.5 * max_shrink
det_s = detect_face(model, image, st)
index = np.where(
np.maximum(det_s[:, 2] - det_s[:, 0] + 1, det_s[:, 3] - det_s[:, 1] + 1)
> 30)[0]
det_s = det_s[index, :]
# Enlarge one times
bt = min(2, max_shrink) if max_shrink > 1 else (st + max_shrink) / 2
det_b = detect_face(model, image, bt)
# Enlarge small image x times for small faces
if max_shrink > 2:
bt *= 2
while bt < max_shrink:
det_b = np.row_stack((det_b, detect_face(model, image, bt)))
bt *= 2
det_b = np.row_stack((det_b, detect_face(model, image, max_shrink)))
# Enlarged images are only used to detect small faces.
if bt > 1:
index = np.where(
np.minimum(det_b[:, 2] - det_b[:, 0] + 1,
det_b[:, 3] - det_b[:, 1] + 1) < 100)[0]
det_b = det_b[index, :]
# Shrinked images are only used to detect big faces.
else:
index = np.where(
np.maximum(det_b[:, 2] - det_b[:, 0] + 1,
det_b[:, 3] - det_b[:, 1] + 1) > 30)[0]
det_b = det_b[index, :]
return det_s, det_b
def multi_scale_test_pyramid(model, image, max_shrink):
# Use image pyramids to detect faces
det_b = detect_face(model, image, 0.25)
index = np.where(
np.maximum(det_b[:, 2] - det_b[:, 0] + 1, det_b[:, 3] - det_b[:, 1] + 1)
> 30)[0]
det_b = det_b[index, :]
st = [0.75, 1.25, 1.5, 1.75]
for i in range(len(st)):
if st[i] <= max_shrink:
det_temp = detect_face(model, image, st[i])
# Enlarged images are only used to detect small faces.
if st[i] > 1:
index = np.where(
np.minimum(det_temp[:, 2] - det_temp[:, 0] + 1,
det_temp[:, 3] - det_temp[:, 1] + 1) < 100)[0]
det_temp = det_temp[index, :]
# Shrinked images are only used to detect big faces.
else:
index = np.where(
np.maximum(det_temp[:, 2] - det_temp[:, 0] + 1,
det_temp[:, 3] - det_temp[:, 1] + 1) > 30)[0]
det_temp = det_temp[index, :]
det_b = np.row_stack((det_b, det_temp))
return det_b
def to_chw(image):
"""
Transpose image from HWC to CHW.
Args:
image (np.array): an image with HWC layout.
"""
# HWC to CHW
if len(image.shape) == 3:
image = np.swapaxes(image, 1, 2)
image = np.swapaxes(image, 1, 0)
return image
def face_img_process(image,
mean=[104., 117., 123.],
std=[127.502231, 127.502231, 127.502231]):
img = np.array(image)
img = to_chw(img)
img = img.astype('float32')
img -= np.array(mean)[:, np.newaxis, np.newaxis].astype('float32')
img /= np.array(std)[:, np.newaxis, np.newaxis].astype('float32')
img = [img]
img = np.array(img)
return img
def get_shrink(height, width):
"""
Args:
height (int): image height.
width (int): image width.
"""
# avoid out of memory
max_shrink_v1 = (0x7fffffff / 577.0 / (height * width))**0.5
max_shrink_v2 = ((678 * 1024 * 2.0 * 2.0) / (height * width))**0.5
def get_round(x, loc):
str_x = str(x)
if '.' in str_x:
str_before, str_after = str_x.split('.')
len_after = len(str_after)
if len_after >= 3:
str_final = str_before + '.' + str_after[0:loc]
return float(str_final)
else:
return x
max_shrink = get_round(min(max_shrink_v1, max_shrink_v2), 2) - 0.3
if max_shrink >= 1.5 and max_shrink < 2:
max_shrink = max_shrink - 0.1
elif max_shrink >= 2 and max_shrink < 3:
max_shrink = max_shrink - 0.2
elif max_shrink >= 3 and max_shrink < 4:
max_shrink = max_shrink - 0.3
elif max_shrink >= 4 and max_shrink < 5:
max_shrink = max_shrink - 0.4
elif max_shrink >= 5:
max_shrink = max_shrink - 0.5
elif max_shrink <= 0.1:
max_shrink = 0.1
shrink = max_shrink if max_shrink < 1 else 1
return shrink, max_shrink
def bbox_vote(det):
order = det[:, 4].ravel().argsort()[::-1]
det = det[order, :]
if det.shape[0] == 0:
dets = np.array([[10, 10, 20, 20, 0.002]])
det = np.empty(shape=[0, 5])
while det.shape[0] > 0:
# IOU
area = (det[:, 2] - det[:, 0] + 1) * (det[:, 3] - det[:, 1] + 1)
xx1 = np.maximum(det[0, 0], det[:, 0])
yy1 = np.maximum(det[0, 1], det[:, 1])
xx2 = np.minimum(det[0, 2], det[:, 2])
yy2 = np.minimum(det[0, 3], det[:, 3])
w = np.maximum(0.0, xx2 - xx1 + 1)
h = np.maximum(0.0, yy2 - yy1 + 1)
inter = w * h
o = inter / (area[0] + area[:] - inter)
# nms
merge_index = np.where(o >= 0.3)[0]
det_accu = det[merge_index, :]
det = np.delete(det, merge_index, 0)
if merge_index.shape[0] <= 1:
if det.shape[0] == 0:
try:
dets = np.row_stack((dets, det_accu))
except:
dets = det_accu
continue
det_accu[:, 0:4] = det_accu[:, 0:4] * np.tile(det_accu[:, -1:], (1, 4))
max_score = np.max(det_accu[:, 4])
det_accu_sum = np.zeros((1, 5))
det_accu_sum[:, 0:4] = np.sum(det_accu[:, 0:4],
axis=0) / np.sum(det_accu[:, -1:])
det_accu_sum[:, 4] = max_score
try:
dets = np.row_stack((dets, det_accu_sum))
except:
dets = det_accu_sum
dets = dets[0:750, :]
keep_index = np.where(dets[:, 4] >= 0.01)[0]
dets = dets[keep_index, :]
return dets
def save_widerface_bboxes(image_path, bboxes_scores, output_dir):
image_name = image_path.split('/')[-1]
image_class = image_path.split('/')[-2]
odir = os.path.join(output_dir, image_class)
if not os.path.exists(odir):
os.makedirs(odir)
ofname = os.path.join(odir, '%s.txt' % (image_name[:-4]))
f = open(ofname, 'w')
f.write('{:s}\n'.format(image_class + '/' + image_name))
f.write('{:d}\n'.format(bboxes_scores.shape[0]))
for box_score in bboxes_scores:
xmin, ymin, xmax, ymax, score = box_score
f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'.format(xmin, ymin, (
xmax - xmin + 1), (ymax - ymin + 1), score))
f.close()
logger.info("The predicted result is saved as {}".format(ofname))
def save_fddb_bboxes(bboxes_scores,
output_dir,
output_fname='pred_fddb_res.txt'):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
predict_file = os.path.join(output_dir, output_fname)
f = open(predict_file, 'w')
for image_path, dets in bboxes_scores.iteritems():
f.write('{:s}\n'.format(image_path))
f.write('{:d}\n'.format(dets.shape[0]))
for box_score in dets:
xmin, ymin, xmax, ymax, score = box_score
width, height = xmax - xmin, ymax - ymin
f.write('{:.1f} {:.1f} {:.1f} {:.1f} {:.3f}\n'
.format(xmin, ymin, width, height, score))
logger.info("The predicted result is saved as {}".format(predict_file))
return predict_file
def lmk2out(results, is_bbox_normalized=False):
"""
Args:
results: request a dict, should include: `landmark`, `im_id`,
if is_bbox_normalized=True, also need `im_shape`.
is_bbox_normalized: whether or not landmark is normalized.
"""
xywh_res = []
for t in results:
bboxes = t['bbox'][0]
lengths = t['bbox'][1][0]
im_ids = np.array(t['im_id'][0]).flatten()
if bboxes.shape == (1, 1) or bboxes is None:
continue
face_index = t['face_index'][0]
prior_box = t['prior_boxes'][0]
predict_lmk = t['landmark'][0]
prior = np.reshape(prior_box, (-1, 4))
predictlmk = np.reshape(predict_lmk, (-1, 10))
k = 0
for a in range(len(lengths)):
num = lengths[a]
im_id = int(im_ids[a])
for i in range(num):
score = bboxes[k][1]
theindex = face_index[i][0]
me_prior = prior[theindex, :]
lmk_pred = predictlmk[theindex, :]
prior_w = me_prior[2] - me_prior[0]
prior_h = me_prior[3] - me_prior[1]
prior_w_center = (me_prior[2] + me_prior[0]) / 2
prior_h_center = (me_prior[3] + me_prior[1]) / 2
lmk_decode = np.zeros((10))
for j in [0, 2, 4, 6, 8]:
lmk_decode[j] = lmk_pred[j] * 0.1 * prior_w + prior_w_center
for j in [1, 3, 5, 7, 9]:
lmk_decode[j] = lmk_pred[j] * 0.1 * prior_h + prior_h_center
im_shape = t['im_shape'][0][a].tolist()
image_h, image_w = int(im_shape[0]), int(im_shape[1])
if is_bbox_normalized:
lmk_decode = lmk_decode * np.array([
image_w, image_h, image_w, image_h, image_w, image_h,
image_w, image_h, image_w, image_h
])
lmk_res = {
'image_id': im_id,
'landmark': lmk_decode,
'score': score,
}
xywh_res.append(lmk_res)
k += 1
return xywh_res
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import vgg from . import vgg
from . import resnet from . import resnet
from . import darknet from . import darknet
from . import mobilenet_v1 from . import mobilenet_v1
from . import mobilenet_v3 from . import mobilenet_v3
from . import hrnet from . import hrnet
from . import blazenet
from .vgg import * from .vgg import *
from .resnet import * from .resnet import *
...@@ -11,3 +26,4 @@ from .darknet import * ...@@ -11,3 +26,4 @@ from .darknet import *
from .mobilenet_v1 import * from .mobilenet_v1 import *
from .mobilenet_v3 import * from .mobilenet_v3 import *
from .hrnet import * from .hrnet import *
from .blazenet import *
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import ParamAttr
from paddle.regularizer import L2Decay
from paddle.nn.initializer import KaimingNormal
from ppdet.core.workspace import register, serializable
from numbers import Integral
from ..shape_spec import ShapeSpec
__all__ = ['BlazeNet']
class ConvBNLayer(nn.Layer):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding,
num_groups=1,
act='relu',
conv_lr=0.1,
conv_decay=0.,
norm_decay=0.,
norm_type='bn',
name=None):
super(ConvBNLayer, self).__init__()
self.act = act
self._conv = nn.Conv2D(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=num_groups,
weight_attr=ParamAttr(
learning_rate=conv_lr,
initializer=KaimingNormal(),
name=name + "_weights"),
bias_attr=False)
param_attr = ParamAttr(name=name + "_bn_scale")
bias_attr = ParamAttr(name=name + "_bn_offset")
if norm_type == 'sync_bn':
self._batch_norm = nn.SyncBatchNorm(
out_channels, weight_attr=param_attr, bias_attr=bias_attr)
else:
self._batch_norm = nn.BatchNorm(
out_channels,
act=None,
param_attr=param_attr,
bias_attr=bias_attr,
use_global_stats=False,
moving_mean_name=name + '_bn_mean',
moving_variance_name=name + '_bn_variance')
def forward(self, x):
x = self._conv(x)
x = self._batch_norm(x)
if self.act == "relu":
x = F.relu(x)
elif self.act == "relu6":
x = F.relu6(x)
return x
class BlazeBlock(nn.Layer):
def __init__(self,
in_channels,
out_channels1,
out_channels2,
double_channels=None,
stride=1,
use_5x5kernel=True,
name=None):
super(BlazeBlock, self).__init__()
assert stride in [1, 2]
self.use_pool = not stride == 1
self.use_double_block = double_channels is not None
self.conv_dw = []
if use_5x5kernel:
self.conv_dw.append(
self.add_sublayer(
name + "1_dw",
ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels1,
kernel_size=5,
stride=stride,
padding=2,
num_groups=out_channels1,
name=name + "1_dw")))
else:
self.conv_dw.append(
self.add_sublayer(
name + "1_dw_1",
ConvBNLayer(
in_channels=in_channels,
out_channels=out_channels1,
kernel_size=3,
stride=1,
padding=1,
num_groups=out_channels1,
name=name + "1_dw_1")))
self.conv_dw.append(
self.add_sublayer(
name + "1_dw_2",
ConvBNLayer(
in_channels=out_channels1,
out_channels=out_channels1,
kernel_size=3,
stride=stride,
padding=1,
num_groups=out_channels1,
name=name + "1_dw_2")))
act = 'relu' if self.use_double_block else None
self.conv_pw = ConvBNLayer(
in_channels=out_channels1,
out_channels=out_channels2,
kernel_size=1,
stride=1,
padding=0,
act=act,
name=name + "1_sep")
if self.use_double_block:
self.conv_dw2 = []
if use_5x5kernel:
self.conv_dw2.append(
self.add_sublayer(
name + "2_dw",
ConvBNLayer(
in_channels=out_channels2,
out_channels=out_channels2,
kernel_size=5,
stride=1,
padding=2,
num_groups=out_channels2,
name=name + "2_dw")))
else:
self.conv_dw2.append(
self.add_sublayer(
name + "2_dw_1",
ConvBNLayer(
in_channels=out_channels2,
out_channels=out_channels2,
kernel_size=3,
stride=1,
padding=1,
num_groups=out_channels2,
name=name + "1_dw_1")))
self.conv_dw2.append(
self.add_sublayer(
name + "2_dw_2",
ConvBNLayer(
in_channels=out_channels2,
out_channels=out_channels2,
kernel_size=3,
stride=1,
padding=1,
num_groups=out_channels2,
name=name + "2_dw_2")))
self.conv_pw2 = ConvBNLayer(
in_channels=out_channels2,
out_channels=double_channels,
kernel_size=1,
stride=1,
padding=0,
name=name + "2_sep")
# shortcut
if self.use_pool:
shortcut_channel = double_channels or out_channels2
self._shortcut = []
self._shortcut.append(
self.add_sublayer(
name + '_shortcut_pool',
nn.MaxPool2D(
kernel_size=stride, stride=stride, ceil_mode=True)))
self._shortcut.append(
self.add_sublayer(
name + '_shortcut_conv',
ConvBNLayer(
in_channels=in_channels,
out_channels=shortcut_channel,
kernel_size=1,
stride=1,
padding=0,
name="shortcut" + name)))
def forward(self, x):
y = x
for conv_dw_block in self.conv_dw:
y = conv_dw_block(y)
y = self.conv_pw(y)
if self.use_double_block:
for conv_dw2_block in self.conv_dw2:
y = conv_dw2_block(y)
y = self.conv_pw2(y)
if self.use_pool:
for shortcut in self._shortcut:
x = shortcut(x)
return F.relu(paddle.add(x, y))
@register
@serializable
class BlazeNet(nn.Layer):
"""
BlazeFace, see https://arxiv.org/abs/1907.05047
Args:
blaze_filters (list): number of filter for each blaze block.
double_blaze_filters (list): number of filter for each double_blaze block.
use_5x5kernel (bool): whether or not filter size is 5x5 in depth-wise conv.
"""
def __init__(
self,
blaze_filters=[[24, 24], [24, 24], [24, 48, 2], [48, 48], [48, 48]],
double_blaze_filters=[[48, 24, 96, 2], [96, 24, 96], [96, 24, 96],
[96, 24, 96, 2], [96, 24, 96], [96, 24, 96]],
use_5x5kernel=True):
super(BlazeNet, self).__init__()
conv1_num_filters = blaze_filters[0][0]
self.conv1 = ConvBNLayer(
in_channels=3,
out_channels=conv1_num_filters,
kernel_size=3,
stride=2,
padding=1,
name="conv1")
in_channels = conv1_num_filters
self.blaze_block = []
self._out_channels = []
for k, v in enumerate(blaze_filters):
assert len(v) in [2, 3], \
"blaze_filters {} not in [2, 3]"
if len(v) == 2:
self.blaze_block.append(
self.add_sublayer(
'blaze_{}'.format(k),
BlazeBlock(
in_channels,
v[0],
v[1],
use_5x5kernel=use_5x5kernel,
name='blaze_{}'.format(k))))
elif len(v) == 3:
self.blaze_block.append(
self.add_sublayer(
'blaze_{}'.format(k),
BlazeBlock(
in_channels,
v[0],
v[1],
stride=v[2],
use_5x5kernel=use_5x5kernel,
name='blaze_{}'.format(k))))
in_channels = v[1]
for k, v in enumerate(double_blaze_filters):
assert len(v) in [3, 4], \
"blaze_filters {} not in [3, 4]"
if len(v) == 3:
self.blaze_block.append(
self.add_sublayer(
'double_blaze_{}'.format(k),
BlazeBlock(
in_channels,
v[0],
v[1],
double_channels=v[2],
use_5x5kernel=use_5x5kernel,
name='double_blaze_{}'.format(k))))
elif len(v) == 4:
self.blaze_block.append(
self.add_sublayer(
'double_blaze_{}'.format(k),
BlazeBlock(
in_channels,
v[0],
v[1],
double_channels=v[2],
stride=v[3],
use_5x5kernel=use_5x5kernel,
name='double_blaze_{}'.format(k))))
in_channels = v[2]
self._out_channels.append(in_channels)
def forward(self, inputs):
outs = []
y = self.conv1(inputs['image'])
for block in self.blaze_block:
y = block(y)
outs.append(y)
return [outs[-4], outs[-1]]
@property
def out_shape(self):
return [
ShapeSpec(channels=c)
for c in [self._out_channels[-4], self._out_channels[-1]]
]
...@@ -21,6 +21,7 @@ from . import fcos_head ...@@ -21,6 +21,7 @@ from . import fcos_head
from . import solov2_head from . import solov2_head
from . import ttf_head from . import ttf_head
from . import cascade_head from . import cascade_head
from . import face_head
from .bbox_head import * from .bbox_head import *
from .mask_head import * from .mask_head import *
...@@ -31,3 +32,4 @@ from .fcos_head import * ...@@ -31,3 +32,4 @@ from .fcos_head import *
from .solov2_head import * from .solov2_head import *
from .ttf_head import * from .ttf_head import *
from .cascade_head import * from .cascade_head import *
from .face_head import *
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
from paddle.regularizer import L2Decay
from paddle import ParamAttr
from ..layers import AnchorGeneratorSSD
@register
class FaceHead(nn.Layer):
"""
Head block for Face detection network
Args:
num_classes (int): Number of output classes.
in_channels (int): Number of input channels.
anchor_generator(object): instance of anchor genertor method.
kernel_size (int): kernel size of Conv2D in FaceHead.
padding (int): padding of Conv2D in FaceHead.
conv_decay (float): norm_decay (float): weight decay for conv layer weights.
loss (object): loss of face detection model.
"""
__shared__ = ['num_classes']
__inject__ = ['anchor_generator', 'loss']
def __init__(self,
num_classes=80,
in_channels=(96, 96),
anchor_generator=AnchorGeneratorSSD().__dict__,
kernel_size=3,
padding=1,
conv_decay=0.,
loss='SSDLoss'):
super(FaceHead, self).__init__()
# add background class
self.num_classes = num_classes + 1
self.in_channels = in_channels
self.anchor_generator = anchor_generator
self.loss = loss
if isinstance(anchor_generator, dict):
self.anchor_generator = AnchorGeneratorSSD(**anchor_generator)
self.num_priors = self.anchor_generator.num_priors
self.box_convs = []
self.score_convs = []
for i, num_prior in enumerate(self.num_priors):
box_conv_name = "boxes{}".format(i)
box_conv = self.add_sublayer(
box_conv_name,
nn.Conv2D(
in_channels=in_channels[i],
out_channels=num_prior * 4,
kernel_size=kernel_size,
padding=padding))
self.box_convs.append(box_conv)
score_conv_name = "scores{}".format(i)
score_conv = self.add_sublayer(
score_conv_name,
nn.Conv2D(
in_channels=in_channels[i],
out_channels=num_prior * self.num_classes,
kernel_size=kernel_size,
padding=padding))
self.score_convs.append(score_conv)
@classmethod
def from_config(cls, cfg, input_shape):
return {'in_channels': [i.channels for i in input_shape], }
def forward(self, feats, image, gt_bbox=None, gt_class=None):
box_preds = []
cls_scores = []
prior_boxes = []
for feat, box_conv, score_conv in zip(feats, self.box_convs,
self.score_convs):
box_pred = box_conv(feat)
box_pred = paddle.transpose(box_pred, [0, 2, 3, 1])
box_pred = paddle.reshape(box_pred, [0, -1, 4])
box_preds.append(box_pred)
cls_score = score_conv(feat)
cls_score = paddle.transpose(cls_score, [0, 2, 3, 1])
cls_score = paddle.reshape(cls_score, [0, -1, self.num_classes])
cls_scores.append(cls_score)
prior_boxes = self.anchor_generator(feats, image)
if self.training:
return self.get_loss(box_preds, cls_scores, gt_bbox, gt_class,
prior_boxes)
else:
return (box_preds, cls_scores), prior_boxes
def get_loss(self, boxes, scores, gt_bbox, gt_class, prior_boxes):
return self.loss(boxes, scores, gt_bbox, gt_class, prior_boxes)
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
import paddle.nn.functional as F import paddle.nn.functional as F
......
...@@ -271,8 +271,12 @@ class AnchorGeneratorSSD(object): ...@@ -271,8 +271,12 @@ class AnchorGeneratorSSD(object):
self.num_priors = [] self.num_priors = []
for aspect_ratio, min_size, max_size in zip( for aspect_ratio, min_size, max_size in zip(
aspect_ratios, self.min_sizes, self.max_sizes): aspect_ratios, self.min_sizes, self.max_sizes):
self.num_priors.append((len(aspect_ratio) * 2 + 1) * len( if isinstance(min_size, (list, tuple)):
_to_list(min_size)) + len(_to_list(max_size))) self.num_priors.append(
len(_to_list(min_size)) + len(_to_list(max_size)))
else:
self.num_priors.append((len(aspect_ratio) * 2 + 1) * len(
_to_list(min_size)) + len(_to_list(max_size)))
def __call__(self, inputs, image): def __call__(self, inputs, image):
boxes = [] boxes = []
......
...@@ -136,8 +136,10 @@ def load_pretrain_weight(model, ...@@ -136,8 +136,10 @@ def load_pretrain_weight(model,
path = _strip_postfix(pretrain_weight) path = _strip_postfix(pretrain_weight)
if not (os.path.isdir(path) or os.path.isfile(path) or if not (os.path.isdir(path) or os.path.isfile(path) or
os.path.exists(path + '.pdparams')): os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not " raise ValueError("Model pretrain path `{}` does not exists. "
"exists.".format(path)) "If you don't want to load pretrain model, "
"please delete `pretrain_weights` field in "
"config file.".format(path))
model_dict = model.state_dict() model_dict = model.state_dict()
......
...@@ -91,7 +91,7 @@ def run(FLAGS, cfg): ...@@ -91,7 +91,7 @@ def run(FLAGS, cfg):
trainer = Trainer(cfg, mode='train') trainer = Trainer(cfg, mode='train')
# load weights # load weights
if not FLAGS.slim_config: if not FLAGS.slim_config and 'pretrain_weights' in cfg and cfg.pretrain_weights:
trainer.load_weights(cfg.pretrain_weights, FLAGS.weight_type) trainer.load_weights(cfg.pretrain_weights, FLAGS.weight_type)
# training # training
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册