diff --git a/docs/tutorials/GradCAM_cn.md b/docs/tutorials/GradCAM_cn.md
index 2dcd06e05da753dd3c4622748dae114058a07e68..e5a28d142e11c95ef4ee73e2a31b52a06f31464d 100644
--- a/docs/tutorials/GradCAM_cn.md
+++ b/docs/tutorials/GradCAM_cn.md
@@ -2,22 +2,23 @@
## 1.简介
-基于backbone特征图计算物体预测框的cam(类激活图)
+基于backbone/roi特征图计算物体预测框的cam(类激活图)
## 2.使用方法
* 以PP-YOLOE为例,准备好数据之后,指定网络配置文件、模型权重地址和图片路径以及输出文件夹路径,使用脚本调用tools/cam_ppdet.py计算图片中物体预测框的grad_cam热力图。下面为运行脚本示例。
```shell
-python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
+python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
```
* **参数**
-| FLAG | 用途 |
-| :----------------------: |:-----------------------------------------------------------------------------------------------------:|
-| -c | 指定配置文件 |
-| --infer_img | 用于预测的图片路径 |
-| --cam_out | 指定输出路径 |
-| -o | 设置或更改配置文件里的参数内容, 如 -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams |
+| FLAG | 用途 |
+|:--------------------------:|:--------------------------------------------------------------------------------------------------------------------------:|
+| -c | 指定配置文件 |
+| --infer_img | 用于预测的图片路径 |
+| --cam_out | 指定输出路径 |
+| --target_feature_layer_name | 计算cam的特征图位置, 如model.backbone、 model.bbox_head.roi_extractor |
+| -o | 设置或更改配置文件里的参数内容, 如 -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams |
* 运行效果
@@ -26,12 +27,43 @@ python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer
cam_ppyoloe/225.jpg
-## 3. 目前支持基于FasterRCNN和YOLOv3系列的网络。
-* FasterRCNN网络热图可视化脚本
+## 3. 目前支持基于FasterRCNN/MaskRCNN, PPYOLOE系列以及BlazeFace, SSD, Retinanet网络。
+* PPYOLOE网络热图可视化脚本
```bash
-python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams
+python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
```
-* PPYOLOE网络热图可视化脚本
+
+* MaskRCNN网络roi特征热图可视化脚本
+```bash
+python tools/cam_ppdet.py -c configs/mask_rcnn/mask_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_mask_rcnn_roi --target_feature_layer_name model.bbox_head.roi_extractor -o weights=https://paddledet.bj.bcebos.com/models/mask_rcnn_r50_vd_fpn_2x_coco.pdparams
+```
+
+* MaskRCNN网络backbone特征的热图可视化脚本
+```bash
+python tools/cam_ppdet.py -c configs/mask_rcnn/mask_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_mask_rcnn_backbone --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/mask_rcnn_r50_vd_fpn_2x_coco.pdparams
+```
+
+* FasterRCNN网络基于roi特征的热图可视化脚本
+```bash
+python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn_roi --target_feature_layer_name model.bbox_head.roi_extractor -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams
+```
+
+* FasterRCNN网络基于backbone特征的热图可视化脚本
+```bash
+python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn_backbone --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams
+```
+
+* BlaczeFace网络backbone特征热图可视化脚本
```bash
-python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
-```
\ No newline at end of file
+python tools/cam_ppdet.py -c configs/face_detection/blazeface_1000e.yml --infer_img demo/hrnet_demo.jpg --cam_out cam_blazeface --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/blazeface_1000e.pdparams
+```
+
+* SSD网络backbone特征热图可视化脚本
+```bash
+python tools/cam_ppdet.py -c configs/ssd/ssd_mobilenet_v1_300_120e_voc.yml --infer_img demo/000000014439.jpg --cam_out cam_ssd --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/ssd_mobilenet_v1_300_120e_voc.pdparams
+```
+
+* Retinanet网络backbone特征热图可视化脚本
+```bash
+python tools/cam_ppdet.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_retinanet --target_feature_layer_name model.backbone -o weights=https://bj.bcebos.com/v1/paddledet/models/retinanet_r50_fpn_2x_coco.pdparams
+```
diff --git a/docs/tutorials/GradCAM_en.md b/docs/tutorials/GradCAM_en.md
index 7551c709b98e101620c0b533d30e412df8ba4381..29487592ccc7bf3a2b3c8fbd82424851563be33a 100644
--- a/docs/tutorials/GradCAM_en.md
+++ b/docs/tutorials/GradCAM_en.md
@@ -1,12 +1,12 @@
# Object detection grad_cam heatmap
## 1.Introduction
-Calculate the cam (class activation map) of the object predict bbox based on the backbone feature map
+Calculate the cam (class activation map) of the object predict bbox based on the backbone/roi feature map
## 2.Usage
* Taking PP-YOLOE as an example, after preparing the data, specify the network configuration file, model weight address, image path and output folder path, and then use the script to call tools/cam_ppdet.py to calculate the grad_cam heat map of the prediction box. Below is an example run script.
```shell
-python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
+python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
```
* **Arguments**
@@ -16,6 +16,7 @@ python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer
| -c | Select config file |
| --infer_img | Image path |
| --cam_out | Directory for output |
+| --target_feature_layer_name | The position of featuremap to do gradcam, for example:model.backbone, model.bbox_head.roi_extractor |
| -o | Set parameters in configure file, for example: -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams |
* result
@@ -26,12 +27,45 @@ python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer
cam_ppyoloe/225.jpg
-## 3.Currently supports networks based on FasterRCNN and YOLOv3 series.
-* FasterRCNN bbox heat map visualization script
+## 3.Currently supports networks based on FasterRCNN/MaskRCNN, PPYOLOE series and BlazeFace, SSD, Retinanet.
+* PPYOLOE bbox heat map visualization script (with backbone featuremap)
```bash
-python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams
+python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
```
-* PPYOLOE bbox heat map visualization script
+
+* MaskRCNN bbox heat map visualization script (with roi featuremap)
```bash
-python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
-```
\ No newline at end of file
+python tools/cam_ppdet.py -c configs/mask_rcnn/mask_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_mask_rcnn_roi --target_feature_layer_name model.bbox_head.roi_extractor -o weights=https://paddledet.bj.bcebos.com/models/mask_rcnn_r50_vd_fpn_2x_coco.pdparams
+```
+
+* MaskRCNN bbox heat map visualization script (with backbone featuremap)
+```bash
+python tools/cam_ppdet.py -c configs/mask_rcnn/mask_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_mask_rcnn_backbone --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/mask_rcnn_r50_vd_fpn_2x_coco.pdparams
+```
+
+* FasterRCNN bbox heat map visualization script (with roi featuremap)
+```bash
+python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn_roi --target_feature_layer_name model.bbox_head.roi_extractor -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams
+```
+
+* FasterRCNN bbox heat map visualization script (with backbone featuremap)
+```bash
+python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn_backbone --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams
+```
+
+* BlaczeFace bbox heat map visualization script (with backbone featuremap)
+```bash
+python tools/cam_ppdet.py -c configs/face_detection/blazeface_1000e.yml --infer_img demo/hrnet_demo.jpg --cam_out cam_blazeface --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/blazeface_1000e.pdparams
+```
+
+* SSD bbox heat map visualization script (with backbone featuremap)
+```bash
+python tools/cam_ppdet.py -c configs/ssd/ssd_mobilenet_v1_300_120e_voc.yml --infer_img demo/000000014439.jpg --cam_out cam_ssd --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/ssd_mobilenet_v1_300_120e_voc.pdparams
+```
+
+* Retinanet bbox heat map visualization script (with backbone featuremap)
+```bash
+python tools/cam_ppdet.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_retinanet --target_feature_layer_name model.backbone -o weights=https://bj.bcebos.com/v1/paddledet/models/retinanet_r50_fpn_2x_coco.pdparams
+```
+
+
diff --git a/ppdet/modeling/architectures/blazeface.py b/ppdet/modeling/architectures/blazeface.py
index 2f5d0d238ebd5d4b92916e79ff192932f4baff33..477732d95eec16448b42b062918d099344a81a10 100644
--- a/ppdet/modeling/architectures/blazeface.py
+++ b/ppdet/modeling/architectures/blazeface.py
@@ -18,6 +18,8 @@ from __future__ import print_function
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
+import paddle
+import paddle.nn.functional as F
__all__ = ['BlazeFace']
@@ -74,18 +76,42 @@ class BlazeFace(BaseArch):
self.inputs['gt_class'])
else:
preds, anchors = self.blaze_head(neck_feats, self.inputs['image'])
- bbox, bbox_num, before_nms_indexes = self.post_process(
+ bbox, bbox_num, nms_keep_idx = self.post_process(
preds, anchors, self.inputs['im_shape'],
self.inputs['scale_factor'])
- return bbox, bbox_num
+ if self.use_extra_data:
+ extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx
+ """extra_data:{
+ 'scores': predict scores,
+ 'nms_keep_idx': bbox index before nms,
+ }
+ """
+ preds_logits = preds[1] # [[1xNumBBoxNumClass]]
+ extra_data['scores'] = F.softmax(paddle.concat(
+ preds_logits, axis=1)).transpose([0, 2, 1])
+ extra_data['logits'] = paddle.concat(
+ preds_logits, axis=1).transpose([0, 2, 1])
+ extra_data['nms_keep_idx'] = nms_keep_idx # bbox index before nms
+ return bbox, bbox_num, extra_data
+ else:
+ return bbox, bbox_num
def get_loss(self, ):
return {"loss": self._forward()}
def get_pred(self):
- bbox_pred, bbox_num = self._forward()
- output = {
- "bbox": bbox_pred,
- "bbox_num": bbox_num,
- }
+ if self.use_extra_data:
+ bbox_pred, bbox_num, extra_data = self._forward()
+ output = {
+ "bbox": bbox_pred,
+ "bbox_num": bbox_num,
+ "extra_data": extra_data
+ }
+ else:
+ bbox_pred, bbox_num = self._forward()
+ output = {
+ "bbox": bbox_pred,
+ "bbox_num": bbox_num,
+ }
+
return output
diff --git a/ppdet/modeling/architectures/cascade_rcnn.py b/ppdet/modeling/architectures/cascade_rcnn.py
index e8d330c695e59ef02e5e81907c98f369d0f23cd6..c5d454f4948891ed1400d09d0e24490ce46fb361 100644
--- a/ppdet/modeling/architectures/cascade_rcnn.py
+++ b/ppdet/modeling/architectures/cascade_rcnn.py
@@ -108,7 +108,7 @@ class CascadeRCNN(BaseArch):
im_shape = self.inputs['im_shape']
scale_factor = self.inputs['scale_factor']
- bbox, bbox_num, before_nms_indexes = self.bbox_post_process(
+ bbox, bbox_num, nms_keep_idx = self.bbox_post_process(
preds, (refined_rois, rois_num), im_shape, scale_factor)
# rescale the prediction back to origin image
bbox, bbox_pred, bbox_num = self.bbox_post_process.get_pred(
diff --git a/ppdet/modeling/architectures/faster_rcnn.py b/ppdet/modeling/architectures/faster_rcnn.py
index a520d3a164ba951d5089e54b86ce53fa00265333..41c286fe02ef9d25f5b086ed99931fdd2aa70062 100644
--- a/ppdet/modeling/architectures/faster_rcnn.py
+++ b/ppdet/modeling/architectures/faster_rcnn.py
@@ -82,21 +82,31 @@ class FasterRCNN(BaseArch):
self.inputs)
return rpn_loss, bbox_loss
else:
- cam_data = {} # record bbox scores and index before nms
rois, rois_num, _ = self.rpn_head(body_feats, self.inputs)
preds, _ = self.bbox_head(body_feats, rois, rois_num, None)
- cam_data['scores'] = preds[1]
-
im_shape = self.inputs['im_shape']
scale_factor = self.inputs['scale_factor']
- bbox, bbox_num, before_nms_indexes = self.bbox_post_process(preds, (rois, rois_num),
+ bbox, bbox_num, nms_keep_idx = self.bbox_post_process(preds, (rois, rois_num),
im_shape, scale_factor)
- cam_data['before_nms_indexes'] = before_nms_indexes # , bbox index before nms, for cam
# rescale the prediction back to origin image
bboxes, bbox_pred, bbox_num = self.bbox_post_process.get_pred(
bbox, bbox_num, im_shape, scale_factor)
- return bbox_pred, bbox_num, cam_data
+
+ if self.use_extra_data:
+ extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx
+ """extra_data:{
+ 'scores': predict scores,
+ 'nms_keep_idx': bbox index before nms,
+ }
+ """
+ extra_data['scores'] = preds[1] # predict scores (probability)
+ # Todo: get logits output
+ extra_data['nms_keep_idx'] = nms_keep_idx # bbox index before nms
+ return bbox_pred, bbox_num, extra_data
+ else:
+ return bbox_pred, bbox_num
+
def get_loss(self, ):
rpn_loss, bbox_loss = self._forward()
@@ -108,8 +118,12 @@ class FasterRCNN(BaseArch):
return loss
def get_pred(self):
- bbox_pred, bbox_num, cam_data = self._forward()
- output = {'bbox': bbox_pred, 'bbox_num': bbox_num, 'cam_data': cam_data}
+ if self.use_extra_data:
+ bbox_pred, bbox_num, extra_data = self._forward()
+ output = {'bbox': bbox_pred, 'bbox_num': bbox_num, 'extra_data': extra_data}
+ else:
+ bbox_pred, bbox_num = self._forward()
+ output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
return output
def target_bbox_forward(self, data):
diff --git a/ppdet/modeling/architectures/mask_rcnn.py b/ppdet/modeling/architectures/mask_rcnn.py
index bea6a3eea0e0b4d301a99eb97fc6f159f6345db4..4f6a9ce10f76801ca4bff4e3dc3e304b8e3567f5 100644
--- a/ppdet/modeling/architectures/mask_rcnn.py
+++ b/ppdet/modeling/architectures/mask_rcnn.py
@@ -106,7 +106,7 @@ class MaskRCNN(BaseArch):
im_shape = self.inputs['im_shape']
scale_factor = self.inputs['scale_factor']
- bbox, bbox_num, before_nms_indexes = self.bbox_post_process(
+ bbox, bbox_num, nms_keep_idx = self.bbox_post_process(
preds, (rois, rois_num), im_shape, scale_factor)
mask_out = self.mask_head(
body_feats, bbox, bbox_num, self.inputs, feat_func=feat_func)
@@ -117,7 +117,20 @@ class MaskRCNN(BaseArch):
origin_shape = self.bbox_post_process.get_origin_shape()
mask_pred = self.mask_post_process(mask_out, bbox_pred, bbox_num,
origin_shape)
- return bbox_pred, bbox_num, mask_pred
+
+ if self.use_extra_data:
+ extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx
+ """extra_data:{
+ 'scores': predict scores,
+ 'nms_keep_idx': bbox index before nms,
+ }
+ """
+ extra_data['scores'] = preds[1] # predict scores (probability)
+ # Todo: get logits output
+ extra_data['nms_keep_idx'] = nms_keep_idx # bbox index before nms
+ return bbox_pred, bbox_num, mask_pred, extra_data
+ else:
+ return bbox_pred, bbox_num, mask_pred
def get_loss(self, ):
bbox_loss, mask_loss, rpn_loss = self._forward()
@@ -130,6 +143,10 @@ class MaskRCNN(BaseArch):
return loss
def get_pred(self):
- bbox_pred, bbox_num, mask_pred = self._forward()
- output = {'bbox': bbox_pred, 'bbox_num': bbox_num, 'mask': mask_pred}
+ if self.use_extra_data:
+ bbox_pred, bbox_num, mask_pred, extra_data = self._forward()
+ output = {'bbox': bbox_pred, 'bbox_num': bbox_num, 'mask': mask_pred, 'extra_data': extra_data}
+ else:
+ bbox_pred, bbox_num, mask_pred = self._forward()
+ output = {'bbox': bbox_pred, 'bbox_num': bbox_num, 'mask': mask_pred}
return output
diff --git a/ppdet/modeling/architectures/meta_arch.py b/ppdet/modeling/architectures/meta_arch.py
index 4ff84a97a61739e06f215f56a64daf0459e4a971..370b2b124bfc1f5477a942f972731f2857e1641c 100644
--- a/ppdet/modeling/architectures/meta_arch.py
+++ b/ppdet/modeling/architectures/meta_arch.py
@@ -15,11 +15,12 @@ __all__ = ['BaseArch']
@register
class BaseArch(nn.Layer):
- def __init__(self, data_format='NCHW'):
+ def __init__(self, data_format='NCHW', use_extra_data=False):
super(BaseArch, self).__init__()
self.data_format = data_format
self.inputs = {}
self.fuse_norm = False
+ self.use_extra_data = use_extra_data
def load_meanstd(self, cfg_transform):
scale = 1.
diff --git a/ppdet/modeling/architectures/ppyoloe.py b/ppdet/modeling/architectures/ppyoloe.py
index 7c2b3a81575a6c527eed6def5ee702320b1e7e76..330542b8ab20d04fb433eddffa14bf05afd8e6a2 100644
--- a/ppdet/modeling/architectures/ppyoloe.py
+++ b/ppdet/modeling/architectures/ppyoloe.py
@@ -106,21 +106,30 @@ class PPYOLOE(BaseArch):
raise ValueError
return yolo_losses
else:
- cam_data = {} # record bbox scores and index before nms
+
yolo_head_outs = self.yolo_head(neck_feats)
- cam_data['scores'] = yolo_head_outs[0]
if self.post_process is not None:
- bbox, bbox_num, before_nms_indexes = self.post_process(
+ bbox, bbox_num, nms_keep_idx = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
- cam_data['before_nms_indexes'] = before_nms_indexes
+
else:
- bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process(
+ bbox, bbox_num, nms_keep_idx = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor'])
- # data for cam
- cam_data['before_nms_indexes'] = before_nms_indexes
- output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data}
+
+ if self.use_extra_data:
+ extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx
+ """extra_data:{
+ 'scores': predict scores,
+ 'nms_keep_idx': bbox index before nms,
+ }
+ """
+ extra_data['scores'] = yolo_head_outs[0] # predict scores (probability)
+ extra_data['nms_keep_idx'] = nms_keep_idx
+ output = {'bbox': bbox, 'bbox_num': bbox_num, 'extra_data': extra_data}
+ else:
+ output = {'bbox': bbox, 'bbox_num': bbox_num}
return output
@@ -218,21 +227,29 @@ class PPYOLOEWithAuxHead(BaseArch):
aux_pred=[aux_cls_scores, aux_bbox_preds])
return loss
else:
- cam_data = {} # record bbox scores and index before nms
yolo_head_outs = self.yolo_head(neck_feats)
- cam_data['scores'] = yolo_head_outs[0]
if self.post_process is not None:
- bbox, bbox_num, before_nms_indexes = self.post_process(
+ bbox, bbox_num, nms_keep_idx = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
- cam_data['before_nms_indexes'] = before_nms_indexes
else:
- bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process(
+ bbox, bbox_num, nms_keep_idx = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor'])
- # data for cam
- cam_data['before_nms_indexes'] = before_nms_indexes
- output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data}
+
+ if self.use_extra_data:
+ extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx
+ """extra_data:{
+ 'scores': predict scores,
+ 'nms_keep_idx': bbox index before nms,
+ }
+ """
+ extra_data['scores'] = yolo_head_outs[0] # predict scores (probability)
+ # Todo: get logits output
+ extra_data['nms_keep_idx'] = nms_keep_idx
+ output = {'bbox': bbox, 'bbox_num': bbox_num, 'extra_data': extra_data}
+ else:
+ output = {'bbox': bbox, 'bbox_num': bbox_num}
return output
diff --git a/ppdet/modeling/architectures/retinanet.py b/ppdet/modeling/architectures/retinanet.py
index e774430a03dfebf74c1e91138ed57f2ee52f1c9d..fc49f0e97365ef10d14c9214133d53db304b880b 100644
--- a/ppdet/modeling/architectures/retinanet.py
+++ b/ppdet/modeling/architectures/retinanet.py
@@ -19,6 +19,7 @@ from __future__ import print_function
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
import paddle
+import paddle.nn.functional as F
__all__ = ['RetinaNet']
@@ -57,9 +58,24 @@ class RetinaNet(BaseArch):
return self.head(neck_feats, self.inputs)
else:
head_outs = self.head(neck_feats)
- bbox, bbox_num = self.head.post_process(
+ bbox, bbox_num, nms_keep_idx = self.head.post_process(
head_outs, self.inputs['im_shape'], self.inputs['scale_factor'])
- return {'bbox': bbox, 'bbox_num': bbox_num}
+
+ if self.use_extra_data:
+ extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx
+ """extra_data:{
+ 'scores': predict scores,
+ 'nms_keep_idx': bbox index before nms,
+ }
+ """
+ preds_logits = self.head.decode_cls_logits(head_outs[0])
+ preds_scores = F.sigmoid(preds_logits)
+ extra_data['logits'] = preds_logits
+ extra_data['scores'] = preds_scores
+ extra_data['nms_keep_idx'] = nms_keep_idx # bbox index before nms
+ return {'bbox': bbox, 'bbox_num': bbox_num, "extra_data": extra_data}
+ else:
+ return {'bbox': bbox, 'bbox_num': bbox_num}
def get_loss(self):
return self._forward()
diff --git a/ppdet/modeling/architectures/ssd.py b/ppdet/modeling/architectures/ssd.py
index 2d2bf1dd2ce4b4b273ff420239f2f388ae05e015..b8669b7cf127857662f0f78e9c52e43f29fbfcfe 100644
--- a/ppdet/modeling/architectures/ssd.py
+++ b/ppdet/modeling/architectures/ssd.py
@@ -18,6 +18,8 @@ from __future__ import print_function
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
+import paddle
+import paddle.nn.functional as F
__all__ = ['SSD']
@@ -75,18 +77,42 @@ class SSD(BaseArch):
self.inputs['gt_class'])
else:
preds, anchors = self.ssd_head(body_feats, self.inputs['image'])
- bbox, bbox_num, before_nms_indexes = self.post_process(
+ bbox, bbox_num, nms_keep_idx = self.post_process(
preds, anchors, self.inputs['im_shape'],
self.inputs['scale_factor'])
- return bbox, bbox_num
+
+ if self.use_extra_data:
+ extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx
+ """extra_data:{
+ 'scores': predict scores,
+ 'nms_keep_idx': bbox index before nms,
+ }
+ """
+ preds_logits = preds[1] # [[1xNumBBoxNumClass]]
+ extra_data['scores'] = F.softmax(paddle.concat(
+ preds_logits, axis=1)).transpose([0, 2, 1])
+ extra_data['logits'] = paddle.concat(
+ preds_logits, axis=1).transpose([0, 2, 1])
+ extra_data['nms_keep_idx'] = nms_keep_idx # bbox index before nms
+ return bbox, bbox_num, extra_data
+ else:
+ return bbox, bbox_num
def get_loss(self, ):
return {"loss": self._forward()}
def get_pred(self):
- bbox_pred, bbox_num = self._forward()
- output = {
- "bbox": bbox_pred,
- "bbox_num": bbox_num,
- }
+ if self.use_extra_data:
+ bbox_pred, bbox_num, extra_data = self._forward()
+ output = {
+ "bbox": bbox_pred,
+ "bbox_num": bbox_num,
+ "extra_data": extra_data
+ }
+ else:
+ bbox_pred, bbox_num = self._forward()
+ output = {
+ "bbox": bbox_pred,
+ "bbox_num": bbox_num,
+ }
return output
diff --git a/ppdet/modeling/architectures/yolo.py b/ppdet/modeling/architectures/yolo.py
index 4d9ec4d33fc12253846720db7adc10ce2c1954ed..b004935654ed0ec2290af758133f479147231dbd 100644
--- a/ppdet/modeling/architectures/yolo.py
+++ b/ppdet/modeling/architectures/yolo.py
@@ -98,9 +98,7 @@ class YOLOv3(BaseArch):
return yolo_losses
else:
- cam_data = {} # record bbox scores and index before nms
yolo_head_outs = self.yolo_head(neck_feats)
- cam_data['scores'] = yolo_head_outs[0]
if self.for_mot:
# the detection part of JDE MOT model
@@ -116,21 +114,32 @@ class YOLOv3(BaseArch):
else:
if self.return_idx:
# the detection part of JDE MOT model
- _, bbox, bbox_num, _ = self.post_process(
+ _, bbox, bbox_num, nms_keep_idx = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors)
elif self.post_process is not None:
# anchor based YOLOs: YOLOv3,PP-YOLO,PP-YOLOv2 use mask_anchors
- bbox, bbox_num, before_nms_indexes = self.post_process(
+ bbox, bbox_num, nms_keep_idx = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
- cam_data['before_nms_indexes'] = before_nms_indexes
else:
# anchor free YOLOs: PP-YOLOE, PP-YOLOE+
- bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process(
+ bbox, bbox_num, nms_keep_idx = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor'])
- # data for cam
- cam_data['before_nms_indexes'] = before_nms_indexes
- output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data}
+
+ if self.use_extra_data:
+ extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx
+ """extra_data:{
+ 'scores': predict scores,
+ 'nms_keep_idx': bbox index before nms,
+ }
+ """
+ extra_data['scores'] = yolo_head_outs[0] # predict scores (probability)
+ # Todo: get logits output
+ extra_data['nms_keep_idx'] = nms_keep_idx
+ # Todo support for mask_anchors yolo
+ output = {'bbox': bbox, 'bbox_num': bbox_num, 'extra_data': extra_data}
+ else:
+ output = {'bbox': bbox, 'bbox_num': bbox_num}
return output
diff --git a/ppdet/modeling/heads/ppyoloe_head.py b/ppdet/modeling/heads/ppyoloe_head.py
index 5784f70f878ad8f44561cca09c109a234331aa4e..60d7bbc2f83fb95dbb7a1afe7be891df98ce5e76 100644
--- a/ppdet/modeling/heads/ppyoloe_head.py
+++ b/ppdet/modeling/heads/ppyoloe_head.py
@@ -525,9 +525,9 @@ class PPYOLOEHead(nn.Layer):
# `exclude_nms=True` just use in benchmark
return pred_bboxes, pred_scores, None
else:
- bbox_pred, bbox_num, before_nms_indexes = self.nms(pred_bboxes,
+ bbox_pred, bbox_num, nms_keep_idx = self.nms(pred_bboxes,
pred_scores)
- return bbox_pred, bbox_num, before_nms_indexes
+ return bbox_pred, bbox_num, nms_keep_idx
def get_activation(name="LeakyReLU"):
diff --git a/ppdet/modeling/heads/ppyoloe_r_head.py b/ppdet/modeling/heads/ppyoloe_r_head.py
index fba401d7e7a7ff725938760be09b478d924dc705..e7cf772f56991152f138dfbf7f5297d01c0e0b0f 100644
--- a/ppdet/modeling/heads/ppyoloe_r_head.py
+++ b/ppdet/modeling/heads/ppyoloe_r_head.py
@@ -420,6 +420,6 @@ class PPYOLOERHead(nn.Layer):
pred_bboxes /= scale_factor
if self.export_onnx:
return pred_bboxes, pred_scores, None
- bbox_pred, bbox_num, before_nms_indexes = self.nms(pred_bboxes,
+ bbox_pred, bbox_num, nms_keep_idx = self.nms(pred_bboxes,
pred_scores)
- return bbox_pred, bbox_num, before_nms_indexes
+ return bbox_pred, bbox_num, nms_keep_idx
diff --git a/ppdet/modeling/heads/retina_head.py b/ppdet/modeling/heads/retina_head.py
index 8705e86febb30d06fcbbd06187a76548450c9600..67a51265d1decee3d3077a504771f5be050101f3 100644
--- a/ppdet/modeling/heads/retina_head.py
+++ b/ppdet/modeling/heads/retina_head.py
@@ -245,5 +245,34 @@ class RetinaHead(nn.Layer):
bboxes, scores = self.decode(anchors, cls_logits, bboxes_reg, im_shape,
scale_factor)
- bbox_pred, bbox_num, _ = self.nms(bboxes, scores)
- return bbox_pred, bbox_num
+ bbox_pred, bbox_num, nms_keep_idx = self.nms(bboxes, scores)
+ return bbox_pred, bbox_num, nms_keep_idx
+
+
+ def get_scores_single(self, cls_scores_list):
+ mlvl_logits = []
+ for cls_score in cls_scores_list:
+ cls_score = cls_score.reshape([-1, self.num_classes])
+ if self.nms_pre is not None and cls_score.shape[0] > self.nms_pre:
+ max_score = cls_score.max(axis=1)
+ _, topk_inds = max_score.topk(self.nms_pre)
+ cls_score = cls_score.gather(topk_inds)
+
+ mlvl_logits.append(cls_score)
+
+ mlvl_logits = paddle.concat(mlvl_logits)
+ mlvl_logits = mlvl_logits.transpose([1, 0])
+
+ return mlvl_logits
+
+ def decode_cls_logits(self, cls_logits_list):
+ cls_logits = [_.transpose([0, 2, 3, 1]) for _ in cls_logits_list]
+ batch_logits = []
+ for img_id in range(cls_logits[0].shape[0]):
+ num_lvls = len(cls_logits)
+ cls_scores_list = [cls_logits[i][img_id] for i in range(num_lvls)]
+ logits = self.get_scores_single(cls_scores_list)
+ batch_logits.append(logits)
+ batch_logits = paddle.stack(batch_logits, axis=0)
+ return batch_logits
+
diff --git a/ppdet/modeling/heads/roi_extractor.py b/ppdet/modeling/heads/roi_extractor.py
index b96bb4e91a0ca720d099562c3cf51d95080051fa..6c2f5c81904bf1e55e7799c78a76edc2034e447b 100644
--- a/ppdet/modeling/heads/roi_extractor.py
+++ b/ppdet/modeling/heads/roi_extractor.py
@@ -15,6 +15,7 @@
import paddle
from ppdet.core.workspace import register
from ppdet.modeling import ops
+import paddle.nn as nn
def _to_list(v):
@@ -24,7 +25,7 @@ def _to_list(v):
@register
-class RoIAlign(object):
+class RoIAlign(nn.Layer):
"""
RoI Align module
@@ -73,7 +74,7 @@ class RoIAlign(object):
def from_config(cls, cfg, input_shape):
return {'spatial_scale': [1. / i.stride for i in input_shape]}
- def __call__(self, feats, roi, rois_num):
+ def forward(self, feats, roi, rois_num):
roi = paddle.concat(roi) if len(roi) > 1 else roi[0]
if len(feats) == 1:
rois_feat = paddle.vision.ops.roi_align(
diff --git a/ppdet/utils/cam_utils.py b/ppdet/utils/cam_utils.py
index 7aa4b174320d49f3655c15c87d406a85e0ecc920..d98350e2c3dda5ec86aa2e83a0d1b0ee0e13b7c0 100644
--- a/ppdet/utils/cam_utils.py
+++ b/ppdet/utils/cam_utils.py
@@ -4,6 +4,7 @@ import os
import sys
import glob
from ppdet.utils.logger import setup_logger
+import copy
logger = setup_logger('ppdet_cam')
import paddle
@@ -119,7 +120,12 @@ class BBoxCAM:
# num_class
self.num_class = cfg.num_classes
# set hook for extraction of featuremaps and grads
- self.set_hook()
+ self.set_hook(cfg)
+ self.nms_idx_need_divid_numclass_arch = ['FasterRCNN', 'MaskRCNN', 'CascadeRCNN']
+ """
+ In these networks, the bbox array shape before nms contain num_class,
+ the nms_keep_idx of the bbox need to divide the num_class;
+ """
# cam image output_dir
try:
@@ -134,9 +140,10 @@ class BBoxCAM:
# load weights
trainer.load_weights(cfg.weights)
- # record the bbox index before nms
- # Todo: hard code for nms return index
- if cfg.architecture == 'FasterRCNN':
+ # set for get extra_data before nms
+ trainer.model.use_extra_data=True
+ # set for record the bbox index before nms
+ if cfg.architecture == 'FasterRCNN' or cfg.architecture == 'MaskRCNN':
trainer.model.bbox_post_process.nms.return_index = True
elif cfg.architecture == 'YOLOv3':
if trainer.model.post_process is not None:
@@ -145,24 +152,39 @@ class BBoxCAM:
else:
# anchor free YOLOs: PP-YOLOE, PP-YOLOE+
trainer.model.yolo_head.nms.return_index = True
+ elif cfg.architecture=='BlazeFace' or cfg.architecture=='SSD':
+ trainer.model.post_process.nms.return_index = True
+ elif cfg.architecture=='RetinaNet':
+ trainer.model.head.nms.return_index = True
else:
print(
- 'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!'
+ cfg.architecture+' is not supported for cam temporarily!'
)
sys.exit()
+ # Todo: Unify the head/post_process name in each model
return trainer
- def set_hook(self):
+ def set_hook(self, cfg):
# set hook for extraction of featuremaps and grads
self.target_feats = {}
- self.target_layer_name = 'trainer.model.backbone'
+ self.target_layer_name = cfg.target_feature_layer_name
+ # such as trainer.model.backbone, trainer.model.bbox_head.roi_extractor
def hook(layer, input, output):
self.target_feats[layer._layer_name_for_hook] = output
- self.trainer.model.backbone._layer_name_for_hook = self.target_layer_name
- self.trainer.model.backbone.register_forward_post_hook(hook)
+ try:
+ exec('self.trainer.'+self.target_layer_name+'._layer_name_for_hook = self.target_layer_name')
+ # self.trainer.target_layer_name._layer_name_for_hook = self.target_layer_name
+ exec('self.trainer.'+self.target_layer_name+'.register_forward_post_hook(hook)')
+ # self.trainer.target_layer_name.register_forward_post_hook(hook)
+ except:
+ print("Error! "
+ "The target_layer_name--"+self.target_layer_name+" is not in model! "
+ "Please check the spelling and "
+ "the network's architecture!")
+ sys.exit()
def get_bboxes(self):
# get inference images
@@ -187,31 +209,27 @@ class BBoxCAM:
img = np.array(Image.open(self.cfg.infer_img))
# data for calaulate bbox grad_cam
- cam_data = inference_result['cam_data']
+ extra_data = inference_result['extra_data']
"""
- if Faster_RCNN based architecture:
- cam_data: {'scores': tensor with shape [num_of_bboxes_before_nms, num_classes], for example: [1000, 80]
- 'before_nms_indexes': tensor with shape [num_of_bboxes_after_nms, 1], for example: [300, 1]
+ Example of Faster_RCNN based architecture:
+ extra_data: {'scores': tensor with shape [num_of_bboxes_before_nms, num_classes], for example: [1000, 80]
+ 'nms_keep_idx': tensor with shape [num_of_bboxes_after_nms, 1], for example: [300, 1]
}
- elif YOLOv3 based architecture:
- cam_data: {'scores': tensor with shape [1, num_classes, num_of_yolo_bboxes_before_nms], #for example: [1, 80, 8400]
- 'before_nms_indexes': tensor with shape [num_of_yolo_bboxes_after_nms, 1], # for example: [300, 1]
+ Example of YOLOv3 based architecture:
+ extra_data: {'scores': tensor with shape [1, num_classes, num_of_yolo_bboxes_before_nms], #for example: [1, 80, 8400]
+ 'nms_keep_idx': tensor with shape [num_of_yolo_bboxes_after_nms, 1], # for example: [300, 1]
}
"""
# array index of the predicted bbox before nms
- if self.cfg.architecture == 'FasterRCNN':
- # the bbox array shape of FasterRCNN before nms is [num_of_bboxes_before_nms, num_classes, 4],
+ if self.cfg.architecture in self.nms_idx_need_divid_numclass_arch:
+ # some network's bbox array shape before nms may be like [num_of_bboxes_before_nms, num_classes, 4],
# we need to divide num_classes to get the before_nms_index;
- before_nms_indexes = cam_data['before_nms_indexes'].cpu().numpy(
+ # currently, only include the rcnn architectures (fasterrcnn, maskrcnn, cascadercnn);
+ before_nms_indexes = extra_data['nms_keep_idx'].cpu().numpy(
) // self.num_class # num_class
- elif self.cfg.architecture == 'YOLOv3':
- before_nms_indexes = cam_data['before_nms_indexes'].cpu().numpy()
- else:
- print(
- 'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!'
- )
- sys.exit()
+ else :
+ before_nms_indexes = extra_data['nms_keep_idx'].cpu().numpy()
# Calculate and visualize the heatmap of per predict bbox
for index, target_bbox in enumerate(inference_result['bbox']):
@@ -222,20 +240,16 @@ class BBoxCAM:
target_bbox_before_nms = int(before_nms_indexes[index])
- # bbox score vector
- if self.cfg.architecture == 'FasterRCNN':
- # the shape of faster_rcnn scores tensor is
- # [num_of_bboxes_before_nms, num_classes], for example: [1000, 80]
- score_out = cam_data['scores'][target_bbox_before_nms]
- elif self.cfg.architecture == 'YOLOv3':
- # the shape of yolov3 scores tensor is
- # [1, num_classes, num_of_yolo_bboxes_before_nms]
- score_out = cam_data['scores'][0, :, target_bbox_before_nms]
+ if len(extra_data['scores'].shape)==2:
+ score_out = extra_data['scores'][target_bbox_before_nms]
else:
- print(
- 'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!'
- )
- sys.exit()
+ score_out = extra_data['scores'][0, :, target_bbox_before_nms]
+ """
+ There are two kinds array shape of bbox score output :
+ 1) [num_of_bboxes_before_nms, num_classes], for example: [1000, 80]
+ 2) [num_of_image, num_classes, num_of_yolo_bboxes_before_nms], for example: [1, 80, 1000]
+ """
+
# construct one_hot label and do backward to get the gradients
predicted_label = paddle.argmax(score_out)
@@ -245,34 +259,76 @@ class BBoxCAM:
target = paddle.sum(score_out * label_onehot)
target.backward(retain_graph=True)
- if isinstance(self.target_feats[self.target_layer_name], list):
- # when the backbone output contains features of multiple scales,
- # take the featuremap of the last scale
- # Todo: fuse the cam result from multisclae featuremaps
- backbone_grad = self.target_feats[self.target_layer_name][
- -1].grad.squeeze().cpu().numpy()
- backbone_feat = self.target_feats[self.target_layer_name][
- -1].squeeze().cpu().numpy()
- else:
- backbone_grad = self.target_feats[
- self.target_layer_name].grad.squeeze().cpu().numpy()
- backbone_feat = self.target_feats[
- self.target_layer_name].squeeze().cpu().numpy()
-
- # grad_cam:
- exp = grad_cam(backbone_feat, backbone_grad)
- # reshape the cam image to the input image size
- resized_exp = resize_cam(exp, (img.shape[1], img.shape[0]))
+ if 'backbone' in self.target_layer_name or \
+ 'neck' in self.target_layer_name: # backbone/neck level feature
+ if isinstance(self.target_feats[self.target_layer_name], list):
+ # when the featuremap contains of multiple scales,
+ # take the featuremap of the last scale
+ # Todo: fuse the cam result from multisclae featuremaps
+ if self.target_feats[self.target_layer_name][
+ -1].shape[-1]==1:
+ """
+ if the last level featuremap is 1x1 size,
+ we take the second last one
+ """
+ cam_grad = self.target_feats[self.target_layer_name][
+ -2].grad.squeeze().cpu().numpy()
+ cam_feat = self.target_feats[self.target_layer_name][
+ -2].squeeze().cpu().numpy()
+ else:
+ cam_grad = self.target_feats[self.target_layer_name][
+ -1].grad.squeeze().cpu().numpy()
+ cam_feat = self.target_feats[self.target_layer_name][
+ -1].squeeze().cpu().numpy()
+ else:
+ cam_grad = self.target_feats[
+ self.target_layer_name].grad.squeeze().cpu().numpy()
+ cam_feat = self.target_feats[
+ self.target_layer_name].squeeze().cpu().numpy()
+ else: # roi level feature
+ cam_grad = self.target_feats[
+ self.target_layer_name].grad.squeeze().cpu().numpy()[target_bbox_before_nms]
+ cam_feat = self.target_feats[
+ self.target_layer_name].squeeze().cpu().numpy()[target_bbox_before_nms]
- # set the area outside the predic bbox to 0
- mask = np.zeros((img.shape[0], img.shape[1], 3))
- mask[int(target_bbox[3]):int(target_bbox[5]), int(target_bbox[2]):
- int(target_bbox[4]), :] = 1
- resized_exp = resized_exp * mask
+ # grad_cam:
+ exp = grad_cam(cam_feat, cam_grad)
+
+ if 'backbone' in self.target_layer_name or \
+ 'neck' in self.target_layer_name:
+ """
+ when use backbone/neck featuremap,
+ we first do the cam on whole image,
+ and then set the area outside the predic bbox to 0
+ """
+ # reshape the cam image to the input image size
+ resized_exp = resize_cam(exp, (img.shape[1], img.shape[0]))
+ mask = np.zeros((img.shape[0], img.shape[1], 3))
+ mask[int(target_bbox[3]):int(target_bbox[5]), int(target_bbox[2]):
+ int(target_bbox[4]), :] = 1
+ resized_exp = resized_exp * mask
+ # add the bbox cam back to the input image
+ overlay_vis = np.uint8(resized_exp * 0.4 + img * 0.6)
+ elif 'roi' in self.target_layer_name:
+ # get the bbox part of the image
+ bbox_img = copy.deepcopy(img[int(target_bbox[3]):int(target_bbox[5]),
+ int(target_bbox[2]):int(target_bbox[4]), :])
+ # reshape the cam image to the bbox size
+ resized_exp = resize_cam(exp, (bbox_img.shape[1], bbox_img.shape[0]))
+ # add the bbox cam back to the bbox image
+ bbox_overlay_vis = np.uint8(resized_exp * 0.4 + bbox_img * 0.6)
+ # put the bbox_cam image to the original image
+ overlay_vis = copy.deepcopy(img)
+ overlay_vis[int(target_bbox[3]):int(target_bbox[5]),
+ int(target_bbox[2]):int(target_bbox[4]), :] = bbox_overlay_vis
+ else:
+ print(
+ 'Only supported cam for backbone/neck feature and roi feature, the others are not supported temporarily!'
+ )
+ sys.exit()
- # add the bbox cam back to the input image
- overlay_vis = np.uint8(resized_exp * 0.4 + img * 0.6)
+ # put the bbox rectangle on image
cv2.rectangle(
overlay_vis, (int(target_bbox[2]), int(target_bbox[3])),
(int(target_bbox[4]), int(target_bbox[5])), (0, 0, 255), 2)
diff --git a/tools/cam_ppdet.py b/tools/cam_ppdet.py
index e1c3019959146e8b70928b23731259c756fc2634..e73f082bb2951ee72c686ad92f10c6d0add07262 100644
--- a/tools/cam_ppdet.py
+++ b/tools/cam_ppdet.py
@@ -58,34 +58,16 @@ def parse_args():
default=False,
help="Whether to save inference results to output_dir.")
parser.add_argument(
- "--slice_infer",
- action='store_true',
- help="Whether to slice the image and merge the inference results for small object detection."
- )
- parser.add_argument(
- '--slice_size',
- nargs='+',
- type=int,
- default=[640, 640],
- help="Height of the sliced image.")
- parser.add_argument(
- "--overlap_ratio",
- nargs='+',
- type=float,
- default=[0.25, 0.25],
- help="Overlap height ratio of the sliced image.")
- parser.add_argument(
- "--combine_method",
+ "--target_feature_layer_name",
type=str,
- default='nms',
- help="Combine method of the sliced images' detection results, choose in ['nms', 'nmm', 'concat']."
- )
+ default='model.backbone', # define the featuremap to show grad cam, such as model.backbone, model.bbox_head.roi_extractor
+ help="Whether to save inference results to output_dir.")
args = parser.parse_args()
return args
def run(FLAGS, cfg):
- assert cfg.architecture in ['FasterRCNN', 'YOLOv3'], 'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!'
+ assert cfg.architecture in ['FasterRCNN', 'MaskRCNN', 'YOLOv3', 'BlazeFace', 'SSD', 'RetinaNet'], 'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!'
bbox_cam = BBoxCAM(FLAGS, cfg)
bbox_cam.get_bboxes_cams()