未验证 提交 54bb8ebd 编写于 作者: W WJJ1995 提交者: GitHub

Add Support for MMdetection RetinaNet&&FSAF&&SSD&&Faster R-CNN (#614)

* fix flatten op bug for retinanet&&fsaf

* fix RoiAlign op for fasterrcnn

* fix NMS op bugs for ap test

* update model_zoo.md
上级 399dc06a
......@@ -73,7 +73,11 @@
|GPT2| [pytorch(huggingface)](https://github.com/huggingface/transformers/blob/master/notebooks/04-onnx-export.ipynb)|11|转换时需指定input shape,见[文档Q3](../inference_model_convertor/FAQ.md)|
|CifarNet | [tensorflow](https://github.com/tensorflow/models/blob/master/research/slim/nets/cifarnet.py)|9|
|Fcos | [pytorch(mmdetection)](https://github.com/open-mmlab/mmdetection/blob/master/configs/fcos/fcos_r50_caffe_fpn_gn-head_1x_coco.py)|11|
|Yolov3 | [pytorch(mmdetection)](https://github.com/open-mmlab/mmdetection/blob/master/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py)|11||
|Yolov3 | [pytorch(mmdetection)](https://github.com/open-mmlab/mmdetection/blob/master/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py)|11|
|RetinaNet | [pytorch(mmdetection)](https://github.com/open-mmlab/mmdetection/blob/master/configs/retinanet/retinanet_r50_fpn_1x_coco.py)|11|
|FSAF | [pytorch(mmdetection)](https://github.com/open-mmlab/mmdetection/blob/master/configs/fsaf/fsaf_r50_fpn_1x_coco.py)|11|
|SSD | [pytorch(mmdetection)](https://github.com/open-mmlab/mmdetection/blob/master/configs/ssd/ssd300_coco.py)|11|
|Faster R-CNN | [pytorch(mmdetection)](https://github.com/open-mmlab/mmdetection/blob/master/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py)|11||
## PyTorch预测模型
......
......@@ -83,18 +83,18 @@ def multiclass_nms(bboxes,
class NMS(object):
def __init__(self, score_threshold, nms_top_k, nms_threshold):
def __init__(self, score_threshold, keep_top_k, nms_threshold):
self.score_threshold = score_threshold
self.nms_top_k = nms_top_k
self.keep_top_k = keep_top_k
self.nms_threshold = nms_threshold
def __call__(self, bboxes, scores):
attrs = {
'background_label': -1,
'score_threshold': self.score_threshold,
'nms_top_k': self.nms_top_k,
'nms_top_k': -1,
'nms_threshold': self.nms_threshold,
'keep_top_k': -1,
'keep_top_k': self.keep_top_k,
'nms_eta': 1.0,
'normalized': False,
'return_index': True
......
......@@ -465,12 +465,20 @@ class OpSet9():
inputs={"input": val_rois.name},
outputs=[val_rois_shape])
val_rois_num = val_rois.name + '_num'
self.paddle_graph.add_layer(
'paddle.split',
inputs={"x": val_rois_shape},
outputs=[val_rois_num, '_', '_', '_'],
num_or_sections=[1, 1, 1, 1],
axis=0)
if len(val_rois.out_shapes[0]) == 4:
self.paddle_graph.add_layer(
'paddle.split',
inputs={"x": val_rois_shape},
outputs=[val_rois_num, ' _', ' _', ' _'],
num_or_sections=[1, 1, 1, 1],
axis=0)
elif len(val_rois.out_shapes[0]) == 2:
self.paddle_graph.add_layer(
'paddle.split',
inputs={"x": val_rois_shape},
outputs=[val_rois_num, ' _'],
num_or_sections=[1, 1],
axis=0)
layer_attrs = {
'pooled_height': pooled_height,
'pooled_width': pooled_width,
......@@ -1329,7 +1337,7 @@ class OpSet9():
@print_mapping_info
def Flatten(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
output_shape = node.out_shapes[0]
output_shape = val_x.out_shapes[0]
axis = node.get_attr('axis', 1)
shape_list = [1, 1]
if axis == 0:
......@@ -2192,15 +2200,16 @@ class OpSet9():
layer_outputs = [nn_op_name, output_name]
boxes = self.graph.get_input_node(node, idx=0, copy=True)
scores = self.graph.get_input_node(node, idx=1, copy=True)
num_classes = scores.out_shapes[0][1]
inputs_len = len(node.layer.input)
layer_attrs = dict()
if inputs_len > 2:
max_output_boxes_per_class = self.graph.get_input_node(
node, idx=2, copy=True)
layer_attrs["nms_top_k"] = _const_weight_or_none(
max_output_boxes_per_class).tolist()[0]
layer_attrs["keep_top_k"] = _const_weight_or_none(
max_output_boxes_per_class).tolist()[0] * num_classes
else:
layer_attrs["nms_top_k"] = 0
layer_attrs["keep_top_k"] = 0
if inputs_len > 3:
iou_threshold = self.graph.get_input_node(node, idx=3, copy=True)
layer_attrs["nms_threshold"] = _const_weight_or_none(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册