Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
d6e4a4ba
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d6e4a4ba
编写于
7月 18, 2019
作者:
M
Macrobull
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add naive option
上级
9d147284
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
90 addition
and
54 deletion
+90
-54
onnx2fluid/examples/convert_data_npz.py
onnx2fluid/examples/convert_data_npz.py
+3
-3
onnx2fluid/examples/convert_data_pb.py
onnx2fluid/examples/convert_data_pb.py
+3
-3
onnx2fluid/onnx2fluid/__main__.py
onnx2fluid/onnx2fluid/__main__.py
+8
-1
onnx2fluid/onnx2fluid/cmdline.py
onnx2fluid/onnx2fluid/cmdline.py
+5
-1
onnx2fluid/onnx2fluid/conversion.py
onnx2fluid/onnx2fluid/conversion.py
+24
-10
onnx2fluid/onnx2fluid/onnx_utils.py
onnx2fluid/onnx2fluid/onnx_utils.py
+3
-3
onnx2fluid/onnx2fluid/symbolic.py
onnx2fluid/onnx2fluid/symbolic.py
+20
-18
onnx2fluid/onnx2fluid/torch_export_helper.py
onnx2fluid/onnx2fluid/torch_export_helper.py
+4
-4
onnx2fluid/onnx2fluid/validation.py
onnx2fluid/onnx2fluid/validation.py
+5
-4
onnx2fluid/onnx2fluid/writer.py
onnx2fluid/onnx2fluid/writer.py
+15
-7
未找到文件。
onnx2fluid/examples/convert_data_npz.py
浏览文件 @
d6e4a4ba
...
@@ -19,12 +19,12 @@ def make_var_name(name):
...
@@ -19,12 +19,12 @@ def make_var_name(name):
assert
name
assert
name
if
name
[
0
].
isdigit
():
for
s
in
'
\\
|/:.-'
:
return
'var_'
+
name
for
s
in
'
\\
|/:-'
:
#
name
=
name
.
replace
(
s
,
'_'
)
name
=
name
.
replace
(
s
,
'_'
)
if
name
.
startswith
(
'_'
):
if
name
.
startswith
(
'_'
):
name
=
'var'
+
name
name
=
'var'
+
name
elif
name
[
0
].
isdigit
():
name
=
'var_'
+
name
return
name
return
name
...
...
onnx2fluid/examples/convert_data_pb.py
浏览文件 @
d6e4a4ba
...
@@ -22,12 +22,12 @@ def make_var_name(name):
...
@@ -22,12 +22,12 @@ def make_var_name(name):
assert
name
assert
name
if
name
[
0
].
isdigit
():
for
s
in
'
\\
|/:.-'
:
return
'var_'
+
name
for
s
in
'
\\
|/:-'
:
#
name
=
name
.
replace
(
s
,
'_'
)
name
=
name
.
replace
(
s
,
'_'
)
if
name
.
startswith
(
'_'
):
if
name
.
startswith
(
'_'
):
name
=
'var'
+
name
name
=
'var'
+
name
elif
name
[
0
].
isdigit
():
name
=
'var_'
+
name
return
name
return
name
...
...
onnx2fluid/onnx2fluid/__main__.py
浏览文件 @
d6e4a4ba
...
@@ -64,7 +64,14 @@ parser.add_argument(
...
@@ -64,7 +64,14 @@ parser.add_argument(
'-x'
,
'-x'
,
action
=
'store_false'
,
action
=
'store_false'
,
dest
=
'pedantic'
,
dest
=
'pedantic'
,
help
=
'process non-standard ONNX ops, this may lead to fails'
,
help
=
'process non-standard ONNX ops, this may lead to failures'
,
)
parser
.
add_argument
(
'--naive'
,
'-n'
,
action
=
'store_true'
,
default
=
False
,
help
=
'bypass ONNX op optimizations, especially for training purpose'
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--skip-version-conversion'
,
'--skip-version-conversion'
,
...
...
onnx2fluid/onnx2fluid/cmdline.py
浏览文件 @
d6e4a4ba
...
@@ -18,6 +18,8 @@ from __future__ import unicode_literals
...
@@ -18,6 +18,8 @@ from __future__ import unicode_literals
import
logging
,
shutil
,
zipfile
import
logging
,
shutil
,
zipfile
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
__all__
=
[
'main'
,
'main'
,
]
]
...
@@ -45,6 +47,7 @@ def main(**kwargs):
...
@@ -45,6 +47,7 @@ def main(**kwargs):
model_basename
=
DEFAULT_MODEL_MODULE
+
'.py'
model_basename
=
DEFAULT_MODEL_MODULE
+
'.py'
model_func_name
=
DEFAULT_MODEL_FUNC
model_func_name
=
DEFAULT_MODEL_FUNC
onnx_opset_pedantic
=
kwargs
.
pop
(
'pedantic'
,
True
)
onnx_opset_pedantic
=
kwargs
.
pop
(
'pedantic'
,
True
)
onnx_skip_optimization
=
kwargs
.
pop
(
'naive'
,
False
)
skip_version_conversion
=
kwargs
.
pop
(
'skip_version_conversion'
,
False
)
skip_version_conversion
=
kwargs
.
pop
(
'skip_version_conversion'
,
False
)
onnx_opset_version
=
None
if
skip_version_conversion
else
DEFAULT_ONNX_OPSET_VERSION
onnx_opset_version
=
None
if
skip_version_conversion
else
DEFAULT_ONNX_OPSET_VERSION
...
@@ -55,6 +58,7 @@ def main(**kwargs):
...
@@ -55,6 +58,7 @@ def main(**kwargs):
model_func_name
=
model_func_name
,
model_func_name
=
model_func_name
,
onnx_opset_version
=
onnx_opset_version
,
onnx_opset_version
=
onnx_opset_version
,
onnx_opset_pedantic
=
onnx_opset_pedantic
,
onnx_opset_pedantic
=
onnx_opset_pedantic
,
onnx_skip_optimization
=
onnx_skip_optimization
,
**
kwargs
)
**
kwargs
)
# validate
# validate
...
@@ -65,7 +69,7 @@ def main(**kwargs):
...
@@ -65,7 +69,7 @@ def main(**kwargs):
if
golden_data_filename
or
save_inference_model
:
if
golden_data_filename
or
save_inference_model
:
from
.validation
import
validate
from
.validation
import
validate
if
save_inference_model
:
if
infer_inputs
:
inference_input_names
=
infer_inputs
.
split
(
','
)
inference_input_names
=
infer_inputs
.
split
(
','
)
else
:
else
:
inference_input_names
=
None
inference_input_names
=
None
...
...
onnx2fluid/onnx2fluid/conversion.py
浏览文件 @
d6e4a4ba
...
@@ -24,12 +24,12 @@ def make_var_name(name):
...
@@ -24,12 +24,12 @@ def make_var_name(name):
if
name
==
''
:
if
name
==
''
:
return
''
return
''
if
name
[
0
].
isdigit
():
return
'var_'
+
name
for
s
in
'
\\
|/:.-'
:
for
s
in
'
\\
|/:.-'
:
name
=
name
.
replace
(
s
,
'_'
)
name
=
name
.
replace
(
s
,
'_'
)
if
name
.
startswith
(
'_'
):
if
name
.
startswith
(
'_'
):
name
=
'var'
+
name
name
=
'var'
+
name
elif
name
[
0
].
isdigit
():
name
=
'var_'
+
name
return
name
return
name
...
@@ -40,6 +40,7 @@ def convert(onnx_model_filename,
...
@@ -40,6 +40,7 @@ def convert(onnx_model_filename,
embed_params
=
False
,
embed_params
=
False
,
onnx_opset_version
=
None
,
onnx_opset_version
=
None
,
onnx_opset_pedantic
=
True
,
onnx_opset_pedantic
=
True
,
onnx_skip_optimization
=
False
,
debug
=
False
,
debug
=
False
,
**
kwargs
):
**
kwargs
):
"""
"""
...
@@ -61,10 +62,10 @@ def convert(onnx_model_filename,
...
@@ -61,10 +62,10 @@ def convert(onnx_model_filename,
from
.onnx_utils
import
DEFAULT_OP_DOMAIN
from
.onnx_utils
import
DEFAULT_OP_DOMAIN
from
.onnx_utils
import
graph_ops
,
graph_weights
from
.onnx_utils
import
graph_ops
,
graph_weights
from
.onnx_utils
import
inferred_model_value_info
from
.onnx_utils
import
inferred_model_value_info
from
.onnx_utils
import
polish_model
from
.onnx_utils
import
polish_model
,
optimize_model_strip_initializer
from
.writer
import
Program
,
Writer
from
.writer
import
Program
,
Writer
logger
=
logging
.
getLogger
(
'
convert
'
)
logger
=
logging
.
getLogger
(
'
onnx2fluid
'
)
# prepare onnx model
# prepare onnx model
logger
.
info
(
'loading model: %s ...'
,
onnx_model_filename
)
logger
.
info
(
'loading model: %s ...'
,
onnx_model_filename
)
...
@@ -90,8 +91,12 @@ def convert(onnx_model_filename,
...
@@ -90,8 +91,12 @@ def convert(onnx_model_filename,
# onnx model optimization
# onnx model optimization
logger
.
info
(
'model has %d ops'
,
len
(
onnx_model
.
graph
.
node
))
logger
.
info
(
'model has %d ops'
,
len
(
onnx_model
.
graph
.
node
))
logger
.
info
(
'optimizing model ...'
)
if
onnx_skip_optimization
:
onnx_model
=
polish_model
(
onnx_model
,
checking
=
onnx_opset_pedantic
)
logger
.
info
(
'stripping model ...'
)
onnx_model
=
optimize_model_strip_initializer
(
onnx_model
)
else
:
logger
.
info
(
'optimizing model ...'
)
onnx_model
=
polish_model
(
onnx_model
,
checking
=
onnx_opset_pedantic
)
# prepare filesystem
# prepare filesystem
shutil
.
rmtree
(
save_dir
,
ignore_errors
=
True
)
shutil
.
rmtree
(
save_dir
,
ignore_errors
=
True
)
...
@@ -123,7 +128,7 @@ def convert(onnx_model_filename,
...
@@ -123,7 +128,7 @@ def convert(onnx_model_filename,
for
name
,
weight
in
graph_weights
(
onnx_graph
):
for
name
,
weight
in
graph_weights
(
onnx_graph
):
var_name
=
make_var_name
(
name
)
var_name
=
make_var_name
(
name
)
value_info
=
value_infos
[
var_name
]
value_info
=
value_infos
[
var_name
]
value_info
[
'lod'
]
=
[
0
]
value_info
[
'lod'
]
=
[]
value_info
[
'embedded_as'
]
=
[]
value_info
[
'embedded_as'
]
=
[]
value_info
[
'get_weight'
]
=
(
lambda
w
:
lambda
:
w
.
tolist
())(
value_info
[
'get_weight'
]
=
(
lambda
w
:
lambda
:
w
.
tolist
())(
weight
)
# lazy getter
weight
)
# lazy getter
...
@@ -306,7 +311,14 @@ def main():
...
@@ -306,7 +311,14 @@ def main():
'-x'
,
'-x'
,
action
=
'store_false'
,
action
=
'store_false'
,
dest
=
'pedantic'
,
dest
=
'pedantic'
,
help
=
'process non-standard ONNX ops, this may lead to fails'
,
help
=
'process non-standard ONNX ops, this may lead to failures'
,
)
parser
.
add_argument
(
'--naive'
,
'-n'
,
action
=
'store_true'
,
default
=
False
,
help
=
'bypass ONNX op optimizations, especially for training purpose'
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
'--skip-version-conversion'
,
'--skip-version-conversion'
,
...
@@ -329,13 +341,15 @@ def main():
...
@@ -329,13 +341,15 @@ def main():
if
save_dir
else
basepath
)
+
shutil
.
os
.
sep
if
save_dir
else
basepath
)
+
shutil
.
os
.
sep
embed_params
=
args
.
embed_params
embed_params
=
args
.
embed_params
pedantic
=
args
.
pedantic
pedantic
=
args
.
pedantic
skip_version_conversion
=
args
.
skip_version_conversion
skip_optimization
=
args
.
naive
onnx_opset_version
=
None
if
args
.
skip_version_conversion
else
DEFAULT_ONNX_OPSET_VERSION
convert
(
model_filename
,
convert
(
model_filename
,
save_dir
,
save_dir
,
embed_params
=
embed_params
,
embed_params
=
embed_params
,
onnx_opset_version
=
onnx_opset_version
,
onnx_opset_pedantic
=
pedantic
,
onnx_opset_pedantic
=
pedantic
,
onnx_skip_
version_conversion
=
skip_version_convers
ion
,
onnx_skip_
optimization
=
skip_optimizat
ion
,
debug
=
debug
)
debug
=
debug
)
...
...
onnx2fluid/onnx2fluid/onnx_utils.py
浏览文件 @
d6e4a4ba
...
@@ -356,16 +356,16 @@ def polish_model(model, internals=True, extras=True, checking=True):
...
@@ -356,16 +356,16 @@ def polish_model(model, internals=True, extras=True, checking=True):
def
polish_and_save
(
model_filename
,
def
polish_and_save
(
model_filename
,
save_filename
=
''
,
suffix
=
'.polished'
,
suffix
=
'.polished'
,
save_filename
=
None
,
*
args
,
*
args
,
**
kwargs
):
**
kwargs
):
"""
"""
run polish_model and save
run polish_model and save
"""
"""
if
save_filename
is
None
:
save_filename
=
save_filename
or
model_filename
.
replace
(
save_filename
=
model_filename
.
replace
(
'.onnx'
,
suffix
+
'.onnx'
)
'.onnx'
,
suffix
+
'.onnx'
)
model
=
onnx
.
load
(
model_filename
)
model
=
onnx
.
load
(
model_filename
)
model
=
polish_model
(
model
,
*
args
,
**
kwargs
)
model
=
polish_model
(
model
,
*
args
,
**
kwargs
)
...
...
onnx2fluid/onnx2fluid/symbolic.py
浏览文件 @
d6e4a4ba
...
@@ -18,7 +18,8 @@ import numpy as _np
...
@@ -18,7 +18,8 @@ import numpy as _np
from
collections
import
OrderedDict
as
_dict
from
collections
import
OrderedDict
as
_dict
from
onnx.mapping
import
TENSOR_TYPE_TO_NP_TYPE
from
onnx.mapping
import
TENSOR_TYPE_TO_NP_TYPE
_logger
=
_logging
.
getLogger
(
__name__
)
# _logger = _logging.getLogger(__name__)
_logger
=
_logging
.
getLogger
(
'onnx2fluid'
)
ONNX_INT_MAX
=
2
**
63
-
1
ONNX_INT_MAX
=
2
**
63
-
1
FLUID_INT_MAX
=
2
**
31
-
1
#
FLUID_INT_MAX
=
2
**
31
-
1
#
...
@@ -58,8 +59,8 @@ DEFAULT_OP_MAPPING = {
...
@@ -58,8 +59,8 @@ DEFAULT_OP_MAPPING = {
'Ceil'
:
[
'ceil'
,
[
'X'
],
[
'Out'
]],
'Ceil'
:
[
'ceil'
,
[
'X'
],
[
'Out'
]],
'Clip'
:
'Clip'
:
[
'clip'
,
[
'X'
],
[
'Out'
],
dict
(),
dict
(
[
'clip'
,
[
'X'
],
[
'Out'
],
dict
(),
dict
(
min
=
(
_np
.
array
([
255
,
255
,
127
,
255
],
dtype
=
_np
.
uint8
).
view
(
_np
.
float32
)),
min
=
(
_np
.
a
sa
rray
([
255
,
255
,
127
,
255
],
dtype
=
_np
.
uint8
).
view
(
_np
.
float32
)),
max
=
(
_np
.
array
([
255
,
255
,
127
,
127
],
dtype
=
_np
.
uint8
).
view
(
_np
.
float32
)),
max
=
(
_np
.
a
sa
rray
([
255
,
255
,
127
,
127
],
dtype
=
_np
.
uint8
).
view
(
_np
.
float32
)),
)],
)],
'Cos'
:
[
'cos'
,
[
'X'
],
[
'Out'
]],
'Cos'
:
[
'cos'
,
[
'X'
],
[
'Out'
]],
'Elu'
:
[
'elu'
,
[
'X'
],
[
'Out'
],
dict
(),
dict
(
alpha
=
1.
)],
'Elu'
:
[
'elu'
,
[
'X'
],
[
'Out'
],
dict
(),
dict
(
alpha
=
1.
)],
...
@@ -449,7 +450,7 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name):
...
@@ -449,7 +450,7 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name):
# I/O
# I/O
var_x
,
=
inputs
var_x
,
=
inputs
var_y
,
var_indices
,
=
(
outputs
+
[
''
]
*
1
)[:
2
]
var_y
,
var_indices
,
=
(
outputs
+
[
''
]
*
1
)[:
2
]
assert
name
and
var_x
and
var_y
assert
name
and
all
(
inputs
)
and
var_y
# interpretation
# interpretation
assert
attrs
.
get
(
assert
attrs
.
get
(
...
@@ -512,7 +513,7 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, name):
...
@@ -512,7 +513,7 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, name):
# I/O
# I/O
var_x
,
var_rois
,
=
inputs
var_x
,
var_rois
,
=
inputs
var_y
,
=
outputs
var_y
,
=
outputs
assert
name
and
var_x
and
var_rois
and
var_y
assert
name
and
all
(
inputs
)
and
all
(
outputs
)
# interpretation
# interpretation
spatial_scale
=
attrs
[
'spatial_scale'
]
# required
spatial_scale
=
attrs
[
'spatial_scale'
]
# required
...
@@ -565,7 +566,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
...
@@ -565,7 +566,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
# I/O
# I/O
var_x
,
var_scales
,
=
inputs
var_x
,
var_scales
,
=
inputs
var_y
,
=
outputs
var_y
,
=
outputs
assert
var_x
and
var_scales
and
var_y
assert
all
(
inputs
)
and
all
(
outputs
)
# interpretation
# interpretation
# output shape
# output shape
...
@@ -701,7 +702,7 @@ def BatchNormalization(prog,
...
@@ -701,7 +702,7 @@ def BatchNormalization(prog,
var_x
,
var_scale
,
var_b
,
var_mean
,
var_var
,
=
inputs
var_x
,
var_scale
,
var_b
,
var_mean
,
var_var
,
=
inputs
var_y
,
var_mean_
,
var_var_
,
var_saved_mean
,
var_saved_variance
,
=
(
var_y
,
var_mean_
,
var_var_
,
var_saved_mean
,
var_saved_variance
,
=
(
outputs
+
[
''
]
*
4
)[:
5
]
outputs
+
[
''
]
*
4
)[:
5
]
assert
var_x
and
var_scale
and
var_b
and
var_mean
and
var_var
and
var_y
assert
all
(
inputs
)
and
var_y
assert
var_saved_mean
or
name
assert
var_saved_mean
or
name
assert
var_saved_variance
or
name
assert
var_saved_variance
or
name
var_saved_mean
=
var_saved_mean
or
(
name
+
'.saved_mean'
)
# dummy output
var_saved_mean
=
var_saved_mean
or
(
name
+
'.saved_mean'
)
# dummy output
...
@@ -879,7 +880,8 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -879,7 +880,8 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
_logger
.
warning
(
_logger
.
warning
(
'in op (Constant -> %s): '
'in op (Constant -> %s): '
'attribute "shape" of %s not inferred, '
'attribute "shape" of %s not inferred, '
'using value as 1-D tensor may lead to fails'
,
outputs
,
var_output
)
'using value as 1-D tensor may lead to failures'
,
outputs
,
var_output
)
# generation
# generation
if
not
shape
or
value
.
size
==
1
:
# scalar or 1-size
if
not
shape
or
value
.
size
==
1
:
# scalar or 1-size
...
@@ -929,7 +931,7 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -929,7 +931,7 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
'given shape is neither const value nor deductible from output, '
'given shape is neither const value nor deductible from output, '
'this is not supported'
)
'this is not supported'
)
attrs
=
attrs
.
copy
()
attrs
=
attrs
.
copy
()
attrs
.
setdefault
(
'value'
,
_np
.
array
(
0
,
dtype
=
_np
.
float32
))
attrs
.
setdefault
(
'value'
,
_np
.
a
sa
rray
(
0
,
dtype
=
_np
.
float32
))
attrs
.
update
({
'shape'
:
shape
})
# pass const
attrs
.
update
({
'shape'
:
shape
})
# pass const
prog
.
Code
(
'# shape: {} = {} # const as literal'
.
format
(
var_shape
,
shape
))
prog
.
Code
(
'# shape: {} = {} # const as literal'
.
format
(
var_shape
,
shape
))
...
@@ -959,7 +961,7 @@ def Conv(prog,
...
@@ -959,7 +961,7 @@ def Conv(prog,
# I/O
# I/O
var_x
,
var_w
,
var_b
,
=
(
inputs
+
[
''
]
*
1
)[:
3
]
var_x
,
var_w
,
var_b
,
=
(
inputs
+
[
''
]
*
1
)[:
3
]
var_y
,
=
outputs
var_y
,
=
outputs
assert
name
and
var_x
and
var_w
and
var_y
assert
name
and
var_x
and
var_w
and
all
(
outputs
)
# interpretation
# interpretation
assert
attrs
.
get
(
assert
attrs
.
get
(
...
@@ -1066,7 +1068,7 @@ def ConvTranspose(prog,
...
@@ -1066,7 +1068,7 @@ def ConvTranspose(prog,
# I/O
# I/O
var_x
,
var_w
,
var_b
,
=
(
inputs
+
[
''
]
*
1
)[:
3
]
var_x
,
var_w
,
var_b
,
=
(
inputs
+
[
''
]
*
1
)[:
3
]
var_y
,
=
outputs
var_y
,
=
outputs
assert
name
and
var_x
and
var_w
and
var_y
assert
name
and
var_x
and
var_w
and
all
(
outputs
)
# interpretation
# interpretation
assert
attrs
.
get
(
assert
attrs
.
get
(
...
@@ -1174,7 +1176,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1174,7 +1176,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
# due to fluid fc don't support transposed weight, we use matmul + ew_add
# due to fluid fc don't support transposed weight, we use matmul + ew_add
var_a
,
var_b
,
var_c
,
=
inputs
var_a
,
var_b
,
var_c
,
=
inputs
var_y
,
=
outputs
var_y
,
=
outputs
assert
name
and
var_a
and
var_b
and
var_c
and
var_y
assert
name
and
all
(
inputs
)
and
all
(
outputs
)
alpha
=
attrs
.
get
(
'alpha'
,
1.
)
# optional
alpha
=
attrs
.
get
(
'alpha'
,
1.
)
# optional
beta
=
attrs
.
get
(
'beta'
,
1.
)
# optional
beta
=
attrs
.
get
(
'beta'
,
1.
)
# optional
...
@@ -1794,7 +1796,7 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
...
@@ -1794,7 +1796,7 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
mode
)
mode
)
fluid_op
=
'pad'
fluid_op
=
'pad'
pad2d_attr
=
''
pad2d_attr
=
''
paddings
=
_np
.
array
(
pads
).
reshape
(
paddings
=
_np
.
a
sa
rray
(
pads
).
reshape
(
(
-
1
,
2
)).
transpose
().
flatten
().
tolist
()
# SSEE -> SESE
(
-
1
,
2
)).
transpose
().
flatten
().
tolist
()
# SSEE -> SESE
od_attrs
[
'paddings'
]
=
paddings
od_attrs
[
'paddings'
]
=
paddings
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
...
@@ -1838,7 +1840,7 @@ def PRelu(prog,
...
@@ -1838,7 +1840,7 @@ def PRelu(prog,
# I/O
# I/O
var_x
,
var_slope
,
=
inputs
var_x
,
var_slope
,
=
inputs
var_y
,
=
outputs
var_y
,
=
outputs
assert
name
and
var_x
and
var_slope
and
var_y
assert
name
and
all
(
inputs
)
and
all
(
outputs
)
# interpretation
# interpretation
mode
=
'channel'
mode
=
'channel'
...
@@ -1904,7 +1906,7 @@ def Reshape(prog, inputs, outputs, attrs_, value_infos, name, *args, **kwargs):
...
@@ -1904,7 +1906,7 @@ def Reshape(prog, inputs, outputs, attrs_, value_infos, name, *args, **kwargs):
# I/O
# I/O
var_data
,
var_shape
,
=
inputs
var_data
,
var_shape
,
=
inputs
var_reshaped
,
=
outputs
var_reshaped
,
=
outputs
assert
name
and
var_data
and
var_shape
and
var_reshaped
assert
name
and
all
(
inputs
)
and
all
(
outputs
)
# interpretation
# interpretation
shape
=
_const_weight_or_none
(
value_infos
,
var_shape
)
shape
=
_const_weight_or_none
(
value_infos
,
var_shape
)
...
@@ -2015,7 +2017,7 @@ def Shape(prog, inputs, outputs, attrs_, name, **kwargs):
...
@@ -2015,7 +2017,7 @@ def Shape(prog, inputs, outputs, attrs_, name, **kwargs):
# I/O
# I/O
var_data
,
=
inputs
var_data
,
=
inputs
var_shape
,
=
outputs
var_shape
,
=
outputs
assert
name
and
var_data
and
var_shape
assert
name
and
all
(
inputs
)
and
all
(
outputs
)
# interpretation
# interpretation
fluid_op
=
'shape'
fluid_op
=
'shape'
...
@@ -2189,7 +2191,7 @@ def Tile(prog, inputs, outputs, attrs_, value_infos, name='', *args, **kwargs):
...
@@ -2189,7 +2191,7 @@ def Tile(prog, inputs, outputs, attrs_, value_infos, name='', *args, **kwargs):
# I/O
# I/O
var_input
,
var_repeats
,
=
inputs
var_input
,
var_repeats
,
=
inputs
var_output
,
=
outputs
var_output
,
=
outputs
assert
var_input
and
var_repeats
and
var_output
assert
all
(
inputs
)
and
all
(
outputs
)
# interpretation
# interpretation
repeats
=
_const_weight_or_none
(
value_infos
,
var_repeats
)
repeats
=
_const_weight_or_none
(
value_infos
,
var_repeats
)
...
@@ -2227,7 +2229,7 @@ def Transpose(prog, inputs, outputs, attrs, name, *args, **kwargs):
...
@@ -2227,7 +2229,7 @@ def Transpose(prog, inputs, outputs, attrs, name, *args, **kwargs):
# I/O
# I/O
var_data
,
=
inputs
var_data
,
=
inputs
var_transposed
,
=
outputs
var_transposed
,
=
outputs
assert
name
and
var_data
and
var_transposed
assert
name
and
all
(
inputs
)
and
all
(
outputs
)
# interpretation
# interpretation
fluid_op
=
'transpose'
fluid_op
=
'transpose'
...
...
onnx2fluid/onnx2fluid/torch_export_helper.py
浏览文件 @
d6e4a4ba
...
@@ -138,10 +138,10 @@ def export_onnx_with_validation(
...
@@ -138,10 +138,10 @@ def export_onnx_with_validation(
outputs
=
torch
.
onnx
.
export
(
model
,
outputs
=
torch
.
onnx
.
export
(
model
,
torch_inputs
,
torch_inputs
,
export_basepath
+
'.onnx'
,
export_basepath
+
'.onnx'
,
input_names
=
(
None
if
input_names
is
None
else
input_names
=
(
input_names
flatten_list
(
input_names
)),
and
flatten_list
(
input_names
)),
output_names
=
(
None
if
output_names
is
None
else
output_names
=
(
output_names
flatten_list
(
output_names
)),
and
flatten_list
(
output_names
)),
*
args
,
*
args
,
**
kwargs
)
**
kwargs
)
if
outputs
is
None
:
# WORKAROUND: for torch.onnx
if
outputs
is
None
:
# WORKAROUND: for torch.onnx
...
...
onnx2fluid/onnx2fluid/validation.py
浏览文件 @
d6e4a4ba
...
@@ -90,7 +90,7 @@ def validate(fluid_model_filename,
...
@@ -90,7 +90,7 @@ def validate(fluid_model_filename,
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
logger
=
logging
.
getLogger
(
'
validate
'
)
logger
=
logging
.
getLogger
(
'
onnx2fluid
'
)
place
=
fluid
.
CPUPlace
()
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
...
@@ -126,6 +126,7 @@ def validate(fluid_model_filename,
...
@@ -126,6 +126,7 @@ def validate(fluid_model_filename,
logger
.
info
(
'import passed'
)
logger
.
info
(
'import passed'
)
prog
=
fluid
.
default_main_program
()
prog
=
fluid
.
default_main_program
()
prog
=
prog
.
clone
(
for_test
=
True
)
# force inference mode
fluid
.
io
.
load_persistables
(
executor
=
exe
,
fluid
.
io
.
load_persistables
(
executor
=
exe
,
dirname
=
fluid_model_dir
,
dirname
=
fluid_model_dir
,
main_program
=
prog
)
main_program
=
prog
)
...
@@ -160,8 +161,7 @@ def validate(fluid_model_filename,
...
@@ -160,8 +161,7 @@ def validate(fluid_model_filename,
logger
.
info
(
'with %d inputs and %d outputs'
,
len
(
input_data
),
logger
.
info
(
'with %d inputs and %d outputs'
,
len
(
input_data
),
len
(
output_data
))
len
(
output_data
))
elif
save_inference_model
:
elif
save_inference_model
:
assert
inference_input_names
is
not
None
,
(
assert
inference_input_names
,
'input names required for type-shape inference'
'input names required for type-shape inference'
)
input_names
=
inference_input_names
input_names
=
inference_input_names
logger
.
info
(
'using input names: %s'
,
', '
.
join
(
input_names
))
logger
.
info
(
'using input names: %s'
,
', '
.
join
(
input_names
))
...
@@ -185,6 +185,7 @@ def validate(fluid_model_filename,
...
@@ -185,6 +185,7 @@ def validate(fluid_model_filename,
# execute
# execute
outputs
=
exe
.
run
(
prog
,
feed
=
input_data
,
outputs
=
exe
.
run
(
prog
,
feed
=
input_data
,
fetch_list
=
out_names
)
# out_names can be vars
fetch_list
=
out_names
)
# out_names can be vars
exe
.
close
()
logger
.
info
(
'execution passed'
)
logger
.
info
(
'execution passed'
)
# validate
# validate
...
@@ -264,7 +265,7 @@ def main():
...
@@ -264,7 +265,7 @@ def main():
atol
,
rtol
=
args
.
atol
,
args
.
rtol
atol
,
rtol
=
args
.
atol
,
args
.
rtol
save_inference_model
=
args
.
infer_inputs
is
not
None
save_inference_model
=
args
.
infer_inputs
is
not
None
inference_input_names
=
args
.
infer_inputs
.
split
(
inference_input_names
=
args
.
infer_inputs
.
split
(
','
)
if
args
.
infer_inputs
else
None
','
)
if
save_inference_model
else
None
validate
(
fluid_model_filename
,
validate
(
fluid_model_filename
,
golden_data_filename
=
golden_data_filename
,
golden_data_filename
=
golden_data_filename
,
...
...
onnx2fluid/onnx2fluid/writer.py
浏览文件 @
d6e4a4ba
...
@@ -372,7 +372,7 @@ class Writer(object):
...
@@ -372,7 +372,7 @@ class Writer(object):
prog
.
Code
(
'# input {}'
.
format
(
name
))
prog
.
Code
(
'# input {}'
.
format
(
name
))
prog
.
Code
((
prog
.
Code
((
'{} = layers.data(name={}, shape={}, dtype={}, '
'{} = layers.data(name={}, shape={}, dtype={}, '
'append_batch_size={})'
# , stop_gradient=True
'append_batch_size={}
, lod_level=1
)'
# , stop_gradient=True
).
format
(
).
format
(
name
,
name
,
repr
(
name
),
repr
(
name
),
...
@@ -427,20 +427,28 @@ class Writer(object):
...
@@ -427,20 +427,28 @@ class Writer(object):
assert
lod
is
None
or
isinstance
(
lod
,
assert
lod
is
None
or
isinstance
(
lod
,
list
),
'lod should be None or list'
list
),
'lod should be None or list'
if
lod
is
None
:
lod
=
lod
or
[]
lod
=
[
0
]
tensor_desc
=
framework_pb2
.
VarType
.
TensorDesc
()
tensor_desc
=
framework_pb2
.
VarType
.
TensorDesc
()
tensor_desc
.
data_type
=
Program
.
Dtype
(
weight
.
dtype
)
tensor_desc
.
data_type
=
Program
.
Dtype
(
weight
.
dtype
)
tensor_desc
.
dims
.
extend
(
weight
.
shape
)
tensor_desc
.
dims
.
extend
(
weight
.
shape
)
fp
=
open
(
filename
,
'wb'
)
fp
=
open
(
filename
,
'wb'
)
np
.
array
([
0
],
dtype
=
np
.
int32
).
tofile
(
fp
)
# version
np
.
array
(
lod
,
dtype
=
np
.
int64
).
tofile
(
fp
)
# LOD level
# lod_tensor.cc: SerializeToStream
np
.
array
([
0
],
dtype
=
np
.
int32
).
tofile
(
fp
)
# tensor version
np
.
asarray
([
0
],
dtype
=
np
.
uint32
).
tofile
(
fp
)
# version
np
.
array
([
tensor_desc
.
ByteSize
()],
dtype
=
np
.
int32
).
tofile
(
fp
)
np
.
asarray
([
len
(
lod
)],
dtype
=
np
.
int64
).
tofile
(
fp
)
# LOD levels
for
level
in
lod
:
np
.
asarray
([
len
(
level
)],
dtype
=
np
.
int64
).
tofile
(
fp
)
# level size
np
.
asarray
(
level
,
dtype
=
np
.
uint64
).
tofile
(
fp
)
# LOD: size_t
# tensor_util.cc: TensorToStream
np
.
asarray
([
0
],
dtype
=
np
.
uint32
).
tofile
(
fp
)
# tensor version
np
.
asarray
([
tensor_desc
.
ByteSize
()],
dtype
=
np
.
int32
).
tofile
(
fp
)
fp
.
write
(
tensor_desc
.
SerializeToString
())
fp
.
write
(
tensor_desc
.
SerializeToString
())
weight
.
tofile
(
fp
)
weight
.
tofile
(
fp
)
fp
.
close
()
fp
.
close
()
@
staticmethod
@
staticmethod
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录