未验证 提交 4a2df7e1 编写于 作者: W WJJ1995 提交者: GitHub

Add MMdetection FCOS && Yolov3 support (#595)

* add BatchToSpaceND and SpaceToBatchND op convert

* add Less and fixed Resize RoiAlign Greater Tile op

* add nms custom layer

* fix topk int64 bug

* update model_zoo

* deal with comments
Co-authored-by: Nchanningss <chen_lingchi@163.com>
上级 3903bade
......@@ -18,12 +18,12 @@
| ResNet_V2_101 | [code](https://github.com/tensorflow/models/tree/master/research/slim/nets) |
| UNet | [code1](https://github.com/jakeret/tf_unet)/[code2](https://github.com/lyatdawn/Unet-Tensorflow) |
| MTCNN | [code](https://github.com/AITTSMD/MTCNN-Tensorflow) |
| YOLO-V3| [code](https://github.com/YunYang1994/tensorflow-yolov3) |
| YOLO-V3| [code](https://github.com/YunYang1994/tensorflow-yolov3) |
| FALSR | [code](https://github.com/xiaomi-automl/FALSR) |
| DCSCN | [code](https://modelzoo.co/model/dcscn-super-resolution) |
| Bert(albert) | [code](https://github.com/google-research/albert#pre-trained-models) |
| Bert(chinese_L-12_H-768_A-12) | [code](https://github.com/google-research/bert#pre-trained-models) |
| Bert(multi_cased_L-12_H-768_A-12) | [code](https://github.com/google-research/bert#pre-trained-models) |
| Bert(chinese_L-12_H-768_A-12) | [code](https://github.com/google-research/bert#pre-trained-models) |
| Bert(multi_cased_L-12_H-768_A-12) | [code](https://github.com/google-research/bert#pre-trained-models) |
## Caffe预测模型
......@@ -71,7 +71,9 @@
|Ultra-Light-Fast-Generic-Face-Detector-1MB| [onnx_model](https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/tree/master/models/onnx)|9 |
|BERT| [pytorch(huggingface)](https://github.com/huggingface/transformers/blob/master/notebooks/04-onnx-export.ipynb)|11|转换时需指定input shape,见[文档Q3](../inference_model_convertor/FAQ.md)|
|GPT2| [pytorch(huggingface)](https://github.com/huggingface/transformers/blob/master/notebooks/04-onnx-export.ipynb)|11|转换时需指定input shape,见[文档Q3](../inference_model_convertor/FAQ.md)|
|CifarNet | [tensorflow](https://github.com/tensorflow/models/blob/master/research/slim/nets/cifarnet.py)|9||
|CifarNet | [tensorflow](https://github.com/tensorflow/models/blob/master/research/slim/nets/cifarnet.py)|9|
|Fcos | [pytorch(mmdetection)](https://github.com/open-mmlab/mmdetection/blob/master/configs/fcos/fcos_r50_caffe_fpn_gn-head_1x_coco.py)|11|
|Yolov3 | [pytorch(mmdetection)](https://github.com/open-mmlab/mmdetection/blob/master/configs/yolo/yolov3_d53_mstrain-608_273e_coco.py)|11||
## PyTorch预测模型
......
......@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .one_hot import OneHot
from .pad_two_input import PadWithTwoInput
from .pad_all_dim2 import PadAllDim2
from .pad_all_dim4 import PadAllDim4
from .pad_all_dim4_one_input import PadAllDim4WithOneInput
from .lrn import LocalResponseNorm
from .nms import NMS
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid import core
from paddle.fluid.framework import Variable, in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
def multiclass_nms(bboxes,
scores,
score_threshold,
nms_top_k,
keep_top_k,
nms_threshold=0.3,
normalized=True,
nms_eta=1.,
background_label=-1,
return_index=False,
return_rois_num=True,
rois_num=None,
name=None):
helper = LayerHelper('multiclass_nms3', **locals())
if in_dygraph_mode():
attrs = ('background_label', background_label, 'score_threshold',
score_threshold, 'nms_top_k', nms_top_k, 'nms_threshold',
nms_threshold, 'keep_top_k', keep_top_k, 'nms_eta', nms_eta,
'normalized', normalized)
output, index, nms_rois_num = core.ops.multiclass_nms3(bboxes, scores,
rois_num, *attrs)
if not return_index:
index = None
return output, nms_rois_num, index
else:
output = helper.create_variable_for_type_inference(dtype=bboxes.dtype)
index = helper.create_variable_for_type_inference(dtype='int')
inputs = {'BBoxes': bboxes, 'Scores': scores}
outputs = {'Out': output, 'Index': index}
if rois_num is not None:
inputs['RoisNum'] = rois_num
if return_rois_num:
nms_rois_num = helper.create_variable_for_type_inference(
dtype='int32')
outputs['NmsRoisNum'] = nms_rois_num
helper.append_op(
type="multiclass_nms3",
inputs=inputs,
attrs={
'background_label': background_label,
'score_threshold': score_threshold,
'nms_top_k': nms_top_k,
'nms_threshold': nms_threshold,
'keep_top_k': keep_top_k,
'nms_eta': nms_eta,
'normalized': normalized
},
outputs=outputs)
output.stop_gradient = True
index.stop_gradient = True
if not return_index:
index = None
if not return_rois_num:
nms_rois_num = None
return output, nms_rois_num, index
class NMS(object):
def __init__(self, score_threshold, nms_top_k, nms_threshold):
self.score_threshold = score_threshold
self.nms_top_k = nms_top_k
self.nms_threshold = nms_threshold
def __call__(self, bboxes, scores):
attrs = {
'background_label': -1,
'score_threshold': self.score_threshold,
'nms_top_k': self.nms_top_k,
'nms_threshold': self.nms_threshold,
'keep_top_k': -1,
'nms_eta': 1.0,
'normalized': False,
'return_index': True
}
output, nms_rois_num, index = multiclass_nms(bboxes, scores, **attrs)
clas = paddle.slice(output, axes=[1], starts=[0], ends=[1])
clas = paddle.cast(clas, dtype="int64")
index = paddle.cast(index, dtype="int64")
if bboxes.shape[0] == 1:
batch = paddle.zeros_like(clas, dtype="int64")
else:
bboxes_count = bboxes.shape[1]
batch = paddle.divide(index, bboxes_count)
index = paddle.mod(index, bboxes_count)
res = paddle.concat([batch, clas, index], axis=1)
return res
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册