Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
492e9661
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
492e9661
编写于
5月 26, 2019
作者:
M
Macrobull
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add new ONNX polish_model
上级
826481c4
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
244 addition
and
127 deletion
+244
-127
onnx2fluid/onnx2fluid/cmdline.py
onnx2fluid/onnx2fluid/cmdline.py
+3
-4
onnx2fluid/onnx2fluid/conversion.py
onnx2fluid/onnx2fluid/conversion.py
+9
-18
onnx2fluid/onnx2fluid/onnx_utils.py
onnx2fluid/onnx2fluid/onnx_utils.py
+58
-21
onnx2fluid/onnx2fluid/symbolic.py
onnx2fluid/onnx2fluid/symbolic.py
+23
-25
onnx2fluid/onnx2fluid/torch_export_helper.py
onnx2fluid/onnx2fluid/torch_export_helper.py
+89
-29
onnx2fluid/onnx2fluid/validation.py
onnx2fluid/onnx2fluid/validation.py
+62
-30
未找到文件。
onnx2fluid/onnx2fluid/cmdline.py
浏览文件 @
492e9661
...
...
@@ -33,7 +33,7 @@ def main(**kwargs):
from
.conversion
import
convert
logger
=
logging
.
getLogger
(
'onnx2fluid'
)
debug
=
kwargs
.
get
(
'debug'
,
False
)
#
debug = kwargs.get('debug', False)
# prepare arguments
filename
=
kwargs
.
pop
(
'model'
)[
0
]
...
...
@@ -65,8 +65,7 @@ def main(**kwargs):
from
.validation
import
validate
save_inference_model
=
infer_inputs
is
not
None
inference_input_names
=
infer_inputs
.
split
(
','
)
if
infer_inputs
else
None
inference_input_names
=
infer_inputs
and
infer_inputs
.
split
(
','
)
logger
.
info
(
'starting validation on desc ...'
)
passed
&=
validate
(
shutil
.
os
.
path
.
join
(
save_dir
,
'__model__'
),
...
...
@@ -85,7 +84,7 @@ def main(**kwargs):
**
kwargs
)
if
not
passed
:
logger
.
error
(
'validation failed, exit'
)
logger
.
fatal
(
'validation failed, exit'
)
return
# create zip file
...
...
onnx2fluid/onnx2fluid/conversion.py
浏览文件 @
492e9661
...
...
@@ -34,15 +34,12 @@ def convert(onnx_model_filename,
from
onnx.checker
import
ValidationError
from
onnx.checker
import
check_model
from
onnx.utils
import
polish_model
from
onnx.version_converter
import
convert_version
from
.onnx_utils
import
DEFAULT_OP_DOMAIN
from
.onnx_utils
import
graph_ops
,
graph_weights
from
.onnx_utils
import
inferred_model_value_info
from
.onnx_utils
import
optimize_model_skip_op_for_inference
from
.onnx_utils
import
optimize_model_strip_initializer
from
.onnx_utils
import
optimize_model_cast
,
optimize_model_slice
from
.onnx_utils
import
polish_model
from
.writer
import
Program
,
Writer
from
.writer
import
make_var_name
...
...
@@ -56,14 +53,12 @@ def convert(onnx_model_filename,
logger
.
info
(
'checking model ...'
)
check_model
(
onnx_model
)
if
onnx_opset_version
is
None
:
# WORKAROUND: RuntimeError: No Adapter For OP
logger
.
debug
(
'assumed opset version: %d'
,
DEFAULT_ONNX_OPSET_VERSION
)
logger
.
warning
(
'opset conversion skipped for onnx_opset_pedantic is OFF'
)
logger
.
info
(
'assumed opset version: %d'
,
DEFAULT_ONNX_OPSET_VERSION
)
else
:
logger
.
debug
(
'using opset version: %d'
,
onnx_opset_version
)
logger
.
info
(
'using opset version: %d'
,
onnx_opset_version
)
onnx_model
=
convert_version
(
onnx_model
,
onnx_opset_version
)
onnx_model
=
polish_model
(
onnx_model
)
except
ValidationError
as
e
:
if
onnx_opset_pedantic
:
raise
e
...
...
@@ -75,10 +70,7 @@ def convert(onnx_model_filename,
# onnx model optimization
logger
.
info
(
'model has %d ops'
,
len
(
onnx_model
.
graph
.
node
))
logger
.
info
(
'optimizing model ...'
)
onnx_model
=
optimize_model_skip_op_for_inference
(
onnx_model
)
onnx_model
=
optimize_model_strip_initializer
(
onnx_model
)
onnx_model
=
optimize_model_cast
(
onnx_model
)
onnx_model
=
optimize_model_slice
(
onnx_model
)
onnx_model
=
polish_model
(
onnx_model
)
# prepare filesystem
shutil
.
rmtree
(
save_dir
,
ignore_errors
=
True
)
...
...
@@ -87,9 +79,8 @@ def convert(onnx_model_filename,
# DEBUG:
if
debug
:
model
=
onnx
.
shape_inference
.
infer_shapes
(
onnx_model
)
debug_model_filename
,
_
=
shutil
.
os
.
path
.
splitext
(
onnx_model_filename
)
onnx
.
save
(
model
,
debug_model_filename
+
'.optimized_and_inffer
ed.onnx'
)
onnx
.
save
(
onnx_model
,
debug_model_filename
+
'.polish
ed.onnx'
)
# I/O instances
onnx_graph
=
onnx_model
.
graph
...
...
@@ -141,11 +132,11 @@ def convert(onnx_model_filename,
logger
.
info
(
'%d ops in, %d ops out'
,
len
(
onnx_graph
.
node
),
len
(
fluid_program
.
op_descs
))
# type-shape inf
erence
# type-shape inf
o copy
for
name
,
value_info
in
graph_value_infos
.
items
():
var_name
=
make_var_name
(
name
)
fluid_program
.
VarTypeShapeInfo
(
var_name
,
value_info
,
remove_batch
=
False
)
#
shape-infer only
remove_batch
=
False
)
#
bad_var_names
=
[]
for
var_name
,
var_desc
in
fluid_program
.
var_descs
.
items
():
if
not
var_desc
.
type
.
lod_tensor
.
HasField
(
'tensor'
):
...
...
@@ -155,8 +146,8 @@ def convert(onnx_model_filename,
', '
.
join
(
bad_var_names
[:
5
]))
logger
.
warning
(
'this causes little problem for PaddlePaddle, '
'but Paddle Mobile may not infer correctly'
)
logger
.
warning
(
'please consider running
onnx2fluid.
validation with -i '
'to invoke
PaddlePaddle type-shape inferenc
e'
)
logger
.
warning
(
'please consider running validation with -i '
'to invoke
type-shape inference in PaddlePaddl
e'
)
# weight writer
for
name
,
weight
in
graph_weights
(
onnx_graph
):
...
...
onnx2fluid/onnx2fluid/onnx_utils.py
浏览文件 @
492e9661
...
...
@@ -11,9 +11,11 @@ from __future__ import division
import
logging
import
numpy
as
np
import
onnx
import
onnx.optimizer
as
optimizer
from
collections
import
OrderedDict
as
Dict
# as default dict
from
onnx.helper
import
get_attribute_value
,
make_attribute
from
onnx.checker
import
check_model
from
onnx.helper
import
get_attribute_value
,
make_attribute
,
strip_doc_string
from
onnx.mapping
import
TENSOR_TYPE_TO_NP_TYPE
from
onnx.numpy_helper
import
to_array
from
onnx.shape_inference
import
infer_shapes
...
...
@@ -23,14 +25,16 @@ logger = logging.getLogger(__name__)
__all__
=
[
'print_pb_structure'
,
'build_value_refs'
,
'tensor_dtype'
,
'tensor_shape'
,
'node_attrs'
,
'node_topo'
,
'node_iter'
,
'tensor_dtype'
,
'tensor_shape'
,
'graph_ops'
,
'graph_weights'
,
'inferred_model_value_info'
,
'polish_model'
,
'polish_and_save'
,
'optimize_model_skip_op_for_inference'
,
'optimize_model_strip_initializer'
,
'optimize_model_cast'
,
...
...
@@ -110,7 +114,7 @@ def tensor_shape(tensor):
get ONNX tensor shape
"""
return
[
dim
.
dim_value
for
dim
in
tensor
.
type
.
tensor_type
.
shape
.
dim
]
return
tuple
([
dim
.
dim_value
for
dim
in
tensor
.
type
.
tensor_type
.
shape
.
dim
])
def
node_attrs
(
node
):
...
...
@@ -195,10 +199,7 @@ def node_iter(nodes, indices=None):
generator for ONNX node graph with given indices
"""
if
indices
is
None
:
indices
=
range
(
len
(
nodes
))
for
index
in
indices
:
for
index
in
indices
or
range
(
len
(
nodes
)):
node
=
nodes
[
index
]
name
=
node
.
name
domain
=
node
.
domain
...
...
@@ -306,6 +307,48 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs):
return
processed
def
polish_model
(
model
,
extras
=
True
):
"""
polish_model enhanced for inference
"""
check_model
(
model
)
strip_doc_string
(
model
)
passes
=
optimizer
.
get_available_passes
()
passes
=
list
(
filter
(
lambda
name
:
not
name
.
startswith
(
'split_'
),
passes
))
#
logger
.
debug
(
'builtin optimizations to perform in ONNX:
\n\t
%s'
,
passes
)
model
=
optimizer
.
optimize
(
model
,
passes
=
passes
)
if
extras
:
for
optimize
in
(
optimize_model_skip_op_for_inference
,
optimize_model_strip_initializer
,
optimize_model_cast
,
optimize_model_slice
,
):
model
=
optimize
(
model
)
model
=
infer_shapes
(
model
)
check_model
(
model
)
return
model
def
polish_and_save
(
model_filename
,
suffix
=
'.polished'
,
save_filename
=
None
,
*
args
,
**
kwargs
):
"""
run polish_model and save
"""
model
=
onnx
.
load
(
model_filename
)
model
=
polish_model
(
model
,
*
args
,
**
kwargs
)
save_filename
=
save_filename
or
model_filename
.
replace
(
'.onnx'
,
suffix
+
'.onnx'
)
onnx
.
save
(
model
,
save_filename
)
logger
.
info
(
'polished model saved to: %s'
,
save_filename
)
return
save_filename
def
optimize_model_skip_op_for_inference
(
model
,
op_list
=
None
):
"""
skip ops can be bypassed for inference
...
...
@@ -326,7 +369,7 @@ def optimize_model_skip_op_for_inference(model, op_list=None):
if
not
(
node
.
domain
==
DEFAULT_OP_DOMAIN
or
node
.
domain
==
''
):
continue
op_type
=
node
.
op_type
if
not
(
op_type
in
op_list
)
:
if
op_type
not
in
op_list
:
continue
if
op_type
in
(
'Dropout'
,
):
...
...
@@ -590,22 +633,16 @@ if __name__ == '__main__':
level
=
logging
.
DEBUG
,
)
from
onnx.checker
import
check_model
from
onnx.utils
import
polish_model
from
onnx.version_converter
import
convert_version
model
=
onnx
.
load
(
'
../examples/t1
.onnx'
)
model
=
onnx
.
load
(
'
/tmp/export
.onnx'
)
print_pb_structure
(
model
,
loop_iterative
=
False
)
check_model
(
model
)
model
=
convert_version
(
model
,
9
)
model
=
optimize_model_skip_op_for_inference
(
model
)
model
=
optimize_model_strip_initializer
(
model
)
model
=
optimize_model_cast
(
model
)
model
=
optimize_model_slice
(
model
)
model
=
polish_model
(
model
)
onnx
.
save
(
model
,
'/tmp/optimized.onnx'
)
onnx
.
save
(
model
,
'/tmp/
export.
optimized.onnx'
)
graph
=
model
.
graph
value_info
=
inferred_model_value_info
(
model
)
...
...
@@ -617,23 +654,23 @@ if __name__ == '__main__':
logger
.
info
(
'ops:'
)
for
name
,
domain
,
op_type
,
_
,
_
,
attrs
in
graph_ops
(
graph
,
topo
=
'forward'
):
logger
.
info
(
'%s %s::%s: %s'
,
name
,
domain
,
op_type
,
attrs
)
logger
.
info
(
'
-
\t
%s %s::%s: %s'
,
name
,
domain
,
op_type
,
attrs
)
logger
.
info
(
'weights:'
)
for
name
,
array
in
graph_weights
(
graph
):
weights
.
append
(
name
)
logger
.
info
(
'%s: %s'
,
name
,
array
.
shape
)
logger
.
info
(
'
-
\t
%s: %s'
,
name
,
array
.
shape
)
logger
.
info
(
'inputs:'
)
external_inputs
=
[]
for
name
in
inputs
:
if
name
not
in
weights
:
external_inputs
.
append
(
name
)
logger
.
info
(
'%s: %s'
,
name
,
value_info
[
name
][
'shape'
])
logger
.
info
(
'
-
\t
%s: %s'
,
name
,
value_info
[
name
][
'shape'
])
logger
.
info
(
'outputs:'
)
external_outputs
=
[]
for
name
in
outputs
:
if
name
not
in
weights
:
external_outputs
.
append
(
name
)
logger
.
info
(
'%s: %s'
,
name
,
value_info
[
name
][
'shape'
])
logger
.
info
(
'
-
\t
%s: %s'
,
name
,
value_info
[
name
][
'shape'
])
onnx2fluid/onnx2fluid/symbolic.py
浏览文件 @
492e9661
...
...
@@ -203,8 +203,7 @@ def _check_embeddable(value_infos, *val_names):
keyword
=
'get_weight'
for
val_name
in
val_names
:
if
keyword
not
in
value_infos
[
val_name
]:
_logger
.
warning
(
'parameter %s not embeddable for some ops'
,
val_name
)
_logger
.
warning
(
'parameter %s not embeddable'
,
val_name
)
return
False
return
True
...
...
@@ -240,9 +239,9 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
fluid_attrs
=
default_attrs
.
copy
()
fluid_attrs
.
update
(
mapped_attrs
)
# as new attrs
val_inps
=
inputs
if
input_perm
is
None
else
map
(
lambda
i
:
inputs
[
i
]
,
val_inps
=
inputs
if
input_perm
is
None
else
map
(
inputs
.
__getitem__
,
input_perm
)
val_outs
=
outputs
if
output_perm
is
None
else
map
(
lambda
i
:
outputs
[
i
]
,
val_outs
=
outputs
if
output_perm
is
None
else
map
(
outputs
.
__getitem__
,
output_perm
)
var_inps
=
[
_make_var_name
(
val
)
for
val
in
val_inps
]
var_outs
=
[
_make_var_name
(
val
)
for
val
in
val_outs
]
...
...
@@ -578,7 +577,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
1
]
==
1
,
'only scale on (NC)HW supported'
assert
scales
[
2
]
==
scales
[
3
],
'only aspect-ratio-invariant scale supported'
scale
=
None
if
scales
is
None
else
scales
[
2
]
scale
=
scales
and
scales
[
2
]
# try input shape
if
scale
is
None
:
assert
out_shape_
,
'neither scales nor output shape is available'
...
...
@@ -717,6 +716,10 @@ def BatchNormalization(prog,
if
embed_params
:
embed_params
=
_check_embeddable
(
value_infos
,
val_scale
,
val_b
,
val_mean
,
val_var
)
if
not
embed_params
and
name
:
_logger
.
warning
(
'for op %s(%s -> BatchNormalization -> %s)'
,
name
,
inputs
,
outputs
)
_logger
.
warning
(
'broken Python code will be generated'
)
if
embed_params
:
assert
name
!=
''
var_scale
=
name
+
'.w_0'
...
...
@@ -875,7 +878,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
if
shape
is
None
:
shape
=
list
(
value
.
shape
)
_logger
.
warning
(
'in (Constant -> %s): '
'in
op
(Constant -> %s): '
'attribute "shape" of %s not inferred, '
'using value as 1-D tensor may lead to fails'
,
outputs
,
val_output
)
...
...
@@ -986,6 +989,10 @@ def Conv(prog,
if
embed_params
:
embed_params
=
(
_check_embeddable
(
value_infos
,
val_w
)
and
not
has_bias
or
_check_embeddable
(
value_infos
,
val_b
))
if
not
embed_params
and
name
:
_logger
.
warning
(
'for op %s(%s -> Conv -> %s)'
,
name
,
inputs
,
outputs
)
_logger
.
warning
(
'broken Python code will be generated'
)
if
embed_params
:
assert
name
!=
''
var_w
=
name
+
'.w_0'
...
...
@@ -1099,6 +1106,10 @@ def ConvTranspose(prog,
if
embed_params
:
embed_params
=
(
_check_embeddable
(
value_infos
,
val_w
)
and
not
has_bias
or
_check_embeddable
(
value_infos
,
val_b
))
if
not
embed_params
and
name
:
_logger
.
warning
(
'for op %s(%s -> ConvTranspose -> %s)'
,
name
,
inputs
,
outputs
)
_logger
.
warning
(
'broken Python code will be generated'
)
if
embed_params
:
assert
name
!=
''
var_w
=
name
+
'.w_0'
...
...
@@ -1167,23 +1178,6 @@ def ConvTranspose(prog,
prog
.
VarDesc
(
var_y
)
# should not appear
#def Dropout(
# prog, inputs, outputs, value_infos,
# *args, **kwargs):
# """
# onnx::Dropout-7:9
# """
#
# val_data, = inputs
# val_output, = outputs[:1]
#
# _assign(prog,
# dict([(val_output, val_data)]),
# value_infos,
# )
def
Gemm
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
**
kwargs
):
"""
onnx::Gemm-9:
...
...
@@ -1236,7 +1230,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
if
vm_dtype
is
None
:
vm_dtype
=
_np
.
dtype
(
'float32'
)
_logger
.
warning
(
'in %s(%s -> Gemm -> %s): '
'in
op
%s(%s -> Gemm -> %s): '
'attribute "beta" seems to be an interger, '
'however dtype can not be inferred, '
'still use float32'
,
name
,
inputs
,
outputs
)
...
...
@@ -1425,6 +1419,10 @@ def PRelu(prog,
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
embed_params
=
_check_embeddable
(
value_infos
,
val_slope
)
if
not
embed_params
and
name
:
_logger
.
warning
(
'for op %s(%s -> PRelu -> %s)'
,
name
,
inputs
,
outputs
)
_logger
.
warning
(
'broken Python code will be generated'
)
if
embed_params
:
assert
name
!=
''
var_slope
=
name
+
'.w_0'
...
...
@@ -1487,7 +1485,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
if
shape
is
None
:
shape
=
[
1
,
-
1
]
# who knows
_logger
.
warning
(
'in %s(%s -> Reshape -> %s): '
'in
op
%s(%s -> Reshape -> %s): '
'input "shape" not inferred, use [1, -1] as dummy value, '
'the behavior of Paddle fluid maybe undefined'
,
name
,
inputs
,
outputs
)
...
...
onnx2fluid/onnx2fluid/torch_export_helper.py
浏览文件 @
492e9661
...
...
@@ -9,22 +9,50 @@ Created on Fri Mar 22 11:22:46 2019
import
numpy
as
np
import
torch
from
collections
import
OrderedDict
as
Dict
from
collections
import
OrderedDict
from
typing
import
(
TypeVar
,
Any
,
Generic
,
Iterable
,
List
,
Mapping
,
Optional
,
Sequence
,
Text
,
Tuple
,
Union
,
)
__all__
=
[
'export_data'
,
'export_onnx_with_validation'
,
]
def
ensure_list
(
obj
):
my_dict
=
OrderedDict
KT
=
TypeVar
(
'KT'
)
VT
=
TypeVar
(
'VT'
)
class
MyDict
(
my_dict
,
Generic
[
KT
,
VT
]):
pass
def
ensure_list
(
obj
:
Union
[
object
,
Sequence
[
object
]])
->
List
[
object
]:
if
isinstance
(
obj
,
(
list
,
tuple
,
set
)):
return
list
(
obj
)
return
[
obj
]
def
ensure_tuple
(
obj
)
:
def
ensure_tuple
(
obj
:
Union
[
object
,
Sequence
[
object
]])
->
Tuple
[
object
,
...]
:
if
isinstance
(
obj
,
(
tuple
,
list
,
set
)):
return
tuple
(
obj
)
return
(
obj
,
)
def
flatten_list
(
obj
,
out
=
None
):
def
flatten_list
(
obj
:
List
[
Union
[
object
,
List
[
object
]]],
out
:
Optional
[
List
[
object
]]
=
None
)
->
List
[
object
]:
assert
isinstance
(
obj
,
list
),
'list type required'
if
out
is
None
:
...
...
@@ -37,21 +65,21 @@ def flatten_list(obj, out=None):
return
out
def
export_data
(
state_dict
,
prefix
=
''
)
:
def
export_data
(
state_dict
:
Mapping
[
Text
,
Any
],
prefix
:
Text
=
''
)
->
None
:
"""
export binary data with meta text for raw C++ inference engines
"""
def
str_
(
obj
)
:
def
str_
(
obj
:
object
)
->
Text
:
if
isinstance
(
obj
,
(
tuple
,
list
,
set
)):
return
str
(
obj
)[
1
:
-
1
].
replace
(
' '
,
''
)
return
str
(
obj
)
prefix_
=
prefix
+
(
'_'
if
prefix
else
''
)
fp
=
open
(
'{}.txt'
.
format
(
prefix
if
prefix
else
'meta'
),
'w'
)
fp
=
open
(
'{}.txt'
.
format
(
prefix
or
'meta'
),
'w'
)
for
key
,
value
in
state_dict
.
items
():
data
=
None
if
torch
and
torch
.
is_tensor
(
value
):
if
torch
.
is_tensor
(
value
):
data
=
value
.
data
.
cpu
().
numpy
()
elif
isinstance
(
value
,
np
.
ndarray
):
data
=
value
...
...
@@ -64,30 +92,33 @@ def export_data(state_dict, prefix=''):
fp
.
close
()
def
export_onnx_with_validation
(
model
,
inputs
,
export_basepath
,
input_names
=
None
,
output_names
=
None
,
use_npz
=
True
,
*
args
,
**
kwargs
):
def
export_onnx_with_validation
(
model
:
torch
.
nn
.
Module
,
inputs
:
Sequence
[
Union
[
torch
.
Tensor
,
Sequence
[
object
]]],
export_basepath
:
Text
,
input_names
:
Optional
[
List
[
Text
]]
=
None
,
output_names
:
Optional
[
List
[
Text
]]
=
None
,
use_npz
:
bool
=
True
,
*
args
,
**
kwargs
)
->
Sequence
[
Union
[
torch
.
Tensor
,
Sequence
[
object
]]]:
"""
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
"""
is_tuple_or_list
=
lambda
x
:
isinstance
(
x
,
(
tuple
,
list
))
def
tensors_to_arrays
(
tensors
):
def
tensors_to_arrays
(
tensors
:
Union
[
torch
.
Tensor
,
Iterable
[
Union
[
torch
.
Tensor
,
Iterable
[
Any
]]]],
)
->
List
[
np
.
ndarray
]:
if
torch
.
is_tensor
(
tensors
):
return
tensors
.
data
.
cpu
().
numpy
()
arrays
=
[]
for
tensor
in
tensors
:
arrays
.
append
(
tensors_to_arrays
(
tensor
))
return
arrays
def
zip_dict
(
keys
,
values
):
ret
=
Dict
()
return
list
(
map
(
tensors_to_arrays
,
tensors
))
def
zip_dict
(
keys
:
Union
[
Iterable
[
Any
],
None
],
values
:
Sequence
[
Union
[
Any
,
Sequence
[
Any
]]],
)
->
MyDict
[
Text
,
Union
[
object
,
MyDict
[
Text
,
object
]]]:
keys
=
keys
or
range
(
len
(
values
))
ret
=
my_dict
()
for
idx
,
(
key
,
value
)
in
enumerate
(
zip
(
keys
,
values
)):
is_key_list
=
is_tuple_or_list
(
key
)
is_value_list
=
is_tuple_or_list
(
value
)
...
...
@@ -102,19 +133,48 @@ def export_onnx_with_validation(model,
outputs
=
torch
.
onnx
.
export
(
model
,
torch_inputs
,
export_basepath
+
'.onnx'
,
input_names
=
flatten_list
(
input_names
),
output_names
=
flatten_list
(
output_names
),
input_names
=
(
None
if
input_names
is
None
else
flatten_list
(
input_names
)),
output_names
=
(
None
if
output_names
is
None
else
flatten_list
(
output_names
)),
*
args
,
**
kwargs
)
if
outputs
is
None
:
# WORKAROUND: for torch.onnx
outputs
=
model
(
*
inputs
)
training
=
kwargs
.
get
(
'training'
,
False
)
with
torch
.
onnx
.
set_training
(
model
,
training
):
outputs
=
model
(
*
inputs
)
torch_outputs
=
ensure_tuple
(
outputs
)
inputs
=
zip_dict
(
input_names
,
tensors_to_arrays
(
torch_inputs
))
outputs
=
zip_dict
(
output_names
,
tensors_to_arrays
(
torch_outputs
))
if
use_npz
:
np
.
savez
(
export_basepath
+
'.npz'
,
inputs
=
inputs
,
outputs
=
outputs
)
np
.
savez
(
export_basepath
+
'.npz'
,
inputs
=
inputs
,
outputs
=
outputs
,
)
else
:
np
.
save
(
export_basepath
+
'.npy'
,
np
.
array
(
Dict
(
inputs
=
inputs
,
outputs
=
outputs
)))
np
.
asarray
(
my_dict
(
inputs
=
inputs
,
outputs
=
outputs
)),
allow_pickle
=
True
)
return
torch_outputs
if
__name__
==
'__main__'
:
from
torchvision.models
import
resnet18
as
net
model
=
net
()
xb
=
torch
.
rand
((
1
,
3
,
224
,
224
))
export_onnx_with_validation
(
model
,
(
xb
,
),
'/tmp/export'
,
input_names
=
[
'image'
,
],
output_names
=
[
'prob'
,
],
use_npz
=
True
,
)
onnx2fluid/onnx2fluid/validation.py
浏览文件 @
492e9661
...
...
@@ -8,6 +8,13 @@ Created on Fri Mar 22 12:17:19 2019
import
importlib
,
logging
,
os
,
sys
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'fluid_prog_shape_infer'
,
'validate'
,
]
def
flatten_dict
(
obj
,
out
=
None
):
assert
isinstance
(
obj
,
dict
),
'dict type required'
...
...
@@ -29,6 +36,42 @@ def ensure_list(obj):
return
[
obj
]
def
fluid_prog_shape_infer
(
prog
):
"""
additional type-shape inference for fluid program
"""
import
paddle.fluid
as
fluid
assert
isinstance
(
prog
,
fluid
.
framework
.
Program
)
logger
.
info
(
'performing type-shape inference ...'
)
for
block
in
prog
.
blocks
:
block_desc
=
block
.
desc
for
idx_op
in
range
(
block_desc
.
op_size
()):
op_desc
=
block_desc
.
op
(
idx_op
)
if
op_desc
.
type
()
in
(
'feed'
,
'fetch'
):
continue
op_desc
.
infer_var_type
(
block_desc
)
op_desc
.
infer_shape
(
block_desc
)
for
var_name
,
var
in
block
.
vars
.
items
():
var_desc
=
var
.
desc
if
var_desc
.
type
()
!=
fluid
.
core
.
VarDesc
.
VarType
.
LOD_TENSOR
:
continue
# WORKAROUND: dirty way to give dtype to partial-infered vars
# which could not be cleared!
try
:
var
.
to_string
(
True
)
except
ValueError
:
var_desc
.
set_dtype
(
fluid
.
core
.
VarDesc
.
VarType
.
FP32
)
logger
.
debug
(
'dtype of var %s not inferred, float32 assumed'
,
var_name
)
def
validate
(
fluid_model_filename
,
golden_data_filename
=
''
,
atol
=
1e-3
,
...
...
@@ -53,12 +96,12 @@ def validate(fluid_model_filename,
# load model
fluid_model_dir
,
basename
=
os
.
path
.
split
(
fluid_model_filename
)
if
basename
==
'__model__'
:
# is desc program
logger
.
debug
(
'using desc file %s'
,
basename
)
logger
.
info
(
'using desc file %s'
,
basename
)
prog
,
_
,
var_outs
=
fluid
.
io
.
load_inference_model
(
fluid_model_dir
,
exe
)
out_names
=
var_outs
# HINT: pass var if fetch ops already created
logger
.
info
(
'model load passed'
)
elif
basename
.
endswith
(
'.py'
):
# is Python code
logger
.
debug
(
'using code file %s'
,
basename
)
logger
.
info
(
'using code file %s'
,
basename
)
module_name
,
_
=
os
.
path
.
splitext
(
basename
)
sys_path
=
sys
.
path
.
copy
()
sys
.
path
.
append
(
fluid_model_dir
)
...
...
@@ -91,18 +134,28 @@ def validate(fluid_model_filename,
if
golden_data_filename
:
logger
.
info
(
'using golden data %s'
,
golden_data_filename
)
if
golden_data_filename
.
endswith
(
'.npz'
):
test_data
=
np
.
load
(
golden_data_filename
,
encoding
=
'bytes'
)
test_data
=
np
.
load
(
golden_data_filename
,
encoding
=
'bytes'
,
allow_pickle
=
True
,
)
input_data
=
test_data
[
'inputs'
].
tolist
()
output_data
=
test_data
[
'outputs'
].
tolist
()
else
:
test_data
=
np
.
load
(
golden_data_filename
,
encoding
=
'bytes'
).
tolist
()
test_data
=
np
.
load
(
golden_data_filename
,
encoding
=
'bytes'
,
allow_pickle
=
True
,
).
tolist
()
input_data
=
test_data
[
'inputs'
]
output_data
=
test_data
[
'outputs'
]
input_data
=
flatten_dict
(
input_data
)
output_data
=
flatten_dict
(
output_data
)
input_names
=
input_data
.
keys
()
logger
.
info
(
'found %d I/O golden data, starting test ...'
,
len
(
input_data
)
+
len
(
output_data
))
output_names
=
output_data
.
keys
()
logger
.
info
(
'with %d inputs and %d outputs'
,
len
(
input_data
),
len
(
output_data
))
else
:
assert
inference_input_names
,
'input names required for type-shape inference'
...
...
@@ -111,25 +164,7 @@ def validate(fluid_model_filename,
# type-shape inference and re-save
if
save_inference_model
:
for
block
in
prog
.
blocks
:
block_desc
=
block
.
desc
for
idx_op
in
range
(
block_desc
.
op_size
()):
op_desc
=
block_desc
.
op
(
idx_op
)
if
op_desc
.
type
()
in
(
'feed'
,
'fetch'
):
continue
op_desc
.
infer_var_type
(
block_desc
)
op_desc
.
infer_shape
(
block_desc
)
for
var_name
,
var
in
block
.
vars
.
items
():
var_desc
=
var
.
desc
if
var_desc
.
type
()
!=
fluid
.
core
.
VarDesc
.
VarType
.
LOD_TENSOR
:
continue
# WORKAROUND: dirty way to give dtype to partial-infered vars
# which could not be cleared!
try
:
var
.
to_string
(
True
)
except
ValueError
:
var_desc
.
set_dtype
(
fluid
.
core
.
VarDesc
.
VarType
.
FP32
)
fluid_prog_shape_infer
(
prog
)
fluid
.
io
.
save_inference_model
(
fluid_model_dir
,
input_names
,
var_outs
,
...
...
@@ -151,7 +186,7 @@ def validate(fluid_model_filename,
# validate
passed
=
True
for
(
name
,
truth
),
output
in
zip
(
output_data
.
items
(),
outputs
):
logger
.
info
(
'testing output {} ...'
.
format
(
name
))
logger
.
info
(
'testing o
n o
utput {} ...'
.
format
(
name
))
try
:
np
.
testing
.
assert_allclose
(
output
,
truth
,
...
...
@@ -162,10 +197,7 @@ def validate(fluid_model_filename,
except
AssertionError
as
e
:
passed
=
False
logger
.
error
(
'failed: %s
\n
'
,
e
)
if
passed
:
logger
.
info
(
'accuracy passed'
)
else
:
logger
.
info
(
'accuracy not passed'
)
logger
.
info
(
'accuracy %spassed'
,
''
if
passed
else
'not '
)
return
passed
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录