diff --git a/docs/tutorials/GradCAM_cn.md b/docs/tutorials/GradCAM_cn.md index 2dcd06e05da753dd3c4622748dae114058a07e68..e5a28d142e11c95ef4ee73e2a31b52a06f31464d 100644 --- a/docs/tutorials/GradCAM_cn.md +++ b/docs/tutorials/GradCAM_cn.md @@ -2,22 +2,23 @@ ## 1.简介 -基于backbone特征图计算物体预测框的cam(类激活图) +基于backbone/roi特征图计算物体预测框的cam(类激活图) ## 2.使用方法 * 以PP-YOLOE为例,准备好数据之后,指定网络配置文件、模型权重地址和图片路径以及输出文件夹路径,使用脚本调用tools/cam_ppdet.py计算图片中物体预测框的grad_cam热力图。下面为运行脚本示例。 ```shell -python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams +python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams ``` * **参数** -| FLAG | 用途 | -| :----------------------: |:-----------------------------------------------------------------------------------------------------:| -| -c | 指定配置文件 | -| --infer_img | 用于预测的图片路径 | -| --cam_out | 指定输出路径 | -| -o | 设置或更改配置文件里的参数内容, 如 -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams | +| FLAG | 用途 | +|:--------------------------:|:--------------------------------------------------------------------------------------------------------------------------:| +| -c | 指定配置文件 | +| --infer_img | 用于预测的图片路径 | +| --cam_out | 指定输出路径 | +| --target_feature_layer_name | 计算cam的特征图位置, 如model.backbone、 model.bbox_head.roi_extractor | +| -o | 设置或更改配置文件里的参数内容, 如 -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams | * 运行效果 @@ -26,12 +27,43 @@ python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer
cam_ppyoloe/225.jpg

-## 3. 目前支持基于FasterRCNN和YOLOv3系列的网络。 -* FasterRCNN网络热图可视化脚本 +## 3. 目前支持基于FasterRCNN/MaskRCNN, PPYOLOE系列以及BlazeFace, SSD, Retinanet网络。 +* PPYOLOE网络热图可视化脚本 ```bash -python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams +python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams ``` -* PPYOLOE网络热图可视化脚本 + +* MaskRCNN网络roi特征热图可视化脚本 +```bash +python tools/cam_ppdet.py -c configs/mask_rcnn/mask_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_mask_rcnn_roi --target_feature_layer_name model.bbox_head.roi_extractor -o weights=https://paddledet.bj.bcebos.com/models/mask_rcnn_r50_vd_fpn_2x_coco.pdparams +``` + +* MaskRCNN网络backbone特征的热图可视化脚本 +```bash +python tools/cam_ppdet.py -c configs/mask_rcnn/mask_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_mask_rcnn_backbone --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/mask_rcnn_r50_vd_fpn_2x_coco.pdparams +``` + +* FasterRCNN网络基于roi特征的热图可视化脚本 +```bash +python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn_roi --target_feature_layer_name model.bbox_head.roi_extractor -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams +``` + +* FasterRCNN网络基于backbone特征的热图可视化脚本 +```bash +python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn_backbone --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams +``` + +* BlaczeFace网络backbone特征热图可视化脚本 ```bash -python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams -``` \ No newline at end of file +python tools/cam_ppdet.py -c configs/face_detection/blazeface_1000e.yml --infer_img demo/hrnet_demo.jpg --cam_out cam_blazeface --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/blazeface_1000e.pdparams +``` + +* SSD网络backbone特征热图可视化脚本 +```bash +python tools/cam_ppdet.py -c configs/ssd/ssd_mobilenet_v1_300_120e_voc.yml --infer_img demo/000000014439.jpg --cam_out cam_ssd --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/ssd_mobilenet_v1_300_120e_voc.pdparams +``` + +* Retinanet网络backbone特征热图可视化脚本 +```bash +python tools/cam_ppdet.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_retinanet --target_feature_layer_name model.backbone -o weights=https://bj.bcebos.com/v1/paddledet/models/retinanet_r50_fpn_2x_coco.pdparams +``` diff --git a/docs/tutorials/GradCAM_en.md b/docs/tutorials/GradCAM_en.md index 7551c709b98e101620c0b533d30e412df8ba4381..29487592ccc7bf3a2b3c8fbd82424851563be33a 100644 --- a/docs/tutorials/GradCAM_en.md +++ b/docs/tutorials/GradCAM_en.md @@ -1,12 +1,12 @@ # Object detection grad_cam heatmap ## 1.Introduction -Calculate the cam (class activation map) of the object predict bbox based on the backbone feature map +Calculate the cam (class activation map) of the object predict bbox based on the backbone/roi feature map ## 2.Usage * Taking PP-YOLOE as an example, after preparing the data, specify the network configuration file, model weight address, image path and output folder path, and then use the script to call tools/cam_ppdet.py to calculate the grad_cam heat map of the prediction box. Below is an example run script. ```shell -python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams +python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams ``` * **Arguments** @@ -16,6 +16,7 @@ python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer | -c | Select config file | | --infer_img | Image path | | --cam_out | Directory for output | +| --target_feature_layer_name | The position of featuremap to do gradcam, for example:model.backbone, model.bbox_head.roi_extractor | | -o | Set parameters in configure file, for example: -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams | * result @@ -26,12 +27,45 @@ python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer
cam_ppyoloe/225.jpg

-## 3.Currently supports networks based on FasterRCNN and YOLOv3 series. -* FasterRCNN bbox heat map visualization script +## 3.Currently supports networks based on FasterRCNN/MaskRCNN, PPYOLOE series and BlazeFace, SSD, Retinanet. +* PPYOLOE bbox heat map visualization script (with backbone featuremap) ```bash -python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams +python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams ``` -* PPYOLOE bbox heat map visualization script + +* MaskRCNN bbox heat map visualization script (with roi featuremap) ```bash -python tools/cam_ppdet.py -c configs/ppyoloe/ppyoloe_crn_l_300e_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_ppyoloe -o weights=https://paddledet.bj.bcebos.com/models/ppyoloe_crn_l_300e_coco.pdparams -``` \ No newline at end of file +python tools/cam_ppdet.py -c configs/mask_rcnn/mask_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_mask_rcnn_roi --target_feature_layer_name model.bbox_head.roi_extractor -o weights=https://paddledet.bj.bcebos.com/models/mask_rcnn_r50_vd_fpn_2x_coco.pdparams +``` + +* MaskRCNN bbox heat map visualization script (with backbone featuremap) +```bash +python tools/cam_ppdet.py -c configs/mask_rcnn/mask_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_mask_rcnn_backbone --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/mask_rcnn_r50_vd_fpn_2x_coco.pdparams +``` + +* FasterRCNN bbox heat map visualization script (with roi featuremap) +```bash +python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn_roi --target_feature_layer_name model.bbox_head.roi_extractor -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams +``` + +* FasterRCNN bbox heat map visualization script (with backbone featuremap) +```bash +python tools/cam_ppdet.py -c configs/faster_rcnn/faster_rcnn_r50_vd_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_faster_rcnn_backbone --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/faster_rcnn_r50_vd_fpn_ssld_2x_coco.pdparams +``` + +* BlaczeFace bbox heat map visualization script (with backbone featuremap) +```bash +python tools/cam_ppdet.py -c configs/face_detection/blazeface_1000e.yml --infer_img demo/hrnet_demo.jpg --cam_out cam_blazeface --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/blazeface_1000e.pdparams +``` + +* SSD bbox heat map visualization script (with backbone featuremap) +```bash +python tools/cam_ppdet.py -c configs/ssd/ssd_mobilenet_v1_300_120e_voc.yml --infer_img demo/000000014439.jpg --cam_out cam_ssd --target_feature_layer_name model.backbone -o weights=https://paddledet.bj.bcebos.com/models/ssd_mobilenet_v1_300_120e_voc.pdparams +``` + +* Retinanet bbox heat map visualization script (with backbone featuremap) +```bash +python tools/cam_ppdet.py -c configs/retinanet/retinanet_r50_fpn_2x_coco.yml --infer_img demo/000000014439.jpg --cam_out cam_retinanet --target_feature_layer_name model.backbone -o weights=https://bj.bcebos.com/v1/paddledet/models/retinanet_r50_fpn_2x_coco.pdparams +``` + + diff --git a/ppdet/modeling/architectures/blazeface.py b/ppdet/modeling/architectures/blazeface.py index 2f5d0d238ebd5d4b92916e79ff192932f4baff33..477732d95eec16448b42b062918d099344a81a10 100644 --- a/ppdet/modeling/architectures/blazeface.py +++ b/ppdet/modeling/architectures/blazeface.py @@ -18,6 +18,8 @@ from __future__ import print_function from ppdet.core.workspace import register, create from .meta_arch import BaseArch +import paddle +import paddle.nn.functional as F __all__ = ['BlazeFace'] @@ -74,18 +76,42 @@ class BlazeFace(BaseArch): self.inputs['gt_class']) else: preds, anchors = self.blaze_head(neck_feats, self.inputs['image']) - bbox, bbox_num, before_nms_indexes = self.post_process( + bbox, bbox_num, nms_keep_idx = self.post_process( preds, anchors, self.inputs['im_shape'], self.inputs['scale_factor']) - return bbox, bbox_num + if self.use_extra_data: + extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx + """extra_data:{ + 'scores': predict scores, + 'nms_keep_idx': bbox index before nms, + } + """ + preds_logits = preds[1] # [[1xNumBBoxNumClass]] + extra_data['scores'] = F.softmax(paddle.concat( + preds_logits, axis=1)).transpose([0, 2, 1]) + extra_data['logits'] = paddle.concat( + preds_logits, axis=1).transpose([0, 2, 1]) + extra_data['nms_keep_idx'] = nms_keep_idx # bbox index before nms + return bbox, bbox_num, extra_data + else: + return bbox, bbox_num def get_loss(self, ): return {"loss": self._forward()} def get_pred(self): - bbox_pred, bbox_num = self._forward() - output = { - "bbox": bbox_pred, - "bbox_num": bbox_num, - } + if self.use_extra_data: + bbox_pred, bbox_num, extra_data = self._forward() + output = { + "bbox": bbox_pred, + "bbox_num": bbox_num, + "extra_data": extra_data + } + else: + bbox_pred, bbox_num = self._forward() + output = { + "bbox": bbox_pred, + "bbox_num": bbox_num, + } + return output diff --git a/ppdet/modeling/architectures/cascade_rcnn.py b/ppdet/modeling/architectures/cascade_rcnn.py index e8d330c695e59ef02e5e81907c98f369d0f23cd6..c5d454f4948891ed1400d09d0e24490ce46fb361 100644 --- a/ppdet/modeling/architectures/cascade_rcnn.py +++ b/ppdet/modeling/architectures/cascade_rcnn.py @@ -108,7 +108,7 @@ class CascadeRCNN(BaseArch): im_shape = self.inputs['im_shape'] scale_factor = self.inputs['scale_factor'] - bbox, bbox_num, before_nms_indexes = self.bbox_post_process( + bbox, bbox_num, nms_keep_idx = self.bbox_post_process( preds, (refined_rois, rois_num), im_shape, scale_factor) # rescale the prediction back to origin image bbox, bbox_pred, bbox_num = self.bbox_post_process.get_pred( diff --git a/ppdet/modeling/architectures/faster_rcnn.py b/ppdet/modeling/architectures/faster_rcnn.py index a520d3a164ba951d5089e54b86ce53fa00265333..41c286fe02ef9d25f5b086ed99931fdd2aa70062 100644 --- a/ppdet/modeling/architectures/faster_rcnn.py +++ b/ppdet/modeling/architectures/faster_rcnn.py @@ -82,21 +82,31 @@ class FasterRCNN(BaseArch): self.inputs) return rpn_loss, bbox_loss else: - cam_data = {} # record bbox scores and index before nms rois, rois_num, _ = self.rpn_head(body_feats, self.inputs) preds, _ = self.bbox_head(body_feats, rois, rois_num, None) - cam_data['scores'] = preds[1] - im_shape = self.inputs['im_shape'] scale_factor = self.inputs['scale_factor'] - bbox, bbox_num, before_nms_indexes = self.bbox_post_process(preds, (rois, rois_num), + bbox, bbox_num, nms_keep_idx = self.bbox_post_process(preds, (rois, rois_num), im_shape, scale_factor) - cam_data['before_nms_indexes'] = before_nms_indexes # , bbox index before nms, for cam # rescale the prediction back to origin image bboxes, bbox_pred, bbox_num = self.bbox_post_process.get_pred( bbox, bbox_num, im_shape, scale_factor) - return bbox_pred, bbox_num, cam_data + + if self.use_extra_data: + extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx + """extra_data:{ + 'scores': predict scores, + 'nms_keep_idx': bbox index before nms, + } + """ + extra_data['scores'] = preds[1] # predict scores (probability) + # Todo: get logits output + extra_data['nms_keep_idx'] = nms_keep_idx # bbox index before nms + return bbox_pred, bbox_num, extra_data + else: + return bbox_pred, bbox_num + def get_loss(self, ): rpn_loss, bbox_loss = self._forward() @@ -108,8 +118,12 @@ class FasterRCNN(BaseArch): return loss def get_pred(self): - bbox_pred, bbox_num, cam_data = self._forward() - output = {'bbox': bbox_pred, 'bbox_num': bbox_num, 'cam_data': cam_data} + if self.use_extra_data: + bbox_pred, bbox_num, extra_data = self._forward() + output = {'bbox': bbox_pred, 'bbox_num': bbox_num, 'extra_data': extra_data} + else: + bbox_pred, bbox_num = self._forward() + output = {'bbox': bbox_pred, 'bbox_num': bbox_num} return output def target_bbox_forward(self, data): diff --git a/ppdet/modeling/architectures/mask_rcnn.py b/ppdet/modeling/architectures/mask_rcnn.py index bea6a3eea0e0b4d301a99eb97fc6f159f6345db4..4f6a9ce10f76801ca4bff4e3dc3e304b8e3567f5 100644 --- a/ppdet/modeling/architectures/mask_rcnn.py +++ b/ppdet/modeling/architectures/mask_rcnn.py @@ -106,7 +106,7 @@ class MaskRCNN(BaseArch): im_shape = self.inputs['im_shape'] scale_factor = self.inputs['scale_factor'] - bbox, bbox_num, before_nms_indexes = self.bbox_post_process( + bbox, bbox_num, nms_keep_idx = self.bbox_post_process( preds, (rois, rois_num), im_shape, scale_factor) mask_out = self.mask_head( body_feats, bbox, bbox_num, self.inputs, feat_func=feat_func) @@ -117,7 +117,20 @@ class MaskRCNN(BaseArch): origin_shape = self.bbox_post_process.get_origin_shape() mask_pred = self.mask_post_process(mask_out, bbox_pred, bbox_num, origin_shape) - return bbox_pred, bbox_num, mask_pred + + if self.use_extra_data: + extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx + """extra_data:{ + 'scores': predict scores, + 'nms_keep_idx': bbox index before nms, + } + """ + extra_data['scores'] = preds[1] # predict scores (probability) + # Todo: get logits output + extra_data['nms_keep_idx'] = nms_keep_idx # bbox index before nms + return bbox_pred, bbox_num, mask_pred, extra_data + else: + return bbox_pred, bbox_num, mask_pred def get_loss(self, ): bbox_loss, mask_loss, rpn_loss = self._forward() @@ -130,6 +143,10 @@ class MaskRCNN(BaseArch): return loss def get_pred(self): - bbox_pred, bbox_num, mask_pred = self._forward() - output = {'bbox': bbox_pred, 'bbox_num': bbox_num, 'mask': mask_pred} + if self.use_extra_data: + bbox_pred, bbox_num, mask_pred, extra_data = self._forward() + output = {'bbox': bbox_pred, 'bbox_num': bbox_num, 'mask': mask_pred, 'extra_data': extra_data} + else: + bbox_pred, bbox_num, mask_pred = self._forward() + output = {'bbox': bbox_pred, 'bbox_num': bbox_num, 'mask': mask_pred} return output diff --git a/ppdet/modeling/architectures/meta_arch.py b/ppdet/modeling/architectures/meta_arch.py index 4ff84a97a61739e06f215f56a64daf0459e4a971..370b2b124bfc1f5477a942f972731f2857e1641c 100644 --- a/ppdet/modeling/architectures/meta_arch.py +++ b/ppdet/modeling/architectures/meta_arch.py @@ -15,11 +15,12 @@ __all__ = ['BaseArch'] @register class BaseArch(nn.Layer): - def __init__(self, data_format='NCHW'): + def __init__(self, data_format='NCHW', use_extra_data=False): super(BaseArch, self).__init__() self.data_format = data_format self.inputs = {} self.fuse_norm = False + self.use_extra_data = use_extra_data def load_meanstd(self, cfg_transform): scale = 1. diff --git a/ppdet/modeling/architectures/ppyoloe.py b/ppdet/modeling/architectures/ppyoloe.py index 7c2b3a81575a6c527eed6def5ee702320b1e7e76..330542b8ab20d04fb433eddffa14bf05afd8e6a2 100644 --- a/ppdet/modeling/architectures/ppyoloe.py +++ b/ppdet/modeling/architectures/ppyoloe.py @@ -106,21 +106,30 @@ class PPYOLOE(BaseArch): raise ValueError return yolo_losses else: - cam_data = {} # record bbox scores and index before nms + yolo_head_outs = self.yolo_head(neck_feats) - cam_data['scores'] = yolo_head_outs[0] if self.post_process is not None: - bbox, bbox_num, before_nms_indexes = self.post_process( + bbox, bbox_num, nms_keep_idx = self.post_process( yolo_head_outs, self.yolo_head.mask_anchors, self.inputs['im_shape'], self.inputs['scale_factor']) - cam_data['before_nms_indexes'] = before_nms_indexes + else: - bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process( + bbox, bbox_num, nms_keep_idx = self.yolo_head.post_process( yolo_head_outs, self.inputs['scale_factor']) - # data for cam - cam_data['before_nms_indexes'] = before_nms_indexes - output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data} + + if self.use_extra_data: + extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx + """extra_data:{ + 'scores': predict scores, + 'nms_keep_idx': bbox index before nms, + } + """ + extra_data['scores'] = yolo_head_outs[0] # predict scores (probability) + extra_data['nms_keep_idx'] = nms_keep_idx + output = {'bbox': bbox, 'bbox_num': bbox_num, 'extra_data': extra_data} + else: + output = {'bbox': bbox, 'bbox_num': bbox_num} return output @@ -218,21 +227,29 @@ class PPYOLOEWithAuxHead(BaseArch): aux_pred=[aux_cls_scores, aux_bbox_preds]) return loss else: - cam_data = {} # record bbox scores and index before nms yolo_head_outs = self.yolo_head(neck_feats) - cam_data['scores'] = yolo_head_outs[0] if self.post_process is not None: - bbox, bbox_num, before_nms_indexes = self.post_process( + bbox, bbox_num, nms_keep_idx = self.post_process( yolo_head_outs, self.yolo_head.mask_anchors, self.inputs['im_shape'], self.inputs['scale_factor']) - cam_data['before_nms_indexes'] = before_nms_indexes else: - bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process( + bbox, bbox_num, nms_keep_idx = self.yolo_head.post_process( yolo_head_outs, self.inputs['scale_factor']) - # data for cam - cam_data['before_nms_indexes'] = before_nms_indexes - output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data} + + if self.use_extra_data: + extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx + """extra_data:{ + 'scores': predict scores, + 'nms_keep_idx': bbox index before nms, + } + """ + extra_data['scores'] = yolo_head_outs[0] # predict scores (probability) + # Todo: get logits output + extra_data['nms_keep_idx'] = nms_keep_idx + output = {'bbox': bbox, 'bbox_num': bbox_num, 'extra_data': extra_data} + else: + output = {'bbox': bbox, 'bbox_num': bbox_num} return output diff --git a/ppdet/modeling/architectures/retinanet.py b/ppdet/modeling/architectures/retinanet.py index e774430a03dfebf74c1e91138ed57f2ee52f1c9d..fc49f0e97365ef10d14c9214133d53db304b880b 100644 --- a/ppdet/modeling/architectures/retinanet.py +++ b/ppdet/modeling/architectures/retinanet.py @@ -19,6 +19,7 @@ from __future__ import print_function from ppdet.core.workspace import register, create from .meta_arch import BaseArch import paddle +import paddle.nn.functional as F __all__ = ['RetinaNet'] @@ -57,9 +58,24 @@ class RetinaNet(BaseArch): return self.head(neck_feats, self.inputs) else: head_outs = self.head(neck_feats) - bbox, bbox_num = self.head.post_process( + bbox, bbox_num, nms_keep_idx = self.head.post_process( head_outs, self.inputs['im_shape'], self.inputs['scale_factor']) - return {'bbox': bbox, 'bbox_num': bbox_num} + + if self.use_extra_data: + extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx + """extra_data:{ + 'scores': predict scores, + 'nms_keep_idx': bbox index before nms, + } + """ + preds_logits = self.head.decode_cls_logits(head_outs[0]) + preds_scores = F.sigmoid(preds_logits) + extra_data['logits'] = preds_logits + extra_data['scores'] = preds_scores + extra_data['nms_keep_idx'] = nms_keep_idx # bbox index before nms + return {'bbox': bbox, 'bbox_num': bbox_num, "extra_data": extra_data} + else: + return {'bbox': bbox, 'bbox_num': bbox_num} def get_loss(self): return self._forward() diff --git a/ppdet/modeling/architectures/ssd.py b/ppdet/modeling/architectures/ssd.py index 2d2bf1dd2ce4b4b273ff420239f2f388ae05e015..b8669b7cf127857662f0f78e9c52e43f29fbfcfe 100644 --- a/ppdet/modeling/architectures/ssd.py +++ b/ppdet/modeling/architectures/ssd.py @@ -18,6 +18,8 @@ from __future__ import print_function from ppdet.core.workspace import register, create from .meta_arch import BaseArch +import paddle +import paddle.nn.functional as F __all__ = ['SSD'] @@ -75,18 +77,42 @@ class SSD(BaseArch): self.inputs['gt_class']) else: preds, anchors = self.ssd_head(body_feats, self.inputs['image']) - bbox, bbox_num, before_nms_indexes = self.post_process( + bbox, bbox_num, nms_keep_idx = self.post_process( preds, anchors, self.inputs['im_shape'], self.inputs['scale_factor']) - return bbox, bbox_num + + if self.use_extra_data: + extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx + """extra_data:{ + 'scores': predict scores, + 'nms_keep_idx': bbox index before nms, + } + """ + preds_logits = preds[1] # [[1xNumBBoxNumClass]] + extra_data['scores'] = F.softmax(paddle.concat( + preds_logits, axis=1)).transpose([0, 2, 1]) + extra_data['logits'] = paddle.concat( + preds_logits, axis=1).transpose([0, 2, 1]) + extra_data['nms_keep_idx'] = nms_keep_idx # bbox index before nms + return bbox, bbox_num, extra_data + else: + return bbox, bbox_num def get_loss(self, ): return {"loss": self._forward()} def get_pred(self): - bbox_pred, bbox_num = self._forward() - output = { - "bbox": bbox_pred, - "bbox_num": bbox_num, - } + if self.use_extra_data: + bbox_pred, bbox_num, extra_data = self._forward() + output = { + "bbox": bbox_pred, + "bbox_num": bbox_num, + "extra_data": extra_data + } + else: + bbox_pred, bbox_num = self._forward() + output = { + "bbox": bbox_pred, + "bbox_num": bbox_num, + } return output diff --git a/ppdet/modeling/architectures/yolo.py b/ppdet/modeling/architectures/yolo.py index 4d9ec4d33fc12253846720db7adc10ce2c1954ed..b004935654ed0ec2290af758133f479147231dbd 100644 --- a/ppdet/modeling/architectures/yolo.py +++ b/ppdet/modeling/architectures/yolo.py @@ -98,9 +98,7 @@ class YOLOv3(BaseArch): return yolo_losses else: - cam_data = {} # record bbox scores and index before nms yolo_head_outs = self.yolo_head(neck_feats) - cam_data['scores'] = yolo_head_outs[0] if self.for_mot: # the detection part of JDE MOT model @@ -116,21 +114,32 @@ class YOLOv3(BaseArch): else: if self.return_idx: # the detection part of JDE MOT model - _, bbox, bbox_num, _ = self.post_process( + _, bbox, bbox_num, nms_keep_idx = self.post_process( yolo_head_outs, self.yolo_head.mask_anchors) elif self.post_process is not None: # anchor based YOLOs: YOLOv3,PP-YOLO,PP-YOLOv2 use mask_anchors - bbox, bbox_num, before_nms_indexes = self.post_process( + bbox, bbox_num, nms_keep_idx = self.post_process( yolo_head_outs, self.yolo_head.mask_anchors, self.inputs['im_shape'], self.inputs['scale_factor']) - cam_data['before_nms_indexes'] = before_nms_indexes else: # anchor free YOLOs: PP-YOLOE, PP-YOLOE+ - bbox, bbox_num, before_nms_indexes = self.yolo_head.post_process( + bbox, bbox_num, nms_keep_idx = self.yolo_head.post_process( yolo_head_outs, self.inputs['scale_factor']) - # data for cam - cam_data['before_nms_indexes'] = before_nms_indexes - output = {'bbox': bbox, 'bbox_num': bbox_num, 'cam_data': cam_data} + + if self.use_extra_data: + extra_data = {} # record the bbox output before nms, such like scores and nms_keep_idx + """extra_data:{ + 'scores': predict scores, + 'nms_keep_idx': bbox index before nms, + } + """ + extra_data['scores'] = yolo_head_outs[0] # predict scores (probability) + # Todo: get logits output + extra_data['nms_keep_idx'] = nms_keep_idx + # Todo support for mask_anchors yolo + output = {'bbox': bbox, 'bbox_num': bbox_num, 'extra_data': extra_data} + else: + output = {'bbox': bbox, 'bbox_num': bbox_num} return output diff --git a/ppdet/modeling/heads/ppyoloe_head.py b/ppdet/modeling/heads/ppyoloe_head.py index 5784f70f878ad8f44561cca09c109a234331aa4e..60d7bbc2f83fb95dbb7a1afe7be891df98ce5e76 100644 --- a/ppdet/modeling/heads/ppyoloe_head.py +++ b/ppdet/modeling/heads/ppyoloe_head.py @@ -525,9 +525,9 @@ class PPYOLOEHead(nn.Layer): # `exclude_nms=True` just use in benchmark return pred_bboxes, pred_scores, None else: - bbox_pred, bbox_num, before_nms_indexes = self.nms(pred_bboxes, + bbox_pred, bbox_num, nms_keep_idx = self.nms(pred_bboxes, pred_scores) - return bbox_pred, bbox_num, before_nms_indexes + return bbox_pred, bbox_num, nms_keep_idx def get_activation(name="LeakyReLU"): diff --git a/ppdet/modeling/heads/ppyoloe_r_head.py b/ppdet/modeling/heads/ppyoloe_r_head.py index fba401d7e7a7ff725938760be09b478d924dc705..e7cf772f56991152f138dfbf7f5297d01c0e0b0f 100644 --- a/ppdet/modeling/heads/ppyoloe_r_head.py +++ b/ppdet/modeling/heads/ppyoloe_r_head.py @@ -420,6 +420,6 @@ class PPYOLOERHead(nn.Layer): pred_bboxes /= scale_factor if self.export_onnx: return pred_bboxes, pred_scores, None - bbox_pred, bbox_num, before_nms_indexes = self.nms(pred_bboxes, + bbox_pred, bbox_num, nms_keep_idx = self.nms(pred_bboxes, pred_scores) - return bbox_pred, bbox_num, before_nms_indexes + return bbox_pred, bbox_num, nms_keep_idx diff --git a/ppdet/modeling/heads/retina_head.py b/ppdet/modeling/heads/retina_head.py index 8705e86febb30d06fcbbd06187a76548450c9600..67a51265d1decee3d3077a504771f5be050101f3 100644 --- a/ppdet/modeling/heads/retina_head.py +++ b/ppdet/modeling/heads/retina_head.py @@ -245,5 +245,34 @@ class RetinaHead(nn.Layer): bboxes, scores = self.decode(anchors, cls_logits, bboxes_reg, im_shape, scale_factor) - bbox_pred, bbox_num, _ = self.nms(bboxes, scores) - return bbox_pred, bbox_num + bbox_pred, bbox_num, nms_keep_idx = self.nms(bboxes, scores) + return bbox_pred, bbox_num, nms_keep_idx + + + def get_scores_single(self, cls_scores_list): + mlvl_logits = [] + for cls_score in cls_scores_list: + cls_score = cls_score.reshape([-1, self.num_classes]) + if self.nms_pre is not None and cls_score.shape[0] > self.nms_pre: + max_score = cls_score.max(axis=1) + _, topk_inds = max_score.topk(self.nms_pre) + cls_score = cls_score.gather(topk_inds) + + mlvl_logits.append(cls_score) + + mlvl_logits = paddle.concat(mlvl_logits) + mlvl_logits = mlvl_logits.transpose([1, 0]) + + return mlvl_logits + + def decode_cls_logits(self, cls_logits_list): + cls_logits = [_.transpose([0, 2, 3, 1]) for _ in cls_logits_list] + batch_logits = [] + for img_id in range(cls_logits[0].shape[0]): + num_lvls = len(cls_logits) + cls_scores_list = [cls_logits[i][img_id] for i in range(num_lvls)] + logits = self.get_scores_single(cls_scores_list) + batch_logits.append(logits) + batch_logits = paddle.stack(batch_logits, axis=0) + return batch_logits + diff --git a/ppdet/modeling/heads/roi_extractor.py b/ppdet/modeling/heads/roi_extractor.py index b96bb4e91a0ca720d099562c3cf51d95080051fa..6c2f5c81904bf1e55e7799c78a76edc2034e447b 100644 --- a/ppdet/modeling/heads/roi_extractor.py +++ b/ppdet/modeling/heads/roi_extractor.py @@ -15,6 +15,7 @@ import paddle from ppdet.core.workspace import register from ppdet.modeling import ops +import paddle.nn as nn def _to_list(v): @@ -24,7 +25,7 @@ def _to_list(v): @register -class RoIAlign(object): +class RoIAlign(nn.Layer): """ RoI Align module @@ -73,7 +74,7 @@ class RoIAlign(object): def from_config(cls, cfg, input_shape): return {'spatial_scale': [1. / i.stride for i in input_shape]} - def __call__(self, feats, roi, rois_num): + def forward(self, feats, roi, rois_num): roi = paddle.concat(roi) if len(roi) > 1 else roi[0] if len(feats) == 1: rois_feat = paddle.vision.ops.roi_align( diff --git a/ppdet/utils/cam_utils.py b/ppdet/utils/cam_utils.py index 7aa4b174320d49f3655c15c87d406a85e0ecc920..d98350e2c3dda5ec86aa2e83a0d1b0ee0e13b7c0 100644 --- a/ppdet/utils/cam_utils.py +++ b/ppdet/utils/cam_utils.py @@ -4,6 +4,7 @@ import os import sys import glob from ppdet.utils.logger import setup_logger +import copy logger = setup_logger('ppdet_cam') import paddle @@ -119,7 +120,12 @@ class BBoxCAM: # num_class self.num_class = cfg.num_classes # set hook for extraction of featuremaps and grads - self.set_hook() + self.set_hook(cfg) + self.nms_idx_need_divid_numclass_arch = ['FasterRCNN', 'MaskRCNN', 'CascadeRCNN'] + """ + In these networks, the bbox array shape before nms contain num_class, + the nms_keep_idx of the bbox need to divide the num_class; + """ # cam image output_dir try: @@ -134,9 +140,10 @@ class BBoxCAM: # load weights trainer.load_weights(cfg.weights) - # record the bbox index before nms - # Todo: hard code for nms return index - if cfg.architecture == 'FasterRCNN': + # set for get extra_data before nms + trainer.model.use_extra_data=True + # set for record the bbox index before nms + if cfg.architecture == 'FasterRCNN' or cfg.architecture == 'MaskRCNN': trainer.model.bbox_post_process.nms.return_index = True elif cfg.architecture == 'YOLOv3': if trainer.model.post_process is not None: @@ -145,24 +152,39 @@ class BBoxCAM: else: # anchor free YOLOs: PP-YOLOE, PP-YOLOE+ trainer.model.yolo_head.nms.return_index = True + elif cfg.architecture=='BlazeFace' or cfg.architecture=='SSD': + trainer.model.post_process.nms.return_index = True + elif cfg.architecture=='RetinaNet': + trainer.model.head.nms.return_index = True else: print( - 'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!' + cfg.architecture+' is not supported for cam temporarily!' ) sys.exit() + # Todo: Unify the head/post_process name in each model return trainer - def set_hook(self): + def set_hook(self, cfg): # set hook for extraction of featuremaps and grads self.target_feats = {} - self.target_layer_name = 'trainer.model.backbone' + self.target_layer_name = cfg.target_feature_layer_name + # such as trainer.model.backbone, trainer.model.bbox_head.roi_extractor def hook(layer, input, output): self.target_feats[layer._layer_name_for_hook] = output - self.trainer.model.backbone._layer_name_for_hook = self.target_layer_name - self.trainer.model.backbone.register_forward_post_hook(hook) + try: + exec('self.trainer.'+self.target_layer_name+'._layer_name_for_hook = self.target_layer_name') + # self.trainer.target_layer_name._layer_name_for_hook = self.target_layer_name + exec('self.trainer.'+self.target_layer_name+'.register_forward_post_hook(hook)') + # self.trainer.target_layer_name.register_forward_post_hook(hook) + except: + print("Error! " + "The target_layer_name--"+self.target_layer_name+" is not in model! " + "Please check the spelling and " + "the network's architecture!") + sys.exit() def get_bboxes(self): # get inference images @@ -187,31 +209,27 @@ class BBoxCAM: img = np.array(Image.open(self.cfg.infer_img)) # data for calaulate bbox grad_cam - cam_data = inference_result['cam_data'] + extra_data = inference_result['extra_data'] """ - if Faster_RCNN based architecture: - cam_data: {'scores': tensor with shape [num_of_bboxes_before_nms, num_classes], for example: [1000, 80] - 'before_nms_indexes': tensor with shape [num_of_bboxes_after_nms, 1], for example: [300, 1] + Example of Faster_RCNN based architecture: + extra_data: {'scores': tensor with shape [num_of_bboxes_before_nms, num_classes], for example: [1000, 80] + 'nms_keep_idx': tensor with shape [num_of_bboxes_after_nms, 1], for example: [300, 1] } - elif YOLOv3 based architecture: - cam_data: {'scores': tensor with shape [1, num_classes, num_of_yolo_bboxes_before_nms], #for example: [1, 80, 8400] - 'before_nms_indexes': tensor with shape [num_of_yolo_bboxes_after_nms, 1], # for example: [300, 1] + Example of YOLOv3 based architecture: + extra_data: {'scores': tensor with shape [1, num_classes, num_of_yolo_bboxes_before_nms], #for example: [1, 80, 8400] + 'nms_keep_idx': tensor with shape [num_of_yolo_bboxes_after_nms, 1], # for example: [300, 1] } """ # array index of the predicted bbox before nms - if self.cfg.architecture == 'FasterRCNN': - # the bbox array shape of FasterRCNN before nms is [num_of_bboxes_before_nms, num_classes, 4], + if self.cfg.architecture in self.nms_idx_need_divid_numclass_arch: + # some network's bbox array shape before nms may be like [num_of_bboxes_before_nms, num_classes, 4], # we need to divide num_classes to get the before_nms_index; - before_nms_indexes = cam_data['before_nms_indexes'].cpu().numpy( + # currently, only include the rcnn architectures (fasterrcnn, maskrcnn, cascadercnn); + before_nms_indexes = extra_data['nms_keep_idx'].cpu().numpy( ) // self.num_class # num_class - elif self.cfg.architecture == 'YOLOv3': - before_nms_indexes = cam_data['before_nms_indexes'].cpu().numpy() - else: - print( - 'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!' - ) - sys.exit() + else : + before_nms_indexes = extra_data['nms_keep_idx'].cpu().numpy() # Calculate and visualize the heatmap of per predict bbox for index, target_bbox in enumerate(inference_result['bbox']): @@ -222,20 +240,16 @@ class BBoxCAM: target_bbox_before_nms = int(before_nms_indexes[index]) - # bbox score vector - if self.cfg.architecture == 'FasterRCNN': - # the shape of faster_rcnn scores tensor is - # [num_of_bboxes_before_nms, num_classes], for example: [1000, 80] - score_out = cam_data['scores'][target_bbox_before_nms] - elif self.cfg.architecture == 'YOLOv3': - # the shape of yolov3 scores tensor is - # [1, num_classes, num_of_yolo_bboxes_before_nms] - score_out = cam_data['scores'][0, :, target_bbox_before_nms] + if len(extra_data['scores'].shape)==2: + score_out = extra_data['scores'][target_bbox_before_nms] else: - print( - 'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!' - ) - sys.exit() + score_out = extra_data['scores'][0, :, target_bbox_before_nms] + """ + There are two kinds array shape of bbox score output : + 1) [num_of_bboxes_before_nms, num_classes], for example: [1000, 80] + 2) [num_of_image, num_classes, num_of_yolo_bboxes_before_nms], for example: [1, 80, 1000] + """ + # construct one_hot label and do backward to get the gradients predicted_label = paddle.argmax(score_out) @@ -245,34 +259,76 @@ class BBoxCAM: target = paddle.sum(score_out * label_onehot) target.backward(retain_graph=True) - if isinstance(self.target_feats[self.target_layer_name], list): - # when the backbone output contains features of multiple scales, - # take the featuremap of the last scale - # Todo: fuse the cam result from multisclae featuremaps - backbone_grad = self.target_feats[self.target_layer_name][ - -1].grad.squeeze().cpu().numpy() - backbone_feat = self.target_feats[self.target_layer_name][ - -1].squeeze().cpu().numpy() - else: - backbone_grad = self.target_feats[ - self.target_layer_name].grad.squeeze().cpu().numpy() - backbone_feat = self.target_feats[ - self.target_layer_name].squeeze().cpu().numpy() - - # grad_cam: - exp = grad_cam(backbone_feat, backbone_grad) - # reshape the cam image to the input image size - resized_exp = resize_cam(exp, (img.shape[1], img.shape[0])) + if 'backbone' in self.target_layer_name or \ + 'neck' in self.target_layer_name: # backbone/neck level feature + if isinstance(self.target_feats[self.target_layer_name], list): + # when the featuremap contains of multiple scales, + # take the featuremap of the last scale + # Todo: fuse the cam result from multisclae featuremaps + if self.target_feats[self.target_layer_name][ + -1].shape[-1]==1: + """ + if the last level featuremap is 1x1 size, + we take the second last one + """ + cam_grad = self.target_feats[self.target_layer_name][ + -2].grad.squeeze().cpu().numpy() + cam_feat = self.target_feats[self.target_layer_name][ + -2].squeeze().cpu().numpy() + else: + cam_grad = self.target_feats[self.target_layer_name][ + -1].grad.squeeze().cpu().numpy() + cam_feat = self.target_feats[self.target_layer_name][ + -1].squeeze().cpu().numpy() + else: + cam_grad = self.target_feats[ + self.target_layer_name].grad.squeeze().cpu().numpy() + cam_feat = self.target_feats[ + self.target_layer_name].squeeze().cpu().numpy() + else: # roi level feature + cam_grad = self.target_feats[ + self.target_layer_name].grad.squeeze().cpu().numpy()[target_bbox_before_nms] + cam_feat = self.target_feats[ + self.target_layer_name].squeeze().cpu().numpy()[target_bbox_before_nms] - # set the area outside the predic bbox to 0 - mask = np.zeros((img.shape[0], img.shape[1], 3)) - mask[int(target_bbox[3]):int(target_bbox[5]), int(target_bbox[2]): - int(target_bbox[4]), :] = 1 - resized_exp = resized_exp * mask + # grad_cam: + exp = grad_cam(cam_feat, cam_grad) + + if 'backbone' in self.target_layer_name or \ + 'neck' in self.target_layer_name: + """ + when use backbone/neck featuremap, + we first do the cam on whole image, + and then set the area outside the predic bbox to 0 + """ + # reshape the cam image to the input image size + resized_exp = resize_cam(exp, (img.shape[1], img.shape[0])) + mask = np.zeros((img.shape[0], img.shape[1], 3)) + mask[int(target_bbox[3]):int(target_bbox[5]), int(target_bbox[2]): + int(target_bbox[4]), :] = 1 + resized_exp = resized_exp * mask + # add the bbox cam back to the input image + overlay_vis = np.uint8(resized_exp * 0.4 + img * 0.6) + elif 'roi' in self.target_layer_name: + # get the bbox part of the image + bbox_img = copy.deepcopy(img[int(target_bbox[3]):int(target_bbox[5]), + int(target_bbox[2]):int(target_bbox[4]), :]) + # reshape the cam image to the bbox size + resized_exp = resize_cam(exp, (bbox_img.shape[1], bbox_img.shape[0])) + # add the bbox cam back to the bbox image + bbox_overlay_vis = np.uint8(resized_exp * 0.4 + bbox_img * 0.6) + # put the bbox_cam image to the original image + overlay_vis = copy.deepcopy(img) + overlay_vis[int(target_bbox[3]):int(target_bbox[5]), + int(target_bbox[2]):int(target_bbox[4]), :] = bbox_overlay_vis + else: + print( + 'Only supported cam for backbone/neck feature and roi feature, the others are not supported temporarily!' + ) + sys.exit() - # add the bbox cam back to the input image - overlay_vis = np.uint8(resized_exp * 0.4 + img * 0.6) + # put the bbox rectangle on image cv2.rectangle( overlay_vis, (int(target_bbox[2]), int(target_bbox[3])), (int(target_bbox[4]), int(target_bbox[5])), (0, 0, 255), 2) diff --git a/tools/cam_ppdet.py b/tools/cam_ppdet.py index e1c3019959146e8b70928b23731259c756fc2634..e73f082bb2951ee72c686ad92f10c6d0add07262 100644 --- a/tools/cam_ppdet.py +++ b/tools/cam_ppdet.py @@ -58,34 +58,16 @@ def parse_args(): default=False, help="Whether to save inference results to output_dir.") parser.add_argument( - "--slice_infer", - action='store_true', - help="Whether to slice the image and merge the inference results for small object detection." - ) - parser.add_argument( - '--slice_size', - nargs='+', - type=int, - default=[640, 640], - help="Height of the sliced image.") - parser.add_argument( - "--overlap_ratio", - nargs='+', - type=float, - default=[0.25, 0.25], - help="Overlap height ratio of the sliced image.") - parser.add_argument( - "--combine_method", + "--target_feature_layer_name", type=str, - default='nms', - help="Combine method of the sliced images' detection results, choose in ['nms', 'nmm', 'concat']." - ) + default='model.backbone', # define the featuremap to show grad cam, such as model.backbone, model.bbox_head.roi_extractor + help="Whether to save inference results to output_dir.") args = parser.parse_args() return args def run(FLAGS, cfg): - assert cfg.architecture in ['FasterRCNN', 'YOLOv3'], 'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!' + assert cfg.architecture in ['FasterRCNN', 'MaskRCNN', 'YOLOv3', 'BlazeFace', 'SSD', 'RetinaNet'], 'Only supported cam for faster_rcnn based and yolov3 based architecture for now, the others are not supported temporarily!' bbox_cam = BBoxCAM(FLAGS, cfg) bbox_cam.get_bboxes_cams()