diff --git a/configs/ppyoloe/distill/README.md b/configs/ppyoloe/distill/README.md index 9f8761d83bddaeb519ca392b59e20ba48aa8b703..868d70b88805dca01e63bd56dff7c08c06a2f5cb 100644 --- a/configs/ppyoloe/distill/README.md +++ b/configs/ppyoloe/distill/README.md @@ -1,6 +1,6 @@ # PPYOLOE+ Distillation(PPYOLOE+ 蒸馏) -PaddleDetection提供了对PPYOLOE+ 进行模型蒸馏的方案,结合了logits蒸馏和feature蒸馏。 +PaddleDetection提供了对PPYOLOE+ 进行模型蒸馏的方案,结合了logits蒸馏和feature蒸馏。更多蒸馏方案可以查看[slim/distill](../../slim/distill/)。 ## 模型库 diff --git a/configs/slim/distill/README.md b/configs/slim/distill/README.md index 6ffdf50dad2b55748bc7febfafcf6fb18ec15132..97c93fcc42d3f7d233e5e8794144bfeb8c1cd5b0 100644 --- a/configs/slim/distill/README.md +++ b/configs/slim/distill/README.md @@ -1,5 +1,13 @@ # Distillation(蒸馏) +## 内容 +- [YOLOv3模型蒸馏](#YOLOv3模型蒸馏) +- [FGD模型蒸馏](#FGD模型蒸馏) +- [CWD模型蒸馏](#CWD模型蒸馏) +- [LD模型蒸馏](#LD模型蒸馏) +- [PPYOLOE模型蒸馏](#PPYOLOE模型蒸馏) +- [引用](#引用) + ## YOLOv3模型蒸馏 以YOLOv3-MobileNetV1为例,使用YOLOv3-ResNet34作为蒸馏训练的teacher网络, 对YOLOv3-MobileNetV1结构的student网络进行蒸馏。 @@ -12,6 +20,25 @@ COCO数据集作为目标检测任务的训练目标难度更大,意味着teac | YOLOv3-MobileNetV1 | student | 608 | 270e | 29.4 | [config](../../yolov3/yolov3_mobilenet_v1_270e_coco.yml) | [download](https://paddledet.bj.bcebos.com/models/yolov3_mobilenet_v1_270e_coco.pdparams) | | YOLOv3-MobileNetV1 | distill | 608 | 270e | 31.0(+1.6) | [config](../../yolov3/yolov3_mobilenet_v1_270e_coco.yml),[slim_config](./yolov3_mobilenet_v1_coco_distill.yml) | [download](https://paddledet.bj.bcebos.com/models/slim/yolov3_mobilenet_v1_coco_distill.pdparams) | +
+ 快速开始 + +```shell +# 单卡训练(不推荐) +python tools/train.py -c configs/yolov3/yolov3_mobilenet_v1_270e_coco.yml --slim_config configs/slim/distill/yolov3_mobilenet_v1_coco_distill.yml +# 多卡训练 +python -m paddle.distributed.launch --log_dir=logs/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/yolov3/yolov3_mobilenet_v1_270e_coco.yml --slim_config configs/slim/distill/yolov3_mobilenet_v1_coco_distill.yml +# 评估 +python tools/eval.py -c configs/yolov3/yolov3_mobilenet_v1_270e_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/slim/yolov3_mobilenet_v1_coco_distill.pdparams +# 预测 +python tools/infer.py -c configs/yolov3/yolov3_mobilenet_v1_270e_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/slim/yolov3_mobilenet_v1_coco_distill.pdparams --infer_img=demo/000000014439_640x640.jpg +``` + +- `-c`: 指定模型配置文件,也是student配置文件。 +- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。 + +
+ ## FGD模型蒸馏 @@ -24,16 +51,24 @@ FGD全称为[Focal and Global Knowledge Distillation for Detectors](https://arxi | RetinaNet-ResNet50 | student | 1333x800 | 2x | 39.1 | [config](../../retinanet/retinanet_r50_fpn_2x_coco.yml) | [download](https://paddledet.bj.bcebos.com/models/retinanet_r50_fpn_2x_coco.pdparams) | | RetinaNet-ResNet50 | FGD | 1333x800 | 2x | 40.8(+1.7) | [config](../../retinanet/retinanet_r50_fpn_2x_coco.yml),[slim_config](./retinanet_resnet101_coco_distill.yml) | [download](https://paddledet.bj.bcebos.com/models/retinanet_r101_distill_r50_2x_coco.pdparams) | +
+ 快速开始 -## LD模型蒸馏 +```shell +# 单卡训练(不推荐) +python tools/train.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml --slim_config configs/slim/distill/retinanet_resnet101_coco_distill.yml +# 多卡训练 +python -m paddle.distributed.launch --log_dir=logs/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml --slim_config configs/slim/distill/retinanet_resnet101_coco_distill.yml +# 评估 +python tools/eval.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/retinanet_r101_distill_r50_2x_coco.pdparams +# 预测 +python tools/infer.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/retinanet_r101_distill_r50_2x_coco.pdparams --infer_img=demo/000000014439_640x640.jpg +``` -LD全称为[Localization Distillation for Dense Object Detection](https://arxiv.org/abs/2102.12252),将回归框表示为概率分布,把分类任务的KD用在定位任务上,并且使用因地制宜、分而治之的策略,在不同的区域分别学习分类知识与定位知识。在PaddleDetection中,我们实现了LD算法,并基于GFL模型进行验证,实验结果如下: +- `-c`: 指定模型配置文件,也是student配置文件。 +- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。 -| 模型 | 方案 | 输入尺寸 | epochs | Box mAP | 配置文件 | 下载链接 | -| ----------------- | ----------- | ------ | :----: | :-----------: | :--------------: | :------------: | -| GFL_ResNet101-vd| teacher | 1333x800 | 2x | 46.8 | [config](../../gfl/gfl_r101vd_fpn_mstrain_2x_coco.yml) | [download](https://paddledet.bj.bcebos.com/models/gfl_r101vd_fpn_mstrain_2x_coco.pdparams) | -| GFL_ResNet18-vd | student | 1333x800 | 1x | 36.6 | [config](../../gfl/gfl_r18vd_1x_coco.yml) | [download](https://paddledet.bj.bcebos.com/models/gfl_r18vd_1x_coco.pdparams) | -| GFL_ResNet18-vd | LD | 1333x800 | 1x | 38.2(+1.6) | [config](../../gfl/gfl_slim_ld_r18vd_1x_coco.yml),[slim_config](./gfl_ld_distill.yml) | [download](https://bj.bcebos.com/v1/paddledet/models/gfl_slim_ld_r18vd_1x_coco.pdparams) | +
## CWD模型蒸馏 @@ -44,60 +79,104 @@ CWD全称为[Channel-wise Knowledge Distillation for Dense Prediction*](https:// | ----------------- | ----------- | ------ | :----: | :-----------: | :--------------: | :------------: | | RetinaNet-ResNet101| teacher | 1333x800 | 2x | 40.6 | [config](../../retinanet/retinanet_r101_fpn_2x_coco.yml) | [download](https://paddledet.bj.bcebos.com/models/retinanet_r101_fpn_2x_coco.pdparams) | | RetinaNet-ResNet50 | student | 1333x800 | 2x | 39.1 | [config](../../retinanet/retinanet_r50_fpn_2x_coco.yml) | [download](https://paddledet.bj.bcebos.com/models/retinanet_r50_fpn_2x_coco.pdparams) | -| RetinaNet-ResNet50 | CWD | 1333x800 | 2x | 40.5(+1.4) | [config](../../retinanet/retinanet_r50_fpn_2x_coco_cwd.yml),[slim_config](./retinanet_resnet101_coco_distill_cwd.yml) | [download](https://paddledet.bj.bcebos.com/models/retinanet_r50_fpn_2x_coco_cwd.pdparams) | +| RetinaNet-ResNet50 | CWD | 1333x800 | 2x | 40.5(+1.4) | [config](../../retinanet/retinanet_r50_fpn_2x_coco.yml),[slim_config](./retinanet_resnet101_coco_distill_cwd.yml) | [download](https://paddledet.bj.bcebos.com/models/retinanet_r50_fpn_2x_coco_cwd.pdparams) | | GFL_ResNet101-vd| teacher | 1333x800 | 2x | 46.8 | [config](../../gfl/gfl_r101vd_fpn_mstrain_2x_coco.yml) | [download](https://paddledet.bj.bcebos.com/models/gfl_r101vd_fpn_mstrain_2x_coco.pdparams) | | GFL_ResNet50 | student | 1333x800 | 1x | 41.0 | [config](../../gfl/gfl_r50_fpn_1x_coco.yml) | [download](https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_1x_coco.pdparams) | -| GFL_ResNet50 | LD | 1333x800 | 2x | 44.0(+3.0) | [config](../../gfl/gfl_r50_fpn_2x_coco_cwd.yml),[slim_config](./gfl_r101vd_fpn_coco_distill_cwd.yml) | [download](https://bj.bcebos.com/v1/paddledet/models/gfl_r50_fpn_2x_coco_cwd.pdparams) | +| GFL_ResNet50 | CWD | 1333x800 | 2x | 44.0(+3.0) | [config](../../gfl/gfl_r50_fpn_1x_coco.yml),[slim_config](./gfl_r101vd_fpn_coco_distill_cwd.yml) | [download](https://bj.bcebos.com/v1/paddledet/models/gfl_r50_fpn_2x_coco_cwd.pdparams) | + +
+ 快速开始 + +```shell +# 单卡训练(不推荐) +python tools/train.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml --slim_config configs/slim/distill/retinanet_resnet101_coco_distill_cwd.yml +# 多卡训练 +python -m paddle.distributed.launch --log_dir=logs/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml --slim_config configs/slim/distill/retinanet_resnet101_coco_distill_cwd.yml +# 评估 +python tools/eval.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/retinanet_r50_fpn_2x_coco_cwd.pdparams +# 预测 +python tools/infer.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/retinanet_r50_fpn_2x_coco_cwd.pdparams --infer_img=demo/000000014439_640x640.jpg + +# 单卡训练(不推荐) +python tools/train.py -c configs/gfl/gfl_r50_fpn_1x_coco.yml --slim_config configs/slim/distill/gfl_r101vd_fpn_coco_distill_cwd.yml +# 多卡训练 +python -m paddle.distributed.launch --log_dir=logs/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/gfl/gfl_r50_fpn_1x_coco.yml --slim_config configs/slim/distill/gfl_r101vd_fpn_coco_distill_cwd.yml +# 评估 +python tools/eval.py -c configs/gfl/gfl_r50_fpn_1x_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_2x_coco_cwd.pdparams +# 预测 +python tools/infer.py -c configs/gfl/gfl_r50_fpn_1x_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/gfl_r50_fpn_2x_coco_cwd.pdparams --infer_img=demo/000000014439_640x640.jpg +``` +- `-c`: 指定模型配置文件,也是student配置文件。 +- `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。 -## PPYOLOE+ 模型蒸馏 +
-PaddleDetection提供了对PPYOLOE+ 进行模型蒸馏的方案,结合了logits蒸馏和feature蒸馏。 + +## LD模型蒸馏 + +LD全称为[Localization Distillation for Dense Object Detection](https://arxiv.org/abs/2102.12252),将回归框表示为概率分布,把分类任务的KD用在定位任务上,并且使用因地制宜、分而治之的策略,在不同的区域分别学习分类知识与定位知识。在PaddleDetection中,我们实现了LD算法,并基于GFL模型进行验证,实验结果如下: | 模型 | 方案 | 输入尺寸 | epochs | Box mAP | 配置文件 | 下载链接 | | ----------------- | ----------- | ------ | :----: | :-----------: | :--------------: | :------------: | -| PP-YOLOE+_x | teacher | 640 | 80e | 54.7 | [config](../../ppyoloe/ppyoloe_plus_crn_x_80e_coco.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_plus_crn_x_80e_coco.pdparams) | -| PP-YOLOE+_l | student | 640 | 80e | 52.9 | [config](../../ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_plus_crn_l_80e_coco.pdparams) | -| PP-YOLOE+_l | distill | 640 | 80e | 53.9(+1.0) | [config](../../ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml),[slim_config](./ppyoloe_plus_distill_x_distill_l.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_plus_crn_l_80e_coco_distill.pdparams) | -| PP-YOLOE+_l | teacher | 640 | 80e | 52.9 | [config](../../ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_plus_crn_l_80e_coco.pdparams) | -| PP-YOLOE+_m | student | 640 | 80e | 49.8 | [config](../../ppyoloe/ppyoloe_plus_crn_m_80e_coco.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_plus_crn_m_80e_coco.pdparams) | -| PP-YOLOE+_m | distill | 640 | 80e | 50.7(+0.9) | [config](../../ppyoloe/distill/ppyoloe_plus_crn_m_80e_coco_distill.yml),[slim_config](./ppyoloe_plus_distill_l_distill_m.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_plus_crn_m_80e_coco_distill.pdparams) | - +| GFL_ResNet101-vd| teacher | 1333x800 | 2x | 46.8 | [config](../../gfl/gfl_r101vd_fpn_mstrain_2x_coco.yml) | [download](https://paddledet.bj.bcebos.com/models/gfl_r101vd_fpn_mstrain_2x_coco.pdparams) | +| GFL_ResNet18-vd | student | 1333x800 | 1x | 36.6 | [config](../../gfl/gfl_r18vd_1x_coco.yml) | [download](https://paddledet.bj.bcebos.com/models/gfl_r18vd_1x_coco.pdparams) | +| GFL_ResNet18-vd | LD | 1333x800 | 1x | 38.2(+1.6) | [config](../../gfl/gfl_slim_ld_r18vd_1x_coco.yml),[slim_config](./gfl_ld_distill.yml) | [download](https://bj.bcebos.com/v1/paddledet/models/gfl_slim_ld_r18vd_1x_coco.pdparams) | -## 快速开始 +
+ 快速开始 -### 训练 ```shell -# 单卡 -python tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_distill_l.yml -# 多卡 -python -m paddle.distributed.launch --log_dir=ppyoloe_plus_distill_x_distill_l/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_distill_l.yml +# 单卡训练(不推荐) +python tools/train.py -c configs/gfl/gfl_slim_ld_r18vd_1x_coco.yml --slim_config configs/slim/distill/gfl_ld_distill.yml +# 多卡训练 +python -m paddle.distributed.launch --log_dir=logs/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/gfl/gfl_slim_ld_r18vd_1x_coco.yml --slim_config configs/slim/distill/gfl_ld_distill.yml +# 评估 +python tools/eval.py -c configs/gfl/gfl_slim_ld_r18vd_1x_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/gfl_slim_ld_r18vd_1x_coco.pdparams +# 预测 +python tools/infer.py -c configs/gfl/gfl_slim_ld_r18vd_1x_coco.yml -o weights=https://paddledet.bj.bcebos.com/models/gfl_slim_ld_r18vd_1x_coco.pdparams --infer_img=demo/000000014439_640x640.jpg ``` - `-c`: 指定模型配置文件,也是student配置文件。 - `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。 -### 评估 +
+ + +## PPYOLOE模型蒸馏 + +PaddleDetection提供了对PPYOLOE+ 进行模型蒸馏的方案,结合了logits蒸馏和feature蒸馏。 + +| 模型 | 方案 | 输入尺寸 | epochs | Box mAP | 配置文件 | 下载链接 | +| ----------------- | ----------- | ------ | :----: | :-----------: | :--------------: | :------------: | +| PP-YOLOE+_x | teacher | 640 | 80e | 54.7 | [config](../../ppyoloe/ppyoloe_plus_crn_x_80e_coco.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_plus_crn_x_80e_coco.pdparams) | +| PP-YOLOE+_l | student | 640 | 80e | 52.9 | [config](../../ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_plus_crn_l_80e_coco.pdparams) | +| PP-YOLOE+_l | distill | 640 | 80e | **54.0(+1.1)** | [config](../../ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml),[slim_config](./ppyoloe_plus_distill_x_distill_l.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_plus_crn_l_80e_coco_distill.pdparams) | +| PP-YOLOE+_l | teacher | 640 | 80e | 52.9 | [config](../../ppyoloe/ppyoloe_plus_crn_l_80e_coco.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_plus_crn_l_80e_coco.pdparams) | +| PP-YOLOE+_m | student | 640 | 80e | 49.8 | [config](../../ppyoloe/ppyoloe_plus_crn_m_80e_coco.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_plus_crn_m_80e_coco.pdparams) | +| PP-YOLOE+_m | distill | 640 | 80e | **51.0(+1.2)** | [config](../../ppyoloe/distill/ppyoloe_plus_crn_m_80e_coco_distill.yml),[slim_config](./ppyoloe_plus_distill_l_distill_m.yml) | [model](https://bj.bcebos.com/v1/paddledet/models/ppyoloe_plus_crn_m_80e_coco_distill.pdparams) | + +
+ 快速开始 + ```shell -python tools/eval.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=output/ppyoloe_plus_crn_l_80e_coco_distill/model_final.pdparams +# 单卡训练(不推荐) +python tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_distill_l.yml +# 多卡训练 +python -m paddle.distributed.launch --log_dir=logs/ --gpus 0,1,2,3,4,5,6,7 tools/train.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml --slim_config configs/slim/distill/ppyoloe_plus_distill_x_distill_l.yml +# 评估 +python tools/eval.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_l_80e_coco_distill.pdparams +# 预测 +python tools/infer.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_plus_crn_l_80e_coco_distill.pdparams --infer_img=demo/000000014439_640x640.jpg ``` - `-c`: 指定模型配置文件,也是student配置文件。 - `--slim_config`: 指定压缩策略配置文件,也是teacher配置文件。 -- `-o weights`: 指定压缩算法训好的模型路径。 - -### 测试 -```shell -python tools/infer.py -c configs/ppyoloe/distill/ppyoloe_plus_crn_l_80e_coco_distill.yml -o weights=output/ppyoloe_plus_crn_l_80e_coco_distill/model_final.pdparams --infer_img=demo/000000014439_640x640.jpg -``` -- `-c`: 指定模型配置文件。 -- `--slim_config`: 指定压缩策略配置文件。 -- `-o weights`: 指定压缩算法训好的模型路径。 -- `--infer_img`: 指定测试图像路径。 +
-## Citations +## 引用 ``` @article{mehta2018object, title={Object detection at 200 Frames Per Second}, diff --git a/ppdet/modeling/heads/__init__.py b/ppdet/modeling/heads/__init__.py index ecd15b2f139e1244e908cfec59592dafcb1f3ec4..9cceb268a7d32b8746b9c3995323b01c60cae99a 100644 --- a/ppdet/modeling/heads/__init__.py +++ b/ppdet/modeling/heads/__init__.py @@ -35,7 +35,6 @@ from . import retina_head from . import ppyoloe_head from . import fcosr_head from . import ppyoloe_r_head -from . import ld_gfl_head from . import yolof_head from . import ppyoloe_contrast_head from . import centertrack_head @@ -63,7 +62,6 @@ from .tood_head import * from .retina_head import * from .ppyoloe_head import * from .fcosr_head import * -from .ld_gfl_head import * from .ppyoloe_r_head import * from .yolof_head import * from .ppyoloe_contrast_head import * diff --git a/ppdet/modeling/heads/gfl_head.py b/ppdet/modeling/heads/gfl_head.py index a1f518da5cf6a16208802816115372f00dd8a15f..040a3f7090d4a82ed5f4641967ceae1c0349d3fb 100644 --- a/ppdet/modeling/heads/gfl_head.py +++ b/ppdet/modeling/heads/gfl_head.py @@ -28,10 +28,11 @@ from paddle import ParamAttr from paddle.nn.initializer import Normal, Constant from ppdet.core.workspace import register -from ppdet.modeling.layers import ConvNormLayer from ppdet.modeling.bbox_utils import distance2bbox, bbox2distance, batch_distance2bbox from ppdet.data.transform.atss_assigner import bbox_overlaps +__all__ = ['GFLHead', 'LDGFLHead'] + class ScaleReg(nn.Layer): """ @@ -437,3 +438,299 @@ class GFLHead(nn.Layer): mlvl_scores = mlvl_scores.transpose([0, 2, 1]) bbox_pred, bbox_num, _ = self.nms(bboxes, mlvl_scores) return bbox_pred, bbox_num + + +@register +class LDGFLHead(GFLHead): + """ + GFLHead for LD distill + Args: + conv_feat (object): Instance of 'FCOSFeat' + num_classes (int): Number of classes + fpn_stride (list): The stride of each FPN Layer + prior_prob (float): Used to set the bias init for the class prediction layer + loss_class (object): Instance of QualityFocalLoss. + loss_dfl (object): Instance of DistributionFocalLoss. + loss_bbox (object): Instance of bbox loss. + reg_max: Max value of integral set :math: `{0, ..., reg_max}` + n QFL setting. Default: 16. + """ + __inject__ = [ + 'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox', + 'loss_ld', 'loss_ld_vlr', 'loss_kd', 'nms' + ] + __shared__ = ['num_classes'] + + def __init__(self, + conv_feat='FCOSFeat', + dgqp_module=None, + num_classes=80, + fpn_stride=[8, 16, 32, 64, 128], + prior_prob=0.01, + loss_class='QualityFocalLoss', + loss_dfl='DistributionFocalLoss', + loss_bbox='GIoULoss', + loss_ld='KnowledgeDistillationKLDivLoss', + loss_ld_vlr='KnowledgeDistillationKLDivLoss', + loss_kd='KnowledgeDistillationKLDivLoss', + reg_max=16, + feat_in_chan=256, + nms=None, + nms_pre=1000, + cell_offset=0): + + super(LDGFLHead, self).__init__( + conv_feat=conv_feat, + dgqp_module=dgqp_module, + num_classes=num_classes, + fpn_stride=fpn_stride, + prior_prob=prior_prob, + loss_class=loss_class, + loss_dfl=loss_dfl, + loss_bbox=loss_bbox, + reg_max=reg_max, + feat_in_chan=feat_in_chan, + nms=nms, + nms_pre=nms_pre, + cell_offset=cell_offset) + self.loss_ld = loss_ld + self.loss_kd = loss_kd + self.loss_ld_vlr = loss_ld_vlr + + def forward(self, fpn_feats): + assert len(fpn_feats) == len( + self.fpn_stride + ), "The size of fpn_feats is not equal to size of fpn_stride" + cls_logits_list = [] + bboxes_reg_list = [] + for stride, scale_reg, fpn_feat in zip(self.fpn_stride, + self.scales_regs, fpn_feats): + conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat) + cls_score = self.gfl_head_cls(conv_cls_feat) + bbox_pred = scale_reg(self.gfl_head_reg(conv_reg_feat)) + + if self.dgqp_module: + quality_score = self.dgqp_module(bbox_pred) + cls_score = F.sigmoid(cls_score) * quality_score + if not self.training: + cls_score = F.sigmoid(cls_score.transpose([0, 2, 3, 1])) + bbox_pred = bbox_pred.transpose([0, 2, 3, 1]) + b, cell_h, cell_w, _ = paddle.shape(cls_score) + y, x = self.get_single_level_center_point( + [cell_h, cell_w], stride, cell_offset=self.cell_offset) + center_points = paddle.stack([x, y], axis=-1) + cls_score = cls_score.reshape([b, -1, self.cls_out_channels]) + bbox_pred = self.distribution_project(bbox_pred) * stride + bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4]) + + # NOTE: If keep_ratio=False and image shape value that + # multiples of 32, distance2bbox not set max_shapes parameter + # to speed up model prediction. If need to set max_shapes, + # please use inputs['im_shape']. + bbox_pred = batch_distance2bbox( + center_points, bbox_pred, max_shapes=None) + + cls_logits_list.append(cls_score) + bboxes_reg_list.append(bbox_pred) + + return (cls_logits_list, bboxes_reg_list) + + def get_loss(self, gfl_head_outs, gt_meta, soft_label_list, + soft_targets_list): + cls_logits, bboxes_reg = gfl_head_outs + + num_level_anchors = [ + featmap.shape[-2] * featmap.shape[-1] for featmap in cls_logits + ] + + grid_cells_list = self._images_to_levels(gt_meta['grid_cells'], + num_level_anchors) + + labels_list = self._images_to_levels(gt_meta['labels'], + num_level_anchors) + + label_weights_list = self._images_to_levels(gt_meta['label_weights'], + num_level_anchors) + bbox_targets_list = self._images_to_levels(gt_meta['bbox_targets'], + num_level_anchors) + # vlr regions + vlr_regions_list = self._images_to_levels(gt_meta['vlr_regions'], + num_level_anchors) + + num_total_pos = sum(gt_meta['pos_num']) + try: + paddle.distributed.all_reduce(num_total_pos) + num_total_pos = paddle.clip( + num_total_pos / paddle.distributed.get_world_size(), min=1.) + except: + num_total_pos = max(num_total_pos, 1) + + loss_bbox_list, loss_dfl_list, loss_qfl_list, loss_ld_list, avg_factor = [], [], [], [], [] + loss_ld_vlr_list, loss_kd_list = [], [] + + for cls_score, bbox_pred, grid_cells, labels, label_weights, bbox_targets, stride, soft_targets,\ + soft_label, vlr_region in zip( + cls_logits, bboxes_reg, grid_cells_list, labels_list, + label_weights_list, bbox_targets_list, self.fpn_stride, soft_targets_list, + soft_label_list, vlr_regions_list): + + grid_cells = grid_cells.reshape([-1, 4]) + cls_score = cls_score.transpose([0, 2, 3, 1]).reshape( + [-1, self.cls_out_channels]) + bbox_pred = bbox_pred.transpose([0, 2, 3, 1]).reshape( + [-1, 4 * (self.reg_max + 1)]) + + soft_targets = soft_targets.transpose([0, 2, 3, 1]).reshape( + [-1, 4 * (self.reg_max + 1)]) + + soft_label = soft_label.transpose([0, 2, 3, 1]).reshape( + [-1, self.cls_out_channels]) + + # feture im + # teacher_x = teacher_x.transpose([0, 2, 3, 1]).reshape([-1, 256]) + # x = x.transpose([0, 2, 3, 1]).reshape([-1, 256]) + + bbox_targets = bbox_targets.reshape([-1, 4]) + labels = labels.reshape([-1]) + label_weights = label_weights.reshape([-1]) + + vlr_region = vlr_region.reshape([-1]) + + bg_class_ind = self.num_classes + pos_inds = paddle.nonzero( + paddle.logical_and((labels >= 0), (labels < bg_class_ind)), + as_tuple=False).squeeze(1) + score = np.zeros(labels.shape) + + remain_inds = (vlr_region > 0).nonzero() + + if len(pos_inds) > 0: + pos_bbox_targets = paddle.gather(bbox_targets, pos_inds, axis=0) + pos_bbox_pred = paddle.gather(bbox_pred, pos_inds, axis=0) + pos_grid_cells = paddle.gather(grid_cells, pos_inds, axis=0) + + pos_grid_cell_centers = self._grid_cells_to_center( + pos_grid_cells) / stride + + weight_targets = F.sigmoid(cls_score.detach()) + weight_targets = paddle.gather( + weight_targets.max(axis=1, keepdim=True), pos_inds, axis=0) + pos_bbox_pred_corners = self.distribution_project(pos_bbox_pred) + pos_decode_bbox_pred = distance2bbox(pos_grid_cell_centers, + pos_bbox_pred_corners) + pos_decode_bbox_targets = pos_bbox_targets / stride + bbox_iou = bbox_overlaps( + pos_decode_bbox_pred.detach().numpy(), + pos_decode_bbox_targets.detach().numpy(), + is_aligned=True) + score[pos_inds.numpy()] = bbox_iou + pred_corners = pos_bbox_pred.reshape([-1, self.reg_max + 1]) + + pos_soft_targets = paddle.gather(soft_targets, pos_inds, axis=0) + soft_corners = pos_soft_targets.reshape([-1, self.reg_max + 1]) + + target_corners = bbox2distance(pos_grid_cell_centers, + pos_decode_bbox_targets, + self.reg_max).reshape([-1]) + # regression loss + loss_bbox = paddle.sum( + self.loss_bbox(pos_decode_bbox_pred, + pos_decode_bbox_targets) * weight_targets) + + # dfl loss + loss_dfl = self.loss_dfl( + pred_corners, + target_corners, + weight=weight_targets.expand([-1, 4]).reshape([-1]), + avg_factor=4.0) + + # ld loss + loss_ld = self.loss_ld( + pred_corners, + soft_corners, + weight=weight_targets.expand([-1, 4]).reshape([-1]), + avg_factor=4.0) + + loss_kd = self.loss_kd( + paddle.gather( + cls_score, pos_inds, axis=0), + paddle.gather( + soft_label, pos_inds, axis=0), + weight=paddle.gather( + label_weights, pos_inds, axis=0), + avg_factor=pos_inds.shape[0]) + + else: + loss_bbox = bbox_pred.sum() * 0 + loss_dfl = bbox_pred.sum() * 0 + loss_ld = bbox_pred.sum() * 0 + loss_kd = bbox_pred.sum() * 0 + weight_targets = paddle.to_tensor([0], dtype='float32') + + if len(remain_inds) > 0: + neg_pred_corners = bbox_pred[remain_inds].reshape( + [-1, self.reg_max + 1]) + neg_soft_corners = soft_targets[remain_inds].reshape( + [-1, self.reg_max + 1]) + + remain_targets = vlr_region[remain_inds] + + loss_ld_vlr = self.loss_ld_vlr( + neg_pred_corners, + neg_soft_corners, + weight=remain_targets.expand([-1, 4]).reshape([-1]), + avg_factor=16.0) + else: + loss_ld_vlr = bbox_pred.sum() * 0 + + # qfl loss + score = paddle.to_tensor(score) + loss_qfl = self.loss_qfl( + cls_score, (labels, score), + weight=label_weights, + avg_factor=num_total_pos) + + loss_bbox_list.append(loss_bbox) + loss_dfl_list.append(loss_dfl) + loss_qfl_list.append(loss_qfl) + loss_ld_list.append(loss_ld) + loss_ld_vlr_list.append(loss_ld_vlr) + loss_kd_list.append(loss_kd) + avg_factor.append(weight_targets.sum()) + + avg_factor = sum(avg_factor) # + 1e-6 + try: + paddle.distributed.all_reduce(avg_factor) + avg_factor = paddle.clip( + avg_factor / paddle.distributed.get_world_size(), min=1) + except: + avg_factor = max(avg_factor.item(), 1) + + if avg_factor <= 0: + loss_qfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False) + loss_bbox = paddle.to_tensor( + 0, dtype='float32', stop_gradient=False) + loss_dfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False) + loss_ld = paddle.to_tensor(0, dtype='float32', stop_gradient=False) + loss_ld_vlr = paddle.to_tensor( + 0, dtype='float32', stop_gradient=False) + loss_kd = paddle.to_tensor(0, dtype='float32', stop_gradient=False) + else: + losses_bbox = list(map(lambda x: x / avg_factor, loss_bbox_list)) + losses_dfl = list(map(lambda x: x / avg_factor, loss_dfl_list)) + loss_qfl = sum(loss_qfl_list) + loss_bbox = sum(losses_bbox) + loss_dfl = sum(losses_dfl) + loss_ld = sum(loss_ld_list) + loss_ld_vlr = sum(loss_ld_vlr_list) + loss_kd = sum(loss_kd_list) + + loss_states = dict( + loss_qfl=loss_qfl, + loss_bbox=loss_bbox, + loss_dfl=loss_dfl, + loss_ld=loss_ld, + loss_ld_vlr=loss_ld_vlr, + loss_kd=loss_kd) + + return loss_states diff --git a/ppdet/modeling/heads/ld_gfl_head.py b/ppdet/modeling/heads/ld_gfl_head.py deleted file mode 100644 index dbff7ecbab0268e741da1e687751e07acc05d0fc..0000000000000000000000000000000000000000 --- a/ppdet/modeling/heads/ld_gfl_head.py +++ /dev/null @@ -1,330 +0,0 @@ -# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# The code is based on: -# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/ld_head.py - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import math -import numpy as np -import paddle -import paddle.nn as nn -import paddle.nn.functional as F -from paddle import ParamAttr -from paddle.nn.initializer import Normal, Constant - -from ppdet.core.workspace import register, serializable -from ppdet.modeling.layers import ConvNormLayer -from ppdet.modeling.bbox_utils import distance2bbox, bbox2distance, batch_distance2bbox -from ppdet.data.transform.atss_assigner import bbox_overlaps -from .gfl_head import GFLHead - - -@register -class LDGFLHead(GFLHead): - """ - GFLHead for LD distill - Args: - conv_feat (object): Instance of 'FCOSFeat' - num_classes (int): Number of classes - fpn_stride (list): The stride of each FPN Layer - prior_prob (float): Used to set the bias init for the class prediction layer - loss_class (object): Instance of QualityFocalLoss. - loss_dfl (object): Instance of DistributionFocalLoss. - loss_bbox (object): Instance of bbox loss. - reg_max: Max value of integral set :math: `{0, ..., reg_max}` - n QFL setting. Default: 16. - """ - __inject__ = [ - 'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox', - 'loss_ld', 'loss_ld_vlr', 'loss_kd', 'nms' - ] - __shared__ = ['num_classes'] - - def __init__(self, - conv_feat='FCOSFeat', - dgqp_module=None, - num_classes=80, - fpn_stride=[8, 16, 32, 64, 128], - prior_prob=0.01, - loss_class='QualityFocalLoss', - loss_dfl='DistributionFocalLoss', - loss_bbox='GIoULoss', - loss_ld='KnowledgeDistillationKLDivLoss', - loss_ld_vlr='KnowledgeDistillationKLDivLoss', - loss_kd='KnowledgeDistillationKLDivLoss', - reg_max=16, - feat_in_chan=256, - nms=None, - nms_pre=1000, - cell_offset=0): - - super(LDGFLHead, self).__init__( - conv_feat=conv_feat, - dgqp_module=dgqp_module, - num_classes=num_classes, - fpn_stride=fpn_stride, - prior_prob=prior_prob, - loss_class=loss_class, - loss_dfl=loss_dfl, - loss_bbox=loss_bbox, - reg_max=reg_max, - feat_in_chan=feat_in_chan, - nms=nms, - nms_pre=nms_pre, - cell_offset=cell_offset) - self.loss_ld = loss_ld - self.loss_kd = loss_kd - self.loss_ld_vlr = loss_ld_vlr - - def forward(self, fpn_feats): - assert len(fpn_feats) == len( - self.fpn_stride - ), "The size of fpn_feats is not equal to size of fpn_stride" - cls_logits_list = [] - bboxes_reg_list = [] - for stride, scale_reg, fpn_feat in zip(self.fpn_stride, - self.scales_regs, fpn_feats): - conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat) - cls_score = self.gfl_head_cls(conv_cls_feat) - bbox_pred = scale_reg(self.gfl_head_reg(conv_reg_feat)) - - if self.dgqp_module: - quality_score = self.dgqp_module(bbox_pred) - cls_score = F.sigmoid(cls_score) * quality_score - if not self.training: - cls_score = F.sigmoid(cls_score.transpose([0, 2, 3, 1])) - bbox_pred = bbox_pred.transpose([0, 2, 3, 1]) - b, cell_h, cell_w, _ = paddle.shape(cls_score) - y, x = self.get_single_level_center_point( - [cell_h, cell_w], stride, cell_offset=self.cell_offset) - center_points = paddle.stack([x, y], axis=-1) - cls_score = cls_score.reshape([b, -1, self.cls_out_channels]) - bbox_pred = self.distribution_project(bbox_pred) * stride - bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4]) - - # NOTE: If keep_ratio=False and image shape value that - # multiples of 32, distance2bbox not set max_shapes parameter - # to speed up model prediction. If need to set max_shapes, - # please use inputs['im_shape']. - bbox_pred = batch_distance2bbox( - center_points, bbox_pred, max_shapes=None) - - cls_logits_list.append(cls_score) - bboxes_reg_list.append(bbox_pred) - - return (cls_logits_list, bboxes_reg_list) - - def get_loss(self, gfl_head_outs, gt_meta, soft_label_list, - soft_targets_list): - cls_logits, bboxes_reg = gfl_head_outs - - num_level_anchors = [ - featmap.shape[-2] * featmap.shape[-1] for featmap in cls_logits - ] - - grid_cells_list = self._images_to_levels(gt_meta['grid_cells'], - num_level_anchors) - - labels_list = self._images_to_levels(gt_meta['labels'], - num_level_anchors) - - label_weights_list = self._images_to_levels(gt_meta['label_weights'], - num_level_anchors) - bbox_targets_list = self._images_to_levels(gt_meta['bbox_targets'], - num_level_anchors) - # vlr regions - vlr_regions_list = self._images_to_levels(gt_meta['vlr_regions'], - num_level_anchors) - - num_total_pos = sum(gt_meta['pos_num']) - try: - paddle.distributed.all_reduce(num_total_pos) - num_total_pos = paddle.clip( - num_total_pos / paddle.distributed.get_world_size(), min=1.) - except: - num_total_pos = max(num_total_pos, 1) - - loss_bbox_list, loss_dfl_list, loss_qfl_list, loss_ld_list, avg_factor = [], [], [], [], [] - loss_ld_vlr_list, loss_kd_list = [], [] - - for cls_score, bbox_pred, grid_cells, labels, label_weights, bbox_targets, stride, soft_targets,\ - soft_label, vlr_region in zip( - cls_logits, bboxes_reg, grid_cells_list, labels_list, - label_weights_list, bbox_targets_list, self.fpn_stride, soft_targets_list, - soft_label_list, vlr_regions_list): - - grid_cells = grid_cells.reshape([-1, 4]) - cls_score = cls_score.transpose([0, 2, 3, 1]).reshape( - [-1, self.cls_out_channels]) - bbox_pred = bbox_pred.transpose([0, 2, 3, 1]).reshape( - [-1, 4 * (self.reg_max + 1)]) - - soft_targets = soft_targets.transpose([0, 2, 3, 1]).reshape( - [-1, 4 * (self.reg_max + 1)]) - - soft_label = soft_label.transpose([0, 2, 3, 1]).reshape( - [-1, self.cls_out_channels]) - - # feture im - # teacher_x = teacher_x.transpose([0, 2, 3, 1]).reshape([-1, 256]) - # x = x.transpose([0, 2, 3, 1]).reshape([-1, 256]) - - bbox_targets = bbox_targets.reshape([-1, 4]) - labels = labels.reshape([-1]) - label_weights = label_weights.reshape([-1]) - - vlr_region = vlr_region.reshape([-1]) - - bg_class_ind = self.num_classes - pos_inds = paddle.nonzero( - paddle.logical_and((labels >= 0), (labels < bg_class_ind)), - as_tuple=False).squeeze(1) - score = np.zeros(labels.shape) - - remain_inds = (vlr_region > 0).nonzero() - - if len(pos_inds) > 0: - pos_bbox_targets = paddle.gather(bbox_targets, pos_inds, axis=0) - pos_bbox_pred = paddle.gather(bbox_pred, pos_inds, axis=0) - pos_grid_cells = paddle.gather(grid_cells, pos_inds, axis=0) - - pos_grid_cell_centers = self._grid_cells_to_center( - pos_grid_cells) / stride - - weight_targets = F.sigmoid(cls_score.detach()) - weight_targets = paddle.gather( - weight_targets.max(axis=1, keepdim=True), pos_inds, axis=0) - pos_bbox_pred_corners = self.distribution_project(pos_bbox_pred) - pos_decode_bbox_pred = distance2bbox(pos_grid_cell_centers, - pos_bbox_pred_corners) - pos_decode_bbox_targets = pos_bbox_targets / stride - bbox_iou = bbox_overlaps( - pos_decode_bbox_pred.detach().numpy(), - pos_decode_bbox_targets.detach().numpy(), - is_aligned=True) - score[pos_inds.numpy()] = bbox_iou - pred_corners = pos_bbox_pred.reshape([-1, self.reg_max + 1]) - - pos_soft_targets = paddle.gather(soft_targets, pos_inds, axis=0) - soft_corners = pos_soft_targets.reshape([-1, self.reg_max + 1]) - - target_corners = bbox2distance(pos_grid_cell_centers, - pos_decode_bbox_targets, - self.reg_max).reshape([-1]) - # regression loss - loss_bbox = paddle.sum( - self.loss_bbox(pos_decode_bbox_pred, - pos_decode_bbox_targets) * weight_targets) - - # dfl loss - loss_dfl = self.loss_dfl( - pred_corners, - target_corners, - weight=weight_targets.expand([-1, 4]).reshape([-1]), - avg_factor=4.0) - - # ld loss - loss_ld = self.loss_ld( - pred_corners, - soft_corners, - weight=weight_targets.expand([-1, 4]).reshape([-1]), - avg_factor=4.0) - - loss_kd = self.loss_kd( - paddle.gather( - cls_score, pos_inds, axis=0), - paddle.gather( - soft_label, pos_inds, axis=0), - weight=paddle.gather( - label_weights, pos_inds, axis=0), - avg_factor=pos_inds.shape[0]) - - else: - loss_bbox = bbox_pred.sum() * 0 - loss_dfl = bbox_pred.sum() * 0 - loss_ld = bbox_pred.sum() * 0 - loss_kd = bbox_pred.sum() * 0 - weight_targets = paddle.to_tensor([0], dtype='float32') - - if len(remain_inds) > 0: - neg_pred_corners = bbox_pred[remain_inds].reshape( - [-1, self.reg_max + 1]) - neg_soft_corners = soft_targets[remain_inds].reshape( - [-1, self.reg_max + 1]) - - remain_targets = vlr_region[remain_inds] - - loss_ld_vlr = self.loss_ld_vlr( - neg_pred_corners, - neg_soft_corners, - weight=remain_targets.expand([-1, 4]).reshape([-1]), - avg_factor=16.0) - else: - loss_ld_vlr = bbox_pred.sum() * 0 - - # qfl loss - score = paddle.to_tensor(score) - loss_qfl = self.loss_qfl( - cls_score, (labels, score), - weight=label_weights, - avg_factor=num_total_pos) - - loss_bbox_list.append(loss_bbox) - loss_dfl_list.append(loss_dfl) - loss_qfl_list.append(loss_qfl) - loss_ld_list.append(loss_ld) - loss_ld_vlr_list.append(loss_ld_vlr) - loss_kd_list.append(loss_kd) - avg_factor.append(weight_targets.sum()) - - avg_factor = sum(avg_factor) # + 1e-6 - try: - paddle.distributed.all_reduce(avg_factor) - avg_factor = paddle.clip( - avg_factor / paddle.distributed.get_world_size(), min=1) - except: - avg_factor = max(avg_factor.item(), 1) - - if avg_factor <= 0: - loss_qfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False) - loss_bbox = paddle.to_tensor( - 0, dtype='float32', stop_gradient=False) - loss_dfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False) - loss_ld = paddle.to_tensor(0, dtype='float32', stop_gradient=False) - loss_ld_vlr = paddle.to_tensor( - 0, dtype='float32', stop_gradient=False) - loss_kd = paddle.to_tensor(0, dtype='float32', stop_gradient=False) - else: - losses_bbox = list(map(lambda x: x / avg_factor, loss_bbox_list)) - losses_dfl = list(map(lambda x: x / avg_factor, loss_dfl_list)) - loss_qfl = sum(loss_qfl_list) - loss_bbox = sum(losses_bbox) - loss_dfl = sum(losses_dfl) - loss_ld = sum(loss_ld_list) - loss_ld_vlr = sum(loss_ld_vlr_list) - loss_kd = sum(loss_kd_list) - - loss_states = dict( - loss_qfl=loss_qfl, - loss_bbox=loss_bbox, - loss_dfl=loss_dfl, - loss_ld=loss_ld, - loss_ld_vlr=loss_ld_vlr, - loss_kd=loss_kd) - - return loss_states diff --git a/ppdet/slim/distill_loss.py b/ppdet/slim/distill_loss.py index 6e94fd8410e351fcaa2719fc1b6e78868905c888..d325a5b2ac93983256bf8c07b165354f0b4ffd98 100644 --- a/ppdet/slim/distill_loss.py +++ b/ppdet/slim/distill_loss.py @@ -17,14 +17,12 @@ from __future__ import division from __future__ import print_function import math -import numpy as np - import paddle import paddle.nn as nn import paddle.nn.functional as F from paddle import ParamAttr -from ppdet.core.workspace import register, create +from ppdet.core.workspace import register from ppdet.modeling import ops from ppdet.modeling.losses.iou_loss import GIoULoss from ppdet.utils.logger import setup_logger @@ -456,7 +454,7 @@ class CWDFeatureLoss(nn.Layer): x /= tau return F.softmax(x, axis=1) - def forward(self, preds_s, preds_t, inputs): + def forward(self, preds_s, preds_t, inputs=None): assert preds_s.shape[-2:] == preds_t.shape[-2:] N, C, H, W = preds_s.shape eps = 1e-5 @@ -676,7 +674,7 @@ class FGDFeatureLoss(nn.Layer): wmin, wmax, hmin, hmax = [], [], [], [] - if gt_bboxes.shape[1] == 0: + if len(gt_bboxes) == 0: loss = self.relation_loss(stu_feature, tea_feature) return self.lambda_fgd * loss @@ -750,7 +748,7 @@ class PKDFeatureLoss(nn.Layer): self.loss_weight = loss_weight self.resize_stu = resize_stu - def forward(self, stu_feature, tea_feature, inputs): + def forward(self, stu_feature, tea_feature, inputs=None): size_s, size_t = stu_feature.shape[2:], tea_feature.shape[2:] if size_s[0] != size_t[0]: if self.resize_stu: @@ -791,7 +789,7 @@ class MimicFeatureLoss(nn.Layer): else: self.align = None - def forward(self, stu_feature, tea_feature, inputs): + def forward(self, stu_feature, tea_feature, inputs=None): if self.align is not None: stu_feature = self.align(stu_feature) @@ -839,7 +837,7 @@ class MGDFeatureLoss(nn.Layer): nn.Conv2D( teacher_channels, teacher_channels, kernel_size=3, padding=1)) - def forward(self, stu_feature, tea_feature, inputs): + def forward(self, stu_feature, tea_feature, inputs=None): N = stu_feature.shape[0] if self.align is not None: stu_feature = self.align(stu_feature) diff --git a/ppdet/slim/distill_model.py b/ppdet/slim/distill_model.py index c06f92f08e3ade74f780fae36f623c6e45f45a5b..96e1366381308aab6ee5e3c46d5b4d378da0783c 100644 --- a/ppdet/slim/distill_model.py +++ b/ppdet/slim/distill_model.py @@ -18,8 +18,6 @@ from __future__ import print_function import paddle import paddle.nn as nn -import paddle.nn.functional as F -from paddle import ParamAttr from ppdet.core.workspace import register, create, load_config from ppdet.utils.checkpoint import load_pretrain_weight @@ -206,13 +204,13 @@ class CWDDistillModel(DistillModel): def get_loss_retinanet(self, stu_fea_list, tea_fea_list, inputs): loss = self.student_model.head(stu_fea_list, inputs) - distill_loss = {} - for idx, k in enumerate(self.loss_dic): - distill_loss[k] = self.loss_dic[k](stu_fea_list[idx], - tea_fea_list[idx]) + loss_dict = {} + for idx, k in enumerate(self.distill_loss): + loss_dict[k] = self.distill_loss[k](stu_fea_list[idx], + tea_fea_list[idx]) - loss['loss'] += distill_loss[k] - loss[k] = distill_loss[k] + loss['loss'] += loss_dict[k] + loss[k] = loss_dict[k] return loss def get_loss_gfl(self, stu_fea_list, tea_fea_list, inputs): @@ -234,10 +232,11 @@ class CWDDistillModel(DistillModel): s_cls_feat.append(cls_score) t_cls_feat.append(t_cls_score) - for idx, k in enumerate(self.loss_dic): - loss_dict[k] = self.loss_dic[k](s_cls_feat[idx], t_cls_feat[idx]) - feat_loss[f"neck_f_{idx}"] = self.loss_dic[k](stu_fea_list[idx], - tea_fea_list[idx]) + for idx, k in enumerate(self.distill_loss): + loss_dict[k] = self.distill_loss[k](s_cls_feat[idx], + t_cls_feat[idx]) + feat_loss[f"neck_f_{idx}"] = self.distill_loss[k](stu_fea_list[idx], + tea_fea_list[idx]) for k in feat_loss: loss['loss'] += feat_loss[k]