Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
e397bf84
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e397bf84
编写于
8月 08, 2019
作者:
C
channingss
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update ONNX supported models info
上级
e2b1d343
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
46 addition
and
9 deletion
+46
-9
test_model_zoo.md
test_model_zoo.md
+27
-0
x2paddle/decoder/onnx_decoder.py
x2paddle/decoder/onnx_decoder.py
+12
-5
x2paddle/op_mapper/onnx_op_mapper.py
x2paddle/op_mapper/onnx_op_mapper.py
+7
-4
未找到文件。
test_model_zoo.md
浏览文件 @
e397bf84
...
@@ -26,3 +26,30 @@
...
@@ -26,3 +26,30 @@
| ShuffleNet |
[
code
](
https://github.com/miaow1988/ShuffleNet_V2_pytorch_caffe/releases/tag/v0.1.0
)
|
| ShuffleNet |
[
code
](
https://github.com/miaow1988/ShuffleNet_V2_pytorch_caffe/releases/tag/v0.1.0
)
|
| mNASNet |
[
code
](
https://github.com/LiJianfei06/MnasNet-caffe
)
|
| mNASNet |
[
code
](
https://github.com/LiJianfei06/MnasNet-caffe
)
|
| MTCNN |
[
code
](
https://github.com/kpzhang93/MTCNN_face_detection_alignment/tree/master/code/codes/MTCNNv1/model
)
|
| MTCNN |
[
code
](
https://github.com/kpzhang93/MTCNN_face_detection_alignment/tree/master/code/codes/MTCNNv1/model
)
|
# ONNX
| 模型 | 来源 | operator version|
|-------|--------|
| Resnet18 |
[
torchvison.model.resnet18
](
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
)
|9|
| Resnet34 |
[
torchvison.model.resnet34
](
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
)
|9|
| Resnet50 |
[
torchvison.model.resnet50
](
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
)
|9|
| Resnet101 |
[
torchvison.model.resnet101
](
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
)
|9|
| Vgg11 |
[
torchvison.model.vgg11
](
https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
)
|9|
| Vgg11_bn |
[
torchvison.model.vgg11_bn
](
https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
)
|9|
| Vgg19|
[
torchvison.model.vgg16_bn
](
https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py
)
|9|
| Densenet121 |
[
torchvison.model.densenet121
](
https://github.com/pytorch/vision/blob/master/torchvision/models/densenet.py
)
|9|
| Alexnet |
[
onnx official
](
https://github.com/pytorch/vision/blob/master/torchvision/models/alexnet.py
)
|9|
| Shufflenet |
[
onnx official
](
https://github.com/onnx/models/tree/master/vision/classification/shufflenet
)
|9|
| Inception_v2 |
[
onnx official
](
https://github.com/onnx/models/tree/master/vision/classification/inception_and_googlenet/inception_v2
)
|9|
目前onnx2paddle主要支持onnx operator version 9,关于如何使用torchvison的model:
```
import torch
import torchvision
dummy_input = torch.randn(1, 3, 224, 224) #根据不同模型调整shape
resnet18 = torchvision.models.resnet18(pretrained=True)
torch.onnx.export(resnet18, dummy_input, "resnet18.onnx",verbose=True)#"resnet18.onnx"为onnx model的存储路径
```
x2paddle/decoder/onnx_decoder.py
浏览文件 @
e397bf84
...
@@ -27,8 +27,10 @@ from collections import OrderedDict as Dict
...
@@ -27,8 +27,10 @@ from collections import OrderedDict as Dict
import
onnx
import
onnx
import
numpy
as
np
import
numpy
as
np
from
copy
import
deepcopy
from
copy
import
deepcopy
import
logging
as
_logging
default_op_domain
=
'ai.onnx'
default_op_domain
=
'ai.onnx'
_logger
=
_logging
.
getLogger
(
__name__
)
class
ONNXGraphNode
(
GraphNode
):
class
ONNXGraphNode
(
GraphNode
):
...
@@ -229,10 +231,17 @@ class ONNXDecoder(object):
...
@@ -229,10 +231,17 @@ class ONNXDecoder(object):
def
__init__
(
self
,
onnx_model
):
def
__init__
(
self
,
onnx_model
):
model
=
onnx
.
load
(
onnx_model
)
model
=
onnx
.
load
(
onnx_model
)
print
(
'model ir_version: {}, op version: {}'
.
format
(
print
(
'model ir_version: {}, op version: {}'
.
format
(
model
.
ir_version
,
model
.
opset_import
))
model
.
ir_version
,
model
.
opset_import
[
0
].
version
))
if
model
.
opset_import
[
0
].
version
<
9
:
_logger
.
warning
(
'Now, onnx2paddle main support convert onnx model opset_verison == 9,'
'opset_verison of your onnx model is %d < 9,'
'some operator may cannot convert.'
,
model
.
opset_import
[
0
].
version
)
check_model
(
model
)
check_model
(
model
)
model
=
convert_version
(
model
,
9
)
# model = convert_version(model, 10)
model
=
polish_model
(
model
)
model
=
polish_model
(
model
)
model
=
self
.
optimize_model_skip_op_for_inference
(
model
)
model
=
self
.
optimize_model_skip_op_for_inference
(
model
)
...
@@ -263,7 +272,6 @@ class ONNXDecoder(object):
...
@@ -263,7 +272,6 @@ class ONNXDecoder(object):
"""
"""
skip nodes between src_output_name -> dst_input_name and connect this pair
skip nodes between src_output_name -> dst_input_name and connect this pair
"""
"""
processed
=
0
processed
=
0
for
next_idx
in
input_refs
[
src_output_name
]:
for
next_idx
in
input_refs
[
src_output_name
]:
next_node
=
nodes
[
next_idx
]
next_node
=
nodes
[
next_idx
]
...
@@ -278,7 +286,6 @@ class ONNXDecoder(object):
...
@@ -278,7 +286,6 @@ class ONNXDecoder(object):
"""
"""
skip nodes between dst_output_name -> src_input_name and connect this pair
skip nodes between dst_output_name -> src_input_name and connect this pair
"""
"""
processed
=
0
processed
=
0
for
prev_idx
in
output_refs
[
src_input_name
]:
for
prev_idx
in
output_refs
[
src_input_name
]:
prev_node
=
nodes
[
prev_idx
]
prev_node
=
nodes
[
prev_idx
]
...
...
x2paddle/op_mapper/onnx_op_mapper.py
浏览文件 @
e397bf84
...
@@ -355,16 +355,16 @@ class ONNXOpMapper(OpMapper):
...
@@ -355,16 +355,16 @@ class ONNXOpMapper(OpMapper):
if
shape_dtype
is
None
:
if
shape_dtype
is
None
:
_logger
.
warning
(
_logger
.
warning
(
'in op %s(%s -> Reshape -> %s): '
'in op %s(%s -> Reshape -> %s): '
'dtype of input "shape" not inferred, int32 assumed'
,
name
,
'dtype of input "shape" not inferred, int32 assumed'
,
inputs
,
outputs
)
node
.
layer_name
,
val_x
.
layer_name
,
val_reshaped
.
layer_name
)
shape_dtype
=
_np
.
dtype
(
'int32'
)
shape_dtype
=
_np
.
dtype
(
'int32'
)
if
shape
is
None
:
if
shape
is
None
:
shape
=
[
1
,
-
1
]
shape
=
[
1
,
-
1
]
_logger
.
warning
(
_logger
.
warning
(
'in %s(%s -> Reshape -> %s): '
'in %s(%s -> Reshape -> %s): '
'input "shape" not inferred, use [1, -1] as dummy value, '
'input "shape" not inferred, use [1, -1] as dummy value, '
'the behavior of Paddle fluid maybe undefined'
,
n
ame
,
inputs
,
'the behavior of Paddle fluid maybe undefined'
,
n
ode
.
layer_name
,
outputs
)
val_x
.
layer_name
,
val_reshaped
.
layer_name
)
attr
=
{
'shape'
:
shape
,
'name'
:
string
(
node
.
layer_name
)}
attr
=
{
'shape'
:
shape
,
'name'
:
string
(
node
.
layer_name
)}
node
.
fluid_code
.
add_layer
(
'reshape'
,
node
.
fluid_code
.
add_layer
(
'reshape'
,
...
@@ -532,6 +532,8 @@ class ONNXOpMapper(OpMapper):
...
@@ -532,6 +532,8 @@ class ONNXOpMapper(OpMapper):
momentum
=
node
.
get_attr
(
'momentum'
,
.
9
)
momentum
=
node
.
get_attr
(
'momentum'
,
.
9
)
epsilon
=
node
.
get_attr
(
'epsilon'
,
1e-5
)
epsilon
=
node
.
get_attr
(
'epsilon'
,
1e-5
)
# Attribute: spatial is used in BatchNormalization-1,6,7
spatial
=
bool
(
node
.
get_attr
(
'spatial'
))
attr
=
{
attr
=
{
"momentum"
:
momentum
,
"momentum"
:
momentum
,
"epsilon"
:
epsilon
,
"epsilon"
:
epsilon
,
...
@@ -541,6 +543,7 @@ class ONNXOpMapper(OpMapper):
...
@@ -541,6 +543,7 @@ class ONNXOpMapper(OpMapper):
"bias_attr"
:
string
(
val_b
.
layer_name
),
"bias_attr"
:
string
(
val_b
.
layer_name
),
"moving_mean_name"
:
string
(
val_mean
.
layer_name
),
"moving_mean_name"
:
string
(
val_mean
.
layer_name
),
"moving_variance_name"
:
string
(
val_var
.
layer_name
),
"moving_variance_name"
:
string
(
val_var
.
layer_name
),
"use_global_stats"
:
spatial
,
"name"
:
string
(
node
.
layer_name
)
"name"
:
string
(
node
.
layer_name
)
}
}
node
.
fluid_code
.
add_layer
(
"batch_norm"
,
node
.
fluid_code
.
add_layer
(
"batch_norm"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录