From 23ad49103878428852ddfa5382b91498f551cd7c Mon Sep 17 00:00:00 2001 From: Feng Ni Date: Wed, 16 Nov 2022 15:19:38 +0800 Subject: [PATCH] add iou-fcos and ssod baseline (#7335) * add iou-fcos and ssod baseline * add gen_semi_coco, fix doc * fix doc links --- configs/fcos/README.md | 28 ++--- configs/fcos/fcos_r50_fpn_iou_1x_coco.yml | 78 +++++++++++++ .../fcos_r50_fpn_iou_multiscale_2x_coco.yml | 90 +++++++++++++++ configs/ssod/README.md | 109 ++++++++++++++++++ configs/ssod/baseline/README.md | 64 ++++++++++ .../faster_rcnn_r50_fpn_2x_coco_sup010.yml | 26 +++++ .../baseline/fcos_r50_fpn_2x_coco_sup005.yml | 26 +++++ .../baseline/fcos_r50_fpn_2x_coco_sup010.yml | 26 +++++ .../ppyoloe_plus_crn_s_80e_coco_sup005.yml | 29 +++++ .../ppyoloe_plus_crn_s_80e_coco_sup010.yml | 29 +++++ .../retinanet_r50_fpn_2x_coco_sup010.yml | 26 +++++ ppdet/modeling/heads/fcos_head.py | 14 ++- ppdet/modeling/losses/fcos_loss.py | 77 ++++++++++--- tools/gen_semi_coco.py | 102 ++++++++++++++++ 14 files changed, 686 insertions(+), 38 deletions(-) create mode 100644 configs/fcos/fcos_r50_fpn_iou_1x_coco.yml create mode 100644 configs/fcos/fcos_r50_fpn_iou_multiscale_2x_coco.yml create mode 100644 configs/ssod/README.md create mode 100644 configs/ssod/baseline/README.md create mode 100644 configs/ssod/baseline/faster_rcnn_r50_fpn_2x_coco_sup010.yml create mode 100644 configs/ssod/baseline/fcos_r50_fpn_2x_coco_sup005.yml create mode 100644 configs/ssod/baseline/fcos_r50_fpn_2x_coco_sup010.yml create mode 100644 configs/ssod/baseline/ppyoloe_plus_crn_s_80e_coco_sup005.yml create mode 100644 configs/ssod/baseline/ppyoloe_plus_crn_s_80e_coco_sup010.yml create mode 100644 configs/ssod/baseline/retinanet_r50_fpn_2x_coco_sup010.yml create mode 100644 tools/gen_semi_coco.py diff --git a/configs/fcos/README.md b/configs/fcos/README.md index cdd433423..44c043440 100644 --- a/configs/fcos/README.md +++ b/configs/fcos/README.md @@ -1,24 +1,18 @@ -# FCOS for Object Detection +# FCOS (Fully Convolutional One-Stage Object Detection) -## Introduction +## Model Zoo on COCO -FCOS (Fully Convolutional One-Stage Object Detection) is a fast anchor-free object detection framework with strong performance. We reproduced the model of the paper, and improved and optimized the accuracy of the FCOS. +| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 | +| :------------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: | +| ResNet50-FPN | FCOS | 2 | 1x | ---- | 39.6 | [download](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_1x_coco.pdparams) | [config](./fcos_r50_fpn_1x_coco.yml) | +| ResNet50-FPN | FCOS + iou | 2 | 1x | ---- | 40.0 | [download](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_iou_1x_coco.pdparams) | [config](./fcos_r50_fpn_iou_1x_coco.yml) | +| ResNet50-FPN | FCOS + DCN | 2 | 1x | ---- | 44.3 | [download](https://paddledet.bj.bcebos.com/models/fcos_dcn_r50_fpn_1x_coco.pdparams) | [config](./fcos_dcn_r50_fpn_1x_coco.yml) | +| ResNet50-FPN | FCOS + multiscale_train | 2 | 2x | ---- | 41.8 | [download](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_multiscale_2x_coco.pdparams) | [config](./fcos_r50_fpn_multiscale_2x_coco.yml) | +| ResNet50-FPN | FCOS + multiscale_train + iou | 2 | 2x | ---- | 42.6 | [download](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_iou_multiscale_2x_coco.pdparams) | [config](./fcos_r50_fpn_iou_multiscale_2x_coco.yml) | -**Highlights:** +**注意:** + - `+ iou` 表示与原版 FCOS 相比,不使用 `centerness` 而是使用 `iou` 来参与计算loss。 -- Training Time: The training time of the model of `fcos_r50_fpn_1x` on Tesla v100 with 8 GPU is only 8.5 hours. - -## Model Zoo - -| Backbone | Model | images/GPU | lr schedule |FPS | Box AP | download | config | -| :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: | -| ResNet50-FPN | FCOS | 2 | 1x | ---- | 39.6 | [download](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/fcos/fcos_r50_fpn_1x_coco.yml) | -| ResNet50-FPN | FCOS+DCN | 2 | 1x | ---- | 44.3 | [download](https://paddledet.bj.bcebos.com/models/fcos_dcn_r50_fpn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/fcos/fcos_dcn_r50_fpn_1x_coco.yml) | -| ResNet50-FPN | FCOS+multiscale_train | 2 | 2x | ---- | 41.8 | [download](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_multiscale_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/fcos/fcos_r50_fpn_multiscale_2x_coco.yml) | - -**Notes:** - -- FCOS is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`. ## Citations ``` diff --git a/configs/fcos/fcos_r50_fpn_iou_1x_coco.yml b/configs/fcos/fcos_r50_fpn_iou_1x_coco.yml new file mode 100644 index 000000000..943c5bc04 --- /dev/null +++ b/configs/fcos/fcos_r50_fpn_iou_1x_coco.yml @@ -0,0 +1,78 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/fcos_r50_fpn.yml', + '_base_/optimizer_1x.yml', + '_base_/fcos_reader.yml', +] + +weights: output/fcos_r50_fpn_iou_1x_coco/model_final + + +TrainReader: + sample_transforms: + - Decode: {} + - RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], keep_ratio: True, interp: 1} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} + - RandomFlip: {} + batch_transforms: + - Permute: {} + - PadBatch: {pad_to_stride: 32} + - Gt2FCOSTarget: + object_sizes_boundary: [64, 128, 256, 512] + center_sampling_radius: 1.5 + downsample_ratios: [8, 16, 32, 64, 128] + norm_reg_targets: True + batch_size: 2 + shuffle: True + drop_last: True + + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [800, 1333], keep_ratio: True, interp: 1} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + + +TestReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [800, 1333], keep_ratio: True, interp: 1} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + fuse_normalize: True + + +FCOSHead: + fcos_feat: + name: FCOSFeat + feat_in: 256 + feat_out: 256 + num_convs: 4 + norm_type: "gn" + use_dcn: False + fpn_stride: [8, 16, 32, 64, 128] + prior_prob: 0.01 + norm_reg_targets: True + centerness_on_reg: True + fcos_loss: + name: FCOSLoss + loss_alpha: 0.25 + loss_gamma: 2.0 + iou_loss_type: "giou" + reg_weights: 1.0 + quality: "iou" # default 'centerness' + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 100 + score_threshold: 0.025 + nms_threshold: 0.6 diff --git a/configs/fcos/fcos_r50_fpn_iou_multiscale_2x_coco.yml b/configs/fcos/fcos_r50_fpn_iou_multiscale_2x_coco.yml new file mode 100644 index 000000000..3f6a327db --- /dev/null +++ b/configs/fcos/fcos_r50_fpn_iou_multiscale_2x_coco.yml @@ -0,0 +1,90 @@ +_BASE_: [ + '../datasets/coco_detection.yml', + '../runtime.yml', + '_base_/fcos_r50_fpn.yml', + '_base_/optimizer_1x.yml', + '_base_/fcos_reader.yml', +] + +weights: output/fcos_r50_fpn_iou_multiscale_2x_coco_010/model_final + +TrainReader: + sample_transforms: + - Decode: {} + - RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], keep_ratio: True, interp: 1} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} + - RandomFlip: {} + batch_transforms: + - Permute: {} + - PadBatch: {pad_to_stride: 32} + - Gt2FCOSTarget: + object_sizes_boundary: [64, 128, 256, 512] + center_sampling_radius: 1.5 + downsample_ratios: [8, 16, 32, 64, 128] + norm_reg_targets: True + batch_size: 2 + shuffle: True + drop_last: True + + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [800, 1333], keep_ratio: True, interp: 1} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + + +TestReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [800, 1333], keep_ratio: True, interp: 1} + - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True} + - Permute: {} + batch_transforms: + - PadBatch: {pad_to_stride: 32} + batch_size: 1 + fuse_normalize: True + + +epoch: 24 + +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [16, 22] + - !LinearWarmup + start_factor: 0.001 + steps: 1000 + + +FCOSHead: + fcos_feat: + name: FCOSFeat + feat_in: 256 + feat_out: 256 + num_convs: 4 + norm_type: "gn" + use_dcn: False + fpn_stride: [8, 16, 32, 64, 128] + prior_prob: 0.01 + norm_reg_targets: True + centerness_on_reg: True + fcos_loss: + name: FCOSLoss + loss_alpha: 0.25 + loss_gamma: 2.0 + iou_loss_type: "giou" + reg_weights: 1.0 + quality: "iou" # default 'centerness' + nms: + name: MultiClassNMS + nms_top_k: 1000 + keep_top_k: 100 + score_threshold: 0.025 + nms_threshold: 0.6 diff --git a/configs/ssod/README.md b/configs/ssod/README.md new file mode 100644 index 000000000..f0462882c --- /dev/null +++ b/configs/ssod/README.md @@ -0,0 +1,109 @@ +简体中文 | [English](README_en.md) + +# Semi-Supervised Object Detection (SSOD) 半监督目标检测 + +## 内容 +- [简介](#简介) +- [模型库](#模型库) +- [数据集准备](#数据集准备) +- [引用](#引用) + +## 简介 +半监督目标检测(SSOD)是**同时使用有标注数据和无标注数据**进行训练的目标检测,既可以极大地节省标注成本,也可以充分利用无标注数据进一步提高检测精度。 + + +## 模型库 + +### [Baseline](baseline) + +**纯监督数据**模型的训练和模型库,请参照[Baseline](baseline); + + + +## 数据集准备 + +半监督目标检测**同时需要有标注数据和无标注数据**,且无标注数据量一般**远多于有标注数据量**。 +对于COCO数据集一般有两种常规设置: + +(1)抽取部分比例的原始训练集`train2017`作为标注数据和无标注数据; + +从`train2017`中按固定百分比(1%、2%、5%、10%等)抽取,由于抽取方法会对半监督训练的结果影响较大,所以采用五折交叉验证来评估。运行数据集划分制作的脚本如下: +```bash +python tools/gen_semi_coco.py +``` +会按照 1%、2%、5%、10% 的监督数据比例来划分`train2017`全集,为了交叉验证每一种划分会随机重复5次,生成的半监督标注文件如下: +- 标注数据集标注:`instances_train2017.{fold}@{percent}.json` +- 无标注数据集标注:`instances_train2017.{fold}@{percent}-unlabeled.json` +其中,`fold` 表示交叉验证,`percent` 表示有标注数据的百分比。 + +(2)使用全量原始训练集`train2017`作为有标注数据 和 全量原始无标签图片集`unlabeled2017`作为无标注数据; + + +### 下载链接 + +PaddleDetection团队提供了COCO数据集全部的标注文件,请下载并解压存放至对应目录: + +```shell +# 下载COCO全量数据集图片和标注 +# 包括 train2017, val2017, annotations +wget https://bj.bcebos.com/v1/paddledet/data/coco.tar + +# 下载PaddleDetection团队整理的COCO部分比例数据的标注文件 +wget https://bj.bcebos.com/v1/paddledet/data/coco/semi_annotations.zip + +# unlabeled2017是可选,如果不需要训‘full’则无需下载 +# 下载COCO全量 unlabeled 无标注数据集 +wget https://bj.bcebos.com/v1/paddledet/data/coco/unlabeled2017.zip +wget https://bj.bcebos.com/v1/paddledet/data/coco/image_info_unlabeled2017.zip +# 下载转换完的 unlabeled2017 无标注json文件 +wget https://bj.bcebos.com/v1/paddledet/data/coco/instances_unlabeled2017.zip +``` + +如果需要用到COCO全量unlabeled无标注数据集,需要将原版的`image_info_unlabeled2017.json`进行格式转换,运行以下代码: + +
+ COCO unlabeled 标注转换代码: + +```python +import json +anns_train = json.load(open('annotations/instances_train2017.json', 'r')) +anns_unlabeled = json.load(open('annotations/image_info_unlabeled2017.json', 'r')) +unlabeled_json = { + 'images': anns_unlabeled['images'], + 'annotations': [], + 'categories': anns_train['categories'], +} +path = 'annotations/instances_unlabeled2017.json' +with open(path, 'w') as f: + json.dump(unlabeled_json, f) +``` + +
+ + +
+ 解压后的数据集目录如下: + +``` +PaddleDetection +├── dataset +│ ├── coco +│ │ ├── annotations +│ │ │ ├── instances_train2017.json +│ │ │ ├── instances_unlabeled2017.json +│ │ │ ├── instances_val2017.json +│ │ ├── semi_annotations +│ │ │ ├── instances_train2017.1@1.json +│ │ │ ├── instances_train2017.1@1-unlabeled.json +│ │ │ ├── instances_train2017.1@2.json +│ │ │ ├── instances_train2017.1@2-unlabeled.json +│ │ │ ├── instances_train2017.1@5.json +│ │ │ ├── instances_train2017.1@5-unlabeled.json +│ │ │ ├── instances_train2017.1@10.json +│ │ │ ├── instances_train2017.1@10-unlabeled.json +│ │ ├── train2017 +│ │ ├── unlabeled2017 +│ │ ├── val2017 +``` + +
diff --git a/configs/ssod/baseline/README.md b/configs/ssod/baseline/README.md new file mode 100644 index 000000000..1453b8b22 --- /dev/null +++ b/configs/ssod/baseline/README.md @@ -0,0 +1,64 @@ +# Supervised Baseline 纯监督模型基线 + +## COCO数据集模型库 + +### [FCOS](../../fcos) + +| 基础模型 | 监督数据比例 | mAPval
0.5:0.95 | 模型下载 | 配置文件 | +| :---------------: | :-------------: | :---------------------: |:--------: | :---------: | +| FCOS ResNet50-FPN | 5% | 21.3 | [download](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_2x_coco_sup005.pdparams) | [config](fcos_r50_fpn_2x_coco_sup005.yml) | +| FCOS ResNet50-FPN | 10% | 26.3 | [download](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_2x_coco_sup010.pdparams) | [config](fcos_r50_fpn_2x_coco_sup010.yml) | +| FCOS ResNet50-FPN | full | 42.6 | [download](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_iou_multiscale_2x_coco.pdparams) | [config](../../fcos/fcos_r50_fpn_iou_multiscale_2x_coco.yml) | + + +### [PP-YOLOE+](../../ppyoloe) + +| 基础模型 | 监督数据比例 | mAPval
0.5:0.95 | 模型下载 | 配置文件 | +| :---------------: | :-------------: | :---------------------: |:--------: | :---------: | +| PP-YOLOE+_s | 5% | 32.8 | [download](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_s_80e_coco_sup005.pdparams) | [config](ppyoloe_plus_crn_s_80e_coco_sup005.yml) | +| PP-YOLOE+_s | 10% | 35.3 | [download](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_s_80e_coco_sup010.pdparams) | [config](ppyoloe_plus_crn_s_80e_coco_sup010.yml) | +| PP-YOLOE+_s | full | 43.7 | [download](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_s_80e_coco.pdparams) | [config](../../ppyoloe/ppyoloe_plus_crn_s_80e_coco.yml) | + + +### [Faster R-CNN](../../faster_rcnn) + +| 基础模型 | 监督数据比例 | mAPval
0.5:0.95 | 模型下载 | 配置文件 | +| :---------------: | :-------------: | :---------------------: |:--------: | :---------: | +| Faster R-CNN ResNet50-FPN | 10% | 25.6 | [download](https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_fpn_2x_coco_sup010.pdparams) | [config](faster_rcnn_r50_fpn_2x_coco_sup010.yml) | +| Faster R-CNN ResNet50-FPN | full | 40.0 | [download](https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_fpn_2x_coco.pdparams) | [config](../../configs/faster_rcnn/faster_rcnn_r50_fpn_2x_coco.yml) | + + +### [RetinaNet](../../retinanet) + +| 基础模型 | 监督数据比例 | mAPval
0.5:0.95 | 模型下载 | 配置文件 | +| :---------------: | :-------------: | :---------------------: |:--------: | :---------: | +| RetinaNet ResNet50-FPN | 10% | 23.6 | [download](https://paddledet.bj.bcebos.com/models/retinanet_r50_fpn_2x_coco_sup010.pdparams) | [config](retinanet_r50_fpn_2x_coco_sup010.yml) | +| RetinaNet ResNet50-FPN | full | 37.5(1x) | [download](https://paddledet.bj.bcebos.com/models/retinanet_r50_fpn_1x_coco.pdparams) | [config](../../configs/retinanet/retinanet_r50_fpn_1x_coco.yml) | + + +**注意:** + - COCO部分监督数据集请参照 [数据集准备](../README.md) 去下载和准备,各个比例的训练集均为**从train2017中抽取部分百分比的子集**,默认使用`fold`号为1的划分子集,`sup010`表示抽取10%的监督数据训练,`sup005`表示抽取5%,`full`表示全部train2017,验证集均为val2017全量; + - 抽取部分百分比的监督数据的抽法不同,或使用的`fold`号不同,精度都会因此而有约0.5 mAP之多的差异; + - PP-YOLOE+ 使用Objects365预训练,其余模型均使用ImageNet预训练; + - PP-YOLOE+ 训练80 epoch,其余模型均训练24 epoch,; + + +## 使用教程 + +将以下命令写在一个脚本文件里如```run.sh```,一键运行命令为:```sh run.sh```,也可命令行一句句去运行: + +```bash +model_type=ssod/baseline +job_name=ppyoloe_plus_crn_s_80e_coco_sup010 # 可修改,如 fcos_r50_fpn_2x_coco_sup010 + +config=configs/${model_type}/${job_name}.yml +log_dir=log_dir/${job_name} +weights=output/${job_name}/model_final.pdparams + +# 1.training +# CUDA_VISIBLE_DEVICES=0 python tools/train.py -c ${config} +python -m paddle.distributed.launch --log_dir=${log_dir} --gpus 0,1,2,3,4,5,6,7 tools/train.py -c ${config} --eval + +# 2.eval +CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c ${config} -o weights=${weights} +``` diff --git a/configs/ssod/baseline/faster_rcnn_r50_fpn_2x_coco_sup010.yml b/configs/ssod/baseline/faster_rcnn_r50_fpn_2x_coco_sup010.yml new file mode 100644 index 000000000..345b083a7 --- /dev/null +++ b/configs/ssod/baseline/faster_rcnn_r50_fpn_2x_coco_sup010.yml @@ -0,0 +1,26 @@ +_BASE_: [ + '../../faster_rcnn/faster_rcnn_r50_fpn_2x_coco.yml', +] +log_iter: 50 +snapshot_epoch: 2 +weights: output/faster_rcnn_r50_fpn_2x_coco_sup010/model_final + + +TrainDataset: + !COCODataSet + image_dir: train2017 + anno_path: semi_annotations/instances_train2017.1@10.json + dataset_dir: dataset/coco + data_fields: ['image', 'gt_bbox', 'gt_class'] + + +epoch: 24 +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [16, 22] + - !LinearWarmup + start_factor: 0.1 + steps: 500 diff --git a/configs/ssod/baseline/fcos_r50_fpn_2x_coco_sup005.yml b/configs/ssod/baseline/fcos_r50_fpn_2x_coco_sup005.yml new file mode 100644 index 000000000..a85b10429 --- /dev/null +++ b/configs/ssod/baseline/fcos_r50_fpn_2x_coco_sup005.yml @@ -0,0 +1,26 @@ +_BASE_: [ + '../../fcos/fcos_r50_fpn_iou_multiscale_2x_coco.yml', +] +log_iter: 50 +snapshot_epoch: 2 +weights: output/fcos_r50_fpn_2x_coco_sup005/model_final + + +TrainDataset: + !COCODataSet + image_dir: train2017 + anno_path: semi_annotations/instances_train2017.1@5.json + dataset_dir: dataset/coco + data_fields: ['image', 'gt_bbox', 'gt_class'] + + +epoch: 24 +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [16, 22] + - !LinearWarmup + start_factor: 0.001 + steps: 1000 diff --git a/configs/ssod/baseline/fcos_r50_fpn_2x_coco_sup010.yml b/configs/ssod/baseline/fcos_r50_fpn_2x_coco_sup010.yml new file mode 100644 index 000000000..dc44de406 --- /dev/null +++ b/configs/ssod/baseline/fcos_r50_fpn_2x_coco_sup010.yml @@ -0,0 +1,26 @@ +_BASE_: [ + '../../fcos/fcos_r50_fpn_iou_multiscale_2x_coco.yml', +] +log_iter: 50 +snapshot_epoch: 2 +weights: output/fcos_r50_fpn_2x_coco_sup010/model_final + + +TrainDataset: + !COCODataSet + image_dir: train2017 + anno_path: semi_annotations/instances_train2017.1@10.json + dataset_dir: dataset/coco + data_fields: ['image', 'gt_bbox', 'gt_class'] + + +epoch: 24 +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [16, 22] + - !LinearWarmup + start_factor: 0.001 + steps: 1000 diff --git a/configs/ssod/baseline/ppyoloe_plus_crn_s_80e_coco_sup005.yml b/configs/ssod/baseline/ppyoloe_plus_crn_s_80e_coco_sup005.yml new file mode 100644 index 000000000..88de96dcc --- /dev/null +++ b/configs/ssod/baseline/ppyoloe_plus_crn_s_80e_coco_sup005.yml @@ -0,0 +1,29 @@ +_BASE_: [ + '../../ppyoloe/ppyoloe_plus_crn_s_80e_coco.yml', +] +log_iter: 50 +snapshot_epoch: 5 +weights: output/ppyoloe_plus_crn_s_80e_coco_sup005/model_final + +pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_s_obj365_pretrained.pdparams +depth_mult: 0.33 +width_mult: 0.50 + + +TrainDataset: + !COCODataSet + image_dir: train2017 + anno_path: semi_annotations/instances_train2017.1@5.json + dataset_dir: dataset/coco + data_fields: ['image', 'gt_bbox', 'gt_class'] + + +epoch: 80 +LearningRate: + base_lr: 0.001 + schedulers: + - !CosineDecay + max_epochs: 96 + - !LinearWarmup + start_factor: 0. + epochs: 5 diff --git a/configs/ssod/baseline/ppyoloe_plus_crn_s_80e_coco_sup010.yml b/configs/ssod/baseline/ppyoloe_plus_crn_s_80e_coco_sup010.yml new file mode 100644 index 000000000..aeb9435a0 --- /dev/null +++ b/configs/ssod/baseline/ppyoloe_plus_crn_s_80e_coco_sup010.yml @@ -0,0 +1,29 @@ +_BASE_: [ + '../../ppyoloe/ppyoloe_plus_crn_s_80e_coco.yml', +] +log_iter: 50 +snapshot_epoch: 5 +weights: output/ppyoloe_plus_crn_s_80e_coco_sup010/model_final + +pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_s_obj365_pretrained.pdparams +depth_mult: 0.33 +width_mult: 0.50 + + +TrainDataset: + !COCODataSet + image_dir: train2017 + anno_path: semi_annotations/instances_train2017.1@10.json + dataset_dir: dataset/coco + data_fields: ['image', 'gt_bbox', 'gt_class'] + + +epoch: 80 +LearningRate: + base_lr: 0.001 + schedulers: + - !CosineDecay + max_epochs: 96 + - !LinearWarmup + start_factor: 0. + epochs: 5 diff --git a/configs/ssod/baseline/retinanet_r50_fpn_2x_coco_sup010.yml b/configs/ssod/baseline/retinanet_r50_fpn_2x_coco_sup010.yml new file mode 100644 index 000000000..9b9cc72bc --- /dev/null +++ b/configs/ssod/baseline/retinanet_r50_fpn_2x_coco_sup010.yml @@ -0,0 +1,26 @@ +_BASE_: [ + '../../retinanet/retinanet_r50_fpn_1x_coco.yml', +] +log_iter: 50 +snapshot_epoch: 2 +weights: output/retinanet_r50_fpn_2x_coco_sup010/model_final + + +TrainDataset: + !COCODataSet + image_dir: train2017 + anno_path: semi_annotations/instances_train2017.1@10.json + dataset_dir: dataset/coco + data_fields: ['image', 'gt_bbox', 'gt_class'] + + +epoch: 24 +LearningRate: + base_lr: 0.01 + schedulers: + - !PiecewiseDecay + gamma: 0.1 + milestones: [16, 22] + - !LinearWarmup + start_factor: 0.001 + steps: 500 diff --git a/ppdet/modeling/heads/fcos_head.py b/ppdet/modeling/heads/fcos_head.py index 5c22244ee..79b69f08d 100644 --- a/ppdet/modeling/heads/fcos_head.py +++ b/ppdet/modeling/heads/fcos_head.py @@ -139,6 +139,7 @@ class FCOSHead(nn.Layer): norm_reg_targets=True, centerness_on_reg=True, num_shift=0.5, + sqrt_score=False, fcos_loss='FCOSLoss', nms='MultiClassNMS', trt=False): @@ -154,6 +155,7 @@ class FCOSHead(nn.Layer): self.nms = nms if isinstance(self.nms, MultiClassNMS) and trt: self.nms.trt = trt + self.sqrt_score = sqrt_score conv_cls_name = "fcos_head_cls" bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob) @@ -296,10 +298,17 @@ class FCOSHead(nn.Layer): tag_labels, tag_bboxes, tag_centerness) return losses_fcos - def _post_process_by_level(self, locations, box_cls, box_reg, box_ctn): + def _post_process_by_level(self, + locations, + box_cls, + box_reg, + box_ctn, + sqrt_score=False): box_scores = F.sigmoid(box_cls).flatten(2).transpose([0, 2, 1]) box_centerness = F.sigmoid(box_ctn).flatten(2).transpose([0, 2, 1]) pred_scores = box_scores * box_centerness + if sqrt_score: + pred_scores = paddle.sqrt(pred_scores) box_reg_ch_last = box_reg.flatten(2).transpose([0, 2, 1]) box_reg_decoding = paddle.stack( @@ -320,7 +329,8 @@ class FCOSHead(nn.Layer): for pts, cls, reg, ctn in zip(locations, cls_logits, bboxes_reg, centerness): - scores, boxes = self._post_process_by_level(pts, cls, reg, ctn) + scores, boxes = self._post_process_by_level(pts, cls, reg, ctn, + self.sqrt_score) pred_scores.append(scores) pred_bboxes.append(boxes) pred_bboxes = paddle.concat(pred_bboxes, axis=1) diff --git a/ppdet/modeling/losses/fcos_loss.py b/ppdet/modeling/losses/fcos_loss.py index c8d600573..0cd6b581b 100644 --- a/ppdet/modeling/losses/fcos_loss.py +++ b/ppdet/modeling/losses/fcos_loss.py @@ -53,20 +53,28 @@ class FCOSLoss(nn.Layer): loss_gamma (float): gamma in focal loss iou_loss_type (str): location loss type, IoU/GIoU/LINEAR_IoU reg_weights (float): weight for location loss + quality (str): quality branch, centerness/iou """ def __init__(self, loss_alpha=0.25, loss_gamma=2.0, iou_loss_type="giou", - reg_weights=1.0): + reg_weights=1.0, + quality='centerness'): super(FCOSLoss, self).__init__() self.loss_alpha = loss_alpha self.loss_gamma = loss_gamma self.iou_loss_type = iou_loss_type self.reg_weights = reg_weights + self.quality = quality - def __iou_loss(self, pred, targets, positive_mask, weights=None): + def __iou_loss(self, + pred, + targets, + positive_mask, + weights=None, + return_iou=False): """ Calculate the loss for location prediction Args: @@ -108,6 +116,9 @@ class FCOSLoss(nn.Layer): area_predict + area_target - area_inter + 1.0) ious = ious * positive_mask + if return_iou: + return ious + if self.iou_loss_type.lower() == "linear_iou": loss = 1.0 - ious elif self.iou_loss_type.lower() == "giou": @@ -201,25 +212,53 @@ class FCOSLoss(nn.Layer): cls_loss = F.sigmoid_focal_loss( cls_logits_flatten, tag_labels_flatten_bin) / num_positive_fp32 - # 2. bboxes_reg: giou_loss - mask_positive_float = paddle.squeeze(mask_positive_float, axis=-1) - tag_center_flatten = paddle.squeeze(tag_center_flatten, axis=-1) - reg_loss = self.__iou_loss( - bboxes_reg_flatten, - tag_bboxes_flatten, - mask_positive_float, - weights=tag_center_flatten) - reg_loss = reg_loss * mask_positive_float / normalize_sum - - # 3. centerness: sigmoid_cross_entropy_with_logits_loss - centerness_flatten = paddle.squeeze(centerness_flatten, axis=-1) - ctn_loss = ops.sigmoid_cross_entropy_with_logits(centerness_flatten, - tag_center_flatten) - ctn_loss = ctn_loss * mask_positive_float / num_positive_fp32 + if self.quality == 'centerness': + # 2. bboxes_reg: giou_loss + mask_positive_float = paddle.squeeze(mask_positive_float, axis=-1) + tag_center_flatten = paddle.squeeze(tag_center_flatten, axis=-1) + reg_loss = self.__iou_loss( + bboxes_reg_flatten, # [61570, 4] + tag_bboxes_flatten, + mask_positive_float, # [61570] sum 57 + weights=tag_center_flatten + ) # [61570] tag_center_flatten.sum()=34.43262482 + reg_loss = reg_loss * mask_positive_float / normalize_sum # 34.43262482 + + # 3. centerness: sigmoid_cross_entropy_with_logits_loss + centerness_flatten = paddle.squeeze(centerness_flatten, axis=-1) + quality_loss = ops.sigmoid_cross_entropy_with_logits( + centerness_flatten, tag_center_flatten) + quality_loss = quality_loss * mask_positive_float / num_positive_fp32 + + elif self.quality == 'iou': + # 2. bboxes_reg: giou_loss + mask_positive_float = paddle.squeeze(mask_positive_float, axis=-1) + tag_center_flatten = paddle.squeeze(tag_center_flatten, axis=-1) + reg_loss = self.__iou_loss( + bboxes_reg_flatten, + tag_bboxes_flatten, + mask_positive_float, + weights=None) + reg_loss = reg_loss * mask_positive_float / num_positive_fp32 + # num_positive_fp32 is num_foreground + + # 3. centerness: sigmoid_cross_entropy_with_logits_loss + centerness_flatten = paddle.squeeze(centerness_flatten, axis=-1) + gt_ious = self.__iou_loss( + bboxes_reg_flatten, + tag_bboxes_flatten, + mask_positive_float, + weights=None, + return_iou=True) + quality_loss = ops.sigmoid_cross_entropy_with_logits( + centerness_flatten, gt_ious) + quality_loss = quality_loss * mask_positive_float / num_positive_fp32 + else: + raise Exception(f'Unknown quality type: {self.quality}') loss_all = { - "loss_centerness": paddle.sum(ctn_loss), "loss_cls": paddle.sum(cls_loss), - "loss_box": paddle.sum(reg_loss) + "loss_box": paddle.sum(reg_loss), + "loss_quality": paddle.sum(quality_loss), } return loss_all diff --git a/tools/gen_semi_coco.py b/tools/gen_semi_coco.py new file mode 100644 index 000000000..acacb5861 --- /dev/null +++ b/tools/gen_semi_coco.py @@ -0,0 +1,102 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import argparse +import numpy as np + + +def save_json(path, images, annotations, categories): + new_json = { + 'images': images, + 'annotations': annotations, + 'categories': categories, + } + with open(path, 'w') as f: + json.dump(new_json, f) + print('{} saved, with {} images and {} annotations.'.format( + path, len(images), len(annotations))) + + +def gen_semi_data(data_dir, + json_file, + percent=10.0, + seed=1, + seed_offset=0, + txt_file=None): + json_name = json_file.split('/')[-1].split('.')[0] + json_file = os.path.join(data_dir, json_file) + anno = json.load(open(json_file, 'r')) + categories = anno['categories'] + all_images = anno['images'] + all_anns = anno['annotations'] + print( + 'Totally {} images and {} annotations, about {} gts per image.'.format( + len(all_images), len(all_anns), len(all_anns) / len(all_images))) + + if txt_file: + print('Using percent {} and seed {}.'.format(percent, seed)) + txt_file = os.path.join(data_dir, txt_file) + sup_idx = json.load(open(txt_file, 'r'))[str(percent)][str(seed)] + # max(sup_idx) = 117262 # 10%, sup_idx is not image_id + else: + np.random.seed(seed + seed_offset) + sup_len = int(percent / 100.0 * len(all_images)) + sup_idx = np.random.choice( + range(len(all_images)), size=sup_len, replace=False) + labeled_images, labeled_anns = [], [] + labeled_im_ids = [] + unlabeled_images, unlabeled_anns = [], [] + + for i in range(len(all_images)): + if i in sup_idx: + labeled_im_ids.append(all_images[i]['id']) + labeled_images.append(all_images[i]) + else: + unlabeled_images.append(all_images[i]) + + for an in all_anns: + im_id = an['image_id'] + if im_id in labeled_im_ids: + labeled_anns.append(an) + else: + continue + + save_path = '{}/{}'.format(data_dir, 'semi_annotations') + if not os.path.exists(save_path): + os.mkdir(save_path) + + sup_name = '{}.{}@{}.json'.format(json_name, seed, int(percent)) + sup_path = os.path.join(save_path, sup_name) + save_json(sup_path, labeled_images, labeled_anns, categories) + + unsup_name = '{}.{}@{}-unlabeled.json'.format(json_name, seed, int(percent)) + unsup_path = os.path.join(save_path, unsup_name) + save_json(unsup_path, unlabeled_images, unlabeled_anns, categories) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--data_dir', type=str, default='./dataset/coco') + parser.add_argument( + '--json_file', type=str, default='annotations/instances_train2017.json') + parser.add_argument('--percent', type=float, default=10.0) + parser.add_argument('--seed', type=int, default=1) + parser.add_argument('--seed_offset', type=int, default=0) + parser.add_argument('--txt_file', type=str, default='COCO_supervision.txt') + args = parser.parse_args() + print(args) + gen_semi_data(args.data_dir, args.json_file, args.percent, args.seed, + args.seed_offset, args.txt_file) -- GitLab