Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
ba40d265
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ba40d265
编写于
3月 29, 2019
作者:
M
Macrobull
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize symbolic
上级
52a502df
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
177 addition
and
144 deletion
+177
-144
onnx2fluid/onnx2fluid/__main__.py
onnx2fluid/onnx2fluid/__main__.py
+16
-0
onnx2fluid/onnx2fluid/cmdline.py
onnx2fluid/onnx2fluid/cmdline.py
+24
-12
onnx2fluid/onnx2fluid/conversion.py
onnx2fluid/onnx2fluid/conversion.py
+23
-13
onnx2fluid/onnx2fluid/onnx_utils.py
onnx2fluid/onnx2fluid/onnx_utils.py
+21
-14
onnx2fluid/onnx2fluid/symbolic.py
onnx2fluid/onnx2fluid/symbolic.py
+81
-93
onnx2fluid/onnx2fluid/validation.py
onnx2fluid/onnx2fluid/validation.py
+6
-5
onnx2fluid/onnx2fluid/writer.py
onnx2fluid/onnx2fluid/writer.py
+6
-7
未找到文件。
onnx2fluid/onnx2fluid/__main__.py
浏览文件 @
ba40d265
...
@@ -69,6 +69,22 @@ parser.add_argument(
...
@@ -69,6 +69,22 @@ parser.add_argument(
dest
=
'pedantic'
,
dest
=
'pedantic'
,
help
=
'process non-standard ONNX ops, this may lead to fails'
,
help
=
'process non-standard ONNX ops, this may lead to fails'
,
)
)
parser
.
add_argument
(
'--skip-version-conversion'
,
'-y'
,
action
=
'store_true'
,
default
=
False
,
help
=
'skip ONNX op version conversion, workaround for RumtimeErrors'
,
)
parser
.
add_argument
(
'--archive'
,
'-z'
,
nargs
=
'?'
,
type
=
str
,
default
=
None
,
const
=
''
,
help
=
'compress outputs to ZIP file if conversion successed'
,
)
parser
.
add_argument
(
parser
.
add_argument
(
'--precision'
,
'--precision'
,
'-p'
,
'-p'
,
...
...
onnx2fluid/onnx2fluid/cmdline.py
浏览文件 @
ba40d265
...
@@ -16,10 +16,10 @@ from __future__ import division
...
@@ -16,10 +16,10 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
from
__future__
import
unicode_literals
from
__future__
import
unicode_literals
#
import logging, shutil, zipfile
import
logging
,
shutil
,
zipfile
import
logging
#
import logging
import
shutil
#
import shutil
import
zipfile
#
import zipfile
__all__
=
[
__all__
=
[
'main'
,
'main'
,
...
@@ -49,12 +49,14 @@ def main(**kwargs):
...
@@ -49,12 +49,14 @@ def main(**kwargs):
basepath
,
_
=
shutil
.
os
.
path
.
splitext
(
filename
)
basepath
,
_
=
shutil
.
os
.
path
.
splitext
(
filename
)
save_dir
=
kwargs
.
get
(
'output_dir'
,
''
)
save_dir
=
kwargs
.
get
(
'output_dir'
,
''
)
# model.onnx -> model/
# model.onnx -> model/
save_dir
=
shutil
.
os
.
path
.
dirname
(
save_dir
)
if
save_dir
else
basepath
save_dir
=
(
save_dir
.
rstrip
(
'/'
)
if
save_dir
else
basepath
)
+
'/'
model_basename
=
DEFAULT_MODEL_MODULE
+
'.py'
model_basename
=
DEFAULT_MODEL_MODULE
+
'.py'
model_func_name
=
DEFAULT_MODEL_FUNC
model_func_name
=
DEFAULT_MODEL_FUNC
embed_params
=
kwargs
.
get
(
'embed_params'
,
False
)
embed_params
=
kwargs
.
get
(
'embed_params'
,
False
)
onnx_opset_version
=
DEFAULT_ONNX_OPSET_VERSION
onnx_opset_version
=
DEFAULT_ONNX_OPSET_VERSION
onnx_opset_pedantic
=
kwargs
.
get
(
'pedantic'
,
True
)
onnx_opset_pedantic
=
kwargs
.
get
(
'pedantic'
,
True
)
onnx_skip_version_conversion
=
kwargs
.
get
(
'skip_version_conversion'
,
False
)
archive
=
kwargs
.
get
(
'archive'
,
None
)
# convert
# convert
convert
(
convert
(
...
@@ -65,6 +67,7 @@ def main(**kwargs):
...
@@ -65,6 +67,7 @@ def main(**kwargs):
embed_params
=
embed_params
,
embed_params
=
embed_params
,
onnx_opset_version
=
onnx_opset_version
,
onnx_opset_version
=
onnx_opset_version
,
onnx_opset_pedantic
=
onnx_opset_pedantic
,
onnx_opset_pedantic
=
onnx_opset_pedantic
,
onnx_skip_version_conversion
=
onnx_skip_version_conversion
,
debug
=
debug
)
debug
=
debug
)
# validate
# validate
...
@@ -104,13 +107,21 @@ def main(**kwargs):
...
@@ -104,13 +107,21 @@ def main(**kwargs):
return
return
# create zip file
# create zip file
fn_zip
=
save_dir
.
rstrip
(
'/'
)
+
'.zip'
if
archive
is
not
None
:
logger
.
info
(
'compressing file to %s ...'
,
fn_zip
)
if
archive
==
''
:
fz
=
zipfile
.
ZipFile
(
fn_zip
,
'w'
,
compression
=
zipfile
.
ZIP_LZMA
)
archive
=
save_dir
.
rstrip
(
'/'
)
+
'.zip'
for
fn
in
shutil
.
os
.
listdir
(
save_dir
):
logger
.
info
(
'compressing file to %s ...'
,
archive
)
fz
.
write
(
shutil
.
os
.
path
.
join
(
save_dir
,
fn
),
arcname
=
fn
)
shutil
.
sys
.
stderr
.
write
(
'
\n
'
)
fz
.
close
()
shutil
.
sys
.
stderr
.
flush
()
logger
.
info
(
'compressing done'
)
file_list
=
shutil
.
os
.
listdir
(
save_dir
)
fz
=
zipfile
.
ZipFile
(
archive
,
'w'
,
compression
=
zipfile
.
ZIP_LZMA
)
for
idx
,
fn
in
enumerate
(
file_list
):
shutil
.
sys
.
stderr
.
write
(
'
\033
[F
\033
[2K'
)
logger
.
info
(
'file {}/{}: {}'
.
format
(
idx
+
1
,
len
(
file_list
),
fn
))
shutil
.
sys
.
stderr
.
flush
()
fz
.
write
(
shutil
.
os
.
path
.
join
(
save_dir
,
fn
),
arcname
=
fn
)
fz
.
close
()
logger
.
info
(
'compressing done'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -132,5 +143,6 @@ if __name__ == '__main__':
...
@@ -132,5 +143,6 @@ if __name__ == '__main__':
output_dir
=
'/tmp/export/'
,
output_dir
=
'/tmp/export/'
,
embed_params
=
True
,
embed_params
=
True
,
pedantic
=
False
,
pedantic
=
False
,
skip_version_conversion
=
False
,
test_data
=
'../examples/inception_v2/test_data_set_2.npz'
,
test_data
=
'../examples/inception_v2/test_data_set_2.npz'
,
debug
=
True
)
debug
=
True
)
onnx2fluid/onnx2fluid/conversion.py
浏览文件 @
ba40d265
...
@@ -8,9 +8,9 @@ Created on Mon Feb 25 09:50:35 2019
...
@@ -8,9 +8,9 @@ Created on Mon Feb 25 09:50:35 2019
from
__future__
import
division
from
__future__
import
division
#
import logging, shutil
import
logging
,
shutil
import
logging
#
import logging
import
shutil
#
import shutil
__all__
=
[
__all__
=
[
'convert'
,
'convert'
,
...
@@ -24,6 +24,7 @@ def convert(onnx_model_filename,
...
@@ -24,6 +24,7 @@ def convert(onnx_model_filename,
embed_params
=
False
,
embed_params
=
False
,
onnx_opset_version
=
9
,
onnx_opset_version
=
9
,
onnx_opset_pedantic
=
True
,
onnx_opset_pedantic
=
True
,
onnx_skip_version_conversion
=
False
,
debug
=
False
):
debug
=
False
):
"""
"""
convert an ONNX model to Paddle fluid Python code and desc pb
convert an ONNX model to Paddle fluid Python code and desc pb
...
@@ -60,12 +61,13 @@ def convert(onnx_model_filename,
...
@@ -60,12 +61,13 @@ def convert(onnx_model_filename,
try
:
try
:
logger
.
info
(
'checking model ...'
)
logger
.
info
(
'checking model ...'
)
check_model
(
onnx_model
)
check_model
(
onnx_model
)
logger
.
debug
(
'using opset version: %d'
,
onnx_opset_version
)
if
onnx_skip_version_conversion
:
# WORKAROUND: RuntimeError: No Adapter For OP
if
onnx_opset_pedantic
:
# WORKAROUND: RuntimeError: No Adapter For OP
logger
.
debug
(
'assumed opset version: %d'
,
onnx_opset_version
)
onnx_model
=
convert_version
(
onnx_model
,
onnx_opset_version
)
else
:
# TODO: add new argument for this option
logger
.
warning
(
logger
.
warning
(
'opset conversion skipped for onnx_opset_pedantic is OFF'
)
'opset conversion skipped for onnx_opset_pedantic is OFF'
)
else
:
logger
.
debug
(
'using opset version: %d'
,
onnx_opset_version
)
onnx_model
=
convert_version
(
onnx_model
,
onnx_opset_version
)
onnx_model
=
polish_model
(
onnx_model
)
onnx_model
=
polish_model
(
onnx_model
)
except
ValidationError
as
e
:
except
ValidationError
as
e
:
if
onnx_opset_pedantic
:
if
onnx_opset_pedantic
:
...
@@ -152,16 +154,15 @@ def convert(onnx_model_filename,
...
@@ -152,16 +154,15 @@ def convert(onnx_model_filename,
logger
.
info
(
logger
.
info
(
'weight %s is shared between ops, more disk space will be consumed'
,
'weight %s is shared between ops, more disk space will be consumed'
,
name
)
name
)
logger
.
debug
(
logger
.
debug
(
'saving weight %s(%s[%d], %dB) as %s ...'
,
name
,
'saving weight %s with size of %d, in %d bytes, as %s ...'
,
weight
.
dtype
,
weight
.
size
,
weight
.
nbytes
,
var_names
)
name
,
weight
.
size
,
weight
.
nbytes
,
var_names
)
for
var_name
in
var_names
:
# multiple references
for
var_name
in
var_names
:
# multiple references
fluid_writer
.
write_weight
(
fluid_writer
.
write_weight
(
weight
,
shutil
.
os
.
path
.
join
(
save_dir
,
var_name
))
weight
,
shutil
.
os
.
path
.
join
(
save_dir
,
var_name
))
else
:
else
:
logger
.
debug
(
logger
.
debug
(
'saving weight %s(%s[%d], %dB) to %s ...'
,
name
,
'saving weight %s with size of %d, in %d bytes, to %s ...'
,
weight
.
dtype
,
weight
.
size
,
weight
.
nbytes
,
name
,
weight
.
size
,
weight
.
nbytes
,
make_var_name
(
name
))
make_var_name
(
name
))
fluid_writer
.
write_weight
(
fluid_writer
.
write_weight
(
weight
,
shutil
.
os
.
path
.
join
(
save_dir
,
make_var_name
(
name
)))
weight
,
shutil
.
os
.
path
.
join
(
save_dir
,
make_var_name
(
name
)))
fluid_writer
.
emit_param
(
fluid_program
,
name
,
value_info
)
fluid_writer
.
emit_param
(
fluid_program
,
name
,
value_info
)
...
@@ -262,6 +263,13 @@ if __name__ == '__main__':
...
@@ -262,6 +263,13 @@ if __name__ == '__main__':
dest
=
'pedantic'
,
dest
=
'pedantic'
,
help
=
'process non-standard ONNX ops, this may lead to fails'
,
help
=
'process non-standard ONNX ops, this may lead to fails'
,
)
)
parser
.
add_argument
(
'--skip-version-conversion'
,
'-y'
,
action
=
'store_true'
,
default
=
False
,
help
=
'skip ONNX op version conversion, workaround for RumtimeErrors'
,
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
logging_format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
...
@@ -273,10 +281,12 @@ if __name__ == '__main__':
...
@@ -273,10 +281,12 @@ if __name__ == '__main__':
save_dir
=
args
.
output_dir
save_dir
=
args
.
output_dir
embed_params
=
args
.
embed_params
embed_params
=
args
.
embed_params
pedantic
=
args
.
pedantic
pedantic
=
args
.
pedantic
skip_version_conversion
=
args
.
skip_version_conversion
convert
(
convert
(
model_filename
,
model_filename
,
save_dir
,
save_dir
,
embed_params
=
embed_params
,
embed_params
=
embed_params
,
onnx_opset_pedantic
=
pedantic
,
onnx_opset_pedantic
=
pedantic
,
onnx_skip_version_conversion
=
skip_version_conversion
,
debug
=
debug
)
debug
=
debug
)
onnx2fluid/onnx2fluid/onnx_utils.py
浏览文件 @
ba40d265
...
@@ -26,6 +26,7 @@ __all__ = [
...
@@ -26,6 +26,7 @@ __all__ = [
'node_attrs'
,
'node_attrs'
,
'node_topo'
,
'node_topo'
,
'node_iter'
,
'node_iter'
,
'tensor_dtype'
,
'tensor_shape'
,
'tensor_shape'
,
'graph_ops'
,
'graph_ops'
,
'graph_weights'
,
'graph_weights'
,
...
@@ -92,13 +93,12 @@ def get_attribute_value2(attr):
...
@@ -92,13 +93,12 @@ def get_attribute_value2(attr):
return
value
return
value
def
node_attrs
(
node
):
def
tensor_dtype
(
tensor
):
"""
"""
convert ONNX node attributes to dict
get ONNX tensor in np.dtype
"""
"""
return
{
attr
.
name
:
get_attribute_value2
(
attr
)
return
TENSOR_TYPE_TO_NP_TYPE
[
tensor
.
type
.
tensor_type
.
elem_type
]
for
attr
in
node
.
attribute
}
# dict
def
tensor_shape
(
tensor
):
def
tensor_shape
(
tensor
):
...
@@ -109,6 +109,15 @@ def tensor_shape(tensor):
...
@@ -109,6 +109,15 @@ def tensor_shape(tensor):
return
[
dim
.
dim_value
for
dim
in
tensor
.
type
.
tensor_type
.
shape
.
dim
]
return
[
dim
.
dim_value
for
dim
in
tensor
.
type
.
tensor_type
.
shape
.
dim
]
def
node_attrs
(
node
):
"""
convert ONNX node attributes to dict
"""
return
{
attr
.
name
:
get_attribute_value2
(
attr
)
for
attr
in
node
.
attribute
}
# dict
def
node_topo
(
nodes
,
topo
=
'default'
):
def
node_topo
(
nodes
,
topo
=
'default'
):
"""
"""
build indices with given topology to an ONNX node graph
build indices with given topology to an ONNX node graph
...
@@ -237,21 +246,21 @@ def inferred_model_value_info(model):
...
@@ -237,21 +246,21 @@ def inferred_model_value_info(model):
value_info
=
Dict
()
value_info
=
Dict
()
for
item
in
graph
.
value_info
:
for
item
in
graph
.
value_info
:
value_info
[
item
.
name
]
=
dict
(
value_info
[
item
.
name
]
=
dict
(
dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
item
.
type
.
tensor_type
.
elem_type
]
,
dtype
=
tensor_dtype
(
item
)
,
shape
=
tensor_shape
(
item
),
shape
=
tensor_shape
(
item
),
external
=
False
,
external
=
False
,
)
)
for
item
in
graph
.
input
:
for
item
in
graph
.
input
:
assert
item
.
name
not
in
value_info
assert
item
.
name
not
in
value_info
value_info
[
item
.
name
]
=
dict
(
value_info
[
item
.
name
]
=
dict
(
dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
item
.
type
.
tensor_type
.
elem_type
]
,
dtype
=
tensor_dtype
(
item
)
,
shape
=
tensor_shape
(
item
),
shape
=
tensor_shape
(
item
),
external
=
True
,
external
=
True
,
)
)
for
item
in
graph
.
output
:
for
item
in
graph
.
output
:
# assert item.name not in value_info, 'bypass-model not supported'
# assert item.name not in value_info, 'bypass-model not supported'
value_info
[
item
.
name
]
=
dict
(
value_info
[
item
.
name
]
=
dict
(
dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
item
.
type
.
tensor_type
.
elem_type
]
,
dtype
=
tensor_dtype
(
item
)
,
shape
=
tensor_shape
(
item
),
shape
=
tensor_shape
(
item
),
external
=
True
,
external
=
True
,
)
)
...
@@ -373,9 +382,9 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
...
@@ -373,9 +382,9 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
elif
not
keep_input_only
and
name
in
output_refs
:
elif
not
keep_input_only
and
name
in
output_refs
:
ret_initializers
.
add
().
CopyFrom
(
initializer
)
ret_initializers
.
add
().
CopyFrom
(
initializer
)
else
:
else
:
logger
.
debug
(
'initializer %s(%s[%d]) stripped'
,
name
,
dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
initializer
.
data_type
]
TENSOR_TYPE_TO_NP_TYPE
[
initializer
.
data_type
]
,
logger
.
debug
(
'initializer %s(%s[%d]) stripped'
,
name
,
dtype
,
len
(
initializer
.
raw_data
))
len
(
initializer
.
raw_data
)
//
dtype
.
itemsize
)
# strip inputs
# strip inputs
ret
.
graph
.
ClearField
(
'input'
)
ret
.
graph
.
ClearField
(
'input'
)
...
@@ -385,10 +394,8 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
...
@@ -385,10 +394,8 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
if
name
in
input_refs
or
name
in
out_names
:
if
name
in
input_refs
or
name
in
out_names
:
ret_inputs
.
add
().
CopyFrom
(
item
)
ret_inputs
.
add
().
CopyFrom
(
item
)
else
:
else
:
logger
.
debug
(
logger
.
debug
(
'input %s(%s%s) stripped'
,
name
,
tensor_dtype
(
item
),
'input %s(%s%s) stripped'
,
name
,
tensor_shape
(
item
))
TENSOR_TYPE_TO_NP_TYPE
[
item
.
type
.
tensor_type
.
elem_type
],
tensor_shape
(
item
))
return
ret
return
ret
...
...
onnx2fluid/onnx2fluid/symbolic.py
浏览文件 @
ba40d265
...
@@ -19,7 +19,7 @@ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
...
@@ -19,7 +19,7 @@ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
_logger
=
_logging
.
getLogger
(
__name__
)
_logger
=
_logging
.
getLogger
(
__name__
)
ONNX_INT_MAX
=
2
**
63
-
1
ONNX_INT_MAX
=
2
**
63
-
1
FLUID_INT_MAX
=
2
**
31
-
1
FLUID_INT_MAX
=
2
**
31
-
1
#
DEFAULT_ONNX_OP_DOMAIN
=
''
DEFAULT_ONNX_OP_DOMAIN
=
''
DEFAULT_FLUID_OP_NAMESCOPE
=
'/'
DEFAULT_FLUID_OP_NAMESCOPE
=
'/'
...
@@ -186,13 +186,17 @@ def _shape_or_none(value_infos, val_name):
...
@@ -186,13 +186,17 @@ def _shape_or_none(value_infos, val_name):
return
list
(
value_info
[
'shape'
])
return
list
(
value_info
[
'shape'
])
#def _maybe_const_value(value_infos, val_name):
def
_const_weight_or_none
(
value_infos
,
val_name
):
# var_name = _make_var_name(val_name)
if
val_name
not
in
value_infos
:
# if val_name not in value_infos:
return
None
# return var_name
value_info
=
value_infos
[
val_name
]
# value_info = value_infos[val_name]
const_value
=
value_info
.
get
(
'const_value'
,
None
)
# assert value_info.get('remove_batch', False) == False, 'const value should not have batch dim'
if
const_value
:
# return value_info.get('const_value', var_name)
return
const_value
get_weight_func
=
value_info
.
get
(
'get_weight'
,
None
)
if
get_weight_func
:
return
get_weight_func
()
return
None
def
_default
(
prog
,
op_type
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
def
_default
(
prog
,
op_type
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
...
@@ -253,7 +257,7 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
...
@@ -253,7 +257,7 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
num_vars
=
len
(
var_outs
)
num_vars
=
len
(
var_outs
)
num_args
=
len
(
fluid_output_args
)
num_args
=
len
(
fluid_output_args
)
if
num_vars
<
num_args
:
if
num_vars
<
num_args
:
assert
fill_name_field
,
'name required to nam
ing dummy output variable
'
assert
fill_name_field
,
'name required to nam
e dummy output variables
'
for
idx_out
in
range
(
num_vars
,
num_args
):
for
idx_out
in
range
(
num_vars
,
num_args
):
var_out
=
_make_var_name
(
name
+
'.'
+
var_out
=
_make_var_name
(
name
+
'.'
+
fluid_output_args
[
idx_out
].
lower
())
fluid_output_args
[
idx_out
].
lower
())
...
@@ -294,9 +298,8 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
...
@@ -294,9 +298,8 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
if
pads
[
idx_dim
]
!=
pads
[
ndims
+
idx_dim
]:
if
pads
[
idx_dim
]
!=
pads
[
ndims
+
idx_dim
]:
symmetric
=
False
symmetric
=
False
break
break
if
symmetric
:
if
symmetric
:
return
pads
[:
ndims
],
Non
e
return
pads
[:
ndims
],
val_nam
e
val_padded
=
val_name
+
'_padded'
val_padded
=
val_name
+
'_padded'
prog
.
Op
(
prog
.
Op
(
...
@@ -315,13 +318,7 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
...
@@ -315,13 +318,7 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
return
[
0
]
*
ndims
,
val_padded
return
[
0
]
*
ndims
,
val_padded
def
_adaptive_pool
(
prog
,
def
_adaptive_pool
(
prog
,
pool_type
,
inputs
,
outputs
,
attrs
,
name
=
''
):
pool_type
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
):
# I/O
# I/O
val_x
,
=
inputs
val_x
,
=
inputs
val_y
,
=
outputs
[:
1
]
val_y
,
=
outputs
[:
1
]
...
@@ -335,10 +332,6 @@ def _adaptive_pool(prog,
...
@@ -335,10 +332,6 @@ def _adaptive_pool(prog,
# interpretation
# interpretation
pool_size
=
attrs
[
'output_size'
]
# required
pool_size
=
attrs
[
'output_size'
]
# required
output_shape
=
_shape_or_none
(
value_infos
,
val_y
)
if
output_shape
is
not
None
:
assert
pool_size
==
output_shape
[
2
:],
'pool_size unmatches shape of Y'
# NC...
poolnd
=
len
(
pool_size
)
poolnd
=
len
(
pool_size
)
assert
2
<=
poolnd
<=
3
,
'only pool2d and pool3d supported'
assert
2
<=
poolnd
<=
3
,
'only pool2d and pool3d supported'
...
@@ -445,11 +438,9 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
...
@@ -445,11 +438,9 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
fluid_op
=
'pool{}d'
.
format
(
poolnd
)
fluid_op
=
'pool{}d'
.
format
(
poolnd
)
strides
=
attrs
.
get
(
'strides'
,
[
1
]
*
poolnd
)
# optional
strides
=
attrs
.
get
(
'strides'
,
[
1
]
*
poolnd
)
# optional
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
len
(
pool_size
*
2
))
# optional
paddings
,
val_x_padded
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
if
val_x_padded
:
val_x
=
val_x_padded
ceil_mode
=
bool
(
attrs
.
get
(
'ceil_mode'
,
0
))
# optional
ceil_mode
=
bool
(
attrs
.
get
(
'ceil_mode'
,
0
))
# optional
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
(
poolnd
*
2
))
# optional
paddings
,
val_x
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
var_x
=
_make_var_name
(
val_x
)
var_x
=
_make_var_name
(
val_x
)
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
...
@@ -506,17 +497,17 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
...
@@ -506,17 +497,17 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
spatial_scale
=
attrs
[
'spatial_scale'
]
# required
spatial_scale
=
attrs
[
'spatial_scale'
]
# required
pooled_height
,
pooled_width
=
attrs
[
'pooled_shape'
]
# required
pooled_height
,
pooled_width
=
attrs
[
'pooled_shape'
]
# required
od_attrs
=
dict
(
od_attrs
=
dict
(
spatial_scale
=
spatial_scale
,
pooled_height
=
pooled_height
,
pooled_height
=
pooled_height
,
pooled_width
=
pooled_width
,
pooled_width
=
pooled_width
,
spatial_scale
=
spatial_scale
,
)
)
feature_attr
=
''
feature_attr
=
''
is_max_pool
=
fluid_op
==
'roi_pool'
is_max_pool
=
fluid_op
==
'roi_pool'
if
'sampling_ratio'
in
attrs
:
if
'sampling_ratio'
in
attrs
:
#
sampling_ratio
=
attrs
[
'sampling_ratio'
]
sampling_ratio
=
attrs
[
'sampling_ratio'
]
od_attrs
[
'sampling_ratio'
]
=
sampling_ratio
od_attrs
[
'sampling_ratio'
]
=
sampling_ratio
feature_attr
+=
', sampling_ratio={}'
.
format
(
sampling_ratio
)
feature_attr
+=
', sampling_ratio={}'
.
format
(
sampling_ratio
)
if
'output_channels'
in
attrs
:
if
'output_channels'
in
attrs
:
#
output_channels
=
attrs
[
'output_channels'
]
output_channels
=
attrs
[
'output_channels'
]
od_attrs
[
'output_channels'
]
=
output_channels
od_attrs
[
'output_channels'
]
=
output_channels
feature_attr
+=
', output_channels={}'
.
format
(
output_channels
)
feature_attr
+=
', output_channels={}'
.
format
(
output_channels
)
...
@@ -560,36 +551,20 @@ def _zeros_like(prog, val_ref, val_out, value_infos):
...
@@ -560,36 +551,20 @@ def _zeros_like(prog, val_ref, val_out, value_infos):
)
)
def
AdaptiveAveragePool
(
prog
,
def
AdaptiveAveragePool
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
**
kwargs
):
"""
"""
aten::adaptive_avg_poolnd
aten::adaptive_avg_poolnd
"""
"""
return
_adaptive_pool
(
return
_adaptive_pool
(
prog
,
'avg'
,
inputs
,
outputs
,
attrs
,
name
=
name
)
prog
,
'avg'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
def
AdaptiveMaxPool
(
prog
,
def
AdaptiveMaxPool
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
**
kwargs
):
"""
"""
aten::adaptive_max_poolnd
aten::adaptive_max_poolnd
"""
"""
return
_adaptive_pool
(
return
_adaptive_pool
(
prog
,
'max'
,
inputs
,
outputs
,
attrs
,
name
=
name
)
prog
,
'max'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
def
AveragePool
(
prog
,
def
AveragePool
(
prog
,
...
@@ -734,9 +709,9 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -734,9 +709,9 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
var_output
=
_make_var_name
(
val_output
)
var_output
=
_make_var_name
(
val_output
)
# interpretation
# interpretation
dtype
=
attrs
[
'to'
]
dtype
=
attrs
[
'to'
]
# required
if
not
isinstance
(
dtype
,
np
.
dtype
):
if
not
isinstance
(
dtype
,
np
.
dtype
):
# additional: possible np.dtype
dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
dtype
]
# required
dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
dtype
]
output_dtype
=
_dtype_or_none
(
value_infos
,
val_output
)
output_dtype
=
_dtype_or_none
(
value_infos
,
val_output
)
if
output_dtype
:
if
output_dtype
:
assert
dtype
==
output_dtype
,
'dtype of to unmatches output'
assert
dtype
==
output_dtype
,
'dtype of to unmatches output'
...
@@ -818,15 +793,16 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -818,15 +793,16 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
assert
dtype
==
output_dtype
,
'tensor dtype unmatches storage dtype'
assert
dtype
==
output_dtype
,
'tensor dtype unmatches storage dtype'
# dtype = np.dtype('float32') # force to float32
# dtype = np.dtype('float32') #
HINT:
force to float32
shape
=
attrs
.
get
(
'shape'
,
None
)
# additional, maybe var_name
shape
=
attrs
.
get
(
'shape'
,
None
)
# additional, maybe var_name
if
shape
is
None
:
if
shape
is
None
:
shape
=
_shape_or_none
(
value_infos
,
val_output
)
shape
=
_shape_or_none
(
value_infos
,
val_output
)
if
shape
is
None
:
if
shape
is
None
:
shape
=
list
(
value
.
shape
)
shape
=
list
(
value
.
shape
)
_logger
.
warning
(
_logger
.
warning
(
'shape of %s not inferred, using value as 1-D tensor may lead to fails'
,
'in (Constant -> %s): '
val_output
)
'shape of %s not inferred, '
'using value as 1-D tensor may lead to fails'
,
outputs
,
val_output
)
# generation
# generation
if
value
.
size
==
1
:
# scalar
if
value
.
size
==
1
:
# scalar
...
@@ -855,18 +831,27 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -855,18 +831,27 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
"""
"""
# I/O
# I/O
val_input
,
=
inputs
val_shape
,
=
inputs
val_output
,
=
outputs
is_const_shape
=
'const_value'
in
value_infos
[
val_input
]
shape
=
_const_weight_or_none
(
value_infos
,
val_shape
)
if
is_const_shape
:
if
shape
is
None
:
shape
=
_make_var_name
(
val_input
)
shape
=
_shape_or_none
(
value_infos
,
val_output
)
else
:
assert
shape
is
not
None
,
(
shape
=
value_infos
[
val_input
][
'get_weight'
]()
'given shape is neither const value nor deductible from output, '
'this is not supported'
)
dtype
=
attrs
[
'value'
].
dtype
dtype
=
attrs
[
'value'
].
dtype
attrs
=
attrs
.
copy
()
attrs
=
attrs
.
copy
()
attrs
.
update
(
dict
(
shape
=
shape
,
dtype
=
dtype
))
# pass var_name
attrs
.
update
(
dict
(
shape
=
shape
,
dtype
=
dtype
))
# pass var_name
Constant
(
prog
,
[],
outputs
,
attrs
,
value_infos
)
prog
.
Op
(
''
,
'Constant'
,
[],
outputs
,
# val
attrs
,
value_infos
,
)
def
Conv
(
prog
,
def
Conv
(
prog
,
...
@@ -903,13 +888,11 @@ def Conv(prog,
...
@@ -903,13 +888,11 @@ def Conv(prog,
num_out_channels
=
_shape
(
value_infos
,
val_w
)[
0
]
# OI...
num_out_channels
=
_shape
(
value_infos
,
val_w
)[
0
]
# OI...
fluid_op
=
'conv{}d'
.
format
(
convnd
)
fluid_op
=
'conv{}d'
.
format
(
convnd
)
num_groups
=
attrs
.
get
(
'group'
,
1
)
# optional
strides
=
attrs
.
get
(
'strides'
,
[
1
]
*
convnd
)
# optional
strides
=
attrs
.
get
(
'strides'
,
[
1
]
*
convnd
)
# optional
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
convnd
*
2
)
# optional
paddings
,
val_x_padded
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
if
val_x_padded
:
val_x
=
val_x_padded
dilations
=
attrs
.
get
(
'dilations'
,
[
1
]
*
convnd
)
# optional
dilations
=
attrs
.
get
(
'dilations'
,
[
1
]
*
convnd
)
# optional
num_groups
=
attrs
.
get
(
'group'
,
1
)
# optional
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
(
convnd
*
2
))
# optional
paddings
,
val_x
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
var_x
=
_make_var_name
(
val_x
)
var_x
=
_make_var_name
(
val_x
)
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
if
embed_params
:
...
@@ -1014,13 +997,11 @@ def ConvTranspose(prog,
...
@@ -1014,13 +997,11 @@ def ConvTranspose(prog,
num_out_channels
=
_shape
(
value_infos
,
val_w
)[
1
]
# IO...
num_out_channels
=
_shape
(
value_infos
,
val_w
)[
1
]
# IO...
fluid_op
=
'conv{}d_transpose'
.
format
(
convnd
)
fluid_op
=
'conv{}d_transpose'
.
format
(
convnd
)
num_groups
=
attrs
.
get
(
'group'
,
1
)
# optional
strides
=
attrs
.
get
(
'strides'
,
[
1
]
*
convnd
)
# optional
strides
=
attrs
.
get
(
'strides'
,
[
1
]
*
convnd
)
# optional
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
convnd
*
2
)
# optional
paddings
,
val_x_padded
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
if
val_x_padded
:
val_x
=
val_x_padded
dilations
=
attrs
.
get
(
'dilations'
,
[
1
]
*
convnd
)
# optional
dilations
=
attrs
.
get
(
'dilations'
,
[
1
]
*
convnd
)
# optional
num_groups
=
attrs
.
get
(
'group'
,
1
)
# optional
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
(
convnd
*
2
))
# optional
paddings
,
val_x
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
var_x
=
_make_var_name
(
val_x
)
var_x
=
_make_var_name
(
val_x
)
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
if
embed_params
:
...
@@ -1090,7 +1071,7 @@ def ConvTranspose(prog,
...
@@ -1090,7 +1071,7 @@ def ConvTranspose(prog,
prog
.
VarDesc
(
var_y
)
prog
.
VarDesc
(
var_y
)
# should not appear
s
# should not appear
#def Dropout(
#def Dropout(
# prog, inputs, outputs, value_infos,
# prog, inputs, outputs, value_infos,
# *args, **kwargs):
# *args, **kwargs):
...
@@ -1154,10 +1135,16 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1154,10 +1135,16 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
else
:
else
:
val_beta
=
name
+
'_beta'
# explicit variable
val_beta
=
name
+
'_beta'
# explicit variable
val_vm
=
name
+
'_vm'
# explicit variable
val_vm
=
name
+
'_vm'
# explicit variable
vm_dtype
=
_dtype_or_none
(
value_infos
,
val_c
)
if
beta
.
is_integer
():
if
vm_dtype
is
None
:
vm_dtype
=
_dtype_or_none
(
value_infos
,
val_c
)
vm_dtype
=
np
.
dtype
(
'float32'
)
if
vm_dtype
is
None
:
beta
=
np
.
dtype
(
vm_dtype
).
type
(
beta
)
vm_dtype
=
np
.
dtype
(
'float32'
)
_logger
.
warning
(
'in %s(%s -> Gemm -> %s): '
'beta seems to be an interger, '
'however dtype can not be inferred, '
'still use float32'
,
name
,
inputs
,
outputs
)
beta
=
np
.
dtype
(
vm_dtype
).
type
(
beta
)
prog
.
Op
(
prog
.
Op
(
''
,
''
,
'Constant'
,
'Constant'
,
...
@@ -1429,13 +1416,15 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1429,13 +1416,15 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
var_reshaped
=
_make_var_name
(
val_reshaped
)
var_reshaped
=
_make_var_name
(
val_reshaped
)
# interpretation
# interpretation
fluid_op
=
'reshape'
is_const_shape
=
'const_value'
in
value_infos
[
val_shape
]
var_shape
=
_make_var_name
(
val_shape
)
# for code
var_shape
=
_make_var_name
(
val_shape
)
# for code
if
is_const_shape
:
shape
=
_const_weight_or_none
(
value_infos
,
val_shape
)
shape
=
value_infos
[
val_shape
][
'const_value'
]
# for desc
is_const_shape
=
shape
and
'const_value'
in
value_infos
[
val_shape
]
else
:
if
shape
is
None
:
shape
=
value_infos
[
val_shape
][
'get_weight'
]()
# for desc
shape
=
_shape_or_none
(
value_infos
,
var_reshaped
)
assert
shape
is
not
None
,
(
'given shape is neither const value nor deductible from output, '
'this is not supported'
)
fluid_op
=
'reshape'
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
# generation
# generation
...
@@ -1457,7 +1446,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1457,7 +1446,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'Cast'
,
'Cast'
,
[
var_shape
],
[
var_shape
],
[
var_shape_int32
],
# var
[
var_shape_int32
],
# var
dict
(
to
=
np
.
dtype
(
'int32'
)),
dict
(
to
=
np
.
dtype
(
'int32'
)),
# use np.dtype
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
(
name
+
'_cast'
),
name
=
(
name
+
'_cast'
),
)
)
...
@@ -1593,26 +1582,25 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
...
@@ -1593,26 +1582,25 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
var_output
=
_make_var_name
(
val_output
)
var_output
=
_make_var_name
(
val_output
)
# interpretation
# interpretation
repeats
=
_const_weight_or_none
(
value_infos
,
val_repeats
)
assert
repeats
is
not
None
,
'only const repeats is supported'
fluid_op
=
'expand'
fluid_op
=
'expand'
is_const_repeats
=
'const_value'
in
value_infos
[
val_repeats
]
if
is_const_repeats
:
code_repeats
=
_make_var_name
(
val_repeats
)
# for code
repeats
=
value_infos
[
val_repeats
][
'const_value'
]
# for desc
else
:
repeats
=
value_infos
[
val_input
][
'get_weight'
]()
# for desc
code_repeats
=
repeats
# for code
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
# generation
# generation
prog
.
Code
(
'{} = layers.{}({}'
prog
.
Code
(
'{} = layers.{}({}'
', expand_times={}'
', expand_times={}'
'{})'
.
format
(
'{})'
' # {} = {}'
.
format
(
var_output
,
var_output
,
fluid_op
,
fluid_op
,
var_input
,
var_input
,
# attrs
# attrs
code_
repeats
,
repeats
,
name_attr
,
name_attr
,
# comment
_make_var_name
(
val_repeats
),
repeats
,
))
))
prog
.
VarDesc
(
var_output
)
prog
.
VarDesc
(
var_output
)
prog
.
OpDesc
(
prog
.
OpDesc
(
...
...
onnx2fluid/onnx2fluid/validation.py
浏览文件 @
ba40d265
...
@@ -6,11 +6,12 @@ Created on Fri Mar 22 12:17:19 2019
...
@@ -6,11 +6,12 @@ Created on Fri Mar 22 12:17:19 2019
@author: Macrobull
@author: Macrobull
"""
"""
# import importlib, logging, os, sys
import
importlib
,
logging
,
os
,
sys
import
importlib
import
logging
#import importlib
import
os
#import logging
import
sys
#import os
#import sys
def
_flatten_dict
(
obj
,
out
=
None
):
def
_flatten_dict
(
obj
,
out
=
None
):
...
...
onnx2fluid/onnx2fluid/writer.py
浏览文件 @
ba40d265
...
@@ -8,9 +8,9 @@ Created on Sun Feb 24 20:44:43 2019
...
@@ -8,9 +8,9 @@ Created on Sun Feb 24 20:44:43 2019
from
__future__
import
division
from
__future__
import
division
#
import logging, os
import
logging
,
os
import
logging
#
import logging
import
os
#
import os
import
numpy
as
np
import
numpy
as
np
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -215,10 +215,6 @@ class Program(object):
...
@@ -215,10 +215,6 @@ class Program(object):
var_desc
.
persistable
=
persistable
var_desc
.
persistable
=
persistable
var_desc
.
type
.
type
=
framework_pb2
.
VarType
.
LOD_TENSOR
var_desc
.
type
.
type
=
framework_pb2
.
VarType
.
LOD_TENSOR
# REMOVEIT: WORKAROUND: Netron: null.tensor error
tensor_desc
=
var_desc
.
type
.
lod_tensor
.
tensor
tensor_desc
.
data_type
=
self
.
Dtype
(
dummy_dtype
)
# required
if
value_info
and
'dtype'
in
value_info
:
if
value_info
and
'dtype'
in
value_info
:
tensor_desc
=
var_desc
.
type
.
lod_tensor
.
tensor
tensor_desc
=
var_desc
.
type
.
lod_tensor
.
tensor
tensor_desc
.
data_type
=
self
.
Dtype
(
value_info
[
'dtype'
])
# required
tensor_desc
.
data_type
=
self
.
Dtype
(
value_info
[
'dtype'
])
# required
...
@@ -230,6 +226,9 @@ class Program(object):
...
@@ -230,6 +226,9 @@ class Program(object):
not
persistable
)
not
persistable
)
if
remove_batch
:
if
remove_batch
:
tensor_desc
.
dims
[
0
]
=
-
1
tensor_desc
.
dims
[
0
]
=
-
1
else
:
# REMOVEIT: WORKAROUND: Netron: null.tensor error
tensor_desc
=
var_desc
.
type
.
lod_tensor
.
tensor
tensor_desc
.
data_type
=
self
.
Dtype
(
dummy_dtype
)
# required
self
.
var_descs
.
append
(
var_desc
)
self
.
var_descs
.
append
(
var_desc
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录