未验证 提交 5731e99d 编写于 作者: X xs1997zju 提交者: GitHub

GradCAM Refinement (#7685)

* GradCAM Refinement

* ROI CAM bugfix

* Add support for blazeface, ssd, retinanet

* Change nms return_idx default value to False

* Add use_extra_data attr

* pre-commit code style

* fix yolo extra_data process
上级 15fd9316
......@@ -2,21 +2,22 @@
## 1.简介
基于backbone特征图计算物体预测框的cam(类激活图)
基于backbone/roi特征图计算物体预测框的cam(类激活图)
## 2.使用方法
* 以PP-YOLOE为例,准备好数据之后,指定网络配置文件、模型权重地址和图片路径以及输出文件夹路径,使用脚本调用tools/cam_ppdet.py计算图片中物体预测框的grad_cam热力图。下面为运行脚本示例。
```shell
python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
```
* **参数**
| FLAG | 用途 |
| :----------------------: |:-----------------------------------------------------------------------------------------------------:|
|:--------------------------:|:--------------------------------------------------------------------------------------------------------------------------:|
| -c | 指定配置文件 |
| --infer_img | 用于预测的图片路径 |
| --cam_out | 指定输出路径 |
| --target_feature_layer_name | 计算cam的特征图位置, 如model.backbone、 model.bbox_head.roi_extractor |
| -o | 设置或更改配置文件里的参数内容, 如 -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams |
* 运行效果
......@@ -26,12 +27,43 @@ python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer
</center>
<br><center>cam_ppyoloe/225.jpg</center></br>
## 3. 目前支持基于FasterRCNN和YOLOv3系列的网络。
* FasterRCNN网络热图可视化脚本
## 3. 目前支持基于FasterRCNN/MaskRCNN, PPYOLOE系列以及BlazeFace, SSD, Retinanet网络。
* PPYOLOE网络热图可视化脚本
```bash
python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams
python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
```
* PPYOLOE网络热图可视化脚本
* MaskRCNN网络roi特征热图可视化脚本
```bash
python tools/cam_ppdet.py -c configs/mask_rcnn/mask_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_mask_rcnn_roi --target_feature_layer_name model.bbox_head.roi_extractor -o weights=https://paddledet.bj.bcebos.com/models/mask_rcnn_r50_vd_fpn_2x_coco.pdparams
```
* MaskRCNN网络backbone特征的热图可视化脚本
```bash
python tools/cam_ppdet.py -c configs/mask_rcnn/mask_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_mask_rcnn_backbone --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/mask_rcnn_r50_vd_fpn_2x_coco.pdparams
```
* FasterRCNN网络基于roi特征的热图可视化脚本
```bash
python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn_roi --target_feature_layer_name model.bbox_head.roi_extractor -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams
```
* FasterRCNN网络基于backbone特征的热图可视化脚本
```bash
python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn_backbone --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams
```
* BlaczeFace网络backbone特征热图可视化脚本
```bash
python tools/cam_ppdet.py -c configs/face_detection/blazeface_1000e.yml --infer_img demo/hrnet_demo.jpg --cam_out cam_blazeface --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/blazeface_1000e.pdparams
```
* SSD网络backbone特征热图可视化脚本
```bash
python tools/cam_ppdet.py -c configs/ssd/ssd_mobilenet_v1_300_120e_voc.yml --infer_img demo/000000014439.jpg --cam_out cam_ssd --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/ssd_mobilenet_v1_300_120e_voc.pdparams
```
* Retinanet网络backbone特征热图可视化脚本
```bash
python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
python tools/cam_ppdet.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_retinanet --target_feature_layer_name model.backbone -o weights=https://bj.bcebos.com/v1/paddledet/models/retinanet_r50_fpn_2x_coco.pdparams
```
# Object detection grad_cam heatmap
## 1.Introduction
Calculate the cam (class activation map) of the object predict bbox based on the backbone feature map
Calculate the cam (class activation map) of the object predict bbox based on the backbone/roi feature map
## 2.Usage
* Taking PP-YOLOE as an example, after preparing the data, specify the network configuration file, model weight address, image path and output folder path, and then use the script to call tools/cam_ppdet.py to calculate the grad_cam heat map of the prediction box. Below is an example run script.
```shell
python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
```
* **Arguments**
......@@ -16,6 +16,7 @@ python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer
| -c | Select config file |
| --infer_img | Image path |
| --cam_out | Directory for output |
| --target_feature_layer_name | The position of featuremap to do gradcam, for example:model.backbone, model.bbox_head.roi_extractor |
| -o | Set parameters in configure file, for example: -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams |
* result
......@@ -26,12 +27,45 @@ python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer
<br><center>cam_ppyoloe/225.jpg</center></br>
## 3.Currently supports networks based on FasterRCNN and YOLOv3 series.
* FasterRCNN bbox heat map visualization script
## 3.Currently supports networks based on FasterRCNN/MaskRCNN, PPYOLOE series and BlazeFace, SSD, Retinanet.
* PPYOLOE bbox heat map visualization script (with backbone featuremap)
```bash
python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams
python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
```
* PPYOLOE bbox heat map visualization script
* MaskRCNN bbox heat map visualization script (with roi featuremap)
```bash
python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
python tools/cam_ppdet.py -c configs/mask_rcnn/mask_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_mask_rcnn_roi --target_feature_layer_name model.bbox_head.roi_extractor -o weights=https://paddledet.bj.bcebos.com/models/mask_rcnn_r50_vd_fpn_2x_coco.pdparams
```
* MaskRCNN bbox heat map visualization script (with backbone featuremap)
```bash
python tools/cam_ppdet.py -c configs/mask_rcnn/mask_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_mask_rcnn_backbone --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/mask_rcnn_r50_vd_fpn_2x_coco.pdparams
```
* FasterRCNN bbox heat map visualization script (with roi featuremap)
```bash
python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn_roi --target_feature_layer_name model.bbox_head.roi_extractor -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams
```
* FasterRCNN bbox heat map visualization script (with backbone featuremap)
```bash
python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn_backbone --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams
```
* BlaczeFace bbox heat map visualization script (with backbone featuremap)
```bash
python tools/cam_ppdet.py -c configs/face_detection/blazeface_1000e.yml --infer_img demo/hrnet_demo.jpg --cam_out cam_blazeface --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/blazeface_1000e.pdparams
```
* SSD bbox heat map visualization script (with backbone featuremap)
```bash
python tools/cam_ppdet.py -c configs/ssd/ssd_mobilenet_v1_300_120e_voc.yml --infer_img demo/000000014439.jpg --cam_out cam_ssd --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/ssd_mobilenet_v1_300_120e_voc.pdparams
```
* Retinanet bbox heat map visualization script (with backbone featuremap)
```bash
python tools/cam_ppdet.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_retinanet --target_feature_layer_name model.backbone -o weights=https://bj.bcebos.com/v1/paddledet/models/retinanet_r50_fpn_2x_coco.pdparams
```
......@@ -18,6 +18,8 @@ from __future__ import print_function
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
import paddle
import paddle.nn.functional as F
__all__ = ['BlazeFace']
......@@ -74,18 +76,42 @@ class BlazeFace(BaseArch):
self.inputs['gt_class'])
else:
preds, anchors = self.blaze_head(neck_feats, self.inputs['image'])
bbox, bbox_num, before_nms_indexes = self.post_process(
bbox, bbox_num, nms_keep_idx = self.post_process(
preds, anchors, self.inputs['im_shape'],
self.inputs['scale_factor'])
if self.use_extra_data:
extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx
"""extra_data:{
'scores': predict scores,
'nms_keep_idx': bbox index before nms,
}
"""
preds_logits = preds[1] # [[1xNumBBoxNumClass]]
extra_data['scores'] = F.softmax(paddle.concat(
preds_logits, axis=1)).transpose([0, 2, 1])
extra_data['logits'] = paddle.concat(
preds_logits, axis=1).transpose([0, 2, 1])
extra_data['nms_keep_idx'] = nms_keep_idx # bbox index before nms
return bbox, bbox_num, extra_data
else:
return bbox, bbox_num
def get_loss(self, ):
return {"loss": self._forward()}
def get_pred(self):
if self.use_extra_data:
bbox_pred, bbox_num, extra_data = self._forward()
output = {
"bbox": bbox_pred,
"bbox_num": bbox_num,
"extra_data": extra_data
}
else:
bbox_pred, bbox_num = self._forward()
output = {
"bbox": bbox_pred,
"bbox_num": bbox_num,
}
return output
......@@ -108,7 +108,7 @@ class CascadeRCNN(BaseArch):
im_shape = self.inputs['im_shape']
scale_factor = self.inputs['scale_factor']
bbox, bbox_num, before_nms_indexes = self.bbox_post_process(
bbox, bbox_num, nms_keep_idx = self.bbox_post_process(
preds, (refined_rois, rois_num), im_shape, scale_factor)
# rescale the prediction back to origin image
bbox, bbox_pred, bbox_num = self.bbox_post_process.get_pred(
......
......@@ -82,21 +82,31 @@ class FasterRCNN(BaseArch):
self.inputs)
return rpn_loss, bbox_loss
else:
cam_data = {} # record bbox scores and index before nms
rois, rois_num, _ = self.rpn_head(body_feats, self.inputs)
preds, _ = self.bbox_head(body_feats, rois, rois_num, None)
cam_data['scores'] = preds[1]
im_shape = self.inputs['im_shape']
scale_factor = self.inputs['scale_factor']
bbox, bbox_num, before_nms_indexes = self.bbox_post_process(preds, (rois, rois_num),
bbox, bbox_num, nms_keep_idx = self.bbox_post_process(preds, (rois, rois_num),
im_shape, scale_factor)
cam_data['before_nms_indexes'] = before_nms_indexes # , bbox index before nms, for cam
# rescale the prediction back to origin image
bboxes, bbox_pred, bbox_num = self.bbox_post_process.get_pred(
bbox, bbox_num, im_shape, scale_factor)
return bbox_pred, bbox_num, cam_data
if self.use_extra_data:
extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx
"""extra_data:{
'scores': predict scores,
'nms_keep_idx': bbox index before nms,
}
"""
extra_data['scores'] = preds[1] # predict scores (probability)
# Todo: get logits output
extra_data['nms_keep_idx'] = nms_keep_idx # bbox index before nms
return bbox_pred, bbox_num, extra_data
else:
return bbox_pred, bbox_num
def get_loss(self, ):
rpn_loss, bbox_loss = self._forward()
......@@ -108,8 +118,12 @@ class FasterRCNN(BaseArch):
return loss
def get_pred(self):
bbox_pred, bbox_num, cam_data = self._forward()
output = {'bbox': bbox_pred, 'bbox_num': bbox_num, 'cam_data': cam_data}
if self.use_extra_data:
bbox_pred, bbox_num, extra_data = self._forward()
output = {'bbox': bbox_pred, 'bbox_num': bbox_num, 'extra_data': extra_data}
else:
bbox_pred, bbox_num = self._forward()
output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
return output
def target_bbox_forward(self, data):
......
......@@ -106,7 +106,7 @@ class MaskRCNN(BaseArch):
im_shape = self.inputs['im_shape']
scale_factor = self.inputs['scale_factor']
bbox, bbox_num, before_nms_indexes = self.bbox_post_process(
bbox, bbox_num, nms_keep_idx = self.bbox_post_process(
preds, (rois, rois_num), im_shape, scale_factor)
mask_out = self.mask_head(
body_feats, bbox, bbox_num, self.inputs, feat_func=feat_func)
......@@ -117,6 +117,19 @@ class MaskRCNN(BaseArch):
origin_shape = self.bbox_post_process.get_origin_shape()
mask_pred = self.mask_post_process(mask_out, bbox_pred, bbox_num,
origin_shape)
if self.use_extra_data:
extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx
"""extra_data:{
'scores': predict scores,
'nms_keep_idx': bbox index before nms,
}
"""
extra_data['scores'] = preds[1] # predict scores (probability)
# Todo: get logits output
extra_data['nms_keep_idx'] = nms_keep_idx # bbox index before nms
return bbox_pred, bbox_num, mask_pred, extra_data
else:
return bbox_pred, bbox_num, mask_pred
def get_loss(self, ):
......@@ -130,6 +143,10 @@ class MaskRCNN(BaseArch):
return loss
def get_pred(self):
if self.use_extra_data:
bbox_pred, bbox_num, mask_pred, extra_data = self._forward()
output = {'bbox': bbox_pred, 'bbox_num': bbox_num, 'mask': mask_pred, 'extra_data': extra_data}
else:
bbox_pred, bbox_num, mask_pred = self._forward()
output = {'bbox': bbox_pred, 'bbox_num': bbox_num, 'mask': mask_pred}
return output
......@@ -15,11 +15,12 @@ __all__ = ['BaseArch']
@register
class BaseArch(nn.Layer):
def __init__(self, data_format='NCHW'):
def __init__(self, data_format='NCHW', use_extra_data=False):
super(BaseArch, self).__init__()
self.data_format = data_format
self.inputs = {}
self.fuse_norm = False
self.use_extra_data = use_extra_data
def load_meanstd(self, cfg_transform):
scale = 1.
......
......@@ -106,21 +106,30 @@ class PPYOLOE(BaseArch):
raise ValueError
return yolo_losses
else:
cam_data = {} # record bbox scores and index before nms
yolo_head_outs = self.yolo_head(neck_feats)
cam_data['scores'] = yolo_head_outs[0]
if self.post_process is not None:
bbox, bbox_num, before_nms_indexes = self.post_process(
bbox, bbox_num, nms_keep_idx = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
cam_data['before_nms_indexes'] = before_nms_indexes
else:
bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process(
bbox, bbox_num, nms_keep_idx = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor'])
# data for cam
cam_data['before_nms_indexes'] = before_nms_indexes
output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data}
if self.use_extra_data:
extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx
"""extra_data:{
'scores': predict scores,
'nms_keep_idx': bbox index before nms,
}
"""
extra_data['scores'] = yolo_head_outs[0] # predict scores (probability)
extra_data['nms_keep_idx'] = nms_keep_idx
output = {'bbox': bbox, 'bbox_num': bbox_num, 'extra_data': extra_data}
else:
output = {'bbox': bbox, 'bbox_num': bbox_num}
return output
......@@ -218,21 +227,29 @@ class PPYOLOEWithAuxHead(BaseArch):
aux_pred=[aux_cls_scores, aux_bbox_preds])
return loss
else:
cam_data = {} # record bbox scores and index before nms
yolo_head_outs = self.yolo_head(neck_feats)
cam_data['scores'] = yolo_head_outs[0]
if self.post_process is not None:
bbox, bbox_num, before_nms_indexes = self.post_process(
bbox, bbox_num, nms_keep_idx = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
cam_data['before_nms_indexes'] = before_nms_indexes
else:
bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process(
bbox, bbox_num, nms_keep_idx = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor'])
# data for cam
cam_data['before_nms_indexes'] = before_nms_indexes
output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data}
if self.use_extra_data:
extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx
"""extra_data:{
'scores': predict scores,
'nms_keep_idx': bbox index before nms,
}
"""
extra_data['scores'] = yolo_head_outs[0] # predict scores (probability)
# Todo: get logits output
extra_data['nms_keep_idx'] = nms_keep_idx
output = {'bbox': bbox, 'bbox_num': bbox_num, 'extra_data': extra_data}
else:
output = {'bbox': bbox, 'bbox_num': bbox_num}
return output
......
......@@ -19,6 +19,7 @@ from __future__ import print_function
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
import paddle
import paddle.nn.functional as F
__all__ = ['RetinaNet']
......@@ -57,8 +58,23 @@ class RetinaNet(BaseArch):
return self.head(neck_feats, self.inputs)
else:
head_outs = self.head(neck_feats)
bbox, bbox_num = self.head.post_process(
bbox, bbox_num, nms_keep_idx = self.head.post_process(
head_outs, self.inputs['im_shape'], self.inputs['scale_factor'])
if self.use_extra_data:
extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx
"""extra_data:{
'scores': predict scores,
'nms_keep_idx': bbox index before nms,
}
"""
preds_logits = self.head.decode_cls_logits(head_outs[0])
preds_scores = F.sigmoid(preds_logits)
extra_data['logits'] = preds_logits
extra_data['scores'] = preds_scores
extra_data['nms_keep_idx'] = nms_keep_idx # bbox index before nms
return {'bbox': bbox, 'bbox_num': bbox_num, "extra_data": extra_data}
else:
return {'bbox': bbox, 'bbox_num': bbox_num}
def get_loss(self):
......
......@@ -18,6 +18,8 @@ from __future__ import print_function
from ppdet.core.workspace import register, create
from .meta_arch import BaseArch
import paddle
import paddle.nn.functional as F
__all__ = ['SSD']
......@@ -75,15 +77,39 @@ class SSD(BaseArch):
self.inputs['gt_class'])
else:
preds, anchors = self.ssd_head(body_feats, self.inputs['image'])
bbox, bbox_num, before_nms_indexes = self.post_process(
bbox, bbox_num, nms_keep_idx = self.post_process(
preds, anchors, self.inputs['im_shape'],
self.inputs['scale_factor'])
if self.use_extra_data:
extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx
"""extra_data:{
'scores': predict scores,
'nms_keep_idx': bbox index before nms,
}
"""
preds_logits = preds[1] # [[1xNumBBoxNumClass]]
extra_data['scores'] = F.softmax(paddle.concat(
preds_logits, axis=1)).transpose([0, 2, 1])
extra_data['logits'] = paddle.concat(
preds_logits, axis=1).transpose([0, 2, 1])
extra_data['nms_keep_idx'] = nms_keep_idx # bbox index before nms
return bbox, bbox_num, extra_data
else:
return bbox, bbox_num
def get_loss(self, ):
return {"loss": self._forward()}
def get_pred(self):
if self.use_extra_data:
bbox_pred, bbox_num, extra_data = self._forward()
output = {
"bbox": bbox_pred,
"bbox_num": bbox_num,
"extra_data": extra_data
}
else:
bbox_pred, bbox_num = self._forward()
output = {
"bbox": bbox_pred,
......
......@@ -98,9 +98,7 @@ class YOLOv3(BaseArch):
return yolo_losses
else:
cam_data = {} # record bbox scores and index before nms
yolo_head_outs = self.yolo_head(neck_feats)
cam_data['scores'] = yolo_head_outs[0]
if self.for_mot:
# the detection part of JDE MOT model
......@@ -116,21 +114,32 @@ class YOLOv3(BaseArch):
else:
if self.return_idx:
# the detection part of JDE MOT model
_, bbox, bbox_num, _ = self.post_process(
_, bbox, bbox_num, nms_keep_idx = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors)
elif self.post_process is not None:
# anchor based YOLOs: YOLOv3,PP-YOLO,PP-YOLOv2 use mask_anchors
bbox, bbox_num, before_nms_indexes = self.post_process(
bbox, bbox_num, nms_keep_idx = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
cam_data['before_nms_indexes'] = before_nms_indexes
else:
# anchor free YOLOs: PP-YOLOE, PP-YOLOE+
bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process(
bbox, bbox_num, nms_keep_idx = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor'])
# data for cam
cam_data['before_nms_indexes'] = before_nms_indexes
output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data}
if self.use_extra_data:
extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx
"""extra_data:{
'scores': predict scores,
'nms_keep_idx': bbox index before nms,
}
"""
extra_data['scores'] = yolo_head_outs[0] # predict scores (probability)
# Todo: get logits output
extra_data['nms_keep_idx'] = nms_keep_idx
# Todo support for mask_anchors yolo
output = {'bbox': bbox, 'bbox_num': bbox_num, 'extra_data': extra_data}
else:
output = {'bbox': bbox, 'bbox_num': bbox_num}
return output
......
......@@ -525,9 +525,9 @@ class PPYOLOEHead(nn.Layer):
# `exclude_nms=True` just use in benchmark
return pred_bboxes, pred_scores, None
else:
bbox_pred, bbox_num, before_nms_indexes = self.nms(pred_bboxes,
bbox_pred, bbox_num, nms_keep_idx = self.nms(pred_bboxes,
pred_scores)
return bbox_pred, bbox_num, before_nms_indexes
return bbox_pred, bbox_num, nms_keep_idx
def get_activation(name="LeakyReLU"):
......
......@@ -420,6 +420,6 @@ class PPYOLOERHead(nn.Layer):
pred_bboxes /= scale_factor
if self.export_onnx:
return pred_bboxes, pred_scores, None
bbox_pred, bbox_num, before_nms_indexes = self.nms(pred_bboxes,
bbox_pred, bbox_num, nms_keep_idx = self.nms(pred_bboxes,
pred_scores)
return bbox_pred, bbox_num, before_nms_indexes
return bbox_pred, bbox_num, nms_keep_idx
......@@ -245,5 +245,34 @@ class RetinaHead(nn.Layer):
bboxes, scores = self.decode(anchors, cls_logits, bboxes_reg, im_shape,
scale_factor)
bbox_pred, bbox_num, _ = self.nms(bboxes, scores)
return bbox_pred, bbox_num
bbox_pred, bbox_num, nms_keep_idx = self.nms(bboxes, scores)
return bbox_pred, bbox_num, nms_keep_idx
def get_scores_single(self, cls_scores_list):
mlvl_logits = []
for cls_score in cls_scores_list:
cls_score = cls_score.reshape([-1, self.num_classes])
if self.nms_pre is not None and cls_score.shape[0] > self.nms_pre:
max_score = cls_score.max(axis=1)
_, topk_inds = max_score.topk(self.nms_pre)
cls_score = cls_score.gather(topk_inds)
mlvl_logits.append(cls_score)
mlvl_logits = paddle.concat(mlvl_logits)
mlvl_logits = mlvl_logits.transpose([1, 0])
return mlvl_logits
def decode_cls_logits(self, cls_logits_list):
cls_logits = [_.transpose([0, 2, 3, 1]) for _ in cls_logits_list]
batch_logits = []
for img_id in range(cls_logits[0].shape[0]):
num_lvls = len(cls_logits)
cls_scores_list = [cls_logits[i][img_id] for i in range(num_lvls)]
logits = self.get_scores_single(cls_scores_list)
batch_logits.append(logits)
batch_logits = paddle.stack(batch_logits, axis=0)
return batch_logits
......@@ -15,6 +15,7 @@
import paddle
from ppdet.core.workspace import register
from ppdet.modeling import ops
import paddle.nn as nn
def _to_list(v):
......@@ -24,7 +25,7 @@ def _to_list(v):
@register
class RoIAlign(object):
class RoIAlign(nn.Layer):
"""
RoI Align module
......@@ -73,7 +74,7 @@ class RoIAlign(object):
def from_config(cls, cfg, input_shape):
return {'spatial_scale': [1. / i.stride for i in input_shape]}
def __call__(self, feats, roi, rois_num):
def forward(self, feats, roi, rois_num):
roi = paddle.concat(roi) if len(roi) > 1 else roi[0]
if len(feats) == 1:
rois_feat = paddle.vision.ops.roi_align(
......
......@@ -4,6 +4,7 @@ import os
import sys
import glob
from ppdet.utils.logger import setup_logger
import copy
logger = setup_logger('ppdet_cam')
import paddle
......@@ -119,7 +120,12 @@ class BBoxCAM:
# num_class
self.num_class = cfg.num_classes
# set hook for extraction of featuremaps and grads
self.set_hook()
self.set_hook(cfg)
self.nms_idx_need_divid_numclass_arch = ['FasterRCNN', 'MaskRCNN', 'CascadeRCNN']
"""
In these networks, the bbox array shape before nms contain num_class,
the nms_keep_idx of the bbox need to divide the num_class;
"""
# cam image output_dir
try:
......@@ -134,9 +140,10 @@ class BBoxCAM:
# load weights
trainer.load_weights(cfg.weights)
# record the bbox index before nms
# Todo: hard code for nms return index
if cfg.architecture == 'FasterRCNN':
# set for get extra_data before nms
trainer.model.use_extra_data=True
# set for record the bbox index before nms
if cfg.architecture == 'FasterRCNN' or cfg.architecture == 'MaskRCNN':
trainer.model.bbox_post_process.nms.return_index = True
elif cfg.architecture == 'YOLOv3':
if trainer.model.post_process is not None:
......@@ -145,24 +152,39 @@ class BBoxCAM:
else:
# anchor free YOLOs: PP-YOLOE, PP-YOLOE+
trainer.model.yolo_head.nms.return_index = True
elif cfg.architecture=='BlazeFace' or cfg.architecture=='SSD':
trainer.model.post_process.nms.return_index = True
elif cfg.architecture=='RetinaNet':
trainer.model.head.nms.return_index = True
else:
print(
'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!'
cfg.architecture+' is not supported for cam temporarily!'
)
sys.exit()
# Todo: Unify the head/post_process name in each model
return trainer
def set_hook(self):
def set_hook(self, cfg):
# set hook for extraction of featuremaps and grads
self.target_feats = {}
self.target_layer_name = 'trainer.model.backbone'
self.target_layer_name = cfg.target_feature_layer_name
# such as trainer.model.backbone, trainer.model.bbox_head.roi_extractor
def hook(layer, input, output):
self.target_feats[layer._layer_name_for_hook] = output
self.trainer.model.backbone._layer_name_for_hook = self.target_layer_name
self.trainer.model.backbone.register_forward_post_hook(hook)
try:
exec('self.trainer.'+self.target_layer_name+'._layer_name_for_hook = self.target_layer_name')
# self.trainer.target_layer_name._layer_name_for_hook = self.target_layer_name
exec('self.trainer.'+self.target_layer_name+'.register_forward_post_hook(hook)')
# self.trainer.target_layer_name.register_forward_post_hook(hook)
except:
print("Error! "
"The target_layer_name--"+self.target_layer_name+" is not in model! "
"Please check the spelling and "
"the network's architecture!")
sys.exit()
def get_bboxes(self):
# get inference images
......@@ -187,31 +209,27 @@ class BBoxCAM:
img = np.array(Image.open(self.cfg.infer_img))
# data for calaulate bbox grad_cam
cam_data = inference_result['cam_data']
extra_data = inference_result['extra_data']
"""
if Faster_RCNN based architecture:
cam_data: {'scores': tensor with shape [num_of_bboxes_before_nms, num_classes], for example: [1000, 80]
'before_nms_indexes': tensor with shape [num_of_bboxes_after_nms, 1], for example: [300, 1]
Example of Faster_RCNN based architecture:
extra_data: {'scores': tensor with shape [num_of_bboxes_before_nms, num_classes], for example: [1000, 80]
'nms_keep_idx': tensor with shape [num_of_bboxes_after_nms, 1], for example: [300, 1]
}
elif YOLOv3 based architecture:
cam_data: {'scores': tensor with shape [1, num_classes, num_of_yolo_bboxes_before_nms], #for example: [1, 80, 8400]
'before_nms_indexes': tensor with shape [num_of_yolo_bboxes_after_nms, 1], # for example: [300, 1]
Example of YOLOv3 based architecture:
extra_data: {'scores': tensor with shape [1, num_classes, num_of_yolo_bboxes_before_nms], #for example: [1, 80, 8400]
'nms_keep_idx': tensor with shape [num_of_yolo_bboxes_after_nms, 1], # for example: [300, 1]
}
"""
# array index of the predicted bbox before nms
if self.cfg.architecture == 'FasterRCNN':
# the bbox array shape of FasterRCNN before nms is [num_of_bboxes_before_nms, num_classes, 4],
if self.cfg.architecture in self.nms_idx_need_divid_numclass_arch:
# some network's bbox array shape before nms may be like [num_of_bboxes_before_nms, num_classes, 4],
# we need to divide num_classes to get the before_nms_index;
before_nms_indexes = cam_data['before_nms_indexes'].cpu().numpy(
# currently, only include the rcnn architectures (fasterrcnn, maskrcnn, cascadercnn);
before_nms_indexes = extra_data['nms_keep_idx'].cpu().numpy(
) // self.num_class # num_class
elif self.cfg.architecture == 'YOLOv3':
before_nms_indexes = cam_data['before_nms_indexes'].cpu().numpy()
else:
print(
'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!'
)
sys.exit()
else :
before_nms_indexes = extra_data['nms_keep_idx'].cpu().numpy()
# Calculate and visualize the heatmap of per predict bbox
for index, target_bbox in enumerate(inference_result['bbox']):
......@@ -222,20 +240,16 @@ class BBoxCAM:
target_bbox_before_nms = int(before_nms_indexes[index])
# bbox score vector
if self.cfg.architecture == 'FasterRCNN':
# the shape of faster_rcnn scores tensor is
# [num_of_bboxes_before_nms, num_classes], for example: [1000, 80]
score_out = cam_data['scores'][target_bbox_before_nms]
elif self.cfg.architecture == 'YOLOv3':
# the shape of yolov3 scores tensor is
# [1, num_classes, num_of_yolo_bboxes_before_nms]
score_out = cam_data['scores'][0, :, target_bbox_before_nms]
if len(extra_data['scores'].shape)==2:
score_out = extra_data['scores'][target_bbox_before_nms]
else:
print(
'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!'
)
sys.exit()
score_out = extra_data['scores'][0, :, target_bbox_before_nms]
"""
There are two kinds array shape of bbox score output :
1) [num_of_bboxes_before_nms, num_classes], for example: [1000, 80]
2) [num_of_image, num_classes, num_of_yolo_bboxes_before_nms], for example: [1, 80, 1000]
"""
# construct one_hot label and do backward to get the gradients
predicted_label = paddle.argmax(score_out)
......@@ -245,34 +259,76 @@ class BBoxCAM:
target = paddle.sum(score_out * label_onehot)
target.backward(retain_graph=True)
if 'backbone' in self.target_layer_name or \
'neck' in self.target_layer_name: # backbone/neck level feature
if isinstance(self.target_feats[self.target_layer_name], list):
# when the backbone output contains features of multiple scales,
# when the featuremap contains of multiple scales,
# take the featuremap of the last scale
# Todo: fuse the cam result from multisclae featuremaps
backbone_grad = self.target_feats[self.target_layer_name][
if self.target_feats[self.target_layer_name][
-1].shape[-1]==1:
"""
if the last level featuremap is 1x1 size,
we take the second last one
"""
cam_grad = self.target_feats[self.target_layer_name][
-2].grad.squeeze().cpu().numpy()
cam_feat = self.target_feats[self.target_layer_name][
-2].squeeze().cpu().numpy()
else:
cam_grad = self.target_feats[self.target_layer_name][
-1].grad.squeeze().cpu().numpy()
backbone_feat = self.target_feats[self.target_layer_name][
cam_feat = self.target_feats[self.target_layer_name][
-1].squeeze().cpu().numpy()
else:
backbone_grad = self.target_feats[
cam_grad = self.target_feats[
self.target_layer_name].grad.squeeze().cpu().numpy()
backbone_feat = self.target_feats[
cam_feat = self.target_feats[
self.target_layer_name].squeeze().cpu().numpy()
else: # roi level feature
cam_grad = self.target_feats[
self.target_layer_name].grad.squeeze().cpu().numpy()[target_bbox_before_nms]
cam_feat = self.target_feats[
self.target_layer_name].squeeze().cpu().numpy()[target_bbox_before_nms]
# grad_cam:
exp = grad_cam(backbone_feat, backbone_grad)
exp = grad_cam(cam_feat, cam_grad)
if 'backbone' in self.target_layer_name or \
'neck' in self.target_layer_name:
"""
when use backbone/neck featuremap,
we first do the cam on whole image,
and then set the area outside the predic bbox to 0
"""
# reshape the cam image to the input image size
resized_exp = resize_cam(exp, (img.shape[1], img.shape[0]))
# set the area outside the predic bbox to 0
mask = np.zeros((img.shape[0], img.shape[1], 3))
mask[int(target_bbox[3]):int(target_bbox[5]), int(target_bbox[2]):
int(target_bbox[4]), :] = 1
resized_exp = resized_exp * mask
# add the bbox cam back to the input image
overlay_vis = np.uint8(resized_exp * 0.4 + img * 0.6)
elif 'roi' in self.target_layer_name:
# get the bbox part of the image
bbox_img = copy.deepcopy(img[int(target_bbox[3]):int(target_bbox[5]),
int(target_bbox[2]):int(target_bbox[4]), :])
# reshape the cam image to the bbox size
resized_exp = resize_cam(exp, (bbox_img.shape[1], bbox_img.shape[0]))
# add the bbox cam back to the bbox image
bbox_overlay_vis = np.uint8(resized_exp * 0.4 + bbox_img * 0.6)
# put the bbox_cam image to the original image
overlay_vis = copy.deepcopy(img)
overlay_vis[int(target_bbox[3]):int(target_bbox[5]),
int(target_bbox[2]):int(target_bbox[4]), :] = bbox_overlay_vis
else:
print(
'Only supported cam for backbone/neck feature and roi feature, the others are not supported temporarily!'
)
sys.exit()
# put the bbox rectangle on image
cv2.rectangle(
overlay_vis, (int(target_bbox[2]), int(target_bbox[3])),
(int(target_bbox[4]), int(target_bbox[5])), (0, 0, 255), 2)
......
......@@ -58,34 +58,16 @@ def parse_args():
default=False,
help="Whether to save inference results to output_dir.")
parser.add_argument(
"--slice_infer",
action='store_true',
help="Whether to slice the image and merge the inference results for small object detection."
)
parser.add_argument(
'--slice_size',
nargs='+',
type=int,
default=[640, 640],
help="Height of the sliced image.")
parser.add_argument(
"--overlap_ratio",
nargs='+',
type=float,
default=[0.25, 0.25],
help="Overlap height ratio of the sliced image.")
parser.add_argument(
"--combine_method",
"--target_feature_layer_name",
type=str,
default='nms',
help="Combine method of the sliced images' detection results, choose in ['nms', 'nmm', 'concat']."
)
default='model.backbone', # define the featuremap to show grad cam, such as model.backbone, model.bbox_head.roi_extractor
help="Whether to save inference results to output_dir.")
args = parser.parse_args()
return args
def run(FLAGS, cfg):
assert cfg.architecture in ['FasterRCNN', 'YOLOv3'], 'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!'
assert cfg.architecture in ['FasterRCNN', 'MaskRCNN', 'YOLOv3', 'BlazeFace', 'SSD', 'RetinaNet'], 'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!'
bbox_cam = BBoxCAM(FLAGS, cfg)
bbox_cam.get_bboxes_cams()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册