Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
6d11992d
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6d11992d
编写于
9月 11, 2019
作者:
C
channingss
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support transformer
上级
fa5bdff4
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
125 addition
and
90 deletion
+125
-90
setup.py
setup.py
+6
-1
x2paddle/convert.py
x2paddle/convert.py
+1
-4
x2paddle/decoder/onnx_decoder.py
x2paddle/decoder/onnx_decoder.py
+37
-31
x2paddle/op_mapper/onnx_custom_layer/InstanceNormalization.py
...ddle/op_mapper/onnx_custom_layer/InstanceNormalization.py
+1
-1
x2paddle/op_mapper/onnx_custom_layer/__init__.py
x2paddle/op_mapper/onnx_custom_layer/__init__.py
+2
-2
x2paddle/op_mapper/onnx_directly_map.py
x2paddle/op_mapper/onnx_directly_map.py
+1
-9
x2paddle/op_mapper/onnx_op_mapper.py
x2paddle/op_mapper/onnx_op_mapper.py
+77
-42
未找到文件。
setup.py
浏览文件 @
6d11992d
...
@@ -23,4 +23,9 @@ setuptools.setup(
...
@@ -23,4 +23,9 @@ setuptools.setup(
"Operating System :: OS Independent"
,
"Operating System :: OS Independent"
,
],
],
license
=
'Apache 2.0'
,
license
=
'Apache 2.0'
,
entry_points
=
{
'console_scripts'
:
[
'x2paddle=x2paddle.convert:main'
]})
entry_points
=
{
'console_scripts'
:
[
'x2paddle=x2paddle.convert:main'
,
'onnx_infer=x2paddle.decoder.onnx_infer:main'
]
})
x2paddle/convert.py
浏览文件 @
6d11992d
...
@@ -120,9 +120,6 @@ def tf2paddle(model_path,
...
@@ -120,9 +120,6 @@ def tf2paddle(model_path,
mapper
.
save_inference_model
(
save_dir
)
mapper
.
save_inference_model
(
save_dir
)
0
def
caffe2paddle
(
proto
,
weight
,
save_dir
,
caffe_proto
):
def
caffe2paddle
(
proto
,
weight
,
save_dir
,
caffe_proto
):
from
x2paddle.decoder.caffe_decoder
import
CaffeDecoder
from
x2paddle.decoder.caffe_decoder
import
CaffeDecoder
from
x2paddle.op_mapper.caffe_op_mapper
import
CaffeOpMapper
from
x2paddle.op_mapper.caffe_op_mapper
import
CaffeOpMapper
...
@@ -154,7 +151,7 @@ def onnx2paddle(model_path, save_dir):
...
@@ -154,7 +151,7 @@ def onnx2paddle(model_path, save_dir):
print
(
"Now translating model from onnx to paddle."
)
print
(
"Now translating model from onnx to paddle."
)
from
x2paddle.decoder.onnx_decoder
import
ONNXDecoder
from
x2paddle.decoder.onnx_decoder
import
ONNXDecoder
model
=
ONNXDecoder
(
model_path
)
model
=
ONNXDecoder
(
model_path
,
save_dir
)
from
x2paddle.op_mapper.onnx_op_mapper
import
ONNXOpMapper
from
x2paddle.op_mapper.onnx_op_mapper
import
ONNXOpMapper
mapper
=
ONNXOpMapper
(
model
)
mapper
=
ONNXOpMapper
(
model
)
...
...
x2paddle/decoder/onnx_decoder.py
浏览文件 @
6d11992d
...
@@ -29,6 +29,7 @@ from onnx.helper import ValueInfoProto
...
@@ -29,6 +29,7 @@ from onnx.helper import ValueInfoProto
import
numpy
as
np
import
numpy
as
np
from
copy
import
deepcopy
from
copy
import
deepcopy
import
logging
as
_logging
import
logging
as
_logging
import
os
default_op_domain
=
'ai.onnx'
default_op_domain
=
'ai.onnx'
_logger
=
_logging
.
getLogger
(
__name__
)
_logger
=
_logging
.
getLogger
(
__name__
)
...
@@ -131,15 +132,16 @@ class ONNXGraphDataNode(GraphNode):
...
@@ -131,15 +132,16 @@ class ONNXGraphDataNode(GraphNode):
class
ONNXGraph
(
Graph
):
class
ONNXGraph
(
Graph
):
def
__init__
(
self
,
graph
,
onnx_model
):
def
__init__
(
self
,
onnx_model
,
save_dir
):
super
(
ONNXGraph
,
self
).
__init__
(
graph
)
super
(
ONNXGraph
,
self
).
__init__
(
onnx_model
.
graph
)
self
.
onnx_model
=
onnx_model
self
.
onnx_model
=
onnx_model
self
.
initializer
=
{}
self
.
initializer
=
{}
self
.
place_holder_nodes
=
list
()
self
.
place_holder_nodes
=
list
()
self
.
get_place_holder_nodes
()
self
.
get_place_holder_nodes
()
self
.
save_dir
=
save_dir
self
.
value_infos
=
self
.
inferred_model_value_info
(
graph
)
self
.
value_infos
=
self
.
inferred_model_value_info
(
self
.
model
)
self
.
results_of_inference
=
dict
()
self
.
results_of_inference
=
dict
()
self
.
is_inference
=
False
def
get_inner_nodes
(
self
):
def
get_inner_nodes
(
self
):
"""
"""
...
@@ -176,9 +178,6 @@ class ONNXGraph(Graph):
...
@@ -176,9 +178,6 @@ class ONNXGraph(Graph):
"""
"""
build topo_sort of ONNX model
build topo_sort of ONNX model
"""
"""
data_nodes
=
self
.
place_holder_nodes
self
.
get_results_of_inference_rt
(
self
.
onnx_model
,
data_nodes
)
for
layer
in
self
.
model
.
node
:
for
layer
in
self
.
model
.
node
:
node
=
ONNXGraphNode
(
layer
)
node
=
ONNXGraphNode
(
layer
)
self
.
node_map
[
layer
.
name
]
=
node
self
.
node_map
[
layer
.
name
]
=
node
...
@@ -187,13 +186,21 @@ class ONNXGraph(Graph):
...
@@ -187,13 +186,21 @@ class ONNXGraph(Graph):
value_info
=
self
.
value_infos
[
opt
]
value_info
=
self
.
value_infos
[
opt
]
if
len
(
value_info
[
'shape'
]
if
len
(
value_info
[
'shape'
]
)
==
0
or
value_info
[
'dtype'
]
is
None
:
)
==
0
or
value_info
[
'dtype'
]
is
None
:
if
self
.
is_inference
==
False
:
self
.
get_results_of_inference_rt
(
self
.
onnx_model
,
self
.
place_holder_nodes
)
self
.
is_inference
=
True
_
,
dtype
,
shape
=
self
.
get_dynamic_shape
(
opt
)
_
,
dtype
,
shape
=
self
.
get_dynamic_shape
(
opt
)
node
.
dtype
=
dtype
node
.
out_shapes
.
append
(
shape
)
node
.
out_shapes
.
append
(
shape
)
node
.
dtype
=
dtype
else
:
else
:
node
.
dtype
=
value_info
[
'dtype'
]
node
.
dtype
=
value_info
[
'dtype'
]
node
.
out_shapes
.
append
(
value_info
[
'shape'
])
node
.
out_shapes
.
append
(
value_info
[
'shape'
])
else
:
else
:
if
self
.
is_inference
==
False
:
self
.
get_results_of_inference_rt
(
self
.
onnx_model
,
self
.
place_holder_nodes
)
self
.
is_inference
=
True
_
,
dtype
,
shape
=
self
.
get_dynamic_shape
(
opt
)
_
,
dtype
,
shape
=
self
.
get_dynamic_shape
(
opt
)
node
.
dtype
=
dtype
node
.
dtype
=
dtype
node
.
out_shapes
.
append
(
shape
)
node
.
out_shapes
.
append
(
shape
)
...
@@ -232,8 +239,8 @@ class ONNXGraph(Graph):
...
@@ -232,8 +239,8 @@ class ONNXGraph(Graph):
if
opt
==
in_node
:
if
opt
==
in_node
:
self
.
connect
(
nd
.
name
,
layer_name
)
self
.
connect
(
nd
.
name
,
layer_name
)
flag
=
1
flag
=
1
print
(
nd
.
name
+
'->'
+
layer_name
)
node
.
which_child
[
nd
.
name
]
=
idx
node
.
which_child
[
nd
.
name
]
=
idx
self
.
node_map
[
nd
.
name
].
index
=
0
break
break
if
flag
==
1
:
if
flag
==
1
:
break
break
...
@@ -250,7 +257,9 @@ class ONNXGraph(Graph):
...
@@ -250,7 +257,9 @@ class ONNXGraph(Graph):
def
get_input_node
(
self
,
node
,
idx
=
0
,
copy
=
False
):
def
get_input_node
(
self
,
node
,
idx
=
0
,
copy
=
False
):
if
len
(
node
.
which_child
)
==
0
:
if
len
(
node
.
which_child
)
==
0
:
return
super
(
ONNXGraph
,
self
).
get_node
(
node
.
inputs
[
idx
],
copy
)
ipt_node
=
super
(
ONNXGraph
,
self
).
get_node
(
node
.
inputs
[
idx
],
copy
)
return
ipt_node
else
:
else
:
ipt_node
=
super
(
ONNXGraph
,
self
).
get_node
(
node
.
inputs
[
idx
],
copy
)
ipt_node
=
super
(
ONNXGraph
,
self
).
get_node
(
node
.
inputs
[
idx
],
copy
)
if
ipt_node
.
layer_name
in
node
.
which_child
:
if
ipt_node
.
layer_name
in
node
.
which_child
:
...
@@ -312,11 +321,13 @@ class ONNXGraph(Graph):
...
@@ -312,11 +321,13 @@ class ONNXGraph(Graph):
import
torch
import
torch
version
=
torch
.
__version__
version
=
torch
.
__version__
if
'1.1.0'
not
in
version
:
if
'1.1.0'
not
in
version
:
print
(
"your model have dynamic graph, torch==1.1.0 is required"
)
print
(
"shape of somenode need inference, torch==1.1.0 is required"
)
return
return
except
:
except
:
print
(
print
(
"
your model have dynamic graph, we use caff
2 to inference graph, please use
\"
pip install torch==1.1.0
\"
."
"
shape of somenode need inference, we use caffe
2 to inference graph, please use
\"
pip install torch==1.1.0
\"
."
)
)
return
return
from
x2paddle.decoder.onnx_backend
import
prepare
from
x2paddle.decoder.onnx_backend
import
prepare
...
@@ -326,6 +337,7 @@ class ONNXGraph(Graph):
...
@@ -326,6 +337,7 @@ class ONNXGraph(Graph):
value_info
=
self
.
value_infos
[
data_node
]
value_info
=
self
.
value_infos
[
data_node
]
ipt
=
np
.
random
.
random
(
value_info
[
'shape'
]).
astype
(
'float32'
)
ipt
=
np
.
random
.
random
(
value_info
[
'shape'
]).
astype
(
'float32'
)
inputs
.
append
(
ipt
)
inputs
.
append
(
ipt
)
print
(
ipt
.
shape
)
outputs
=
[]
outputs
=
[]
for
node
in
model
.
graph
.
node
:
for
node
in
model
.
graph
.
node
:
value_info
=
helper
.
make_tensor_value_info
(
node
.
name
,
value_info
=
helper
.
make_tensor_value_info
(
node
.
name
,
...
@@ -347,13 +359,11 @@ class ONNXGraph(Graph):
...
@@ -347,13 +359,11 @@ class ONNXGraph(Graph):
return
return
def
get_results_of_inference_rt
(
self
,
model
,
data_nodes
):
def
get_results_of_inference_rt
(
self
,
model
,
data_nodes
):
import
onnxruntime
as
rt
inputs
=
[]
inputs
=
[]
for
data_node
in
data_nodes
:
for
data_node
in
data_nodes
:
value_info
=
self
.
value_infos
[
data_node
]
value_info
=
self
.
value_infos
[
data_node
]
ipt
=
np
.
random
.
random
(
value_info
[
'shape'
]).
astype
(
'float32'
)
ipt
=
np
.
random
.
random
(
value_info
[
'shape'
]).
astype
(
value_info
[
'dtype'
])
inputs
.
append
(
ipt
)
inputs
.
append
(
ipt
)
model
=
onnx
.
shape_inference
.
infer_shapes
(
model
)
model
=
onnx
.
shape_inference
.
infer_shapes
(
model
)
...
@@ -363,30 +373,26 @@ class ONNXGraph(Graph):
...
@@ -363,30 +373,26 @@ class ONNXGraph(Graph):
model
.
graph
.
ClearField
(
'output'
)
model
.
graph
.
ClearField
(
'output'
)
model
.
graph
.
output
.
MergeFrom
(
outputs
)
model
.
graph
.
output
.
MergeFrom
(
outputs
)
onnx
.
save
(
model
,
'./onnx_model_infer.onnx'
)
if
not
os
.
path
.
exists
(
self
.
save_dir
):
os
.
makedirs
(
self
.
save_dir
)
sess
=
rt
.
InferenceSession
(
'./onnx_model_infer.onnx'
)
onnx
.
save
(
model
,
os
.
path
.
join
(
self
.
save_dir
,
'onnx_model_infer.onnx'
))
inputs_dict
=
{}
np
.
save
(
os
.
path
.
join
(
self
.
save_dir
,
'input_data.npy'
),
inputs
)
os
.
system
(
'onnx_infer --save_dir='
+
self
.
save_dir
)
for
i
,
ipt
in
enumerate
(
inputs
):
# res = np.load(os.path.join(self.save_dir, 'results_of_inference.npy'),allow_pickle=True)
inputs_dict
[
sess
.
get_inputs
()[
i
].
name
]
=
ipt
# for idx, info in enumerate(outputs):
# self.results_of_inference[info.name] = res[idx]
res
=
sess
.
run
(
None
,
input_feed
=
inputs_dict
)
for
idx
,
info
in
enumerate
(
outputs
):
self
.
results_of_inference
[
info
.
name
]
=
res
[
idx
]
return
return
def
get_dynamic_shape
(
self
,
layer
):
def
get_dynamic_shape
(
self
,
layer
):
"""
"""
get dynamic shape from infer_result
get dynamic shape from infer_result
"""
"""
output
=
self
.
results_of_inference
[
layer
]
output
=
np
.
load
(
os
.
path
.
join
(
self
.
save_dir
,
layer
+
'.npy'
))
return
output
.
tolist
(),
output
.
dtype
,
output
.
shape
return
output
.
tolist
(),
output
.
dtype
,
output
.
shape
class
ONNXDecoder
(
object
):
class
ONNXDecoder
(
object
):
def
__init__
(
self
,
onnx_model
):
def
__init__
(
self
,
onnx_model
,
save_dir
):
model
=
onnx
.
load
(
onnx_model
)
model
=
onnx
.
load
(
onnx_model
)
print
(
'model ir_version: {}, op version: {}'
.
format
(
print
(
'model ir_version: {}, op version: {}'
.
format
(
model
.
ir_version
,
model
.
opset_import
[
0
].
version
))
model
.
ir_version
,
model
.
opset_import
[
0
].
version
))
...
@@ -405,7 +411,7 @@ class ONNXDecoder(object):
...
@@ -405,7 +411,7 @@ class ONNXDecoder(object):
self
.
model
=
model
self
.
model
=
model
graph
=
model
.
graph
graph
=
model
.
graph
self
.
onnx_graph
=
ONNXGraph
(
graph
,
model
)
self
.
onnx_graph
=
ONNXGraph
(
model
,
save_dir
)
self
.
onnx_graph
.
build
()
self
.
onnx_graph
.
build
()
def
build_value_refs
(
self
,
nodes
):
def
build_value_refs
(
self
,
nodes
):
...
...
x2paddle/op_mapper/onnx_custom_layer/InstanceNormalization.py
浏览文件 @
6d11992d
...
@@ -24,7 +24,7 @@ def InstanceNormalization_layer(inputs, name=None):
...
@@ -24,7 +24,7 @@ def InstanceNormalization_layer(inputs, name=None):
epsilon
=
1e-5
epsilon
=
1e-5
input_
=
inputs
[
0
]
input_
=
inputs
[
0
]
mean
=
fluid
.
layers
.
reduce_mean
(
input_
,
dim
=
[
2
,
3
],
keep_dim
=
True
)
mean
=
fluid
.
layers
.
reduce_mean
(
input_
,
dim
=
[
2
,
3
],
keep_dim
=
True
)
var
=
fluid
.
layers
.
reduce_mean
(
fluid
.
layers
.
square
(
input
s
-
mean
),
var
=
fluid
.
layers
.
reduce_mean
(
fluid
.
layers
.
square
(
input
_
-
mean
),
dim
=
[
2
,
3
],
dim
=
[
2
,
3
],
keep_dim
=
True
)
keep_dim
=
True
)
if
name
is
not
None
:
if
name
is
not
None
:
...
...
x2paddle/op_mapper/onnx_custom_layer/__init__.py
浏览文件 @
6d11992d
...
@@ -100,9 +100,9 @@ def make_custom_child_func(node):
...
@@ -100,9 +100,9 @@ def make_custom_child_func(node):
""" get the code which implement the custom layer function
""" get the code which implement the custom layer function
"""
"""
layer_type
=
node
.
layer_type
layer_type
=
node
.
layer_type
assert
layer_type
in
custom_layers
,
"layer[%s] not exist in custom layers"
%
(
layer_type
)
child_func
=
custom_layers
[
layer_type
][
'child_func'
]
child_func
=
custom_layers
[
layer_type
][
'child_func'
]
if
child_func
is
None
:
return
None
,
child_func
import
inspect
import
inspect
return
inspect
.
getsource
(
child_func
),
child_func
return
inspect
.
getsource
(
child_func
),
child_func
...
...
x2paddle/op_mapper/onnx_directly_map.py
浏览文件 @
6d11992d
...
@@ -29,12 +29,6 @@ default_op_mapping = {
...
@@ -29,12 +29,6 @@ default_op_mapping = {
'Gather'
:
[
'gather'
,
[
'X'
],
[
'Out'
],
'Gather'
:
[
'gather'
,
[
'X'
],
[
'Out'
],
dict
(
axis
=
''
)],
dict
(
axis
=
''
)],
'Shape'
:
[
'shape'
,
[
'X'
],
[
'Out'
]],
'Shape'
:
[
'shape'
,
[
'X'
],
[
'Out'
]],
'Mul'
:
[
'elementwise_mul'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
'Sub'
:
[
'elementwise_sub'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
'Clip'
:
[
'Clip'
:
[
'clip'
,
[
'X'
],
[
'Out'
],
'clip'
,
[
'X'
],
[
'Out'
],
dict
(),
dict
(),
...
@@ -74,9 +68,6 @@ default_op_mapping = {
...
@@ -74,9 +68,6 @@ default_op_mapping = {
],
],
'Tanh'
:
[
'tanh'
,
[
'X'
],
[
'Out'
]],
'Tanh'
:
[
'tanh'
,
[
'X'
],
[
'Out'
]],
'Sigmoid'
:
[
'sigmoid'
,
[
'X'
],
[
'Out'
]],
'Sigmoid'
:
[
'sigmoid'
,
[
'X'
],
[
'Out'
]],
'Pow'
:
[
'elementwise_pow'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
# TODO: pow for scalar exponent
'HardSigmoid'
:
[
'HardSigmoid'
:
[
'hard_sigmoid'
,
[
'X'
],
[
'Out'
],
'hard_sigmoid'
,
[
'X'
],
[
'Out'
],
dict
(
alpha
=
'slope'
,
beta
=
'offset'
),
dict
(
alpha
=
'slope'
,
beta
=
'offset'
),
...
@@ -87,6 +78,7 @@ default_op_mapping = {
...
@@ -87,6 +78,7 @@ default_op_mapping = {
'Exp'
:
[
'exp'
,
[
'X'
],
[
'Out'
]],
'Exp'
:
[
'exp'
,
[
'X'
],
[
'Out'
]],
'Softmax'
:
[
'softmax'
,
[
'X'
],
[
'Out'
],
'Softmax'
:
[
'softmax'
,
[
'X'
],
[
'Out'
],
dict
(),
dict
(
axis
=
1
)],
dict
(),
dict
(
axis
=
1
)],
'Sqrt'
:
[
'sqrt'
,
[
'X'
],
[
'Out'
]],
}
}
activefunc_op_mapping
=
{
activefunc_op_mapping
=
{
...
...
x2paddle/op_mapper/onnx_op_mapper.py
浏览文件 @
6d11992d
...
@@ -101,7 +101,6 @@ class ONNXOpMapper(OpMapper):
...
@@ -101,7 +101,6 @@ class ONNXOpMapper(OpMapper):
outputs
=
node
.
layer
.
output
outputs
=
node
.
layer
.
output
op_type
=
node
.
layer_type
op_type
=
node
.
layer_type
attrs
=
node
.
attr_map
attrs
=
node
.
attr_map
info
=
default_op_mapping
[
op_type
]
info
=
default_op_mapping
[
op_type
]
info
.
extend
(
list
(
default_op_mapping_field_values
.
values
())[
len
(
info
):])
info
.
extend
(
list
(
default_op_mapping_field_values
.
values
())[
len
(
info
):])
(
(
...
@@ -138,10 +137,11 @@ class ONNXOpMapper(OpMapper):
...
@@ -138,10 +137,11 @@ class ONNXOpMapper(OpMapper):
val_outs
=
outputs
if
output_perm
is
None
else
list
(
val_outs
=
outputs
if
output_perm
is
None
else
list
(
map
(
lambda
i
:
outputs
[
i
],
output_perm
))
map
(
lambda
i
:
outputs
[
i
],
output_perm
))
attr
=
fluid_attrs
attr
=
fluid_attrs
if
fluid_op
not
in
[
'shape'
,
'gather'
]:
assert
len
(
val_inps
)
==
1
,
'directly_map error with multi inputs'
if
fluid_op
not
in
[
'shape'
]:
attr
[
'name'
]
=
string
(
node
.
layer_name
)
attr
[
'name'
]
=
string
(
node
.
layer_name
)
node
.
fluid_code
.
add_layer
(
fluid_op
,
node
.
fluid_code
.
add_layer
(
fluid_op
,
inputs
=
val_inps
,
inputs
=
val_inps
[
0
]
,
output
=
val_outs
[
0
],
output
=
val_outs
[
0
],
param_attr
=
attr
)
param_attr
=
attr
)
...
@@ -160,7 +160,9 @@ class ONNXOpMapper(OpMapper):
...
@@ -160,7 +160,9 @@ class ONNXOpMapper(OpMapper):
if
op
not
in
self
.
used_custom_layers
:
if
op
not
in
self
.
used_custom_layers
:
self
.
used_custom_layers
[
op
]
=
custom_code
self
.
used_custom_layers
[
op
]
=
custom_code
if
op
+
'_child_func'
not
in
self
.
used_custom_layers
:
if
op
+
'_child_func'
not
in
self
.
used_custom_layers
:
self
.
used_custom_layers
[
op
+
'_child_func'
]
=
child_func_code
if
child_func_code
is
not
None
:
self
.
used_custom_layers
[
op
+
'_child_func'
]
=
child_func_code
def
place_holder
(
self
,
node
):
def
place_holder
(
self
,
node
):
self
.
input_shapes
.
append
(
node
.
out_shapes
[
0
])
self
.
input_shapes
.
append
(
node
.
out_shapes
[
0
])
...
@@ -422,18 +424,21 @@ class ONNXOpMapper(OpMapper):
...
@@ -422,18 +424,21 @@ class ONNXOpMapper(OpMapper):
indices
=
self
.
graph
.
get_input_node
(
node
,
idx
=
1
,
copy
=
True
)
indices
=
self
.
graph
.
get_input_node
(
node
,
idx
=
1
,
copy
=
True
)
indices_shape
=
indices
.
out_shapes
[
0
]
indices_shape
=
indices
.
out_shapes
[
0
]
axis
=
node
.
get_attr
(
'axis'
)
axis
=
node
.
get_attr
(
'axis'
)
print
(
indices
.
layer_name
)
print
(
indices_shape
)
assert
len
(
assert
len
(
indices_shape
)
=
=
1
,
"Gather op don't support dim of indice >1 "
indices_shape
)
<
=
1
,
"Gather op don't support dim of indice >1 "
if
axis
==
0
and
len
(
indices_shape
)
=
=
1
:
if
axis
==
0
and
len
(
indices_shape
)
<
=
1
:
node
.
fluid_code
.
add_layer
(
'gather'
,
node
.
fluid_code
.
add_layer
(
'gather'
,
inputs
=
[
val_x
,
indices
],
inputs
=
{
'input'
:
val_x
,
'index'
:
indices
},
output
=
node
,
output
=
node
,
param_attr
=
None
)
param_attr
=
None
)
elif
axis
>
0
and
len
(
indices_shape
)
==
1
:
elif
axis
>
0
and
len
(
indices_shape
)
<=
1
:
perm
=
[
range
(
len
(
indices_shape
))]
perm
=
list
(
range
(
len
(
val_x
.
out_shapes
[
0
])))
print
(
val_x
.
out_shapes
[
0
])
perm
=
[
axis
]
+
perm
[:
axis
]
+
perm
[
axis
+
1
:]
perm
=
[
axis
]
+
perm
[:
axis
]
+
perm
[
axis
+
1
:]
# perm = [0]
attr_trans
=
{
'perm'
:
perm
}
attr_trans
=
{
'perm'
:
perm
}
name_trans
=
val_x
.
layer_name
+
'_trans'
name_trans
=
val_x
.
layer_name
+
'_trans'
node
.
fluid_code
.
add_layer
(
'transpose'
,
node
.
fluid_code
.
add_layer
(
'transpose'
,
...
@@ -441,7 +446,10 @@ class ONNXOpMapper(OpMapper):
...
@@ -441,7 +446,10 @@ class ONNXOpMapper(OpMapper):
output
=
name_trans
,
output
=
name_trans
,
param_attr
=
attr_trans
)
param_attr
=
attr_trans
)
node
.
fluid_code
.
add_layer
(
'gather'
,
node
.
fluid_code
.
add_layer
(
'gather'
,
inputs
=
[
name_trans
,
indices
],
inputs
=
{
'input'
:
name_trans
,
'index'
:
indices
},
output
=
node
,
output
=
node
,
param_attr
=
None
)
param_attr
=
None
)
node
.
fluid_code
.
add_layer
(
'transpose'
,
node
.
fluid_code
.
add_layer
(
'transpose'
,
...
@@ -451,23 +459,29 @@ class ONNXOpMapper(OpMapper):
...
@@ -451,23 +459,29 @@ class ONNXOpMapper(OpMapper):
def
Slice
(
self
,
node
):
def
Slice
(
self
,
node
):
val_x
=
self
.
graph
.
get_input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_x
=
self
.
graph
.
get_input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_starts
=
self
.
graph
.
get_input_node
(
node
,
idx
=
1
,
copy
=
True
)
val_starts
,
val_ends
,
val_axes
,
val_steps
=
None
,
None
,
None
,
None
val_ends
=
self
.
graph
.
get_input_node
(
node
,
idx
=
2
,
copy
=
True
)
if
len
(
node
.
inputs
)
>
1
:
val_axes
=
self
.
graph
.
get_input_node
(
node
,
idx
=
3
,
copy
=
True
)
starts
=
self
.
graph
.
get_input_node
(
node
,
idx
=
1
,
copy
=
True
)
val_steps
=
self
.
graph
.
get_input_node
(
node
,
idx
=
4
,
copy
=
True
)
ends
=
self
.
graph
.
get_input_node
(
node
,
idx
=
2
,
copy
=
True
)
axes
=
self
.
graph
.
get_input_node
(
node
,
idx
=
3
,
copy
=
True
)
steps
=
self
.
graph
.
get_input_node
(
node
,
idx
=
4
,
copy
=
True
)
self
.
omit_nodes
.
append
(
starts
.
layer_name
)
self
.
omit_nodes
.
append
(
ends
.
layer_name
)
self
.
omit_nodes
.
append
(
axes
.
layer_name
)
self
.
omit_nodes
.
append
(
steps
.
layer_name
)
starts
=
_const_weight_or_none
(
starts
).
copy
()
ends
=
_const_weight_or_none
(
ends
).
copy
()
axes
=
_const_weight_or_none
(
axes
)
steps
=
_const_weight_or_none
(
steps
)
else
:
starts
=
node
.
get_attr
(
'starts'
)
ends
=
node
.
get_attr
(
'ends'
)
axes
=
node
.
get_attr
(
'axes'
)
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
starts
=
_const_weight_or_none
(
val_starts
).
copy
()
ends
=
_const_weight_or_none
(
val_ends
).
copy
()
axes
=
_const_weight_or_none
(
val_axes
)
steps
=
_const_weight_or_none
(
val_steps
)
self
.
omit_nodes
.
append
(
val_starts
.
layer_name
)
self
.
omit_nodes
.
append
(
val_ends
.
layer_name
)
self
.
omit_nodes
.
append
(
val_axes
.
layer_name
)
self
.
omit_nodes
.
append
(
val_steps
.
layer_name
)
shape
=
val_x
.
out_shapes
[
0
]
shape
=
val_x
.
out_shapes
[
0
]
if
shape
is
not
None
:
if
shape
is
not
None
:
...
@@ -510,17 +524,21 @@ class ONNXOpMapper(OpMapper):
...
@@ -510,17 +524,21 @@ class ONNXOpMapper(OpMapper):
param_attr
=
attr
)
param_attr
=
attr
)
def
Split
(
self
,
node
):
def
Split
(
self
,
node
):
val_
input
=
self
.
graph
.
get_input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_
x
=
self
.
graph
.
get_input_node
(
node
,
idx
=
0
,
copy
=
True
)
va
r_outs
=
[
val
for
val
in
node
.
layer
.
input
]
va
l_y
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
fluid_op
=
'split'
fluid_op
=
'split'
split
=
node
.
get_attr
[
'split'
]
split
=
node
.
get_attr
(
'split'
)
axis
=
node
.
get_attr
(
'axis'
,
0
)
axis
=
node
.
get_attr
(
'axis'
,
0
)
attr
=
{
'split'
:
split
,
'axis'
:
axis
,
'name'
:
string
(
node
.
layer_name
)}
attr
=
{
'num_or_sections'
:
split
,
'dim'
:
axis
,
'name'
:
string
(
node
.
layer_name
)
}
# generation
# generation
node
.
fluid_code
.
add_layer
(
'split'
,
node
.
fluid_code
.
add_layer
(
'split'
,
inputs
=
val_
input
,
inputs
=
val_
x
,
output
=
va
r_outs
,
output
=
va
l_y
,
param_attr
=
attr
)
param_attr
=
attr
)
def
Reshape
(
self
,
node
):
def
Reshape
(
self
,
node
):
...
@@ -536,6 +554,7 @@ class ONNXOpMapper(OpMapper):
...
@@ -536,6 +554,7 @@ class ONNXOpMapper(OpMapper):
if
isinstance
(
val_shape
,
ONNXGraphNode
):
if
isinstance
(
val_shape
,
ONNXGraphNode
):
shape
,
_
,
_
=
self
.
decoder
.
onnx_graph
.
get_dynamic_shape
(
shape
,
_
,
_
=
self
.
decoder
.
onnx_graph
.
get_dynamic_shape
(
val_shape
.
layer_name
)
val_shape
.
layer_name
)
if
shape
is
None
:
if
shape
is
None
:
shape
=
val_reshaped
.
out_shapes
[
0
]
shape
=
val_reshaped
.
out_shapes
[
0
]
...
@@ -698,6 +717,32 @@ class ONNXOpMapper(OpMapper):
...
@@ -698,6 +717,32 @@ class ONNXOpMapper(OpMapper):
output
=
node
,
output
=
node
,
param_attr
=
attr
)
param_attr
=
attr
)
def
Sub
(
self
,
node
):
val_x
=
self
.
graph
.
get_input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_y
=
self
.
graph
.
get_input_node
(
node
,
idx
=
1
,
copy
=
True
)
inputs
=
{
"x"
:
val_x
,
"y"
:
val_y
,
}
attr
=
{
"name"
:
string
(
node
.
layer_name
)}
node
.
fluid_code
.
add_layer
(
"elementwise_sub"
,
inputs
=
inputs
,
output
=
node
,
param_attr
=
attr
)
def
Pow
(
self
,
node
):
val_x
=
self
.
graph
.
get_input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_y
=
self
.
graph
.
get_input_node
(
node
,
idx
=
1
,
copy
=
True
)
inputs
=
{
"x"
:
val_x
,
"y"
:
val_y
,
}
attr
=
{
"name"
:
string
(
node
.
layer_name
)}
node
.
fluid_code
.
add_layer
(
"elementwise_pow"
,
inputs
=
inputs
,
output
=
node
,
param_attr
=
attr
)
def
Sum
(
self
,
node
):
def
Sum
(
self
,
node
):
val_inps
=
node
.
layer
.
input
val_inps
=
node
.
layer
.
input
inputs
=
{
inputs
=
{
...
@@ -917,16 +962,6 @@ class ONNXOpMapper(OpMapper):
...
@@ -917,16 +962,6 @@ class ONNXOpMapper(OpMapper):
output
=
node
,
output
=
node
,
param_attr
=
attr
)
param_attr
=
attr
)
# def Tile(self, node):
# pass
# def Loop(self, node):
# pass
# def NonMaxSuppression(self, node):
# pass
def
GlobalAveragePool
(
self
,
node
):
def
GlobalAveragePool
(
self
,
node
):
val_x
=
self
.
graph
.
get_input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_x
=
self
.
graph
.
get_input_node
(
node
,
idx
=
0
,
copy
=
True
)
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录