Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
324b75ee
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
324b75ee
编写于
9月 07, 2019
作者:
C
channingss
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug & support new op for ssd
上级
b6e359f1
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
282 addition
and
168 deletion
+282
-168
x2paddle/convert.py
x2paddle/convert.py
+3
-0
x2paddle/decoder/onnx_decoder.py
x2paddle/decoder/onnx_decoder.py
+105
-78
x2paddle/op_mapper/onnx_custom_layer/InstanceNormalization.py
...ddle/op_mapper/onnx_custom_layer/InstanceNormalization.py
+6
-4
x2paddle/op_mapper/onnx_custom_layer/__init__.py
x2paddle/op_mapper/onnx_custom_layer/__init__.py
+12
-0
x2paddle/op_mapper/onnx_custom_layer/register.py
x2paddle/op_mapper/onnx_custom_layer/register.py
+2
-1
x2paddle/op_mapper/onnx_directly_map.py
x2paddle/op_mapper/onnx_directly_map.py
+10
-3
x2paddle/op_mapper/onnx_op_mapper.py
x2paddle/op_mapper/onnx_op_mapper.py
+144
-82
未找到文件。
x2paddle/convert.py
浏览文件 @
324b75ee
...
...
@@ -89,6 +89,9 @@ def tf2paddle(model_path, save_dir):
mapper
.
save_inference_model
(
save_dir
)
0
def
caffe2paddle
(
proto
,
weight
,
save_dir
,
caffe_proto
):
from
x2paddle.decoder.caffe_decoder
import
CaffeDecoder
from
x2paddle.op_mapper.caffe_op_mapper
import
CaffeOpMapper
...
...
x2paddle/decoder/onnx_decoder.py
浏览文件 @
324b75ee
...
...
@@ -17,7 +17,6 @@ from x2paddle.core.fluid_code import FluidCode
from
onnx.checker
import
ValidationError
from
onnx.checker
import
check_model
from
onnx.utils
import
polish_model
from
onnx.version_converter
import
convert_version
from
onnx
import
helper
from
onnx.helper
import
get_attribute_value
,
make_attribute
from
onnx.shape_inference
import
infer_shapes
...
...
@@ -26,6 +25,7 @@ from onnx.numpy_helper import to_array
from
onnx
import
AttributeProto
,
TensorProto
,
GraphProto
from
collections
import
OrderedDict
as
Dict
import
onnx
from
onnx.helper
import
ValueInfoProto
import
numpy
as
np
from
copy
import
deepcopy
import
logging
as
_logging
...
...
@@ -47,6 +47,7 @@ class ONNXGraphNode(GraphNode):
self
.
weight_inputs
=
list
()
self
.
out_shapes
=
list
()
self
.
dtype
=
None
self
.
which_child
=
{}
def
get_attr_map
(
self
):
"""
...
...
@@ -60,10 +61,9 @@ class ONNXGraphNode(GraphNode):
@
property
def
value
(
self
):
assert
'Constant'
in
self
.
layer_type
,
"Only Constant | ConstantOfShape node has value."
attr
=
self
.
layer
.
attribute
[
'value'
]
if
'value'
not
in
self
.
attr_map
:
return
None
return
self
.
attr_map
[
name
]
return
self
.
attr_map
[
'value'
]
def
get_attribute_value2
(
self
,
attr
):
"""
...
...
@@ -105,18 +105,29 @@ class ONNXGraphDataNode(GraphNode):
self
.
fluid_code
=
FluidCode
()
self
.
weight
=
None
self
.
embeded_as
=
None
self
.
which_child
=
{}
@
property
def
out_shapes
(
self
):
if
isinstance
(
self
.
layer
,
ValueInfoProto
):
values
=
self
.
layer
.
type
.
tensor_type
.
shape
.
dim
out_shapes
=
list
()
out_shapes
.
append
([
dim
.
dim_value
for
dim
in
values
])
return
out_shapes
else
:
values
=
self
.
layer
.
dims
out_shapes
=
list
()
out_shapes
.
append
(
values
)
return
out_shapes
@
property
def
dtype
(
self
):
if
isinstance
(
self
.
layer
,
ValueInfoProto
):
dtype
=
self
.
layer
.
type
.
tensor_type
.
elem_type
return
TENSOR_TYPE_TO_NP_TYPE
[
dtype
]
else
:
dtype
=
self
.
layer
.
data_type
return
TENSOR_TYPE_TO_NP_TYPE
[
dtype
]
class
ONNXGraph
(
Graph
):
...
...
@@ -165,16 +176,21 @@ class ONNXGraph(Graph):
"""
build topo_sort of ONNX model
"""
data_node
=
self
.
place_holder_nodes
[
0
]
value_info
=
self
.
value_infos
[
data_node
]
input_shape
=
value_info
[
'shape'
]
self
.
get_results_of_inference
(
self
.
onnx_model
,
input_shape
)
data_nodes
=
self
.
place_holder_nodes
self
.
get_results_of_inference_rt
(
self
.
onnx_model
,
data_nodes
)
for
layer
in
self
.
model
.
node
:
node
=
ONNXGraphNode
(
layer
)
self
.
node_map
[
layer
.
name
]
=
node
for
opt
in
layer
.
output
:
if
opt
in
self
.
value_infos
:
value_info
=
self
.
value_infos
[
opt
]
if
len
(
value_info
[
'shape'
]
)
==
0
or
value_info
[
'dtype'
]
is
None
:
_
,
dtype
,
shape
=
self
.
get_dynamic_shape
(
opt
)
node
.
dtype
=
dtype
node
.
out_shapes
.
append
(
shape
)
else
:
node
.
dtype
=
value_info
[
'dtype'
]
node
.
out_shapes
.
append
(
value_info
[
'shape'
])
else
:
...
...
@@ -191,20 +207,40 @@ class ONNXGraph(Graph):
is_global_input
=
is_place_holder
)
#set data node's weight
for
name
,
weight
in
self
.
graph_weights
(
self
.
model
):
for
initializer
in
self
.
model
.
initializer
:
name
=
initializer
.
name
weight
=
to_array
(
initializer
)
if
name
in
self
.
node_map
:
if
isinstance
(
self
.
node_map
[
name
],
ONNXGraphDataNode
):
self
.
node_map
[
name
].
weight
=
weight
self
.
node_map
[
name
].
embeded_as
=
[]
else
:
self
.
node_map
[
name
]
=
ONNXGraphDataNode
(
initializer
,
layer_name
=
name
,
is_global_input
=
False
)
self
.
node_map
[
name
].
weight
=
weight
self
.
node_map
[
name
].
embeded_as
=
[]
#generate connection between nodes for topo
for
layer_name
,
node
in
self
.
node_map
.
items
():
if
isinstance
(
node
,
ONNXGraphNode
):
for
idx
,
in_node
in
enumerate
(
node
.
layer
.
input
):
if
in_node
not
in
self
.
node_map
:
flag
=
0
for
nd
in
self
.
model
.
node
:
for
idx
,
opt
in
enumerate
(
nd
.
output
):
if
opt
==
in_node
:
self
.
connect
(
nd
.
name
,
layer_name
)
flag
=
1
print
(
nd
.
name
+
'->'
+
layer_name
)
node
.
which_child
[
nd
.
name
]
=
idx
break
if
flag
==
1
:
break
if
flag
==
0
:
raise
Exception
(
'input[{}] of node[{}] does not exist in node_map'
.
format
(
in_node
,
layer_name
))
'input[{}] of node[{}] does not exist in node_map'
.
format
(
in_node
,
layer_name
))
else
:
self
.
connect
(
in_node
,
layer_name
)
#generate topo
...
...
@@ -212,13 +248,14 @@ class ONNXGraph(Graph):
self
.
input_nodes
=
self
.
place_holder_nodes
def
get_nodes
(
self
,
names
,
copy
=
False
):
"""
get nodes by more than one name
"""
nodes
=
[]
for
name
in
names
:
nodes
.
add
(
self
.
get_node
(
name
,
copy
=
copy
))
def
get_input_node
(
self
,
node
,
idx
=
0
,
copy
=
False
):
if
len
(
node
.
which_child
)
==
0
:
return
super
(
ONNXGraph
,
self
).
get_node
(
node
.
inputs
[
idx
],
copy
)
else
:
ipt_node
=
super
(
ONNXGraph
,
self
).
get_node
(
node
.
inputs
[
idx
],
copy
)
if
ipt_node
.
layer_name
in
node
.
which_child
:
ipt_node
.
index
=
node
.
which_child
[
ipt_node
.
layer_name
]
return
ipt_node
def
graph_weights
(
self
,
graph
):
"""
...
...
@@ -270,7 +307,7 @@ class ONNXGraph(Graph):
}
return
value_info
def
get_results_of_inference
(
self
,
model
,
shape
):
def
get_results_of_inference
(
self
,
model
,
data_nodes
):
try
:
import
torch
version
=
torch
.
__version__
...
...
@@ -284,9 +321,11 @@ class ONNXGraph(Graph):
return
from
x2paddle.decoder.onnx_backend
import
prepare
np_images
=
np
.
random
.
rand
(
shape
[
0
],
shape
[
1
],
shape
[
2
],
shape
[
3
]).
astype
(
'float32'
)
inputs
=
[]
for
data_node
in
data_nodes
:
value_info
=
self
.
value_infos
[
data_node
]
ipt
=
np
.
random
.
random
(
value_info
[
'shape'
]).
astype
(
'float32'
)
inputs
.
append
(
ipt
)
outputs
=
[]
for
node
in
model
.
graph
.
node
:
value_info
=
helper
.
make_tensor_value_info
(
node
.
name
,
...
...
@@ -301,15 +340,46 @@ class ONNXGraph(Graph):
prepared_backend
=
prepare
(
model
,
device
=
'CPU'
,
no_check_UNSAFE
=
True
)
res
=
prepared_backend
.
run
(
inputs
=
np_image
s
)
res
=
prepared_backend
.
run
(
inputs
=
input
s
)
for
idx
,
info
in
enumerate
(
tmp_outputs
):
self
.
results_of_inference
[
info
.
name
]
=
res
[
idx
]
outputs
=
outputs
[
254
:]
return
def
get_results_of_inference_rt
(
self
,
model
,
data_nodes
):
import
onnxruntime
as
rt
inputs
=
[]
for
data_node
in
data_nodes
:
value_info
=
self
.
value_infos
[
data_node
]
ipt
=
np
.
random
.
random
(
value_info
[
'shape'
]).
astype
(
'float32'
)
inputs
.
append
(
ipt
)
model
=
onnx
.
shape_inference
.
infer_shapes
(
model
)
outputs
=
[]
for
value_info
in
model
.
graph
.
value_info
:
outputs
.
append
(
value_info
)
model
.
graph
.
ClearField
(
'output'
)
model
.
graph
.
output
.
MergeFrom
(
outputs
)
onnx
.
save
(
model
,
'./onnx_model_infer.onnx'
)
sess
=
rt
.
InferenceSession
(
'./onnx_model_infer.onnx'
)
inputs_dict
=
{}
for
i
,
ipt
in
enumerate
(
inputs
):
inputs_dict
[
sess
.
get_inputs
()[
i
].
name
]
=
ipt
res
=
sess
.
run
(
None
,
input_feed
=
inputs_dict
)
for
idx
,
info
in
enumerate
(
outputs
):
self
.
results_of_inference
[
info
.
name
]
=
res
[
idx
]
return
def
get_dynamic_shape
(
self
,
layer
):
"""
get dynamic shape from
caffe2.backend
get dynamic shape from
infer_result
"""
output
=
self
.
results_of_inference
[
layer
]
return
output
.
tolist
(),
output
.
dtype
,
output
.
shape
...
...
@@ -334,8 +404,8 @@ class ONNXDecoder(object):
self
.
standardize_variable_name
(
model
.
graph
)
self
.
model
=
model
graph
_def
=
model
.
graph
self
.
onnx_graph
=
ONNXGraph
(
graph
_def
,
model
)
graph
=
model
.
graph
self
.
onnx_graph
=
ONNXGraph
(
graph
,
model
)
self
.
onnx_graph
.
build
()
def
build_value_refs
(
self
,
nodes
):
...
...
@@ -476,7 +546,7 @@ class ONNXDecoder(object):
if
name
==
''
:
raise
ValueError
(
'name should not be empty'
)
for
s
in
' .*?
\\
/-:'
:
#
for
s
in
' .*?
\\
/-:'
:
name
=
name
.
replace
(
s
,
'_'
)
return
'_'
+
name
...
...
@@ -499,46 +569,3 @@ class ONNXDecoder(object):
node
.
input
[
i
]
=
self
.
make_variable_name
(
node
.
input
[
i
])
for
i
in
range
(
len
(
node
.
output
)):
node
.
output
[
i
]
=
self
.
make_variable_name
(
node
.
output
[
i
])
def
split_model
(
self
,
model
,
outputs
=
None
):
"""
Takes a model and changes its outputs.
"""
if
outputs
is
None
:
raise
RuntimeError
(
"outputs is None"
)
if
outputs
==
model
.
graph
.
output
[
0
].
name
:
return
model
nodes
=
model
.
graph
.
node
keep_nodes
=
[]
# all the nodes we need to keep.
for
node
in
nodes
:
if
outputs
in
node
.
output
:
keep_nodes
.
append
(
node
)
break
keep_nodes
.
append
(
node
)
infer_shapes
=
onnx
.
shape_inference
.
infer_shapes
(
model
)
var_out
=
[]
for
value_info
in
infer_shapes
.
graph
.
value_info
:
if
value_info
.
name
==
outputs
:
var_out
.
append
(
value_info
)
break
graph
=
helper
.
make_graph
(
keep_nodes
,
model
.
graph
.
name
,
model
.
graph
.
input
,
var_out
,
model
.
graph
.
initializer
)
onnx_model
=
helper
.
make_model
(
graph
)
onnx_model
.
ir_version
=
model
.
ir_version
onnx_model
.
producer_name
=
model
.
producer_name
onnx_model
.
producer_version
=
model
.
producer_version
onnx_model
.
domain
=
model
.
domain
onnx_model
.
model_version
=
model
.
model_version
onnx_model
.
doc_string
=
model
.
doc_string
if
len
(
onnx_model
.
graph
.
input
)
!=
len
(
model
.
graph
.
input
):
raise
RuntimeError
(
"Input mismatch {} != {}"
.
format
(
len
(
onnx_model
.
input
),
len
(
model
.
input
)))
return
onnx_model
x2paddle/op_mapper/onnx_custom_layer/InstanceNormalization.py
浏览文件 @
324b75ee
...
...
@@ -22,7 +22,8 @@ def InstanceNormalization_shape(input_shape):
def
InstanceNormalization_layer
(
inputs
,
name
=
None
):
# TODO(lvmengsi@baidu.com): Check the accuracy when using fluid.layers.layer_norm.
epsilon
=
1e-5
mean
=
fluid
.
layers
.
reduce_mean
(
inputs
,
dim
=
[
2
,
3
],
keep_dim
=
True
)
input_
=
inputs
[
0
]
mean
=
fluid
.
layers
.
reduce_mean
(
input_
,
dim
=
[
2
,
3
],
keep_dim
=
True
)
var
=
fluid
.
layers
.
reduce_mean
(
fluid
.
layers
.
square
(
inputs
-
mean
),
dim
=
[
2
,
3
],
keep_dim
=
True
)
...
...
@@ -36,13 +37,13 @@ def InstanceNormalization_layer(inputs, name=None):
initializer
=
fluid
.
initializer
.
Constant
(
0.0
),
trainable
=
True
)
scale
=
fluid
.
layers
.
create_parameter
(
attr
=
scale_param
,
shape
=
input
s
.
shape
[
1
:
2
],
shape
=
input
_
.
shape
[
1
:
2
],
dtype
=
"float32"
)
offset
=
fluid
.
layers
.
create_parameter
(
attr
=
offset_param
,
shape
=
input
s
.
shape
[
1
:
2
],
shape
=
input
_
.
shape
[
1
:
2
],
dtype
=
"float32"
)
tmp
=
fluid
.
layers
.
elementwise_mul
(
x
=
(
input
s
-
mean
),
y
=
scale
,
axis
=
1
)
tmp
=
fluid
.
layers
.
elementwise_mul
(
x
=
(
input
_
-
mean
),
y
=
scale
,
axis
=
1
)
tmp
=
tmp
/
fluid
.
layers
.
sqrt
(
var
+
epsilon
)
tmp
=
fluid
.
layers
.
elementwise_add
(
tmp
,
offset
,
axis
=
1
)
return
tmp
...
...
@@ -56,4 +57,5 @@ def InstanceNormalization_weights(name, data=None):
register
(
kind
=
'InstanceNormalization'
,
shape
=
InstanceNormalization_shape
,
layer
=
InstanceNormalization_layer
,
child_func
=
None
,
weights
=
InstanceNormalization_weights
)
x2paddle/op_mapper/onnx_custom_layer/__init__.py
浏览文件 @
324b75ee
...
...
@@ -16,6 +16,7 @@ from .register import get_registered_layers
#custom layer import begins
from
.
import
InstanceNormalization
from
.
import
NonMaxSuppression
#custom layer import ends
custom_layers
=
get_registered_layers
()
...
...
@@ -95,6 +96,17 @@ def make_custom_layer(node):
return
inspect
.
getsource
(
layer_func
),
layer_func
def
make_custom_child_func
(
node
):
""" get the code which implement the custom layer function
"""
layer_type
=
node
.
layer_type
assert
layer_type
in
custom_layers
,
"layer[%s] not exist in custom layers"
%
(
layer_type
)
child_func
=
custom_layers
[
layer_type
][
'child_func'
]
import
inspect
return
inspect
.
getsource
(
child_func
),
child_func
def
deal_weights
(
node
,
data
=
None
):
""" deal the weights of the custom layer
"""
...
...
x2paddle/op_mapper/onnx_custom_layer/register.py
浏览文件 @
324b75ee
...
...
@@ -17,7 +17,7 @@
g_custom_layers
=
{}
def
register
(
kind
,
shape
,
layer
,
weights
):
def
register
(
kind
,
shape
,
layer
,
child_func
,
weights
):
""" register a custom layer or a list of custom layers
Args:
...
...
@@ -48,6 +48,7 @@ def register(kind, shape, layer, weights):
g_custom_layers
[
k
]
=
{
'shape'
:
shape
,
'layer'
:
layer
,
'child_func'
:
child_func
,
'weights'
:
weights
}
...
...
x2paddle/op_mapper/onnx_directly_map.py
浏览文件 @
324b75ee
...
...
@@ -32,6 +32,9 @@ default_op_mapping = {
'Mul'
:
[
'elementwise_mul'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
'Sub'
:
[
'elementwise_sub'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
'Clip'
:
[
'clip'
,
[
'X'
],
[
'Out'
],
dict
(),
...
...
@@ -42,6 +45,7 @@ default_op_mapping = {
dtype
=
_np
.
uint8
).
view
(
_np
.
float32
)),
)
],
'Ceil'
:
[
'ceil'
,
[
'X'
],
[
'Out'
]],
'ReduceMean'
:
[
'reduce_mean'
,
[
'X'
],
[
'Out'
],
dict
(
axes
=
'dim'
,
keepdims
=
'keep_dim'
),
...
...
@@ -52,7 +56,11 @@ default_op_mapping = {
dict
(
axes
=
'dim'
,
keepdims
=
'keep_dim'
),
dict
(
keep_dim
=
1
)
],
'ReduceMin'
:
[
'reduce_min'
,
[
'X'
],
[
'Out'
],
dict
(
axes
=
'dim'
,
keepdims
=
'keep_dim'
),
dict
(
keep_dim
=
1
)
],
#active function
'Relu'
:
[
'relu'
,
[
'X'
],
[
'Out'
]],
'LeakyRelu'
:
[
'leaky_relu'
,
[
'X'
],
[
'Out'
],
...
...
@@ -78,8 +86,7 @@ default_op_mapping = {
'Softplus'
:
[
'softplus'
,
[
'X'
],
[
'Out'
]],
'Exp'
:
[
'exp'
,
[
'X'
],
[
'Out'
]],
'Softmax'
:
[
'softmax'
,
[
'X'
],
[
'Out'
],
dict
(
axis
=
''
),
dict
(
axis
=
1
)],
dict
(),
dict
(
axis
=
1
)],
}
activefunc_op_mapping
=
{
...
...
x2paddle/op_mapper/onnx_op_mapper.py
浏览文件 @
324b75ee
...
...
@@ -24,15 +24,17 @@ from x2paddle.op_mapper.onnx_custom_layer import *
from
x2paddle.core.util
import
string
import
numpy
as
np
import
onnx.numpy_helper
as
numpy_helper
from
onnx.mapping
import
TENSOR_TYPE_TO_NP_TYPE
import
logging
as
_logging
from
collections
import
OrderedDict
as
_dict
import
math
_logger
=
_logging
.
getLogger
(
__name__
)
def
_const_weight_or_none
(
node
):
if
'Constant'
in
node
.
layer_name
:
return
val
.
value
return
node
.
value
if
isinstance
(
node
,
ONNXGraphDataNode
):
return
node
.
weight
return
None
...
...
@@ -94,7 +96,7 @@ class ONNXOpMapper(OpMapper):
print
(
op
)
return
False
def
directly_map
(
self
,
node
,
*
args
,
name
=
''
,
**
kwargs
):
def
directly_map
(
self
,
node
,
name
=
''
,
*
args
,
**
kwargs
):
inputs
=
node
.
layer
.
input
outputs
=
node
.
layer
.
output
op_type
=
node
.
layer_type
...
...
@@ -127,34 +129,38 @@ class ONNXOpMapper(OpMapper):
mapped_attrs
.
pop
(
'_'
)
fluid_attrs
=
default_attrs
.
copy
()
fluid_attrs
.
update
(
mapped_attrs
)
val_inp
s
=
inputs
if
input_perm
is
None
else
list
(
input
s
=
inputs
if
input_perm
is
None
else
list
(
map
(
lambda
i
:
inputs
[
i
],
input_perm
))
val_inps
=
[]
for
idx
,
ipt
in
enumerate
(
inputs
):
val_inps
.
append
(
self
.
graph
.
get_input_node
(
node
,
idx
=
idx
,
copy
=
True
))
val_outs
=
outputs
if
output_perm
is
None
else
list
(
map
(
lambda
i
:
outputs
[
i
],
output_perm
))
attr
=
fluid_attrs
if
fluid_op
not
in
[
'shape'
,
'gather'
]:
attr
[
'name'
]
=
string
(
node
.
layer_name
)
node
.
fluid_code
.
add_layer
(
fluid_op
,
inputs
=
', '
.
join
(
val_inps
)
,
inputs
=
val_inps
,
output
=
val_outs
[
0
],
param_attr
=
attr
)
def
deal_custom_layer
(
self
,
node
):
op
=
node
.
layer_type
val_x
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
custom_code
,
func
=
make_custom_layer
(
node
)
child_func_code
,
child_func
=
make_custom_child_func
(
node
)
params
=
get_params
(
node
.
layer
,
node
.
layer_type
)
arg_names
,
kwargs
=
set_args
(
func
,
params
)
kwargs
[
'name'
]
=
string
(
node
.
layer_name
)
inputs_node
=
[]
inputs_node
.
append
(
node
.
inputs
[
0
])
node
.
fluid_code
.
add_layer
(
func
.
__code__
.
co_name
,
inputs
=
inputs_node
[
0
]
,
inputs
=
node
.
inputs
,
output
=
node
,
param_attr
=
kwargs
,
is_custom_layer
=
True
)
if
op
not
in
self
.
used_custom_layers
:
self
.
used_custom_layers
[
op
]
=
custom_code
if
op
+
'_child_func'
not
in
self
.
used_custom_layers
:
self
.
used_custom_layers
[
op
+
'_child_func'
]
=
child_func_code
def
place_holder
(
self
,
node
):
self
.
input_shapes
.
append
(
node
.
out_shapes
[
0
])
...
...
@@ -203,8 +209,8 @@ class ONNXOpMapper(OpMapper):
return
[
0
]
*
ndims
,
val_padded
def
_interpolate
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_scales
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
1
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_scales
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
1
,
copy
=
True
)
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
out_shape_
=
val_y
.
out_shapes
[
0
]
...
...
@@ -245,7 +251,7 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
Pad
(
self
,
node
,
op_independent
=
True
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
pads
=
node
.
get_attr
(
'pads'
)
mode
=
node
.
get_attr
(
'mode'
,
'constant'
)
value
=
node
.
get_attr
(
'value'
,
0.
)
...
...
@@ -291,8 +297,18 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
return
node
.
layer_name
+
'_paded'
def
TopK
(
self
,
node
):
val_x
=
self
.
graph
.
get_input_node
(
node
,
idx
=
0
,
copy
=
True
)
axes
=
node
.
get_attr
(
'axes'
)
k
=
10
attr
=
{
'k'
:
k
,
'name'
:
string
(
node
.
layer_name
)}
node
.
fluid_code
.
add_layer
(
'topk'
,
inputs
=
val_x
,
output
=
node
,
param_attr
=
attr
)
def
Unsqueeze
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
axes
=
node
.
get_attr
(
'axes'
)
attr
=
{
'axes'
:
axes
,
'name'
:
string
(
node
.
layer_name
)}
node
.
fluid_code
.
add_layer
(
'unsqueeze'
,
...
...
@@ -301,7 +317,7 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
Shrink
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
bias
=
node
.
get_attr
(
'bias'
)
lambd
=
node
.
get_attr
(
'lambd'
)
assert
bias
==
0.0
,
'not support bias!=0'
...
...
@@ -358,8 +374,8 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
Resize
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_scales
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
1
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_scales
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
1
,
copy
=
True
)
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
out_shape_
=
val_y
.
out_shapes
[
0
]
...
...
@@ -401,24 +417,66 @@ class ONNXOpMapper(OpMapper):
def
Upsample
(
self
,
node
):
self
.
_interpolate
(
node
)
def
Gather
(
self
,
node
):
val_x
=
self
.
graph
.
get_input_node
(
node
,
idx
=
0
,
copy
=
True
)
indices
=
self
.
graph
.
get_input_node
(
node
,
idx
=
1
,
copy
=
True
)
indices_shape
=
indices
.
out_shapes
[
0
]
axis
=
node
.
get_attr
(
'axis'
)
print
(
indices
.
layer_name
)
print
(
indices_shape
)
assert
len
(
indices_shape
)
==
1
,
"Gather op don't support dim of indice >1 "
if
axis
==
0
and
len
(
indices_shape
)
==
1
:
node
.
fluid_code
.
add_layer
(
'gather'
,
inputs
=
[
val_x
,
indices
],
output
=
node
,
param_attr
=
None
)
elif
axis
>
0
and
len
(
indices_shape
)
==
1
:
perm
=
[
range
(
len
(
indices_shape
))]
perm
=
[
axis
]
+
perm
[:
axis
]
+
perm
[
axis
+
1
:]
attr_trans
=
{
'perm'
:
perm
}
name_trans
=
val_x
.
layer_name
+
'_trans'
node
.
fluid_code
.
add_layer
(
'transpose'
,
inputs
=
val_x
,
output
=
name_trans
,
param_attr
=
attr_trans
)
node
.
fluid_code
.
add_layer
(
'gather'
,
inputs
=
[
name_trans
,
indices
],
output
=
node
,
param_attr
=
None
)
node
.
fluid_code
.
add_layer
(
'transpose'
,
inputs
=
node
,
output
=
node
,
param_attr
=
attr_trans
)
def
Slice
(
self
,
node
):
val_x
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
val_x
=
self
.
graph
.
get_input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_starts
=
self
.
graph
.
get_input_node
(
node
,
idx
=
1
,
copy
=
True
)
val_ends
=
self
.
graph
.
get_input_node
(
node
,
idx
=
2
,
copy
=
True
)
val_axes
=
self
.
graph
.
get_input_node
(
node
,
idx
=
3
,
copy
=
True
)
val_steps
=
self
.
graph
.
get_input_node
(
node
,
idx
=
4
,
copy
=
True
)
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
axes
=
node
.
get_attr
(
'axes'
)
starts
=
node
.
get_attr
(
'starts'
)
ends
=
node
.
get_attr
(
'ends'
)
starts
=
_const_weight_or_none
(
val_starts
).
copy
()
ends
=
_const_weight_or_none
(
val_ends
).
copy
()
axes
=
_const_weight_or_none
(
val_axes
)
steps
=
_const_weight_or_none
(
val_steps
)
self
.
omit_nodes
.
append
(
val_starts
.
layer_name
)
self
.
omit_nodes
.
append
(
val_ends
.
layer_name
)
self
.
omit_nodes
.
append
(
val_axes
.
layer_name
)
self
.
omit_nodes
.
append
(
val_steps
.
layer_name
)
shape
=
val_x
.
out_shapes
[
0
]
if
shape
is
not
None
:
for
idx
,
value
in
enumerate
(
starts
):
if
value
>
2
**
63
-
1
//
2
:
value
=
value
-
ONNX_INT_MAX
starts
[
idx
]
=
shape
[
axes
[
idx
]]
+
value
if
value
>
shape
[
axes
[
idx
]]:
starts
[
idx
]
=
shape
[
axes
[
idx
]]
for
idx
,
value
in
enumerate
(
ends
):
if
value
>
2
**
63
-
1
//
2
:
value
=
value
-
ONNX_INT_MAX
ends
[
idx
]
=
shape
[
axes
[
idx
]]
+
value
if
value
>
shape
[
axes
[
idx
]]:
ends
[
idx
]
=
shape
[
axes
[
idx
]]
attr
=
{
"axes"
:
axes
,
"starts"
:
starts
,
"ends"
:
ends
}
node
.
fluid_code
.
add_layer
(
'slice'
,
inputs
=
val_x
,
...
...
@@ -426,7 +484,7 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
ConstantOfShape
(
self
,
node
):
val_shape
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_shape
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
shape
=
_const_weight_or_none
(
val_shape
)
...
...
@@ -452,7 +510,7 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
Split
(
self
,
node
):
val_input
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_input
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
var_outs
=
[
val
for
val
in
node
.
layer
.
input
]
fluid_op
=
'split'
...
...
@@ -466,10 +524,11 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
Reshape
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_shape
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
1
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_shape
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
1
,
copy
=
True
)
val_reshaped
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
shape
=
None
if
isinstance
(
val_shape
,
ONNXGraphDataNode
):
self
.
omit_nodes
.
append
(
val_shape
.
layer_name
)
...
...
@@ -503,7 +562,7 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
Cast
(
self
,
node
):
val_input
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_input
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_output
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
dtype
=
node
.
get_attr
(
'to'
)
...
...
@@ -520,7 +579,7 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
AveragePool
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
auto_pad
=
node
.
get_attr
(
'auto_pad'
,
'NOTSET'
)
kernel_shape
=
node
.
get_attr
(
"kernel_shape"
)
...
...
@@ -532,10 +591,10 @@ class ONNXOpMapper(OpMapper):
fluid_op
=
'pool{}d'
.
format
(
poolnd
)
assert
2
<=
poolnd
<=
3
,
'only pool2d and pool3d is supported'
input_shape
=
val_x
.
out_shapes
[
0
]
paddings
,
val_x
=
self
.
_pad_if_asymmetric
(
node
,
pads
,
val_x
)
if
auto_pad
==
"SAME_UPPER"
or
auto_pad
==
"SAME_LOWER"
:
input_shape
=
val_x
.
out_shapes
[
0
]
pad_h
=
get_same_padding
(
input_shape
[
2
],
kernel_shape
[
0
],
strides
[
0
])
pad_w
=
get_same_padding
(
input_shape
[
3
],
kernel_shape
[
1
],
...
...
@@ -560,7 +619,7 @@ class ONNXOpMapper(OpMapper):
def
Concat
(
self
,
node
):
inputs
=
[]
for
i
in
range
(
len
(
node
.
layer
.
input
)):
ipt
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
i
]
,
copy
=
True
)
ipt
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
i
,
copy
=
True
)
if
isinstance
(
ipt
,
str
):
inputs
.
append
(
ipt
)
else
:
...
...
@@ -568,12 +627,12 @@ class ONNXOpMapper(OpMapper):
axis
=
node
.
get_attr
(
'axis'
)
attr
=
{
'axis'
:
axis
}
node
.
fluid_code
.
add_layer
(
'concat'
,
inputs
=
'['
+
', '
.
join
(
inputs
)
+
']'
,
inputs
=
inputs
,
output
=
node
,
param_attr
=
attr
)
def
Flatten
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
axis
=
node
.
get_attr
(
'axis'
,
1
)
attr
=
{
"axis"
:
str
(
axis
),
"name"
:
string
(
node
.
layer_name
)}
node
.
fluid_code
.
add_layer
(
'flatten'
,
...
...
@@ -582,9 +641,9 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
Gemm
(
self
,
node
):
val_a
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_b
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
1
]
,
copy
=
True
)
val_c
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
2
]
,
copy
=
True
)
val_a
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_b
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
1
,
copy
=
True
)
val_c
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
2
,
copy
=
True
)
alpha
=
node
.
get_attr
(
'alpha'
,
1.
)
# optional
beta
=
node
.
get_attr
(
'beta'
,
1.
)
# optional
...
...
@@ -627,8 +686,8 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
Add
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_y
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
1
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_y
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
1
,
copy
=
True
)
inputs
=
{
"x"
:
val_x
,
"y"
:
val_y
,
...
...
@@ -642,23 +701,24 @@ class ONNXOpMapper(OpMapper):
def
Sum
(
self
,
node
):
val_inps
=
node
.
layer
.
input
inputs
=
{
"x"
:
val_inps
[
0
]
,
"y"
:
val_inps
[
1
]
,
"x"
:
self
.
graph
.
get_input_node
(
node
,
idx
=
0
,
copy
=
True
)
,
"y"
:
self
.
graph
.
get_input_node
(
node
,
idx
=
1
,
copy
=
True
)
,
}
node
.
fluid_code
.
add_layer
(
"elementwise_add"
,
inputs
=
inputs
,
output
=
node
)
for
ipt
in
val_inps
[
2
:]:
for
idx
,
ipt
in
enumerate
(
val_inps
[
2
:]):
y
=
self
.
graph
.
get_input_node
(
node
,
idx
=
idx
,
copy
=
True
)
inputs
=
{
"x"
:
node
.
layer_name
,
"y"
:
ipt
,
"y"
:
y
,
}
node
.
fluid_code
.
add_layer
(
"elementwise_add"
,
inputs
=
inputs
,
output
=
node
)
def
MatMul
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_y
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
1
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_y
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
1
,
copy
=
True
)
inputs
=
{
"x"
:
val_x
,
"y"
:
val_y
}
attr
=
{
"name"
:
string
(
node
.
layer_name
)}
node
.
fluid_code
.
add_layer
(
"matmul"
,
...
...
@@ -667,11 +727,11 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
BatchNormalization
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_scale
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
1
]
,
copy
=
True
)
val_b
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
2
]
,
copy
=
True
)
val_mean
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
3
]
,
copy
=
True
)
val_var
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
4
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_scale
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
1
,
copy
=
True
)
val_b
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
2
,
copy
=
True
)
val_mean
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
3
,
copy
=
True
)
val_var
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
4
,
copy
=
True
)
self
.
omit_nodes
.
append
(
val_scale
.
layer_name
)
self
.
omit_nodes
.
append
(
val_b
.
layer_name
)
...
...
@@ -701,7 +761,7 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
Transpose
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
perm
=
node
.
get_attr
(
'perm'
)
attr
=
{
'perm'
:
perm
,
"name"
:
string
(
node
.
layer_name
)}
node
.
fluid_code
.
add_layer
(
"transpose"
,
...
...
@@ -710,12 +770,9 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
Mul
(
self
,
node
):
val_x
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
val_x_shape
=
val_x
.
out_shapes
[
0
]
val_x
=
self
.
graph
.
get_input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_y
=
self
.
graph
.
get_input_node
(
node
,
idx
=
1
,
copy
=
True
)
val_y_shape
=
val_y
.
out_shapes
[
0
]
slice_idx
=
0
for
dim
in
val_y_shape
:
if
dim
==
1
:
...
...
@@ -747,12 +804,9 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
Div
(
self
,
node
):
val_x
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
val_x_shape
=
val_x
.
out_shapes
[
0
]
val_x
=
self
.
graph
.
get_input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_y
=
self
.
graph
.
get_input_node
(
node
,
idx
=
1
,
copy
=
True
)
val_y_shape
=
val_y
.
out_shapes
[
0
]
slice_idx
=
0
for
dim
in
val_y_shape
:
if
dim
==
1
:
...
...
@@ -784,7 +838,7 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
Relu
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
attr
=
{
"name"
:
string
(
node
.
layer_name
)}
node
.
fluid_code
.
add_layer
(
"relu"
,
inputs
=
val_x
,
...
...
@@ -792,8 +846,8 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
PRelu
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_slope
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
1
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_slope
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
1
,
copy
=
True
)
mode
=
'channel'
shape_slope
=
val_slope
.
out_shapes
[
0
]
...
...
@@ -811,20 +865,20 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
Squeeze
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
squeeze_dims
=
node
.
get_attr
(
'squeeze_dim
s'
)
attr
=
{
'axes'
:
squeeze_dim
s
,
"name"
:
string
(
node
.
layer_name
)}
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
axes
=
node
.
get_attr
(
'axe
s'
)
attr
=
{
'axes'
:
axe
s
,
"name"
:
string
(
node
.
layer_name
)}
node
.
fluid_code
.
add_layer
(
"squeeze"
,
inputs
=
val_x
,
output
=
node
,
param_attr
=
attr
)
def
Identity
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
node
.
fluid_code
.
add_layer
(
"assign"
,
inputs
=
val_x
,
output
=
node
)
def
MaxPool
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
auto_pad
=
node
.
get_attr
(
'auto_pad'
,
'NOTSET'
)
assert
node
.
get_attr
(
...
...
@@ -839,10 +893,10 @@ class ONNXOpMapper(OpMapper):
fluid_op
=
'pool{}d'
.
format
(
poolnd
)
assert
2
<=
poolnd
<=
3
,
'only pool2d and pool3d is supported'
input_shape
=
val_x
.
out_shapes
[
0
]
paddings
,
val_x
=
self
.
_pad_if_asymmetric
(
node
,
pads
,
val_x
)
if
auto_pad
==
"SAME_UPPER"
or
auto_pad
==
"SAME_LOWER"
:
input_shape
=
val_x
.
out_shapes
[
0
]
pad_h
=
get_same_padding
(
input_shape
[
2
],
kernel_shape
[
0
],
strides
[
0
])
pad_w
=
get_same_padding
(
input_shape
[
3
],
kernel_shape
[
1
],
...
...
@@ -863,8 +917,18 @@ class ONNXOpMapper(OpMapper):
output
=
node
,
param_attr
=
attr
)
# def Tile(self, node):
# pass
# def Loop(self, node):
# pass
# def NonMaxSuppression(self, node):
# pass
def
GlobalAveragePool
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
input_shape
=
val_x
.
out_shapes
[
0
]
output_shape
=
val_y
.
out_shapes
[
0
]
...
...
@@ -886,21 +950,19 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
Conv
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_w
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
1
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_w
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
1
,
copy
=
True
)
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
self
.
omit_nodes
.
append
(
val_w
.
layer_name
)
has_bias
=
len
(
node
.
layer
.
input
)
==
3
if
has_bias
:
val_b
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
2
]
,
copy
=
True
)
val_b
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
2
,
copy
=
True
)
self
.
omit_nodes
.
append
(
val_b
.
layer_name
)
auto_pad
=
node
.
get_attr
(
'auto_pad'
,
'NOTSET'
)
kernel_shape
=
val_w
.
out_shapes
[
0
][
2
:]
# OI...
assert
kernel_shape
==
node
.
get_attr
(
'kernel_shape'
),
'kernel_shape in attr unmatches value_info'
# HW
kernel_shape
=
node
.
get_attr
(
'kernel_shape'
)
convnd
=
len
(
kernel_shape
)
assert
2
<=
convnd
<=
3
,
'only conv2d and conv3d is supported'
num_out_channels
=
val_w
.
out_shapes
[
0
][
0
]
# OI...
...
...
@@ -941,9 +1003,9 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
def
ConvTranspose
(
self
,
node
):
val_x
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
0
]
,
copy
=
True
)
val_w
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
1
]
,
copy
=
True
)
val_b
=
self
.
graph
.
get_
node
(
node
.
layer
.
input
[
2
]
,
copy
=
True
)
val_x
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_w
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
1
,
copy
=
True
)
val_b
=
self
.
graph
.
get_
input_node
(
node
,
idx
=
2
,
copy
=
True
)
self
.
omit_nodes
.
append
(
val_w
.
layer_name
)
self
.
omit_nodes
.
append
(
val_b
.
layer_name
)
...
...
@@ -952,7 +1014,7 @@ class ONNXOpMapper(OpMapper):
auto_pad
=
node
.
get_attr
(
'auto_pad'
,
'NOTSET'
)
out_padding
=
node
.
get_attr
(
'output_padding'
,
[
0
,
0
])
kernel_shape
=
node
.
get_attr
(
'kernel_shape'
,
val_w
.
out_shapes
[
0
][
2
:]
)
kernel_shape
=
node
.
get_attr
(
'kernel_shape'
)
assert
kernel_shape
,
'kernel_shape not inferred'
convnd
=
len
(
kernel_shape
)
assert
2
<=
convnd
<=
3
,
'only conv2d_transpose and conv3d_transpose supported'
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录