Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
7c3e9379
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
接近 2 年 前同步成功
通知
329
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7c3e9379
编写于
4月 28, 2019
作者:
M
Macrobull
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bugfix
上级
816ac6e2
变更
16
显示空白变更内容
内联
并排
Showing
16 changed file
with
1605 addition
and
1127 deletion
+1605
-1127
onnx2fluid/examples/convert_data_npz.py
onnx2fluid/examples/convert_data_npz.py
+48
-0
onnx2fluid/examples/convert_data_pb.py
onnx2fluid/examples/convert_data_pb.py
+64
-0
onnx2fluid/examples/gen_some_samples.py
onnx2fluid/examples/gen_some_samples.py
+27
-33
onnx2fluid/examples/gen_unet.py
onnx2fluid/examples/gen_unet.py
+10
-11
onnx2fluid/examples/gen_yolov2.py
onnx2fluid/examples/gen_yolov2.py
+167
-190
onnx2fluid/examples/onnx_model_zoo.sh
onnx2fluid/examples/onnx_model_zoo.sh
+530
-233
onnx2fluid/onnx2fluid/__main__.py
onnx2fluid/onnx2fluid/__main__.py
+1
-1
onnx2fluid/onnx2fluid/cmdline.py
onnx2fluid/onnx2fluid/cmdline.py
+9
-11
onnx2fluid/onnx2fluid/conversion.py
onnx2fluid/onnx2fluid/conversion.py
+43
-23
onnx2fluid/onnx2fluid/framework_pb2.py
onnx2fluid/onnx2fluid/framework_pb2.py
+173
-81
onnx2fluid/onnx2fluid/onnx_utils.py
onnx2fluid/onnx2fluid/onnx_utils.py
+68
-64
onnx2fluid/onnx2fluid/symbolic.py
onnx2fluid/onnx2fluid/symbolic.py
+280
-314
onnx2fluid/onnx2fluid/torch_export_helper.py
onnx2fluid/onnx2fluid/torch_export_helper.py
+33
-34
onnx2fluid/onnx2fluid/validation.py
onnx2fluid/onnx2fluid/validation.py
+35
-38
onnx2fluid/onnx2fluid/writer.py
onnx2fluid/onnx2fluid/writer.py
+108
-87
onnx2fluid/setup.cfg
onnx2fluid/setup.cfg
+9
-7
未找到文件。
onnx2fluid/examples/convert_data_npz.py
0 → 100644
浏览文件 @
7c3e9379
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 27 11:50:03 2019
@author: Macrobull
"""
import
sys
import
numpy
as
np
from
collections
import
OrderedDict
as
Dict
def
make_var_name
(
name
):
"""
make a valid variable name in Python code
"""
if
name
==
''
:
return
'_'
if
name
[
0
].
isdigit
():
return
'var_'
+
name
for
s
in
'
\\
|/:'
:
#
name
=
name
.
replace
(
s
,
'_'
)
if
name
.
startswith
(
'_'
):
name
=
'var'
+
name
return
name
fn
=
sys
.
argv
[
1
]
input_names
=
sys
.
argv
[
2
].
split
(
','
)
output_names
=
sys
.
argv
[
3
].
split
(
','
)
squeeze_data
=
len
(
sys
.
argv
)
>
4
data
=
np
.
load
(
fn
,
encoding
=
'bytes'
)
input_data
=
data
[
'inputs'
]
output_data
=
data
[
'outputs'
]
while
squeeze_data
and
input_data
.
ndim
>
4
and
input_data
.
shape
[
0
]
==
1
:
input_data
=
input_data
.
squeeze
(
0
)
while
squeeze_data
and
output_data
.
ndim
>
2
and
output_data
.
shape
[
0
]
==
1
:
output_data
=
output_data
.
squeeze
(
0
)
inputs
=
Dict
(
zip
(
map
(
make_var_name
,
input_names
),
[
input_data
]))
outputs
=
Dict
(
zip
(
map
(
make_var_name
,
output_names
),
[
output_data
]))
np
.
savez
(
fn
,
inputs
=
inputs
,
outputs
=
outputs
)
# overwrite
onnx2fluid/examples/convert_data_pb.py
0 → 100644
浏览文件 @
7c3e9379
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 27 11:50:03 2019
@author: Macrobull
"""
import
os
,
sys
import
numpy
as
np
import
onnx
import
onnx.numpy_helper
as
numpy_helper
from
collections
import
OrderedDict
as
Dict
from
glob
import
glob
def
make_var_name
(
name
):
"""
make a valid variable name in Python code
"""
if
name
==
''
:
return
'_'
if
name
[
0
].
isdigit
():
return
'var_'
+
name
for
s
in
'
\\
|/:'
:
#
name
=
name
.
replace
(
s
,
'_'
)
if
name
.
startswith
(
'_'
):
name
=
'var'
+
name
return
name
data_dir
=
os
.
path
.
dirname
(
sys
.
argv
[
1
])
input_names
=
sys
.
argv
[
2
].
split
(
','
)
output_names
=
sys
.
argv
[
3
].
split
(
','
)
squeeze_data
=
len
(
sys
.
argv
)
>
4
# Load inputs
inputs
=
[]
for
fn
in
glob
(
os
.
path
.
join
(
data_dir
,
'input_*.pb'
)):
tensor
=
onnx
.
TensorProto
()
with
open
(
fn
,
'rb'
)
as
f
:
tensor
.
ParseFromString
(
f
.
read
())
tensor
=
numpy_helper
.
to_array
(
tensor
)
while
squeeze_data
and
tensor
.
ndim
>
4
and
tensor
.
shape
[
0
]
==
1
:
tensor
=
tensor
.
squeeze
(
0
)
inputs
.
append
(
tensor
)
# Load outputs
outputs
=
[]
for
fn
in
glob
(
os
.
path
.
join
(
data_dir
,
'output_*.pb'
)):
tensor
=
onnx
.
TensorProto
()
with
open
(
fn
,
'rb'
)
as
f
:
tensor
.
ParseFromString
(
f
.
read
())
tensor
=
numpy_helper
.
to_array
(
tensor
)
while
squeeze_data
and
tensor
.
ndim
>
2
and
tensor
.
shape
[
0
]
==
1
:
tensor
=
tensor
.
squeeze
(
0
)
outputs
.
append
(
tensor
)
inputs
=
Dict
(
zip
(
map
(
make_var_name
,
input_names
),
inputs
))
outputs
=
Dict
(
zip
(
map
(
make_var_name
,
output_names
),
outputs
))
np
.
savez
(
data_dir
,
inputs
=
inputs
,
outputs
=
outputs
)
onnx2fluid/examples/gen_some_samples.py
浏览文件 @
7c3e9379
...
@@ -39,7 +39,7 @@ idx = 0
...
@@ -39,7 +39,7 @@ idx = 0
#yp = model(xb)
#yp = model(xb)
#idx += 1
#idx += 1
#print('index: ', idx)
#print('index: ', idx)
#export_onnx_with_validation(model,
(xb, )
, prefix + str(idx),
#export_onnx_with_validation(model,
[xb]
, prefix + str(idx),
# ['x'], ['y'],
# ['x'], ['y'],
# verbose=True, training=False)
# verbose=True, training=False)
...
@@ -61,7 +61,7 @@ idx = 0
...
@@ -61,7 +61,7 @@ idx = 0
#yp = model(xb)
#yp = model(xb)
#idx += 1
#idx += 1
#print('index: ', idx)
#print('index: ', idx)
#export_onnx_with_validation(model,
(xb, )
, prefix + str(idx),
#export_onnx_with_validation(model,
[xb]
, prefix + str(idx),
# ['x'], ['y'],
# ['x'], ['y'],
# verbose=True, training=False)
# verbose=True, training=False)
...
@@ -85,8 +85,7 @@ xb = torch.rand((2, 3))
...
@@ -85,8 +85,7 @@ xb = torch.rand((2, 3))
yp
=
model
(
xb
)
yp
=
model
(
xb
)
idx
+=
1
idx
+=
1
print
(
'index: '
,
idx
)
print
(
'index: '
,
idx
)
export_onnx_with_validation
(
export_onnx_with_validation
(
model
,
[
xb
],
model
,
(
xb
,
),
prefix
+
str
(
idx
),
[
'x'
],
[
'y'
],
prefix
+
str
(
idx
),
[
'x'
],
[
'y'
],
verbose
=
True
,
verbose
=
True
,
training
=
False
)
training
=
False
)
...
@@ -113,8 +112,7 @@ xb1 = torch.rand((2, 3))
...
@@ -113,8 +112,7 @@ xb1 = torch.rand((2, 3))
ya
,
yb
,
yc
=
model
(
xb0
,
xb1
)
ya
,
yb
,
yc
=
model
(
xb0
,
xb1
)
idx
+=
1
idx
+=
1
print
(
'index: '
,
idx
)
print
(
'index: '
,
idx
)
export_onnx_with_validation
(
export_onnx_with_validation
(
model
,
[
xb0
,
xb1
],
model
,
(
xb0
,
xb1
),
prefix
+
str
(
idx
),
[
'x0'
,
'x1'
],
[
'ya'
,
'yb'
,
'yc'
],
prefix
+
str
(
idx
),
[
'x0'
,
'x1'
],
[
'ya'
,
'yb'
,
'yc'
],
verbose
=
True
,
verbose
=
True
,
training
=
False
)
training
=
False
)
...
@@ -137,8 +135,7 @@ theta = torch.rand((2, 2, 3))
...
@@ -137,8 +135,7 @@ theta = torch.rand((2, 2, 3))
grid
=
model
(
theta
)
grid
=
model
(
theta
)
idx
+=
1
idx
+=
1
print
(
'index: '
,
idx
)
print
(
'index: '
,
idx
)
export_onnx_with_validation
(
export_onnx_with_validation
(
model
,
(
theta
,
),
model
,
(
theta
,
),
prefix
+
str
(
idx
),
[
'theta'
],
[
'grid'
],
prefix
+
str
(
idx
),
[
'theta'
],
[
'grid'
],
verbose
=
True
,
verbose
=
True
,
training
=
False
)
training
=
False
)
...
@@ -165,8 +162,7 @@ xb = torch.rand((2, 3, 4, 5))
...
@@ -165,8 +162,7 @@ xb = torch.rand((2, 3, 4, 5))
yp
=
model
(
xb
)
yp
=
model
(
xb
)
idx
+=
1
idx
+=
1
print
(
'index: '
,
idx
)
print
(
'index: '
,
idx
)
export_onnx_with_validation
(
export_onnx_with_validation
(
model
,
[
xb
],
model
,
(
xb
,
),
prefix
+
str
(
idx
),
[
'x'
],
[
'y'
],
prefix
+
str
(
idx
),
[
'x'
],
[
'y'
],
verbose
=
True
,
verbose
=
True
,
training
=
False
)
training
=
False
)
...
@@ -195,8 +191,7 @@ xb = torch.rand((2, 3, 4, 5))
...
@@ -195,8 +191,7 @@ xb = torch.rand((2, 3, 4, 5))
yp
=
model
(
xb
)
yp
=
model
(
xb
)
idx
+=
1
idx
+=
1
print
(
'index: '
,
idx
)
print
(
'index: '
,
idx
)
export_onnx_with_validation
(
export_onnx_with_validation
(
model
,
[
xb
],
model
,
(
xb
,
),
prefix
+
str
(
idx
),
[
'x'
],
[
'y'
],
prefix
+
str
(
idx
),
[
'x'
],
[
'y'
],
verbose
=
True
,
verbose
=
True
,
training
=
False
)
training
=
False
)
...
@@ -220,7 +215,7 @@ export_onnx_with_validation(
...
@@ -220,7 +215,7 @@ export_onnx_with_validation(
#yp = model(xb)
#yp = model(xb)
#idx += 1
#idx += 1
#print('index: ', idx)
#print('index: ', idx)
#export_onnx_with_validation(model,
(xb, )
, prefix + str(idx),
#export_onnx_with_validation(model,
[xb]
, prefix + str(idx),
# ['x'], ['y'],
# ['x'], ['y'],
# verbose=True, training=False)
# verbose=True, training=False)
...
@@ -241,8 +236,7 @@ xb = torch.rand((2, 3))
...
@@ -241,8 +236,7 @@ xb = torch.rand((2, 3))
yp
=
model
(
xb
)
yp
=
model
(
xb
)
idx
+=
1
idx
+=
1
print
(
'index: '
,
idx
)
print
(
'index: '
,
idx
)
export_onnx_with_validation
(
export_onnx_with_validation
(
model
,
[
xb
],
model
,
(
xb
,
),
prefix
+
str
(
idx
),
[
'y'
],
[
'y'
],
prefix
+
str
(
idx
),
[
'y'
],
[
'y'
],
verbose
=
True
,
verbose
=
True
,
training
=
False
)
training
=
False
)
onnx2fluid/examples/gen_unet.py
浏览文件 @
7c3e9379
...
@@ -21,9 +21,9 @@ class double_conv(nn.Module):
...
@@ -21,9 +21,9 @@ class double_conv(nn.Module):
def
__init__
(
self
,
in_ch
,
out_ch
):
def
__init__
(
self
,
in_ch
,
out_ch
):
super
(
double_conv
,
self
).
__init__
()
super
(
double_conv
,
self
).
__init__
()
self
.
conv
=
nn
.
Sequential
(
self
.
conv
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_ch
,
out_ch
,
3
,
padding
=
1
),
nn
.
Conv2d
(
in_ch
,
out_ch
,
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
out_ch
),
nn
.
BatchNorm2d
(
out_ch
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
ReLU
(
inplace
=
True
),
nn
.
Conv2d
(
out_ch
,
out_ch
,
3
,
padding
=
1
),
nn
.
Conv2d
(
out_ch
,
out_ch
,
3
,
padding
=
1
),
nn
.
BatchNorm2d
(
out_ch
),
nn
.
ReLU
(
inplace
=
True
))
nn
.
BatchNorm2d
(
out_ch
),
nn
.
ReLU
(
inplace
=
True
))
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
...
@@ -58,8 +58,8 @@ class up(nn.Module):
...
@@ -58,8 +58,8 @@ class up(nn.Module):
# would be a nice idea if the upsampling could be learned too,
# would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights
# but my machine do not have enough memory to handle all those weights
if
bilinear
:
if
bilinear
:
self
.
up
=
nn
.
Upsample
(
self
.
up
=
nn
.
Upsample
(
scale_factor
=
2
,
scale_factor
=
2
,
mode
=
'bilinear'
)
#, align_corners=True)
mode
=
'bilinear'
)
#, align_corners=True)
else
:
else
:
self
.
up
=
nn
.
ConvTranspose2d
(
in_ch
//
2
,
in_ch
//
2
,
2
,
stride
=
2
)
self
.
up
=
nn
.
ConvTranspose2d
(
in_ch
//
2
,
in_ch
//
2
,
2
,
stride
=
2
)
...
@@ -131,8 +131,7 @@ model = UNet(3, 80)
...
@@ -131,8 +131,7 @@ model = UNet(3, 80)
model
.
eval
()
model
.
eval
()
xb
=
torch
.
rand
((
1
,
3
,
512
,
512
))
xb
=
torch
.
rand
((
1
,
3
,
512
,
512
))
yp
=
model
(
xb
)
yp
=
model
(
xb
)
export_onnx_with_validation
(
export_onnx_with_validation
(
model
,
[
xb
],
model
,
(
xb
,
),
'sample_unet'
,
[
'image'
],
[
'pred'
],
'sample_unet'
,
[
'image'
],
[
'pred'
],
verbose
=
True
,
verbose
=
True
,
training
=
False
)
training
=
False
)
onnx2fluid/examples/gen_yolov2.py
浏览文件 @
7c3e9379
...
@@ -20,8 +20,7 @@ class Yolov2(nn.Module):
...
@@ -20,8 +20,7 @@ class Yolov2(nn.Module):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Yolov2
,
self
).
__init__
()
super
(
Yolov2
,
self
).
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
self
.
conv1
=
nn
.
Conv2d
(
in_channels
=
3
,
in_channels
=
3
,
out_channels
=
32
,
out_channels
=
32
,
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
...
@@ -29,8 +28,7 @@ class Yolov2(nn.Module):
...
@@ -29,8 +28,7 @@ class Yolov2(nn.Module):
bias
=
False
)
bias
=
False
)
self
.
batchnorm1
=
nn
.
BatchNorm2d
(
32
)
self
.
batchnorm1
=
nn
.
BatchNorm2d
(
32
)
self
.
conv2
=
nn
.
Conv2d
(
self
.
conv2
=
nn
.
Conv2d
(
in_channels
=
32
,
in_channels
=
32
,
out_channels
=
64
,
out_channels
=
64
,
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
...
@@ -38,24 +36,21 @@ class Yolov2(nn.Module):
...
@@ -38,24 +36,21 @@ class Yolov2(nn.Module):
bias
=
False
)
bias
=
False
)
self
.
batchnorm2
=
nn
.
BatchNorm2d
(
64
)
self
.
batchnorm2
=
nn
.
BatchNorm2d
(
64
)
self
.
conv3
=
nn
.
Conv2d
(
self
.
conv3
=
nn
.
Conv2d
(
in_channels
=
64
,
in_channels
=
64
,
out_channels
=
128
,
out_channels
=
128
,
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
padding
=
1
,
padding
=
1
,
bias
=
False
)
bias
=
False
)
self
.
batchnorm3
=
nn
.
BatchNorm2d
(
128
)
self
.
batchnorm3
=
nn
.
BatchNorm2d
(
128
)
self
.
conv4
=
nn
.
Conv2d
(
self
.
conv4
=
nn
.
Conv2d
(
in_channels
=
128
,
in_channels
=
128
,
out_channels
=
64
,
out_channels
=
64
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
padding
=
0
,
padding
=
0
,
bias
=
False
)
bias
=
False
)
self
.
batchnorm4
=
nn
.
BatchNorm2d
(
64
)
self
.
batchnorm4
=
nn
.
BatchNorm2d
(
64
)
self
.
conv5
=
nn
.
Conv2d
(
self
.
conv5
=
nn
.
Conv2d
(
in_channels
=
64
,
in_channels
=
64
,
out_channels
=
128
,
out_channels
=
128
,
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
...
@@ -63,24 +58,21 @@ class Yolov2(nn.Module):
...
@@ -63,24 +58,21 @@ class Yolov2(nn.Module):
bias
=
False
)
bias
=
False
)
self
.
batchnorm5
=
nn
.
BatchNorm2d
(
128
)
self
.
batchnorm5
=
nn
.
BatchNorm2d
(
128
)
self
.
conv6
=
nn
.
Conv2d
(
self
.
conv6
=
nn
.
Conv2d
(
in_channels
=
128
,
in_channels
=
128
,
out_channels
=
256
,
out_channels
=
256
,
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
padding
=
1
,
padding
=
1
,
bias
=
False
)
bias
=
False
)
self
.
batchnorm6
=
nn
.
BatchNorm2d
(
256
)
self
.
batchnorm6
=
nn
.
BatchNorm2d
(
256
)
self
.
conv7
=
nn
.
Conv2d
(
self
.
conv7
=
nn
.
Conv2d
(
in_channels
=
256
,
in_channels
=
256
,
out_channels
=
128
,
out_channels
=
128
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
padding
=
0
,
padding
=
0
,
bias
=
False
)
bias
=
False
)
self
.
batchnorm7
=
nn
.
BatchNorm2d
(
128
)
self
.
batchnorm7
=
nn
.
BatchNorm2d
(
128
)
self
.
conv8
=
nn
.
Conv2d
(
self
.
conv8
=
nn
.
Conv2d
(
in_channels
=
128
,
in_channels
=
128
,
out_channels
=
256
,
out_channels
=
256
,
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
...
@@ -88,40 +80,35 @@ class Yolov2(nn.Module):
...
@@ -88,40 +80,35 @@ class Yolov2(nn.Module):
bias
=
False
)
bias
=
False
)
self
.
batchnorm8
=
nn
.
BatchNorm2d
(
256
)
self
.
batchnorm8
=
nn
.
BatchNorm2d
(
256
)
self
.
conv9
=
nn
.
Conv2d
(
self
.
conv9
=
nn
.
Conv2d
(
in_channels
=
256
,
in_channels
=
256
,
out_channels
=
512
,
out_channels
=
512
,
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
padding
=
1
,
padding
=
1
,
bias
=
False
)
bias
=
False
)
self
.
batchnorm9
=
nn
.
BatchNorm2d
(
512
)
self
.
batchnorm9
=
nn
.
BatchNorm2d
(
512
)
self
.
conv10
=
nn
.
Conv2d
(
self
.
conv10
=
nn
.
Conv2d
(
in_channels
=
512
,
in_channels
=
512
,
out_channels
=
256
,
out_channels
=
256
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
padding
=
0
,
padding
=
0
,
bias
=
False
)
bias
=
False
)
self
.
batchnorm10
=
nn
.
BatchNorm2d
(
256
)
self
.
batchnorm10
=
nn
.
BatchNorm2d
(
256
)
self
.
conv11
=
nn
.
Conv2d
(
self
.
conv11
=
nn
.
Conv2d
(
in_channels
=
256
,
in_channels
=
256
,
out_channels
=
512
,
out_channels
=
512
,
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
padding
=
1
,
padding
=
1
,
bias
=
False
)
bias
=
False
)
self
.
batchnorm11
=
nn
.
BatchNorm2d
(
512
)
self
.
batchnorm11
=
nn
.
BatchNorm2d
(
512
)
self
.
conv12
=
nn
.
Conv2d
(
self
.
conv12
=
nn
.
Conv2d
(
in_channels
=
512
,
in_channels
=
512
,
out_channels
=
256
,
out_channels
=
256
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
padding
=
0
,
padding
=
0
,
bias
=
False
)
bias
=
False
)
self
.
batchnorm12
=
nn
.
BatchNorm2d
(
256
)
self
.
batchnorm12
=
nn
.
BatchNorm2d
(
256
)
self
.
conv13
=
nn
.
Conv2d
(
self
.
conv13
=
nn
.
Conv2d
(
in_channels
=
256
,
in_channels
=
256
,
out_channels
=
512
,
out_channels
=
512
,
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
...
@@ -129,40 +116,35 @@ class Yolov2(nn.Module):
...
@@ -129,40 +116,35 @@ class Yolov2(nn.Module):
bias
=
False
)
bias
=
False
)
self
.
batchnorm13
=
nn
.
BatchNorm2d
(
512
)
self
.
batchnorm13
=
nn
.
BatchNorm2d
(
512
)
self
.
conv14
=
nn
.
Conv2d
(
self
.
conv14
=
nn
.
Conv2d
(
in_channels
=
512
,
in_channels
=
512
,
out_channels
=
1024
,
out_channels
=
1024
,
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
padding
=
1
,
padding
=
1
,
bias
=
False
)
bias
=
False
)
self
.
batchnorm14
=
nn
.
BatchNorm2d
(
1024
)
self
.
batchnorm14
=
nn
.
BatchNorm2d
(
1024
)
self
.
conv15
=
nn
.
Conv2d
(
self
.
conv15
=
nn
.
Conv2d
(
in_channels
=
1024
,
in_channels
=
1024
,
out_channels
=
512
,
out_channels
=
512
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
padding
=
0
,
padding
=
0
,
bias
=
False
)
bias
=
False
)
self
.
batchnorm15
=
nn
.
BatchNorm2d
(
512
)
self
.
batchnorm15
=
nn
.
BatchNorm2d
(
512
)
self
.
conv16
=
nn
.
Conv2d
(
self
.
conv16
=
nn
.
Conv2d
(
in_channels
=
512
,
in_channels
=
512
,
out_channels
=
1024
,
out_channels
=
1024
,
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
padding
=
1
,
padding
=
1
,
bias
=
False
)
bias
=
False
)
self
.
batchnorm16
=
nn
.
BatchNorm2d
(
1024
)
self
.
batchnorm16
=
nn
.
BatchNorm2d
(
1024
)
self
.
conv17
=
nn
.
Conv2d
(
self
.
conv17
=
nn
.
Conv2d
(
in_channels
=
1024
,
in_channels
=
1024
,
out_channels
=
512
,
out_channels
=
512
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
padding
=
0
,
padding
=
0
,
bias
=
False
)
bias
=
False
)
self
.
batchnorm17
=
nn
.
BatchNorm2d
(
512
)
self
.
batchnorm17
=
nn
.
BatchNorm2d
(
512
)
self
.
conv18
=
nn
.
Conv2d
(
self
.
conv18
=
nn
.
Conv2d
(
in_channels
=
512
,
in_channels
=
512
,
out_channels
=
1024
,
out_channels
=
1024
,
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
...
@@ -170,16 +152,14 @@ class Yolov2(nn.Module):
...
@@ -170,16 +152,14 @@ class Yolov2(nn.Module):
bias
=
False
)
bias
=
False
)
self
.
batchnorm18
=
nn
.
BatchNorm2d
(
1024
)
self
.
batchnorm18
=
nn
.
BatchNorm2d
(
1024
)
self
.
conv19
=
nn
.
Conv2d
(
self
.
conv19
=
nn
.
Conv2d
(
in_channels
=
1024
,
in_channels
=
1024
,
out_channels
=
1024
,
out_channels
=
1024
,
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
padding
=
1
,
padding
=
1
,
bias
=
False
)
bias
=
False
)
self
.
batchnorm19
=
nn
.
BatchNorm2d
(
1024
)
self
.
batchnorm19
=
nn
.
BatchNorm2d
(
1024
)
self
.
conv20
=
nn
.
Conv2d
(
self
.
conv20
=
nn
.
Conv2d
(
in_channels
=
1024
,
in_channels
=
1024
,
out_channels
=
1024
,
out_channels
=
1024
,
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
...
@@ -187,8 +167,7 @@ class Yolov2(nn.Module):
...
@@ -187,8 +167,7 @@ class Yolov2(nn.Module):
bias
=
False
)
bias
=
False
)
self
.
batchnorm20
=
nn
.
BatchNorm2d
(
1024
)
self
.
batchnorm20
=
nn
.
BatchNorm2d
(
1024
)
self
.
conv21
=
nn
.
Conv2d
(
self
.
conv21
=
nn
.
Conv2d
(
in_channels
=
3072
,
in_channels
=
3072
,
out_channels
=
1024
,
out_channels
=
1024
,
kernel_size
=
3
,
kernel_size
=
3
,
stride
=
1
,
stride
=
1
,
...
@@ -196,8 +175,7 @@ class Yolov2(nn.Module):
...
@@ -196,8 +175,7 @@ class Yolov2(nn.Module):
bias
=
False
)
bias
=
False
)
self
.
batchnorm21
=
nn
.
BatchNorm2d
(
1024
)
self
.
batchnorm21
=
nn
.
BatchNorm2d
(
1024
)
self
.
conv22
=
nn
.
Conv2d
(
self
.
conv22
=
nn
.
Conv2d
(
in_channels
=
1024
,
in_channels
=
1024
,
out_channels
=
125
,
out_channels
=
125
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
1
,
stride
=
1
,
...
@@ -227,12 +205,12 @@ class Yolov2(nn.Module):
...
@@ -227,12 +205,12 @@ class Yolov2(nn.Module):
return
passthrough
return
passthrough
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
out
=
F
.
max_pool2d
(
out
=
F
.
max_pool2d
(
F
.
leaky_relu
(
self
.
batchnorm1
(
self
.
conv1
(
x
)),
F
.
leaky_relu
(
self
.
batchnorm1
(
self
.
conv1
(
x
)),
negative_slope
=
0.1
),
negative_slope
=
0.1
),
2
,
2
,
stride
=
2
)
stride
=
2
)
out
=
F
.
max_pool2d
(
out
=
F
.
max_pool2d
(
F
.
leaky_relu
(
self
.
batchnorm2
(
self
.
conv2
(
out
)),
F
.
leaky_relu
(
self
.
batchnorm2
(
self
.
conv2
(
out
)),
negative_slope
=
0.1
),
negative_slope
=
0.1
),
2
,
2
,
stride
=
2
)
stride
=
2
)
...
@@ -247,36 +225,36 @@ class Yolov2(nn.Module):
...
@@ -247,36 +225,36 @@ class Yolov2(nn.Module):
out
=
F
.
max_pool2d
(
out
,
2
,
stride
=
2
)
out
=
F
.
max_pool2d
(
out
,
2
,
stride
=
2
)
out
=
F
.
leaky_relu
(
self
.
batchnorm9
(
self
.
conv9
(
out
)),
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
self
.
batchnorm9
(
self
.
conv9
(
out
)),
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
out
=
F
.
leaky_relu
(
self
.
batchnorm10
(
self
.
conv10
(
out
)),
self
.
batchnorm10
(
self
.
conv10
(
out
)),
negative_slope
=
0.1
)
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
out
=
F
.
leaky_relu
(
self
.
batchnorm11
(
self
.
conv11
(
out
)),
self
.
batchnorm11
(
self
.
conv11
(
out
)),
negative_slope
=
0.1
)
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
out
=
F
.
leaky_relu
(
self
.
batchnorm12
(
self
.
conv12
(
out
)),
self
.
batchnorm12
(
self
.
conv12
(
out
)),
negative_slope
=
0.1
)
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
out
=
F
.
leaky_relu
(
self
.
batchnorm13
(
self
.
conv13
(
out
)),
self
.
batchnorm13
(
self
.
conv13
(
out
)),
negative_slope
=
0.1
)
negative_slope
=
0.1
)
passthrough
=
self
.
reorg_layer
(
out
)
passthrough
=
self
.
reorg_layer
(
out
)
out
=
F
.
max_pool2d
(
out
,
2
,
stride
=
2
)
out
=
F
.
max_pool2d
(
out
,
2
,
stride
=
2
)
out
=
F
.
leaky_relu
(
out
=
F
.
leaky_relu
(
self
.
batchnorm14
(
self
.
conv14
(
out
)),
self
.
batchnorm14
(
self
.
conv14
(
out
)),
negative_slope
=
0.1
)
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
out
=
F
.
leaky_relu
(
self
.
batchnorm15
(
self
.
conv15
(
out
)),
self
.
batchnorm15
(
self
.
conv15
(
out
)),
negative_slope
=
0.1
)
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
out
=
F
.
leaky_relu
(
self
.
batchnorm16
(
self
.
conv16
(
out
)),
self
.
batchnorm16
(
self
.
conv16
(
out
)),
negative_slope
=
0.1
)
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
out
=
F
.
leaky_relu
(
self
.
batchnorm17
(
self
.
conv17
(
out
)),
self
.
batchnorm17
(
self
.
conv17
(
out
)),
negative_slope
=
0.1
)
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
out
=
F
.
leaky_relu
(
self
.
batchnorm18
(
self
.
conv18
(
out
)),
self
.
batchnorm18
(
self
.
conv18
(
out
)),
negative_slope
=
0.1
)
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
out
=
F
.
leaky_relu
(
self
.
batchnorm19
(
self
.
conv19
(
out
)),
self
.
batchnorm19
(
self
.
conv19
(
out
)),
negative_slope
=
0.1
)
negative_slope
=
0.1
)
out
=
F
.
leaky_relu
(
out
=
F
.
leaky_relu
(
self
.
batchnorm20
(
self
.
conv20
(
out
)),
self
.
batchnorm20
(
self
.
conv20
(
out
)),
negative_slope
=
0.1
)
negative_slope
=
0.1
)
out
=
torch
.
cat
([
passthrough
,
out
],
1
)
out
=
torch
.
cat
([
passthrough
,
out
],
1
)
out
=
F
.
leaky_relu
(
out
=
F
.
leaky_relu
(
self
.
batchnorm21
(
self
.
conv21
(
out
)),
self
.
batchnorm21
(
self
.
conv21
(
out
)),
negative_slope
=
0.1
)
negative_slope
=
0.1
)
out
=
self
.
conv22
(
out
)
out
=
self
.
conv22
(
out
)
return
out
return
out
...
@@ -286,8 +264,7 @@ model = Yolov2()
...
@@ -286,8 +264,7 @@ model = Yolov2()
model
.
eval
()
model
.
eval
()
xb
=
torch
.
rand
((
1
,
3
,
224
,
224
))
xb
=
torch
.
rand
((
1
,
3
,
224
,
224
))
yp
=
model
(
xb
)
yp
=
model
(
xb
)
export_onnx_with_validation
(
export_onnx_with_validation
(
model
,
[
xb
],
model
,
(
xb
,
),
'sample_yolov2'
,
[
'image'
],
[
'pred'
],
'sample_yolov2'
,
[
'image'
],
[
'pred'
],
verbose
=
True
,
verbose
=
True
,
training
=
False
)
training
=
False
)
onnx2fluid/examples/onnx_model_zoo.sh
浏览文件 @
7c3e9379
...
@@ -11,6 +11,7 @@ validate_flags2="/tmp/export/__model__"
...
@@ -11,6 +11,7 @@ validate_flags2="/tmp/export/__model__"
alias
http_get
=
"aria2c -c -s8 -x8"
alias
http_get
=
"aria2c -c -s8 -x8"
# alias python="python3" # if ...
# alias python="python3" # if ...
bvlc_alexnet
()
bvlc_alexnet
()
{
{
bn_tar
=
"bvlc_alexnet"
bn_tar
=
"bvlc_alexnet"
...
@@ -26,17 +27,19 @@ bvlc_alexnet()
...
@@ -26,17 +27,19 @@ bvlc_alexnet()
for
npz
in
"
$bn_tar
"
/
*
.npz
for
npz
in
"
$bn_tar
"
/
*
.npz
do
do
echo
"converting
$npz
..."
echo
"converting
$npz
..."
python convert_data_npz_0
.py
"
$npz
"
data_0 prob_1
-s
python convert_data_npz
.py
"
$npz
"
data_0 prob_1
-s
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
done
done
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb_0
.py
"
$pb_dir
"
data_0 prob_1
python convert_data_pb
.py
"
$pb_dir
"
data_0 prob_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
rm
-rf
"
$bn_tar
/"
}
}
bvlc_googlenet
()
bvlc_googlenet
()
...
@@ -54,10 +57,12 @@ bvlc_googlenet()
...
@@ -54,10 +57,12 @@ bvlc_googlenet()
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb_0
.py
"
$pb_dir
"
data_0 prob_1
python convert_data_pb
.py
"
$pb_dir
"
data_0 prob_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
rm
-rf
"
$bn_tar
/"
}
}
bvlc_reference_caffenet
()
bvlc_reference_caffenet
()
...
@@ -75,10 +80,12 @@ bvlc_reference_caffenet()
...
@@ -75,10 +80,12 @@ bvlc_reference_caffenet()
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb_0
.py
"
$pb_dir
"
data_0 prob_1
python convert_data_pb
.py
"
$pb_dir
"
data_0 prob_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
rm
-rf
"
$bn_tar
/"
}
}
bvlc_reference_rcnn_ilsvrc13
()
bvlc_reference_rcnn_ilsvrc13
()
...
@@ -96,10 +103,65 @@ bvlc_reference_rcnn_ilsvrc13()
...
@@ -96,10 +103,65 @@ bvlc_reference_rcnn_ilsvrc13()
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb_0
.py
"
$pb_dir
"
data_0 fc-rcnn_1
python convert_data_pb
.py
"
$pb_dir
"
data_0 fc-rcnn_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
rm
-rf
"
$bn_tar
/"
}
densenet121
()
{
bn_tar
=
"densenet121"
fn_tar
=
"
$bn_tar
.tar.gz"
fn_model
=
"
$bn_tar
/model.onnx"
http_get
"
$base_url$fn_tar
"
rm
-rf
"
$bn_tar
/"
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
for
npz
in
"
$bn_tar
"
/
*
.npz
do
echo
"converting
$npz
..."
python convert_data_npz.py
"
$npz
"
data_0 fc6_1
-s
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
done
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
echo
"converting
$pb_dir
"
python convert_data_pb.py
"
$pb_dir
"
data_0 fc6_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
rm
-rf
"
$bn_tar
/"
}
emotion_ferplus
()
{
bn_tar
=
"emotion_ferplus"
fn_tar
=
"
$bn_tar
.tar.gz"
fn_model
=
"
$bn_tar
/model.onnx"
http_get
"https://onnxzoo.blob.core.windows.net/models/opset_8/emotion_ferplus/
$fn_tar
"
rm
-rf
"
$bn_tar
/"
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
Input3 Plus692_Output_0
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
rm
-rf
"
$bn_tar
/"
}
}
inception_v1
()
inception_v1
()
...
@@ -117,17 +179,19 @@ inception_v1()
...
@@ -117,17 +179,19 @@ inception_v1()
for
npz
in
"
$bn_tar
"
/
*
.npz
for
npz
in
"
$bn_tar
"
/
*
.npz
do
do
echo
"converting
$npz
..."
echo
"converting
$npz
..."
python convert_data_npz_0
.py
"
$npz
"
data_0 prob_1
-s
python convert_data_npz
.py
"
$npz
"
data_0 prob_1
-s
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
done
done
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb_0
.py
"
$pb_dir
"
data_0 prob_1
python convert_data_pb
.py
"
$pb_dir
"
data_0 prob_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
rm
-rf
"
$bn_tar
/"
}
}
inception_v2
()
inception_v2
()
...
@@ -145,17 +209,65 @@ inception_v2()
...
@@ -145,17 +209,65 @@ inception_v2()
for
npz
in
"
$bn_tar
"
/
*
.npz
for
npz
in
"
$bn_tar
"
/
*
.npz
do
do
echo
"converting
$npz
..."
echo
"converting
$npz
..."
python convert_data_npz_0
.py
"
$npz
"
data_0 prob_1
-s
python convert_data_npz
.py
"
$npz
"
data_0 prob_1
-s
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
done
done
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb_0
.py
"
$pb_dir
"
data_0 prob_1
python convert_data_pb
.py
"
$pb_dir
"
data_0 prob_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
rm
-rf
"
$bn_tar
/"
}
mobilenet
()
{
bn_tar
=
"mobilenetv2-1.0"
fn_tar
=
"
$bn_tar
.tar.gz"
fn_model
=
"
$bn_tar
/
$bn_tar
.onnx"
http_get
"https://s3.amazonaws.com/onnx-model-zoo/mobilenet/mobilenetv2-1.0/
$fn_tar
"
rm
-rf
"
$bn_tar
/"
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data mobilenetv20_output_flatten0_reshape0
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
rm
-rf
"
$bn_tar
/"
}
resnet18
()
{
bn_tar
=
"resnet18v1"
fn_tar
=
"
$bn_tar
.tar.gz"
fn_model
=
"
$bn_tar
/
$bn_tar
.onnx"
http_get
"https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet18v1/
$fn_tar
"
rm
-rf
"
$bn_tar
/"
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data resnetv15_dense0_fwd
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
rm
-rf
"
$bn_tar
/"
}
}
resnet50
()
resnet50
()
...
@@ -173,17 +285,88 @@ resnet50()
...
@@ -173,17 +285,88 @@ resnet50()
for
npz
in
"
$bn_tar
"
/
*
.npz
for
npz
in
"
$bn_tar
"
/
*
.npz
do
do
echo
"converting
$npz
..."
echo
"converting
$npz
..."
python convert_data_npz_0
.py
"
$npz
"
gpu_0/data_0 gpu_0/softmaxout_1
-s
python convert_data_npz
.py
"
$npz
"
gpu_0/data_0 gpu_0/softmaxout_1
-s
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
done
done
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb_0.py
"
$pb_dir
"
gpu_0/data_0 gpu_0/softmaxout_1
python convert_data_pb.py
"
$pb_dir
"
gpu_0/data_0 gpu_0/softmaxout_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
rm
-rf
"
$bn_tar
/"
}
resnet100_arcface
()
{
bn_tar
=
"resnet100"
fn_tar
=
"
$bn_tar
.tar.gz"
fn_model
=
"
$bn_tar
/
$bn_tar
.onnx"
http_get
"https://s3.amazonaws.com/onnx-model-zoo/arcface/resnet100/
$fn_tar
"
rm
-rf
"
$bn_tar
/"
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
-o
/tmp/export/
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data fc1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
rm
-rf
"
$bn_tar
/"
}
resnet101_duc
()
{
bn_tar
=
"ResNet101_DUC_HDC"
fn_tar
=
"
$bn_tar
.tar.gz"
fn_model
=
"
$bn_tar
/
$bn_tar
.onnx"
http_get
"https://s3.amazonaws.com/onnx-model-zoo/duc/
$fn_tar
"
rm
-rf
"
$bn_tar
/"
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data seg_loss
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
rm
-rf
"
$bn_tar
/"
}
resnet152
()
{
bn_tar
=
"resnet152v2"
fn_tar
=
"
$bn_tar
.tar.gz"
fn_model
=
"
$bn_tar
/
$bn_tar
.onnx"
http_get
"https://s3.amazonaws.com/onnx-model-zoo/resnet/resnet152v2/
$fn_tar
"
rm
-rf
"
$bn_tar
/"
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data resnetv27_dense0_fwd
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
rm
-rf
"
$bn_tar
/"
}
}
shufflenet
()
shufflenet
()
...
@@ -201,10 +384,12 @@ shufflenet()
...
@@ -201,10 +384,12 @@ shufflenet()
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb_0.py
"
$pb_dir
"
gpu_0/data_0 gpu_0/softmaxout
_1
python convert_data_pb.py
"
$pb_dir
"
gpu_0/data_0 gpu_0/softmax
_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
rm
-rf
"
$bn_tar
/"
}
}
squeezenet
()
squeezenet
()
...
@@ -222,10 +407,59 @@ squeezenet()
...
@@ -222,10 +407,59 @@ squeezenet()
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb_0.py
"
$pb_dir
"
data_0 softmaxout_1
python convert_data_pb.py
"
$pb_dir
"
data_0 softmaxout_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
rm
-rf
"
$bn_tar
/"
}
squeezenet1v1
()
{
bn_tar
=
"squeezenet1.1"
fn_tar
=
"
$bn_tar
.tar.gz"
fn_model
=
"
$bn_tar
/
$bn_tar
.onnx"
http_get
"https://s3.amazonaws.com/onnx-model-zoo/squeezenet/squeezenet1.1/
$fn_tar
"
rm
-rf
"
$bn_tar
/"
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data squeezenet0_flatten0_reshape0
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
rm
-rf
"
$bn_tar
/"
}
ssd
()
{
bn_tar
=
"ssd"
fn_tar
=
"
$bn_tar
.tar.gz"
fn_model
=
"
$bn_tar
/model.onnx"
http_get
"https://onnxzoo.blob.core.windows.net/models/opset_10/ssd/
$fn_tar
"
rm
-rf
"
$bn_tar
/"
echo
"extracting ..."
mkdir
"
$bn_tar
"
tar
xf
"
$fn_tar
"
-C
"
$bn_tar
"
/
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
image bboxes,labels,scores
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
rm
-rf
"
$bn_tar
/"
}
}
tiny_yolov2
()
tiny_yolov2
()
...
@@ -239,14 +473,39 @@ tiny_yolov2()
...
@@ -239,14 +473,39 @@ tiny_yolov2()
echo
"extracting ..."
echo
"extracting ..."
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-x
y
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-
y
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb_0.py
"
$pb_dir
"
image grid
python convert_data_pb.py
"
$pb_dir
"
image grid
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
rm
-rf
"
$bn_tar
/"
}
vgg16bn
()
{
bn_tar
=
"vgg16-bn"
fn_tar
=
"
$bn_tar
.tar.gz"
fn_model
=
"
$bn_tar
/
$bn_tar
.onnx"
http_get
"https://s3.amazonaws.com/onnx-model-zoo/vgg/vgg16-bn/
$fn_tar
"
rm
-rf
"
$bn_tar
/"
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data vgg0_dense2_fwd
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
rm
-rf
"
$bn_tar
/"
}
}
vgg19
()
vgg19
()
...
@@ -264,10 +523,35 @@ vgg19()
...
@@ -264,10 +523,35 @@ vgg19()
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb_0.py
"
$pb_dir
"
data_0 prob_1
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
rm
-rf
"
$bn_tar
/"
}
yolov3
()
{
bn_tar
=
"yolov3"
fn_tar
=
"
$bn_tar
.tar.gz"
fn_model
=
"
$bn_tar
/model.onnx"
http_get
"https://onnxzoo.blob.core.windows.net/models/opset_10/yolov3/
$fn_tar
"
rm
-rf
"
$bn_tar
/"
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-x
#
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
input_1:01,image_shape:01 yolonms_layer_1/ExpandDims_1:0,yolonms_layer_1/ExpandDims_3:0,yolonms_layer_1/concat_2:0
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
rm
-rf
"
$bn_tar
/"
}
}
zfnet512
()
zfnet512
()
...
@@ -285,10 +569,12 @@ zfnet512()
...
@@ -285,10 +569,12 @@ zfnet512()
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb_0
.py
"
$pb_dir
"
gpu_0/data_0 gpu_0/softmax_1
python convert_data_pb
.py
"
$pb_dir
"
gpu_0/data_0 gpu_0/softmax_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
rm
-rf
"
$bn_tar
/"
}
}
...
@@ -296,11 +582,22 @@ bvlc_alexnet
...
@@ -296,11 +582,22 @@ bvlc_alexnet
bvlc_googlenet
bvlc_googlenet
bvlc_reference_caffenet
bvlc_reference_caffenet
bvlc_reference_rcnn_ilsvrc13
bvlc_reference_rcnn_ilsvrc13
densenet121
emotion_ferplus
# not supported
inception_v1
inception_v1
inception_v2
inception_v2
mobilenet
resnet18
resnet50
resnet50
resnet100_arcface
resnet101_duc
resnet152
shufflenet
shufflenet
squeezenet
# softmax bug
squeezenet
# softmax bug
# tiny_yolov2 # not supported
squeezenet1v1
ssd
# version not supported
tiny_yolov2
# not supported
vgg16bn
vgg19
vgg19
yolov3
# malformed model ?
zfnet512
zfnet512
onnx2fluid/onnx2fluid/__main__.py
浏览文件 @
7c3e9379
...
@@ -92,7 +92,7 @@ parser.add_argument(
...
@@ -92,7 +92,7 @@ parser.add_argument(
parser
.
add_argument
(
parser
.
add_argument
(
'--rtol'
,
'--rtol'
,
type
=
float
,
type
=
float
,
default
=
1e-
4
,
default
=
1e-
2
,
help
=
'assertion relative tolerance for validation'
,
help
=
'assertion relative tolerance for validation'
,
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
onnx2fluid/onnx2fluid/cmdline.py
浏览文件 @
7c3e9379
...
@@ -22,7 +22,6 @@ __all__ = [
...
@@ -22,7 +22,6 @@ __all__ = [
'main'
,
'main'
,
]
]
DEFAULT_ONNX_OPSET_VERSION
=
9
DEFAULT_MODEL_MODULE
=
'model'
DEFAULT_MODEL_MODULE
=
'model'
DEFAULT_MODEL_FUNC
=
'inference'
DEFAULT_MODEL_FUNC
=
'inference'
...
@@ -30,6 +29,7 @@ DEFAULT_MODEL_FUNC = 'inference'
...
@@ -30,6 +29,7 @@ DEFAULT_MODEL_FUNC = 'inference'
def
main
(
**
kwargs
):
def
main
(
**
kwargs
):
"""主程序入口"""
"""主程序入口"""
from
.conversion
import
DEFAULT_ONNX_OPSET_VERSION
from
.conversion
import
convert
from
.conversion
import
convert
logger
=
logging
.
getLogger
(
'onnx2fluid'
)
logger
=
logging
.
getLogger
(
'onnx2fluid'
)
...
@@ -44,9 +44,9 @@ def main(**kwargs):
...
@@ -44,9 +44,9 @@ def main(**kwargs):
if
save_dir
else
basepath
)
+
shutil
.
os
.
sep
if
save_dir
else
basepath
)
+
shutil
.
os
.
sep
model_basename
=
DEFAULT_MODEL_MODULE
+
'.py'
model_basename
=
DEFAULT_MODEL_MODULE
+
'.py'
model_func_name
=
DEFAULT_MODEL_FUNC
model_func_name
=
DEFAULT_MODEL_FUNC
onnx_opset_version
=
DEFAULT_ONNX_OPSET_VERSION
onnx_opset_pedantic
=
kwargs
.
pop
(
'pedantic'
,
True
)
onnx_opset_pedantic
=
kwargs
.
pop
(
'pedantic'
,
True
)
onnx_skip_version_conversion
=
kwargs
.
pop
(
'skip_version_conversion'
,
False
)
skip_version_conversion
=
kwargs
.
pop
(
'skip_version_conversion'
,
False
)
onnx_opset_version
=
None
if
skip_version_conversion
else
DEFAULT_ONNX_OPSET_VERSION
# convert
# convert
convert
(
filename
,
convert
(
filename
,
...
@@ -55,7 +55,6 @@ def main(**kwargs):
...
@@ -55,7 +55,6 @@ def main(**kwargs):
model_func_name
=
model_func_name
,
model_func_name
=
model_func_name
,
onnx_opset_version
=
onnx_opset_version
,
onnx_opset_version
=
onnx_opset_version
,
onnx_opset_pedantic
=
onnx_opset_pedantic
,
onnx_opset_pedantic
=
onnx_opset_pedantic
,
onnx_skip_version_conversion
=
onnx_skip_version_conversion
,
**
kwargs
)
**
kwargs
)
# validate
# validate
...
@@ -69,12 +68,11 @@ def main(**kwargs):
...
@@ -69,12 +68,11 @@ def main(**kwargs):
golden_data_filename
,
**
kwargs
)
golden_data_filename
,
**
kwargs
)
logger
.
info
(
'starting validation on code ...'
)
logger
.
info
(
'starting validation on code ...'
)
passed
&=
validate
(
# this re-generate desc proto with Python code when debug on
shutil
.
os
.
path
.
join
(
save_dir
,
model_basename
),
passed
&=
validate
(
shutil
.
os
.
path
.
join
(
save_dir
,
model_basename
),
golden_data_filename
,
golden_data_filename
,
model_func_name
=
model_func_name
,
model_func_name
=
model_func_name
,
save_inference_model
=
save_inference_model
=
debug
,
debug
,
# re-generate desc proto with python code when debug on
**
kwargs
)
**
kwargs
)
if
not
passed
:
if
not
passed
:
...
...
onnx2fluid/onnx2fluid/conversion.py
浏览文件 @
7c3e9379
...
@@ -14,15 +14,16 @@ __all__ = [
...
@@ -14,15 +14,16 @@ __all__ = [
'convert'
,
'convert'
,
]
]
DEFAULT_ONNX_OPSET_VERSION
=
9
def
convert
(
onnx_model_filename
,
def
convert
(
onnx_model_filename
,
save_dir
,
save_dir
,
model_basename
=
'model.py'
,
model_basename
=
'model.py'
,
model_func_name
=
'inference'
,
model_func_name
=
'inference'
,
embed_params
=
False
,
embed_params
=
False
,
onnx_opset_version
=
9
,
onnx_opset_version
=
None
,
onnx_opset_pedantic
=
True
,
onnx_opset_pedantic
=
True
,
onnx_skip_version_conversion
=
False
,
debug
=
False
,
debug
=
False
,
**
kwargs
):
**
kwargs
):
"""
"""
...
@@ -50,11 +51,13 @@ def convert(onnx_model_filename,
...
@@ -50,11 +51,13 @@ def convert(onnx_model_filename,
# prepare onnx model
# prepare onnx model
logger
.
info
(
'loading model: %s ...'
,
onnx_model_filename
)
logger
.
info
(
'loading model: %s ...'
,
onnx_model_filename
)
onnx_model
=
onnx
.
load
(
onnx_model_filename
)
onnx_model
=
onnx
.
load
(
onnx_model_filename
)
try
:
try
:
logger
.
info
(
'checking model ...'
)
logger
.
info
(
'checking model ...'
)
check_model
(
onnx_model
)
check_model
(
onnx_model
)
if
onnx_skip_version_conversion
:
# WORKAROUND: RuntimeError: No Adapter For OP
if
onnx_opset_version
is
None
:
# WORKAROUND: RuntimeError: No Adapter For OP
logger
.
debug
(
'assumed opset version: %d'
,
onnx_opset_version
)
logger
.
debug
(
'assumed opset version: %d'
,
DEFAULT_ONNX_OPSET_VERSION
)
logger
.
warning
(
logger
.
warning
(
'opset conversion skipped for onnx_opset_pedantic is OFF'
)
'opset conversion skipped for onnx_opset_pedantic is OFF'
)
else
:
else
:
...
@@ -68,6 +71,7 @@ def convert(onnx_model_filename,
...
@@ -68,6 +71,7 @@ def convert(onnx_model_filename,
logger
.
warning
(
'due to onnx_opset_pedantic is OFF'
)
logger
.
warning
(
'due to onnx_opset_pedantic is OFF'
)
logger
.
warning
(
'the ONNX model sanity checking error is suppressed'
)
logger
.
warning
(
'the ONNX model sanity checking error is suppressed'
)
logger
.
warning
(
'value_info inferring may be uncompleted'
)
logger
.
warning
(
'value_info inferring may be uncompleted'
)
# onnx model optimization
# onnx model optimization
logger
.
info
(
'model has %d ops'
,
len
(
onnx_model
.
graph
.
node
))
logger
.
info
(
'model has %d ops'
,
len
(
onnx_model
.
graph
.
node
))
logger
.
info
(
'optimizing model ...'
)
logger
.
info
(
'optimizing model ...'
)
...
@@ -87,10 +91,7 @@ def convert(onnx_model_filename,
...
@@ -87,10 +91,7 @@ def convert(onnx_model_filename,
debug_model_filename
,
_
=
shutil
.
os
.
path
.
splitext
(
onnx_model_filename
)
debug_model_filename
,
_
=
shutil
.
os
.
path
.
splitext
(
onnx_model_filename
)
onnx
.
save
(
model
,
debug_model_filename
+
'.optimized_and_inffered.onnx'
)
onnx
.
save
(
model
,
debug_model_filename
+
'.optimized_and_inffered.onnx'
)
# I/O instances
# onnx.save(model, '/tmp/export/optimized_and_inffered.onnx')
# I/O instances
onnx_graph
=
onnx_model
.
graph
onnx_graph
=
onnx_model
.
graph
fluid_program
=
Program
()
fluid_program
=
Program
()
fluid_writer
=
Writer
()
fluid_writer
=
Writer
()
...
@@ -114,8 +115,8 @@ def convert(onnx_model_filename,
...
@@ -114,8 +115,8 @@ def convert(onnx_model_filename,
# op set conversion
# op set conversion
# topo = 'backward' if embed_params else 'forward'
# topo = 'backward' if embed_params else 'forward'
topo
=
'forward'
topo
=
'forward'
for
name
,
domain
,
op_type
,
inputs
,
outputs
,
attrs
in
graph_ops
(
for
name
,
domain
,
op_type
,
inputs
,
outputs
,
attrs
in
graph_ops
(
onnx_graph
,
onnx_graph
,
topo
=
topo
):
topo
=
topo
):
logger
.
debug
(
'translating op %s %s::%s ...'
,
name
,
domain
,
op_type
)
logger
.
debug
(
'translating op %s %s::%s ...'
,
name
,
domain
,
op_type
)
if
domain
==
DEFAULT_OP_DOMAIN
:
if
domain
==
DEFAULT_OP_DOMAIN
:
domain
=
''
domain
=
''
...
@@ -140,6 +141,24 @@ def convert(onnx_model_filename,
...
@@ -140,6 +141,24 @@ def convert(onnx_model_filename,
logger
.
info
(
'%d ops in, %d ops out'
,
len
(
onnx_graph
.
node
),
logger
.
info
(
'%d ops in, %d ops out'
,
len
(
onnx_graph
.
node
),
len
(
fluid_program
.
op_descs
))
len
(
fluid_program
.
op_descs
))
# shape-inference
for
name
,
value_info
in
graph_value_infos
.
items
():
var_name
=
make_var_name
(
name
)
fluid_program
.
VarTypeInfo
(
var_name
,
value_info
,
remove_batch
=
False
)
# shape-infer only
bad_var_names
=
[]
for
var_name
,
var_desc
in
fluid_program
.
var_descs
.
items
():
if
not
var_desc
.
type
.
lod_tensor
.
HasField
(
'tensor'
):
bad_var_names
.
append
(
var_name
)
if
len
(
bad_var_names
)
>
0
:
logger
.
warning
(
'type info not infered for var %s ...'
,
', '
.
join
(
bad_var_names
[:
5
]))
logger
.
warning
(
'this causes little problem for PaddlePaddle, '
'but Paddle Mobile may not infer correctly'
)
logger
.
warning
(
'please consider adding option -d to invoke PaddlePaddle shape-inference'
)
# weight writer
# weight writer
for
name
,
weight
in
graph_weights
(
onnx_graph
):
for
name
,
weight
in
graph_weights
(
onnx_graph
):
graph_params
.
append
(
name
)
graph_params
.
append
(
name
)
...
@@ -173,8 +192,9 @@ def convert(onnx_model_filename,
...
@@ -173,8 +192,9 @@ def convert(onnx_model_filename,
value_info
=
graph_value_infos
[
name
]
value_info
=
graph_value_infos
[
name
]
assert
value_info
[
'external'
]
assert
value_info
[
'external'
]
external_inputs
.
append
(
name
)
external_inputs
.
append
(
name
)
fluid_writer
.
emit_inputs
(
fluid_writer
.
emit_inputs
(
fluid_program
,
fluid_program
,
external_inputs
,
graph_value_infos
,
external_inputs
,
graph_value_infos
,
remove_batch
=
False
)
# TODO:
remove_batch
=
False
)
# TODO:
input_codes
=
fluid_program
.
codes
input_codes
=
fluid_program
.
codes
fluid_program
.
codes
=
[]
fluid_program
.
codes
=
[]
...
@@ -206,12 +226,13 @@ def convert(onnx_model_filename,
...
@@ -206,12 +226,13 @@ def convert(onnx_model_filename,
fluid_writer
.
write_desc_file
(
fluid_writer
.
write_desc_file
(
desc_filename
,
desc_filename
,
op_descs
=
fluid_program
.
op_descs
,
op_descs
=
fluid_program
.
op_descs
,
var_descs
=
fluid_program
.
var_descs
,
var_descs
=
list
(
fluid_program
.
var_descs
.
values
())
,
)
)
logger
.
info
(
'program saved to %s'
,
desc_filename
)
logger
.
info
(
'program saved to %s'
,
desc_filename
)
logger
.
info
(
'conversion finished'
)
logger
.
info
(
'conversion finished'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
del
convert
del
convert
...
@@ -283,8 +304,7 @@ if __name__ == '__main__':
...
@@ -283,8 +304,7 @@ if __name__ == '__main__':
pedantic
=
args
.
pedantic
pedantic
=
args
.
pedantic
skip_version_conversion
=
args
.
skip_version_conversion
skip_version_conversion
=
args
.
skip_version_conversion
convert
(
convert
(
model_filename
,
model_filename
,
save_dir
,
save_dir
,
embed_params
=
embed_params
,
embed_params
=
embed_params
,
onnx_opset_pedantic
=
pedantic
,
onnx_opset_pedantic
=
pedantic
,
...
...
onnx2fluid/onnx2fluid/framework_pb2.py
浏览文件 @
7c3e9379
...
@@ -28,30 +28,66 @@ _ATTRTYPE = _descriptor.EnumDescriptor(
...
@@ -28,30 +28,66 @@ _ATTRTYPE = _descriptor.EnumDescriptor(
filename
=
None
,
filename
=
None
,
file
=
DESCRIPTOR
,
file
=
DESCRIPTOR
,
values
=
[
values
=
[
_descriptor
.
EnumValueDescriptor
(
_descriptor
.
EnumValueDescriptor
(
name
=
'INT'
,
name
=
'INT'
,
index
=
0
,
number
=
0
,
options
=
None
,
type
=
None
),
index
=
0
,
_descriptor
.
EnumValueDescriptor
(
number
=
0
,
name
=
'FLOAT'
,
index
=
1
,
number
=
1
,
options
=
None
,
type
=
None
),
options
=
None
,
_descriptor
.
EnumValueDescriptor
(
type
=
None
),
name
=
'STRING'
,
index
=
2
,
number
=
2
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'FLOAT'
,
_descriptor
.
EnumValueDescriptor
(
index
=
1
,
name
=
'INTS'
,
index
=
3
,
number
=
3
,
options
=
None
,
type
=
None
),
number
=
1
,
_descriptor
.
EnumValueDescriptor
(
options
=
None
,
name
=
'FLOATS'
,
index
=
4
,
number
=
4
,
options
=
None
,
type
=
None
),
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
_descriptor
.
EnumValueDescriptor
(
name
=
'STRING'
,
name
=
'STRINGS'
,
index
=
5
,
number
=
5
,
options
=
None
,
type
=
None
),
index
=
2
,
_descriptor
.
EnumValueDescriptor
(
number
=
2
,
name
=
'BOOLEAN'
,
index
=
6
,
number
=
6
,
options
=
None
,
type
=
None
),
options
=
None
,
_descriptor
.
EnumValueDescriptor
(
type
=
None
),
name
=
'BOOLEANS'
,
index
=
7
,
number
=
7
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'INTS'
,
_descriptor
.
EnumValueDescriptor
(
index
=
3
,
name
=
'BLOCK'
,
index
=
8
,
number
=
8
,
options
=
None
,
type
=
None
),
number
=
3
,
_descriptor
.
EnumValueDescriptor
(
options
=
None
,
name
=
'LONG'
,
index
=
9
,
number
=
9
,
options
=
None
,
type
=
None
),
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
_descriptor
.
EnumValueDescriptor
(
name
=
'FLOATS'
,
name
=
'BLOCKS'
,
index
=
10
,
number
=
10
,
options
=
None
,
type
=
None
),
index
=
4
,
_descriptor
.
EnumValueDescriptor
(
number
=
4
,
name
=
'LONGS'
,
index
=
11
,
number
=
11
,
options
=
None
,
type
=
None
),
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'STRINGS'
,
index
=
5
,
number
=
5
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'BOOLEAN'
,
index
=
6
,
number
=
6
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'BOOLEANS'
,
index
=
7
,
number
=
7
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'BLOCK'
,
index
=
8
,
number
=
8
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'LONG'
,
index
=
9
,
number
=
9
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'BLOCKS'
,
index
=
10
,
number
=
10
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'LONGS'
,
index
=
11
,
number
=
11
,
options
=
None
,
type
=
None
),
],
],
containing_type
=
None
,
containing_type
=
None
,
options
=
None
,
options
=
None
,
...
@@ -80,53 +116,111 @@ _VARTYPE_TYPE = _descriptor.EnumDescriptor(
...
@@ -80,53 +116,111 @@ _VARTYPE_TYPE = _descriptor.EnumDescriptor(
filename
=
None
,
filename
=
None
,
file
=
DESCRIPTOR
,
file
=
DESCRIPTOR
,
values
=
[
values
=
[
_descriptor
.
EnumValueDescriptor
(
_descriptor
.
EnumValueDescriptor
(
name
=
'BOOL'
,
name
=
'BOOL'
,
index
=
0
,
number
=
0
,
options
=
None
,
type
=
None
),
index
=
0
,
_descriptor
.
EnumValueDescriptor
(
number
=
0
,
name
=
'INT16'
,
index
=
1
,
number
=
1
,
options
=
None
,
type
=
None
),
options
=
None
,
_descriptor
.
EnumValueDescriptor
(
type
=
None
),
name
=
'INT32'
,
index
=
2
,
number
=
2
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'INT16'
,
_descriptor
.
EnumValueDescriptor
(
index
=
1
,
name
=
'INT64'
,
index
=
3
,
number
=
3
,
options
=
None
,
type
=
None
),
number
=
1
,
_descriptor
.
EnumValueDescriptor
(
options
=
None
,
name
=
'FP16'
,
index
=
4
,
number
=
4
,
options
=
None
,
type
=
None
),
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
_descriptor
.
EnumValueDescriptor
(
name
=
'INT32'
,
name
=
'FP32'
,
index
=
5
,
number
=
5
,
options
=
None
,
type
=
None
),
index
=
2
,
_descriptor
.
EnumValueDescriptor
(
number
=
2
,
name
=
'FP64'
,
index
=
6
,
number
=
6
,
options
=
None
,
type
=
None
),
options
=
None
,
_descriptor
.
EnumValueDescriptor
(
type
=
None
),
name
=
'SIZE_T'
,
index
=
7
,
number
=
19
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'INT64'
,
_descriptor
.
EnumValueDescriptor
(
index
=
3
,
name
=
'UINT8'
,
index
=
8
,
number
=
20
,
options
=
None
,
type
=
None
),
number
=
3
,
_descriptor
.
EnumValueDescriptor
(
options
=
None
,
name
=
'INT8'
,
index
=
9
,
number
=
21
,
options
=
None
,
type
=
None
),
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
_descriptor
.
EnumValueDescriptor
(
name
=
'FP16'
,
name
=
'LOD_TENSOR'
,
index
=
10
,
number
=
7
,
options
=
None
,
type
=
None
),
index
=
4
,
_descriptor
.
EnumValueDescriptor
(
number
=
4
,
name
=
'SELECTED_ROWS'
,
index
=
11
,
number
=
8
,
options
=
None
,
type
=
None
),
options
=
None
,
_descriptor
.
EnumValueDescriptor
(
type
=
None
),
name
=
'FEED_MINIBATCH'
,
index
=
12
,
number
=
9
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'FP32'
,
_descriptor
.
EnumValueDescriptor
(
index
=
5
,
name
=
'FETCH_LIST'
,
index
=
13
,
number
=
10
,
options
=
None
,
type
=
None
),
number
=
5
,
_descriptor
.
EnumValueDescriptor
(
options
=
None
,
name
=
'STEP_SCOPES'
,
index
=
14
,
number
=
11
,
options
=
None
,
type
=
None
),
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
_descriptor
.
EnumValueDescriptor
(
name
=
'FP64'
,
name
=
'LOD_RANK_TABLE'
,
index
=
15
,
number
=
12
,
options
=
None
,
index
=
6
,
number
=
6
,
options
=
None
,
type
=
None
),
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
_descriptor
.
EnumValueDescriptor
(
name
=
'SIZE_T'
,
name
=
'LOD_TENSOR_ARRAY'
,
index
=
7
,
number
=
19
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'UINT8'
,
index
=
8
,
number
=
20
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'INT8'
,
index
=
9
,
number
=
21
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'LOD_TENSOR'
,
index
=
10
,
number
=
7
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'SELECTED_ROWS'
,
index
=
11
,
number
=
8
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'FEED_MINIBATCH'
,
index
=
12
,
number
=
9
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'FETCH_LIST'
,
index
=
13
,
number
=
10
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'STEP_SCOPES'
,
index
=
14
,
number
=
11
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'LOD_RANK_TABLE'
,
index
=
15
,
number
=
12
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'LOD_TENSOR_ARRAY'
,
index
=
16
,
index
=
16
,
number
=
13
,
number
=
13
,
options
=
None
,
options
=
None
,
type
=
None
),
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
_descriptor
.
EnumValueDescriptor
(
name
=
'PLACE_LIST'
,
name
=
'PLACE_LIST'
,
index
=
17
,
number
=
14
,
options
=
None
,
type
=
None
),
index
=
17
,
_descriptor
.
EnumValueDescriptor
(
number
=
14
,
name
=
'READER'
,
index
=
18
,
number
=
15
,
options
=
None
,
type
=
None
),
options
=
None
,
_descriptor
.
EnumValueDescriptor
(
type
=
None
),
name
=
'RAW'
,
index
=
19
,
number
=
17
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'READER'
,
_descriptor
.
EnumValueDescriptor
(
index
=
18
,
name
=
'TUPLE'
,
index
=
20
,
number
=
18
,
options
=
None
,
type
=
None
),
number
=
15
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'RAW'
,
index
=
19
,
number
=
17
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'TUPLE'
,
index
=
20
,
number
=
18
,
options
=
None
,
type
=
None
),
],
],
containing_type
=
None
,
containing_type
=
None
,
options
=
None
,
options
=
None
,
...
@@ -1480,8 +1574,7 @@ DESCRIPTOR.enum_types_by_name['AttrType'] = _ATTRTYPE
...
@@ -1480,8 +1574,7 @@ DESCRIPTOR.enum_types_by_name['AttrType'] = _ATTRTYPE
Version
=
_reflection
.
GeneratedProtocolMessageType
(
Version
=
_reflection
.
GeneratedProtocolMessageType
(
'Version'
,
'Version'
,
(
_message
.
Message
,
),
(
_message
.
Message
,
),
dict
(
dict
(
DESCRIPTOR
=
_VERSION
,
DESCRIPTOR
=
_VERSION
,
__module__
=
'framework_pb2'
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.Version)
# @@protoc_insertion_point(class_scope:paddle.framework.proto.Version)
))
))
...
@@ -1601,8 +1694,7 @@ _sym_db.RegisterMessage(VarType.Tuple)
...
@@ -1601,8 +1694,7 @@ _sym_db.RegisterMessage(VarType.Tuple)
VarDesc
=
_reflection
.
GeneratedProtocolMessageType
(
VarDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'VarDesc'
,
'VarDesc'
,
(
_message
.
Message
,
),
(
_message
.
Message
,
),
dict
(
dict
(
DESCRIPTOR
=
_VARDESC
,
DESCRIPTOR
=
_VARDESC
,
__module__
=
'framework_pb2'
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarDesc)
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarDesc)
))
))
...
...
onnx2fluid/onnx2fluid/onnx_utils.py
浏览文件 @
7c3e9379
...
@@ -50,8 +50,7 @@ def print_pb_structure(message, loop_iterative=False, depth=0):
...
@@ -50,8 +50,7 @@ def print_pb_structure(message, loop_iterative=False, depth=0):
if
hasattr
(
message
,
'DESCRIPTOR'
)
and
hasattr
(
message
.
DESCRIPTOR
,
'fields'
):
if
hasattr
(
message
,
'DESCRIPTOR'
)
and
hasattr
(
message
.
DESCRIPTOR
,
'fields'
):
for
field
in
message
.
DESCRIPTOR
.
fields
:
for
field
in
message
.
DESCRIPTOR
.
fields
:
print
(
'
\t
'
*
depth
+
'-'
,
field
.
name
)
print
(
'
\t
'
*
depth
+
'-'
,
field
.
name
)
print_pb_structure
(
print_pb_structure
(
getattr
(
message
,
field
.
name
),
getattr
(
message
,
field
.
name
),
loop_iterative
=
loop_iterative
,
loop_iterative
=
loop_iterative
,
depth
=
(
depth
+
1
))
depth
=
(
depth
+
1
))
...
@@ -59,8 +58,9 @@ def print_pb_structure(message, loop_iterative=False, depth=0):
...
@@ -59,8 +58,9 @@ def print_pb_structure(message, loop_iterative=False, depth=0):
message
,
'__len__'
):
message
,
'__len__'
):
for
idx
,
item
in
enumerate
(
message
):
for
idx
,
item
in
enumerate
(
message
):
print
(
'
\t
'
*
depth
+
'-'
,
idx
)
print
(
'
\t
'
*
depth
+
'-'
,
idx
)
print_pb_structure
(
print_pb_structure
(
item
,
item
,
loop_iterative
=
loop_iterative
,
depth
=
(
depth
+
1
))
loop_iterative
=
loop_iterative
,
depth
=
(
depth
+
1
))
def
build_value_refs
(
nodes
):
def
build_value_refs
(
nodes
):
...
@@ -86,8 +86,9 @@ def get_attribute_value2(attr):
...
@@ -86,8 +86,9 @@ def get_attribute_value2(attr):
if
attr
.
type
==
onnx
.
AttributeProto
.
TENSOR
:
if
attr
.
type
==
onnx
.
AttributeProto
.
TENSOR
:
dtype
=
np
.
dtype
(
TENSOR_TYPE_TO_NP_TYPE
[
attr
.
t
.
data_type
])
dtype
=
np
.
dtype
(
TENSOR_TYPE_TO_NP_TYPE
[
attr
.
t
.
data_type
])
data
=
attr
.
t
.
raw_data
data
=
attr
.
t
.
raw_data
value
=
np
.
frombuffer
(
value
=
np
.
frombuffer
(
data
,
data
,
dtype
=
dtype
,
count
=
(
len
(
data
)
//
dtype
.
itemsize
))
dtype
=
dtype
,
count
=
(
len
(
data
)
//
dtype
.
itemsize
))
elif
attr
.
type
==
onnx
.
AttributeProto
.
STRING
:
elif
attr
.
type
==
onnx
.
AttributeProto
.
STRING
:
value
=
attr
.
s
value
=
attr
.
s
value
=
value
.
decode
()
if
isinstance
(
value
,
bytes
)
else
value
value
=
value
.
decode
()
if
isinstance
(
value
,
bytes
)
else
value
...
@@ -208,6 +209,9 @@ def node_iter(nodes, indices=None):
...
@@ -208,6 +209,9 @@ def node_iter(nodes, indices=None):
if
name
==
''
:
if
name
==
''
:
name
=
'op_'
+
str
(
index
)
name
=
'op_'
+
str
(
index
)
else
:
# make_op_name
for
s
in
'
\\
|/:'
:
#
name
=
name
.
replace
(
s
,
'_'
)
if
domain
==
''
:
if
domain
==
''
:
domain
=
DEFAULT_OP_DOMAIN
domain
=
DEFAULT_OP_DOMAIN
...
@@ -250,25 +254,25 @@ def inferred_model_value_info(model):
...
@@ -250,25 +254,25 @@ def inferred_model_value_info(model):
graph
=
model
.
graph
graph
=
model
.
graph
value_info
=
Dict
()
value_info
=
Dict
()
for
item
in
graph
.
value_info
:
for
item
in
graph
.
value_info
:
value_info
[
item
.
name
]
=
dict
(
value_info
[
item
.
name
]
=
{
dtype
=
tensor_dtype
(
item
),
'dtype'
:
tensor_dtype
(
item
),
shape
=
tensor_shape
(
item
),
'shape'
:
tensor_shape
(
item
),
external
=
False
,
'external'
:
False
,
)
}
for
item
in
graph
.
input
:
for
item
in
graph
.
input
:
assert
item
.
name
not
in
value_info
assert
item
.
name
not
in
value_info
value_info
[
item
.
name
]
=
dict
(
value_info
[
item
.
name
]
=
{
dtype
=
tensor_dtype
(
item
),
'dtype'
:
tensor_dtype
(
item
),
shape
=
tensor_shape
(
item
),
'shape'
:
tensor_shape
(
item
),
external
=
True
,
'external'
:
True
,
)
}
for
item
in
graph
.
output
:
for
item
in
graph
.
output
:
# assert item.name not in value_info, 'bypass-model not supported'
# assert item.name not in value_info, 'bypass-model not supported'
value_info
[
item
.
name
]
=
dict
(
value_info
[
item
.
name
]
=
{
dtype
=
tensor_dtype
(
item
),
'dtype'
:
tensor_dtype
(
item
),
shape
=
tensor_shape
(
item
),
'shape'
:
tensor_shape
(
item
),
external
=
True
,
'external'
:
True
,
)
}
return
value_info
return
value_info
...
@@ -307,7 +311,7 @@ def optimize_model_skip_op_for_inference(model, op_list=None):
...
@@ -307,7 +311,7 @@ def optimize_model_skip_op_for_inference(model, op_list=None):
skip ops can be bypassed for inference
skip ops can be bypassed for inference
"""
"""
if
op_list
is
None
:
if
op_list
is
None
:
op_list
=
[
'Dropout'
]
op_list
=
(
'Dropout'
,
'Identity'
)
nodes
=
model
.
graph
.
node
nodes
=
model
.
graph
.
node
input_refs
,
output_refs
=
build_value_refs
(
nodes
)
input_refs
,
output_refs
=
build_value_refs
(
nodes
)
...
@@ -325,7 +329,7 @@ def optimize_model_skip_op_for_inference(model, op_list=None):
...
@@ -325,7 +329,7 @@ def optimize_model_skip_op_for_inference(model, op_list=None):
if
not
(
op_type
in
op_list
):
if
not
(
op_type
in
op_list
):
continue
continue
if
op_type
in
[
'Dropout'
]
:
if
op_type
in
(
'Dropout'
,
)
:
input_name
=
node
.
input
[
0
]
input_name
=
node
.
input
[
0
]
output_name
=
node
.
output
[
0
]
output_name
=
node
.
output
[
0
]
elif
not
(
len
(
node
.
input
)
==
1
and
len
(
node
.
output
)
==
1
):
elif
not
(
len
(
node
.
input
)
==
1
and
len
(
node
.
output
)
==
1
):
...
@@ -406,7 +410,7 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
...
@@ -406,7 +410,7 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
def
optimize_model_cast
(
model
):
def
optimize_model_cast
(
model
):
"""
"""
strip cascade and unecessary onnx::Cast
strip cascade and unecessary onnx::Cast-9:
"""
"""
nodes
=
model
.
graph
.
node
nodes
=
model
.
graph
.
node
...
@@ -463,13 +467,13 @@ def optimize_model_cast(model):
...
@@ -463,13 +467,13 @@ def optimize_model_cast(model):
def
optimize_model_slice
(
model
):
def
optimize_model_slice
(
model
):
"""
"""
strip cascade and unecessary onnx::Slice
strip cascade and unecessary onnx::Slice-1:9
"""
"""
nodes
=
model
.
graph
.
node
nodes
=
model
.
graph
.
node
input_refs
,
output_refs
=
build_value_refs
(
nodes
)
input_refs
,
output_refs
=
build_value_refs
(
nodes
)
def
_
build_slice_node_chain
(
node_idx
):
def
build_slice_node_chain
(
node_idx
):
chain
=
[]
chain
=
[]
while
True
:
while
True
:
node
=
nodes
[
node_idx
]
node
=
nodes
[
node_idx
]
...
@@ -485,7 +489,7 @@ def optimize_model_slice(model):
...
@@ -485,7 +489,7 @@ def optimize_model_slice(model):
node_idx
=
list
(
input_refs
[
output_name
])[
0
]
node_idx
=
list
(
input_refs
[
output_name
])[
0
]
# axis: (start, end)
# axis: (start, end)
def
_
merge_slice
(
slice_chain
):
def
merge_slice
(
slice_chain
):
merged_slice
=
dict
()
merged_slice
=
dict
()
for
slice_node_idx
in
slice_chain
:
for
slice_node_idx
in
slice_chain
:
node
=
nodes
[
slice_node_idx
]
node
=
nodes
[
slice_node_idx
]
...
@@ -508,14 +512,14 @@ def optimize_model_slice(model):
...
@@ -508,14 +512,14 @@ def optimize_model_slice(model):
ret_nodes
=
ret
.
graph
.
node
ret_nodes
=
ret
.
graph
.
node
nodes_to_remove
=
[]
nodes_to_remove
=
[]
for
node_idx
in
range
(
len
(
nodes
)):
for
node_idx
in
range
(
len
(
nodes
)):
slice_chain
=
_
build_slice_node_chain
(
node_idx
)
slice_chain
=
build_slice_node_chain
(
node_idx
)
if
len
(
slice_chain
)
==
0
:
if
len
(
slice_chain
)
==
0
:
continue
continue
merged_slice
=
_
merge_slice
(
slice_chain
)
merged_slice
=
merge_slice
(
slice_chain
)
if
len
(
merged_slice
)
>
0
and
len
(
slice_chain
)
==
1
:
# no need to merge
if
len
(
merged_slice
)
>
0
and
len
(
slice_chain
)
==
1
:
# no need to merge
continue
continue
attrs
=
dict
(
axes
=
[],
starts
=
[],
ends
=
[])
attrs
=
{
'axes'
:
[],
'starts'
:
[],
'ends'
:
[]}
for
axis
,
(
start
,
end
)
in
merged_slice
.
items
():
for
axis
,
(
start
,
end
)
in
merged_slice
.
items
():
attrs
[
'axes'
].
append
(
axis
)
attrs
[
'axes'
].
append
(
axis
)
attrs
[
'starts'
].
append
(
start
)
attrs
[
'starts'
].
append
(
start
)
...
...
onnx2fluid/onnx2fluid/symbolic.py
浏览文件 @
7c3e9379
...
@@ -38,6 +38,7 @@ DEFAULT_OP_MAPPING_FIELD_VALUES[
...
@@ -38,6 +38,7 @@ DEFAULT_OP_MAPPING_FIELD_VALUES[
DEFAULT_OP_MAPPING_FIELD_VALUES
[
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'OUTPUT_PERM'
]
=
None
# sampler: [idx_onnx_arg...]
'OUTPUT_PERM'
]
=
None
# sampler: [idx_onnx_arg...]
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'FILL_NAME_FIELD'
]
=
True
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'FILL_NAME_FIELD'
]
=
True
DEFAULT_OP_MAPPING_VALUES
=
list
(
DEFAULT_OP_MAPPING_FIELD_VALUES
.
values
())
DEFAULT_OP_MAPPING
=
{
DEFAULT_OP_MAPPING
=
{
## nil ops ##
## nil ops ##
...
@@ -145,24 +146,20 @@ DEFAULT_IOA_CONSTRAINTS = {
...
@@ -145,24 +146,20 @@ DEFAULT_IOA_CONSTRAINTS = {
def
_make_var_name
(
name
):
def
_make_var_name
(
name
):
"""
"""
make a valid variable name in Python code
make a valid variable name in Python code and in filesystem
"""
"""
if
name
==
''
:
if
name
==
''
:
return
'_'
return
'_'
if
name
[
0
].
isdigit
():
if
name
[
0
].
isdigit
():
return
'var_'
+
name
return
'var_'
+
name
for
s
in
'
*?
\\
/-:'
:
for
s
in
'
\\
|/:'
:
#
name
=
name
.
replace
(
s
,
'_'
)
name
=
name
.
replace
(
s
,
'_'
)
if
name
.
startswith
(
'_'
):
if
name
.
startswith
(
'_'
):
name
=
'var'
+
name
name
=
'var'
+
name
return
name
return
name
#def _value_info_or_none(value_infos, val_name):
# return value_infos.get(val_name, None)
def
_dtype
(
value_infos
,
val_name
):
def
_dtype
(
value_infos
,
val_name
):
return
_np
.
dtype
(
value_infos
[
val_name
][
'dtype'
])
return
_np
.
dtype
(
value_infos
[
val_name
][
'dtype'
])
...
@@ -204,7 +201,7 @@ def _const_weight_or_none(value_infos, val_name):
...
@@ -204,7 +201,7 @@ def _const_weight_or_none(value_infos, val_name):
def
_default
(
prog
,
op_type
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
def
_default
(
prog
,
op_type
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
info
=
DEFAULT_OP_MAPPING
[
op_type
]
info
=
DEFAULT_OP_MAPPING
[
op_type
]
info
.
extend
(
list
(
DEFAULT_OP_MAPPING_FIELD_VALUES
.
values
())
[
len
(
info
):])
info
.
extend
(
DEFAULT_OP_MAPPING_VALUES
[
len
(
info
):])
(
(
fluid_op
,
fluid_op
,
...
@@ -295,7 +292,7 @@ def _zeros_like(prog, val_ref, val_out, value_infos):
...
@@ -295,7 +292,7 @@ def _zeros_like(prog, val_ref, val_out, value_infos):
'Sub'
,
'Sub'
,
[
val_ref
,
val_ref
],
[
val_ref
,
val_ref
],
[
val_out
],
# val
[
val_out
],
# val
dict
(
axis
=
0
)
,
{
'axis'
:
0
}
,
value_infos
,
value_infos
,
)
)
...
@@ -317,11 +314,11 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
...
@@ -317,11 +314,11 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
'Pad'
,
'Pad'
,
[
val_name
],
[
val_name
],
[
val_padded
],
# val
[
val_padded
],
# val
dict
(
{
mode
=
'constant'
,
'mode'
:
'constant'
,
value
=
0.
,
'value'
:
0.
,
pads
=
pads
,
'pads'
:
pads
,
)
,
}
,
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
val_padded
,
name
=
val_padded
,
)
)
...
@@ -372,14 +369,14 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''):
...
@@ -372,14 +369,14 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''):
fluid_op
,
fluid_op
,
([
var_x
],
'X'
),
([
var_x
],
'X'
),
([
var_y
]
+
([
var_indices
]
if
has_indices
else
[]),
'Out'
,
'Indices'
),
([
var_y
]
+
([
var_indices
]
if
has_indices
else
[]),
'Out'
,
'Indices'
),
dict
(
{
global_pooling
=
False
,
'global_pooling'
:
False
,
adaptive
=
True
,
'adaptive'
:
True
,
exclusive
=
True
,
'exclusive'
:
True
,
require_index
=
has_indices
,
'require_index'
:
has_indices
,
pooling_type
=
pool_type
,
'pooling_type'
:
pool_type
,
ksize
=
pool_size
,
'ksize'
:
pool_size
,
)
,
}
,
)
)
...
@@ -419,12 +416,12 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
...
@@ -419,12 +416,12 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
fluid_op
,
fluid_op
,
([
var_x
],
'X'
),
([
var_x
],
'X'
),
([
var_y
],
'Out'
),
([
var_y
],
'Out'
),
dict
(
{
global_pooling
=
True
,
'global_pooling'
:
True
,
adaptive
=
False
,
'adaptive'
:
False
,
pooling_type
=
pool_type
,
'pooling_type'
:
pool_type
,
ksize
=
[
-
1
,
-
1
],
'ksize'
:
[
-
1
,
-
1
],
)
,
}
,
)
)
...
@@ -481,17 +478,17 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
...
@@ -481,17 +478,17 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
fluid_op
,
fluid_op
,
([
var_x
],
'X'
),
([
var_x
],
'X'
),
([
var_y
]
+
([
var_indices
]
if
has_indices
else
[]),
'Out'
,
'Indices'
),
([
var_y
]
+
([
var_indices
]
if
has_indices
else
[]),
'Out'
,
'Indices'
),
dict
(
{
global_pooling
=
False
,
'global_pooling'
:
False
,
adaptive
=
False
,
'adaptive'
:
False
,
exclusive
=
True
,
'exclusive'
:
True
,
require_index
=
has_indices
,
'require_index'
:
has_indices
,
pooling_type
=
pool_type
,
'pooling_type'
:
pool_type
,
ksize
=
pool_size
,
'ksize'
:
pool_size
,
strides
=
strides
,
'strides'
:
strides
,
paddings
=
paddings
,
'paddings'
:
paddings
,
ceil_mode
=
ceil_mode
,
'ceil_mode'
:
ceil_mode
,
)
,
}
,
)
)
...
@@ -506,11 +503,11 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
...
@@ -506,11 +503,11 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
# interpretation
# interpretation
spatial_scale
=
attrs
[
'spatial_scale'
]
# required
spatial_scale
=
attrs
[
'spatial_scale'
]
# required
pooled_height
,
pooled_width
=
attrs
[
'pooled_shape'
]
# required
pooled_height
,
pooled_width
=
attrs
[
'pooled_shape'
]
# required
od_attrs
=
dict
(
od_attrs
=
{
pooled_height
=
pooled_height
,
'pooled_height'
:
pooled_height
,
pooled_width
=
pooled_width
,
'pooled_width'
:
pooled_width
,
spatial_scale
=
spatial_scale
,
'spatial_scale'
:
spatial_scale
,
)
}
feature_attr
=
''
feature_attr
=
''
is_max_pool
=
fluid_op
==
'roi_pool'
is_max_pool
=
fluid_op
==
'roi_pool'
if
'sampling_ratio'
in
attrs
:
#
if
'sampling_ratio'
in
attrs
:
#
...
@@ -606,11 +603,11 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
...
@@ -606,11 +603,11 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
fluid_op
,
fluid_op
,
([
var_x
],
'X'
),
([
var_x
],
'X'
),
([
var_y
],
'Out'
),
([
var_y
],
'Out'
),
dict
(
{
interp_method
=
mode
,
'interp_method'
:
mode
,
out_h
=
out_shape_
[
0
],
'out_h '
:
out_shape_
[
0
],
out_w
=
out_shape_
[
1
],
'out_w '
:
out_shape_
[
1
],
)
,
}
,
)
)
...
@@ -662,7 +659,7 @@ def AffineGrid(prog, inputs, outputs, attrs, *args, name='', **kwargs):
...
@@ -662,7 +659,7 @@ def AffineGrid(prog, inputs, outputs, attrs, *args, name='', **kwargs):
fluid_op
,
fluid_op
,
([
var_theta
],
'Theta'
),
([
var_theta
],
'Theta'
),
([
var_grid
],
'Output'
),
([
var_grid
],
'Output'
),
dict
(
output_shape
=
size
)
,
# f**k you API
{
'output_shape'
:
size
}
,
# f**k you API
)
)
...
@@ -747,16 +744,17 @@ def BatchNormalization(prog,
...
@@ -747,16 +744,17 @@ def BatchNormalization(prog,
prog
.
VarDesc
(
var_saved_variance
)
prog
.
VarDesc
(
var_saved_variance
)
prog
.
OpDesc
(
prog
.
OpDesc
(
fluid_op
,
fluid_op
,
([
var_x
,
var_scale
,
var_b
,
var_mean
,
var_var
],
'X'
,
'Scale'
,
'Bias'
,
([
var_x
,
var_scale
,
var_b
,
var_mean
,
var_var
'Mean'
,
'Variance'
),
],
'X'
,
'Scale'
,
'Bias'
,
'Mean'
,
'Variance'
),
([
var_y
,
var_mean
,
var_saved_mean
,
var_saved_variance
,
var_var
],
'Y'
,
([
var_y
,
var_mean
,
var_saved_mean
,
var_saved_variance
,
var_var
'MeanOut'
,
'SavedMean'
,
'SavedVariance'
,
'VarianceOut'
),
],
'Y'
,
'MeanOut'
,
'SavedMean'
,
'SavedVariance'
,
'VarianceOut'
),
dict
(
{
is_test
=
1
,
'is_test'
:
1
,
data_layout
=
'NCHW'
,
'data_layout'
:
'NCHW'
,
use_global_stats
=
False
,
'use_global_stats'
:
False
,
momentum
=
momentum
,
'momentum'
:
momentum
,
epsilon
=
epsilon
),
'epsilon'
:
epsilon
,
},
)
)
...
@@ -796,11 +794,12 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -796,11 +794,12 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
fluid_op
,
fluid_op
,
([
var_input
],
'X'
),
([
var_input
],
'X'
),
([
var_output
],
'Out'
),
([
var_output
],
'Out'
),
dict
(
{
in_dtype
=
prog
.
Dtype
(
_dtype
(
value_infos
,
'in_dtype'
:
prog
.
Dtype
(
_dtype
(
value_infos
,
val_input
)),
# holy, required
val_input
)),
# holy, required
out_dtype
=
prog
.
Dtype
(
dtype
),
'out_dtype'
:
prog
.
Dtype
(
dtype
),
))
},
)
def
Concat
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
def
Concat
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
...
@@ -834,7 +833,7 @@ def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs):
...
@@ -834,7 +833,7 @@ def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs):
fluid_op
,
fluid_op
,
(
var_inps
,
*
([
'X'
]
*
len
(
var_inps
))),
(
var_inps
,
*
([
'X'
]
*
len
(
var_inps
))),
([
var_concat_result
],
'Out'
),
([
var_concat_result
],
'Out'
),
dict
(
axis
=
axis
)
,
{
'axis'
:
axis
}
,
)
)
...
@@ -886,11 +885,11 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -886,11 +885,11 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
fluid_op
,
fluid_op
,
([],
),
([],
),
([
var_output
],
'Out'
),
([
var_output
],
'Out'
),
dict
(
{
shape
=
shape
,
'shape'
:
shape
,
dtype
=
prog
.
Dtype
(
dtype
),
'dtype'
:
prog
.
Dtype
(
dtype
),
value
=
value
,
'value'
:
value
,
)
,
}
,
)
)
else
:
# list parameter -> const_value
else
:
# list parameter -> const_value
prog
.
Code
(
'# {} = {} # passed directly as literal'
.
format
(
prog
.
Code
(
'# {} = {} # passed directly as literal'
.
format
(
...
@@ -917,7 +916,7 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -917,7 +916,7 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
'this is not supported'
)
'this is not supported'
)
dtype
=
attrs
[
'value'
].
dtype
dtype
=
attrs
[
'value'
].
dtype
attrs
=
attrs
.
copy
()
attrs
=
attrs
.
copy
()
attrs
.
update
(
dict
(
shape
=
shape
,
dtype
=
dtype
)
)
# pass const
attrs
.
update
(
{
'shape'
:
shape
,
'dtype'
:
dtype
}
)
# pass const
prog
.
Code
(
'# shape:{}={} # const as literal'
.
format
(
var_shape
,
shape
))
prog
.
Code
(
'# shape:{}={} # const as literal'
.
format
(
var_shape
,
shape
))
prog
.
Op
(
prog
.
Op
(
...
@@ -1015,12 +1014,13 @@ def Conv(prog,
...
@@ -1015,12 +1014,13 @@ def Conv(prog,
fluid_op
,
fluid_op
,
([
var_x
,
var_w
],
'Input'
,
'Filter'
),
# , 'Bias', 'ResidualData'
([
var_x
,
var_w
],
'Input'
,
'Filter'
),
# , 'Bias', 'ResidualData'
([
var_conv
if
has_bias
else
var_y
],
'Output'
),
([
var_conv
if
has_bias
else
var_y
],
'Output'
),
dict
(
{
strides
=
strides
,
'strides'
:
strides
,
paddings
=
paddings
,
'paddings'
:
paddings
,
dilations
=
dilations
,
'dilations'
:
dilations
,
groups
=
num_groups
,
'groups'
:
num_groups
,
))
},
)
if
has_bias
:
if
has_bias
:
prog
.
VarDesc
(
var_conv
)
prog
.
VarDesc
(
var_conv
)
prog
.
IntermediateOp
(
prog
.
IntermediateOp
(
...
@@ -1028,7 +1028,7 @@ def Conv(prog,
...
@@ -1028,7 +1028,7 @@ def Conv(prog,
'Add'
,
'Add'
,
[
var_conv
,
var_b
],
#
[
var_conv
,
var_b
],
#
[
val_y
],
[
val_y
],
dict
(
axis
=
1
)
,
{
'axis'
:
1
}
,
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
(
name
+
'.bias'
),
name
=
(
name
+
'.bias'
),
)
)
...
@@ -1125,13 +1125,14 @@ def ConvTranspose(prog,
...
@@ -1125,13 +1125,14 @@ def ConvTranspose(prog,
fluid_op
,
fluid_op
,
([
var_x
,
var_w
],
'Input'
,
'Filter'
),
# , 'Bias', 'ResidualData'
([
var_x
,
var_w
],
'Input'
,
'Filter'
),
# , 'Bias', 'ResidualData'
([
var_conv
if
has_bias
else
var_y
],
'Output'
),
([
var_conv
if
has_bias
else
var_y
],
'Output'
),
dict
(
{
strides
=
strides
,
'strides'
:
strides
,
paddings
=
paddings
,
'paddings'
:
paddings
,
dilations
=
dilations
,
'dilations'
:
dilations
,
# output_size=output_size,
# 'output_size': output_size,
groups
=
num_groups
,
'groups'
:
num_groups
,
))
},
)
if
has_bias
:
if
has_bias
:
prog
.
VarDesc
(
var_conv
)
prog
.
VarDesc
(
var_conv
)
prog
.
IntermediateOp
(
prog
.
IntermediateOp
(
...
@@ -1139,7 +1140,7 @@ def ConvTranspose(prog,
...
@@ -1139,7 +1140,7 @@ def ConvTranspose(prog,
'Add'
,
'Add'
,
[
var_conv
,
var_b
],
#
[
var_conv
,
var_b
],
#
[
val_y
],
[
val_y
],
dict
(
axis
=
1
)
,
{
'axis'
:
1
}
,
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
(
name
+
'.bias'
),
name
=
(
name
+
'.bias'
),
)
)
...
@@ -1184,19 +1185,19 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1184,19 +1185,19 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'MatMul'
,
'MatMul'
,
[
val_a
,
val_b
],
[
val_a
,
val_b
],
[
val_mm
],
# val
[
val_mm
],
# val
dict
(
{
transpose_x
=
trans_a
,
'transpose_x'
:
trans_a
,
transpose_y
=
trans_b
,
'transpose_y'
:
trans_b
,
alpha
=
alpha
,
'alpha'
:
alpha
,
)
,
}
,
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
val_mm
,
name
=
val_mm
,
)
)
prog
.
op_descs
[
-
1
].
attrs
.
extend
(
prog
.
op_descs
[
-
1
].
attrs
.
extend
(
prog
.
OpDescAttrs
(
dict
(
prog
.
OpDescAttrs
(
{
transpose_X
=
trans_a
,
'transpose_X'
:
trans_a
,
transpose_Y
=
trans_b
,
'transpose_Y'
:
trans_b
,
)
))
# f**k you API
}
))
# f**k you API
if
beta
!=
0
:
if
beta
!=
0
:
if
beta
==
1.
:
# exactly
if
beta
==
1.
:
# exactly
prog
.
Op
(
prog
.
Op
(
...
@@ -1204,7 +1205,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1204,7 +1205,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'Add'
,
'Add'
,
[
val_mm
,
val_c
],
[
val_mm
,
val_c
],
[
val_y
],
# val
[
val_y
],
# val
dict
(
axis
=
1
)
,
{
'axis'
:
1
}
,
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
(
name
+
'_beta'
),
name
=
(
name
+
'_beta'
),
)
)
...
@@ -1226,7 +1227,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1226,7 +1227,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'Constant'
,
'Constant'
,
[],
[],
[
val_beta
],
# val
[
val_beta
],
# val
dict
(
value
=
beta
)
,
{
'value'
:
beta
}
,
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
val_beta
,
name
=
val_beta
,
)
)
...
@@ -1244,7 +1245,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1244,7 +1245,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'Add'
,
'Add'
,
[
val_mm
,
val_vm
],
[
val_mm
,
val_vm
],
[
val_y
],
# val
[
val_y
],
# val
dict
(
axis
=
1
)
,
{
'axis'
:
1
}
,
name
=
(
name
+
'_bias'
),
name
=
(
name
+
'_bias'
),
)
)
...
@@ -1261,8 +1262,13 @@ def GlobalAveragePool(prog,
...
@@ -1261,8 +1262,13 @@ def GlobalAveragePool(prog,
onnx::GlobalAveragePool-1:
onnx::GlobalAveragePool-1:
"""
"""
return
_global_pool
(
return
_global_pool
(
prog
,
prog
,
'avg'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
'avg'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
def
GlobalMaxPool
(
prog
,
def
GlobalMaxPool
(
prog
,
...
@@ -1277,60 +1283,13 @@ def GlobalMaxPool(prog,
...
@@ -1277,60 +1283,13 @@ def GlobalMaxPool(prog,
onnx::GlobalMaxPool-1:
onnx::GlobalMaxPool-1:
"""
"""
return
_global_pool
(
return
_global_pool
(
prog
,
prog
,
'max'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
'max'
,
inputs
,
outputs
,
#def LRN(
attrs
,
# prog, inputs, outputs, attrs, value_infos, name, # name required
value_infos
,
# *args, **kwargs):
name
=
name
)
# """
# onnx::LRN-1:
# """
#
# # I/O
# val_x, = inputs
# val_y, = outputs
# var_x = _make_var_name(val_x)
# var_y = _make_var_name(val_y)
#
# # interpretation
# fluid_op = 'lrn'
# size = attrs['size'] # required
# alpha = attrs.get('alpha', 0.0001) # optional
# beta = attrs.get('beta', 0.75) # optional
# bias = attrs.get('bias', 1.0) # optional
# name_attr = ', name={}'.format(repr(name)) if name else ''
#
# # generation
# prog.Code('{} = layers.{}({}'
# ', n={}'
# ', k={}'
# ', alpha={}'
# ', beta={}'
# '{})'
# .format(var_y,
# fluid_op,
# var_x,
# # attrs
# size,
# bias,
# alpha,
# beta,
# name_attr,
# ))
# var_mid = name + '.mid' # hidden variable
# prog.VarDesc(var_y)
# prog.VarDesc(var_mid)
# prog.OpDesc(fluid_op,
# ([var_x], 'X'),
# ([var_y, var_mid], 'Out', 'MidOut'),
# dict(n=size,
# k=bias,
# alpha=alpha,
# beta=beta,
# ),
# )
def
MaxPool
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
def
MaxPool
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
...
@@ -1375,7 +1334,7 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
...
@@ -1375,7 +1334,7 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
assume_pad2d
|=
data_shape
and
len
(
data_shape
)
==
4
# NCHW
assume_pad2d
|=
data_shape
and
len
(
data_shape
)
==
4
# NCHW
if
output_shape
:
if
output_shape
:
assume_pad2d
|=
output_shape
and
len
(
output_shape
)
==
4
# NCHW
assume_pad2d
|=
output_shape
and
len
(
output_shape
)
==
4
# NCHW
od_attrs
=
dict
(
pad_value
=
value
)
od_attrs
=
{
'pad_value'
:
value
}
if
assume_pad2d
:
if
assume_pad2d
:
fluid_op
=
'pad2d'
fluid_op
=
'pad2d'
pad2d_attr
=
', mode={}, data_format="NCHW"'
.
format
(
repr
(
mode
))
pad2d_attr
=
', mode={}, data_format="NCHW"'
.
format
(
repr
(
mode
))
...
@@ -1434,11 +1393,20 @@ def PRelu(prog,
...
@@ -1434,11 +1393,20 @@ def PRelu(prog,
var_y
=
_make_var_name
(
val_y
)
var_y
=
_make_var_name
(
val_y
)
# interpretation
# interpretation
mode
=
'channel'
slope_shape
=
_shape_or_none
(
value_infos
,
val_slope
)
if
slope_shape
is
not
None
:
if
len
(
slope_shape
)
==
0
:
mode
=
'all'
elif
len
(
slope_shape
)
>=
2
:
if
slope_shape
[
1
]
!=
_np
.
product
(
slope_shape
):
# not channel broadcasting
mode
=
'element'
fluid_op
=
'prelu'
fluid_op
=
'prelu'
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
if
embed_params
:
assert
name
!=
''
assert
name
!=
''
var_slope
=
'{}.w_0'
.
format
(
val_slope
)
var_slope
=
name
+
'.w_0'
value_infos
[
val_slope
].
setdefault
(
'embeded_as'
,
[]).
append
(
var_slope
)
value_infos
[
val_slope
].
setdefault
(
'embeded_as'
,
[]).
append
(
var_slope
)
param_attr
=
''
param_attr
=
''
else
:
else
:
...
@@ -1446,21 +1414,23 @@ def PRelu(prog,
...
@@ -1446,21 +1414,23 @@ def PRelu(prog,
param_attr
=
', param_attr={}'
.
format
(
repr
(
var_slope
))
param_attr
=
', param_attr={}'
.
format
(
repr
(
var_slope
))
# generation
# generation
prog
.
Code
(
'{} = layers.{}({}, mode="all"'
prog
.
Code
(
'{} = layers.{}({}'
', mode={}'
'{}{})'
.
format
(
'{}{})'
.
format
(
var_y
,
var_y
,
fluid_op
,
fluid_op
,
var_x
,
var_x
,
# attrs
# attrs
repr
(
mode
),
param_attr
,
param_attr
,
name_attr
,
name_attr
,
))
))
prog
.
VarDesc
(
var_y
)
prog
.
VarDesc
(
var_y
)
prog
.
OpDesc
(
prog
.
OpDesc
(
fluid_op
,
fluid_op
,
([
var_x
],
'X
'
),
([
var_x
,
var_slope
],
'X'
,
'Alpha
'
),
([
var_y
],
'Out'
),
([
var_y
],
'Out'
),
dict
(
mode
=
'all'
)
,
{
'mode'
:
mode
}
,
)
)
...
@@ -1524,7 +1494,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1524,7 +1494,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'Cast'
,
'Cast'
,
[
val_shape
],
[
val_shape
],
[
val_shape_int32
],
# var
[
val_shape_int32
],
# var
dict
(
to
=
_np
.
dtype
(
'int32'
))
,
# use np.dtype
{
'to'
:
_np
.
dtype
(
'int32'
)}
,
# use np.dtype
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
(
name
+
'_cast'
),
name
=
(
name
+
'_cast'
),
)
)
...
@@ -1549,14 +1519,14 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1549,14 +1519,14 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
fluid_op
,
fluid_op
,
([
var_data
],
'X'
),
([
var_data
],
'X'
),
([
var_reshaped
,
var_xshape
],
'Out'
,
'XShape'
),
([
var_reshaped
,
var_xshape
],
'Out'
,
'XShape'
),
dict
(
shape
=
shape
)
,
{
'shape'
:
shape
}
,
)
)
else
:
else
:
prog
.
OpDesc
(
prog
.
OpDesc
(
fluid_op
,
fluid_op
,
([
var_data
,
var_shape_int32
],
'X'
,
'Shape'
),
([
var_data
,
var_shape_int32
],
'X'
,
'Shape'
),
([
var_reshaped
,
var_xshape
],
'Out'
,
'XShape'
),
([
var_reshaped
,
var_xshape
],
'Out'
,
'XShape'
),
dict
(
shape
=
shape
)
,
{
'shape'
:
shape
}
,
)
)
...
@@ -1659,11 +1629,11 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -1659,11 +1629,11 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
fluid_op
,
fluid_op
,
([
var_data
],
'Input'
),
([
var_data
],
'Input'
),
([
var_output
],
'Out'
),
([
var_output
],
'Out'
),
dict
(
{
axes
=
axes
,
'axes'
:
axes
,
starts
=
starts
,
'starts'
:
starts
,
ends
=
ends
,
'ends'
:
ends
,
)
,
}
,
)
)
...
@@ -1701,10 +1671,10 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
...
@@ -1701,10 +1671,10 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
fluid_op
,
fluid_op
,
(
var_input
,
'X'
),
(
var_input
,
'X'
),
([
var_outs
],
*
([
'Out'
]
*
len
(
var_outs
))),
([
var_outs
],
*
([
'Out'
]
*
len
(
var_outs
))),
dict
(
{
axis
=
axis
,
'axis'
:
axis
,
sections
=
split
,
'sections'
:
split
,
)
,
}
,
)
)
...
@@ -1773,7 +1743,7 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
...
@@ -1773,7 +1743,7 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
fluid_op
,
fluid_op
,
([
var_input
],
'X'
),
([
var_input
],
'X'
),
([
var_output
],
'Out'
),
([
var_output
],
'Out'
),
dict
(
expand_times
=
repeats
)
,
{
'expand_times'
:
repeats
}
,
)
)
...
@@ -1812,7 +1782,7 @@ def Transpose(prog, inputs, outputs, attrs, *args, name='', **kwargs):
...
@@ -1812,7 +1782,7 @@ def Transpose(prog, inputs, outputs, attrs, *args, name='', **kwargs):
fluid_op
,
fluid_op
,
([
var_data
],
'X'
),
([
var_data
],
'X'
),
([
var_transposed
,
var_xshape
],
'Out'
,
'XShape'
),
([
var_transposed
,
var_xshape
],
'Out'
,
'XShape'
),
dict
(
axis
=
perm
)
,
# f**k you API
{
'axis'
:
perm
}
,
# f**k you API
)
)
...
@@ -1902,8 +1872,7 @@ if __name__ == '__main__':
...
@@ -1902,8 +1872,7 @@ if __name__ == '__main__':
[
'input'
],
[
'input'
],
[
'output'
],
[
'output'
],
dict
(
to
=
2
),
# TensorProto.UINT8
dict
(
to
=
2
),
# TensorProto.UINT8
dict
(
dict
(
input
=
dict
(
shape
=
(
2
,
3
),
dtype
=
np
.
float32
),
input
=
dict
(
shape
=
(
2
,
3
),
dtype
=
np
.
float32
),
output
=
dict
(
shape
=
(
2
,
3
),
dtype
=
np
.
uint8
)),
output
=
dict
(
shape
=
(
2
,
3
),
dtype
=
np
.
uint8
)),
)
)
logger
.
info
(
'Cast program:
\n
%s'
,
prog
)
logger
.
info
(
'Cast program:
\n
%s'
,
prog
)
...
@@ -2101,8 +2070,7 @@ if __name__ == '__main__':
...
@@ -2101,8 +2070,7 @@ if __name__ == '__main__':
logger
.
info
(
'Less program:
\n
%s'
,
prog
)
logger
.
info
(
'Less program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
_default
(
_default
(
prog
,
prog
,
'MatMul'
,
[
'A'
,
'B'
],
[
'Y'
],
'MatMul'
,
[
'A'
,
'B'
],
[
'Y'
],
dict
(),
dict
(),
dict
(
Y
=
dict
(
shape
=
(
2
,
8
),
dtype
=
np
.
float32
)),
dict
(
Y
=
dict
(
shape
=
(
2
,
8
),
dtype
=
np
.
float32
)),
...
@@ -2168,11 +2136,9 @@ if __name__ == '__main__':
...
@@ -2168,11 +2136,9 @@ if __name__ == '__main__':
logger
.
info
(
'PRelu program:
\n
%s'
,
prog
)
logger
.
info
(
'PRelu program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
Tile
(
Tile
(
prog
,
[
'input'
,
'repeats'
],
[
'output'
],
prog
,
[
'input'
,
'repeats'
],
[
'output'
],
dict
(),
dict
(),
dict
(
dict
(
repeats
=
dict
(
const_value
=
[
1
,
2
]),
repeats
=
dict
(
const_value
=
[
1
,
2
]),
output
=
dict
(
shape
=
(
2
,
2
,
4
),
dtype
=
np
.
float32
)),
output
=
dict
(
shape
=
(
2
,
2
,
4
),
dtype
=
np
.
float32
)),
name
=
'Tile'
)
name
=
'Tile'
)
logger
.
info
(
'Tile program:
\n
%s'
,
prog
)
logger
.
info
(
'Tile program:
\n
%s'
,
prog
)
onnx2fluid/onnx2fluid/torch_export_helper.py
浏览文件 @
7c3e9379
...
@@ -12,25 +12,25 @@ import torch
...
@@ -12,25 +12,25 @@ import torch
from
collections
import
OrderedDict
as
Dict
from
collections
import
OrderedDict
as
Dict
def
_
ensure_list
(
obj
):
def
ensure_list
(
obj
):
if
isinstance
(
obj
,
(
list
,
set
,
tuple
)):
if
isinstance
(
obj
,
(
list
,
tuple
,
set
)):
return
list
(
obj
)
return
list
(
obj
)
return
[
obj
]
return
[
obj
]
def
_
ensure_tuple
(
obj
):
def
ensure_tuple
(
obj
):
if
isinstance
(
obj
,
(
list
,
set
,
tuple
)):
if
isinstance
(
obj
,
(
tuple
,
list
,
set
)):
return
tuple
(
obj
)
return
tuple
(
obj
)
return
(
obj
,
)
return
(
obj
,
)
def
_
flatten_list
(
obj
,
out
=
None
):
def
flatten_list
(
obj
,
out
=
None
):
assert
isinstance
(
obj
,
list
)
assert
isinstance
(
obj
,
list
)
if
out
is
None
:
if
out
is
None
:
out
=
type
(
obj
)()
out
=
type
(
obj
)()
for
item
in
obj
:
for
item
in
obj
:
if
isinstance
(
item
,
list
):
if
isinstance
(
item
,
list
):
_
flatten_list
(
item
,
out
)
flatten_list
(
item
,
out
)
else
:
else
:
out
.
append
(
item
)
out
.
append
(
item
)
return
out
return
out
...
@@ -41,7 +41,7 @@ def export_data(state_dict, prefix=''):
...
@@ -41,7 +41,7 @@ def export_data(state_dict, prefix=''):
export binary data with meta text for raw C++ inference engines
export binary data with meta text for raw C++ inference engines
"""
"""
def
_str
(
obj
):
def
str_
(
obj
):
if
isinstance
(
obj
,
(
tuple
,
list
)):
if
isinstance
(
obj
,
(
tuple
,
list
)):
return
str
(
obj
)[
1
:
-
1
].
replace
(
' '
,
''
)
return
str
(
obj
)[
1
:
-
1
].
replace
(
' '
,
''
)
return
str
(
obj
)
return
str
(
obj
)
...
@@ -52,14 +52,14 @@ def export_data(state_dict, prefix=''):
...
@@ -52,14 +52,14 @@ def export_data(state_dict, prefix=''):
data
=
None
data
=
None
if
torch
and
torch
.
is_tensor
(
value
):
if
torch
and
torch
.
is_tensor
(
value
):
data
=
value
.
data
.
cpu
().
numpy
()
data
=
value
.
data
.
cpu
().
numpy
()
elif
np
and
isinstance
(
value
,
np
.
ndarray
):
elif
isinstance
(
value
,
np
.
ndarray
):
data
=
value
data
=
value
if
data
is
not
None
:
if
data
is
not
None
:
data
.
tofile
(
'{}{}.bin'
.
format
(
prefix_
,
key
))
data
.
tofile
(
'{}{}.bin'
.
format
(
prefix_
,
key
))
fp
.
write
(
'{}.dtype={}
\n
'
.
format
(
key
,
_str
(
data
.
dtype
.
name
)))
fp
.
write
(
'{}.dtype={}
\n
'
.
format
(
key
,
str_
(
data
.
dtype
.
name
)))
fp
.
write
(
'{}.shape={}
\n
'
.
format
(
key
,
_str
(
data
.
shape
)))
fp
.
write
(
'{}.shape={}
\n
'
.
format
(
key
,
str_
(
data
.
shape
)))
else
:
else
:
fp
.
write
(
'{}={}
\n
'
.
format
(
key
,
_str
(
value
)))
fp
.
write
(
'{}={}
\n
'
.
format
(
key
,
str_
(
value
)))
fp
.
close
()
fp
.
close
()
...
@@ -75,43 +75,42 @@ def export_onnx_with_validation(model,
...
@@ -75,43 +75,42 @@ def export_onnx_with_validation(model,
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
"""
"""
is_
list_or_tuple
=
lambda
x
:
isinstance
(
x
,
(
list
,
tuple
))
is_
tuple_or_list
=
lambda
x
:
isinstance
(
x
,
(
tuple
,
list
))
def
_
tensors_to_arrays
(
tensors
):
def
tensors_to_arrays
(
tensors
):
if
torch
.
is_tensor
(
tensors
):
if
torch
.
is_tensor
(
tensors
):
return
tensors
.
data
.
cpu
().
numpy
()
return
tensors
.
data
.
cpu
().
numpy
()
arrays
=
[]
arrays
=
[]
for
tensor
in
tensors
:
for
tensor
in
tensors
:
arrays
.
append
(
_
tensors_to_arrays
(
tensor
))
arrays
.
append
(
tensors_to_arrays
(
tensor
))
return
arrays
return
arrays
def
_
zip_dict
(
keys
,
values
):
def
zip_dict
(
keys
,
values
):
ret
=
Dict
()
ret
=
Dict
()
for
idx
,
(
key
,
value
)
in
enumerate
(
zip
(
keys
,
values
)):
for
idx
,
(
key
,
value
)
in
enumerate
(
zip
(
keys
,
values
)):
is_key_list
=
is_
list_or_tuple
(
key
)
is_key_list
=
is_
tuple_or_list
(
key
)
is_value_list
=
is_
list_or_tuple
(
value
)
is_value_list
=
is_
tuple_or_list
(
value
)
assert
is_key_list
==
is_value_list
,
'keys and values mismatch'
assert
is_key_list
==
is_value_list
,
'keys and values mismatch'
if
is_value_list
:
if
is_value_list
:
ret
[
str
(
idx
)]
=
_
zip_dict
(
key
,
value
)
ret
[
str
(
idx
)]
=
zip_dict
(
key
,
value
)
else
:
else
:
ret
[
key
]
=
value
ret
[
key
]
=
value
return
ret
return
ret
torch_inputs
=
_ensure_tuple
(
inputs
)
# WORKAROUND: for torch.onnx
torch_inputs
=
ensure_tuple
(
inputs
)
# WORKAROUND: for torch.onnx
outputs
=
torch
.
onnx
.
export
(
outputs
=
torch
.
onnx
.
export
(
model
,
model
,
torch_inputs
,
torch_inputs
,
export_basepath
+
'.onnx'
,
export_basepath
+
'.onnx'
,
input_names
=
_
flatten_list
(
input_names
),
input_names
=
flatten_list
(
input_names
),
output_names
=
_
flatten_list
(
output_names
),
output_names
=
flatten_list
(
output_names
),
*
args
,
*
args
,
**
kwargs
)
**
kwargs
)
if
outputs
is
None
:
# WORKAROUND: for torch.onnx
if
outputs
is
None
:
# WORKAROUND: for torch.onnx
outputs
=
model
(
*
inputs
)
outputs
=
model
(
*
inputs
)
torch_outputs
=
_
ensure_tuple
(
outputs
)
torch_outputs
=
ensure_tuple
(
outputs
)
inputs
=
_zip_dict
(
input_names
,
_
tensors_to_arrays
(
torch_inputs
))
inputs
=
zip_dict
(
input_names
,
tensors_to_arrays
(
torch_inputs
))
outputs
=
_zip_dict
(
output_names
,
_
tensors_to_arrays
(
torch_outputs
))
outputs
=
zip_dict
(
output_names
,
tensors_to_arrays
(
torch_outputs
))
if
use_npz
:
if
use_npz
:
np
.
savez
(
export_basepath
+
'.npz'
,
inputs
=
inputs
,
outputs
=
outputs
)
np
.
savez
(
export_basepath
+
'.npz'
,
inputs
=
inputs
,
outputs
=
outputs
)
else
:
else
:
...
...
onnx2fluid/onnx2fluid/validation.py
浏览文件 @
7c3e9379
...
@@ -9,22 +9,21 @@ Created on Fri Mar 22 12:17:19 2019
...
@@ -9,22 +9,21 @@ Created on Fri Mar 22 12:17:19 2019
import
importlib
,
logging
,
os
,
sys
import
importlib
,
logging
,
os
,
sys
def
_
flatten_dict
(
obj
,
out
=
None
):
def
flatten_dict
(
obj
,
out
=
None
):
assert
isinstance
(
obj
,
dict
)
assert
isinstance
(
obj
,
dict
)
if
out
is
None
:
if
out
is
None
:
out
=
type
(
obj
)()
out
=
type
(
obj
)()
for
key
,
value
in
obj
.
items
():
for
key
,
value
in
obj
.
items
():
if
isinstance
(
value
,
dict
):
if
isinstance
(
value
,
dict
):
_
flatten_dict
(
value
,
out
)
flatten_dict
(
value
,
out
)
else
:
else
:
assert
key
not
in
out
assert
key
not
in
out
out
[
key
]
=
value
out
[
key
]
=
value
return
out
return
out
def
_ensure_list
(
obj
):
def
ensure_list
(
obj
):
for
cls
in
[
list
,
set
,
tuple
]:
if
isinstance
(
obj
,
(
list
,
tuple
,
set
)):
if
isinstance
(
obj
,
cls
):
return
list
(
obj
)
return
list
(
obj
)
return
[
obj
]
return
[
obj
]
...
@@ -33,7 +32,7 @@ def validate(fluid_model_filename,
...
@@ -33,7 +32,7 @@ def validate(fluid_model_filename,
golden_data_filename
,
golden_data_filename
,
model_func_name
=
'inference'
,
model_func_name
=
'inference'
,
atol
=
1e-3
,
atol
=
1e-3
,
rtol
=
1e-
4
,
rtol
=
1e-
3
,
save_inference_model
=
False
,
save_inference_model
=
False
,
**
kwargs
):
**
kwargs
):
"""
"""
...
@@ -56,8 +55,8 @@ def validate(fluid_model_filename,
...
@@ -56,8 +55,8 @@ def validate(fluid_model_filename,
prog
,
_
,
var_outs
=
fluid
.
io
.
load_inference_model
(
fluid_model_dir
,
exe
)
prog
,
_
,
var_outs
=
fluid
.
io
.
load_inference_model
(
fluid_model_dir
,
exe
)
out_names
=
var_outs
# HINT: pass var if fetch ops already created
out_names
=
var_outs
# HINT: pass var if fetch ops already created
logger
.
info
(
'model load passed'
)
logger
.
info
(
'model load passed'
)
elif
basename
.
endswith
(
'.py'
):
# is
p
ython code
elif
basename
.
endswith
(
'.py'
):
# is
P
ython code
logger
.
debug
(
'using
python
code file %s'
,
basename
)
logger
.
debug
(
'using code file %s'
,
basename
)
module_name
,
_
=
os
.
path
.
splitext
(
basename
)
module_name
,
_
=
os
.
path
.
splitext
(
basename
)
sys_path
=
sys
.
path
.
copy
()
sys_path
=
sys
.
path
.
copy
()
sys
.
path
.
append
(
fluid_model_dir
)
sys
.
path
.
append
(
fluid_model_dir
)
...
@@ -73,14 +72,15 @@ def validate(fluid_model_filename,
...
@@ -73,14 +72,15 @@ def validate(fluid_model_filename,
func
)
func
)
var_outs
=
func
()
var_outs
=
func
()
var_outs
=
_
ensure_list
(
var_outs
)
var_outs
=
ensure_list
(
var_outs
)
out_names
=
[
var
.
name
for
var
in
var_outs
out_names
=
[
var
.
name
for
var
in
var_outs
]
# HINT: pass string to create fetch ops
]
# HINT: pass string to create fetch ops
logger
.
info
(
'import passed'
)
logger
.
info
(
'import passed'
)
prog
=
fluid
.
default_main_program
()
prog
=
fluid
.
default_main_program
()
fluid
.
io
.
load_persistables
(
fluid
.
io
.
load_persistables
(
executor
=
exe
,
executor
=
exe
,
dirname
=
fluid_model_dir
,
main_program
=
prog
)
dirname
=
fluid_model_dir
,
main_program
=
prog
)
logger
.
info
(
'weight load passed'
)
logger
.
info
(
'weight load passed'
)
else
:
else
:
raise
ValueError
(
'unsupported Paddle fluid model filename'
)
raise
ValueError
(
'unsupported Paddle fluid model filename'
)
...
@@ -95,15 +95,14 @@ def validate(fluid_model_filename,
...
@@ -95,15 +95,14 @@ def validate(fluid_model_filename,
test_data
=
np
.
load
(
golden_data_filename
,
encoding
=
'bytes'
).
tolist
()
test_data
=
np
.
load
(
golden_data_filename
,
encoding
=
'bytes'
).
tolist
()
input_data
=
test_data
[
'inputs'
]
input_data
=
test_data
[
'inputs'
]
output_data
=
test_data
[
'outputs'
]
output_data
=
test_data
[
'outputs'
]
input_data
=
_
flatten_dict
(
input_data
)
input_data
=
flatten_dict
(
input_data
)
output_data
=
_
flatten_dict
(
output_data
)
output_data
=
flatten_dict
(
output_data
)
logger
.
info
(
'found %d I/O golden data, starting test ...'
,
logger
.
info
(
'found %d I/O golden data, starting test ...'
,
len
(
input_data
)
+
len
(
output_data
))
len
(
input_data
)
+
len
(
output_data
))
# DEBUG: reload test for
p
ython code
# DEBUG: reload test for
P
ython code
if
basename
.
endswith
(
'.py'
)
and
save_inference_model
:
if
basename
.
endswith
(
'.py'
)
and
save_inference_model
:
fluid
.
io
.
save_inference_model
(
fluid
.
io
.
save_inference_model
(
fluid_model_dir
,
fluid_model_dir
,
input_data
.
keys
(),
input_data
.
keys
(),
var_outs
,
var_outs
,
exe
,
exe
,
...
@@ -122,8 +121,7 @@ def validate(fluid_model_filename,
...
@@ -122,8 +121,7 @@ def validate(fluid_model_filename,
for
(
name
,
truth
),
output
in
zip
(
output_data
.
items
(),
outputs
):
for
(
name
,
truth
),
output
in
zip
(
output_data
.
items
(),
outputs
):
logger
.
info
(
'testing output {} ...'
.
format
(
name
))
logger
.
info
(
'testing output {} ...'
.
format
(
name
))
try
:
try
:
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
output
,
output
,
truth
,
truth
,
rtol
=
rtol
,
rtol
=
rtol
,
atol
=
atol
,
atol
=
atol
,
...
@@ -174,7 +172,7 @@ if __name__ == '__main__':
...
@@ -174,7 +172,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
parser
.
add_argument
(
'--rtol'
,
'--rtol'
,
type
=
float
,
type
=
float
,
default
=
1e-
4
,
default
=
1e-
2
,
help
=
'assertion relative tolerance for validation'
,
help
=
'assertion relative tolerance for validation'
,
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -188,8 +186,7 @@ if __name__ == '__main__':
...
@@ -188,8 +186,7 @@ if __name__ == '__main__':
golden_data_filename
=
args
.
test_data
golden_data_filename
=
args
.
test_data
atol
,
rtol
=
args
.
atol
,
args
.
rtol
atol
,
rtol
=
args
.
atol
,
args
.
rtol
validate
(
validate
(
fluid_model_filename
,
fluid_model_filename
,
golden_data_filename
,
golden_data_filename
,
atol
=
atol
,
atol
=
atol
,
rtol
=
rtol
,
rtol
=
rtol
,
...
...
onnx2fluid/onnx2fluid/writer.py
浏览文件 @
7c3e9379
...
@@ -11,6 +11,8 @@ from __future__ import division
...
@@ -11,6 +11,8 @@ from __future__ import division
import
logging
,
os
import
logging
,
os
import
numpy
as
np
import
numpy
as
np
from
collections
import
OrderedDict
as
Dict
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
from
.
import
symbolic
from
.
import
symbolic
...
@@ -30,7 +32,7 @@ __all__ = [
...
@@ -30,7 +32,7 @@ __all__ = [
]
]
def
_
irepr
(
obj
,
to
=
'_'
):
def
irepr
(
obj
,
to
=
'_'
):
"""inline repr"""
"""inline repr"""
s
=
repr
(
obj
)
s
=
repr
(
obj
)
...
@@ -41,12 +43,12 @@ def _irepr(obj, to='_'):
...
@@ -41,12 +43,12 @@ def _irepr(obj, to='_'):
return
s
return
s
def
_
flatten_list
(
obj
,
out
=
None
):
def
flatten_list
(
obj
,
out
=
None
):
if
out
is
None
:
if
out
is
None
:
out
=
type
(
obj
)()
out
=
type
(
obj
)()
for
item
in
obj
:
for
item
in
obj
:
if
isinstance
(
item
,
list
):
if
isinstance
(
item
,
list
):
_
flatten_list
(
item
,
out
)
flatten_list
(
item
,
out
)
else
:
else
:
out
.
append
(
item
)
out
.
append
(
item
)
return
out
return
out
...
@@ -59,7 +61,7 @@ def make_attr_name(name):
...
@@ -59,7 +61,7 @@ def make_attr_name(name):
if
name
==
''
:
if
name
==
''
:
raise
ValueError
(
'name should not be empty'
)
raise
ValueError
(
'name should not be empty'
)
for
s
in
'
*?
\\
/-
:'
:
#
for
s
in
'
\\
|/
:'
:
#
name
=
name
.
replace
(
s
,
'_'
)
name
=
name
.
replace
(
s
,
'_'
)
if
not
name
.
startswith
(
'_'
):
if
not
name
.
startswith
(
'_'
):
name
=
'_'
+
name
name
=
'_'
+
name
...
@@ -130,8 +132,8 @@ class Program(object):
...
@@ -130,8 +132,8 @@ class Program(object):
od_attr
.
type
=
framework_pb2
.
STRING
od_attr
.
type
=
framework_pb2
.
STRING
od_attr
.
s
=
value
od_attr
.
s
=
value
elif
isinstance
(
value
,
list
):
elif
isinstance
(
value
,
list
):
if
len
(
value
)
>
0
:
if
len
(
value
)
>
0
:
# TODO: test all items
if
isinstance
(
value
,
if
isinstance
(
value
[
0
]
,
bool
):
# bool.mro() = [bool, int, object]
bool
):
# bool.mro() = [bool, int, object]
od_attr
.
type
=
framework_pb2
.
BOOLEANS
od_attr
.
type
=
framework_pb2
.
BOOLEANS
od_attr
.
bools
.
extend
(
value
)
od_attr
.
bools
.
extend
(
value
)
...
@@ -164,14 +166,15 @@ class Program(object):
...
@@ -164,14 +166,15 @@ class Program(object):
self
.
code_mutable
=
True
self
.
code_mutable
=
True
self
.
codes
=
[]
self
.
codes
=
[]
self
.
op_descs
=
[]
self
.
op_descs
=
[]
self
.
var_descs
=
[]
self
.
var_descs
=
Dict
()
def
__repr__
(
self
):
def
__repr__
(
self
):
return
(
'Program(code mutable: {}) with:
\n
'
return
(
'Program(code mutable: {}) with:
\n
'
'codes: {}
\n
'
'codes: {}
\n
'
'op_descs: {}
\n
'
'op_descs: {}
\n
'
'var_descs: {}
\n
'
).
format
(
self
.
code_mutable
,
self
.
codes
,
'var_descs: {}
\n
'
).
format
(
self
.
code_mutable
,
self
.
codes
,
self
.
op_descs
,
self
.
var_descs
)
self
.
op_descs
,
list
(
self
.
var_descs
.
values
()))
def
Code
(
self
,
code
):
def
Code
(
self
,
code
):
"""
"""
...
@@ -182,7 +185,7 @@ class Program(object):
...
@@ -182,7 +185,7 @@ class Program(object):
self
.
codes
.
append
(
code
)
self
.
codes
.
append
(
code
)
def
OpDesc
(
self
,
def
OpDesc
(
self
,
nam
e
,
op_typ
e
,
input_val_keys
=
None
,
input_val_keys
=
None
,
output_val_keys
=
None
,
output_val_keys
=
None
,
attrs
=
None
):
attrs
=
None
):
...
@@ -191,7 +194,7 @@ class Program(object):
...
@@ -191,7 +194,7 @@ class Program(object):
"""
"""
desc
=
framework_pb2
.
OpDesc
()
desc
=
framework_pb2
.
OpDesc
()
desc
.
type
=
nam
e
desc
.
type
=
op_typ
e
if
input_val_keys
is
not
None
:
if
input_val_keys
is
not
None
:
desc
.
inputs
.
extend
(
self
.
OpDescVars
(
*
input_val_keys
))
desc
.
inputs
.
extend
(
self
.
OpDescVars
(
*
input_val_keys
))
if
output_val_keys
is
not
None
:
if
output_val_keys
is
not
None
:
...
@@ -202,7 +205,7 @@ class Program(object):
...
@@ -202,7 +205,7 @@ class Program(object):
return
desc
return
desc
def
VarDesc
(
self
,
def
VarDesc
(
self
,
name
,
var_
name
,
persistable
=
False
,
persistable
=
False
,
value_info
=
None
,
value_info
=
None
,
remove_batch
=
None
):
remove_batch
=
None
):
...
@@ -210,24 +213,15 @@ class Program(object):
...
@@ -210,24 +213,15 @@ class Program(object):
add VarDesc,
add VarDesc,
"""
"""
assert
var_name
not
in
self
.
var_descs
,
'var naming conflicted'
var_desc
=
framework_pb2
.
VarDesc
()
var_desc
=
framework_pb2
.
VarDesc
()
var_desc
.
name
=
name
var_desc
.
name
=
var_
name
var_desc
.
persistable
=
persistable
var_desc
.
persistable
=
persistable
var_desc
.
type
.
type
=
framework_pb2
.
VarType
.
LOD_TENSOR
var_desc
.
type
.
type
=
framework_pb2
.
VarType
.
LOD_TENSOR
self
.
var_descs
[
var_name
]
=
var_desc
if
value_info
and
'dtype'
in
value_info
:
if
value_info
:
tensor_desc
=
var_desc
.
type
.
lod_tensor
.
tensor
self
.
VarTypeInfo
(
var_name
,
value_info
,
remove_batch
=
remove_batch
)
tensor_desc
.
data_type
=
self
.
Dtype
(
value_info
[
'dtype'
])
# required
if
'shape'
in
value_info
:
tensor_desc
.
dims
.
extend
(
value_info
[
'shape'
])
if
len
(
value_info
[
'shape'
])
>
0
:
# skip scalars
if
remove_batch
is
None
:
remove_batch
=
value_info
.
get
(
'remove_batch'
,
not
persistable
)
if
remove_batch
:
tensor_desc
.
dims
[
0
]
=
-
1
self
.
var_descs
.
append
(
var_desc
)
def
Op
(
self
,
domain
,
op_type
,
*
args
,
**
kwargs
):
def
Op
(
self
,
domain
,
op_type
,
*
args
,
**
kwargs
):
"""
"""
...
@@ -261,13 +255,40 @@ class Program(object):
...
@@ -261,13 +255,40 @@ class Program(object):
else
:
else
:
self
.
code_mutable
=
code_mutable
self
.
code_mutable
=
code_mutable
def
VarTypeInfo
(
self
,
var_name
,
value_info
,
remove_batch
=
None
):
"""
set value_info for var
"""
if
var_name
not
in
self
.
var_descs
:
return
dtype
=
value_info
.
get
(
'dtype'
,
None
)
if
dtype
is
None
:
return
var_desc
=
self
.
var_descs
[
var_name
]
tensor_desc
=
var_desc
.
type
.
lod_tensor
.
tensor
tensor_desc
.
data_type
=
self
.
Dtype
(
dtype
)
# required
shape
=
value_info
.
get
(
'shape'
,
None
)
if
shape
is
not
None
:
tensor_desc
.
dims
.
extend
(
shape
)
if
len
(
shape
)
>
0
:
# skip scalars
if
remove_batch
is
None
:
remove_batch
=
value_info
.
get
(
'remove_batch'
,
False
)
#not persistable)
if
remove_batch
:
tensor_desc
.
dims
[
0
]
=
-
1
class
Writer
(
object
):
class
Writer
(
object
):
"""
"""
fluid code and desc writter
fluid code and desc writter
"""
"""
CODE_INDENT
=
' '
*
4
# CODE_INDENT = ' ' * 4
CODE_INDENT
=
'
\t
'
@
staticmethod
@
staticmethod
def
header_code
(
func_name
,
info
=
''
):
def
header_code
(
func_name
,
info
=
''
):
...
@@ -275,7 +296,7 @@ class Writer(object):
...
@@ -275,7 +296,7 @@ class Writer(object):
Python header codes
Python header codes
"""
"""
codes
=
list
()
codes
=
[]
codes
.
append
(
'"""'
)
codes
.
append
(
'"""'
)
codes
.
append
(
'This code is generated by onnx2fluid.'
)
codes
.
append
(
'This code is generated by onnx2fluid.'
)
codes
.
append
(
'{}'
.
format
(
info
))
codes
.
append
(
'{}'
.
format
(
info
))
...
@@ -299,9 +320,8 @@ class Writer(object):
...
@@ -299,9 +320,8 @@ class Writer(object):
prog
.
Code
(
'# {}, {}::{}: {} -> {}, {}'
.
format
(
name
,
domain
,
op_type
,
prog
.
Code
(
'# {}, {}::{}: {} -> {}, {}'
.
format
(
name
,
domain
,
op_type
,
inputs
,
outputs
,
inputs
,
outputs
,
_irepr
(
attrs
,
to
=
', '
)))
irepr
(
attrs
,
to
=
', '
)))
prog
.
Op
(
prog
.
Op
(
domain
,
domain
,
op_type
,
op_type
,
inputs
,
inputs
,
outputs
,
outputs
,
...
@@ -367,10 +387,11 @@ class Writer(object):
...
@@ -367,10 +387,11 @@ class Writer(object):
'feed'
,
'feed'
,
([
'feed'
],
'X'
),
([
'feed'
],
'X'
),
([
var_name
],
'Out'
),
([
var_name
],
'Out'
),
dict
(
col
=
idx
)
,
{
'col'
:
idx
}
,
)
)
prog
.
VarDesc
(
prog
.
VarDesc
(
var_name
,
var_name
,
value_info
=
value_info
,
remove_batch
=
remove_batch
)
value_info
=
value_info
,
remove_batch
=
remove_batch
)
@
staticmethod
@
staticmethod
def
emit_outputs
(
prog
,
names
):
#, value_infos
def
emit_outputs
(
prog
,
names
):
#, value_infos
...
@@ -387,7 +408,7 @@ class Writer(object):
...
@@ -387,7 +408,7 @@ class Writer(object):
'fetch'
,
'fetch'
,
([
var_name
],
'X'
),
([
var_name
],
'X'
),
([
'fetch'
],
'Out'
),
([
'fetch'
],
'Out'
),
dict
(
col
=
idx
)
,
{
'col'
:
idx
}
,
)
)
# var is emitted over ops
# var is emitted over ops
prog
.
Code
(
code
)
prog
.
Code
(
code
)
...
@@ -398,7 +419,7 @@ class Writer(object):
...
@@ -398,7 +419,7 @@ class Writer(object):
flatten codes in program
flatten codes in program
"""
"""
for
code
in
_
flatten_list
(
others
):
for
code
in
flatten_list
(
others
):
codes
.
append
(
Writer
.
CODE_INDENT
*
indent
+
code
)
codes
.
append
(
Writer
.
CODE_INDENT
*
indent
+
code
)
return
codes
return
codes
...
@@ -451,7 +472,7 @@ class Writer(object):
...
@@ -451,7 +472,7 @@ class Writer(object):
Writer
.
add_codes
(
codes
,
body_code
,
1
)
Writer
.
add_codes
(
codes
,
body_code
,
1
)
fp
=
open
(
filename
,
'w'
)
fp
=
open
(
filename
,
'w'
)
for
code
in
_
flatten_list
(
codes
):
for
code
in
flatten_list
(
codes
):
fp
.
write
(
code
)
fp
.
write
(
code
)
fp
.
write
(
'
\n
'
)
fp
.
write
(
'
\n
'
)
fp
.
close
()
fp
.
close
()
...
...
onnx2fluid/setup.cfg
浏览文件 @
7c3e9379
...
@@ -54,6 +54,8 @@ zip_safe = True
...
@@ -54,6 +54,8 @@ zip_safe = True
[options.entry_points]
[options.entry_points]
console_scripts =
console_scripts =
onnx2fluid = onnx2fluid.__main__
onnx2fluid = onnx2fluid.__main__
onnx2fluid_convert = onnx2fluid.conversion
onnx2fluid_validate = onnx2fluid.validation
# 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下
# 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下
# 仅支持文件,不支持目录,但可以使用通配
# 仅支持文件,不支持目录,但可以使用通配
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录