未验证 提交 ffce1218 编写于 作者: X xs1997zju 提交者: GitHub

Support GradCAM (#7626)

* fix slice infer one image save_results (#7654)

* Support GradCAM Cascade_rcnn forward bugfix

* code style fix

* BBoxCAM class name fix

* Add gradcam tutorial and demo

---------
Co-authored-by: NFeng Ni <nemonameless@qq.com>
上级 d56cf3f7
# 目标检测热力图
## 1.简介
基于backbone特征图计算物体预测框的cam(类激活图)
## 2.使用方法
* 以PP-YOLOE为例,准备好数据之后,指定网络配置文件、模型权重地址和图片路径以及输出文件夹路径,使用脚本调用tools/cam_ppdet.py计算图片中物体预测框的grad_cam热力图。下面为运行脚本示例。
```shell
python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
```
* **参数**
| FLAG | 用途 |
| :----------------------: |:-----------------------------------------------------------------------------------------------------:|
| -c | 指定配置文件 |
| --infer_img | 用于预测的图片路径 |
| --cam_out | 指定输出路径 |
| -o | 设置或更改配置文件里的参数内容, 如 -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams |
* 运行效果
<center>
<img src="../images/grad_cam_ppyoloe_demo.jpg" width="500" >
</center>
<br><center>cam_ppyoloe/225.jpg</center></br>
## 3. 目前支持基于FasterRCNN和YOLOv3系列的网络。
* FasterRCNN网络热图可视化脚本
```bash
python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams
```
* PPYOLOE网络热图可视化脚本
```bash
python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
```
\ No newline at end of file
# Object detection grad_cam heatmap
## 1.Introduction
Calculate the cam (class activation map) of the object predict bbox based on the backbone feature map
## 2.Usage
* Taking PP-YOLOE as an example, after preparing the data, specify the network configuration file, model weight address, image path and output folder path, and then use the script to call tools/cam_ppdet.py to calculate the grad_cam heat map of the prediction box. Below is an example run script.
```shell
python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
```
* **Arguments**
| FLAG | description |
| :----------------------: |:---------------------------------------------------------------------------------------------------------------------------------:|
| -c | Select config file |
| --infer_img | Image path |
| --cam_out | Directory for output |
| -o | Set parameters in configure file, for example: -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams |
* result
<center>
<img src="../images/grad_cam_ppyoloe_demo.jpg" width="500" >
</center>
<br><center>cam_ppyoloe/225.jpg</center></br>
## 3.Currently supports networks based on FasterRCNN and YOLOv3 series.
* FasterRCNN bbox heat map visualization script
```bash
python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams
```
* PPYOLOE bbox heat map visualization script
```bash
python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams
```
\ No newline at end of file
......@@ -1039,7 +1039,7 @@ class Trainer(object):
image.save(save_name, quality=95)
start = end
return results
def _get_save_image_name(self, output_dir, image_path):
"""
Get save image name from source image path.
......
......@@ -108,7 +108,7 @@ class CascadeRCNN(BaseArch):
im_shape = self.inputs['im_shape']
scale_factor = self.inputs['scale_factor']
bbox, bbox_num = self.bbox_post_process(
bbox, bbox_num, before_nms_indexes = self.bbox_post_process(
preds, (refined_rois, rois_num), im_shape, scale_factor)
# rescale the prediction back to origin image
bbox, bbox_pred, bbox_num = self.bbox_post_process.get_pred(
......
......@@ -82,18 +82,21 @@ class FasterRCNN(BaseArch):
self.inputs)
return rpn_loss, bbox_loss
else:
cam_data = {} # record bbox scores and index before nms
rois, rois_num, _ = self.rpn_head(body_feats, self.inputs)
preds, _ = self.bbox_head(body_feats, rois, rois_num, None)
cam_data['scores'] = preds[1]
im_shape = self.inputs['im_shape']
scale_factor = self.inputs['scale_factor']
bbox, bbox_num = self.bbox_post_process(preds, (rois, rois_num),
bbox, bbox_num, before_nms_indexes = self.bbox_post_process(preds, (rois, rois_num),
im_shape, scale_factor)
cam_data['before_nms_indexes'] = before_nms_indexes # , bbox index before nms, for cam
# rescale the prediction back to origin image
bboxes, bbox_pred, bbox_num = self.bbox_post_process.get_pred(
bbox, bbox_num, im_shape, scale_factor)
return bbox_pred, bbox_num
return bbox_pred, bbox_num, cam_data
def get_loss(self, ):
rpn_loss, bbox_loss = self._forward()
......@@ -105,8 +108,8 @@ class FasterRCNN(BaseArch):
return loss
def get_pred(self):
bbox_pred, bbox_num = self._forward()
output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
bbox_pred, bbox_num, cam_data = self._forward()
output = {'bbox': bbox_pred, 'bbox_num': bbox_num, 'cam_data': cam_data}
return output
def target_bbox_forward(self, data):
......
......@@ -98,7 +98,9 @@ class YOLOv3(BaseArch):
return yolo_losses
else:
cam_data = {} # record bbox scores and index before nms
yolo_head_outs = self.yolo_head(neck_feats)
cam_data['scores'] = yolo_head_outs[0]
if self.for_mot:
# the detection part of JDE MOT model
......@@ -118,14 +120,17 @@ class YOLOv3(BaseArch):
yolo_head_outs, self.yolo_head.mask_anchors)
elif self.post_process is not None:
# anchor based YOLOs: YOLOv3,PP-YOLO,PP-YOLOv2 use mask_anchors
bbox, bbox_num = self.post_process(
bbox, bbox_num, before_nms_indexes = self.post_process(
yolo_head_outs, self.yolo_head.mask_anchors,
self.inputs['im_shape'], self.inputs['scale_factor'])
cam_data['before_nms_indexes'] = before_nms_indexes
else:
# anchor free YOLOs: PP-YOLOE, PP-YOLOE+
bbox, bbox_num = self.yolo_head.post_process(
bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process(
yolo_head_outs, self.inputs['scale_factor'])
output = {'bbox': bbox, 'bbox_num': bbox_num}
# data for cam
cam_data['before_nms_indexes'] = before_nms_indexes
output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data}
return output
......
......@@ -462,8 +462,8 @@ class PPYOLOEHead(nn.Layer):
# `exclude_nms=True` just use in benchmark
return pred_bboxes, pred_scores
else:
bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
return bbox_pred, bbox_num
bbox_pred, bbox_num, before_nms_indexes = self.nms(pred_bboxes, pred_scores)
return bbox_pred, bbox_num, before_nms_indexes
def get_activation(name="LeakyReLU"):
......
......@@ -67,7 +67,7 @@ class BBoxPostProcess(object):
"""
if self.nms is not None:
bboxes, score = self.decode(head_out, rois, im_shape, scale_factor)
bbox_pred, bbox_num, _ = self.nms(bboxes, score, self.num_classes)
bbox_pred, bbox_num, before_nms_indexes = self.nms(bboxes, score, self.num_classes)
else:
bbox_pred, bbox_num = self.decode(head_out, rois, im_shape,
......@@ -82,6 +82,9 @@ class BBoxPostProcess(object):
bbox_pred = paddle.concat([bbox_pred, fake_bboxes])
bbox_num = bbox_num + 1
if self.nms is not None:
return bbox_pred, bbox_num, before_nms_indexes
else:
return bbox_pred, bbox_num
def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
......
import numpy as np
import cv2
import os
import sys
import glob
from ppdet.utils.logger import setup_logger
logger = setup_logger('ppdet_cam')
import paddle
from ppdet.engine import Trainer
def get_test_images(infer_dir, infer_img):
"""
Get image path list in TEST mode
"""
assert infer_img is not None or infer_dir is not None, \
"--infer_img or --infer_dir should be set"
assert infer_img is None or os.path.isfile(infer_img), \
"{} is not a file".format(infer_img)
assert infer_dir is None or os.path.isdir(infer_dir), \
"{} is not a directory".format(infer_dir)
# infer_img has a higher priority
if infer_img and os.path.isfile(infer_img):
return [infer_img]
images = set()
infer_dir = os.path.abspath(infer_dir)
assert os.path.isdir(infer_dir), \
"infer_dir {} is not a directory".format(infer_dir)
exts = ['jpg', 'jpeg', 'png', 'bmp']
exts += [ext.upper() for ext in exts]
for ext in exts:
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
images = list(images)
assert len(images) > 0, "no image found in {}".format(infer_dir)
logger.info("Found {} inference images in total.".format(len(images)))
return images
def compute_ious(boxes1, boxes2):
"""[Compute pairwise IOU matrix for given two sets of boxes]
Args:
boxes1 ([numpy ndarray with shape N,4]): [representing bounding boxes with format (xmin,ymin,xmax,ymax)]
boxes2 ([numpy ndarray with shape M,4]): [representing bounding boxes with format (xmin,ymin,xmax,ymax)]
Returns:
pairwise IOU maxtrix with shape (N,M),where the value at ith row jth column hold the iou between ith
box and jth box from box1 and box2 respectively.
"""
lu = np.maximum(
boxes1[:, None, :2], boxes2[:, :2]
) # lu with shape N,M,2 ; boxes1[:,None,:2] with shape (N,1,2) boxes2 with shape(M,2)
rd = np.minimum(boxes1[:, None, 2:], boxes2[:, 2:]) # rd same to lu
intersection_wh = np.maximum(0.0, rd - lu)
intersection_area = intersection_wh[:, :,
0] * intersection_wh[:, :,
1] # with shape (N,M)
boxes1_wh = np.maximum(0.0, boxes1[:, 2:] - boxes1[:, :2])
boxes1_area = boxes1_wh[:, 0] * boxes1_wh[:, 1] # with shape (N,)
boxes2_wh = np.maximum(0.0, boxes2[:, 2:] - boxes2[:, :2])
boxes2_area = boxes2_wh[:, 0] * boxes2_wh[:, 1] # with shape (M,)
union_area = np.maximum(
boxes1_area[:, None] + boxes2_area - intersection_area,
1e-8) # with shape (N,M)
ious = np.clip(intersection_area / union_area, 0.0, 1.0)
return ious
def grad_cam(feat, grad):
"""
Args:
feat: CxHxW
grad: CxHxW
Returns:
cam: HxW
"""
exp = (feat * grad.mean((1, 2), keepdims=True)).mean(axis=0)
exp = np.maximum(-exp, 0)
return exp
def resize_cam(explanation, resize_shape) -> np.ndarray:
"""
Args:
explanation: (width, height)
resize_shape: (width, height)
Returns:
"""
assert len(explanation.shape) == 2, f"{explanation.shape}. " \
f"Currently support 2D explanation results for visualization. " \
"Reduce higher dimensions to 2D for visualization."
explanation = (explanation - explanation.min()) / (
explanation.max() - explanation.min())
explanation = cv2.resize(explanation, resize_shape)
explanation = np.uint8(255 * explanation)
explanation = cv2.applyColorMap(explanation, cv2.COLORMAP_JET)
explanation = cv2.cvtColor(explanation, cv2.COLOR_BGR2RGB)
return explanation
class BBoxCAM:
def __init__(self, FLAGS, cfg):
self.FLAGS = FLAGS
self.cfg = cfg
# build model
self.trainer = self.build_trainer(cfg)
# num_class
self.num_class = cfg.num_classes
# set hook for extraction of featuremaps and grads
self.set_hook()
# cam image output_dir
try:
os.makedirs(FLAGS.cam_out)
except:
print('Path already exists.')
pass
def build_trainer(self, cfg):
# build trainer
trainer = Trainer(cfg, mode='test')
# load weights
trainer.load_weights(cfg.weights)
# record the bbox index before nms
# Todo: hard code for nms return index
if cfg.architecture == 'FasterRCNN':
trainer.model.bbox_post_process.nms.return_index = True
elif cfg.architecture == 'YOLOv3':
if trainer.model.post_process is not None:
# anchor based YOLOs: YOLOv3,PP-YOLO
trainer.model.post_process.nms.return_index = True
else:
# anchor free YOLOs: PP-YOLOE, PP-YOLOE+
trainer.model.yolo_head.nms.return_index = True
else:
print(
'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!'
)
sys.exit()
return trainer
def set_hook(self):
# set hook for extraction of featuremaps and grads
self.target_feats = {}
self.target_layer_name = 'trainer.model.backbone'
def hook(layer, input, output):
self.target_feats[layer._layer_name_for_hook] = output
self.trainer.model.backbone._layer_name_for_hook = self.target_layer_name
self.trainer.model.backbone.register_forward_post_hook(hook)
def get_bboxes(self):
# get inference images
images = get_test_images(self.FLAGS.infer_dir, self.FLAGS.infer_img)
# inference
result = self.trainer.predict(
images,
draw_threshold=self.FLAGS.draw_threshold,
output_dir=self.FLAGS.output_dir,
save_results=self.FLAGS.save_results,
visualize=False)[0]
return result
def get_bboxes_cams(self):
# Get the bboxes prediction(after nms result) of the input
inference_result = self.get_bboxes()
# read input image
# Todo: Support folder multi-images process
from PIL import Image
img = np.array(Image.open(self.cfg.infer_img))
# data for calaulate bbox grad_cam
cam_data = inference_result['cam_data']
"""
if Faster_RCNN based architecture:
cam_data: {'scores': tensor with shape [num_of_bboxes_before_nms, num_classes], for example: [1000, 80]
'before_nms_indexes': tensor with shape [num_of_bboxes_after_nms, 1], for example: [300, 1]
}
elif YOLOv3 based architecture:
cam_data: {'scores': tensor with shape [1, num_classes, num_of_yolo_bboxes_before_nms], #for example: [1, 80, 8400]
'before_nms_indexes': tensor with shape [num_of_yolo_bboxes_after_nms, 1], # for example: [300, 1]
}
"""
# array index of the predicted bbox before nms
if self.cfg.architecture == 'FasterRCNN':
# the bbox array shape of FasterRCNN before nms is [num_of_bboxes_before_nms, num_classes, 4],
# we need to divide num_classes to get the before_nms_index;
before_nms_indexes = cam_data['before_nms_indexes'].cpu().numpy(
) // self.num_class # num_class
elif self.cfg.architecture == 'YOLOv3':
before_nms_indexes = cam_data['before_nms_indexes'].cpu().numpy()
else:
print(
'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!'
)
sys.exit()
# Calculate and visualize the heatmap of per predict bbox
for index, target_bbox in enumerate(inference_result['bbox']):
# target_bbox: [cls, score, x1, y1, x2, y2]
# filter bboxes with low predicted scores
if target_bbox[1] < self.FLAGS.draw_threshold:
continue
target_bbox_before_nms = int(before_nms_indexes[index])
# bbox score vector
if self.cfg.architecture == 'FasterRCNN':
# the shape of faster_rcnn scores tensor is
# [num_of_bboxes_before_nms, num_classes], for example: [1000, 80]
score_out = cam_data['scores'][target_bbox_before_nms]
elif self.cfg.architecture == 'YOLOv3':
# the shape of yolov3 scores tensor is
# [1, num_classes, num_of_yolo_bboxes_before_nms]
score_out = cam_data['scores'][0, :, target_bbox_before_nms]
else:
print(
'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!'
)
sys.exit()
# construct one_hot label and do backward to get the gradients
predicted_label = paddle.argmax(score_out)
label_onehot = paddle.nn.functional.one_hot(
predicted_label, num_classes=len(score_out))
label_onehot = label_onehot.squeeze()
target = paddle.sum(score_out * label_onehot)
target.backward(retain_graph=True)
if isinstance(self.target_feats[self.target_layer_name], list):
# when the backbone output contains features of multiple scales,
# take the featuremap of the last scale
# Todo: fuse the cam result from multisclae featuremaps
backbone_grad = self.target_feats[self.target_layer_name][
-1].grad.squeeze().cpu().numpy()
backbone_feat = self.target_feats[self.target_layer_name][
-1].squeeze().cpu().numpy()
else:
backbone_grad = self.target_feats[
self.target_layer_name].grad.squeeze().cpu().numpy()
backbone_feat = self.target_feats[
self.target_layer_name].squeeze().cpu().numpy()
# grad_cam:
exp = grad_cam(backbone_feat, backbone_grad)
# reshape the cam image to the input image size
resized_exp = resize_cam(exp, (img.shape[1], img.shape[0]))
# set the area outside the predic bbox to 0
mask = np.zeros((img.shape[0], img.shape[1], 3))
mask[int(target_bbox[3]):int(target_bbox[5]), int(target_bbox[2]):
int(target_bbox[4]), :] = 1
resized_exp = resized_exp * mask
# add the bbox cam back to the input image
overlay_vis = np.uint8(resized_exp * 0.4 + img * 0.6)
cv2.rectangle(
overlay_vis, (int(target_bbox[2]), int(target_bbox[3])),
(int(target_bbox[4]), int(target_bbox[5])), (0, 0, 255), 2)
# save visualization result
cam_image = Image.fromarray(overlay_vis)
cam_image.save(self.FLAGS.cam_out + '/' + str(index) + '.jpg')
# clear gradients after each bbox grad_cam
target.clear_gradient()
for n, v in self.trainer.model.named_sublayers():
v.clear_gradients()
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
# add python path of PadleDetection to sys.path
parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
sys.path.insert(0, parent_path)
# ignore warning log
import warnings
warnings.filterwarnings('ignore')
from ppdet.utils.cli import ArgsParser, merge_args
from ppdet.core.workspace import load_config, merge_config
from ppdet.utils.check import check_gpu, check_npu, check_xpu, check_version, check_config
from ppdet.utils.cam_utils import BBoxCAM
import paddle
def parse_args():
parser = ArgsParser()
parser.add_argument(
"--infer_img",
type=str,
default='demo/000000014439.jpg', # hxw: 404x640
help="Image path, has higher priority over --infer_dir")
parser.add_argument("--weights",
type=str,
default='output/faster_rcnn_r50_vd_fpn_2x_coco_paddlejob/best_model.pdparams'
)
parser.add_argument("--cam_out",
type=str,
default='cam_faster_rcnn'
)
parser.add_argument("--use_gpu",
type=bool,
default=True)
parser.add_argument(
"--infer_dir",
type=str,
default=None,
help="Directory for images to perform inference on.")
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Directory for storing the output visualization files.")
parser.add_argument(
"--draw_threshold",
type=float,
default=0.8,
help="Threshold to reserve the result for visualization.")
parser.add_argument(
"--save_results",
type=bool,
default=False,
help="Whether to save inference results to output_dir.")
parser.add_argument(
"--slice_infer",
action='store_true',
help="Whether to slice the image and merge the inference results for small object detection."
)
parser.add_argument(
'--slice_size',
nargs='+',
type=int,
default=[640, 640],
help="Height of the sliced image.")
parser.add_argument(
"--overlap_ratio",
nargs='+',
type=float,
default=[0.25, 0.25],
help="Overlap height ratio of the sliced image.")
parser.add_argument(
"--combine_method",
type=str,
default='nms',
help="Combine method of the sliced images' detection results, choose in ['nms', 'nmm', 'concat']."
)
args = parser.parse_args()
return args
def run(FLAGS, cfg):
assert cfg.architecture in ['FasterRCNN', 'YOLOv3'], 'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!'
bbox_cam = BBoxCAM(FLAGS, cfg)
bbox_cam.get_bboxes_cams()
print('finish')
def main():
FLAGS = parse_args()
cfg = load_config(FLAGS.config)
merge_args(cfg, FLAGS)
merge_config(FLAGS.opt)
# disable npu in config by default
if 'use_npu' not in cfg:
cfg.use_npu = False
# disable xpu in config by default
if 'use_xpu' not in cfg:
cfg.use_xpu = False
if cfg.use_gpu:
place = paddle.set_device('gpu')
elif cfg.use_npu:
place = paddle.set_device('npu')
elif cfg.use_xpu:
place = paddle.set_device('xpu')
else:
place = paddle.set_device('cpu')
check_config(cfg)
check_gpu(cfg.use_gpu)
check_npu(cfg.use_npu)
check_xpu(cfg.use_xpu)
check_version()
run(FLAGS, cfg)
if __name__ == '__main__':
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册