diff --git a/configs/fcos/README.md b/configs/fcos/README.md
index cdd4334235a30283ac9b8c9902098fdc94364c11..44c043440f1dc75e94e433d495bbeb05a218dff1 100644
--- a/configs/fcos/README.md
+++ b/configs/fcos/README.md
@@ -1,24 +1,18 @@
-# FCOS for Object Detection
+# FCOS (Fully Convolutional One-Stage Object Detection)
-## Introduction
+## Model Zoo on COCO
-FCOS (Fully Convolutional One-Stage Object Detection) is a fast anchor-free object detection framework with strong performance. We reproduced the model of the paper, and improved and optimized the accuracy of the FCOS.
+| 骨架网络 | 网络类型 | 每张GPU图片个数 | 学习率策略 |推理时间(fps) | Box AP | 下载 | 配置文件 |
+| :------------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: |
+| ResNet50-FPN | FCOS | 2 | 1x | ---- | 39.6 | [download](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_1x_coco.pdparams) | [config](./fcos_r50_fpn_1x_coco.yml) |
+| ResNet50-FPN | FCOS + iou | 2 | 1x | ---- | 40.0 | [download](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_iou_1x_coco.pdparams) | [config](./fcos_r50_fpn_iou_1x_coco.yml) |
+| ResNet50-FPN | FCOS + DCN | 2 | 1x | ---- | 44.3 | [download](https://paddledet.bj.bcebos.com/models/fcos_dcn_r50_fpn_1x_coco.pdparams) | [config](./fcos_dcn_r50_fpn_1x_coco.yml) |
+| ResNet50-FPN | FCOS + multiscale_train | 2 | 2x | ---- | 41.8 | [download](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_multiscale_2x_coco.pdparams) | [config](./fcos_r50_fpn_multiscale_2x_coco.yml) |
+| ResNet50-FPN | FCOS + multiscale_train + iou | 2 | 2x | ---- | 42.6 | [download](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_iou_multiscale_2x_coco.pdparams) | [config](./fcos_r50_fpn_iou_multiscale_2x_coco.yml) |
-**Highlights:**
+**注意:**
+ - `+ iou` 表示与原版 FCOS 相比,不使用 `centerness` 而是使用 `iou` 来参与计算loss。
-- Training Time: The training time of the model of `fcos_r50_fpn_1x` on Tesla v100 with 8 GPU is only 8.5 hours.
-
-## Model Zoo
-
-| Backbone | Model | images/GPU | lr schedule |FPS | Box AP | download | config |
-| :-------------- | :------------- | :-----: | :-----: | :------------: | :-----: | :-----------------------------------------------------: | :-----: |
-| ResNet50-FPN | FCOS | 2 | 1x | ---- | 39.6 | [download](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/fcos/fcos_r50_fpn_1x_coco.yml) |
-| ResNet50-FPN | FCOS+DCN | 2 | 1x | ---- | 44.3 | [download](https://paddledet.bj.bcebos.com/models/fcos_dcn_r50_fpn_1x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/fcos/fcos_dcn_r50_fpn_1x_coco.yml) |
-| ResNet50-FPN | FCOS+multiscale_train | 2 | 2x | ---- | 41.8 | [download](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_multiscale_2x_coco.pdparams) | [config](https://github.com/PaddlePaddle/PaddleDetection/tree/develop/configs/fcos/fcos_r50_fpn_multiscale_2x_coco.yml) |
-
-**Notes:**
-
-- FCOS is trained on COCO train2017 dataset and evaluated on val2017 results of `mAP(IoU=0.5:0.95)`.
## Citations
```
diff --git a/configs/fcos/fcos_r50_fpn_iou_1x_coco.yml b/configs/fcos/fcos_r50_fpn_iou_1x_coco.yml
new file mode 100644
index 0000000000000000000000000000000000000000..943c5bc04dedb4759ef64b8eeffd2e1b1b074fb2
--- /dev/null
+++ b/configs/fcos/fcos_r50_fpn_iou_1x_coco.yml
@@ -0,0 +1,78 @@
+_BASE_: [
+ '../datasets/coco_detection.yml',
+ '../runtime.yml',
+ '_base_/fcos_r50_fpn.yml',
+ '_base_/optimizer_1x.yml',
+ '_base_/fcos_reader.yml',
+]
+
+weights: output/fcos_r50_fpn_iou_1x_coco/model_final
+
+
+TrainReader:
+ sample_transforms:
+ - Decode: {}
+ - RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], keep_ratio: True, interp: 1}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - RandomFlip: {}
+ batch_transforms:
+ - Permute: {}
+ - PadBatch: {pad_to_stride: 32}
+ - Gt2FCOSTarget:
+ object_sizes_boundary: [64, 128, 256, 512]
+ center_sampling_radius: 1.5
+ downsample_ratios: [8, 16, 32, 64, 128]
+ norm_reg_targets: True
+ batch_size: 2
+ shuffle: True
+ drop_last: True
+
+
+EvalReader:
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: [800, 1333], keep_ratio: True, interp: 1}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_transforms:
+ - PadBatch: {pad_to_stride: 32}
+ batch_size: 1
+
+
+TestReader:
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: [800, 1333], keep_ratio: True, interp: 1}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_transforms:
+ - PadBatch: {pad_to_stride: 32}
+ batch_size: 1
+ fuse_normalize: True
+
+
+FCOSHead:
+ fcos_feat:
+ name: FCOSFeat
+ feat_in: 256
+ feat_out: 256
+ num_convs: 4
+ norm_type: "gn"
+ use_dcn: False
+ fpn_stride: [8, 16, 32, 64, 128]
+ prior_prob: 0.01
+ norm_reg_targets: True
+ centerness_on_reg: True
+ fcos_loss:
+ name: FCOSLoss
+ loss_alpha: 0.25
+ loss_gamma: 2.0
+ iou_loss_type: "giou"
+ reg_weights: 1.0
+ quality: "iou" # default 'centerness'
+ nms:
+ name: MultiClassNMS
+ nms_top_k: 1000
+ keep_top_k: 100
+ score_threshold: 0.025
+ nms_threshold: 0.6
diff --git a/configs/fcos/fcos_r50_fpn_iou_multiscale_2x_coco.yml b/configs/fcos/fcos_r50_fpn_iou_multiscale_2x_coco.yml
new file mode 100644
index 0000000000000000000000000000000000000000..3f6a327db268501c05447a02805753012b76e9b3
--- /dev/null
+++ b/configs/fcos/fcos_r50_fpn_iou_multiscale_2x_coco.yml
@@ -0,0 +1,90 @@
+_BASE_: [
+ '../datasets/coco_detection.yml',
+ '../runtime.yml',
+ '_base_/fcos_r50_fpn.yml',
+ '_base_/optimizer_1x.yml',
+ '_base_/fcos_reader.yml',
+]
+
+weights: output/fcos_r50_fpn_iou_multiscale_2x_coco_010/model_final
+
+TrainReader:
+ sample_transforms:
+ - Decode: {}
+ - RandomResize: {target_size: [[640, 1333], [672, 1333], [704, 1333], [736, 1333], [768, 1333], [800, 1333]], keep_ratio: True, interp: 1}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - RandomFlip: {}
+ batch_transforms:
+ - Permute: {}
+ - PadBatch: {pad_to_stride: 32}
+ - Gt2FCOSTarget:
+ object_sizes_boundary: [64, 128, 256, 512]
+ center_sampling_radius: 1.5
+ downsample_ratios: [8, 16, 32, 64, 128]
+ norm_reg_targets: True
+ batch_size: 2
+ shuffle: True
+ drop_last: True
+
+
+EvalReader:
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: [800, 1333], keep_ratio: True, interp: 1}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_transforms:
+ - PadBatch: {pad_to_stride: 32}
+ batch_size: 1
+
+
+TestReader:
+ sample_transforms:
+ - Decode: {}
+ - Resize: {target_size: [800, 1333], keep_ratio: True, interp: 1}
+ - NormalizeImage: {mean: [0.485, 0.456, 0.406], std: [0.229, 0.224, 0.225], is_scale: True}
+ - Permute: {}
+ batch_transforms:
+ - PadBatch: {pad_to_stride: 32}
+ batch_size: 1
+ fuse_normalize: True
+
+
+epoch: 24
+
+LearningRate:
+ base_lr: 0.01
+ schedulers:
+ - !PiecewiseDecay
+ gamma: 0.1
+ milestones: [16, 22]
+ - !LinearWarmup
+ start_factor: 0.001
+ steps: 1000
+
+
+FCOSHead:
+ fcos_feat:
+ name: FCOSFeat
+ feat_in: 256
+ feat_out: 256
+ num_convs: 4
+ norm_type: "gn"
+ use_dcn: False
+ fpn_stride: [8, 16, 32, 64, 128]
+ prior_prob: 0.01
+ norm_reg_targets: True
+ centerness_on_reg: True
+ fcos_loss:
+ name: FCOSLoss
+ loss_alpha: 0.25
+ loss_gamma: 2.0
+ iou_loss_type: "giou"
+ reg_weights: 1.0
+ quality: "iou" # default 'centerness'
+ nms:
+ name: MultiClassNMS
+ nms_top_k: 1000
+ keep_top_k: 100
+ score_threshold: 0.025
+ nms_threshold: 0.6
diff --git a/configs/ssod/README.md b/configs/ssod/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f0462882cc10f3f67c64204390b0d967a280a630
--- /dev/null
+++ b/configs/ssod/README.md
@@ -0,0 +1,109 @@
+简体中文 | [English](README_en.md)
+
+# Semi-Supervised Object Detection (SSOD) 半监督目标检测
+
+## 内容
+- [简介](#简介)
+- [模型库](#模型库)
+- [数据集准备](#数据集准备)
+- [引用](#引用)
+
+## 简介
+半监督目标检测(SSOD)是**同时使用有标注数据和无标注数据**进行训练的目标检测,既可以极大地节省标注成本,也可以充分利用无标注数据进一步提高检测精度。
+
+
+## 模型库
+
+### [Baseline](baseline)
+
+**纯监督数据**模型的训练和模型库,请参照[Baseline](baseline);
+
+
+
+## 数据集准备
+
+半监督目标检测**同时需要有标注数据和无标注数据**,且无标注数据量一般**远多于有标注数据量**。
+对于COCO数据集一般有两种常规设置:
+
+(1)抽取部分比例的原始训练集`train2017`作为标注数据和无标注数据;
+
+从`train2017`中按固定百分比(1%、2%、5%、10%等)抽取,由于抽取方法会对半监督训练的结果影响较大,所以采用五折交叉验证来评估。运行数据集划分制作的脚本如下:
+```bash
+python tools/gen_semi_coco.py
+```
+会按照 1%、2%、5%、10% 的监督数据比例来划分`train2017`全集,为了交叉验证每一种划分会随机重复5次,生成的半监督标注文件如下:
+- 标注数据集标注:`instances_train2017.{fold}@{percent}.json`
+- 无标注数据集标注:`instances_train2017.{fold}@{percent}-unlabeled.json`
+其中,`fold` 表示交叉验证,`percent` 表示有标注数据的百分比。
+
+(2)使用全量原始训练集`train2017`作为有标注数据 和 全量原始无标签图片集`unlabeled2017`作为无标注数据;
+
+
+### 下载链接
+
+PaddleDetection团队提供了COCO数据集全部的标注文件,请下载并解压存放至对应目录:
+
+```shell
+# 下载COCO全量数据集图片和标注
+# 包括 train2017, val2017, annotations
+wget https://bj.bcebos.com/v1/paddledet/data/coco.tar
+
+# 下载PaddleDetection团队整理的COCO部分比例数据的标注文件
+wget https://bj.bcebos.com/v1/paddledet/data/coco/semi_annotations.zip
+
+# unlabeled2017是可选,如果不需要训‘full’则无需下载
+# 下载COCO全量 unlabeled 无标注数据集
+wget https://bj.bcebos.com/v1/paddledet/data/coco/unlabeled2017.zip
+wget https://bj.bcebos.com/v1/paddledet/data/coco/image_info_unlabeled2017.zip
+# 下载转换完的 unlabeled2017 无标注json文件
+wget https://bj.bcebos.com/v1/paddledet/data/coco/instances_unlabeled2017.zip
+```
+
+如果需要用到COCO全量unlabeled无标注数据集,需要将原版的`image_info_unlabeled2017.json`进行格式转换,运行以下代码:
+
+
+ COCO unlabeled 标注转换代码:
+
+```python
+import json
+anns_train = json.load(open('annotations/instances_train2017.json', 'r'))
+anns_unlabeled = json.load(open('annotations/image_info_unlabeled2017.json', 'r'))
+unlabeled_json = {
+ 'images': anns_unlabeled['images'],
+ 'annotations': [],
+ 'categories': anns_train['categories'],
+}
+path = 'annotations/instances_unlabeled2017.json'
+with open(path, 'w') as f:
+ json.dump(unlabeled_json, f)
+```
+
+
+
+
+
+ 解压后的数据集目录如下:
+
+```
+PaddleDetection
+├── dataset
+│ ├── coco
+│ │ ├── annotations
+│ │ │ ├── instances_train2017.json
+│ │ │ ├── instances_unlabeled2017.json
+│ │ │ ├── instances_val2017.json
+│ │ ├── semi_annotations
+│ │ │ ├── instances_train2017.1@1.json
+│ │ │ ├── instances_train2017.1@1-unlabeled.json
+│ │ │ ├── instances_train2017.1@2.json
+│ │ │ ├── instances_train2017.1@2-unlabeled.json
+│ │ │ ├── instances_train2017.1@5.json
+│ │ │ ├── instances_train2017.1@5-unlabeled.json
+│ │ │ ├── instances_train2017.1@10.json
+│ │ │ ├── instances_train2017.1@10-unlabeled.json
+│ │ ├── train2017
+│ │ ├── unlabeled2017
+│ │ ├── val2017
+```
+
+
diff --git a/configs/ssod/baseline/README.md b/configs/ssod/baseline/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1453b8b22ef102d1f10e6adcb75ccaac7eaefa6a
--- /dev/null
+++ b/configs/ssod/baseline/README.md
@@ -0,0 +1,64 @@
+# Supervised Baseline 纯监督模型基线
+
+## COCO数据集模型库
+
+### [FCOS](../../fcos)
+
+| 基础模型 | 监督数据比例 | mAPval
0.5:0.95 | 模型下载 | 配置文件 |
+| :---------------: | :-------------: | :---------------------: |:--------: | :---------: |
+| FCOS ResNet50-FPN | 5% | 21.3 | [download](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_2x_coco_sup005.pdparams) | [config](fcos_r50_fpn_2x_coco_sup005.yml) |
+| FCOS ResNet50-FPN | 10% | 26.3 | [download](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_2x_coco_sup010.pdparams) | [config](fcos_r50_fpn_2x_coco_sup010.yml) |
+| FCOS ResNet50-FPN | full | 42.6 | [download](https://paddledet.bj.bcebos.com/models/fcos_r50_fpn_iou_multiscale_2x_coco.pdparams) | [config](../../fcos/fcos_r50_fpn_iou_multiscale_2x_coco.yml) |
+
+
+### [PP-YOLOE+](../../ppyoloe)
+
+| 基础模型 | 监督数据比例 | mAPval
0.5:0.95 | 模型下载 | 配置文件 |
+| :---------------: | :-------------: | :---------------------: |:--------: | :---------: |
+| PP-YOLOE+_s | 5% | 32.8 | [download](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_s_80e_coco_sup005.pdparams) | [config](ppyoloe_plus_crn_s_80e_coco_sup005.yml) |
+| PP-YOLOE+_s | 10% | 35.3 | [download](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_s_80e_coco_sup010.pdparams) | [config](ppyoloe_plus_crn_s_80e_coco_sup010.yml) |
+| PP-YOLOE+_s | full | 43.7 | [download](https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_s_80e_coco.pdparams) | [config](../../ppyoloe/ppyoloe_plus_crn_s_80e_coco.yml) |
+
+
+### [Faster R-CNN](../../faster_rcnn)
+
+| 基础模型 | 监督数据比例 | mAPval
0.5:0.95 | 模型下载 | 配置文件 |
+| :---------------: | :-------------: | :---------------------: |:--------: | :---------: |
+| Faster R-CNN ResNet50-FPN | 10% | 25.6 | [download](https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_fpn_2x_coco_sup010.pdparams) | [config](faster_rcnn_r50_fpn_2x_coco_sup010.yml) |
+| Faster R-CNN ResNet50-FPN | full | 40.0 | [download](https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_fpn_2x_coco.pdparams) | [config](../../configs/faster_rcnn/faster_rcnn_r50_fpn_2x_coco.yml) |
+
+
+### [RetinaNet](../../retinanet)
+
+| 基础模型 | 监督数据比例 | mAPval
0.5:0.95 | 模型下载 | 配置文件 |
+| :---------------: | :-------------: | :---------------------: |:--------: | :---------: |
+| RetinaNet ResNet50-FPN | 10% | 23.6 | [download](https://paddledet.bj.bcebos.com/models/retinanet_r50_fpn_2x_coco_sup010.pdparams) | [config](retinanet_r50_fpn_2x_coco_sup010.yml) |
+| RetinaNet ResNet50-FPN | full | 37.5(1x) | [download](https://paddledet.bj.bcebos.com/models/retinanet_r50_fpn_1x_coco.pdparams) | [config](../../configs/retinanet/retinanet_r50_fpn_1x_coco.yml) |
+
+
+**注意:**
+ - COCO部分监督数据集请参照 [数据集准备](../README.md) 去下载和准备,各个比例的训练集均为**从train2017中抽取部分百分比的子集**,默认使用`fold`号为1的划分子集,`sup010`表示抽取10%的监督数据训练,`sup005`表示抽取5%,`full`表示全部train2017,验证集均为val2017全量;
+ - 抽取部分百分比的监督数据的抽法不同,或使用的`fold`号不同,精度都会因此而有约0.5 mAP之多的差异;
+ - PP-YOLOE+ 使用Objects365预训练,其余模型均使用ImageNet预训练;
+ - PP-YOLOE+ 训练80 epoch,其余模型均训练24 epoch,;
+
+
+## 使用教程
+
+将以下命令写在一个脚本文件里如```run.sh```,一键运行命令为:```sh run.sh```,也可命令行一句句去运行:
+
+```bash
+model_type=ssod/baseline
+job_name=ppyoloe_plus_crn_s_80e_coco_sup010 # 可修改,如 fcos_r50_fpn_2x_coco_sup010
+
+config=configs/${model_type}/${job_name}.yml
+log_dir=log_dir/${job_name}
+weights=output/${job_name}/model_final.pdparams
+
+# 1.training
+# CUDA_VISIBLE_DEVICES=0 python tools/train.py -c ${config}
+python -m paddle.distributed.launch --log_dir=${log_dir} --gpus 0,1,2,3,4,5,6,7 tools/train.py -c ${config} --eval
+
+# 2.eval
+CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c ${config} -o weights=${weights}
+```
diff --git a/configs/ssod/baseline/faster_rcnn_r50_fpn_2x_coco_sup010.yml b/configs/ssod/baseline/faster_rcnn_r50_fpn_2x_coco_sup010.yml
new file mode 100644
index 0000000000000000000000000000000000000000..345b083a70feae74dfb9410aeab782c8368506b5
--- /dev/null
+++ b/configs/ssod/baseline/faster_rcnn_r50_fpn_2x_coco_sup010.yml
@@ -0,0 +1,26 @@
+_BASE_: [
+ '../../faster_rcnn/faster_rcnn_r50_fpn_2x_coco.yml',
+]
+log_iter: 50
+snapshot_epoch: 2
+weights: output/faster_rcnn_r50_fpn_2x_coco_sup010/model_final
+
+
+TrainDataset:
+ !COCODataSet
+ image_dir: train2017
+ anno_path: semi_annotations/instances_train2017.1@10.json
+ dataset_dir: dataset/coco
+ data_fields: ['image', 'gt_bbox', 'gt_class']
+
+
+epoch: 24
+LearningRate:
+ base_lr: 0.01
+ schedulers:
+ - !PiecewiseDecay
+ gamma: 0.1
+ milestones: [16, 22]
+ - !LinearWarmup
+ start_factor: 0.1
+ steps: 500
diff --git a/configs/ssod/baseline/fcos_r50_fpn_2x_coco_sup005.yml b/configs/ssod/baseline/fcos_r50_fpn_2x_coco_sup005.yml
new file mode 100644
index 0000000000000000000000000000000000000000..a85b104293a588acfd252e67dc1982980a4205f0
--- /dev/null
+++ b/configs/ssod/baseline/fcos_r50_fpn_2x_coco_sup005.yml
@@ -0,0 +1,26 @@
+_BASE_: [
+ '../../fcos/fcos_r50_fpn_iou_multiscale_2x_coco.yml',
+]
+log_iter: 50
+snapshot_epoch: 2
+weights: output/fcos_r50_fpn_2x_coco_sup005/model_final
+
+
+TrainDataset:
+ !COCODataSet
+ image_dir: train2017
+ anno_path: semi_annotations/instances_train2017.1@5.json
+ dataset_dir: dataset/coco
+ data_fields: ['image', 'gt_bbox', 'gt_class']
+
+
+epoch: 24
+LearningRate:
+ base_lr: 0.01
+ schedulers:
+ - !PiecewiseDecay
+ gamma: 0.1
+ milestones: [16, 22]
+ - !LinearWarmup
+ start_factor: 0.001
+ steps: 1000
diff --git a/configs/ssod/baseline/fcos_r50_fpn_2x_coco_sup010.yml b/configs/ssod/baseline/fcos_r50_fpn_2x_coco_sup010.yml
new file mode 100644
index 0000000000000000000000000000000000000000..dc44de406a4e93fd951ccb7c77bb27887ac044f8
--- /dev/null
+++ b/configs/ssod/baseline/fcos_r50_fpn_2x_coco_sup010.yml
@@ -0,0 +1,26 @@
+_BASE_: [
+ '../../fcos/fcos_r50_fpn_iou_multiscale_2x_coco.yml',
+]
+log_iter: 50
+snapshot_epoch: 2
+weights: output/fcos_r50_fpn_2x_coco_sup010/model_final
+
+
+TrainDataset:
+ !COCODataSet
+ image_dir: train2017
+ anno_path: semi_annotations/instances_train2017.1@10.json
+ dataset_dir: dataset/coco
+ data_fields: ['image', 'gt_bbox', 'gt_class']
+
+
+epoch: 24
+LearningRate:
+ base_lr: 0.01
+ schedulers:
+ - !PiecewiseDecay
+ gamma: 0.1
+ milestones: [16, 22]
+ - !LinearWarmup
+ start_factor: 0.001
+ steps: 1000
diff --git a/configs/ssod/baseline/ppyoloe_plus_crn_s_80e_coco_sup005.yml b/configs/ssod/baseline/ppyoloe_plus_crn_s_80e_coco_sup005.yml
new file mode 100644
index 0000000000000000000000000000000000000000..88de96dcc44b7e5c2e7bee42c14ab6358ff308d1
--- /dev/null
+++ b/configs/ssod/baseline/ppyoloe_plus_crn_s_80e_coco_sup005.yml
@@ -0,0 +1,29 @@
+_BASE_: [
+ '../../ppyoloe/ppyoloe_plus_crn_s_80e_coco.yml',
+]
+log_iter: 50
+snapshot_epoch: 5
+weights: output/ppyoloe_plus_crn_s_80e_coco_sup005/model_final
+
+pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_s_obj365_pretrained.pdparams
+depth_mult: 0.33
+width_mult: 0.50
+
+
+TrainDataset:
+ !COCODataSet
+ image_dir: train2017
+ anno_path: semi_annotations/instances_train2017.1@5.json
+ dataset_dir: dataset/coco
+ data_fields: ['image', 'gt_bbox', 'gt_class']
+
+
+epoch: 80
+LearningRate:
+ base_lr: 0.001
+ schedulers:
+ - !CosineDecay
+ max_epochs: 96
+ - !LinearWarmup
+ start_factor: 0.
+ epochs: 5
diff --git a/configs/ssod/baseline/ppyoloe_plus_crn_s_80e_coco_sup010.yml b/configs/ssod/baseline/ppyoloe_plus_crn_s_80e_coco_sup010.yml
new file mode 100644
index 0000000000000000000000000000000000000000..aeb9435a0fee9ab502185f81a6b3710443471c89
--- /dev/null
+++ b/configs/ssod/baseline/ppyoloe_plus_crn_s_80e_coco_sup010.yml
@@ -0,0 +1,29 @@
+_BASE_: [
+ '../../ppyoloe/ppyoloe_plus_crn_s_80e_coco.yml',
+]
+log_iter: 50
+snapshot_epoch: 5
+weights: output/ppyoloe_plus_crn_s_80e_coco_sup010/model_final
+
+pretrain_weights: https://bj.bcebos.com/v1/paddledet/models/pretrained/ppyoloe_crn_s_obj365_pretrained.pdparams
+depth_mult: 0.33
+width_mult: 0.50
+
+
+TrainDataset:
+ !COCODataSet
+ image_dir: train2017
+ anno_path: semi_annotations/instances_train2017.1@10.json
+ dataset_dir: dataset/coco
+ data_fields: ['image', 'gt_bbox', 'gt_class']
+
+
+epoch: 80
+LearningRate:
+ base_lr: 0.001
+ schedulers:
+ - !CosineDecay
+ max_epochs: 96
+ - !LinearWarmup
+ start_factor: 0.
+ epochs: 5
diff --git a/configs/ssod/baseline/retinanet_r50_fpn_2x_coco_sup010.yml b/configs/ssod/baseline/retinanet_r50_fpn_2x_coco_sup010.yml
new file mode 100644
index 0000000000000000000000000000000000000000..9b9cc72bc3ca1bc65b05bbf2692216b3c5a0de3f
--- /dev/null
+++ b/configs/ssod/baseline/retinanet_r50_fpn_2x_coco_sup010.yml
@@ -0,0 +1,26 @@
+_BASE_: [
+ '../../retinanet/retinanet_r50_fpn_1x_coco.yml',
+]
+log_iter: 50
+snapshot_epoch: 2
+weights: output/retinanet_r50_fpn_2x_coco_sup010/model_final
+
+
+TrainDataset:
+ !COCODataSet
+ image_dir: train2017
+ anno_path: semi_annotations/instances_train2017.1@10.json
+ dataset_dir: dataset/coco
+ data_fields: ['image', 'gt_bbox', 'gt_class']
+
+
+epoch: 24
+LearningRate:
+ base_lr: 0.01
+ schedulers:
+ - !PiecewiseDecay
+ gamma: 0.1
+ milestones: [16, 22]
+ - !LinearWarmup
+ start_factor: 0.001
+ steps: 500
diff --git a/ppdet/modeling/heads/fcos_head.py b/ppdet/modeling/heads/fcos_head.py
index 5c22244eed4525ee97e08864cf93f7d1ed519e66..79b69f08d09f57541d37c31a0bf6a8eb6a05db3b 100644
--- a/ppdet/modeling/heads/fcos_head.py
+++ b/ppdet/modeling/heads/fcos_head.py
@@ -139,6 +139,7 @@ class FCOSHead(nn.Layer):
norm_reg_targets=True,
centerness_on_reg=True,
num_shift=0.5,
+ sqrt_score=False,
fcos_loss='FCOSLoss',
nms='MultiClassNMS',
trt=False):
@@ -154,6 +155,7 @@ class FCOSHead(nn.Layer):
self.nms = nms
if isinstance(self.nms, MultiClassNMS) and trt:
self.nms.trt = trt
+ self.sqrt_score = sqrt_score
conv_cls_name = "fcos_head_cls"
bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob)
@@ -296,10 +298,17 @@ class FCOSHead(nn.Layer):
tag_labels, tag_bboxes, tag_centerness)
return losses_fcos
- def _post_process_by_level(self, locations, box_cls, box_reg, box_ctn):
+ def _post_process_by_level(self,
+ locations,
+ box_cls,
+ box_reg,
+ box_ctn,
+ sqrt_score=False):
box_scores = F.sigmoid(box_cls).flatten(2).transpose([0, 2, 1])
box_centerness = F.sigmoid(box_ctn).flatten(2).transpose([0, 2, 1])
pred_scores = box_scores * box_centerness
+ if sqrt_score:
+ pred_scores = paddle.sqrt(pred_scores)
box_reg_ch_last = box_reg.flatten(2).transpose([0, 2, 1])
box_reg_decoding = paddle.stack(
@@ -320,7 +329,8 @@ class FCOSHead(nn.Layer):
for pts, cls, reg, ctn in zip(locations, cls_logits, bboxes_reg,
centerness):
- scores, boxes = self._post_process_by_level(pts, cls, reg, ctn)
+ scores, boxes = self._post_process_by_level(pts, cls, reg, ctn,
+ self.sqrt_score)
pred_scores.append(scores)
pred_bboxes.append(boxes)
pred_bboxes = paddle.concat(pred_bboxes, axis=1)
diff --git a/ppdet/modeling/losses/fcos_loss.py b/ppdet/modeling/losses/fcos_loss.py
index c8d6005739bbbf70d307d8bd7fbc3a0e48c037cb..0cd6b581b8c32085b50725827d44a04c9facbb43 100644
--- a/ppdet/modeling/losses/fcos_loss.py
+++ b/ppdet/modeling/losses/fcos_loss.py
@@ -53,20 +53,28 @@ class FCOSLoss(nn.Layer):
loss_gamma (float): gamma in focal loss
iou_loss_type (str): location loss type, IoU/GIoU/LINEAR_IoU
reg_weights (float): weight for location loss
+ quality (str): quality branch, centerness/iou
"""
def __init__(self,
loss_alpha=0.25,
loss_gamma=2.0,
iou_loss_type="giou",
- reg_weights=1.0):
+ reg_weights=1.0,
+ quality='centerness'):
super(FCOSLoss, self).__init__()
self.loss_alpha = loss_alpha
self.loss_gamma = loss_gamma
self.iou_loss_type = iou_loss_type
self.reg_weights = reg_weights
+ self.quality = quality
- def __iou_loss(self, pred, targets, positive_mask, weights=None):
+ def __iou_loss(self,
+ pred,
+ targets,
+ positive_mask,
+ weights=None,
+ return_iou=False):
"""
Calculate the loss for location prediction
Args:
@@ -108,6 +116,9 @@ class FCOSLoss(nn.Layer):
area_predict + area_target - area_inter + 1.0)
ious = ious * positive_mask
+ if return_iou:
+ return ious
+
if self.iou_loss_type.lower() == "linear_iou":
loss = 1.0 - ious
elif self.iou_loss_type.lower() == "giou":
@@ -201,25 +212,53 @@ class FCOSLoss(nn.Layer):
cls_loss = F.sigmoid_focal_loss(
cls_logits_flatten, tag_labels_flatten_bin) / num_positive_fp32
- # 2. bboxes_reg: giou_loss
- mask_positive_float = paddle.squeeze(mask_positive_float, axis=-1)
- tag_center_flatten = paddle.squeeze(tag_center_flatten, axis=-1)
- reg_loss = self.__iou_loss(
- bboxes_reg_flatten,
- tag_bboxes_flatten,
- mask_positive_float,
- weights=tag_center_flatten)
- reg_loss = reg_loss * mask_positive_float / normalize_sum
-
- # 3. centerness: sigmoid_cross_entropy_with_logits_loss
- centerness_flatten = paddle.squeeze(centerness_flatten, axis=-1)
- ctn_loss = ops.sigmoid_cross_entropy_with_logits(centerness_flatten,
- tag_center_flatten)
- ctn_loss = ctn_loss * mask_positive_float / num_positive_fp32
+ if self.quality == 'centerness':
+ # 2. bboxes_reg: giou_loss
+ mask_positive_float = paddle.squeeze(mask_positive_float, axis=-1)
+ tag_center_flatten = paddle.squeeze(tag_center_flatten, axis=-1)
+ reg_loss = self.__iou_loss(
+ bboxes_reg_flatten, # [61570, 4]
+ tag_bboxes_flatten,
+ mask_positive_float, # [61570] sum 57
+ weights=tag_center_flatten
+ ) # [61570] tag_center_flatten.sum()=34.43262482
+ reg_loss = reg_loss * mask_positive_float / normalize_sum # 34.43262482
+
+ # 3. centerness: sigmoid_cross_entropy_with_logits_loss
+ centerness_flatten = paddle.squeeze(centerness_flatten, axis=-1)
+ quality_loss = ops.sigmoid_cross_entropy_with_logits(
+ centerness_flatten, tag_center_flatten)
+ quality_loss = quality_loss * mask_positive_float / num_positive_fp32
+
+ elif self.quality == 'iou':
+ # 2. bboxes_reg: giou_loss
+ mask_positive_float = paddle.squeeze(mask_positive_float, axis=-1)
+ tag_center_flatten = paddle.squeeze(tag_center_flatten, axis=-1)
+ reg_loss = self.__iou_loss(
+ bboxes_reg_flatten,
+ tag_bboxes_flatten,
+ mask_positive_float,
+ weights=None)
+ reg_loss = reg_loss * mask_positive_float / num_positive_fp32
+ # num_positive_fp32 is num_foreground
+
+ # 3. centerness: sigmoid_cross_entropy_with_logits_loss
+ centerness_flatten = paddle.squeeze(centerness_flatten, axis=-1)
+ gt_ious = self.__iou_loss(
+ bboxes_reg_flatten,
+ tag_bboxes_flatten,
+ mask_positive_float,
+ weights=None,
+ return_iou=True)
+ quality_loss = ops.sigmoid_cross_entropy_with_logits(
+ centerness_flatten, gt_ious)
+ quality_loss = quality_loss * mask_positive_float / num_positive_fp32
+ else:
+ raise Exception(f'Unknown quality type: {self.quality}')
loss_all = {
- "loss_centerness": paddle.sum(ctn_loss),
"loss_cls": paddle.sum(cls_loss),
- "loss_box": paddle.sum(reg_loss)
+ "loss_box": paddle.sum(reg_loss),
+ "loss_quality": paddle.sum(quality_loss),
}
return loss_all
diff --git a/tools/gen_semi_coco.py b/tools/gen_semi_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..acacb5861279a6128971ccb49cc51a030a43e381
--- /dev/null
+++ b/tools/gen_semi_coco.py
@@ -0,0 +1,102 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import json
+import argparse
+import numpy as np
+
+
+def save_json(path, images, annotations, categories):
+ new_json = {
+ 'images': images,
+ 'annotations': annotations,
+ 'categories': categories,
+ }
+ with open(path, 'w') as f:
+ json.dump(new_json, f)
+ print('{} saved, with {} images and {} annotations.'.format(
+ path, len(images), len(annotations)))
+
+
+def gen_semi_data(data_dir,
+ json_file,
+ percent=10.0,
+ seed=1,
+ seed_offset=0,
+ txt_file=None):
+ json_name = json_file.split('/')[-1].split('.')[0]
+ json_file = os.path.join(data_dir, json_file)
+ anno = json.load(open(json_file, 'r'))
+ categories = anno['categories']
+ all_images = anno['images']
+ all_anns = anno['annotations']
+ print(
+ 'Totally {} images and {} annotations, about {} gts per image.'.format(
+ len(all_images), len(all_anns), len(all_anns) / len(all_images)))
+
+ if txt_file:
+ print('Using percent {} and seed {}.'.format(percent, seed))
+ txt_file = os.path.join(data_dir, txt_file)
+ sup_idx = json.load(open(txt_file, 'r'))[str(percent)][str(seed)]
+ # max(sup_idx) = 117262 # 10%, sup_idx is not image_id
+ else:
+ np.random.seed(seed + seed_offset)
+ sup_len = int(percent / 100.0 * len(all_images))
+ sup_idx = np.random.choice(
+ range(len(all_images)), size=sup_len, replace=False)
+ labeled_images, labeled_anns = [], []
+ labeled_im_ids = []
+ unlabeled_images, unlabeled_anns = [], []
+
+ for i in range(len(all_images)):
+ if i in sup_idx:
+ labeled_im_ids.append(all_images[i]['id'])
+ labeled_images.append(all_images[i])
+ else:
+ unlabeled_images.append(all_images[i])
+
+ for an in all_anns:
+ im_id = an['image_id']
+ if im_id in labeled_im_ids:
+ labeled_anns.append(an)
+ else:
+ continue
+
+ save_path = '{}/{}'.format(data_dir, 'semi_annotations')
+ if not os.path.exists(save_path):
+ os.mkdir(save_path)
+
+ sup_name = '{}.{}@{}.json'.format(json_name, seed, int(percent))
+ sup_path = os.path.join(save_path, sup_name)
+ save_json(sup_path, labeled_images, labeled_anns, categories)
+
+ unsup_name = '{}.{}@{}-unlabeled.json'.format(json_name, seed, int(percent))
+ unsup_path = os.path.join(save_path, unsup_name)
+ save_json(unsup_path, unlabeled_images, unlabeled_anns, categories)
+
+
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--data_dir', type=str, default='./dataset/coco')
+ parser.add_argument(
+ '--json_file', type=str, default='annotations/instances_train2017.json')
+ parser.add_argument('--percent', type=float, default=10.0)
+ parser.add_argument('--seed', type=int, default=1)
+ parser.add_argument('--seed_offset', type=int, default=0)
+ parser.add_argument('--txt_file', type=str, default='COCO_supervision.txt')
+ args = parser.parse_args()
+ print(args)
+ gen_semi_data(args.data_dir, args.json_file, args.percent, args.seed,
+ args.seed_offset, args.txt_file)