Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
826481c4
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
826481c4
编写于
4月 29, 2019
作者:
M
Macrobull
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add type shape inference
上级
7c3e9379
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
401 addition
and
315 deletion
+401
-315
onnx2fluid/examples/convert_data_npz.py
onnx2fluid/examples/convert_data_npz.py
+3
-3
onnx2fluid/examples/convert_data_pb.py
onnx2fluid/examples/convert_data_pb.py
+3
-3
onnx2fluid/examples/onnx_model_zoo.sh
onnx2fluid/examples/onnx_model_zoo.sh
+1
-1
onnx2fluid/onnx2fluid/__main__.py
onnx2fluid/onnx2fluid/__main__.py
+8
-0
onnx2fluid/onnx2fluid/cmdline.py
onnx2fluid/onnx2fluid/cmdline.py
+13
-4
onnx2fluid/onnx2fluid/conversion.py
onnx2fluid/onnx2fluid/conversion.py
+17
-14
onnx2fluid/onnx2fluid/onnx_utils.py
onnx2fluid/onnx2fluid/onnx_utils.py
+35
-35
onnx2fluid/onnx2fluid/symbolic.py
onnx2fluid/onnx2fluid/symbolic.py
+176
-163
onnx2fluid/onnx2fluid/torch_export_helper.py
onnx2fluid/onnx2fluid/torch_export_helper.py
+7
-6
onnx2fluid/onnx2fluid/validation.py
onnx2fluid/onnx2fluid/validation.py
+77
-28
onnx2fluid/onnx2fluid/writer.py
onnx2fluid/onnx2fluid/writer.py
+59
-56
onnx2fluid/setup.cfg
onnx2fluid/setup.cfg
+2
-2
未找到文件。
onnx2fluid/examples/convert_data_npz.py
浏览文件 @
826481c4
...
@@ -21,7 +21,7 @@ def make_var_name(name):
...
@@ -21,7 +21,7 @@ def make_var_name(name):
return
'_'
return
'_'
if
name
[
0
].
isdigit
():
if
name
[
0
].
isdigit
():
return
'var_'
+
name
return
'var_'
+
name
for
s
in
'
\\
|/:'
:
#
for
s
in
'
\\
|/:
-
'
:
#
name
=
name
.
replace
(
s
,
'_'
)
name
=
name
.
replace
(
s
,
'_'
)
if
name
.
startswith
(
'_'
):
if
name
.
startswith
(
'_'
):
name
=
'var'
+
name
name
=
'var'
+
name
...
...
onnx2fluid/examples/convert_data_pb.py
浏览文件 @
826481c4
...
@@ -24,7 +24,7 @@ def make_var_name(name):
...
@@ -24,7 +24,7 @@ def make_var_name(name):
return
'_'
return
'_'
if
name
[
0
].
isdigit
():
if
name
[
0
].
isdigit
():
return
'var_'
+
name
return
'var_'
+
name
for
s
in
'
\\
|/:'
:
#
for
s
in
'
\\
|/:
-
'
:
#
name
=
name
.
replace
(
s
,
'_'
)
name
=
name
.
replace
(
s
,
'_'
)
if
name
.
startswith
(
'_'
):
if
name
.
startswith
(
'_'
):
name
=
'var'
+
name
name
=
'var'
+
name
...
...
onnx2fluid/examples/onnx_model_zoo.sh
浏览文件 @
826481c4
...
@@ -311,7 +311,7 @@ resnet100_arcface()
...
@@ -311,7 +311,7 @@ resnet100_arcface()
echo
"extracting ..."
echo
"extracting ..."
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
-o
/tmp/export/
"
$fn_model
"
-y
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
...
...
onnx2fluid/onnx2fluid/__main__.py
浏览文件 @
826481c4
...
@@ -95,6 +95,14 @@ parser.add_argument(
...
@@ -95,6 +95,14 @@ parser.add_argument(
default
=
1e-2
,
default
=
1e-2
,
help
=
'assertion relative tolerance for validation'
,
help
=
'assertion relative tolerance for validation'
,
)
)
parser
.
add_argument
(
'--infer_inputs'
,
'-i'
,
nargs
=
'?'
,
default
=
None
,
const
=
''
,
help
=
'perform type-shape inference with given input names and re-save model'
,
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
logging_format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
...
...
onnx2fluid/onnx2fluid/cmdline.py
浏览文件 @
826481c4
...
@@ -60,19 +60,28 @@ def main(**kwargs):
...
@@ -60,19 +60,28 @@ def main(**kwargs):
# validate
# validate
passed
=
True
passed
=
True
golden_data_filename
=
kwargs
.
pop
(
'test_data'
,
''
)
golden_data_filename
=
kwargs
.
pop
(
'test_data'
,
''
)
if
golden_data_filename
:
infer_inputs
=
kwargs
.
pop
(
'infer_inputs'
,
None
)
if
golden_data_filename
or
infer_inputs
:
from
.validation
import
validate
from
.validation
import
validate
save_inference_model
=
infer_inputs
is
not
None
inference_input_names
=
infer_inputs
.
split
(
','
)
if
infer_inputs
else
None
logger
.
info
(
'starting validation on desc ...'
)
logger
.
info
(
'starting validation on desc ...'
)
passed
&=
validate
(
shutil
.
os
.
path
.
join
(
save_dir
,
'__model__'
),
passed
&=
validate
(
shutil
.
os
.
path
.
join
(
save_dir
,
'__model__'
),
golden_data_filename
,
**
kwargs
)
golden_data_filename
=
golden_data_filename
,
save_inference_model
=
save_inference_model
,
inference_input_names
=
inference_input_names
,
**
kwargs
)
logger
.
info
(
'starting validation on code ...'
)
logger
.
info
(
'starting validation on code ...'
)
# this re-generate desc proto with Python code when debug on
# this re-generate desc proto with Python code when debug on
passed
&=
validate
(
shutil
.
os
.
path
.
join
(
save_dir
,
model_basename
),
passed
&=
validate
(
shutil
.
os
.
path
.
join
(
save_dir
,
model_basename
),
golden_data_filename
,
golden_data_filename
=
golden_data_filename
,
model_func_name
=
model_func_name
,
model_func_name
=
model_func_name
,
save_inference_model
=
debug
,
save_inference_model
=
save_inference_model
,
inference_input_names
=
inference_input_names
,
**
kwargs
)
**
kwargs
)
if
not
passed
:
if
not
passed
:
...
...
onnx2fluid/onnx2fluid/conversion.py
浏览文件 @
826481c4
...
@@ -141,23 +141,22 @@ def convert(onnx_model_filename,
...
@@ -141,23 +141,22 @@ def convert(onnx_model_filename,
logger
.
info
(
'%d ops in, %d ops out'
,
len
(
onnx_graph
.
node
),
logger
.
info
(
'%d ops in, %d ops out'
,
len
(
onnx_graph
.
node
),
len
(
fluid_program
.
op_descs
))
len
(
fluid_program
.
op_descs
))
#
shape-
inference
#
type-shape
inference
for
name
,
value_info
in
graph_value_infos
.
items
():
for
name
,
value_info
in
graph_value_infos
.
items
():
var_name
=
make_var_name
(
name
)
var_name
=
make_var_name
(
name
)
fluid_program
.
VarTypeInfo
(
var_name
,
value_info
,
fluid_program
.
VarType
Shape
Info
(
var_name
,
value_info
,
remove_batch
=
False
)
# shape-infer only
remove_batch
=
False
)
# shape-infer only
bad_var_names
=
[]
bad_var_names
=
[]
for
var_name
,
var_desc
in
fluid_program
.
var_descs
.
items
():
for
var_name
,
var_desc
in
fluid_program
.
var_descs
.
items
():
if
not
var_desc
.
type
.
lod_tensor
.
HasField
(
'tensor'
):
if
not
var_desc
.
type
.
lod_tensor
.
HasField
(
'tensor'
):
bad_var_names
.
append
(
var_name
)
bad_var_names
.
append
(
var_name
)
if
len
(
bad_var_names
)
>
0
:
if
len
(
bad_var_names
)
>
0
:
logger
.
warning
(
'type
info
not infered for var %s ...'
,
logger
.
warning
(
'type
-shape
not infered for var %s ...'
,
', '
.
join
(
bad_var_names
[:
5
]))
', '
.
join
(
bad_var_names
[:
5
]))
logger
.
warning
(
'this causes little problem for PaddlePaddle, '
logger
.
warning
(
'this causes little problem for PaddlePaddle, '
'but Paddle Mobile may not infer correctly'
)
'but Paddle Mobile may not infer correctly'
)
logger
.
warning
(
logger
.
warning
(
'please consider running onnx2fluid.validation with -i '
'please consider adding option -d to invoke PaddlePaddle shape-inference'
'to invoke PaddlePaddle type-shape inference'
)
)
# weight writer
# weight writer
for
name
,
weight
in
graph_weights
(
onnx_graph
):
for
name
,
weight
in
graph_weights
(
onnx_graph
):
...
@@ -233,13 +232,9 @@ def convert(onnx_model_filename,
...
@@ -233,13 +232,9 @@ def convert(onnx_model_filename,
logger
.
info
(
'conversion finished'
)
logger
.
info
(
'conversion finished'
)
if
__name__
==
'__main__'
:
def
main
():
del
convert
import
argparse
import
argparse
from
onnx2fluid.conversion
import
convert
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
'onnx2fluid.convert'
,
description
=
'onnx2fluid.convert'
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
,
...
@@ -310,3 +305,11 @@ if __name__ == '__main__':
...
@@ -310,3 +305,11 @@ if __name__ == '__main__':
onnx_opset_pedantic
=
pedantic
,
onnx_opset_pedantic
=
pedantic
,
onnx_skip_version_conversion
=
skip_version_conversion
,
onnx_skip_version_conversion
=
skip_version_conversion
,
debug
=
debug
)
debug
=
debug
)
if
__name__
==
'__main__'
:
del
convert
from
onnx2fluid.conversion
import
convert
main
()
onnx2fluid/onnx2fluid/onnx_utils.py
浏览文件 @
826481c4
...
@@ -210,7 +210,7 @@ def node_iter(nodes, indices=None):
...
@@ -210,7 +210,7 @@ def node_iter(nodes, indices=None):
if
name
==
''
:
if
name
==
''
:
name
=
'op_'
+
str
(
index
)
name
=
'op_'
+
str
(
index
)
else
:
# make_op_name
else
:
# make_op_name
for
s
in
'
\\
|/:'
:
#
for
s
in
'
\\
|/:
-
'
:
#
name
=
name
.
replace
(
s
,
'_'
)
name
=
name
.
replace
(
s
,
'_'
)
if
domain
==
''
:
if
domain
==
''
:
domain
=
DEFAULT_OP_DOMAIN
domain
=
DEFAULT_OP_DOMAIN
...
...
onnx2fluid/onnx2fluid/symbolic.py
浏览文件 @
826481c4
...
@@ -153,7 +153,7 @@ def _make_var_name(name):
...
@@ -153,7 +153,7 @@ def _make_var_name(name):
return
'_'
return
'_'
if
name
[
0
].
isdigit
():
if
name
[
0
].
isdigit
():
return
'var_'
+
name
return
'var_'
+
name
for
s
in
'
\\
|/:'
:
#
for
s
in
'
\\
|/:
-
'
:
#
name
=
name
.
replace
(
s
,
'_'
)
name
=
name
.
replace
(
s
,
'_'
)
if
name
.
startswith
(
'_'
):
if
name
.
startswith
(
'_'
):
name
=
'var'
+
name
name
=
'var'
+
name
...
@@ -191,14 +191,24 @@ def _const_weight_or_none(value_infos, val_name):
...
@@ -191,14 +191,24 @@ def _const_weight_or_none(value_infos, val_name):
return
None
return
None
value_info
=
value_infos
[
val_name
]
value_info
=
value_infos
[
val_name
]
const_value
=
value_info
.
get
(
'const_value'
,
None
)
const_value
=
value_info
.
get
(
'const_value'
,
None
)
if
const_value
:
if
const_value
is
not
None
:
return
const_value
return
const_value
get_weight_func
=
value_info
.
get
(
'get_weight'
,
None
)
get_weight_func
=
value_info
.
get
(
'get_weight'
,
None
)
if
get_weight_func
:
if
get_weight_func
is
not
None
:
return
get_weight_func
()
return
get_weight_func
()
return
None
return
None
def
_check_embeddable
(
value_infos
,
*
val_names
):
keyword
=
'get_weight'
for
val_name
in
val_names
:
if
keyword
not
in
value_infos
[
val_name
]:
_logger
.
warning
(
'parameter %s not embeddable for some ops'
,
val_name
)
return
False
return
True
def
_default
(
prog
,
op_type
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
def
_default
(
prog
,
op_type
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
info
=
DEFAULT_OP_MAPPING
[
op_type
]
info
=
DEFAULT_OP_MAPPING
[
op_type
]
info
.
extend
(
DEFAULT_OP_MAPPING_VALUES
[
len
(
info
):])
info
.
extend
(
DEFAULT_OP_MAPPING_VALUES
[
len
(
info
):])
...
@@ -391,9 +401,9 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
...
@@ -391,9 +401,9 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
input_shape
=
_shape_or_none
(
value_infos
,
val_x
)
input_shape
=
_shape_or_none
(
value_infos
,
val_x
)
output_shape
=
_shape_or_none
(
value_infos
,
val_y
)
output_shape
=
_shape_or_none
(
value_infos
,
val_y
)
assert
input_shape
is
not
None
or
output_shape
is
not
None
,
'poolnd not inferred'
# NC...
assert
input_shape
is
not
None
or
output_shape
is
not
None
,
'poolnd not inferred'
# NC...
if
input_shape
:
if
input_shape
is
not
None
:
poolnd
=
len
(
input_shape
)
-
2
# NC...
poolnd
=
len
(
input_shape
)
-
2
# NC...
elif
output_shape
:
elif
output_shape
is
not
None
:
poolnd
=
len
(
output_shape
)
-
2
# NC...
poolnd
=
len
(
output_shape
)
-
2
# NC...
assert
2
<=
poolnd
<=
3
,
'only pool2d and pool3d is supported'
assert
2
<=
poolnd
<=
3
,
'only pool2d and pool3d is supported'
...
@@ -568,7 +578,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
...
@@ -568,7 +578,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
1
]
==
1
,
'only scale on (NC)HW supported'
1
]
==
1
,
'only scale on (NC)HW supported'
assert
scales
[
2
]
==
scales
[
assert
scales
[
2
]
==
scales
[
3
],
'only aspect-ratio-invariant scale supported'
3
],
'only aspect-ratio-invariant scale supported'
scale
=
scales
[
2
]
if
scales
else
None
scale
=
None
if
scales
is
None
else
scales
[
2
]
# try input shape
# try input shape
if
scale
is
None
:
if
scale
is
None
:
assert
out_shape_
,
'neither scales nor output shape is available'
assert
out_shape_
,
'neither scales nor output shape is available'
...
@@ -704,16 +714,19 @@ def BatchNormalization(prog,
...
@@ -704,16 +714,19 @@ def BatchNormalization(prog,
momentum
=
attrs
.
get
(
'momentum'
,
.
9
)
# optional
momentum
=
attrs
.
get
(
'momentum'
,
.
9
)
# optional
epsilon
=
attrs
.
get
(
'epsilon'
,
1e-5
)
# optional
epsilon
=
attrs
.
get
(
'epsilon'
,
1e-5
)
# optional
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
embed_params
=
_check_embeddable
(
value_infos
,
val_scale
,
val_b
,
val_mean
,
val_var
)
if
embed_params
:
if
embed_params
:
assert
name
!=
''
assert
name
!=
''
var_scale
=
name
+
'.w_0'
var_scale
=
name
+
'.w_0'
var_b
=
name
+
'.b_0'
var_b
=
name
+
'.b_0'
var_mean
=
name
+
'.w_1'
var_mean
=
name
+
'.w_1'
var_var
=
name
+
'.w_2'
var_var
=
name
+
'.w_2'
value_infos
[
val_scale
]
.
setdefault
(
'embeded_as'
,
[])
.
append
(
var_scale
)
value_infos
[
val_scale
]
[
'embeded_as'
]
.
append
(
var_scale
)
value_infos
[
val_b
]
.
setdefault
(
'embeded_as'
,
[])
.
append
(
var_b
)
value_infos
[
val_b
]
[
'embeded_as'
]
.
append
(
var_b
)
value_infos
[
val_mean
]
.
setdefault
(
'embeded_as'
,
[])
.
append
(
var_mean
)
value_infos
[
val_mean
]
[
'embeded_as'
]
.
append
(
var_mean
)
value_infos
[
val_var
]
.
setdefault
(
'embeded_as'
,
[])
.
append
(
var_var
)
value_infos
[
val_var
]
[
'embeded_as'
]
.
append
(
var_var
)
param_attr
=
''
param_attr
=
''
else
:
else
:
var_scale
=
_make_var_name
(
val_scale
)
var_scale
=
_make_var_name
(
val_scale
)
...
@@ -774,7 +787,7 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -774,7 +787,7 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
if
not
isinstance
(
dtype
,
_np
.
dtype
):
# additional: possible np.dtype
if
not
isinstance
(
dtype
,
_np
.
dtype
):
# additional: possible np.dtype
dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
dtype
]
dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
dtype
]
output_dtype
=
_dtype_or_none
(
value_infos
,
val_output
)
output_dtype
=
_dtype_or_none
(
value_infos
,
val_output
)
if
output_dtype
:
if
output_dtype
is
not
None
:
assert
dtype
==
output_dtype
,
'dtype of to unmatches output'
assert
dtype
==
output_dtype
,
'dtype of to unmatches output'
fluid_op
=
'cast'
fluid_op
=
'cast'
...
@@ -843,7 +856,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -843,7 +856,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
"""
"""
# I/O
# I/O
assert
len
(
inputs
)
==
0
assert
len
(
inputs
)
==
0
,
'constant op accept no inputs'
val_output
,
=
outputs
val_output
,
=
outputs
var_output
=
_make_var_name
(
val_output
)
var_output
=
_make_var_name
(
val_output
)
...
@@ -851,7 +864,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -851,7 +864,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
value
=
attrs
[
'value'
]
# required
value
=
attrs
[
'value'
]
# required
dtype
=
_np
.
dtype
(
value
.
dtype
)
dtype
=
_np
.
dtype
(
value
.
dtype
)
output_dtype
=
_dtype_or_none
(
value_infos
,
val_output
)
output_dtype
=
_dtype_or_none
(
value_infos
,
val_output
)
if
output_dtype
:
if
output_dtype
is
not
None
:
assert
dtype
==
output_dtype
,
'tensor dtype unmatches storage dtype'
assert
dtype
==
output_dtype
,
'tensor dtype unmatches storage dtype'
...
@@ -970,13 +983,16 @@ def Conv(prog,
...
@@ -970,13 +983,16 @@ def Conv(prog,
paddings
,
val_x
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
paddings
,
val_x
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
var_x
=
_make_var_name
(
val_x
)
var_x
=
_make_var_name
(
val_x
)
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
embed_params
=
(
_check_embeddable
(
value_infos
,
val_w
)
and
not
has_bias
or
_check_embeddable
(
value_infos
,
val_b
))
if
embed_params
:
if
embed_params
:
assert
name
!=
''
assert
name
!=
''
var_w
=
name
+
'.w_0'
var_w
=
name
+
'.w_0'
value_infos
[
val_w
]
.
setdefault
(
'embeded_as'
,
[])
.
append
(
var_w
)
value_infos
[
val_w
]
[
'embeded_as'
]
.
append
(
var_w
)
if
has_bias
:
if
has_bias
:
var_b
=
name
+
'.b_0'
var_b
=
name
+
'.b_0'
value_infos
[
val_b
]
.
setdefault
(
'embeded_as'
,
[])
.
append
(
var_b
)
value_infos
[
val_b
]
[
'embeded_as'
]
.
append
(
var_b
)
param_attr
=
''
param_attr
=
''
else
:
else
:
param_attr
=
', bias_attr=False'
param_attr
=
', bias_attr=False'
...
@@ -1080,13 +1096,16 @@ def ConvTranspose(prog,
...
@@ -1080,13 +1096,16 @@ def ConvTranspose(prog,
paddings
,
val_x
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
paddings
,
val_x
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
var_x
=
_make_var_name
(
val_x
)
var_x
=
_make_var_name
(
val_x
)
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
embed_params
=
(
_check_embeddable
(
value_infos
,
val_w
)
and
not
has_bias
or
_check_embeddable
(
value_infos
,
val_b
))
if
embed_params
:
if
embed_params
:
assert
name
!=
''
assert
name
!=
''
var_w
=
name
+
'.w_0'
var_w
=
name
+
'.w_0'
value_infos
[
val_w
]
.
setdefault
(
'embeded_as'
,
[])
.
append
(
var_w
)
value_infos
[
val_w
]
[
'embeded_as'
]
.
append
(
var_w
)
if
has_bias
:
if
has_bias
:
var_b
=
name
+
'.b_0'
var_b
=
name
+
'.b_0'
value_infos
[
val_b
]
.
setdefault
(
'embeded_as'
,
[])
.
append
(
var_b
)
value_infos
[
val_b
]
[
'embeded_as'
]
.
append
(
var_b
)
param_attr
=
''
param_attr
=
''
else
:
else
:
param_attr
=
', bias_attr=False'
param_attr
=
', bias_attr=False'
...
@@ -1330,9 +1349,9 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
...
@@ -1330,9 +1349,9 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
assume_pad2d
=
False
assume_pad2d
=
False
if
len
(
pads
)
==
4
:
if
len
(
pads
)
==
4
:
assume_pad2d
|=
mode
!=
'constant'
assume_pad2d
|=
mode
!=
'constant'
if
data_shape
:
if
data_shape
is
not
None
:
assume_pad2d
|=
data_shape
and
len
(
data_shape
)
==
4
# NCHW
assume_pad2d
|=
data_shape
and
len
(
data_shape
)
==
4
# NCHW
if
output_shape
:
if
output_shape
is
not
None
:
assume_pad2d
|=
output_shape
and
len
(
output_shape
)
==
4
# NCHW
assume_pad2d
|=
output_shape
and
len
(
output_shape
)
==
4
# NCHW
od_attrs
=
{
'pad_value'
:
value
}
od_attrs
=
{
'pad_value'
:
value
}
if
assume_pad2d
:
if
assume_pad2d
:
...
@@ -1404,10 +1423,12 @@ def PRelu(prog,
...
@@ -1404,10 +1423,12 @@ def PRelu(prog,
mode
=
'element'
mode
=
'element'
fluid_op
=
'prelu'
fluid_op
=
'prelu'
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
embed_params
=
_check_embeddable
(
value_infos
,
val_slope
)
if
embed_params
:
if
embed_params
:
assert
name
!=
''
assert
name
!=
''
var_slope
=
name
+
'.w_0'
var_slope
=
name
+
'.w_0'
value_infos
[
val_slope
]
.
setdefault
(
'embeded_as'
,
[])
.
append
(
var_slope
)
value_infos
[
val_slope
]
[
'embeded_as'
]
.
append
(
var_slope
)
param_attr
=
''
param_attr
=
''
else
:
else
:
var_slope
=
_make_var_name
(
val_slope
)
var_slope
=
_make_var_name
(
val_slope
)
...
@@ -1474,6 +1495,8 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1474,6 +1495,8 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
# generation
# generation
val_shape_int32
=
val_shape
+
'_int32'
# explicit variable
var_shape_int32
=
_make_var_name
(
val_shape_int32
)
prog
.
Code
(
'# shape:{}={} # const as literal'
.
format
(
var_shape
,
shape
))
prog
.
Code
(
'# shape:{}={} # const as literal'
.
format
(
var_shape
,
shape
))
if
is_const_shape
:
if
is_const_shape
:
prog
.
Code
(
'{} = layers.{}({}'
prog
.
Code
(
'{} = layers.{}({}'
...
@@ -1487,8 +1510,6 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1487,8 +1510,6 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
name_attr
,
name_attr
,
))
))
else
:
else
:
val_shape_int32
=
val_shape
+
'_int32'
# explicit variable
var_shape_int32
=
_make_var_name
(
val_shape_int32
)
prog
.
Op
(
prog
.
Op
(
''
,
''
,
'Cast'
,
'Cast'
,
...
@@ -1514,14 +1535,6 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1514,14 +1535,6 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
var_xshape
=
name
+
'.xshape'
# dummy output
var_xshape
=
name
+
'.xshape'
# dummy output
prog
.
VarDesc
(
var_reshaped
)
prog
.
VarDesc
(
var_reshaped
)
prog
.
VarDesc
(
var_xshape
)
prog
.
VarDesc
(
var_xshape
)
if
is_const_shape
:
prog
.
OpDesc
(
fluid_op
,
([
var_data
],
'X'
),
([
var_reshaped
,
var_xshape
],
'Out'
,
'XShape'
),
{
'shape'
:
shape
},
)
else
:
prog
.
OpDesc
(
prog
.
OpDesc
(
fluid_op
,
fluid_op
,
([
var_data
,
var_shape_int32
],
'X'
,
'Shape'
),
([
var_data
,
var_shape_int32
],
'X'
,
'Shape'
),
...
@@ -1595,7 +1608,7 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -1595,7 +1608,7 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
starts
=
attrs
[
'starts'
]
# required
starts
=
attrs
[
'starts'
]
# required
ends
=
attrs
[
'ends'
]
# required
ends
=
attrs
[
'ends'
]
# required
shape
=
_shape_or_none
(
value_infos
,
val_data
)
shape
=
_shape_or_none
(
value_infos
,
val_data
)
if
shape
:
if
shape
is
not
None
:
# ndims = len(shape)
# ndims = len(shape)
# for idx, value in enumerate(axes):
# for idx, value in enumerate(axes):
# if value > ONNX_INT_MAX // 2:
# if value > ONNX_INT_MAX // 2:
...
...
onnx2fluid/onnx2fluid/torch_export_helper.py
浏览文件 @
826481c4
...
@@ -25,7 +25,8 @@ def ensure_tuple(obj):
...
@@ -25,7 +25,8 @@ def ensure_tuple(obj):
def
flatten_list
(
obj
,
out
=
None
):
def
flatten_list
(
obj
,
out
=
None
):
assert
isinstance
(
obj
,
list
)
assert
isinstance
(
obj
,
list
),
'list type required'
if
out
is
None
:
if
out
is
None
:
out
=
type
(
obj
)()
out
=
type
(
obj
)()
for
item
in
obj
:
for
item
in
obj
:
...
@@ -42,7 +43,7 @@ def export_data(state_dict, prefix=''):
...
@@ -42,7 +43,7 @@ def export_data(state_dict, prefix=''):
"""
"""
def
str_
(
obj
):
def
str_
(
obj
):
if
isinstance
(
obj
,
(
tuple
,
list
)):
if
isinstance
(
obj
,
(
tuple
,
list
,
set
)):
return
str
(
obj
)[
1
:
-
1
].
replace
(
' '
,
''
)
return
str
(
obj
)[
1
:
-
1
].
replace
(
' '
,
''
)
return
str
(
obj
)
return
str
(
obj
)
...
...
onnx2fluid/onnx2fluid/validation.py
浏览文件 @
826481c4
...
@@ -10,14 +10,15 @@ import importlib, logging, os, sys
...
@@ -10,14 +10,15 @@ import importlib, logging, os, sys
def
flatten_dict
(
obj
,
out
=
None
):
def
flatten_dict
(
obj
,
out
=
None
):
assert
isinstance
(
obj
,
dict
)
assert
isinstance
(
obj
,
dict
),
'dict type required'
if
out
is
None
:
if
out
is
None
:
out
=
type
(
obj
)()
out
=
type
(
obj
)()
for
key
,
value
in
obj
.
items
():
for
key
,
value
in
obj
.
items
():
if
isinstance
(
value
,
dict
):
if
isinstance
(
value
,
dict
):
flatten_dict
(
value
,
out
)
flatten_dict
(
value
,
out
)
else
:
else
:
assert
key
not
in
out
assert
key
not
in
out
,
'key conflicted'
out
[
key
]
=
value
out
[
key
]
=
value
return
out
return
out
...
@@ -29,11 +30,12 @@ def ensure_list(obj):
...
@@ -29,11 +30,12 @@ def ensure_list(obj):
def
validate
(
fluid_model_filename
,
def
validate
(
fluid_model_filename
,
golden_data_filename
,
golden_data_filename
=
''
,
model_func_name
=
'inference'
,
atol
=
1e-3
,
atol
=
1e-3
,
rtol
=
1e-3
,
rtol
=
1e-3
,
model_func_name
=
'inference'
,
save_inference_model
=
False
,
save_inference_model
=
False
,
inference_input_names
=
None
,
**
kwargs
):
**
kwargs
):
"""
"""
inference the converted Paddle fluid model, validate with given golden data
inference the converted Paddle fluid model, validate with given golden data
...
@@ -86,6 +88,7 @@ def validate(fluid_model_filename,
...
@@ -86,6 +88,7 @@ def validate(fluid_model_filename,
raise
ValueError
(
'unsupported Paddle fluid model filename'
)
raise
ValueError
(
'unsupported Paddle fluid model filename'
)
# load data
# load data
if
golden_data_filename
:
logger
.
info
(
'using golden data %s'
,
golden_data_filename
)
logger
.
info
(
'using golden data %s'
,
golden_data_filename
)
if
golden_data_filename
.
endswith
(
'.npz'
):
if
golden_data_filename
.
endswith
(
'.npz'
):
test_data
=
np
.
load
(
golden_data_filename
,
encoding
=
'bytes'
)
test_data
=
np
.
load
(
golden_data_filename
,
encoding
=
'bytes'
)
...
@@ -97,13 +100,38 @@ def validate(fluid_model_filename,
...
@@ -97,13 +100,38 @@ def validate(fluid_model_filename,
output_data
=
test_data
[
'outputs'
]
output_data
=
test_data
[
'outputs'
]
input_data
=
flatten_dict
(
input_data
)
input_data
=
flatten_dict
(
input_data
)
output_data
=
flatten_dict
(
output_data
)
output_data
=
flatten_dict
(
output_data
)
input_names
=
input_data
.
keys
()
logger
.
info
(
'found %d I/O golden data, starting test ...'
,
logger
.
info
(
'found %d I/O golden data, starting test ...'
,
len
(
input_data
)
+
len
(
output_data
))
len
(
input_data
)
+
len
(
output_data
))
else
:
assert
inference_input_names
,
'input names required for type-shape inference'
input_names
=
inference_input_names
logger
.
info
(
'using input names: %s'
,
', '
.
join
(
input_names
))
# type-shape inference and re-save
if
save_inference_model
:
for
block
in
prog
.
blocks
:
block_desc
=
block
.
desc
for
idx_op
in
range
(
block_desc
.
op_size
()):
op_desc
=
block_desc
.
op
(
idx_op
)
if
op_desc
.
type
()
in
(
'feed'
,
'fetch'
):
continue
op_desc
.
infer_var_type
(
block_desc
)
op_desc
.
infer_shape
(
block_desc
)
for
var_name
,
var
in
block
.
vars
.
items
():
var_desc
=
var
.
desc
if
var_desc
.
type
()
!=
fluid
.
core
.
VarDesc
.
VarType
.
LOD_TENSOR
:
continue
# WORKAROUND: dirty way to give dtype to partial-infered vars
# which could not be cleared!
try
:
var
.
to_string
(
True
)
except
ValueError
:
var_desc
.
set_dtype
(
fluid
.
core
.
VarDesc
.
VarType
.
FP32
)
# DEBUG: reload test for Python code
if
basename
.
endswith
(
'.py'
)
and
save_inference_model
:
fluid
.
io
.
save_inference_model
(
fluid_model_dir
,
fluid
.
io
.
save_inference_model
(
fluid_model_dir
,
input_
data
.
keys
()
,
input_
names
,
var_outs
,
var_outs
,
exe
,
exe
,
main_program
=
prog
,
main_program
=
prog
,
...
@@ -112,8 +140,12 @@ def validate(fluid_model_filename,
...
@@ -112,8 +140,12 @@ def validate(fluid_model_filename,
fluid
.
io
.
load_inference_model
(
fluid_model_dir
,
exe
)
fluid
.
io
.
load_inference_model
(
fluid_model_dir
,
exe
)
logger
.
info
(
'model re-load passed'
)
logger
.
info
(
'model re-load passed'
)
if
not
golden_data_filename
:
return
True
# execute
# execute
outputs
=
exe
.
run
(
prog
,
feed
=
input_data
,
fetch_list
=
out_names
)
outputs
=
exe
.
run
(
prog
,
feed
=
input_data
,
fetch_list
=
out_names
)
# out_names can be vars
logger
.
info
(
'execution passed'
)
logger
.
info
(
'execution passed'
)
# validate
# validate
...
@@ -134,11 +166,10 @@ def validate(fluid_model_filename,
...
@@ -134,11 +166,10 @@ def validate(fluid_model_filename,
logger
.
info
(
'accuracy passed'
)
logger
.
info
(
'accuracy passed'
)
else
:
else
:
logger
.
info
(
'accuracy not passed'
)
logger
.
info
(
'accuracy not passed'
)
return
passed
return
passed
if
__name__
==
'__main__'
:
def
main
()
:
import
argparse
import
argparse
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
...
@@ -160,6 +191,7 @@ if __name__ == '__main__':
...
@@ -160,6 +191,7 @@ if __name__ == '__main__':
'--test_data'
,
'--test_data'
,
'-t'
,
'-t'
,
type
=
str
,
type
=
str
,
default
=
''
,
help
=
'I/O golden data for validation, e.g. test.npy, test.npz'
,
help
=
'I/O golden data for validation, e.g. test.npy, test.npz'
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -175,19 +207,36 @@ if __name__ == '__main__':
...
@@ -175,19 +207,36 @@ if __name__ == '__main__':
default
=
1e-2
,
default
=
1e-2
,
help
=
'assertion relative tolerance for validation'
,
help
=
'assertion relative tolerance for validation'
,
)
)
parser
.
add_argument
(
'--infer_inputs'
,
'-i'
,
nargs
=
'?'
,
default
=
None
,
const
=
''
,
help
=
'perform type-shape inference with given input names and re-save model'
,
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
logging_format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_level
=
logging
.
DEBUG
if
args
.
debug
else
logging
.
INFO
logging_level
=
logging
.
DEBUG
if
args
.
debug
else
logging
.
INFO
logging
.
basicConfig
(
format
=
logging_format
,
level
=
logging_level
)
logging
.
basicConfig
(
format
=
logging_format
,
level
=
logging_level
)
debug
=
args
.
debug
#
debug = args.debug
fluid_model_filename
=
args
.
model
[
0
]
fluid_model_filename
=
args
.
model
[
0
]
golden_data_filename
=
args
.
test_data
golden_data_filename
=
args
.
test_data
atol
,
rtol
=
args
.
atol
,
args
.
rtol
atol
,
rtol
=
args
.
atol
,
args
.
rtol
save_inference_model
=
args
.
infer_inputs
is
not
None
inference_input_names
=
args
.
infer_inputs
.
split
(
','
)
if
args
.
infer_inputs
else
None
validate
(
fluid_model_filename
,
validate
(
fluid_model_filename
,
golden_data_filename
,
golden_data_filename
=
golden_data_filename
,
atol
=
atol
,
atol
=
atol
,
rtol
=
rtol
,
rtol
=
rtol
,
save_inference_model
=
debug
)
save_inference_model
=
save_inference_model
,
inference_input_names
=
inference_input_names
)
if
__name__
==
'__main__'
:
main
()
onnx2fluid/onnx2fluid/writer.py
浏览文件 @
826481c4
...
@@ -44,6 +44,8 @@ def irepr(obj, to='_'):
...
@@ -44,6 +44,8 @@ def irepr(obj, to='_'):
def
flatten_list
(
obj
,
out
=
None
):
def
flatten_list
(
obj
,
out
=
None
):
assert
isinstance
(
obj
,
list
),
'list type required'
if
out
is
None
:
if
out
is
None
:
out
=
type
(
obj
)()
out
=
type
(
obj
)()
for
item
in
obj
:
for
item
in
obj
:
...
@@ -59,9 +61,9 @@ def make_attr_name(name):
...
@@ -59,9 +61,9 @@ def make_attr_name(name):
make a valid code name for ParamAttr
make a valid code name for ParamAttr
"""
"""
if
name
==
''
:
assert
name
!=
''
,
'name should not be empty'
raise
ValueError
(
'name should not be empty'
)
for
s
in
'
\\
|/:'
:
#
for
s
in
'
\\
|/:
-
'
:
#
name
=
name
.
replace
(
s
,
'_'
)
name
=
name
.
replace
(
s
,
'_'
)
if
not
name
.
startswith
(
'_'
):
if
not
name
.
startswith
(
'_'
):
name
=
'_'
+
name
name
=
'_'
+
name
...
@@ -195,11 +197,11 @@ class Program(object):
...
@@ -195,11 +197,11 @@ class Program(object):
desc
=
framework_pb2
.
OpDesc
()
desc
=
framework_pb2
.
OpDesc
()
desc
.
type
=
op_type
desc
.
type
=
op_type
if
input_val_keys
is
not
None
:
if
input_val_keys
:
desc
.
inputs
.
extend
(
self
.
OpDescVars
(
*
input_val_keys
))
desc
.
inputs
.
extend
(
self
.
OpDescVars
(
*
input_val_keys
))
if
output_val_keys
is
not
None
:
if
output_val_keys
:
desc
.
outputs
.
extend
(
self
.
OpDescVars
(
*
output_val_keys
))
desc
.
outputs
.
extend
(
self
.
OpDescVars
(
*
output_val_keys
))
if
attrs
is
not
None
:
if
attrs
:
desc
.
attrs
.
extend
(
self
.
OpDescAttrs
(
attrs
))
desc
.
attrs
.
extend
(
self
.
OpDescAttrs
(
attrs
))
self
.
op_descs
.
append
(
desc
)
self
.
op_descs
.
append
(
desc
)
return
desc
return
desc
...
@@ -220,8 +222,11 @@ class Program(object):
...
@@ -220,8 +222,11 @@ class Program(object):
var_desc
.
persistable
=
persistable
var_desc
.
persistable
=
persistable
var_desc
.
type
.
type
=
framework_pb2
.
VarType
.
LOD_TENSOR
var_desc
.
type
.
type
=
framework_pb2
.
VarType
.
LOD_TENSOR
self
.
var_descs
[
var_name
]
=
var_desc
self
.
var_descs
[
var_name
]
=
var_desc
if
value_info
:
if
value_info
:
self
.
VarTypeInfo
(
var_name
,
value_info
,
remove_batch
=
remove_batch
)
self
.
VarTypeShapeInfo
(
var_name
,
value_info
,
remove_batch
=
remove_batch
)
def
Op
(
self
,
domain
,
op_type
,
*
args
,
**
kwargs
):
def
Op
(
self
,
domain
,
op_type
,
*
args
,
**
kwargs
):
"""
"""
...
@@ -255,7 +260,7 @@ class Program(object):
...
@@ -255,7 +260,7 @@ class Program(object):
else
:
else
:
self
.
code_mutable
=
code_mutable
self
.
code_mutable
=
code_mutable
def
VarTypeInfo
(
self
,
var_name
,
value_info
,
remove_batch
=
None
):
def
VarType
Shape
Info
(
self
,
var_name
,
value_info
,
remove_batch
=
None
):
"""
"""
set value_info for var
set value_info for var
"""
"""
...
@@ -429,8 +434,7 @@ class Writer(object):
...
@@ -429,8 +434,7 @@ class Writer(object):
write single weight in fluid desc
write single weight in fluid desc
"""
"""
if
not
isinstance
(
weight
,
np
.
ndarray
):
assert
isinstance
(
weight
,
np
.
ndarray
),
'weight is not an ndarray'
raise
TypeError
(
'weight is not an ndarray'
)
tensor_desc
=
framework_pb2
.
VarType
.
TensorDesc
()
tensor_desc
=
framework_pb2
.
VarType
.
TensorDesc
()
tensor_desc
.
data_type
=
Program
.
Dtype
(
weight
.
dtype
)
tensor_desc
.
data_type
=
Program
.
Dtype
(
weight
.
dtype
)
...
@@ -452,8 +456,7 @@ class Writer(object):
...
@@ -452,8 +456,7 @@ class Writer(object):
"""
"""
for
name
,
weight
in
weights
.
items
():
for
name
,
weight
in
weights
.
items
():
if
not
isinstance
(
weights
,
dict
):
assert
isinstance
(
weights
,
dict
),
'dict type weights required'
raise
TypeError
(
'dict type weights required'
)
var_name
=
make_var_name
(
name
)
var_name
=
make_var_name
(
name
)
filename
=
os
.
path
.
join
(
save_dir
,
var_name
)
filename
=
os
.
path
.
join
(
save_dir
,
var_name
)
...
...
onnx2fluid/setup.cfg
浏览文件 @
826481c4
...
@@ -54,8 +54,8 @@ zip_safe = True
...
@@ -54,8 +54,8 @@ zip_safe = True
[options.entry_points]
[options.entry_points]
console_scripts =
console_scripts =
onnx2fluid = onnx2fluid.__main__
onnx2fluid = onnx2fluid.__main__
onnx2fluid_convert = onnx2fluid.conversion
onnx2fluid_convert = onnx2fluid.conversion
:main
onnx2fluid_validate = onnx2fluid.validation
onnx2fluid_validate = onnx2fluid.validation
:main
# 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下
# 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下
# 仅支持文件,不支持目录,但可以使用通配
# 仅支持文件,不支持目录,但可以使用通配
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录