Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
826481c4
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 2 年 前同步成功
通知
329
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
826481c4
编写于
4月 29, 2019
作者:
M
Macrobull
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add type shape inference
上级
7c3e9379
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
401 addition
and
315 deletion
+401
-315
onnx2fluid/examples/convert_data_npz.py
onnx2fluid/examples/convert_data_npz.py
+3
-3
onnx2fluid/examples/convert_data_pb.py
onnx2fluid/examples/convert_data_pb.py
+3
-3
onnx2fluid/examples/onnx_model_zoo.sh
onnx2fluid/examples/onnx_model_zoo.sh
+1
-1
onnx2fluid/onnx2fluid/__main__.py
onnx2fluid/onnx2fluid/__main__.py
+8
-0
onnx2fluid/onnx2fluid/cmdline.py
onnx2fluid/onnx2fluid/cmdline.py
+13
-4
onnx2fluid/onnx2fluid/conversion.py
onnx2fluid/onnx2fluid/conversion.py
+17
-14
onnx2fluid/onnx2fluid/onnx_utils.py
onnx2fluid/onnx2fluid/onnx_utils.py
+35
-35
onnx2fluid/onnx2fluid/symbolic.py
onnx2fluid/onnx2fluid/symbolic.py
+176
-163
onnx2fluid/onnx2fluid/torch_export_helper.py
onnx2fluid/onnx2fluid/torch_export_helper.py
+7
-6
onnx2fluid/onnx2fluid/validation.py
onnx2fluid/onnx2fluid/validation.py
+77
-28
onnx2fluid/onnx2fluid/writer.py
onnx2fluid/onnx2fluid/writer.py
+59
-56
onnx2fluid/setup.cfg
onnx2fluid/setup.cfg
+2
-2
未找到文件。
onnx2fluid/examples/convert_data_npz.py
浏览文件 @
826481c4
...
@@ -14,14 +14,14 @@ from collections import OrderedDict as Dict
...
@@ -14,14 +14,14 @@ from collections import OrderedDict as Dict
def
make_var_name
(
name
):
def
make_var_name
(
name
):
"""
"""
make a valid variable name in Python code
make a valid variable name in Python code
"""
"""
if
name
==
''
:
if
name
==
''
:
return
'_'
return
'_'
if
name
[
0
].
isdigit
():
if
name
[
0
].
isdigit
():
return
'var_'
+
name
return
'var_'
+
name
for
s
in
'
\\
|/:'
:
#
for
s
in
'
\\
|/:
-
'
:
#
name
=
name
.
replace
(
s
,
'_'
)
name
=
name
.
replace
(
s
,
'_'
)
if
name
.
startswith
(
'_'
):
if
name
.
startswith
(
'_'
):
name
=
'var'
+
name
name
=
'var'
+
name
...
...
onnx2fluid/examples/convert_data_pb.py
浏览文件 @
826481c4
...
@@ -17,14 +17,14 @@ from glob import glob
...
@@ -17,14 +17,14 @@ from glob import glob
def
make_var_name
(
name
):
def
make_var_name
(
name
):
"""
"""
make a valid variable name in Python code
make a valid variable name in Python code
"""
"""
if
name
==
''
:
if
name
==
''
:
return
'_'
return
'_'
if
name
[
0
].
isdigit
():
if
name
[
0
].
isdigit
():
return
'var_'
+
name
return
'var_'
+
name
for
s
in
'
\\
|/:'
:
#
for
s
in
'
\\
|/:
-
'
:
#
name
=
name
.
replace
(
s
,
'_'
)
name
=
name
.
replace
(
s
,
'_'
)
if
name
.
startswith
(
'_'
):
if
name
.
startswith
(
'_'
):
name
=
'var'
+
name
name
=
'var'
+
name
...
...
onnx2fluid/examples/onnx_model_zoo.sh
浏览文件 @
826481c4
...
@@ -311,7 +311,7 @@ resnet100_arcface()
...
@@ -311,7 +311,7 @@ resnet100_arcface()
echo
"extracting ..."
echo
"extracting ..."
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
-o
/tmp/export/
"
$fn_model
"
-y
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
"
/
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
...
...
onnx2fluid/onnx2fluid/__main__.py
浏览文件 @
826481c4
...
@@ -95,6 +95,14 @@ parser.add_argument(
...
@@ -95,6 +95,14 @@ parser.add_argument(
default
=
1e-2
,
default
=
1e-2
,
help
=
'assertion relative tolerance for validation'
,
help
=
'assertion relative tolerance for validation'
,
)
)
parser
.
add_argument
(
'--infer_inputs'
,
'-i'
,
nargs
=
'?'
,
default
=
None
,
const
=
''
,
help
=
'perform type-shape inference with given input names and re-save model'
,
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
logging_format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
...
...
onnx2fluid/onnx2fluid/cmdline.py
浏览文件 @
826481c4
...
@@ -60,19 +60,28 @@ def main(**kwargs):
...
@@ -60,19 +60,28 @@ def main(**kwargs):
# validate
# validate
passed
=
True
passed
=
True
golden_data_filename
=
kwargs
.
pop
(
'test_data'
,
''
)
golden_data_filename
=
kwargs
.
pop
(
'test_data'
,
''
)
if
golden_data_filename
:
infer_inputs
=
kwargs
.
pop
(
'infer_inputs'
,
None
)
if
golden_data_filename
or
infer_inputs
:
from
.validation
import
validate
from
.validation
import
validate
save_inference_model
=
infer_inputs
is
not
None
inference_input_names
=
infer_inputs
.
split
(
','
)
if
infer_inputs
else
None
logger
.
info
(
'starting validation on desc ...'
)
logger
.
info
(
'starting validation on desc ...'
)
passed
&=
validate
(
shutil
.
os
.
path
.
join
(
save_dir
,
'__model__'
),
passed
&=
validate
(
shutil
.
os
.
path
.
join
(
save_dir
,
'__model__'
),
golden_data_filename
,
**
kwargs
)
golden_data_filename
=
golden_data_filename
,
save_inference_model
=
save_inference_model
,
inference_input_names
=
inference_input_names
,
**
kwargs
)
logger
.
info
(
'starting validation on code ...'
)
logger
.
info
(
'starting validation on code ...'
)
# this re-generate desc proto with Python code when debug on
# this re-generate desc proto with Python code when debug on
passed
&=
validate
(
shutil
.
os
.
path
.
join
(
save_dir
,
model_basename
),
passed
&=
validate
(
shutil
.
os
.
path
.
join
(
save_dir
,
model_basename
),
golden_data_filename
,
golden_data_filename
=
golden_data_filename
,
model_func_name
=
model_func_name
,
model_func_name
=
model_func_name
,
save_inference_model
=
debug
,
save_inference_model
=
save_inference_model
,
inference_input_names
=
inference_input_names
,
**
kwargs
)
**
kwargs
)
if
not
passed
:
if
not
passed
:
...
...
onnx2fluid/onnx2fluid/conversion.py
浏览文件 @
826481c4
...
@@ -27,8 +27,8 @@ def convert(onnx_model_filename,
...
@@ -27,8 +27,8 @@ def convert(onnx_model_filename,
debug
=
False
,
debug
=
False
,
**
kwargs
):
**
kwargs
):
"""
"""
convert an ONNX model to Paddle fluid Python code and desc pb
convert an ONNX model to Paddle fluid Python code and desc pb
"""
"""
import
onnx
import
onnx
...
@@ -141,23 +141,22 @@ def convert(onnx_model_filename,
...
@@ -141,23 +141,22 @@ def convert(onnx_model_filename,
logger
.
info
(
'%d ops in, %d ops out'
,
len
(
onnx_graph
.
node
),
logger
.
info
(
'%d ops in, %d ops out'
,
len
(
onnx_graph
.
node
),
len
(
fluid_program
.
op_descs
))
len
(
fluid_program
.
op_descs
))
#
shape-
inference
#
type-shape
inference
for
name
,
value_info
in
graph_value_infos
.
items
():
for
name
,
value_info
in
graph_value_infos
.
items
():
var_name
=
make_var_name
(
name
)
var_name
=
make_var_name
(
name
)
fluid_program
.
VarTypeInfo
(
var_name
,
value_info
,
fluid_program
.
VarType
Shape
Info
(
var_name
,
value_info
,
remove_batch
=
False
)
# shape-infer only
remove_batch
=
False
)
# shape-infer only
bad_var_names
=
[]
bad_var_names
=
[]
for
var_name
,
var_desc
in
fluid_program
.
var_descs
.
items
():
for
var_name
,
var_desc
in
fluid_program
.
var_descs
.
items
():
if
not
var_desc
.
type
.
lod_tensor
.
HasField
(
'tensor'
):
if
not
var_desc
.
type
.
lod_tensor
.
HasField
(
'tensor'
):
bad_var_names
.
append
(
var_name
)
bad_var_names
.
append
(
var_name
)
if
len
(
bad_var_names
)
>
0
:
if
len
(
bad_var_names
)
>
0
:
logger
.
warning
(
'type
info
not infered for var %s ...'
,
logger
.
warning
(
'type
-shape
not infered for var %s ...'
,
', '
.
join
(
bad_var_names
[:
5
]))
', '
.
join
(
bad_var_names
[:
5
]))
logger
.
warning
(
'this causes little problem for PaddlePaddle, '
logger
.
warning
(
'this causes little problem for PaddlePaddle, '
'but Paddle Mobile may not infer correctly'
)
'but Paddle Mobile may not infer correctly'
)
logger
.
warning
(
logger
.
warning
(
'please consider running onnx2fluid.validation with -i '
'please consider adding option -d to invoke PaddlePaddle shape-inference'
'to invoke PaddlePaddle type-shape inference'
)
)
# weight writer
# weight writer
for
name
,
weight
in
graph_weights
(
onnx_graph
):
for
name
,
weight
in
graph_weights
(
onnx_graph
):
...
@@ -233,13 +232,9 @@ def convert(onnx_model_filename,
...
@@ -233,13 +232,9 @@ def convert(onnx_model_filename,
logger
.
info
(
'conversion finished'
)
logger
.
info
(
'conversion finished'
)
if
__name__
==
'__main__'
:
def
main
():
del
convert
import
argparse
import
argparse
from
onnx2fluid.conversion
import
convert
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
'onnx2fluid.convert'
,
description
=
'onnx2fluid.convert'
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
,
...
@@ -310,3 +305,11 @@ if __name__ == '__main__':
...
@@ -310,3 +305,11 @@ if __name__ == '__main__':
onnx_opset_pedantic
=
pedantic
,
onnx_opset_pedantic
=
pedantic
,
onnx_skip_version_conversion
=
skip_version_conversion
,
onnx_skip_version_conversion
=
skip_version_conversion
,
debug
=
debug
)
debug
=
debug
)
if
__name__
==
'__main__'
:
del
convert
from
onnx2fluid.conversion
import
convert
main
()
onnx2fluid/onnx2fluid/onnx_utils.py
浏览文件 @
826481c4
...
@@ -44,8 +44,8 @@ DEFAULT_OP_DOMAIN = 'ai.onnx'
...
@@ -44,8 +44,8 @@ DEFAULT_OP_DOMAIN = 'ai.onnx'
def
print_pb_structure
(
message
,
loop_iterative
=
False
,
depth
=
0
):
def
print_pb_structure
(
message
,
loop_iterative
=
False
,
depth
=
0
):
"""
"""
print pb fields in its structure
print pb fields in its structure
"""
"""
if
hasattr
(
message
,
'DESCRIPTOR'
)
and
hasattr
(
message
.
DESCRIPTOR
,
'fields'
):
if
hasattr
(
message
,
'DESCRIPTOR'
)
and
hasattr
(
message
.
DESCRIPTOR
,
'fields'
):
for
field
in
message
.
DESCRIPTOR
.
fields
:
for
field
in
message
.
DESCRIPTOR
.
fields
:
...
@@ -65,8 +65,8 @@ def print_pb_structure(message, loop_iterative=False, depth=0):
...
@@ -65,8 +65,8 @@ def print_pb_structure(message, loop_iterative=False, depth=0):
def
build_value_refs
(
nodes
):
def
build_value_refs
(
nodes
):
"""
"""
build op reference of inputs and outputs
build op reference of inputs and outputs
"""
"""
input_refs
=
Dict
()
input_refs
=
Dict
()
output_refs
=
Dict
()
output_refs
=
Dict
()
...
@@ -80,8 +80,8 @@ def build_value_refs(nodes):
...
@@ -80,8 +80,8 @@ def build_value_refs(nodes):
def
get_attribute_value2
(
attr
):
def
get_attribute_value2
(
attr
):
"""
"""
get_attribute_value enhanced
get_attribute_value enhanced
"""
"""
if
attr
.
type
==
onnx
.
AttributeProto
.
TENSOR
:
if
attr
.
type
==
onnx
.
AttributeProto
.
TENSOR
:
dtype
=
np
.
dtype
(
TENSOR_TYPE_TO_NP_TYPE
[
attr
.
t
.
data_type
])
dtype
=
np
.
dtype
(
TENSOR_TYPE_TO_NP_TYPE
[
attr
.
t
.
data_type
])
...
@@ -99,24 +99,24 @@ def get_attribute_value2(attr):
...
@@ -99,24 +99,24 @@ def get_attribute_value2(attr):
def
tensor_dtype
(
tensor
):
def
tensor_dtype
(
tensor
):
"""
"""
get ONNX tensor in np.dtype
get ONNX tensor in np.dtype
"""
"""
return
TENSOR_TYPE_TO_NP_TYPE
[
tensor
.
type
.
tensor_type
.
elem_type
]
return
TENSOR_TYPE_TO_NP_TYPE
[
tensor
.
type
.
tensor_type
.
elem_type
]
def
tensor_shape
(
tensor
):
def
tensor_shape
(
tensor
):
"""
"""
get ONNX tensor shape
get ONNX tensor shape
"""
"""
return
[
dim
.
dim_value
for
dim
in
tensor
.
type
.
tensor_type
.
shape
.
dim
]
return
[
dim
.
dim_value
for
dim
in
tensor
.
type
.
tensor_type
.
shape
.
dim
]
def
node_attrs
(
node
):
def
node_attrs
(
node
):
"""
"""
convert ONNX node attributes to dict
convert ONNX node attributes to dict
"""
"""
return
{
attr
.
name
:
get_attribute_value2
(
attr
)
return
{
attr
.
name
:
get_attribute_value2
(
attr
)
for
attr
in
node
.
attribute
}
# dict
for
attr
in
node
.
attribute
}
# dict
...
@@ -124,8 +124,8 @@ def node_attrs(node):
...
@@ -124,8 +124,8 @@ def node_attrs(node):
def
node_topo
(
nodes
,
topo
=
'default'
):
def
node_topo
(
nodes
,
topo
=
'default'
):
"""
"""
build indices with given topology to an ONNX node graph
build indices with given topology to an ONNX node graph
"""
"""
if
topo
==
'default'
:
if
topo
==
'default'
:
return
list
(
range
(
len
(
nodes
)))
return
list
(
range
(
len
(
nodes
)))
...
@@ -192,8 +192,8 @@ def node_topo(nodes, topo='default'):
...
@@ -192,8 +192,8 @@ def node_topo(nodes, topo='default'):
def
node_iter
(
nodes
,
indices
=
None
):
def
node_iter
(
nodes
,
indices
=
None
):
"""
"""
generator for ONNX node graph with given indices
generator for ONNX node graph with given indices
"""
"""
if
indices
is
None
:
if
indices
is
None
:
indices
=
range
(
len
(
nodes
))
indices
=
range
(
len
(
nodes
))
...
@@ -210,7 +210,7 @@ def node_iter(nodes, indices=None):
...
@@ -210,7 +210,7 @@ def node_iter(nodes, indices=None):
if
name
==
''
:
if
name
==
''
:
name
=
'op_'
+
str
(
index
)
name
=
'op_'
+
str
(
index
)
else
:
# make_op_name
else
:
# make_op_name
for
s
in
'
\\
|/:'
:
#
for
s
in
'
\\
|/:
-
'
:
#
name
=
name
.
replace
(
s
,
'_'
)
name
=
name
.
replace
(
s
,
'_'
)
if
domain
==
''
:
if
domain
==
''
:
domain
=
DEFAULT_OP_DOMAIN
domain
=
DEFAULT_OP_DOMAIN
...
@@ -220,8 +220,8 @@ def node_iter(nodes, indices=None):
...
@@ -220,8 +220,8 @@ def node_iter(nodes, indices=None):
def
graph_ops
(
graph
,
topo
=
'default'
):
def
graph_ops
(
graph
,
topo
=
'default'
):
"""
"""
generator for ONNX node graph with given topology
generator for ONNX node graph with given topology
"""
"""
if
not
isinstance
(
graph
,
onnx
.
GraphProto
):
if
not
isinstance
(
graph
,
onnx
.
GraphProto
):
logger
.
error
(
'graph is not a GraphProto instance'
)
logger
.
error
(
'graph is not a GraphProto instance'
)
...
@@ -232,8 +232,8 @@ def graph_ops(graph, topo='default'):
...
@@ -232,8 +232,8 @@ def graph_ops(graph, topo='default'):
def
graph_weights
(
graph
):
def
graph_weights
(
graph
):
"""
"""
generator for weights of an ONNX model
generator for weights of an ONNX model
"""
"""
if
not
isinstance
(
graph
,
onnx
.
GraphProto
):
if
not
isinstance
(
graph
,
onnx
.
GraphProto
):
logger
.
error
(
'graph is not a GraphProto instance'
)
logger
.
error
(
'graph is not a GraphProto instance'
)
...
@@ -247,8 +247,8 @@ def graph_weights(graph):
...
@@ -247,8 +247,8 @@ def graph_weights(graph):
def
inferred_model_value_info
(
model
):
def
inferred_model_value_info
(
model
):
"""
"""
collect value/type info for an ONNX model
collect value/type info for an ONNX model
"""
"""
model
=
infer_shapes
(
model
)
model
=
infer_shapes
(
model
)
graph
=
model
.
graph
graph
=
model
.
graph
...
@@ -278,8 +278,8 @@ def inferred_model_value_info(model):
...
@@ -278,8 +278,8 @@ def inferred_model_value_info(model):
def
skip_node_forward
(
nodes
,
src_output_name
,
dst_input_name
,
input_refs
):
def
skip_node_forward
(
nodes
,
src_output_name
,
dst_input_name
,
input_refs
):
"""
"""
skip nodes between src_output_name -> dst_input_name and connect this pair
skip nodes between src_output_name -> dst_input_name and connect this pair
"""
"""
processed
=
0
processed
=
0
for
next_idx
in
input_refs
[
src_output_name
]:
for
next_idx
in
input_refs
[
src_output_name
]:
...
@@ -293,8 +293,8 @@ def skip_node_forward(nodes, src_output_name, dst_input_name, input_refs):
...
@@ -293,8 +293,8 @@ def skip_node_forward(nodes, src_output_name, dst_input_name, input_refs):
def
skip_node_backward
(
nodes
,
src_input_name
,
dst_output_name
,
output_refs
):
def
skip_node_backward
(
nodes
,
src_input_name
,
dst_output_name
,
output_refs
):
"""
"""
skip nodes between dst_output_name -> src_input_name and connect this pair
skip nodes between dst_output_name -> src_input_name and connect this pair
"""
"""
processed
=
0
processed
=
0
for
prev_idx
in
output_refs
[
src_input_name
]:
for
prev_idx
in
output_refs
[
src_input_name
]:
...
@@ -308,8 +308,8 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs):
...
@@ -308,8 +308,8 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs):
def
optimize_model_skip_op_for_inference
(
model
,
op_list
=
None
):
def
optimize_model_skip_op_for_inference
(
model
,
op_list
=
None
):
"""
"""
skip ops can be bypassed for inference
skip ops can be bypassed for inference
"""
"""
if
op_list
is
None
:
if
op_list
is
None
:
op_list
=
(
'Dropout'
,
'Identity'
)
op_list
=
(
'Dropout'
,
'Identity'
)
...
@@ -369,8 +369,8 @@ def optimize_model_skip_op_for_inference(model, op_list=None):
...
@@ -369,8 +369,8 @@ def optimize_model_skip_op_for_inference(model, op_list=None):
def
optimize_model_strip_initializer
(
model
,
keep_input_only
=
True
):
def
optimize_model_strip_initializer
(
model
,
keep_input_only
=
True
):
"""
"""
strip weights for inference
strip weights for inference
"""
"""
nodes
=
model
.
graph
.
node
nodes
=
model
.
graph
.
node
input_refs
,
output_refs
=
build_value_refs
(
nodes
)
input_refs
,
output_refs
=
build_value_refs
(
nodes
)
...
@@ -410,8 +410,8 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
...
@@ -410,8 +410,8 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
def
optimize_model_cast
(
model
):
def
optimize_model_cast
(
model
):
"""
"""
strip cascade and unecessary onnx::Cast-9:
strip cascade and unecessary onnx::Cast-9:
"""
"""
nodes
=
model
.
graph
.
node
nodes
=
model
.
graph
.
node
input_refs
,
output_refs
=
build_value_refs
(
nodes
)
input_refs
,
output_refs
=
build_value_refs
(
nodes
)
...
@@ -467,8 +467,8 @@ def optimize_model_cast(model):
...
@@ -467,8 +467,8 @@ def optimize_model_cast(model):
def
optimize_model_slice
(
model
):
def
optimize_model_slice
(
model
):
"""
"""
strip cascade and unecessary onnx::Slice-1:9
strip cascade and unecessary onnx::Slice-1:9
"""
"""
nodes
=
model
.
graph
.
node
nodes
=
model
.
graph
.
node
input_refs
,
output_refs
=
build_value_refs
(
nodes
)
input_refs
,
output_refs
=
build_value_refs
(
nodes
)
...
...
onnx2fluid/onnx2fluid/symbolic.py
浏览文件 @
826481c4
...
@@ -41,74 +41,74 @@ DEFAULT_OP_MAPPING_FIELD_VALUES['FILL_NAME_FIELD'] = True
...
@@ -41,74 +41,74 @@ DEFAULT_OP_MAPPING_FIELD_VALUES['FILL_NAME_FIELD'] = True
DEFAULT_OP_MAPPING_VALUES
=
list
(
DEFAULT_OP_MAPPING_FIELD_VALUES
.
values
())
DEFAULT_OP_MAPPING_VALUES
=
list
(
DEFAULT_OP_MAPPING_FIELD_VALUES
.
values
())
DEFAULT_OP_MAPPING
=
{
DEFAULT_OP_MAPPING
=
{
## nil ops ##
## nil ops ##
'RandomUniform'
:
'RandomUniform'
:
[
'uniform_random'
,
[],
[
'Out'
],
dict
(
high
=
'max'
,
low
=
'min'
),
[
'uniform_random'
,
[],
[
'Out'
],
dict
(
high
=
'max'
,
low
=
'min'
),
dict
(),
None
,
None
,
False
],
dict
(),
None
,
None
,
False
],
'RandomNormal'
:
'RandomNormal'
:
[
'gaussian_random'
,
[],
[
'Out'
],
dict
(
scale
=
'std'
),
[
'gaussian_random'
,
[],
[
'Out'
],
dict
(
scale
=
'std'
),
dict
(),
None
,
None
,
False
],
dict
(),
None
,
None
,
False
],
## unary ops ##
## unary ops ##
'Abs'
:
[
'abs'
,
[
'X'
],
[
'Out'
]],
'Abs'
:
[
'abs'
,
[
'X'
],
[
'Out'
]],
'ArgMax'
:
[
'argmax'
,
[
'X'
],
[
'Out'
],
dict
(
keepdims
=
''
)],
'ArgMax'
:
[
'argmax'
,
[
'X'
],
[
'Out'
],
dict
(
keepdims
=
''
)],
'ArgMin'
:
[
'argmin'
,
[
'X'
],
[
'Out'
],
dict
(
keepdims
=
''
)],
'ArgMin'
:
[
'argmin'
,
[
'X'
],
[
'Out'
],
dict
(
keepdims
=
''
)],
'Ceil'
:
[
'ceil'
,
[
'X'
],
[
'Out'
]],
'Ceil'
:
[
'ceil'
,
[
'X'
],
[
'Out'
]],
'Clip'
:
[
'clip'
,
[
'X'
],
[
'Out'
]],
# attrs bypassed
'Clip'
:
[
'clip'
,
[
'X'
],
[
'Out'
]],
# attrs bypassed
'Cos'
:
[
'cos'
,
[
'X'
],
[
'Out'
]],
'Cos'
:
[
'cos'
,
[
'X'
],
[
'Out'
]],
'Elu'
:
[
'elu'
,
[
'X'
],
[
'Out'
]],
'Elu'
:
[
'elu'
,
[
'X'
],
[
'Out'
]],
'Exp'
:
[
'exp'
,
[
'X'
],
[
'Out'
]],
'Exp'
:
[
'exp'
,
[
'X'
],
[
'Out'
]],
'Flatten'
:
[
'flatten'
,
[
'X'
],
[
'Out'
]],
# attrs bypassed, FIXME: emit flatten2
'Flatten'
:
[
'flatten'
,
[
'X'
],
[
'Out'
]],
# attrs bypassed, FIXME: emit flatten2
'Floor'
:
[
'floor'
,
[
'X'
],
[
'Out'
]],
'Floor'
:
[
'floor'
,
[
'X'
],
[
'Out'
]],
'Gather'
:
[
'gather'
,
[
'X'
],
[
'Out'
],
dict
(
axis
=
''
)],
'Gather'
:
[
'gather'
,
[
'X'
],
[
'Out'
],
dict
(
axis
=
''
)],
'LeakyRelu'
:
[
'leaky_relu'
,
[
'X'
],
[
'Out'
]],
'LeakyRelu'
:
[
'leaky_relu'
,
[
'X'
],
[
'Out'
]],
'Log'
:
[
'log'
,
[
'X'
],
[
'Out'
]],
'Log'
:
[
'log'
,
[
'X'
],
[
'Out'
]],
'LRN'
:
[
'lrn'
,
[
'X'
],
[
'Out'
,
'MidOut'
],
dict
(
size
=
'n'
,
bias
=
'k'
)],
#
'LRN'
:
[
'lrn'
,
[
'X'
],
[
'Out'
,
'MidOut'
],
dict
(
size
=
'n'
,
bias
=
'k'
)],
#
'Reciprocal'
:
[
'reciprocal'
,
[
'X'
],
[
'Out'
]],
'Reciprocal'
:
[
'reciprocal'
,
[
'X'
],
[
'Out'
]],
'Relu'
:
[
'relu'
,
[
'X'
],
[
'Out'
]],
'Relu'
:
[
'relu'
,
[
'X'
],
[
'Out'
]],
'Selu'
:
[
'selu'
,
[
'X'
],
[
'Out'
],
dict
(
gamma
=
'scale'
)],
'Selu'
:
[
'selu'
,
[
'X'
],
[
'Out'
],
dict
(
gamma
=
'scale'
)],
'Shape'
:
[
'shape'
,
[
'X'
],
[
'Out'
]],
# FIXME: out is int64 vs int32
'Shape'
:
[
'shape'
,
[
'X'
],
[
'Out'
]],
# FIXME: out is int64 vs int32
'Shrink'
:
[
'softshrink'
,
[
'X'
],
[
'Out'
],
dict
(
bias
=
''
,
labmd
=
''
)],
'Shrink'
:
[
'softshrink'
,
[
'X'
],
[
'Out'
],
dict
(
bias
=
''
,
labmd
=
''
)],
'Sigmoid'
:
[
'sigmoid'
,
[
'X'
],
[
'Out'
]],
'Sigmoid'
:
[
'sigmoid'
,
[
'X'
],
[
'Out'
]],
'Sin'
:
[
'sin'
,
[
'X'
],
[
'Out'
]],
'Sin'
:
[
'sin'
,
[
'X'
],
[
'Out'
]],
'Squeeze'
:
[
'squeeze'
,
[
'X'
],
[
'Out'
]],
# attrs bypassed, FIXME: emit squeeze2
'Squeeze'
:
[
'squeeze'
,
[
'X'
],
[
'Out'
]],
# attrs bypassed, FIXME: emit squeeze2
'Softplus'
:
[
'softplus'
,
[
'X'
],
[
'Out'
]],
'Softplus'
:
[
'softplus'
,
[
'X'
],
[
'Out'
]],
# FIXME: default axis = -1, reshape required before and after
# FIXME: default axis = -1, reshape required before and after
'Softmax'
:
[
'softmax'
,
[
'X'
],
[
'Out'
],
dict
(
axis
=
''
)],
'Softmax'
:
[
'softmax'
,
[
'X'
],
[
'Out'
],
dict
(
axis
=
''
)],
'Softsign'
:
[
'softsign'
,
[
'X'
],
[
'Out'
]],
'Softsign'
:
[
'softsign'
,
[
'X'
],
[
'Out'
]],
'Sqrt'
:
[
'sqrt'
,
[
'X'
],
[
'Out'
]],
'Sqrt'
:
[
'sqrt'
,
[
'X'
],
[
'Out'
]],
'Tanh'
:
[
'tanh'
,
[
'X'
],
[
'Out'
]],
'Tanh'
:
[
'tanh'
,
[
'X'
],
[
'Out'
]],
'ThresholdedRelu'
:
[
'thresholded_relu'
,
[
'X'
],
[
'Out'
],
dict
(
alpha
=
'threshold'
)],
'ThresholdedRelu'
:
[
'thresholded_relu'
,
[
'X'
],
[
'Out'
],
dict
(
alpha
=
'threshold'
)],
#'Transpose': ['transpose', ['X'], ['Out']],
#'Transpose': ['transpose', ['X'], ['Out']],
'Unsqueeze'
:
[
'unsqueeze'
,
[
'X'
],
[
'Out'
]],
# attrs bypassed, FIXME: emit unsqueeze2
'Unsqueeze'
:
[
'unsqueeze'
,
[
'X'
],
[
'Out'
]],
# attrs bypassed, FIXME: emit unsqueeze2
## binary ops ##
## binary ops ##
'Add'
:
[
'elementwise_add'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
'Add'
:
[
'elementwise_add'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
#'AffineGrid': ['affine_grid', ['Theta'], ['Output'], dict(size='out_shape')],
#'AffineGrid': ['affine_grid', ['Theta'], ['Output'], dict(size='out_shape')],
'And'
:
[
'logical_and'
,
[
'X'
,
'Y'
],
[
'Out'
]],
'And'
:
[
'logical_and'
,
[
'X'
,
'Y'
],
[
'Out'
]],
'Div'
:
[
'elementwise_div'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
'Div'
:
[
'elementwise_div'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
'Equal'
:
[
'equal'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(),
None
,
None
,
False
],
'Equal'
:
[
'equal'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(),
None
,
None
,
False
],
'Greater'
:
[
'less_than'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(),
[
1
,
0
],
None
,
False
],
'Greater'
:
[
'less_than'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(),
[
1
,
0
],
None
,
False
],
'Less'
:
[
'less_than'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(),
None
,
None
,
False
],
'Less'
:
[
'less_than'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(),
None
,
None
,
False
],
'MatMul'
:
[
'matmul'
,
[
'X'
,
'Y'
],
[
'Out'
]],
# defaults excluded for transpose_x vs transpose_X
'MatMul'
:
[
'matmul'
,
[
'X'
,
'Y'
],
[
'Out'
]],
# defaults excluded for transpose_x vs transpose_X
'Max'
:
[
'elementwise_max'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
'Max'
:
[
'elementwise_max'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
'Min'
:
[
'elementwise_min'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
'Min'
:
[
'elementwise_min'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
'Mul'
:
[
'elementwise_mul'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
'Mul'
:
[
'elementwise_mul'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
'Not'
:
[
'logical_not'
,
[
'X'
,
'Y'
],
[
'Out'
]],
'Not'
:
[
'logical_not'
,
[
'X'
,
'Y'
],
[
'Out'
]],
'OneHot'
:
# assuming values=[0, 1], axis=-1 and drop them
'OneHot'
:
# assuming values=[0, 1], axis=-1 and drop them
[
'one_hot'
,
[
'Input'
,
'Depth'
],
[
'Out'
],
dict
(
axis
=
''
),
dict
(),
[
'one_hot'
,
[
'Input'
,
'Depth'
],
[
'Out'
],
dict
(
axis
=
''
),
dict
(),
[
0
,
1
],
None
,
False
],
[
0
,
1
],
None
,
False
],
'Or'
:
[
'logical_or'
,
[
'X'
,
'Y'
],
[
'Out'
]],
'Or'
:
[
'logical_or'
,
[
'X'
,
'Y'
],
[
'Out'
]],
'Pow'
:
[
'elementwise_pow'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
# TODO: pow for scalar exponent
'Pow'
:
[
'elementwise_pow'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
# TODO: pow for scalar exponent
'Sub'
:
[
'elementwise_sub'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
'Sub'
:
[
'elementwise_sub'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
'Xor'
:
[
'logical_xor'
,
[
'X'
,
'Y'
],
[
'Out'
]],
'Xor'
:
[
'logical_xor'
,
[
'X'
,
'Y'
],
[
'Out'
]],
# reduce ops
# reduce ops
'ReduceMax'
:
[
'reduce_max'
,
[
'X'
],
[
'Out'
],
dict
(
axes
=
'dim'
,
keepdims
=
'keep_dim'
)],
'ReduceMax'
:
[
'reduce_max'
,
[
'X'
],
[
'Out'
],
dict
(
axes
=
'dim'
,
keepdims
=
'keep_dim'
)],
'ReduceMean'
:
[
'reduce_mean'
,
[
'X'
],
[
'Out'
],
dict
(
axes
=
'dim'
,
keepdims
=
'keep_dim'
)],
'ReduceMean'
:
[
'reduce_mean'
,
[
'X'
],
[
'Out'
],
dict
(
axes
=
'dim'
,
keepdims
=
'keep_dim'
)],
'ReduceMin'
:
[
'reduce_min'
,
[
'X'
],
[
'Out'
],
dict
(
axes
=
'dim'
,
keepdims
=
'keep_dim'
)],
'ReduceMin'
:
[
'reduce_min'
,
[
'X'
],
[
'Out'
],
dict
(
axes
=
'dim'
,
keepdims
=
'keep_dim'
)],
'ReduceProd'
:
[
'reduce_prod'
,
[
'X'
],
[
'Out'
],
dict
(
axes
=
'dim'
,
keepdims
=
'keep_dim'
)],
'ReduceProd'
:
[
'reduce_prod'
,
[
'X'
],
[
'Out'
],
dict
(
axes
=
'dim'
,
keepdims
=
'keep_dim'
)],
'ReduceSum'
:
[
'reduce_sum'
,
[
'X'
],
[
'Out'
],
dict
(
axes
=
'dim'
,
keepdims
=
'keep_dim'
)],
'ReduceSum'
:
[
'reduce_sum'
,
[
'X'
],
[
'Out'
],
dict
(
axes
=
'dim'
,
keepdims
=
'keep_dim'
)],
# other ops
# other ops
'Scatter'
:
[
'scatter'
,
[
'X'
,
'Index'
,
'Updates'
],
[
'Out'
]],
'Scatter'
:
[
'scatter'
,
[
'X'
,
'Index'
,
'Updates'
],
[
'Out'
]],
'TopK'
:
[
'topk'
,
[
'X'
,
'K'
],
[
'Out'
,
'Indices'
]],
'TopK'
:
[
'topk'
,
[
'X'
,
'K'
],
[
'Out'
,
'Indices'
]],
}
}
DEFAULT_IOA_CONSTRAINTS
=
{
DEFAULT_IOA_CONSTRAINTS
=
{
...
@@ -146,14 +146,14 @@ DEFAULT_IOA_CONSTRAINTS = {
...
@@ -146,14 +146,14 @@ DEFAULT_IOA_CONSTRAINTS = {
def
_make_var_name
(
name
):
def
_make_var_name
(
name
):
"""
"""
make a valid variable name in Python code and in filesystem
make a valid variable name in Python code and in filesystem
"""
"""
if
name
==
''
:
if
name
==
''
:
return
'_'
return
'_'
if
name
[
0
].
isdigit
():
if
name
[
0
].
isdigit
():
return
'var_'
+
name
return
'var_'
+
name
for
s
in
'
\\
|/:'
:
#
for
s
in
'
\\
|/:
-
'
:
#
name
=
name
.
replace
(
s
,
'_'
)
name
=
name
.
replace
(
s
,
'_'
)
if
name
.
startswith
(
'_'
):
if
name
.
startswith
(
'_'
):
name
=
'var'
+
name
name
=
'var'
+
name
...
@@ -191,14 +191,24 @@ def _const_weight_or_none(value_infos, val_name):
...
@@ -191,14 +191,24 @@ def _const_weight_or_none(value_infos, val_name):
return
None
return
None
value_info
=
value_infos
[
val_name
]
value_info
=
value_infos
[
val_name
]
const_value
=
value_info
.
get
(
'const_value'
,
None
)
const_value
=
value_info
.
get
(
'const_value'
,
None
)
if
const_value
:
if
const_value
is
not
None
:
return
const_value
return
const_value
get_weight_func
=
value_info
.
get
(
'get_weight'
,
None
)
get_weight_func
=
value_info
.
get
(
'get_weight'
,
None
)
if
get_weight_func
:
if
get_weight_func
is
not
None
:
return
get_weight_func
()
return
get_weight_func
()
return
None
return
None
def
_check_embeddable
(
value_infos
,
*
val_names
):
keyword
=
'get_weight'
for
val_name
in
val_names
:
if
keyword
not
in
value_infos
[
val_name
]:
_logger
.
warning
(
'parameter %s not embeddable for some ops'
,
val_name
)
return
False
return
True
def
_default
(
prog
,
op_type
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
def
_default
(
prog
,
op_type
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
info
=
DEFAULT_OP_MAPPING
[
op_type
]
info
=
DEFAULT_OP_MAPPING
[
op_type
]
info
.
extend
(
DEFAULT_OP_MAPPING_VALUES
[
len
(
info
):])
info
.
extend
(
DEFAULT_OP_MAPPING_VALUES
[
len
(
info
):])
...
@@ -391,9 +401,9 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
...
@@ -391,9 +401,9 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
input_shape
=
_shape_or_none
(
value_infos
,
val_x
)
input_shape
=
_shape_or_none
(
value_infos
,
val_x
)
output_shape
=
_shape_or_none
(
value_infos
,
val_y
)
output_shape
=
_shape_or_none
(
value_infos
,
val_y
)
assert
input_shape
is
not
None
or
output_shape
is
not
None
,
'poolnd not inferred'
# NC...
assert
input_shape
is
not
None
or
output_shape
is
not
None
,
'poolnd not inferred'
# NC...
if
input_shape
:
if
input_shape
is
not
None
:
poolnd
=
len
(
input_shape
)
-
2
# NC...
poolnd
=
len
(
input_shape
)
-
2
# NC...
elif
output_shape
:
elif
output_shape
is
not
None
:
poolnd
=
len
(
output_shape
)
-
2
# NC...
poolnd
=
len
(
output_shape
)
-
2
# NC...
assert
2
<=
poolnd
<=
3
,
'only pool2d and pool3d is supported'
assert
2
<=
poolnd
<=
3
,
'only pool2d and pool3d is supported'
...
@@ -568,7 +578,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
...
@@ -568,7 +578,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
1
]
==
1
,
'only scale on (NC)HW supported'
1
]
==
1
,
'only scale on (NC)HW supported'
assert
scales
[
2
]
==
scales
[
assert
scales
[
2
]
==
scales
[
3
],
'only aspect-ratio-invariant scale supported'
3
],
'only aspect-ratio-invariant scale supported'
scale
=
scales
[
2
]
if
scales
else
None
scale
=
None
if
scales
is
None
else
scales
[
2
]
# try input shape
# try input shape
if
scale
is
None
:
if
scale
is
None
:
assert
out_shape_
,
'neither scales nor output shape is available'
assert
out_shape_
,
'neither scales nor output shape is available'
...
@@ -613,24 +623,24 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
...
@@ -613,24 +623,24 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
def
AdaptiveAveragePool
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
def
AdaptiveAveragePool
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
"""
"""
aten::adaptive_avg_poolnd
aten::adaptive_avg_poolnd
"""
"""
return
_adaptive_pool
(
prog
,
'avg'
,
inputs
,
outputs
,
attrs
,
name
=
name
)
return
_adaptive_pool
(
prog
,
'avg'
,
inputs
,
outputs
,
attrs
,
name
=
name
)
def
AdaptiveMaxPool
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
def
AdaptiveMaxPool
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
"""
"""
aten::adaptive_max_poolnd
aten::adaptive_max_poolnd
"""
"""
return
_adaptive_pool
(
prog
,
'max'
,
inputs
,
outputs
,
attrs
,
name
=
name
)
return
_adaptive_pool
(
prog
,
'max'
,
inputs
,
outputs
,
attrs
,
name
=
name
)
def
AffineGrid
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
def
AffineGrid
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
"""
"""
aten::affine_grid
aten::affine_grid
"""
"""
# I/O
# I/O
val_theta
,
=
inputs
val_theta
,
=
inputs
...
@@ -672,8 +682,8 @@ def AveragePool(prog,
...
@@ -672,8 +682,8 @@ def AveragePool(prog,
*
args
,
*
args
,
**
kwargs
):
**
kwargs
):
"""
"""
onnx::AveragePool-10:
onnx::AveragePool-10:
"""
"""
return
_pool
(
prog
,
'avg'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
return
_pool
(
prog
,
'avg'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
...
@@ -688,8 +698,8 @@ def BatchNormalization(prog,
...
@@ -688,8 +698,8 @@ def BatchNormalization(prog,
*
args
,
*
args
,
**
kwargs
):
**
kwargs
):
"""
"""
onnx::BatchNormalization-9:
onnx::BatchNormalization-9:
"""
"""
# I/O
# I/O
val_x
,
val_scale
,
val_b
,
val_mean
,
val_var
=
inputs
val_x
,
val_scale
,
val_b
,
val_mean
,
val_var
=
inputs
...
@@ -704,16 +714,19 @@ def BatchNormalization(prog,
...
@@ -704,16 +714,19 @@ def BatchNormalization(prog,
momentum
=
attrs
.
get
(
'momentum'
,
.
9
)
# optional
momentum
=
attrs
.
get
(
'momentum'
,
.
9
)
# optional
epsilon
=
attrs
.
get
(
'epsilon'
,
1e-5
)
# optional
epsilon
=
attrs
.
get
(
'epsilon'
,
1e-5
)
# optional
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
embed_params
=
_check_embeddable
(
value_infos
,
val_scale
,
val_b
,
val_mean
,
val_var
)
if
embed_params
:
if
embed_params
:
assert
name
!=
''
assert
name
!=
''
var_scale
=
name
+
'.w_0'
var_scale
=
name
+
'.w_0'
var_b
=
name
+
'.b_0'
var_b
=
name
+
'.b_0'
var_mean
=
name
+
'.w_1'
var_mean
=
name
+
'.w_1'
var_var
=
name
+
'.w_2'
var_var
=
name
+
'.w_2'
value_infos
[
val_scale
]
.
setdefault
(
'embeded_as'
,
[])
.
append
(
var_scale
)
value_infos
[
val_scale
]
[
'embeded_as'
]
.
append
(
var_scale
)
value_infos
[
val_b
]
.
setdefault
(
'embeded_as'
,
[])
.
append
(
var_b
)
value_infos
[
val_b
]
[
'embeded_as'
]
.
append
(
var_b
)
value_infos
[
val_mean
]
.
setdefault
(
'embeded_as'
,
[])
.
append
(
var_mean
)
value_infos
[
val_mean
]
[
'embeded_as'
]
.
append
(
var_mean
)
value_infos
[
val_var
]
.
setdefault
(
'embeded_as'
,
[])
.
append
(
var_var
)
value_infos
[
val_var
]
[
'embeded_as'
]
.
append
(
var_var
)
param_attr
=
''
param_attr
=
''
else
:
else
:
var_scale
=
_make_var_name
(
val_scale
)
var_scale
=
_make_var_name
(
val_scale
)
...
@@ -760,8 +773,8 @@ def BatchNormalization(prog,
...
@@ -760,8 +773,8 @@ def BatchNormalization(prog,
def
Cast
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
def
Cast
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
"""
"""
onnx::Cast-9:
onnx::Cast-9:
"""
"""
# I/O
# I/O
val_input
,
=
inputs
val_input
,
=
inputs
...
@@ -774,7 +787,7 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -774,7 +787,7 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
if
not
isinstance
(
dtype
,
_np
.
dtype
):
# additional: possible np.dtype
if
not
isinstance
(
dtype
,
_np
.
dtype
):
# additional: possible np.dtype
dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
dtype
]
dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
dtype
]
output_dtype
=
_dtype_or_none
(
value_infos
,
val_output
)
output_dtype
=
_dtype_or_none
(
value_infos
,
val_output
)
if
output_dtype
:
if
output_dtype
is
not
None
:
assert
dtype
==
output_dtype
,
'dtype of to unmatches output'
assert
dtype
==
output_dtype
,
'dtype of to unmatches output'
fluid_op
=
'cast'
fluid_op
=
'cast'
...
@@ -804,8 +817,8 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -804,8 +817,8 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
def
Concat
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
def
Concat
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
"""
"""
onnx::Concat-4:
onnx::Concat-4:
"""
"""
# I/O
# I/O
val_concat_result
,
=
outputs
val_concat_result
,
=
outputs
...
@@ -839,11 +852,11 @@ def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs):
...
@@ -839,11 +852,11 @@ def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs):
def
Constant
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
def
Constant
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
"""
"""
onnx::Constant-9:
onnx::Constant-9:
"""
"""
# I/O
# I/O
assert
len
(
inputs
)
==
0
assert
len
(
inputs
)
==
0
,
'constant op accept no inputs'
val_output
,
=
outputs
val_output
,
=
outputs
var_output
=
_make_var_name
(
val_output
)
var_output
=
_make_var_name
(
val_output
)
...
@@ -851,7 +864,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -851,7 +864,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
value
=
attrs
[
'value'
]
# required
value
=
attrs
[
'value'
]
# required
dtype
=
_np
.
dtype
(
value
.
dtype
)
dtype
=
_np
.
dtype
(
value
.
dtype
)
output_dtype
=
_dtype_or_none
(
value_infos
,
val_output
)
output_dtype
=
_dtype_or_none
(
value_infos
,
val_output
)
if
output_dtype
:
if
output_dtype
is
not
None
:
assert
dtype
==
output_dtype
,
'tensor dtype unmatches storage dtype'
assert
dtype
==
output_dtype
,
'tensor dtype unmatches storage dtype'
...
@@ -900,8 +913,8 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -900,8 +913,8 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
def
ConstantOfShape
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
def
ConstantOfShape
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
"""
"""
onnx::ConstantOfShape-9:
onnx::ConstantOfShape-9:
"""
"""
# I/O
# I/O
val_shape
,
=
inputs
val_shape
,
=
inputs
...
@@ -939,8 +952,8 @@ def Conv(prog,
...
@@ -939,8 +952,8 @@ def Conv(prog,
*
args
,
*
args
,
**
kwargs
):
**
kwargs
):
"""
"""
onnx::Conv-1:
onnx::Conv-1:
"""
"""
# I/O
# I/O
val_x
,
val_w
=
inputs
[:
2
]
val_x
,
val_w
=
inputs
[:
2
]
...
@@ -970,13 +983,16 @@ def Conv(prog,
...
@@ -970,13 +983,16 @@ def Conv(prog,
paddings
,
val_x
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
paddings
,
val_x
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
var_x
=
_make_var_name
(
val_x
)
var_x
=
_make_var_name
(
val_x
)
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
embed_params
=
(
_check_embeddable
(
value_infos
,
val_w
)
and
not
has_bias
or
_check_embeddable
(
value_infos
,
val_b
))
if
embed_params
:
if
embed_params
:
assert
name
!=
''
assert
name
!=
''
var_w
=
name
+
'.w_0'
var_w
=
name
+
'.w_0'
value_infos
[
val_w
]
.
setdefault
(
'embeded_as'
,
[])
.
append
(
var_w
)
value_infos
[
val_w
]
[
'embeded_as'
]
.
append
(
var_w
)
if
has_bias
:
if
has_bias
:
var_b
=
name
+
'.b_0'
var_b
=
name
+
'.b_0'
value_infos
[
val_b
]
.
setdefault
(
'embeded_as'
,
[])
.
append
(
var_b
)
value_infos
[
val_b
]
[
'embeded_as'
]
.
append
(
var_b
)
param_attr
=
''
param_attr
=
''
else
:
else
:
param_attr
=
', bias_attr=False'
param_attr
=
', bias_attr=False'
...
@@ -1046,8 +1062,8 @@ def ConvTranspose(prog,
...
@@ -1046,8 +1062,8 @@ def ConvTranspose(prog,
*
args
,
*
args
,
**
kwargs
):
**
kwargs
):
"""
"""
onnx::ConvTranspose-1:
onnx::ConvTranspose-1:
"""
"""
# I/O
# I/O
val_x
,
val_w
=
inputs
[:
2
]
val_x
,
val_w
=
inputs
[:
2
]
...
@@ -1080,13 +1096,16 @@ def ConvTranspose(prog,
...
@@ -1080,13 +1096,16 @@ def ConvTranspose(prog,
paddings
,
val_x
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
paddings
,
val_x
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
var_x
=
_make_var_name
(
val_x
)
var_x
=
_make_var_name
(
val_x
)
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
embed_params
=
(
_check_embeddable
(
value_infos
,
val_w
)
and
not
has_bias
or
_check_embeddable
(
value_infos
,
val_b
))
if
embed_params
:
if
embed_params
:
assert
name
!=
''
assert
name
!=
''
var_w
=
name
+
'.w_0'
var_w
=
name
+
'.w_0'
value_infos
[
val_w
]
.
setdefault
(
'embeded_as'
,
[])
.
append
(
var_w
)
value_infos
[
val_w
]
[
'embeded_as'
]
.
append
(
var_w
)
if
has_bias
:
if
has_bias
:
var_b
=
name
+
'.b_0'
var_b
=
name
+
'.b_0'
value_infos
[
val_b
]
.
setdefault
(
'embeded_as'
,
[])
.
append
(
var_b
)
value_infos
[
val_b
]
[
'embeded_as'
]
.
append
(
var_b
)
param_attr
=
''
param_attr
=
''
else
:
else
:
param_attr
=
', bias_attr=False'
param_attr
=
', bias_attr=False'
...
@@ -1167,8 +1186,8 @@ def ConvTranspose(prog,
...
@@ -1167,8 +1186,8 @@ def ConvTranspose(prog,
def
Gemm
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
**
kwargs
):
def
Gemm
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
**
kwargs
):
"""
"""
onnx::Gemm-9:
onnx::Gemm-9:
"""
"""
# due to fluid fc don't support transposed weight, we use matmul + ew_add
# due to fluid fc don't support transposed weight, we use matmul + ew_add
val_a
,
val_b
,
val_c
=
inputs
val_a
,
val_b
,
val_c
=
inputs
...
@@ -1259,8 +1278,8 @@ def GlobalAveragePool(prog,
...
@@ -1259,8 +1278,8 @@ def GlobalAveragePool(prog,
*
args
,
*
args
,
**
kwargs
):
**
kwargs
):
"""
"""
onnx::GlobalAveragePool-1:
onnx::GlobalAveragePool-1:
"""
"""
return
_global_pool
(
prog
,
return
_global_pool
(
prog
,
'avg'
,
'avg'
,
...
@@ -1280,8 +1299,8 @@ def GlobalMaxPool(prog,
...
@@ -1280,8 +1299,8 @@ def GlobalMaxPool(prog,
*
args
,
*
args
,
**
kwargs
):
**
kwargs
):
"""
"""
onnx::GlobalMaxPool-1:
onnx::GlobalMaxPool-1:
"""
"""
return
_global_pool
(
prog
,
return
_global_pool
(
prog
,
'max'
,
'max'
,
...
@@ -1295,8 +1314,8 @@ def GlobalMaxPool(prog,
...
@@ -1295,8 +1314,8 @@ def GlobalMaxPool(prog,
def
MaxPool
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
def
MaxPool
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
**
kwargs
):
**
kwargs
):
"""
"""
onnx::MaxPool-10:
onnx::MaxPool-10:
"""
"""
return
_pool
(
prog
,
'max'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
return
_pool
(
prog
,
'max'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
...
@@ -1304,16 +1323,16 @@ def MaxPool(prog, inputs, outputs, attrs, value_infos, name='', *args,
...
@@ -1304,16 +1323,16 @@ def MaxPool(prog, inputs, outputs, attrs, value_infos, name='', *args,
def
MaxRoiPool
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
def
MaxRoiPool
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
**
kwargs
):
**
kwargs
):
"""
"""
onnx::MaxRoiPool-1:
onnx::MaxRoiPool-1:
"""
"""
_roi_pool
(
prog
,
'roi_pool'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
)
_roi_pool
(
prog
,
'roi_pool'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
)
def
Pad
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
**
kwargs
):
def
Pad
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
**
kwargs
):
"""
"""
onnx::Pad-2:
onnx::Pad-2:
"""
"""
# I/O
# I/O
val_data
,
=
inputs
val_data
,
=
inputs
...
@@ -1330,9 +1349,9 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
...
@@ -1330,9 +1349,9 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
assume_pad2d
=
False
assume_pad2d
=
False
if
len
(
pads
)
==
4
:
if
len
(
pads
)
==
4
:
assume_pad2d
|=
mode
!=
'constant'
assume_pad2d
|=
mode
!=
'constant'
if
data_shape
:
if
data_shape
is
not
None
:
assume_pad2d
|=
data_shape
and
len
(
data_shape
)
==
4
# NCHW
assume_pad2d
|=
data_shape
and
len
(
data_shape
)
==
4
# NCHW
if
output_shape
:
if
output_shape
is
not
None
:
assume_pad2d
|=
output_shape
and
len
(
output_shape
)
==
4
# NCHW
assume_pad2d
|=
output_shape
and
len
(
output_shape
)
==
4
# NCHW
od_attrs
=
{
'pad_value'
:
value
}
od_attrs
=
{
'pad_value'
:
value
}
if
assume_pad2d
:
if
assume_pad2d
:
...
@@ -1383,8 +1402,8 @@ def PRelu(prog,
...
@@ -1383,8 +1402,8 @@ def PRelu(prog,
*
args
,
*
args
,
**
kwargs
):
**
kwargs
):
"""
"""
onnx::PRelu-9:
onnx::PRelu-9:
"""
"""
# I/O
# I/O
val_x
,
val_slope
=
inputs
val_x
,
val_slope
=
inputs
...
@@ -1404,10 +1423,12 @@ def PRelu(prog,
...
@@ -1404,10 +1423,12 @@ def PRelu(prog,
mode
=
'element'
mode
=
'element'
fluid_op
=
'prelu'
fluid_op
=
'prelu'
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
embed_params
=
_check_embeddable
(
value_infos
,
val_slope
)
if
embed_params
:
if
embed_params
:
assert
name
!=
''
assert
name
!=
''
var_slope
=
name
+
'.w_0'
var_slope
=
name
+
'.w_0'
value_infos
[
val_slope
]
.
setdefault
(
'embeded_as'
,
[])
.
append
(
var_slope
)
value_infos
[
val_slope
]
[
'embeded_as'
]
.
append
(
var_slope
)
param_attr
=
''
param_attr
=
''
else
:
else
:
var_slope
=
_make_var_name
(
val_slope
)
var_slope
=
_make_var_name
(
val_slope
)
...
@@ -1436,16 +1457,16 @@ def PRelu(prog,
...
@@ -1436,16 +1457,16 @@ def PRelu(prog,
def
PsRoiPool
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
**
kwargs
):
def
PsRoiPool
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
**
kwargs
):
"""
"""
caffe2::PsRoiPool
caffe2::PsRoiPool
"""
"""
_roi_pool
(
prog
,
'psroi_pool'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
)
_roi_pool
(
prog
,
'psroi_pool'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
)
def
Reshape
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
**
kwargs
):
def
Reshape
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
**
kwargs
):
"""
"""
onnx::Reshape-5:
onnx::Reshape-5:
"""
"""
# I/O
# I/O
val_data
,
val_shape
=
inputs
val_data
,
val_shape
=
inputs
...
@@ -1474,6 +1495,8 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1474,6 +1495,8 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
# generation
# generation
val_shape_int32
=
val_shape
+
'_int32'
# explicit variable
var_shape_int32
=
_make_var_name
(
val_shape_int32
)
prog
.
Code
(
'# shape:{}={} # const as literal'
.
format
(
var_shape
,
shape
))
prog
.
Code
(
'# shape:{}={} # const as literal'
.
format
(
var_shape
,
shape
))
if
is_const_shape
:
if
is_const_shape
:
prog
.
Code
(
'{} = layers.{}({}'
prog
.
Code
(
'{} = layers.{}({}'
...
@@ -1487,8 +1510,6 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1487,8 +1510,6 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
name_attr
,
name_attr
,
))
))
else
:
else
:
val_shape_int32
=
val_shape
+
'_int32'
# explicit variable
var_shape_int32
=
_make_var_name
(
val_shape_int32
)
prog
.
Op
(
prog
.
Op
(
''
,
''
,
'Cast'
,
'Cast'
,
...
@@ -1514,34 +1535,26 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1514,34 +1535,26 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
var_xshape
=
name
+
'.xshape'
# dummy output
var_xshape
=
name
+
'.xshape'
# dummy output
prog
.
VarDesc
(
var_reshaped
)
prog
.
VarDesc
(
var_reshaped
)
prog
.
VarDesc
(
var_xshape
)
prog
.
VarDesc
(
var_xshape
)
if
is_const_shape
:
prog
.
OpDesc
(
prog
.
OpDesc
(
fluid_op
,
fluid_op
,
([
var_data
,
var_shape_int32
],
'X'
,
'Shape'
),
([
var_data
],
'X'
),
([
var_reshaped
,
var_xshape
],
'Out'
,
'XShape'
),
([
var_reshaped
,
var_xshape
],
'Out'
,
'XShape'
),
{
'shape'
:
shape
},
{
'shape'
:
shape
},
)
)
else
:
prog
.
OpDesc
(
fluid_op
,
([
var_data
,
var_shape_int32
],
'X'
,
'Shape'
),
([
var_reshaped
,
var_xshape
],
'Out'
,
'XShape'
),
{
'shape'
:
shape
},
)
def
Resize
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
**
kwargs
):
def
Resize
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
**
kwargs
):
"""
"""
onnx::Resize-10:
onnx::Resize-10:
"""
"""
return
_interpolate
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
return
_interpolate
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
def
RoiAlign
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
**
kwargs
):
def
RoiAlign
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
**
kwargs
):
"""
"""
caffe2::RoiAlign
caffe2::RoiAlign
"""
"""
_roi_pool
(
prog
,
'roi_align'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
)
_roi_pool
(
prog
,
'roi_align'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
)
...
@@ -1580,8 +1593,8 @@ def RoiAlign(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1580,8 +1593,8 @@ def RoiAlign(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
def
Slice
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
def
Slice
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
"""
"""
onnx::Slice-1:9
onnx::Slice-1:9
"""
"""
# I/O
# I/O
val_data
,
=
inputs
val_data
,
=
inputs
...
@@ -1595,7 +1608,7 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -1595,7 +1608,7 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
starts
=
attrs
[
'starts'
]
# required
starts
=
attrs
[
'starts'
]
# required
ends
=
attrs
[
'ends'
]
# required
ends
=
attrs
[
'ends'
]
# required
shape
=
_shape_or_none
(
value_infos
,
val_data
)
shape
=
_shape_or_none
(
value_infos
,
val_data
)
if
shape
:
if
shape
is
not
None
:
# ndims = len(shape)
# ndims = len(shape)
# for idx, value in enumerate(axes):
# for idx, value in enumerate(axes):
# if value > ONNX_INT_MAX // 2:
# if value > ONNX_INT_MAX // 2:
...
@@ -1639,8 +1652,8 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -1639,8 +1652,8 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
def
Split
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
def
Split
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
"""
"""
onnx::Split-2:
onnx::Split-2:
"""
"""
# I/O
# I/O
val_input
,
=
inputs
val_input
,
=
inputs
...
@@ -1680,8 +1693,8 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
...
@@ -1680,8 +1693,8 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
def
Sum
(
prog
,
inputs
,
outputs
,
*
args
,
**
kwargs
):
def
Sum
(
prog
,
inputs
,
outputs
,
*
args
,
**
kwargs
):
"""
"""
onnx::Sum-8:
onnx::Sum-8:
"""
"""
# I/O
# I/O
val_sum
,
=
outputs
val_sum
,
=
outputs
...
@@ -1710,8 +1723,8 @@ def Sum(prog, inputs, outputs, *args, **kwargs):
...
@@ -1710,8 +1723,8 @@ def Sum(prog, inputs, outputs, *args, **kwargs):
def
Tile
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
**
kwargs
):
def
Tile
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
**
kwargs
):
"""
"""
onnx::Tile-1:
onnx::Tile-1:
"""
"""
# I/O
# I/O
val_input
,
val_repeats
=
inputs
val_input
,
val_repeats
=
inputs
...
@@ -1749,8 +1762,8 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
...
@@ -1749,8 +1762,8 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
def
Transpose
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
def
Transpose
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
"""
"""
onnx::Transpose-1:
onnx::Transpose-1:
"""
"""
# I/O
# I/O
val_data
,
=
inputs
val_data
,
=
inputs
...
@@ -1795,8 +1808,8 @@ def Upsample(prog,
...
@@ -1795,8 +1808,8 @@ def Upsample(prog,
*
args
,
*
args
,
**
kwargs
):
**
kwargs
):
"""
"""
onnx::Upsample-9:9
onnx::Upsample-9:9
"""
"""
return
_interpolate
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
return
_interpolate
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
...
...
onnx2fluid/onnx2fluid/torch_export_helper.py
浏览文件 @
826481c4
...
@@ -25,7 +25,8 @@ def ensure_tuple(obj):
...
@@ -25,7 +25,8 @@ def ensure_tuple(obj):
def
flatten_list
(
obj
,
out
=
None
):
def
flatten_list
(
obj
,
out
=
None
):
assert
isinstance
(
obj
,
list
)
assert
isinstance
(
obj
,
list
),
'list type required'
if
out
is
None
:
if
out
is
None
:
out
=
type
(
obj
)()
out
=
type
(
obj
)()
for
item
in
obj
:
for
item
in
obj
:
...
@@ -38,11 +39,11 @@ def flatten_list(obj, out=None):
...
@@ -38,11 +39,11 @@ def flatten_list(obj, out=None):
def
export_data
(
state_dict
,
prefix
=
''
):
def
export_data
(
state_dict
,
prefix
=
''
):
"""
"""
export binary data with meta text for raw C++ inference engines
export binary data with meta text for raw C++ inference engines
"""
"""
def
str_
(
obj
):
def
str_
(
obj
):
if
isinstance
(
obj
,
(
tuple
,
list
)):
if
isinstance
(
obj
,
(
tuple
,
list
,
set
)):
return
str
(
obj
)[
1
:
-
1
].
replace
(
' '
,
''
)
return
str
(
obj
)[
1
:
-
1
].
replace
(
' '
,
''
)
return
str
(
obj
)
return
str
(
obj
)
...
@@ -72,8 +73,8 @@ def export_onnx_with_validation(model,
...
@@ -72,8 +73,8 @@ def export_onnx_with_validation(model,
*
args
,
*
args
,
**
kwargs
):
**
kwargs
):
"""
"""
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
"""
"""
is_tuple_or_list
=
lambda
x
:
isinstance
(
x
,
(
tuple
,
list
))
is_tuple_or_list
=
lambda
x
:
isinstance
(
x
,
(
tuple
,
list
))
...
...
onnx2fluid/onnx2fluid/validation.py
浏览文件 @
826481c4
...
@@ -10,14 +10,15 @@ import importlib, logging, os, sys
...
@@ -10,14 +10,15 @@ import importlib, logging, os, sys
def
flatten_dict
(
obj
,
out
=
None
):
def
flatten_dict
(
obj
,
out
=
None
):
assert
isinstance
(
obj
,
dict
)
assert
isinstance
(
obj
,
dict
),
'dict type required'
if
out
is
None
:
if
out
is
None
:
out
=
type
(
obj
)()
out
=
type
(
obj
)()
for
key
,
value
in
obj
.
items
():
for
key
,
value
in
obj
.
items
():
if
isinstance
(
value
,
dict
):
if
isinstance
(
value
,
dict
):
flatten_dict
(
value
,
out
)
flatten_dict
(
value
,
out
)
else
:
else
:
assert
key
not
in
out
assert
key
not
in
out
,
'key conflicted'
out
[
key
]
=
value
out
[
key
]
=
value
return
out
return
out
...
@@ -29,15 +30,16 @@ def ensure_list(obj):
...
@@ -29,15 +30,16 @@ def ensure_list(obj):
def
validate
(
fluid_model_filename
,
def
validate
(
fluid_model_filename
,
golden_data_filename
,
golden_data_filename
=
''
,
model_func_name
=
'inference'
,
atol
=
1e-3
,
atol
=
1e-3
,
rtol
=
1e-3
,
rtol
=
1e-3
,
model_func_name
=
'inference'
,
save_inference_model
=
False
,
save_inference_model
=
False
,
inference_input_names
=
None
,
**
kwargs
):
**
kwargs
):
"""
"""
inference the converted Paddle fluid model, validate with given golden data
inference the converted Paddle fluid model, validate with given golden data
"""
"""
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
...
@@ -86,24 +88,50 @@ def validate(fluid_model_filename,
...
@@ -86,24 +88,50 @@ def validate(fluid_model_filename,
raise
ValueError
(
'unsupported Paddle fluid model filename'
)
raise
ValueError
(
'unsupported Paddle fluid model filename'
)
# load data
# load data
logger
.
info
(
'using golden data %s'
,
golden_data_filename
)
if
golden_data_filename
:
if
golden_data_filename
.
endswith
(
'.npz'
):
logger
.
info
(
'using golden data %s'
,
golden_data_filename
)
test_data
=
np
.
load
(
golden_data_filename
,
encoding
=
'bytes'
)
if
golden_data_filename
.
endswith
(
'.npz'
):
input_data
=
test_data
[
'inputs'
].
tolist
()
test_data
=
np
.
load
(
golden_data_filename
,
encoding
=
'bytes'
)
output_data
=
test_data
[
'outputs'
].
tolist
()
input_data
=
test_data
[
'inputs'
].
tolist
()
output_data
=
test_data
[
'outputs'
].
tolist
()
else
:
test_data
=
np
.
load
(
golden_data_filename
,
encoding
=
'bytes'
).
tolist
()
input_data
=
test_data
[
'inputs'
]
output_data
=
test_data
[
'outputs'
]
input_data
=
flatten_dict
(
input_data
)
output_data
=
flatten_dict
(
output_data
)
input_names
=
input_data
.
keys
()
logger
.
info
(
'found %d I/O golden data, starting test ...'
,
len
(
input_data
)
+
len
(
output_data
))
else
:
else
:
test_data
=
np
.
load
(
golden_data_filename
,
encoding
=
'bytes'
).
tolist
()
assert
inference_input_names
,
'input names required for type-shape inference'
input_data
=
test_data
[
'inputs'
]
output_data
=
test_data
[
'outputs'
]
input_names
=
inference_input_names
input_data
=
flatten_dict
(
input_data
)
logger
.
info
(
'using input names: %s'
,
', '
.
join
(
input_names
))
output_data
=
flatten_dict
(
output_data
)
logger
.
info
(
'found %d I/O golden data, starting test ...'
,
# type-shape inference and re-save
len
(
input_data
)
+
len
(
output_data
))
if
save_inference_model
:
for
block
in
prog
.
blocks
:
# DEBUG: reload test for Python code
block_desc
=
block
.
desc
if
basename
.
endswith
(
'.py'
)
and
save_inference_model
:
for
idx_op
in
range
(
block_desc
.
op_size
()):
op_desc
=
block_desc
.
op
(
idx_op
)
if
op_desc
.
type
()
in
(
'feed'
,
'fetch'
):
continue
op_desc
.
infer_var_type
(
block_desc
)
op_desc
.
infer_shape
(
block_desc
)
for
var_name
,
var
in
block
.
vars
.
items
():
var_desc
=
var
.
desc
if
var_desc
.
type
()
!=
fluid
.
core
.
VarDesc
.
VarType
.
LOD_TENSOR
:
continue
# WORKAROUND: dirty way to give dtype to partial-infered vars
# which could not be cleared!
try
:
var
.
to_string
(
True
)
except
ValueError
:
var_desc
.
set_dtype
(
fluid
.
core
.
VarDesc
.
VarType
.
FP32
)
fluid
.
io
.
save_inference_model
(
fluid_model_dir
,
fluid
.
io
.
save_inference_model
(
fluid_model_dir
,
input_
data
.
keys
()
,
input_
names
,
var_outs
,
var_outs
,
exe
,
exe
,
main_program
=
prog
,
main_program
=
prog
,
...
@@ -112,8 +140,12 @@ def validate(fluid_model_filename,
...
@@ -112,8 +140,12 @@ def validate(fluid_model_filename,
fluid
.
io
.
load_inference_model
(
fluid_model_dir
,
exe
)
fluid
.
io
.
load_inference_model
(
fluid_model_dir
,
exe
)
logger
.
info
(
'model re-load passed'
)
logger
.
info
(
'model re-load passed'
)
if
not
golden_data_filename
:
return
True
# execute
# execute
outputs
=
exe
.
run
(
prog
,
feed
=
input_data
,
fetch_list
=
out_names
)
outputs
=
exe
.
run
(
prog
,
feed
=
input_data
,
fetch_list
=
out_names
)
# out_names can be vars
logger
.
info
(
'execution passed'
)
logger
.
info
(
'execution passed'
)
# validate
# validate
...
@@ -134,11 +166,10 @@ def validate(fluid_model_filename,
...
@@ -134,11 +166,10 @@ def validate(fluid_model_filename,
logger
.
info
(
'accuracy passed'
)
logger
.
info
(
'accuracy passed'
)
else
:
else
:
logger
.
info
(
'accuracy not passed'
)
logger
.
info
(
'accuracy not passed'
)
return
passed
return
passed
if
__name__
==
'__main__'
:
def
main
()
:
import
argparse
import
argparse
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
...
@@ -160,6 +191,7 @@ if __name__ == '__main__':
...
@@ -160,6 +191,7 @@ if __name__ == '__main__':
'--test_data'
,
'--test_data'
,
'-t'
,
'-t'
,
type
=
str
,
type
=
str
,
default
=
''
,
help
=
'I/O golden data for validation, e.g. test.npy, test.npz'
,
help
=
'I/O golden data for validation, e.g. test.npy, test.npz'
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
...
@@ -175,19 +207,36 @@ if __name__ == '__main__':
...
@@ -175,19 +207,36 @@ if __name__ == '__main__':
default
=
1e-2
,
default
=
1e-2
,
help
=
'assertion relative tolerance for validation'
,
help
=
'assertion relative tolerance for validation'
,
)
)
parser
.
add_argument
(
'--infer_inputs'
,
'-i'
,
nargs
=
'?'
,
default
=
None
,
const
=
''
,
help
=
'perform type-shape inference with given input names and re-save model'
,
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
logging_format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_level
=
logging
.
DEBUG
if
args
.
debug
else
logging
.
INFO
logging_level
=
logging
.
DEBUG
if
args
.
debug
else
logging
.
INFO
logging
.
basicConfig
(
format
=
logging_format
,
level
=
logging_level
)
logging
.
basicConfig
(
format
=
logging_format
,
level
=
logging_level
)
debug
=
args
.
debug
#
debug = args.debug
fluid_model_filename
=
args
.
model
[
0
]
fluid_model_filename
=
args
.
model
[
0
]
golden_data_filename
=
args
.
test_data
golden_data_filename
=
args
.
test_data
atol
,
rtol
=
args
.
atol
,
args
.
rtol
atol
,
rtol
=
args
.
atol
,
args
.
rtol
save_inference_model
=
args
.
infer_inputs
is
not
None
inference_input_names
=
args
.
infer_inputs
.
split
(
','
)
if
args
.
infer_inputs
else
None
validate
(
fluid_model_filename
,
validate
(
fluid_model_filename
,
golden_data_filename
,
golden_data_filename
=
golden_data_filename
,
atol
=
atol
,
atol
=
atol
,
rtol
=
rtol
,
rtol
=
rtol
,
save_inference_model
=
debug
)
save_inference_model
=
save_inference_model
,
inference_input_names
=
inference_input_names
)
if
__name__
==
'__main__'
:
main
()
onnx2fluid/onnx2fluid/writer.py
浏览文件 @
826481c4
...
@@ -44,6 +44,8 @@ def irepr(obj, to='_'):
...
@@ -44,6 +44,8 @@ def irepr(obj, to='_'):
def
flatten_list
(
obj
,
out
=
None
):
def
flatten_list
(
obj
,
out
=
None
):
assert
isinstance
(
obj
,
list
),
'list type required'
if
out
is
None
:
if
out
is
None
:
out
=
type
(
obj
)()
out
=
type
(
obj
)()
for
item
in
obj
:
for
item
in
obj
:
...
@@ -56,12 +58,12 @@ def flatten_list(obj, out=None):
...
@@ -56,12 +58,12 @@ def flatten_list(obj, out=None):
def
make_attr_name
(
name
):
def
make_attr_name
(
name
):
"""
"""
make a valid code name for ParamAttr
make a valid code name for ParamAttr
"""
"""
if
name
==
''
:
assert
name
!=
''
,
'name should not be empty'
raise
ValueError
(
'name should not be empty'
)
for
s
in
'
\\
|/:'
:
#
for
s
in
'
\\
|/:
-
'
:
#
name
=
name
.
replace
(
s
,
'_'
)
name
=
name
.
replace
(
s
,
'_'
)
if
not
name
.
startswith
(
'_'
):
if
not
name
.
startswith
(
'_'
):
name
=
'_'
+
name
name
=
'_'
+
name
...
@@ -70,8 +72,8 @@ def make_attr_name(name):
...
@@ -70,8 +72,8 @@ def make_attr_name(name):
class
Program
(
object
):
class
Program
(
object
):
"""
"""
fluid Python code and ProgramDesc wrapper
fluid Python code and ProgramDesc wrapper
"""
"""
DTYPE_TO_FRAMEWORK_DTYPE
=
{
DTYPE_TO_FRAMEWORK_DTYPE
=
{
'bool'
:
framework_pb2
.
VarType
.
BOOL
,
'bool'
:
framework_pb2
.
VarType
.
BOOL
,
...
@@ -88,8 +90,8 @@ class Program(object):
...
@@ -88,8 +90,8 @@ class Program(object):
@
staticmethod
@
staticmethod
def
Dtype
(
dtype
):
def
Dtype
(
dtype
):
"""
"""
convert dtype to fulid framework dtype
convert dtype to fulid framework dtype
"""
"""
dtype
=
np
.
dtype
(
dtype
).
name
dtype
=
np
.
dtype
(
dtype
).
name
return
Program
.
DTYPE_TO_FRAMEWORK_DTYPE
[
dtype
]
return
Program
.
DTYPE_TO_FRAMEWORK_DTYPE
[
dtype
]
...
@@ -97,8 +99,8 @@ class Program(object):
...
@@ -97,8 +99,8 @@ class Program(object):
@
staticmethod
@
staticmethod
def
OpDescVars
(
vals
,
*
keys
):
def
OpDescVars
(
vals
,
*
keys
):
"""
"""
make (OpDesc.Var)s
make (OpDesc.Var)s
"""
"""
od_vars
=
[]
od_vars
=
[]
for
idx
,
key
in
enumerate
(
keys
):
for
idx
,
key
in
enumerate
(
keys
):
...
@@ -112,8 +114,8 @@ class Program(object):
...
@@ -112,8 +114,8 @@ class Program(object):
@
staticmethod
@
staticmethod
def
OpDescAttrs
(
attrs
):
def
OpDescAttrs
(
attrs
):
"""
"""
make (OpDesc.Attr)s
make (OpDesc.Attr)s
"""
"""
od_attrs
=
[]
od_attrs
=
[]
for
key
,
value
in
attrs
.
items
():
for
key
,
value
in
attrs
.
items
():
...
@@ -178,8 +180,8 @@ class Program(object):
...
@@ -178,8 +180,8 @@ class Program(object):
def
Code
(
self
,
code
):
def
Code
(
self
,
code
):
"""
"""
add Python code
add Python code
"""
"""
if
self
.
code_mutable
:
if
self
.
code_mutable
:
self
.
codes
.
append
(
code
)
self
.
codes
.
append
(
code
)
...
@@ -190,16 +192,16 @@ class Program(object):
...
@@ -190,16 +192,16 @@ class Program(object):
output_val_keys
=
None
,
output_val_keys
=
None
,
attrs
=
None
):
attrs
=
None
):
"""
"""
add OpDesc
add OpDesc
"""
"""
desc
=
framework_pb2
.
OpDesc
()
desc
=
framework_pb2
.
OpDesc
()
desc
.
type
=
op_type
desc
.
type
=
op_type
if
input_val_keys
is
not
None
:
if
input_val_keys
:
desc
.
inputs
.
extend
(
self
.
OpDescVars
(
*
input_val_keys
))
desc
.
inputs
.
extend
(
self
.
OpDescVars
(
*
input_val_keys
))
if
output_val_keys
is
not
None
:
if
output_val_keys
:
desc
.
outputs
.
extend
(
self
.
OpDescVars
(
*
output_val_keys
))
desc
.
outputs
.
extend
(
self
.
OpDescVars
(
*
output_val_keys
))
if
attrs
is
not
None
:
if
attrs
:
desc
.
attrs
.
extend
(
self
.
OpDescAttrs
(
attrs
))
desc
.
attrs
.
extend
(
self
.
OpDescAttrs
(
attrs
))
self
.
op_descs
.
append
(
desc
)
self
.
op_descs
.
append
(
desc
)
return
desc
return
desc
...
@@ -210,8 +212,8 @@ class Program(object):
...
@@ -210,8 +212,8 @@ class Program(object):
value_info
=
None
,
value_info
=
None
,
remove_batch
=
None
):
remove_batch
=
None
):
"""
"""
add VarDesc,
add VarDesc,
"""
"""
assert
var_name
not
in
self
.
var_descs
,
'var naming conflicted'
assert
var_name
not
in
self
.
var_descs
,
'var naming conflicted'
...
@@ -220,13 +222,16 @@ class Program(object):
...
@@ -220,13 +222,16 @@ class Program(object):
var_desc
.
persistable
=
persistable
var_desc
.
persistable
=
persistable
var_desc
.
type
.
type
=
framework_pb2
.
VarType
.
LOD_TENSOR
var_desc
.
type
.
type
=
framework_pb2
.
VarType
.
LOD_TENSOR
self
.
var_descs
[
var_name
]
=
var_desc
self
.
var_descs
[
var_name
]
=
var_desc
if
value_info
:
if
value_info
:
self
.
VarTypeInfo
(
var_name
,
value_info
,
remove_batch
=
remove_batch
)
self
.
VarTypeShapeInfo
(
var_name
,
value_info
,
remove_batch
=
remove_batch
)
def
Op
(
self
,
domain
,
op_type
,
*
args
,
**
kwargs
):
def
Op
(
self
,
domain
,
op_type
,
*
args
,
**
kwargs
):
"""
"""
convert an ONNX op and add it to program
convert an ONNX op and add it to program
"""
"""
if
domain
!=
''
:
# TODO: symbolic file routing by domain
if
domain
!=
''
:
# TODO: symbolic file routing by domain
raise
ValueError
(
'only default domain supported'
)
raise
ValueError
(
'only default domain supported'
)
...
@@ -242,8 +247,8 @@ class Program(object):
...
@@ -242,8 +247,8 @@ class Program(object):
def
IntermediateOp
(
self
,
domain
,
op_type
,
*
args
,
**
kwargs
):
def
IntermediateOp
(
self
,
domain
,
op_type
,
*
args
,
**
kwargs
):
"""
"""
convert an intermediate ONNX op declaring in desc program only
convert an intermediate ONNX op declaring in desc program only
"""
"""
code_mutable
=
self
.
code_mutable
code_mutable
=
self
.
code_mutable
self
.
code_mutable
=
False
self
.
code_mutable
=
False
...
@@ -255,10 +260,10 @@ class Program(object):
...
@@ -255,10 +260,10 @@ class Program(object):
else
:
else
:
self
.
code_mutable
=
code_mutable
self
.
code_mutable
=
code_mutable
def
VarTypeInfo
(
self
,
var_name
,
value_info
,
remove_batch
=
None
):
def
VarTypeShapeInfo
(
self
,
var_name
,
value_info
,
remove_batch
=
None
):
"""
set value_info for var
"""
"""
set value_info for var
"""
if
var_name
not
in
self
.
var_descs
:
if
var_name
not
in
self
.
var_descs
:
return
return
...
@@ -284,8 +289,8 @@ class Program(object):
...
@@ -284,8 +289,8 @@ class Program(object):
class
Writer
(
object
):
class
Writer
(
object
):
"""
"""
fluid code and desc writter
fluid code and desc writter
"""
"""
# CODE_INDENT = ' ' * 4
# CODE_INDENT = ' ' * 4
CODE_INDENT
=
'
\t
'
CODE_INDENT
=
'
\t
'
...
@@ -293,8 +298,8 @@ class Writer(object):
...
@@ -293,8 +298,8 @@ class Writer(object):
@
staticmethod
@
staticmethod
def
header_code
(
func_name
,
info
=
''
):
def
header_code
(
func_name
,
info
=
''
):
"""
"""
Python header codes
Python header codes
"""
"""
codes
=
[]
codes
=
[]
codes
.
append
(
'"""'
)
codes
.
append
(
'"""'
)
...
@@ -315,8 +320,8 @@ class Writer(object):
...
@@ -315,8 +320,8 @@ class Writer(object):
def
emit_op
(
prog
,
name
,
domain
,
op_type
,
inputs
,
outputs
,
attrs
,
def
emit_op
(
prog
,
name
,
domain
,
op_type
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
value_infos
,
*
args
,
**
kwargs
):
"""
"""
emit an ONNX op into program
emit an ONNX op into program
"""
"""
prog
.
Code
(
'# {}, {}::{}: {} -> {}, {}'
.
format
(
name
,
domain
,
op_type
,
prog
.
Code
(
'# {}, {}::{}: {} -> {}, {}'
.
format
(
name
,
domain
,
op_type
,
inputs
,
outputs
,
inputs
,
outputs
,
...
@@ -334,8 +339,8 @@ class Writer(object):
...
@@ -334,8 +339,8 @@ class Writer(object):
@
staticmethod
@
staticmethod
def
emit_param
(
prog
,
name
,
value_info
):
def
emit_param
(
prog
,
name
,
value_info
):
"""
"""
emit an ONNX weight into program
emit an ONNX weight into program
"""
"""
if
value_info
.
get
(
'embeded_as'
,
[]):
if
value_info
.
get
(
'embeded_as'
,
[]):
var_names
=
value_info
[
'embeded_as'
]
var_names
=
value_info
[
'embeded_as'
]
...
@@ -359,8 +364,8 @@ class Writer(object):
...
@@ -359,8 +364,8 @@ class Writer(object):
@
staticmethod
@
staticmethod
def
emit_inputs
(
prog
,
names
,
value_infos
,
remove_batch
=
None
):
def
emit_inputs
(
prog
,
names
,
value_infos
,
remove_batch
=
None
):
"""
"""
emit ONNX inputs into program
emit ONNX inputs into program
"""
"""
for
idx
,
name
in
enumerate
(
names
):
for
idx
,
name
in
enumerate
(
names
):
var_name
=
make_var_name
(
name
)
var_name
=
make_var_name
(
name
)
...
@@ -396,8 +401,8 @@ class Writer(object):
...
@@ -396,8 +401,8 @@ class Writer(object):
@
staticmethod
@
staticmethod
def
emit_outputs
(
prog
,
names
):
#, value_infos
def
emit_outputs
(
prog
,
names
):
#, value_infos
"""
"""
emit ONNX outputs into program
emit ONNX outputs into program
"""
"""
code
=
'return '
code
=
'return '
for
idx
,
name
in
enumerate
(
names
):
for
idx
,
name
in
enumerate
(
names
):
...
@@ -416,8 +421,8 @@ class Writer(object):
...
@@ -416,8 +421,8 @@ class Writer(object):
@
staticmethod
@
staticmethod
def
add_codes
(
codes
,
others
,
indent
):
def
add_codes
(
codes
,
others
,
indent
):
"""
"""
flatten codes in program
flatten codes in program
"""
"""
for
code
in
flatten_list
(
others
):
for
code
in
flatten_list
(
others
):
codes
.
append
(
Writer
.
CODE_INDENT
*
indent
+
code
)
codes
.
append
(
Writer
.
CODE_INDENT
*
indent
+
code
)
...
@@ -426,11 +431,10 @@ class Writer(object):
...
@@ -426,11 +431,10 @@ class Writer(object):
@
staticmethod
@
staticmethod
def
write_weight
(
weight
,
filename
):
def
write_weight
(
weight
,
filename
):
"""
"""
write single weight in fluid desc
write single weight in fluid desc
"""
"""
if
not
isinstance
(
weight
,
np
.
ndarray
):
assert
isinstance
(
weight
,
np
.
ndarray
),
'weight is not an ndarray'
raise
TypeError
(
'weight is not an ndarray'
)
tensor_desc
=
framework_pb2
.
VarType
.
TensorDesc
()
tensor_desc
=
framework_pb2
.
VarType
.
TensorDesc
()
tensor_desc
.
data_type
=
Program
.
Dtype
(
weight
.
dtype
)
tensor_desc
.
data_type
=
Program
.
Dtype
(
weight
.
dtype
)
...
@@ -448,12 +452,11 @@ class Writer(object):
...
@@ -448,12 +452,11 @@ class Writer(object):
@
staticmethod
@
staticmethod
def
write_weights
(
weights
,
save_dir
):
def
write_weights
(
weights
,
save_dir
):
"""
"""
write multiple weights in each fluid desc
write multiple weights in each fluid desc
"""
"""
for
name
,
weight
in
weights
.
items
():
for
name
,
weight
in
weights
.
items
():
if
not
isinstance
(
weights
,
dict
):
assert
isinstance
(
weights
,
dict
),
'dict type weights required'
raise
TypeError
(
'dict type weights required'
)
var_name
=
make_var_name
(
name
)
var_name
=
make_var_name
(
name
)
filename
=
os
.
path
.
join
(
save_dir
,
var_name
)
filename
=
os
.
path
.
join
(
save_dir
,
var_name
)
...
@@ -463,8 +466,8 @@ class Writer(object):
...
@@ -463,8 +466,8 @@ class Writer(object):
@
staticmethod
@
staticmethod
def
write_code_file
(
filename
,
header_code
,
*
body_codes
):
def
write_code_file
(
filename
,
header_code
,
*
body_codes
):
"""
"""
write Python code to file
write Python code to file
"""
"""
codes
=
[]
codes
=
[]
Writer
.
add_codes
(
codes
,
header_code
,
0
)
Writer
.
add_codes
(
codes
,
header_code
,
0
)
...
@@ -481,8 +484,8 @@ class Writer(object):
...
@@ -481,8 +484,8 @@ class Writer(object):
@
staticmethod
@
staticmethod
def
write_desc_file
(
filename
,
op_descs
,
var_descs
):
def
write_desc_file
(
filename
,
op_descs
,
var_descs
):
"""
"""
write desc program to file
write desc program to file
"""
"""
prog_desc
=
framework_pb2
.
ProgramDesc
()
prog_desc
=
framework_pb2
.
ProgramDesc
()
block_desc
=
prog_desc
.
blocks
.
add
()
block_desc
=
prog_desc
.
blocks
.
add
()
...
...
onnx2fluid/setup.cfg
浏览文件 @
826481c4
...
@@ -54,8 +54,8 @@ zip_safe = True
...
@@ -54,8 +54,8 @@ zip_safe = True
[options.entry_points]
[options.entry_points]
console_scripts =
console_scripts =
onnx2fluid = onnx2fluid.__main__
onnx2fluid = onnx2fluid.__main__
onnx2fluid_convert = onnx2fluid.conversion
onnx2fluid_convert = onnx2fluid.conversion
:main
onnx2fluid_validate = onnx2fluid.validation
onnx2fluid_validate = onnx2fluid.validation
:main
# 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下
# 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下
# 仅支持文件,不支持目录,但可以使用通配
# 仅支持文件,不支持目录,但可以使用通配
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录