From 3e5bf8a63a1adf49beb2e41424487f082d8e0159 Mon Sep 17 00:00:00 2001 From: wjm <897383984@qq.com> Date: Tue, 13 Jun 2023 10:44:07 +0800 Subject: [PATCH] Support RT-DETR semi-supervised object detection (#8336) * rt-detr-ssod * add rt-detr-ssod * change name --- configs/semi_det/baseline/README.md | 12 + .../baseline/rtdetr_r50vd_6x_coco_sup005.yml | 35 ++ .../baseline/rtdetr_r50vd_6x_coco_sup010.yml | 35 ++ configs/semi_det/rtdetr_ssod/README.md | 109 ++++++ .../rt_detr_ssod005_coco_no_warmup.yml | 212 +++++++++++ .../rt_detr_ssod005_coco_with_warmup.yml | 215 +++++++++++ .../rt_detr_ssod010_coco_no_warmup.yml | 212 +++++++++++ .../rt_detr_ssod010_coco_with_warmup.yml | 215 +++++++++++ ppdet/data/reader.py | 6 +- ppdet/data/transform/batch_operators.py | 72 +++- ppdet/data/transform/operators.py | 83 +++-- ppdet/engine/callbacks.py | 140 ++++++- ppdet/engine/trainer.py | 8 +- ppdet/engine/trainer_ssod.py | 336 ++++++++++++++++- ppdet/modeling/architectures/__init__.py | 6 +- ppdet/modeling/architectures/detr.py | 4 +- ppdet/modeling/architectures/detr_ssod.py | 341 ++++++++++++++++++ .../architectures/multi_stream_detector.py | 69 ++++ ppdet/modeling/heads/detr_head.py | 9 +- ppdet/modeling/losses/detr_loss.py | 93 ++++- ppdet/modeling/post_process.py | 58 ++- ppdet/modeling/ssod/utils.py | 22 ++ ppdet/modeling/transformers/hybrid_encoder.py | 6 +- .../transformers/rtdetr_transformer.py | 9 +- ppdet/utils/checkpoint.py | 37 ++ tools/train.py | 11 +- 26 files changed, 2278 insertions(+), 77 deletions(-) create mode 100644 configs/semi_det/baseline/rtdetr_r50vd_6x_coco_sup005.yml create mode 100644 configs/semi_det/baseline/rtdetr_r50vd_6x_coco_sup010.yml create mode 100644 configs/semi_det/rtdetr_ssod/README.md create mode 100644 configs/semi_det/rtdetr_ssod/rt_detr_ssod005_coco_no_warmup.yml create mode 100644 configs/semi_det/rtdetr_ssod/rt_detr_ssod005_coco_with_warmup.yml create mode 100644 configs/semi_det/rtdetr_ssod/rt_detr_ssod010_coco_no_warmup.yml create mode 100644 configs/semi_det/rtdetr_ssod/rt_detr_ssod010_coco_with_warmup.yml create mode 100644 ppdet/modeling/architectures/detr_ssod.py create mode 100644 ppdet/modeling/architectures/multi_stream_detector.py diff --git a/configs/semi_det/baseline/README.md b/configs/semi_det/baseline/README.md index 457ad7f7c..aaf2800f1 100644 --- a/configs/semi_det/baseline/README.md +++ b/configs/semi_det/baseline/README.md @@ -53,6 +53,18 @@ - 以上模型训练默认使用8 GPUs,总batch_size默认为16,默认初始学习率为0.01。如果改动了总batch_size,请按线性比例相应地调整学习率。 + +### [RT-DETR](../../rtdetr) + +| 基础模型 | 监督数据比例 | mAPval
0.5:0.95 | 模型下载 | 配置文件 | +| :---------------: | :-------------: | :---------------------: |:--------: | :---------: | +| RT-DETR ResNet5vd | 5% | 39.1 | [download](https://bj.bcebos.com/v1/paddledet/data/semidet/rtdetr_ssod/baseline/rtdetr_r50vd_6x_coco_sup005.pdparams) | [config](rtdetr_r50vd_6x_coco_sup005.yml) | +| RT-DETR ResNet5vd | 10% | 42.3 | [download](https://bj.bcebos.com/v1/paddledet/data/semidet/rtdetr_ssod/baseline/rtdetr_r50vd_6x_coco_sup010.pdparams) | [config](rtdetr_r50vd_6x_coco_sup010.yml) | +| RT-DETR ResNet5vd | VOC2007 | 62.7 | [download](https://bj.bcebos.com/v1/paddledet/data/semidet/rtdetr_ssod/baseline/rtdetr_r50vd_6x_voc2007.pdparams) | [config](rtdetr_r50vd_6x_voc2007.yml) | + +**注意:** + - RT-DETR模型训练默认使用4 GPUs,总batch_size默认为16,默认初始学习率为0.0001。如果改动了总batch_size,请按线性比例相应地调整学习率。 + ### 注意事项 - COCO部分监督数据集请参照 [数据集准备](../README.md) 去下载和准备,各个比例的训练集均为**从train2017中抽取部分百分比的子集**,默认使用`fold`号为1的划分子集,`sup010`表示抽取10%的监督数据训练,`sup005`表示抽取5%,`full`表示全部train2017,验证集均为val2017全量; - 抽取部分百分比的监督数据的抽法不同,或使用的`fold`号不同,精度都会因此而有约0.5 mAP之多的差异; diff --git a/configs/semi_det/baseline/rtdetr_r50vd_6x_coco_sup005.yml b/configs/semi_det/baseline/rtdetr_r50vd_6x_coco_sup005.yml new file mode 100644 index 000000000..a949553e8 --- /dev/null +++ b/configs/semi_det/baseline/rtdetr_r50vd_6x_coco_sup005.yml @@ -0,0 +1,35 @@ +_BASE_: [ + '../../rtdetr/rtdetr_r50vd_6x_coco.yml', +] +log_iter: 50 +snapshot_epoch: 2 +weights: output/rtdetr_r50vd_6x_coco/model_final + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams +TrainDataset: + !COCODataSet + image_dir: train2017 + anno_path: semi_annotations/instances_train2017.1@5.json + dataset_dir: dataset/coco + data_fields: ['image', 'gt_bbox', 'gt_class'] + + +worker_num: 4 +TrainReader: + sample_transforms: + - Decode: {} + - RandomDistort: {prob: 0.8} + - RandomExpand: {fill_value: [123.675, 116.28, 103.53]} + - RandomCrop: {prob: 0.8} + - RandomFlip: {} + batch_transforms: + - BatchRandomResize: {target_size: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - NormalizeBox: {} + - BboxXYXY2XYWH: {} + - Permute: {} + batch_size: 4 + shuffle: true + drop_last: true + collate_batch: false + use_shared_memory: false diff --git a/configs/semi_det/baseline/rtdetr_r50vd_6x_coco_sup010.yml b/configs/semi_det/baseline/rtdetr_r50vd_6x_coco_sup010.yml new file mode 100644 index 000000000..a949553e8 --- /dev/null +++ b/configs/semi_det/baseline/rtdetr_r50vd_6x_coco_sup010.yml @@ -0,0 +1,35 @@ +_BASE_: [ + '../../rtdetr/rtdetr_r50vd_6x_coco.yml', +] +log_iter: 50 +snapshot_epoch: 2 +weights: output/rtdetr_r50vd_6x_coco/model_final + +pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams +TrainDataset: + !COCODataSet + image_dir: train2017 + anno_path: semi_annotations/instances_train2017.1@5.json + dataset_dir: dataset/coco + data_fields: ['image', 'gt_bbox', 'gt_class'] + + +worker_num: 4 +TrainReader: + sample_transforms: + - Decode: {} + - RandomDistort: {prob: 0.8} + - RandomExpand: {fill_value: [123.675, 116.28, 103.53]} + - RandomCrop: {prob: 0.8} + - RandomFlip: {} + batch_transforms: + - BatchRandomResize: {target_size: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - NormalizeBox: {} + - BboxXYXY2XYWH: {} + - Permute: {} + batch_size: 4 + shuffle: true + drop_last: true + collate_batch: false + use_shared_memory: false diff --git a/configs/semi_det/rtdetr_ssod/README.md b/configs/semi_det/rtdetr_ssod/README.md new file mode 100644 index 000000000..1fc3f8783 --- /dev/null +++ b/configs/semi_det/rtdetr_ssod/README.md @@ -0,0 +1,109 @@ +简体中文 | [English](README_en.md) + +# RTDETR-SSOD(基于RTDETR的配套半监督目标检测方法) +# 复现模型指标注意事项: 模型中指标是采用在监督数据训练饱和后加载监督数据所训练的模型进行半监督训练 + - 例如 使用 baseline/rtdetr_r50vd_6x_coco_sup005.yml使用5%coco数据训练全监督模型,得到rtdetr_r50vd_6x_coco_sup005.pdparams,在rt_detr_ssod005_coco_no_warmup.yml中设置 + - pretrain_student_weights: rtdetr_r50vd_6x_coco_sup005.pdparams + - pretrain_teacher_weights: rtdetr_r50vd_6x_coco_sup005.pdparams + - 1.使用coco数据集5%和10%有标记数据和voc数据集VOC2007trainval 所训练的权重已给出请参考 semi_det/baseline/README.md. + - 2.rt_detr_ssod_voc_no_warmup.yml rt_detr_ssod005_coco_no_warmup.yml rt_detr_ssod010_coco_no_warmup.yml 是使用训练好的全监督权中直接开启半监督训练(推荐) +## RTDETR-SSOD模型库 + +| 模型 | 监督数据比例 | Sup Baseline | Sup Epochs (Iters) | Sup mAPval
0.5:0.95 | Semi mAPval
0.5:0.95 | Semi Epochs (Iters) | 模型下载 | 配置文件 | +| :------------: | :---------: | :---------------------: | :---------------------: |:---------------------------: |:----------------------------: | :------------------: |:--------: |:----------: | +| RTDETR-SSOD | 5% | [sup_config](../baseline/rtdetr_r50vd_6x_coco_sup005.yml) | - | 39.0 | **42.3** | - | [download](https://bj.bcebos.com/v1/paddledet/rt_detr_ssod005_coco_no_warmup.pdparams) | [config](./rt_detr_ssod005_coco_no_warmup.yml) | +| RTDETR-SSOD | 10% | [sup_config](../baseline/rtdetr_r50vd_6x_coco_sup010.yml) | -| 42.3 | **44.8** | - | [download](https://bj.bcebos.com/v1/paddledet/data/semidet/rtdetr_ssod/rt_detr_ssod010_coco/rt_detr_ssod010_coco_no_warmup.pdparams) | [config](./rt_detr_ssod010_coco_with_warmup.yml) | +| RTDETR-SSOD(VOC)| VOC | [sup_config](../baseline/rtdetr_r50vd_6x_coco_voc2007.yml) | - | 62.7 | **65.8(LSJ)** | - | [download](https://bj.bcebos.com/v1/paddledet/data/semidet/rtdetr_ssod/rt_detr_ssod_voc/rt_detr_ssod_voc_no_warmup.pdparams) | [config](./rt_detr_ssod_voc_with_warmup.yml) | + +**注意:** + - 以上模型训练默认使用8 GPUs,监督数据总batch_size默认为16,无监督数据总batch_size默认也为16,默认初始学习率为0.01。如果改动了总batch_size,请按线性比例相应地调整学习率; + - **监督数据比例**是指使用的有标签COCO数据集占 COCO train2017 全量训练集的百分比,使用的无标签COCO数据集一般也是相同比例,但具体图片和有标签数据的图片不重合; + - `Semi Epochs (Iters)`表示**半监督训练**的模型的 Epochs (Iters),如果使用**自定义数据集**,需自行根据Iters换算到对应的Epochs调整,最好保证总Iters 和COCO数据集的设置较为接近; + - `Sup mAP`是**只使用有监督数据训练**的模型的精度,请参照**基础检测器的配置文件** 和 [baseline](../baseline); + - `Semi mAP`是**半监督训练**的模型的精度,模型下载和配置文件的链接均为**半监督模型**; + - `LSJ`表示 **large-scale jittering**,表示使用更大范围的多尺度训练,可进一步提升精度,但训练速度也会变慢; + - 半监督检测的配置讲解,请参照[文档](../README.md/#半监督检测配置); + - `Dense Teacher`原文使用`R50-va-caffe`预训练,PaddleDetection中默认使用`R50-vb`预训练,如果使用`R50-vd`结合[SSLD](../../../docs/feature_models/SSLD_PRETRAINED_MODEL.md)的预训练模型,可进一步显著提升检测精度,同时backbone部分配置也需要做出相应更改,如: + ```python + pretrain_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams + ResNet: + depth: 50 + variant: d + norm_type: bn + freeze_at: 0 + return_idx: [1, 2, 3] + num_stages: 4 + lr_mult_list: [0.05, 0.05, 0.1, 0.15] + ``` + +## 使用说明 + +仅训练时必须使用半监督检测的配置文件去训练,评估、预测、部署也可以按基础检测器的配置文件去执行。 + +### 训练 + +```bash +# 单卡训练 (不推荐,需按线性比例相应地调整学习率) +CUDA_VISIBLE_DEVICES=0 python tools/train.py -c configs/semi_det/rtdetr_ssod/rt_detr_ssod010_coco_no_warmup.yml --eval + +# 多卡训练 +python -m paddle.distributed.launch --log_dir=denseteacher_fcos_semi010/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/semi_det/rtdetr_ssod/rt_detr_ssod010_coco_no_warmup.yml --eval +``` + +### 评估 + +```bash +CUDA_VISIBLE_DEVICES=0 python tools/eval.py -c configs/semi_det/rtdetr_ssod/rt_detr_ssod010_coco_no_warmup.yml -o weights=output/rt_detr_ssod/model_final/model_final.pdparams +``` + +### 预测 + +```bash +CUDA_VISIBLE_DEVICES=0 python tools/infer.py -c configs/semi_det/rtdetr_ssod/rt_detr_ssod010_coco_no_warmup.yml -o weights=output/rt_detr_ssod/model_final/model_final.pdparams --infer_img=demo/000000014439.jpg +``` + +### 部署 + +部署可以使用半监督检测配置文件,也可以使用基础检测器的配置文件去部署和使用。 + +```bash +# 导出模型 +CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/semi_det/rtdetr_ssod/rt_detr_ssod010_coco_no_warmup.yml -o weights=https://paddledet.bj.bcebos.com/models/rt_detr_ssod010_coco_no_warmup.pdparams + +# 导出权重预测 +CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output_inference/rt_detr_ssod010_coco_no_warmup --image_file=demo/000000014439_640x640.jpg --device=GPU + +# 部署测速 +CUDA_VISIBLE_DEVICES=0 python deploy/python/infer.py --model_dir=output_inference/rt_detr_ssod010_coco_no_warmup --image_file=demo/000000014439_640x640.jpg --device=GPU --run_benchmark=True # --run_mode=trt_fp16 + +# 导出ONNX +paddle2onnx --model_dir output_inference/drt_detr_ssod010_coco_no_warmup/ --model_filename model.pdmodel --params_filename model.pdiparams --opset_version 12 --save_file rt_detr_ssod010_coco_no_warmup.onnx +``` + + +# RTDETR-SSOD 下游任务 + +我们验证了RTDETR-SSOD模型强大的泛化能力,在低光、工业、交通等不同场景下游任务检测效果稳定提升! + +voc数据集采用[voc](https://github.com/thsant/wgisd),是一个广泛使用的计算机视觉数据集,用于目标检测、图像分割和场景理解等任务。该数据集包含20个类别的图像,处理后的COCO格式,包含图片标注训练集5011张,图片无标注训练集11540张,测试集2510张,20个类别; + +低光数据集使用[ExDark](https://github.com/cs-chan/Exclusively-Dark-Image-Dataset/tree/master/Dataset),该数据集是一个专门在低光照环境下拍摄出针对低光目标检测的数据集,包括从极低光环境到暮光环境等10种不同光照条件下的图片,处理后的COCO格式,包含图片训练集5891张,测试集1472张,12个类别; + +工业数据集使用[PKU-Market-PCB](https://robotics.pkusz.edu.cn/resources/dataset/),该数据集用于印刷电路板(PCB)的瑕疵检测,提供了6种常见的PCB缺陷; + +商超数据集[SKU110k](https://github.com/eg4000/SKU110K_CVPR19)是商品超市场景下的密集目标检测数据集,包含11,762张图片和超过170个实例。其中包括8,233张用于训练的图像、588张用于验证的图像和2,941张用于测试的图像; + +自动驾驶数据集使用[sslad](https://soda-2d.github.io/index.html); + +交通数据集使用[visdrone](http://aiskyeye.com/home/); + +## 下游数据集实验结果: + +| 数据集 | 业务方向 | 划分 | labeled数据量 | 全监督mAP | 半监督mAP | +|----------|-----------|---------------------|-----------------|------------------|--------------| +| voc | 通用 | voc07, 12;1:2 | 5000 | 63.1 | 65.8(+2.7) | +| visdrone | 无人机交通 | 1:9 | 647 | 19.4 | 20.6 (+1.2) | +| pcb | 工业缺陷 | 1:9 | 55 | 22.9 | 26.8 (+3.9) | +| sku110k | 商品 | 1:9 | 821 | 38.9 | 52.4 (+13.5) | +| sslad | 自动驾驶 | 1:32 | 4967 | 42.1 | 43.3 (+1.2) | +| exdark | 低光照 | 1:9 | 589 | 39.6 | 44.1 (+4.5) | diff --git a/configs/semi_det/rtdetr_ssod/rt_detr_ssod005_coco_no_warmup.yml b/configs/semi_det/rtdetr_ssod/rt_detr_ssod005_coco_no_warmup.yml new file mode 100644 index 000000000..ac4665448 --- /dev/null +++ b/configs/semi_det/rtdetr_ssod/rt_detr_ssod005_coco_no_warmup.yml @@ -0,0 +1,212 @@ +_BASE_: [ + '../../runtime.yml', + '../../rtdetr/_base_/rtdetr_r50vd.yml', + '../../rtdetr/_base_/rtdetr_reader.yml', +] +eval_interval: 4000 +save_interval: 4000 +weights: output/rt_detr_ssod/model_final +find_unused_parameters: True +save_dir: output +log_iter: 1 +ssod_method: Semi_RTDETR +### global config +use_simple_ema: True +ema_decay: 0.9996 +use_gpu: true + +### reader config +worker_num: 4 + +SemiTrainReader: + sample_transforms: + - Decode: {} + - RandomDistort: {prob: 0.8} + - RandomExpand: {fill_value: [0., 0., 0.]} + - RandomCrop: {prob: 0.8} + - RandomFlip: {} + weak_aug: + - RandomFlip: {prob: 0.0} + strong_aug: + - StrongAugImage: {transforms: [ + RandomColorJitter: {prob: 0.8, brightness: 0.4, contrast: 0.4, saturation: 0.4, hue: 0.1}, + RandomErasingCrop: {}, + RandomGaussianBlur: {prob: 0.5, sigma: [0.1, 2.0]}, + RandomGrayscale: {prob: 0.2}, + ]} + sup_batch_transforms: + - BatchRandomResizeForSSOD: {target_size: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - NormalizeBox: {} + - BboxXYXY2XYWH: {} + - Permute: {} + unsup_batch_transforms: + - BatchRandomResizeForSSOD: {target_size: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - NormalizeBox: {} + - BboxXYXY2XYWH: {} + - Permute: {} + sup_batch_size: 2 + unsup_batch_size: 2 + shuffle: true + drop_last: true + collate_batch: false + use_shared_memory: false + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [640, 640], keep_ratio: False} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - Permute: {} + batch_size: 2 + shuffle: false + drop_last: false + + +TestReader: + sample_transforms: + - Decode: {} + - Resize: { target_size: [640, 640], keep_ratio: False } + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - Permute: {} + batch_size: 1 + shuffle: false + drop_last: false + + +pretrain_student_weights: https://bj.bcebos.com/v1/paddledet/data/semidet/rtdetr_ssod/baseline/rtdetr_r50vd_6x_coco_sup005.pdparams +pretrain_teacher_weights: https://bj.bcebos.com/v1/paddledet/data/semidet/rtdetr_ssod/baseline/rtdetr_r50vd_6x_coco_sup005.pdparams + +hidden_dim: 256 +use_focal_loss: True +eval_size: [640, 640] + +architecture: DETR +DETR: + backbone: ResNet + neck: HybridEncoder + transformer: RTDETRTransformer + detr_head: DINOHead + post_process: DETRPostProcess + post_process_semi: DETRBBoxSemiPostProcess +ResNet: + depth: 50 + variant: d + norm_type: bn + freeze_at: 0 + return_idx: [1, 2, 3] + lr_mult_list: [0.1, 0.1, 0.1, 0.1] + num_stages: 4 + freeze_stem_only: True + +HybridEncoder: + hidden_dim: 256 + use_encoder_idx: [2] + num_encoder_layers: 1 + encoder_layer: + name: TransformerLayer + d_model: 256 + nhead: 8 + dim_feedforward: 1024 + dropout: 0. + activation: 'gelu' + expansion: 1.0 + + +PPDETRTransformer: + num_queries: 300 + position_embed_type: sine + feat_strides: [8, 16, 32] + num_levels: 3 + nhead: 8 + num_decoder_layers: 6 + dim_feedforward: 1024 + dropout: 0.0 + activation: relu + num_denoising: 100 + label_noise_ratio: 0.5 + box_noise_scale: 1.0 + learnt_init_query: False + +DINOHead: + loss: + name: DINOLoss + loss_coeff: {class: 1, bbox: 5, giou: 2} + aux_loss: True + use_vfl: True + matcher: + name: HungarianMatcher + matcher_coeff: {class: 2, bbox: 5, giou: 2} + +DETRPostProcess: + num_top_queries: 300 + + + +SSOD: DETR_SSOD +DETR_SSOD: + teacher: DETR + student: DETR + train_cfg: + sup_weight: 1.0 + unsup_weight: 1.0 + ema_start_iters: -1 + pseudo_label_initial_score_thr: 0.7 + min_pseduo_box_size: 0 + concat_sup_data: True + test_cfg: + inference_on: teacher + + + +metric: COCO +num_classes: 80 + +# partial labeled COCO, use `SemiCOCODataSet` rather than `COCODataSet` +TrainDataset: + !SemiCOCODataSet + image_dir: train2017 + anno_path: semi_annotations/instances_train2017.1@5.json + dataset_dir: dataset/coco + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] +# partial unlabeled COCO, use `SemiCOCODataSet` rather than `COCODataSet` +UnsupTrainDataset: + !SemiCOCODataSet + image_dir: train2017 + anno_path: semi_annotations/instances_train2017.1@5-unlabeled.json + dataset_dir: dataset/coco + data_fields: ['image'] + supervised: False + +EvalDataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + allow_empty: true + +TestDataset: + !ImageFolder + anno_path: annotations/instances_val2017.json # also support txt (like VOC's label_list.txt) + dataset_dir: dataset/coco # if set, anno_path will be 'dataset_dir/anno_path' + +epoch: 400 #epoch: 60 + +LearningRate: + base_lr: 0.0002 + schedulers: + - !PiecewiseDecay + gamma: 1.0 + milestones: [400] + use_warmup: false + - !LinearWarmup + start_factor: 0.001 + steps: 2000 + +OptimizerBuilder: + clip_grad_by_norm: 0.1 + regularizer: false + optimizer: + type: AdamW + weight_decay: 0.0001 diff --git a/configs/semi_det/rtdetr_ssod/rt_detr_ssod005_coco_with_warmup.yml b/configs/semi_det/rtdetr_ssod/rt_detr_ssod005_coco_with_warmup.yml new file mode 100644 index 000000000..b739479aa --- /dev/null +++ b/configs/semi_det/rtdetr_ssod/rt_detr_ssod005_coco_with_warmup.yml @@ -0,0 +1,215 @@ +_BASE_: [ + '../../runtime.yml', + '../../rtdetr/_base_/rtdetr_r50vd.yml', + '../../rtdetr/_base_/rtdetr_reader.yml', +] + +#for debug +eval_interval: 4000 +save_interval: 4000 +weights: output/rt_detr_ssod/model_final +find_unused_parameters: True +save_dir: output +log_iter: 50 +ssod_method: Semi_RTDETR +### global config +use_simple_ema: True +ema_decay: 0.9996 +use_gpu: true + +### reader config +worker_num: 4 + +SemiTrainReader: + sample_transforms: + - Decode: {} + - RandomDistort: {prob: 0.8} + - RandomExpand: {fill_value: [0., 0., 0.]} + - RandomCrop: {prob: 0.8} + - RandomFlip: {} + weak_aug: + - RandomFlip: {prob: 0.0} + strong_aug: + - StrongAugImage: {transforms: [ + RandomColorJitter: {prob: 0.8, brightness: 0.4, contrast: 0.4, saturation: 0.4, hue: 0.1}, + RandomErasingCrop: {}, + RandomGaussianBlur: {prob: 0.5, sigma: [0.1, 2.0]}, + RandomGrayscale: {prob: 0.2}, + ]} + sup_batch_transforms: + - BatchRandomResizeForSSOD: {target_size: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - NormalizeBox: {} + - BboxXYXY2XYWH: {} + - Permute: {} + unsup_batch_transforms: + - BatchRandomResizeForSSOD: {target_size: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - NormalizeBox: {} + - BboxXYXY2XYWH: {} + - Permute: {} + sup_batch_size: 2 + unsup_batch_size: 2 + shuffle: true + drop_last: true + collate_batch: false + use_shared_memory: false + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [640, 640], keep_ratio: False} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - Permute: {} + batch_size: 2 + shuffle: false + drop_last: false + + +TestReader: + sample_transforms: + - Decode: {} + - Resize: { target_size: [640, 640], keep_ratio: False } + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - Permute: {} + batch_size: 1 + shuffle: false + drop_last: false + + +pretrain_student_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams +pretrain_teacher_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams + +hidden_dim: 256 +use_focal_loss: True +eval_size: [640, 640] + +architecture: DETR +DETR: + backbone: ResNet + neck: HybridEncoder + transformer: RTDETRTransformer + detr_head: DINOHead + post_process: DETRPostProcess + post_process_semi: DETRBBoxSemiPostProcess +ResNet: + # index 0 stands for res2 + depth: 50 + variant: d + norm_type: bn + freeze_at: 0 + return_idx: [1, 2, 3] + lr_mult_list: [0.1, 0.1, 0.1, 0.1] + num_stages: 4 + freeze_stem_only: True + +HybridEncoder: + hidden_dim: 256 + use_encoder_idx: [2] + num_encoder_layers: 1 + encoder_layer: + name: TransformerLayer + d_model: 256 + nhead: 8 + dim_feedforward: 1024 + dropout: 0. + activation: 'gelu' + expansion: 1.0 + + +RTDETRTransformer: + num_queries: 300 + position_embed_type: sine + feat_strides: [8, 16, 32] + num_levels: 3 + nhead: 8 + num_decoder_layers: 6 + dim_feedforward: 1024 + dropout: 0.0 + activation: relu + num_denoising: 100 + label_noise_ratio: 0.5 + box_noise_scale: 1.0 + learnt_init_query: False + +DINOHead: + loss: + name: DINOLoss + loss_coeff: {class: 1, bbox: 5, giou: 2} + aux_loss: True + use_vfl: True + matcher: + name: HungarianMatcher + matcher_coeff: {class: 2, bbox: 5, giou: 2} + use_uni_match: True +DETRPostProcess: + num_top_queries: 300 + + + +SSOD: DETR_SSOD +DETR_SSOD: + teacher: DETR + student: DETR + train_cfg: + sup_weight: 1.0 + unsup_weight: 1.0 + ema_start_iters: 10000 + pseudo_label_initial_score_thr: 0.7 + min_pseduo_box_size: 0 + concat_sup_data: True + test_cfg: + inference_on: teacher + + + +metric: COCO +num_classes: 80 + +# partial labeled COCO, use `SemiCOCODataSet` rather than `COCODataSet` +TrainDataset: + !SemiCOCODataSet + image_dir: train2017 + anno_path: semi_annotations/instances_train2017.1@5.json + dataset_dir: dataset/coco + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] +# partial unlabeled COCO, use `SemiCOCODataSet` rather than `COCODataSet` +UnsupTrainDataset: + !SemiCOCODataSet + image_dir: train2017 + anno_path: semi_annotations/instances_train2017.1@5-unlabeled.json + dataset_dir: dataset/coco + data_fields: ['image'] + supervised: False + +EvalDataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + allow_empty: true + +TestDataset: + !ImageFolder + anno_path: annotations/instances_val2017.json # also support txt (like VOC's label_list.txt) + dataset_dir: dataset/coco # if set, anno_path will be 'dataset_dir/anno_path' + +epoch: 500 #epoch: 60 + +LearningRate: + base_lr: 0.0002 + schedulers: + - !PiecewiseDecay + gamma: 1.0 + milestones: [500] + use_warmup: false + - !LinearWarmup + start_factor: 0.001 + steps: 2000 + +OptimizerBuilder: + clip_grad_by_norm: 0.1 + regularizer: false + optimizer: + type: AdamW + weight_decay: 0.0001 diff --git a/configs/semi_det/rtdetr_ssod/rt_detr_ssod010_coco_no_warmup.yml b/configs/semi_det/rtdetr_ssod/rt_detr_ssod010_coco_no_warmup.yml new file mode 100644 index 000000000..fe53f58e7 --- /dev/null +++ b/configs/semi_det/rtdetr_ssod/rt_detr_ssod010_coco_no_warmup.yml @@ -0,0 +1,212 @@ +_BASE_: [ + '../../runtime.yml', + '../../rtdetr/_base_/rtdetr_r50vd.yml', + '../../rtdetr/_base_/rtdetr_reader.yml', +] +eval_interval: 4000 +save_interval: 4000 +weights: output/rt_detr_ssod/model_final +find_unused_parameters: True +save_dir: output +log_iter: 1 +ssod_method: Semi_RTDETR +### global config +use_simple_ema: True +ema_decay: 0.9996 +use_gpu: true + +### reader config +worker_num: 4 + +SemiTrainReader: + sample_transforms: + - Decode: {} + - RandomDistort: {prob: 0.8} + - RandomExpand: {fill_value: [0., 0., 0.]} + - RandomCrop: {prob: 0.8} + - RandomFlip: {} + weak_aug: + - RandomFlip: {prob: 0.0} + strong_aug: + - StrongAugImage: {transforms: [ + RandomColorJitter: {prob: 0.8, brightness: 0.4, contrast: 0.4, saturation: 0.4, hue: 0.1}, + RandomErasingCrop: {}, + RandomGaussianBlur: {prob: 0.5, sigma: [0.1, 2.0]}, + RandomGrayscale: {prob: 0.2}, + ]} + sup_batch_transforms: + - BatchRandomResizeForSSOD: {target_size: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - NormalizeBox: {} + - BboxXYXY2XYWH: {} + - Permute: {} + unsup_batch_transforms: + - BatchRandomResizeForSSOD: {target_size: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - NormalizeBox: {} + - BboxXYXY2XYWH: {} + - Permute: {} + sup_batch_size: 2 + unsup_batch_size: 2 + shuffle: true + drop_last: true + collate_batch: false + use_shared_memory: false + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [640, 640], keep_ratio: False} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - Permute: {} + batch_size: 2 + shuffle: false + drop_last: false + + +TestReader: + sample_transforms: + - Decode: {} + - Resize: { target_size: [640, 640], keep_ratio: False } + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - Permute: {} + batch_size: 1 + shuffle: false + drop_last: false + + +pretrain_student_weights: https://bj.bcebos.com/v1/paddledet/data/semidet/rtdetr_ssod/baseline/rtdetr_r50vd_6x_coco_sup010.pdparams +pretrain_teacher_weights: https://bj.bcebos.com/v1/paddledet/data/semidet/rtdetr_ssod/baseline/rtdetr_r50vd_6x_coco_sup010.pdparams + +hidden_dim: 256 +use_focal_loss: True +eval_size: [640, 640] + +architecture: DETR +DETR: + backbone: ResNet + neck: HybridEncoder + transformer: RTDETRTransformer + detr_head: DINOHead + post_process: DETRPostProcess + post_process_semi: DETRBBoxSemiPostProcess +ResNet: + depth: 50 + variant: d + norm_type: bn + freeze_at: 0 + return_idx: [1, 2, 3] + lr_mult_list: [0.1, 0.1, 0.1, 0.1] + num_stages: 4 + freeze_stem_only: True + +HybridEncoder: + hidden_dim: 256 + use_encoder_idx: [2] + num_encoder_layers: 1 + encoder_layer: + name: TransformerLayer + d_model: 256 + nhead: 8 + dim_feedforward: 1024 + dropout: 0. + activation: 'gelu' + expansion: 1.0 + + +PPDETRTransformer: + num_queries: 300 + position_embed_type: sine + feat_strides: [8, 16, 32] + num_levels: 3 + nhead: 8 + num_decoder_layers: 6 + dim_feedforward: 1024 + dropout: 0.0 + activation: relu + num_denoising: 100 + label_noise_ratio: 0.5 + box_noise_scale: 1.0 + learnt_init_query: False + +DINOHead: + loss: + name: DINOLoss + loss_coeff: {class: 1, bbox: 5, giou: 2} + aux_loss: True + use_vfl: True + matcher: + name: HungarianMatcher + matcher_coeff: {class: 2, bbox: 5, giou: 2} + +DETRPostProcess: + num_top_queries: 300 + + + +SSOD: DETR_SSOD +DETR_SSOD: + teacher: DETR + student: DETR + train_cfg: + sup_weight: 1.0 + unsup_weight: 1.0 + ema_start_iters: -1 + pseudo_label_initial_score_thr: 0.7 + min_pseduo_box_size: 0 + concat_sup_data: True + test_cfg: + inference_on: teacher + + + +metric: COCO +num_classes: 80 + +# partial labeled COCO, use `SemiCOCODataSet` rather than `COCODataSet` +TrainDataset: + !SemiCOCODataSet + image_dir: train2017 + anno_path: semi_annotations/instances_train2017.1@10.json + dataset_dir: dataset/coco + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] +# partial unlabeled COCO, use `SemiCOCODataSet` rather than `COCODataSet` +UnsupTrainDataset: + !SemiCOCODataSet + image_dir: train2017 + anno_path: semi_annotations/instances_train2017.1@10-unlabeled.json + dataset_dir: dataset/coco + data_fields: ['image'] + supervised: False + +EvalDataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + allow_empty: true + +TestDataset: + !ImageFolder + anno_path: annotations/instances_val2017.json # also support txt (like VOC's label_list.txt) + dataset_dir: dataset/coco # if set, anno_path will be 'dataset_dir/anno_path' + +epoch: 400 #epoch: 60 + +LearningRate: + base_lr: 0.0002 + schedulers: + - !PiecewiseDecay + gamma: 1.0 + milestones: [400] + use_warmup: false + - !LinearWarmup + start_factor: 0.001 + steps: 2000 + +OptimizerBuilder: + clip_grad_by_norm: 0.1 + regularizer: false + optimizer: + type: AdamW + weight_decay: 0.0001 diff --git a/configs/semi_det/rtdetr_ssod/rt_detr_ssod010_coco_with_warmup.yml b/configs/semi_det/rtdetr_ssod/rt_detr_ssod010_coco_with_warmup.yml new file mode 100644 index 000000000..8e63225d4 --- /dev/null +++ b/configs/semi_det/rtdetr_ssod/rt_detr_ssod010_coco_with_warmup.yml @@ -0,0 +1,215 @@ +_BASE_: [ + '../../runtime.yml', + '../../rtdetr/_base_/rtdetr_r50vd.yml', + '../../rtdetr/_base_/rtdetr_reader.yml', +] + +#for debug +eval_interval: 4000 +save_interval: 4000 +weights: output/rt_detr_ssod/model_final +find_unused_parameters: True +save_dir: output +log_iter: 50 +ssod_method: Semi_RTDETR +### global config +use_simple_ema: True +ema_decay: 0.9996 +use_gpu: true + +### reader config +worker_num: 4 + +SemiTrainReader: + sample_transforms: + - Decode: {} + - RandomDistort: {prob: 0.8} + - RandomExpand: {fill_value: [0., 0., 0.]} + - RandomCrop: {prob: 0.8} + - RandomFlip: {} + weak_aug: + - RandomFlip: {prob: 0.0} + strong_aug: + - StrongAugImage: {transforms: [ + RandomColorJitter: {prob: 0.8, brightness: 0.4, contrast: 0.4, saturation: 0.4, hue: 0.1}, + RandomErasingCrop: {}, + RandomGaussianBlur: {prob: 0.5, sigma: [0.1, 2.0]}, + RandomGrayscale: {prob: 0.2}, + ]} + sup_batch_transforms: + - BatchRandomResizeForSSOD: {target_size: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - NormalizeBox: {} + - BboxXYXY2XYWH: {} + - Permute: {} + unsup_batch_transforms: + - BatchRandomResizeForSSOD: {target_size: [480, 512, 544, 576, 608, 640, 640, 640, 672, 704, 736, 768, 800], random_size: True, random_interp: True, keep_ratio: False} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - NormalizeBox: {} + - BboxXYXY2XYWH: {} + - Permute: {} + sup_batch_size: 2 + unsup_batch_size: 2 + shuffle: true + drop_last: true + collate_batch: false + use_shared_memory: false + +EvalReader: + sample_transforms: + - Decode: {} + - Resize: {target_size: [640, 640], keep_ratio: False} + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - Permute: {} + batch_size: 2 + shuffle: false + drop_last: false + + +TestReader: + sample_transforms: + - Decode: {} + - Resize: { target_size: [640, 640], keep_ratio: False } + - NormalizeImage: {mean: [0., 0., 0.], std: [1., 1., 1.], norm_type: none} + - Permute: {} + batch_size: 1 + shuffle: false + drop_last: false + + +pretrain_student_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams +pretrain_teacher_weights: https://paddledet.bj.bcebos.com/models/pretrained/ResNet50_vd_ssld_v2_pretrained.pdparams + +hidden_dim: 256 +use_focal_loss: True +eval_size: [640, 640] + +architecture: DETR +DETR: + backbone: ResNet + neck: HybridEncoder + transformer: RTDETRTransformer + detr_head: DINOHead + post_process: DETRPostProcess + post_process_semi: DETRBBoxSemiPostProcess +ResNet: + # index 0 stands for res2 + depth: 50 + variant: d + norm_type: bn + freeze_at: 0 + return_idx: [1, 2, 3] + lr_mult_list: [0.1, 0.1, 0.1, 0.1] + num_stages: 4 + freeze_stem_only: True + +HybridEncoder: + hidden_dim: 256 + use_encoder_idx: [2] + num_encoder_layers: 1 + encoder_layer: + name: TransformerLayer + d_model: 256 + nhead: 8 + dim_feedforward: 1024 + dropout: 0. + activation: 'gelu' + expansion: 1.0 + + +RTDETRTransformer: + num_queries: 300 + position_embed_type: sine + feat_strides: [8, 16, 32] + num_levels: 3 + nhead: 8 + num_decoder_layers: 6 + dim_feedforward: 1024 + dropout: 0.0 + activation: relu + num_denoising: 100 + label_noise_ratio: 0.5 + box_noise_scale: 1.0 + learnt_init_query: False + +DINOHead: + loss: + name: DINOLoss + loss_coeff: {class: 1, bbox: 5, giou: 2} + aux_loss: True + use_vfl: True + matcher: + name: HungarianMatcher + matcher_coeff: {class: 2, bbox: 5, giou: 2} + use_uni_match: True +DETRPostProcess: + num_top_queries: 300 + + + +SSOD: DETR_SSOD +DETR_SSOD: + teacher: DETR + student: DETR + train_cfg: + sup_weight: 1.0 + unsup_weight: 1.0 + ema_start_iters: 10000 + pseudo_label_initial_score_thr: 0.7 + min_pseduo_box_size: 0 + concat_sup_data: True + test_cfg: + inference_on: teacher + + + +metric: COCO +num_classes: 80 + +# partial labeled COCO, use `SemiCOCODataSet` rather than `COCODataSet` +TrainDataset: + !SemiCOCODataSet + image_dir: train2017 + anno_path: semi_annotations/instances_train2017.1@10.json + dataset_dir: dataset/coco + data_fields: ['image', 'gt_bbox', 'gt_class', 'is_crowd'] +# partial unlabeled COCO, use `SemiCOCODataSet` rather than `COCODataSet` +UnsupTrainDataset: + !SemiCOCODataSet + image_dir: train2017 + anno_path: semi_annotations/instances_train2017.1@10-unlabeled.json + dataset_dir: dataset/coco + data_fields: ['image'] + supervised: False + +EvalDataset: + !COCODataSet + image_dir: val2017 + anno_path: annotations/instances_val2017.json + dataset_dir: dataset/coco + allow_empty: true + +TestDataset: + !ImageFolder + anno_path: annotations/instances_val2017.json # also support txt (like VOC's label_list.txt) + dataset_dir: dataset/coco # if set, anno_path will be 'dataset_dir/anno_path' + +epoch: 500 #epoch: 60 + +LearningRate: + base_lr: 0.0002 + schedulers: + - !PiecewiseDecay + gamma: 1.0 + milestones: [500] + use_warmup: false + - !LinearWarmup + start_factor: 0.001 + steps: 2000 + +OptimizerBuilder: + clip_grad_by_norm: 0.1 + regularizer: false + optimizer: + type: AdamW + weight_decay: 0.0001 diff --git a/ppdet/data/reader.py b/ppdet/data/reader.py index 041f7735d..c40f3c378 100644 --- a/ppdet/data/reader.py +++ b/ppdet/data/reader.py @@ -392,7 +392,11 @@ class BatchCompose_SSOD(Compose): for f in self.transforms_cls: try: data = f(data) - strong_data = f(strong_data) + if 'BatchRandomResizeForSSOD' in f._id: + strong_data = f(strong_data, data[1])[0] + data = data[0] + else: + strong_data = f(strong_data) except Exception as e: stack_info = traceback.format_exc() logger.warning("fail to map batch transform [{}] " diff --git a/ppdet/data/transform/batch_operators.py b/ppdet/data/transform/batch_operators.py index 5b8bbcd3b..f1ea70243 100644 --- a/ppdet/data/transform/batch_operators.py +++ b/ppdet/data/transform/batch_operators.py @@ -38,19 +38,10 @@ from ppdet.modeling.keypoint_utils import get_affine_transform, affine_transform logger = setup_logger(__name__) __all__ = [ - 'PadBatch', - 'BatchRandomResize', - 'Gt2YoloTarget', - 'Gt2FCOSTarget', - 'Gt2TTFTarget', - 'Gt2Solov2Target', - 'Gt2SparseTarget', - 'PadMaskBatch', - 'Gt2GFLTarget', - 'Gt2CenterNetTarget', - 'Gt2CenterTrackTarget', - 'PadGT', - 'PadRGT', + 'PadBatch', 'BatchRandomResize', 'Gt2YoloTarget', 'Gt2FCOSTarget', + 'Gt2TTFTarget', 'Gt2Solov2Target', 'Gt2SparseTarget', 'PadMaskBatch', + 'Gt2GFLTarget', 'Gt2CenterNetTarget', 'Gt2CenterTrackTarget', 'PadGT', + 'PadRGT', 'BatchRandomResizeForSSOD' ] @@ -1484,3 +1475,58 @@ class Gt2CenterTrackTarget(BaseOperator): del sample return new_sample + + +@register_op +class BatchRandomResizeForSSOD(BaseOperator): + """ + Resize image to target size randomly. random target_size and interpolation method + Args: + target_size (int, list, tuple): image target size, if random size is True, must be list or tuple + keep_ratio (bool): whether keep_raio or not, default true + interp (int): the interpolation method + random_size (bool): whether random select target size of image + random_interp (bool): whether random select interpolation method + """ + + def __init__(self, + target_size, + keep_ratio, + interp=cv2.INTER_NEAREST, + random_size=True, + random_interp=False): + super(BatchRandomResizeForSSOD, self).__init__() + self.keep_ratio = keep_ratio + self.interps = [ + cv2.INTER_NEAREST, + cv2.INTER_LINEAR, + cv2.INTER_AREA, + cv2.INTER_CUBIC, + cv2.INTER_LANCZOS4, + ] + self.interp = interp + assert isinstance(target_size, ( + int, Sequence)), "target_size must be int, list or tuple" + if random_size and not isinstance(target_size, list): + raise TypeError( + "Type of target_size is invalid when random_size is True. Must be List, now is {}". + format(type(target_size))) + self.target_size = target_size + self.random_size = random_size + self.random_interp = random_interp + + def __call__(self, samples, context=None): + if self.random_size: + index = np.random.choice(len(self.target_size)) + target_size = self.target_size[index] + else: + target_size = self.target_size + if context is not None: + target_size = self.target_size[context] + if self.random_interp: + interp = np.random.choice(self.interps) + else: + interp = self.interp + + resizer = Resize(target_size, keep_ratio=self.keep_ratio, interp=interp) + return [resizer(samples, context=context), index] diff --git a/ppdet/data/transform/operators.py b/ppdet/data/transform/operators.py index 206d9a48d..23ebd7730 100644 --- a/ppdet/data/transform/operators.py +++ b/ppdet/data/transform/operators.py @@ -1601,6 +1601,8 @@ class RandomCrop(BaseOperator): # only used in semi-det as unsup data sample = self.set_fake_bboxes(sample) sample = self.random_crop(sample, fake_bboxes=True) + del sample['gt_bbox'] + del sample['gt_class'] return sample if 'gt_bbox' in sample and len(sample['gt_bbox']) == 0: @@ -2074,26 +2076,29 @@ class NormalizeBox(BaseOperator): def apply(self, sample, context): im = sample['image'] - gt_bbox = sample['gt_bbox'] - height, width, _ = im.shape - for i in range(gt_bbox.shape[0]): - gt_bbox[i][0] = gt_bbox[i][0] / width - gt_bbox[i][1] = gt_bbox[i][1] / height - gt_bbox[i][2] = gt_bbox[i][2] / width - gt_bbox[i][3] = gt_bbox[i][3] / height - sample['gt_bbox'] = gt_bbox + if 'gt_bbox' in sample.keys(): + gt_bbox = sample['gt_bbox'] + height, width, _ = im.shape + for i in range(gt_bbox.shape[0]): + gt_bbox[i][0] = gt_bbox[i][0] / width + gt_bbox[i][1] = gt_bbox[i][1] / height + gt_bbox[i][2] = gt_bbox[i][2] / width + gt_bbox[i][3] = gt_bbox[i][3] / height + sample['gt_bbox'] = gt_bbox - if 'gt_keypoint' in sample.keys(): - gt_keypoint = sample['gt_keypoint'] + if 'gt_keypoint' in sample.keys(): + gt_keypoint = sample['gt_keypoint'] - for i in range(gt_keypoint.shape[1]): - if i % 2: - gt_keypoint[:, i] = gt_keypoint[:, i] / height - else: - gt_keypoint[:, i] = gt_keypoint[:, i] / width - sample['gt_keypoint'] = gt_keypoint + for i in range(gt_keypoint.shape[1]): + if i % 2: + gt_keypoint[:, i] = gt_keypoint[:, i] / height + else: + gt_keypoint[:, i] = gt_keypoint[:, i] / width + sample['gt_keypoint'] = gt_keypoint - return sample + return sample + else: + return sample @register_op @@ -2106,12 +2111,14 @@ class BboxXYXY2XYWH(BaseOperator): super(BboxXYXY2XYWH, self).__init__() def apply(self, sample, context=None): - assert 'gt_bbox' in sample - bbox = sample['gt_bbox'] - bbox[:, 2:4] = bbox[:, 2:4] - bbox[:, :2] - bbox[:, :2] = bbox[:, :2] + bbox[:, 2:4] / 2. - sample['gt_bbox'] = bbox - return sample + if 'gt_bbox' in sample.keys(): + bbox = sample['gt_bbox'] + bbox[:, 2:4] = bbox[:, 2:4] - bbox[:, :2] + bbox[:, :2] = bbox[:, :2] + bbox[:, 2:4] / 2. + sample['gt_bbox'] = bbox + return sample + else: + return sample @register_op @@ -2803,6 +2810,36 @@ class RandomSelect(BaseOperator): return self.transforms2(sample) +@register_op +class RandomSelects(BaseOperator): + """ + Randomly choose a transformation between transforms1 and transforms2, + and the probability of choosing transforms1 is p. + + The code is based on https://github.com/facebookresearch/detr/blob/main/datasets/transforms.py + + """ + + def __init__(self, transforms_list, p=None): + super(RandomSelects, self).__init__() + if p is not None: + assert isinstance(p, (list, tuple)) + assert len(transforms_list) == len(p) + else: + assert len(transforms_list) > 0 + self.transforms = [Compose(t) for t in transforms_list] + self.p = p + + def apply(self, sample, context=None): + if self.p is None: + return random.choice(self.transforms)(sample) + else: + prob = random.random() + for p, t in zip(self.p, self.transforms): + if prob <= p: + return t(sample) + + @register_op class RandomShortSideResize(BaseOperator): def __init__(self, diff --git a/ppdet/engine/callbacks.py b/ppdet/engine/callbacks.py index 1f2d546d8..eeb2f06de 100644 --- a/ppdet/engine/callbacks.py +++ b/ppdet/engine/callbacks.py @@ -26,7 +26,7 @@ import json import paddle import paddle.distributed as dist -from ppdet.utils.checkpoint import save_model +from ppdet.utils.checkpoint import save_model, save_semi_model from ppdet.metrics import get_infer_results from ppdet.utils.logger import setup_logger @@ -555,3 +555,141 @@ class SniperProposalsGenerator(Callback): logger.info("save proposals in {}".format(self.cfg.proposals_path)) with open(self.cfg.proposals_path, 'w') as f: json.dump(proposals, f) + + +class SemiLogPrinter(LogPrinter): + def __init__(self, model): + super(SemiLogPrinter, self).__init__(model) + + def on_step_end(self, status): + if dist.get_world_size() < 2 or dist.get_rank() == 0: + mode = status['mode'] + if mode == 'train': + epoch_id = status['epoch_id'] + step_id = status['step_id'] + iter_id = status['iter_id'] + steps_per_epoch = status['steps_per_epoch'] + training_staus = status['training_staus'] + batch_time = status['batch_time'] + data_time = status['data_time'] + + epoches = self.model.cfg.epoch + batch_size = self.model.cfg['{}Reader'.format(mode.capitalize( + ))]['batch_size'] + iters = epoches * steps_per_epoch + logs = training_staus.log() + iter_space_fmt = ':' + str(len(str(iters))) + 'd' + space_fmt = ':' + str(len(str(iters))) + 'd' + if step_id % self.model.cfg.log_iter == 0: + eta_steps = (epoches - epoch_id) * steps_per_epoch - step_id + eta_sec = eta_steps * batch_time.global_avg + eta_str = str(datetime.timedelta(seconds=int(eta_sec))) + ips = float(batch_size) / batch_time.avg + fmt = ' '.join([ + '{' + iter_space_fmt + '}/{} iters', + 'Epoch: [{}]', + '[{' + space_fmt + '}/{}]', + 'learning_rate: {lr:.6f}', + '{meters}', + 'eta: {eta}', + 'batch_cost: {btime}', + 'data_cost: {dtime}', + 'ips: {ips:.4f} images/s', + ]) + fmt = fmt.format( + iter_id, + iters, + epoch_id, + step_id, + steps_per_epoch, + lr=status['learning_rate'], + meters=logs, + eta=eta_str, + btime=str(batch_time), + dtime=str(data_time), + ips=ips) + logger.info(fmt) + if mode == 'eval': + step_id = status['step_id'] + if step_id % 100 == 0: + logger.info("Eval iter: {}".format(step_id)) + + +class SemiCheckpointer(Checkpointer): + def __init__(self, model): + super(SemiCheckpointer, self).__init__(model) + cfg = self.model.cfg + self.best_ap = 0. + self.save_dir = os.path.join(self.model.cfg.save_dir, + self.model.cfg.filename) + if hasattr(self.model.model, 'student') and hasattr(self.model.model, + 'teacher'): + self.weight = (self.model.model.teacher, self.model.model.student) + elif hasattr(self.model.model, 'student') or hasattr(self.model.model, + 'teacher'): + raise AttributeError( + "model has no attribute 'student' or 'teacher'") + else: + raise AttributeError( + "model has no attribute 'student' and 'teacher'") + + def every_n_iters(self, iter_id, n): + return (iter_id + 1) % n == 0 if n > 0 else False + + def on_step_end(self, status): + # Checkpointer only performed during training + mode = status['mode'] + eval_interval = status['eval_interval'] + save_interval = status['save_interval'] + iter_id = status['iter_id'] + epoch_id = status['epoch_id'] + t_weight = None + s_weight = None + save_name = None + if dist.get_world_size() < 2 or dist.get_rank() == 0: + if self.every_n_iters(iter_id, save_interval) and mode == 'train': + save_name = "last_epoch" + # save_name = str(iter_id + 1) + t_weight = self.weight[0].state_dict() + s_weight = self.weight[1].state_dict() + save_semi_model(t_weight, s_weight, self.model.optimizer, + self.save_dir, save_name, epoch_id + 1, + iter_id + 1) + + def on_epoch_end(self, status): + # Checkpointer only performed during training + mode = status['mode'] + eval_interval = status['eval_interval'] + save_interval = status['save_interval'] + iter_id = status['iter_id'] + epoch_id = status['epoch_id'] + t_weight = None + s_weight = None + save_name = None + if dist.get_world_size() < 2 or dist.get_rank() == 0: + if self.every_n_iters(iter_id, eval_interval) and mode == 'eval': + if 'save_best_model' in status and status['save_best_model']: + for metric in self.model._metrics: + map_res = metric.get_results() + if 'bbox' in map_res: + key = 'bbox' + elif 'keypoint' in map_res: + key = 'keypoint' + else: + key = 'mask' + if key not in map_res: + logger.warning("Evaluation results empty, this may be due to " \ + "training iterations being too few or not " \ + "loading the correct weights.") + return + if map_res[key][0] > self.best_ap: + self.best_ap = map_res[key][0] + save_name = 'best_model' + t_weight = self.weight[0].state_dict() + s_weight = self.weight[1].state_dict() + logger.info("Best teacher test {} ap is {:0.3f}.". + format(key, self.best_ap)) + if t_weight and s_weight: + save_semi_model(t_weight, s_weight, + self.model.optimizer, self.save_dir, + save_name, epoch_id + 1, iter_id + 1) diff --git a/ppdet/engine/trainer.py b/ppdet/engine/trainer.py index 7fb8cd611..260dbc9b7 100644 --- a/ppdet/engine/trainer.py +++ b/ppdet/engine/trainer.py @@ -48,7 +48,7 @@ from ppdet.utils import profiler from ppdet.modeling.post_process import multiclass_nms from ppdet.modeling.lane_utils import imshow_lanes -from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator, WandbCallback +from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator, WandbCallback, SemiCheckpointer, SemiLogPrinter from .export_utils import _dump_infer_config, _prune_input_spec, apply_to_static from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients @@ -226,7 +226,11 @@ class Trainer(object): def _init_callbacks(self): if self.mode == 'train': - self._callbacks = [LogPrinter(self), Checkpointer(self)] + if self.cfg.get('ssod_method', + False) and self.cfg['ssod_method'] == 'Semi_RTDETR': + self._callbacks = [SemiLogPrinter(self), SemiCheckpointer(self)] + else: + self._callbacks = [LogPrinter(self), Checkpointer(self)] if self.cfg.get('use_vdl', False): self._callbacks.append(VisualDLWriter(self)) if self.cfg.get('save_proposals', False): diff --git a/ppdet/engine/trainer_ssod.py b/ppdet/engine/trainer_ssod.py index ac39c9a97..ab4a100f5 100644 --- a/ppdet/engine/trainer_ssod.py +++ b/ppdet/engine/trainer_ssod.py @@ -40,7 +40,7 @@ MOT_ARCH = ['JDE', 'FairMOT', 'DeepSORT', 'ByteTrack', 'CenterTrack'] logger = setup_logger('ppdet.engine') -__all__ = ['Trainer_DenseTeacher', 'Trainer_ARSL'] +__all__ = ['Trainer_DenseTeacher', 'Trainer_ARSL', 'Trainer_Semi_RTDETR'] class Trainer_DenseTeacher(Trainer): @@ -856,3 +856,337 @@ class EnsembleTSModel(nn.Layer): super(EnsembleTSModel, self).__init__() self.modelTeacher = modelTeacher self.modelStudent = modelStudent + + +class Trainer_Semi_RTDETR(Trainer): + def __init__(self, cfg, mode='train'): + self.cfg = cfg + assert mode.lower() in ['train', 'eval', 'test'], \ + "mode should be 'train', 'eval' or 'test'" + self.mode = mode.lower() + self.optimizer = None + self.is_loaded_weights = False + self.use_amp = self.cfg.get('amp', False) + self.amp_level = self.cfg.get('amp_level', 'O1') + self.custom_white_list = self.cfg.get('custom_white_list', None) + self.custom_black_list = self.cfg.get('custom_black_list', None) + + # build data loader + capital_mode = self.mode.capitalize() + self.dataset = self.cfg['{}Dataset'.format(capital_mode)] = create( + '{}Dataset'.format(capital_mode))() + + if self.mode == 'train': + self.dataset_unlabel = self.cfg['UnsupTrainDataset'] = create( + 'UnsupTrainDataset') + self.loader = create('SemiTrainReader')( + self.dataset, self.dataset_unlabel, cfg.worker_num) + + # build model + if 'model' not in self.cfg: + self.model = create(cfg.SSOD) + else: + self.model = self.cfg.model + self.is_loaded_weights = True + + # EvalDataset build with BatchSampler to evaluate in single device + # TODO: multi-device evaluate + if self.mode == 'eval': + self._eval_batch_sampler = paddle.io.BatchSampler( + self.dataset, batch_size=self.cfg.EvalReader['batch_size']) + # If metric is VOC, need to be set collate_batch=False. + if cfg.metric == 'VOC': + cfg['EvalReader']['collate_batch'] = False + self.loader = create('EvalReader')(self.dataset, cfg.worker_num, + self._eval_batch_sampler) + # TestDataset build after user set images, skip loader creation here + + # build optimizer in train mode + if self.mode == 'train': + steps_per_epoch = len(self.loader) + if steps_per_epoch < 1: + logger.warning( + "Samples in dataset are less than batch_size, please set smaller batch_size in TrainReader." + ) + self.lr = create('LearningRate')(steps_per_epoch) + self.optimizer = create('OptimizerBuilder')(self.lr, self.model) + + # Unstructured pruner is only enabled in the train mode. + if self.cfg.get('unstructured_prune'): + self.pruner = create('UnstructuredPruner')(self.model, + steps_per_epoch) + if self.use_amp and self.amp_level == 'O2': + self.model, self.optimizer = paddle.amp.decorate( + models=self.model, + optimizers=self.optimizer, + level=self.amp_level) + + self._nranks = dist.get_world_size() + self._local_rank = dist.get_rank() + + self.status = {} + + self.start_epoch = 0 + self.start_iter = 0 + self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch + + # initial default callbacks + self._init_callbacks() + + # initial default metrics + self._init_metrics() + self._reset_metrics() + + def load_semi_weights(self, t_weights, s_weights): + if self.is_loaded_weights: + return + self.start_epoch = 0 + load_pretrain_weight(self.model.teacher, t_weights) + load_pretrain_weight(self.model.student, s_weights) + logger.info("Load teacher weights {} to start training".format( + t_weights)) + logger.info("Load student weights {} to start training".format( + s_weights)) + + def resume_weights(self, weights, exchange=True): + # support Distill resume weights + if hasattr(self.model, 'student_model'): + self.start_epoch = load_weight(self.model.student_model, weights, + self.optimizer, exchange) + else: + self.start_iter, self.start_epoch = load_weight( + self.model, weights, self.optimizer, self.ema + if self.use_ema else None, exchange) + logger.debug("Resume weights of epoch {}".format(self.start_epoch)) + logger.debug("Resume weights of iter {}".format(self.start_iter)) + + def train(self, validate=False): + assert self.mode == 'train', "Model not in 'train' mode" + Init_mark = False + if validate: + self.cfg.EvalDataset = create("EvalDataset")() + + model = self.model + sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and + self.cfg.use_gpu and self._nranks > 1) + if sync_bn: + # self.model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm( + # self.model) + model.teacher = paddle.nn.SyncBatchNorm.convert_sync_batchnorm( + model.teacher) + model.student = paddle.nn.SyncBatchNorm.convert_sync_batchnorm( + self.model.student) + + if self.cfg.get('fleet', False): + # model = fleet.distributed_model(model) + model = fleet.distributed_model(model) + + self.optimizer = fleet.distributed_optimizer(self.optimizer) + elif self._nranks > 1: + find_unused_parameters = self.cfg[ + 'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False + model = paddle.DataParallel( + model, find_unused_parameters=find_unused_parameters) + + if self.cfg.get('amp', False): + scaler = amp.GradScaler( + enable=self.cfg.use_gpu or self.cfg.use_npu, + init_loss_scaling=1024) + + self.status.update({ + 'epoch_id': self.start_epoch, + 'iter_id': self.start_iter, + # 'step_id': self.start_step, + 'steps_per_epoch': len(self.loader), + }) + + self.status['batch_time'] = stats.SmoothedValue( + self.cfg.log_iter, fmt='{avg:.4f}') + self.status['data_time'] = stats.SmoothedValue( + self.cfg.log_iter, fmt='{avg:.4f}') + self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter) + + if self.cfg.get('print_flops', False): + flops_loader = create('{}Reader'.format(self.mode.capitalize()))( + self.dataset, self.cfg.worker_num) + self._flops(flops_loader) + profiler_options = self.cfg.get('profiler_options', None) + + self._compose_callback.on_train_begin(self.status) + iter_id = self.start_iter + self.status['iter_id'] = iter_id + self.status['eval_interval'] = self.cfg.eval_interval + self.status['save_interval'] = self.cfg.save_interval + for epoch_id in range(self.start_epoch, self.cfg.epoch): + self.status['mode'] = 'train' + self.status['epoch_id'] = epoch_id + self._compose_callback.on_epoch_begin(self.status) + self.loader.dataset_label.set_epoch(epoch_id) + self.loader.dataset_unlabel.set_epoch(epoch_id) + iter_tic = time.time() + if self._nranks > 1: + # print(model) + model._layers.teacher.eval() + model._layers.student.train() + else: + model.teacher.eval() + model.student.train() + iter_tic = time.time() + for step_id in range(len(self.loader)): + data = next(self.loader) + data_sup_w, data_sup_s, data_unsup_w, data_unsup_s = data + data_sup_w['epoch_id'] = epoch_id + data_sup_s['epoch_id'] = epoch_id + data_unsup_w['epoch_id'] = epoch_id + data_unsup_s['epoch_id'] = epoch_id + data = [data_sup_w, data_sup_s, data_unsup_w, data_unsup_s] + iter_id += 1 + self.status['data_time'].update(time.time() - iter_tic) + self.status['step_id'] = step_id + self.status['iter_id'] = iter_id + data.append(iter_id) + profiler.add_profiler_step(profiler_options) + self._compose_callback.on_step_begin(self.status) + if self.cfg.get('amp', False): + with amp.auto_cast(enable=self.cfg.use_gpu): + # model forward + if self._nranks > 1: + outputs = model._layers(data) + else: + outputs = model(data) + loss = outputs['loss'] + + scaled_loss = scaler.scale(loss) + scaled_loss.backward() + scaler.minimize(self.optimizer, scaled_loss) + else: + outputs = model(data) + loss = outputs['loss'] + # model backward + loss.backward() + self.optimizer.step() + curr_lr = self.optimizer.get_lr() + self.lr.step() + if self.cfg.get('unstructured_prune'): + self.pruner.step() + self.optimizer.clear_grad() + # print(outputs) + # outputs=reduce_dict(outputs) + # if self.model.debug: + # check_gradient(model) + # self.check_gradient() + self.status['learning_rate'] = curr_lr + if self._nranks < 2 or self._local_rank == 0: + self.status['training_staus'].update(outputs) + + self.status['batch_time'].update(time.time() - iter_tic) + + if validate and (self._nranks < 2 or self._local_rank == 0) and \ + ((iter_id + 1) % self.cfg.eval_interval == 0): + if not hasattr(self, '_eval_loader'): + # build evaluation dataset and loader + self._eval_dataset = self.cfg.EvalDataset + self._eval_batch_sampler = \ + paddle.io.BatchSampler( + self._eval_dataset, + batch_size=self.cfg.EvalReader['batch_size']) + # If metric is VOC, need to be set collate_batch=False. + if self.cfg.metric == 'VOC': + self.cfg['EvalReader']['collate_batch'] = False + self._eval_loader = create('EvalReader')( + self._eval_dataset, + self.cfg.worker_num, + batch_sampler=self._eval_batch_sampler) + # if validation in training is enabled, metrics should be re-init + # Init_mark makes sure this code will only execute once + if validate and Init_mark == False: + Init_mark = True + self._init_metrics(validate=validate) + self._reset_metrics() + + with paddle.no_grad(): + self.status['save_best_model'] = True + self._eval_with_loader(self._eval_loader) + model._layers.student.train() + + self._compose_callback.on_step_end(self.status) + + iter_tic = time.time() + + if self.cfg.get('unstructured_prune'): + self.pruner.update_params() + self._compose_callback.on_epoch_end(self.status) + + self._compose_callback.on_train_end(self.status) + + def _eval_with_loader(self, loader): + sample_num = 0 + tic = time.time() + self._compose_callback.on_epoch_begin(self.status) + self.status['mode'] = 'eval' + self.model.eval() + if self.cfg.get('print_flops', False): + flops_loader = create('{}Reader'.format(self.mode.capitalize()))( + self.dataset, self.cfg.worker_num, self._eval_batch_sampler) + self._flops(flops_loader) + print("*****teacher evaluate*****") + for step_id, data in enumerate(loader): + self.status['step_id'] = step_id + self._compose_callback.on_step_begin(self.status) + # forward + outs = self.model.teacher(data) + + # update metrics + for metric in self._metrics: + metric.update(data, outs) + + # multi-scale inputs: all inputs have same im_id + if isinstance(data, typing.Sequence): + sample_num += data[0]['im_id'].numpy().shape[0] + else: + sample_num += data['im_id'].numpy().shape[0] + self._compose_callback.on_step_end(self.status) + + self.status['sample_num'] = sample_num + self.status['cost_time'] = time.time() - tic + + # accumulate metric to log out + for metric in self._metrics: + metric.accumulate() + metric.log() + self._compose_callback.on_epoch_end(self.status) + # reset metric states for metric may performed multiple times + self._reset_metrics() + + print("*****student evaluate*****") + for step_id, data in enumerate(loader): + self.status['step_id'] = step_id + self._compose_callback.on_step_begin(self.status) + # forward + outs = self.model.student(data) + + # update metrics + for metric in self._metrics: + metric.update(data, outs) + + # multi-scale inputs: all inputs have same im_id + if isinstance(data, typing.Sequence): + sample_num += data[0]['im_id'].numpy().shape[0] + else: + sample_num += data['im_id'].numpy().shape[0] + self._compose_callback.on_step_end(self.status) + + self.status['sample_num'] = sample_num + self.status['cost_time'] = time.time() - tic + + # accumulate metric to log out + for metric in self._metrics: + metric.accumulate() + metric.log() + # reset metric states for metric may performed multiple times + self._reset_metrics() + self.status['mode'] = 'train' + + def evaluate(self): + with paddle.no_grad(): + self._eval_with_loader(self.loader) diff --git a/ppdet/modeling/architectures/__init__.py b/ppdet/modeling/architectures/__init__.py index ad60f0f24..d22df32d8 100644 --- a/ppdet/modeling/architectures/__init__.py +++ b/ppdet/modeling/architectures/__init__.py @@ -42,6 +42,8 @@ from . import yolof from . import pose3d_metro from . import centertrack from . import queryinst +from . import detr_ssod +from . import multi_stream_detector from . import clrnet from .meta_arch import * @@ -76,4 +78,6 @@ from .pose3d_metro import * from .centertrack import * from .queryinst import * from .keypoint_petr import * -from .clrnet import * \ No newline at end of file +from .detr_ssod import * +from .multi_stream_detector import * +from .clrnet import * diff --git a/ppdet/modeling/architectures/detr.py b/ppdet/modeling/architectures/detr.py index 7839a1263..085f63f8c 100644 --- a/ppdet/modeling/architectures/detr.py +++ b/ppdet/modeling/architectures/detr.py @@ -27,7 +27,7 @@ __all__ = ['DETR'] @register class DETR(BaseArch): __category__ = 'architecture' - __inject__ = ['post_process'] + __inject__ = ['post_process', 'post_process_semi'] __shared__ = ['with_mask', 'exclude_post_process'] def __init__(self, @@ -36,6 +36,7 @@ class DETR(BaseArch): detr_head='DETRHead', neck=None, post_process='DETRPostProcess', + post_process_semi=None, with_mask=False, exclude_post_process=False): super(DETR, self).__init__() @@ -46,6 +47,7 @@ class DETR(BaseArch): self.post_process = post_process self.with_mask = with_mask self.exclude_post_process = exclude_post_process + self.post_process_semi = post_process_semi @classmethod def from_config(cls, cfg, *args, **kwargs): diff --git a/ppdet/modeling/architectures/detr_ssod.py b/ppdet/modeling/architectures/detr_ssod.py new file mode 100644 index 000000000..567c23418 --- /dev/null +++ b/ppdet/modeling/architectures/detr_ssod.py @@ -0,0 +1,341 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from ppdet.core.workspace import register, create, merge_config +import paddle + +import numpy as np +import paddle +import paddle.nn.functional as F +from ppdet.core.workspace import register, create +from ppdet.utils.logger import setup_logger +from ppdet.modeling.ssod.utils import filter_invalid +from .multi_stream_detector import MultiSteamDetector +logger = setup_logger(__name__) + +__all__ = ['DETR_SSOD'] +__shared__ = ['num_classes'] + + +@register +class DETR_SSOD(MultiSteamDetector): + def __init__(self, + teacher, + student, + train_cfg=None, + test_cfg=None, + RTDETRTransformer=None, + num_classes=80): + super(DETR_SSOD, self).__init__( + dict( + teacher=teacher, student=student), + train_cfg=train_cfg, + test_cfg=test_cfg, ) + self.ema_start_iters = train_cfg['ema_start_iters'] + self.momentum = 0.9996 + self.cls_thr = None + self.cls_thr_ig = None + self.num_classes = num_classes + if train_cfg is not None: + self.freeze("teacher") + self.unsup_weight = self.train_cfg['unsup_weight'] + self.sup_weight = self.train_cfg['sup_weight'] + self._teacher = None + self._student = None + self._transformer = None + + @classmethod + def from_config(cls, cfg): + teacher = create(cfg['teacher']) + merge_config(cfg) + student = create(cfg['student']) + train_cfg = cfg['train_cfg'] + test_cfg = cfg['test_cfg'] + RTDETRTransformer = cfg['RTDETRTransformer'] + return { + 'teacher': teacher, + 'student': student, + 'train_cfg': train_cfg, + 'test_cfg': test_cfg, + 'RTDETRTransformer': RTDETRTransformer + } + + def forward_train(self, inputs, **kwargs): + if isinstance(inputs, dict): + iter_id = inputs['iter_id'] + elif isinstance(inputs, list): + iter_id = inputs[-1] + if iter_id == self.ema_start_iters: + self.update_ema_model(momentum=0) + elif iter_id > self.ema_start_iters: + self.update_ema_model(momentum=self.momentum) + if iter_id > self.ema_start_iters: + data_sup_w, data_sup_s, data_unsup_w, data_unsup_s, _ = inputs + + if data_sup_w['image'].shape != data_sup_s['image'].shape: + data_sup_w, data_sup_s = align_weak_strong_shape(data_sup_w, + data_sup_s) + + if 'gt_bbox' in data_unsup_s.keys(): + del data_unsup_s['gt_bbox'] + if 'gt_class' in data_unsup_s.keys(): + del data_unsup_s['gt_class'] + if 'gt_class' in data_unsup_w.keys(): + del data_unsup_w['gt_class'] + if 'gt_bbox' in data_unsup_w.keys(): + del data_unsup_w['gt_bbox'] + for k, v in data_sup_s.items(): + if k in ['epoch_id']: + continue + elif k in ['gt_class', 'gt_bbox', 'is_crowd']: + data_sup_s[k].extend(data_sup_w[k]) + else: + data_sup_s[k] = paddle.concat([v, data_sup_w[k]]) + + loss = {} + body_feats = self.student.backbone(data_sup_s) + if self.student.neck is not None: + body_feats = self.student.neck(body_feats) + out_transformer = self.student.transformer(body_feats, None, + data_sup_s) + sup_loss = self.student.detr_head(out_transformer, body_feats, + data_sup_s) + sup_loss.update({ + 'loss': paddle.add_n( + [v for k, v in sup_loss.items() if 'log' not in k]) + }) + sup_loss = {"sup_" + k: v for k, v in sup_loss.items()} + + loss.update(**sup_loss) + unsup_loss = self.foward_unsup_train(data_unsup_w, data_unsup_s) + unsup_loss.update({ + 'loss': paddle.add_n( + [v for k, v in unsup_loss.items() if 'log' not in k]) + }) + unsup_loss = {"unsup_" + k: v for k, v in unsup_loss.items()} + unsup_loss.update({ + 'loss': paddle.add_n( + [v for k, v in unsup_loss.items() if 'log' not in k]) + }) + loss.update(**unsup_loss) + loss.update({'loss': loss['sup_loss'] + loss['unsup_loss']}) + else: + if iter_id == self.ema_start_iters: + logger.info("start semi_supervised_traing") + data_sup_w, data_sup_s, data_unsup_w, data_unsup_s, _ = inputs + + if data_sup_w['image'].shape != data_sup_s['image'].shape: + data_sup_w, data_sup_s = align_weak_strong_shape(data_sup_w, + data_sup_s) + for k, v in data_sup_s.items(): + if k in ['epoch_id']: + continue + elif k in ['gt_class', 'gt_bbox', 'is_crowd']: + data_sup_s[k].extend(data_sup_w[k]) + else: + data_sup_s[k] = paddle.concat([v, data_sup_w[k]]) + loss = {} + sup_loss = self.student(data_sup_s) + unsup_loss = { + "unsup_" + k: v * paddle.to_tensor(0) + for k, v in sup_loss.items() + } + sup_loss = {"sup_" + k: v for k, v in sup_loss.items()} + loss.update(**sup_loss) + unsup_loss.update({ + 'loss': paddle.add_n( + [v * 0 for k, v in sup_loss.items() if 'log' not in k]) + }) + unsup_loss = {"unsup_" + k: v * 0 for k, v in unsup_loss.items()} + loss.update(**unsup_loss) + loss.update({'loss': loss['sup_loss']}) + return loss + + def foward_unsup_train(self, data_unsup_w, data_unsup_s): + + with paddle.no_grad(): + body_feats = self.teacher.backbone(data_unsup_w) + if self.teacher.neck is not None: + body_feats = self.teacher.neck(body_feats, is_teacher=True) + out_transformer = self.teacher.transformer( + body_feats, None, data_unsup_w, is_teacher=True) + preds = self.teacher.detr_head(out_transformer, body_feats) + bbox, bbox_num = self.teacher.post_process_semi(preds) + self.place = body_feats[0].place + + proposal_bbox_list = bbox[:, -4:] + proposal_bbox_list = proposal_bbox_list.split( + tuple(np.array(bbox_num)), 0) + + proposal_label_list = paddle.cast(bbox[:, :1], np.float32) + proposal_label_list = proposal_label_list.split( + tuple(np.array(bbox_num)), 0) + proposal_score_list = paddle.cast(bbox[:, 1:self.num_classes + 1], + np.float32) + proposal_score_list = proposal_score_list.split( + tuple(np.array(bbox_num)), 0) + proposal_bbox_list = [ + paddle.to_tensor( + p, place=self.place) for p in proposal_bbox_list + ] + proposal_label_list = [ + paddle.to_tensor( + p, place=self.place) for p in proposal_label_list + ] + # filter invalid box roughly + if isinstance(self.train_cfg['pseudo_label_initial_score_thr'], float): + thr = self.train_cfg['pseudo_label_initial_score_thr'] + else: + # TODO: use dynamic threshold + raise NotImplementedError( + "Dynamic Threshold is not implemented yet.") + proposal_bbox_list, proposal_label_list, proposal_score_list = list( + zip(* [ + filter_invalid( + proposal[:, :4], + proposal_label, + proposal_score, + thr=thr, + min_size=self.train_cfg['min_pseduo_box_size'], ) + for proposal, proposal_label, proposal_score in + zip(proposal_bbox_list, proposal_label_list, + proposal_score_list) + ])) + + teacher_bboxes = list(proposal_bbox_list) + teacher_labels = proposal_label_list + teacher_info = [teacher_bboxes, teacher_labels] + student_unsup = data_unsup_s + return self.compute_pseudo_label_loss(student_unsup, teacher_info, + proposal_score_list) + + def compute_pseudo_label_loss(self, student_unsup, teacher_info, + proposal_score_list): + + pseudo_bboxes = list(teacher_info[0]) + pseudo_labels = list(teacher_info[1]) + losses = dict() + for i in range(len(pseudo_bboxes)): + if pseudo_labels[i].shape[0] == 0: + pseudo_bboxes[i] = paddle.zeros([0, 4]).numpy() + pseudo_labels[i] = paddle.zeros([0, 1]).numpy() + else: + pseudo_bboxes[i] = pseudo_bboxes[i][:, :4].numpy() + pseudo_labels[i] = pseudo_labels[i].numpy() + for i in range(len(pseudo_bboxes)): + pseudo_labels[i] = paddle.to_tensor( + pseudo_labels[i], dtype=paddle.int32, place=self.place) + pseudo_bboxes[i] = paddle.to_tensor( + pseudo_bboxes[i], dtype=paddle.float32, place=self.place) + student_unsup.update({ + 'gt_bbox': pseudo_bboxes, + 'gt_class': pseudo_labels + }) + pseudo_sum = 0 + for i in range(len(pseudo_bboxes)): + pseudo_sum += pseudo_bboxes[i].sum() + if pseudo_sum == 0: #input fake data when there are no pseudo labels + pseudo_bboxes[0] = paddle.ones([1, 4]) - 0.5 + pseudo_labels[0] = paddle.ones([1, 1]).astype('int32') + student_unsup.update({ + 'gt_bbox': pseudo_bboxes, + 'gt_class': pseudo_labels + }) + body_feats = self.student.backbone(student_unsup) + if self.student.neck is not None: + body_feats = self.student.neck(body_feats) + out_transformer = self.student.transformer(body_feats, None, + student_unsup) + losses = self.student.detr_head(out_transformer, body_feats, + student_unsup) + for n, v in losses.items(): + losses[n] = v * 0 + else: + gt_bbox = [] + gt_class = [] + images = [] + proposal_score = [] + for i in range(len(pseudo_bboxes)): + if pseudo_labels[i].shape[0] == 0: + continue + else: + proposal_score.append(proposal_score_list[i].max(-1) + .unsqueeze(-1)) + gt_class.append(pseudo_labels[i]) + gt_bbox.append(pseudo_bboxes[i]) + images.append(student_unsup['image'][i]) + images = paddle.stack(images) + student_unsup.update({ + 'image': images, + 'gt_bbox': gt_bbox, + 'gt_class': gt_class + }) + body_feats = self.student.backbone(student_unsup) + if self.student.neck is not None: + body_feats = self.student.neck(body_feats) + out_transformer = self.student.transformer(body_feats, None, + student_unsup) + student_unsup.update({'gt_score': proposal_score}) + losses = self.student.detr_head(out_transformer, body_feats, + student_unsup) + return losses + + +def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return paddle.stack(b, axis=-1) + + +def box_xyxy_to_cxcywh(x): + x0, y0, x1, y1 = x.unbind(-1) + b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] + return paddle.stack(b, axis=-1) + + +def get_size_with_aspect_ratio(image_size, size, max_size=None): + w, h = image_size + if max_size is not None: + min_original_size = float(min((w, h))) + max_original_size = float(max((w, h))) + if max_original_size / min_original_size * size > max_size: + size = int(round(max_size * min_original_size / max_original_size)) + + if (w <= h and w == size) or (h <= w and h == size): + return (w, h) + + if w < h: + ow = size + oh = int(size * h / w) + else: + oh = size + ow = int(size * w / h) + + return (ow, oh) + + +def align_weak_strong_shape(data_weak, data_strong): + shape_x = data_strong['image'].shape[2] + shape_y = data_strong['image'].shape[3] + + target_size = [shape_x, shape_y] + data_weak['image'] = F.interpolate( + data_weak['image'], + size=target_size, + mode='bilinear', + align_corners=False) + return data_weak, data_strong diff --git a/ppdet/modeling/architectures/multi_stream_detector.py b/ppdet/modeling/architectures/multi_stream_detector.py new file mode 100644 index 000000000..58c4fe02e --- /dev/null +++ b/ppdet/modeling/architectures/multi_stream_detector.py @@ -0,0 +1,69 @@ +from typing import Dict +from collections import OrderedDict +from ppdet.modeling.architectures.meta_arch import BaseArch + + +class MultiSteamDetector(BaseArch): + def __init__(self, + model: Dict[str, BaseArch], + train_cfg=None, + test_cfg=None): + super(MultiSteamDetector, self).__init__() + self.submodules = list(model.keys()) + for k, v in model.items(): + setattr(self, k, v) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.inference_on = self.test_cfg.get("inference_on", + self.submodules[0]) + self.first_load = True + + def forward(self, inputs, return_loss=True, **kwargs): + """Calls either :func:`forward_train` or :func:`forward_test` depending + on whether ``return_loss`` is ``True``. + + Note this setting will change the expected inputs. When + ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor + and List[dict]), and when ``resturn_loss=False``, img and img_meta + should be double nested (i.e. List[Tensor], List[List[dict]]), with + the outer list indicating test time augmentations. + """ + if return_loss: + return self.forward_train(inputs, **kwargs) + else: + return self.forward_test(inputs, **kwargs) + + def get_loss(self, **kwargs): + # losses = self(**data) + + return self.forward_train(self, **kwargs) + + def model(self, **kwargs) -> BaseArch: + if "submodule" in kwargs: + assert (kwargs["submodule"] in self.submodules + ), "Detector does not contain submodule {}".format(kwargs[ + "submodule"]) + model: BaseArch = getattr(self, kwargs["submodule"]) + else: + model: BaseArch = getattr(self, self.inference_on) + return model + + def freeze(self, model_ref: str): + assert model_ref in self.submodules + model = getattr(self, model_ref) + model.eval() + for param in model.parameters(): + param.stop_gradient = True + + def update_ema_model(self, momentum=0.9996): + # print(momentum) + model_dict = self.student.state_dict() + new_dict = OrderedDict() + for key, value in self.teacher.state_dict().items(): + if key in model_dict.keys(): + new_dict[key] = (model_dict[key] * + (1 - momentum) + value * momentum) + else: + raise Exception("{} is not found in student model".format(key)) + self.teacher.set_dict(new_dict) diff --git a/ppdet/modeling/heads/detr_head.py b/ppdet/modeling/heads/detr_head.py index f65a98434..d3c093fbc 100644 --- a/ppdet/modeling/heads/detr_head.py +++ b/ppdet/modeling/heads/detr_head.py @@ -368,9 +368,10 @@ class DeformableDETRHead(nn.Layer): class DINOHead(nn.Layer): __inject__ = ['loss'] - def __init__(self, loss='DINOLoss'): + def __init__(self, loss='DINOLoss', eval_idx=-1): super(DINOHead, self).__init__() self.loss = loss + self.eval_idx = eval_idx def forward(self, out_transformer, body_feats, inputs=None): (dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits, @@ -456,9 +457,11 @@ class DINOHead(nn.Layer): inputs['gt_class'], dn_out_bboxes=dn_out_bboxes, dn_out_logits=dn_out_logits, - dn_meta=dn_meta) + dn_meta=dn_meta, + gt_score=inputs.get('gt_score', None)) else: - return (dec_out_bboxes[-1], dec_out_logits[-1], None) + return (dec_out_bboxes[self.eval_idx], + dec_out_logits[self.eval_idx], None) @register diff --git a/ppdet/modeling/losses/detr_loss.py b/ppdet/modeling/losses/detr_loss.py index 24f14c3d4..62b9c00ec 100644 --- a/ppdet/modeling/losses/detr_loss.py +++ b/ppdet/modeling/losses/detr_loss.py @@ -81,7 +81,8 @@ class DETRLoss(nn.Layer): bg_index, num_gts, postfix="", - iou_score=None): + iou_score=None, + gt_score=None): # logits: [b, query, num_classes], gt_class: list[[n, 1]] name_class = "loss_class" + postfix @@ -98,15 +99,35 @@ class DETRLoss(nn.Layer): target_label = F.one_hot(target_label, self.num_classes + 1)[..., :-1] if iou_score is not None and self.use_vfl: - target_score = paddle.zeros([bs, num_query_objects]) - if num_gt > 0: + if gt_score is not None: + target_score = paddle.zeros([bs, num_query_objects]) target_score = paddle.scatter( - target_score.reshape([-1, 1]), index, iou_score) - target_score = target_score.reshape( - [bs, num_query_objects, 1]) * target_label - loss_ = self.loss_coeff['class'] * varifocal_loss_with_logits( - logits, target_score, target_label, - num_gts / num_query_objects) + target_score.reshape([-1, 1]), index, gt_score) + target_score = target_score.reshape( + [bs, num_query_objects, 1]) * target_label + + target_score_iou = paddle.zeros([bs, num_query_objects]) + target_score_iou = paddle.scatter( + target_score_iou.reshape([-1, 1]), index, iou_score) + target_score_iou = target_score_iou.reshape( + [bs, num_query_objects, 1]) * target_label + target_score = paddle.multiply(target_score, + target_score_iou) + loss_ = self.loss_coeff[ + 'class'] * varifocal_loss_with_logits( + logits, target_score, target_label, + num_gts / num_query_objects) + else: + target_score = paddle.zeros([bs, num_query_objects]) + if num_gt > 0: + target_score = paddle.scatter( + target_score.reshape([-1, 1]), index, iou_score) + target_score = target_score.reshape( + [bs, num_query_objects, 1]) * target_label + loss_ = self.loss_coeff[ + 'class'] * varifocal_loss_with_logits( + logits, target_score, target_label, + num_gts / num_query_objects) else: loss_ = self.loss_coeff['class'] * sigmoid_focal_loss( logits, target_label, num_gts / num_query_objects) @@ -183,7 +204,8 @@ class DETRLoss(nn.Layer): dn_match_indices=None, postfix="", masks=None, - gt_mask=None): + gt_mask=None, + gt_score=None): loss_class = [] loss_bbox, loss_giou = [], [] loss_mask, loss_dice = [], [] @@ -216,12 +238,22 @@ class DETRLoss(nn.Layer): bbox_cxcywh_to_xyxy(target_bbox).split(4, -1)) else: iou_score = None + if gt_score is not None: + _, target_score = self._get_src_target_assign( + logits[-1].detach(), gt_score, match_indices) else: iou_score = None loss_class.append( - self._get_loss_class(aux_logits, gt_class, match_indices, - bg_index, num_gts, postfix, iou_score)[ - 'loss_class' + postfix]) + self._get_loss_class( + aux_logits, + gt_class, + match_indices, + bg_index, + num_gts, + postfix, + iou_score, + gt_score=target_score + if gt_score is not None else None)['loss_class' + postfix]) loss_ = self._get_loss_bbox(aux_boxes, gt_bbox, match_indices, num_gts, postfix) loss_bbox.append(loss_['loss_bbox' + postfix]) @@ -284,7 +316,8 @@ class DETRLoss(nn.Layer): gt_mask=None, postfix="", dn_match_indices=None, - num_gts=1): + num_gts=1, + gt_score=None): if dn_match_indices is None: match_indices = self.matcher( boxes, logits, gt_bbox, gt_class, masks=masks, gt_mask=gt_mask) @@ -298,6 +331,9 @@ class DETRLoss(nn.Layer): iou_score = bbox_iou( bbox_cxcywh_to_xyxy(src_bbox).split(4, -1), bbox_cxcywh_to_xyxy(target_bbox).split(4, -1)) + if gt_score is not None: #ssod + _, target_score = self._get_src_target_assign( + logits[-1].detach(), gt_score, match_indices) else: iou_score = None else: @@ -305,8 +341,15 @@ class DETRLoss(nn.Layer): loss = dict() loss.update( - self._get_loss_class(logits, gt_class, match_indices, - self.num_classes, num_gts, postfix, iou_score)) + self._get_loss_class( + logits, + gt_class, + match_indices, + self.num_classes, + num_gts, + postfix, + iou_score, + gt_score=target_score if gt_score is not None else None)) loss.update( self._get_loss_bbox(boxes, gt_bbox, match_indices, num_gts, postfix)) @@ -324,6 +367,7 @@ class DETRLoss(nn.Layer): masks=None, gt_mask=None, postfix="", + gt_score=None, **kwargs): r""" Args: @@ -350,7 +394,8 @@ class DETRLoss(nn.Layer): gt_mask=gt_mask, postfix=postfix, dn_match_indices=dn_match_indices, - num_gts=num_gts) + num_gts=num_gts, + gt_score=gt_score if gt_score is not None else None) if self.aux_loss: total_loss.update( @@ -364,7 +409,8 @@ class DETRLoss(nn.Layer): dn_match_indices, postfix, masks=masks[:-1] if masks is not None else None, - gt_mask=gt_mask)) + gt_mask=gt_mask, + gt_score=gt_score if gt_score is not None else None)) return total_loss @@ -382,10 +428,16 @@ class DINOLoss(DETRLoss): dn_out_bboxes=None, dn_out_logits=None, dn_meta=None, + gt_score=None, **kwargs): num_gts = self._get_num_gts(gt_class) total_loss = super(DINOLoss, self).forward( - boxes, logits, gt_bbox, gt_class, num_gts=num_gts) + boxes, + logits, + gt_bbox, + gt_class, + num_gts=num_gts, + gt_score=gt_score) if dn_meta is not None: dn_positive_idx, dn_num_group = \ @@ -405,7 +457,8 @@ class DINOLoss(DETRLoss): gt_class, postfix="_dn", dn_match_indices=dn_match_indices, - num_gts=num_gts) + num_gts=num_gts, + gt_score=gt_score) total_loss.update(dn_loss) else: total_loss.update( diff --git a/ppdet/modeling/post_process.py b/ppdet/modeling/post_process.py index 24722ff67..efde830b1 100644 --- a/ppdet/modeling/post_process.py +++ b/ppdet/modeling/post_process.py @@ -26,7 +26,8 @@ except Exception: __all__ = [ 'BBoxPostProcess', 'MaskPostProcess', 'JDEBBoxPostProcess', - 'CenterNetPostProcess', 'DETRPostProcess', 'SparsePostProcess' + 'CenterNetPostProcess', 'DETRPostProcess', 'SparsePostProcess', + 'DETRBBoxSemiPostProcess' ] @@ -743,3 +744,58 @@ def nms(dets, match_threshold=0.6, match_metric='iou'): keep = np.where(suppressed == 0)[0] dets = dets[keep, :] return dets + + +@register +class DETRBBoxSemiPostProcess(object): + __shared__ = ['num_classes', 'use_focal_loss'] + __inject__ = [] + + def __init__(self, + num_classes=80, + num_top_queries=100, + use_focal_loss=False): + super(DETRBBoxSemiPostProcess, self).__init__() + self.num_classes = num_classes + self.num_top_queries = num_top_queries + self.use_focal_loss = use_focal_loss + + def __call__(self, head_out): + """ + Decode the bbox. + Args: + head_out (tuple): bbox_pred, cls_logit and masks of bbox_head output. + im_shape (Tensor): The shape of the input image. + scale_factor (Tensor): The scale factor of the input image. + Returns: + bbox_pred (Tensor): The output prediction with shape [N, 6], including + labels, scores and bboxes. The size of bboxes are corresponding + to the input image, the bboxes may be used in other branch. + bbox_num (Tensor): The number of prediction boxes of each batch with + shape [bs], and is N. + """ + bboxes, logits, masks = head_out + bbox_pred = bboxes + + scores = F.softmax(logits, axis=2) + + import copy + soft_scores = copy.deepcopy(scores) + scores, index = paddle.topk(scores.max(-1), 300, axis=-1) + + batch_ind = paddle.arange(end=scores.shape[0]).unsqueeze(-1).tile( + [1, 300]) + index = paddle.stack([batch_ind, index], axis=-1) + labels = paddle.gather_nd(soft_scores.argmax(-1), index).astype('int32') + score_class = paddle.gather_nd(soft_scores, index) + bbox_pred = paddle.gather_nd(bbox_pred, index) + bbox_pred = paddle.concat( + [ + labels.unsqueeze(-1).astype('float32'), score_class, + scores.unsqueeze(-1), bbox_pred + ], + axis=-1) + bbox_num = paddle.to_tensor( + bbox_pred.shape[1], dtype='int32').tile([bbox_pred.shape[0]]) + bbox_pred = bbox_pred.reshape([-1, bbox_pred.shape[-1]]) + return bbox_pred, bbox_num diff --git a/ppdet/modeling/ssod/utils.py b/ppdet/modeling/ssod/utils.py index 09753abfe..6c9e86f78 100644 --- a/ppdet/modeling/ssod/utils.py +++ b/ppdet/modeling/ssod/utils.py @@ -80,3 +80,25 @@ def QFLv2(pred_sigmoid, elif reduction == "sum": loss = loss[valid].sum() return loss + + +def filter_invalid(bbox, label=None, score=None, thr=0.0, min_size=0): + if score.numel() > 0: + soft_score = score.max(-1) + valid = soft_score >= thr + bbox = bbox[valid] + + if label is not None: + label = label[valid] + score = score[valid] + if min_size is not None and bbox.shape[0] > 0: + bw = bbox[:, 2] + bh = bbox[:, 3] + valid = (bw > min_size) & (bh > min_size) + bbox = bbox[valid] + + if label is not None: + label = label[valid] + score = score[valid] + + return bbox, label, score diff --git a/ppdet/modeling/transformers/hybrid_encoder.py b/ppdet/modeling/transformers/hybrid_encoder.py index b64c4ee3b..5694803eb 100644 --- a/ppdet/modeling/transformers/hybrid_encoder.py +++ b/ppdet/modeling/transformers/hybrid_encoder.py @@ -43,7 +43,7 @@ class CSPRepLayer(nn.Layer): in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act) self.conv2 = BaseConv( in_channels, hidden_channels, ksize=1, stride=1, bias=bias, act=act) - self.bottlenecks = nn.Sequential(*[ + self.bottlenecks = nn.Sequential(* [ RepVggBlock( hidden_channels, hidden_channels, act=act) for _ in range(num_blocks) @@ -237,7 +237,7 @@ class HybridEncoder(nn.Layer): ], axis=1)[None, :, :] - def forward(self, feats, for_mot=False): + def forward(self, feats, for_mot=False, is_teacher=False): assert len(feats) == len(self.in_channels) # get projection features proj_feats = [self.input_proj[i](feat) for i, feat in enumerate(feats)] @@ -248,7 +248,7 @@ class HybridEncoder(nn.Layer): # flatten [B, C, H, W] to [B, HxW, C] src_flatten = proj_feats[enc_ind].flatten(2).transpose( [0, 2, 1]) - if self.training or self.eval_size is None: + if self.training or self.eval_size is None or is_teacher: pos_embed = self.build_2d_sincos_position_embedding( w, h, self.hidden_dim, self.pe_temperature) else: diff --git a/ppdet/modeling/transformers/rtdetr_transformer.py b/ppdet/modeling/transformers/rtdetr_transformer.py index a01ccd666..f3d021f66 100644 --- a/ppdet/modeling/transformers/rtdetr_transformer.py +++ b/ppdet/modeling/transformers/rtdetr_transformer.py @@ -439,7 +439,7 @@ class RTDETRTransformer(nn.Layer): level_start_index.pop() return (feat_flatten, spatial_shapes, level_start_index) - def forward(self, feats, pad_mask=None, gt_meta=None): + def forward(self, feats, pad_mask=None, gt_meta=None, is_teacher=False): # input projection and embedding (memory, spatial_shapes, level_start_index) = self._get_encoder_input(feats) @@ -459,7 +459,7 @@ class RTDETRTransformer(nn.Layer): target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \ self._get_decoder_input( - memory, spatial_shapes, denoising_class, denoising_bbox_unact) + memory, spatial_shapes, denoising_class, denoising_bbox_unact,is_teacher) # decoder out_bboxes, out_logits = self.decoder( @@ -513,10 +513,11 @@ class RTDETRTransformer(nn.Layer): memory, spatial_shapes, denoising_class=None, - denoising_bbox_unact=None): + denoising_bbox_unact=None, + is_teacher=False): bs, _, _ = memory.shape # prepare input for decoder - if self.training or self.eval_size is None: + if self.training or self.eval_size is None or is_teacher: anchors, valid_mask = self._generate_anchors(spatial_shapes) else: anchors, valid_mask = self.anchors, self.valid_mask diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index f3dafd40f..101e46b32 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -323,3 +323,40 @@ def save_model(model, state_dict['last_epoch'] = last_epoch paddle.save(state_dict, save_path + ".pdopt") logger.info("Save checkpoint: {}".format(save_dir)) + + +def save_semi_model(teacher_model, student_model, optimizer, save_dir, + save_name, last_epoch, last_iter): + """ + save teacher and student model into disk. + Args: + teacher_model (dict): the teacher_model state_dict to save parameters. + student_model (dict): the student_model state_dict to save parameters. + optimizer (paddle.optimizer.Optimizer): the Optimizer instance to + save optimizer states. + save_dir (str): the directory to be saved. + save_name (str): the path to be saved. + last_epoch (int): the epoch index. + last_iter (int): the iter index. + """ + if paddle.distributed.get_rank() != 0: + return + assert isinstance(teacher_model, dict), ( + "teacher_model is not a instance of dict, " + "please call teacher_model.state_dict() to get.") + assert isinstance(student_model, dict), ( + "student_model is not a instance of dict, " + "please call student_model.state_dict() to get.") + if not os.path.exists(save_dir): + os.makedirs(save_dir) + save_path = os.path.join(save_dir, save_name) + # save model + paddle.save(teacher_model, save_path + str(last_epoch) + "epoch_t.pdparams") + paddle.save(student_model, save_path + str(last_epoch) + "epoch_s.pdparams") + + # save optimizer + state_dict = optimizer.state_dict() + state_dict['last_epoch'] = last_epoch + state_dict['last_iter'] = last_iter + paddle.save(state_dict, save_path + str(last_epoch) + "epoch.pdopt") + logger.info("Save checkpoint: {}".format(save_dir)) diff --git a/tools/train.py b/tools/train.py index 3aa0a21a7..21173c409 100755 --- a/tools/train.py +++ b/tools/train.py @@ -32,7 +32,7 @@ import paddle from ppdet.core.workspace import load_config, merge_config from ppdet.engine import Trainer, TrainerCot, init_parallel_env, set_random_seed, init_fleet_env -from ppdet.engine.trainer_ssod import Trainer_DenseTeacher, Trainer_ARSL +from ppdet.engine.trainer_ssod import Trainer_DenseTeacher, Trainer_ARSL, Trainer_Semi_RTDETR from ppdet.slim import build_slim_model @@ -134,10 +134,11 @@ def run(FLAGS, cfg): trainer = Trainer_DenseTeacher(cfg, mode='train') elif ssod_method == 'ARSL': trainer = Trainer_ARSL(cfg, mode='train') + elif ssod_method == 'Semi_RTDETR': + trainer = Trainer_Semi_RTDETR(cfg, mode='train') else: raise ValueError( - "Semi-Supervised Object Detection only support DenseTeacher and ARSL now." - ) + "Semi-Supervised Object Detection only no support this method.") elif cfg.get('use_cot', False): trainer = TrainerCot(cfg, mode='train') else: @@ -146,6 +147,10 @@ def run(FLAGS, cfg): # load weights if FLAGS.resume is not None: trainer.resume_weights(FLAGS.resume) + elif 'pretrain_student_weights' in cfg and 'pretrain_teacher_weights' in cfg \ + and cfg.pretrain_teacher_weights and cfg.pretrain_student_weights: + trainer.load_semi_weights(cfg.pretrain_teacher_weights, + cfg.pretrain_student_weights) elif 'pretrain_weights' in cfg and cfg.pretrain_weights: trainer.load_weights(cfg.pretrain_weights) -- GitLab