diff --git a/docs/tutorials/GradCAM_cn.md b/docs/tutorials/GradCAM_cn.md
index e5a28d142e11c95ef4ee73e2a31b52a06f31464d..e214e48b18c982b0e9e3752e14bacd910616d36f 100644
--- a/docs/tutorials/GradCAM_cn.md
+++ b/docs/tutorials/GradCAM_cn.md
@@ -2,7 +2,7 @@
## 1.简介
-基于backbone/roi特征图计算物体预测框的cam(类激活图)
+基于backbone/roi特征图计算物体预测框的cam(类激活图), 目前支持基于FasterRCNN/MaskRCNN系列, PPYOLOE系列, 以及BlazeFace, SSD, Retinanet网络。
## 2.使用方法
* 以PP-YOLOE为例,准备好数据之后,指定网络配置文件、模型权重地址和图片路径以及输出文件夹路径,使用脚本调用tools/cam_ppdet.py计算图片中物体预测框的grad_cam热力图。下面为运行脚本示例。
@@ -27,7 +27,7 @@ python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer
cam_ppyoloe/225.jpg
-## 3. 目前支持基于FasterRCNN/MaskRCNN, PPYOLOE系列以及BlazeFace, SSD, Retinanet网络。
+## 3. 目前支持基于FasterRCNN/MaskRCNN系列, PPYOLOE系列以及BlazeFace, SSD, Retinanet网络。
* PPYOLOE网络热图可视化脚本
```bash
python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
diff --git a/docs/tutorials/GradCAM_en.md b/docs/tutorials/GradCAM_en.md
index 29487592ccc7bf3a2b3c8fbd82424851563be33a..e4a5fd40ad2efec25d116460f57bcde32b3c1aa1 100644
--- a/docs/tutorials/GradCAM_en.md
+++ b/docs/tutorials/GradCAM_en.md
@@ -1,7 +1,7 @@
# Object detection grad_cam heatmap
## 1.Introduction
-Calculate the cam (class activation map) of the object predict bbox based on the backbone/roi feature map
+Calculate the cam (class activation map) of the object predict bbox based on the backbone/roi feature map, currently supports networks based on FasterRCNN/MaskRCNN series, PPYOLOE series and BlazeFace, SSD, Retinanet.
## 2.Usage
* Taking PP-YOLOE as an example, after preparing the data, specify the network configuration file, model weight address, image path and output folder path, and then use the script to call tools/cam_ppdet.py to calculate the grad_cam heat map of the prediction box. Below is an example run script.
@@ -27,7 +27,7 @@ python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer
cam_ppyoloe/225.jpg
-## 3.Currently supports networks based on FasterRCNN/MaskRCNN, PPYOLOE series and BlazeFace, SSD, Retinanet.
+## 3.Currently supports networks based on FasterRCNN/MaskRCNN series, PPYOLOE series and BlazeFace, SSD, Retinanet.
* PPYOLOE bbox heat map visualization script (with backbone featuremap)
```bash
python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
diff --git a/ppdet/utils/cam_utils.py b/ppdet/utils/cam_utils.py
index d98350e2c3dda5ec86aa2e83a0d1b0ee0e13b7c0..d2f7a4732be1a8ba3157d0d28304b0bd1b71da02 100644
--- a/ppdet/utils/cam_utils.py
+++ b/ppdet/utils/cam_utils.py
@@ -143,9 +143,9 @@ class BBoxCAM:
# set for get extra_data before nms
trainer.model.use_extra_data=True
# set for record the bbox index before nms
- if cfg.architecture == 'FasterRCNN' or cfg.architecture == 'MaskRCNN':
+ if cfg.architecture in ['FasterRCNN', 'MaskRCNN']:
trainer.model.bbox_post_process.nms.return_index = True
- elif cfg.architecture == 'YOLOv3':
+ elif cfg.architecture in ['YOLOv3', 'PPYOLOE', 'PPYOLOEWithAuxHead']:
if trainer.model.post_process is not None:
# anchor based YOLOs: YOLOv3,PP-YOLO
trainer.model.post_process.nms.return_index = True
diff --git a/tools/cam_ppdet.py b/tools/cam_ppdet.py
index e73f082bb2951ee72c686ad92f10c6d0add07262..c8922b9353f90c5f40472ac8e1b7621bf903e620 100644
--- a/tools/cam_ppdet.py
+++ b/tools/cam_ppdet.py
@@ -67,7 +67,10 @@ def parse_args():
return args
def run(FLAGS, cfg):
- assert cfg.architecture in ['FasterRCNN', 'MaskRCNN', 'YOLOv3', 'BlazeFace', 'SSD', 'RetinaNet'], 'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!'
+ assert cfg.architecture in ['FasterRCNN', 'MaskRCNN', 'YOLOv3', 'PPYOLOE',
+ 'PPYOLOEWithAuxHead', 'BlazeFace', 'SSD', 'RetinaNet'], \
+ 'Only supported cam for faster_rcnn based and yolov3 based architecture for now, ' \
+ 'the others are not supported temporarily!'
bbox_cam = BBoxCAM(FLAGS, cfg)
bbox_cam.get_bboxes_cams()