Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
43cf70b5
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
43cf70b5
编写于
8月 28, 2019
作者:
C
channingss
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug of softmax op
上级
e6c908f6
变更
8
展开全部
隐藏空白更改
内联
并排
Showing
8 changed file
with
1398 addition
and
96 deletion
+1398
-96
x2paddle/convert.py
x2paddle/convert.py
+6
-3
x2paddle/decoder/onnx_backend.py
x2paddle/decoder/onnx_backend.py
+1074
-0
x2paddle/decoder/onnx_decoder.py
x2paddle/decoder/onnx_decoder.py
+73
-59
x2paddle/op_mapper/onnx_custom_layer/InstanceNormalization.py
...ddle/op_mapper/onnx_custom_layer/InstanceNormalization.py
+0
-1
x2paddle/op_mapper/onnx_custom_layer/__init__.py
x2paddle/op_mapper/onnx_custom_layer/__init__.py
+0
-1
x2paddle/op_mapper/onnx_directly_map.py
x2paddle/op_mapper/onnx_directly_map.py
+30
-1
x2paddle/op_mapper/onnx_op_mapper.py
x2paddle/op_mapper/onnx_op_mapper.py
+215
-30
x2paddle/optimizer/onnx_optimizer.py
x2paddle/optimizer/onnx_optimizer.py
+0
-1
未找到文件。
x2paddle/convert.py
浏览文件 @
43cf70b5
...
...
@@ -110,14 +110,17 @@ def onnx2paddle(model_path, save_dir):
except
:
print
(
"onnx is not installed, use
\"
pip install onnx==1.5.0
\"
."
)
return
print
(
"Now translating model from onnx to paddle."
)
from
x2paddle.decoder.onnx_decoder
import
ONNXDecoder
from
x2paddle.op_mapper.onnx_op_mapper
import
ONNXOpMapper
from
x2paddle.optimizer.onnx_optimizer
import
ONNXOptimizer
print
(
"Now translating model from onnx to paddle."
)
model
=
ONNXDecoder
(
model_path
)
from
x2paddle.op_mapper.onnx_op_mapper
import
ONNXOpMapper
mapper
=
ONNXOpMapper
(
model
)
from
x2paddle.optimizer.onnx_optimizer
import
ONNXOptimizer
optimizer
=
ONNXOptimizer
(
mapper
)
optimizer
.
delete_redundance_code
()
mapper
.
save_inference_model
(
save_dir
)
...
...
x2paddle/decoder/onnx_backend.py
0 → 100644
浏览文件 @
43cf70b5
此差异已折叠。
点击以展开。
x2paddle/decoder/onnx_decoder.py
浏览文件 @
43cf70b5
...
...
@@ -23,6 +23,7 @@ from onnx.helper import get_attribute_value, make_attribute
from
onnx.shape_inference
import
infer_shapes
from
onnx.mapping
import
TENSOR_TYPE_TO_NP_TYPE
from
onnx.numpy_helper
import
to_array
from
onnx
import
AttributeProto
,
TensorProto
,
GraphProto
from
collections
import
OrderedDict
as
Dict
import
onnx
import
numpy
as
np
...
...
@@ -59,7 +60,6 @@ class ONNXGraphNode(GraphNode):
@
property
def
value
(
self
):
assert
'Constant'
in
self
.
layer_type
,
"Only Constant | ConstantOfShape node has value."
print
(
self
.
layer
)
attr
=
self
.
layer
.
attribute
[
'value'
]
if
'value'
not
in
self
.
attr_map
:
return
None
...
...
@@ -120,12 +120,15 @@ class ONNXGraphDataNode(GraphNode):
class
ONNXGraph
(
Graph
):
def
__init__
(
self
,
model
):
super
(
ONNXGraph
,
self
).
__init__
(
model
)
def
__init__
(
self
,
graph
,
onnx_model
):
super
(
ONNXGraph
,
self
).
__init__
(
graph
)
self
.
onnx_model
=
onnx_model
self
.
initializer
=
{}
self
.
place_holder_nodes
=
list
()
self
.
get_place_holder_nodes
()
self
.
value_infos
=
self
.
inferred_model_value_info
(
model
)
self
.
value_infos
=
self
.
inferred_model_value_info
(
graph
)
self
.
results_of_inference
=
dict
()
def
get_inner_nodes
(
self
):
"""
...
...
@@ -162,13 +165,22 @@ class ONNXGraph(Graph):
"""
build topo_sort of ONNX model
"""
data_node
=
self
.
place_holder_nodes
[
0
]
value_info
=
self
.
value_infos
[
data_node
]
input_shape
=
value_info
[
'shape'
]
self
.
get_results_of_inference
(
self
.
onnx_model
,
input_shape
)
for
layer
in
self
.
model
.
node
:
node
=
ONNXGraphNode
(
layer
)
self
.
node_map
[
layer
.
name
]
=
node
for
opt
in
layer
.
output
:
value_info
=
self
.
value_infos
[
opt
]
node
.
dtype
=
value_info
[
'dtype'
]
node
.
out_shapes
.
append
(
value_info
[
'shape'
])
if
opt
in
self
.
value_infos
:
value_info
=
self
.
value_infos
[
opt
]
node
.
dtype
=
value_info
[
'dtype'
]
node
.
out_shapes
.
append
(
value_info
[
'shape'
])
else
:
_
,
dtype
,
shape
=
self
.
get_dynamic_shape
(
opt
)
node
.
dtype
=
dtype
node
.
out_shapes
.
append
(
shape
)
for
layer
in
self
.
model
.
input
:
if
layer
.
name
not
in
self
.
node_map
:
...
...
@@ -195,10 +207,7 @@ class ONNXGraph(Graph):
format
(
in_node
,
layer_name
))
else
:
self
.
connect
(
in_node
,
layer_name
)
# print([layer_name for layer_name, node in self.node_map.items()])
#generate topo
#generate topo
super
(
ONNXGraph
,
self
).
build
()
self
.
input_nodes
=
self
.
place_holder_nodes
...
...
@@ -229,7 +238,6 @@ class ONNXGraph(Graph):
"""
collect value/type info for an ONNX model
"""
assert
isinstance
(
graph
,
onnx
.
GraphProto
),
'model is not a ModelProto instance'
...
...
@@ -252,6 +260,7 @@ class ONNXGraph(Graph):
'external'
:
True
}
for
item
in
graph
.
output
:
assert
item
.
name
not
in
value_info
value_info
[
item
.
name
]
=
{
'dtype'
:
TENSOR_TYPE_TO_NP_TYPE
[
item
.
type
.
tensor_type
.
elem_type
],
...
...
@@ -261,34 +270,74 @@ class ONNXGraph(Graph):
}
return
value_info
def
get_results_of_inference
(
self
,
model
,
shape
):
try
:
import
torch
version
=
torch
.
__version__
if
'1.1.0'
not
in
version
:
print
(
"your model have dynamic graph, torch==1.1.0 is required"
)
return
except
:
print
(
"your model have dynamic graph, we use caff2 to inference graph, please use
\"
pip install torch==1.1.0
\"
."
)
return
from
x2paddle.decoder.onnx_backend
import
prepare
np_images
=
np
.
random
.
rand
(
shape
[
0
],
shape
[
1
],
shape
[
2
],
shape
[
3
]).
astype
(
'float32'
)
outputs
=
[]
for
node
in
model
.
graph
.
node
:
value_info
=
helper
.
make_tensor_value_info
(
node
.
name
,
TensorProto
.
UNDEFINED
,
[])
outputs
.
append
(
value_info
)
while
len
(
outputs
)
>
0
:
tmp_outputs
=
outputs
[:
254
]
model
.
graph
.
ClearField
(
'output'
)
model
.
graph
.
output
.
MergeFrom
(
tmp_outputs
)
prepared_backend
=
prepare
(
model
,
device
=
'CPU'
,
no_check_UNSAFE
=
True
)
res
=
prepared_backend
.
run
(
inputs
=
np_images
)
for
idx
,
info
in
enumerate
(
tmp_outputs
):
self
.
results_of_inference
[
info
.
name
]
=
res
[
idx
]
outputs
=
outputs
[
254
:]
return
def
get_dynamic_shape
(
self
,
layer
):
"""
get dynamic shape from caffe2.backend
"""
output
=
self
.
results_of_inference
[
layer
]
return
output
.
tolist
(),
output
.
dtype
,
output
.
shape
class
ONNXDecoder
(
object
):
def
__init__
(
self
,
onnx_model
):
model
=
onnx
.
load
(
onnx_model
)
print
(
'model ir_version: {}, op version: {}'
.
format
(
model
.
ir_version
,
model
.
opset_import
[
0
].
version
))
if
model
.
opset_import
[
0
].
version
<
9
:
_logger
.
warning
(
'Now, onnx2paddle main support convert onnx model opset_verison == 9,'
'opset_verison of your onnx model is %d < 9,'
'some operator may cannot convert.'
,
model
.
opset_import
[
0
].
version
)
check_model
(
model
)
model
=
polish_model
(
model
)
check_model
(
model
)
model
=
onnx
.
shape_inference
.
infer_shapes
(
model
)
model
=
self
.
optimize_model_skip_op_for_inference
(
model
)
model
=
self
.
optimize_model_strip_initializer
(
model
)
self
.
standardize_variable_name
(
model
.
graph
)
self
.
model
=
model
graph_def
=
model
.
graph
self
.
onnx_graph
=
ONNXGraph
(
graph_def
)
self
.
onnx_graph
=
ONNXGraph
(
graph_def
,
model
)
self
.
onnx_graph
.
build
()
self
.
results_of_inference
=
dict
()
def
build_value_refs
(
self
,
nodes
):
"""
build op reference of inputs and outputs
...
...
@@ -369,9 +418,13 @@ class ONNXDecoder(object):
output_name
,
output_refs
)
else
:
processed
=
-
1
if
processed
>
0
:
nodes_to_remove
.
append
(
node_idx
)
for
value_info
in
ret
.
graph
.
value_info
:
for
output
in
node
.
output
:
if
value_info
.
name
==
output
:
ret
.
graph
.
value_info
.
remove
(
value_info
)
print
(
'skip op {}: {} -> {} -> {}'
.
format
(
node_idx
,
input_name
,
node
.
op_type
,
output_name
))
elif
processed
==
0
:
...
...
@@ -431,7 +484,6 @@ class ONNXDecoder(object):
"""
standardize variable name for paddle's code
"""
for
initializer
in
graph
.
initializer
:
initializer
.
name
=
self
.
make_variable_name
(
initializer
.
name
)
for
ipt
in
graph
.
input
:
...
...
@@ -490,41 +542,3 @@ class ONNXDecoder(object):
raise
RuntimeError
(
"Input mismatch {} != {}"
.
format
(
len
(
onnx_model
.
input
),
len
(
model
.
input
)))
return
onnx_model
def
get_results_of_inference
(
self
,
model
,
input_shapes
):
try
:
import
torch
version
=
torch
.
__version__
if
'1.1.0'
not
in
version
:
print
(
"your model have dynamic graph, torch==1.1.0 is required"
)
return
except
:
print
(
"your model have dynamic graph, we use caff2 to inference graph, please use
\"
pip install torch==1.1.0
\"
."
)
return
from
caffe2.python.onnx.backend
import
prepare
shape
=
input_shapes
[
0
]
np_images
=
np
.
random
.
rand
(
shape
[
0
],
shape
[
1
],
shape
[
2
],
shape
[
3
]).
astype
(
'float32'
)
infer_shapes
=
onnx
.
shape_inference
.
infer_shapes
(
model
)
model
.
graph
.
ClearField
(
'output'
)
model
.
graph
.
output
.
MergeFrom
(
infer_shapes
.
graph
.
value_info
)
prepared_backend
=
prepare
(
model
,
device
=
'CPU'
)
output
=
prepared_backend
.
run
(
inputs
=
np_images
)
for
idx
,
value_info
in
enumerate
(
infer_shapes
.
graph
.
value_info
):
self
.
results_of_inference
[
value_info
.
name
]
=
output
[
idx
]
return
def
get_dynamic_shape_from_caffe2
(
self
,
layer
,
input_shapes
):
"""
get dynamic shape from caffe2.backend
"""
if
len
(
self
.
results_of_inference
)
==
0
:
self
.
get_results_of_inference
(
self
.
model
,
input_shapes
)
output
=
self
.
results_of_inference
[
layer
]
return
output
.
tolist
()
x2paddle/op_mapper/onnx_custom_layer/InstanceNormalization.py
浏览文件 @
43cf70b5
from
.register
import
register
from
x2paddle.core.util
import
*
def
InstanceNormalization_shape
(
input_shape
):
...
...
x2paddle/op_mapper/onnx_custom_layer/__init__.py
浏览文件 @
43cf70b5
from
.register
import
get_registered_layers
#custom layer import begins
from
.
import
InstanceNormalization
...
...
x2paddle/op_mapper/onnx_directly_map.py
浏览文件 @
43cf70b5
...
...
@@ -47,13 +47,42 @@ default_op_mapping = {
dict
(
axes
=
'dim'
,
keepdims
=
'keep_dim'
),
dict
(
keep_dim
=
1
)
],
'ReduceSum'
:
[
'reduce_sum'
,
[
'X'
],
[
'Out'
],
dict
(
axes
=
'dim'
,
keepdims
=
'keep_dim'
),
dict
(
keep_dim
=
1
)
],
#active function
'Relu'
:
[
'relu'
,
[
'X'
],
[
'Out'
]],
'LeakyRelu'
:
[
'leaky_relu'
,
[
'X'
],
[
'Out'
],
dict
(),
dict
(
alpha
=
.
01
)],
'Elu'
:
[
'elu'
,
[
'X'
],
[
'Out'
],
dict
(),
dict
(
alpha
=
1.
)],
'ThresholdedRelu'
:
[
'thresholded_relu'
,
[
'X'
],
[
'Out'
],
dict
(
alpha
=
'threshold'
),
dict
(
alpha
=
1.
)
],
'Tanh'
:
[
'tanh'
,
[
'X'
],
[
'Out'
]],
'Sigmoid'
:
[
'sigmoid'
,
[
'X'
],
[
'Out'
]],
'Pow'
:
[
'elementwise_pow'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
# TODO: pow for scalar exponent
'HardSigmoid'
:
[
'hard_sigmoid'
,
[
'X'
],
[
'Out'
],
dict
(
alpha
=
'slope'
,
beta
=
'offset'
),
dict
(
slope
=
.
2
,
offset
=
.
5
)
],
'Softsign'
:
[
'softsign'
,
[
'X'
],
[
'Out'
]],
'Softplus'
:
[
'softplus'
,
[
'X'
],
[
'Out'
]],
'Exp'
:
[
'exp'
,
[
'X'
],
[
'Out'
]],
'Softmax'
:
[
'softmax'
,
[
'X'
],
[
'Out'
],
dict
(
axis
=
''
),
dict
(
axis
=
1
)],
}
activefunc_op_mapping
=
{
'Relu'
:
[
'relu'
,
[
'X'
],
[
'Out'
]],
'LeakyRelu'
:
[
'leaky_relu'
,
[
'X'
],
[
'Out'
],
dict
(),
dict
(
alpha
=
.
01
)],
}
...
...
x2paddle/op_mapper/onnx_op_mapper.py
浏览文件 @
43cf70b5
...
...
@@ -14,7 +14,6 @@
from
x2paddle.core.graph
import
GraphNode
from
x2paddle.core.op_mapper
import
OpMapper
from
x2paddle.core.util
import
*
from
x2paddle.core.fluid_code
import
Layer
from
x2paddle.core.fluid_code
import
FluidCode
from
x2paddle.decoder.onnx_decoder
import
ONNXGraph
,
ONNXGraphNode
,
ONNXGraphDataNode
...
...
@@ -22,6 +21,7 @@ from x2paddle.op_mapper.onnx_directly_map import default_op_mapping_field_values
from
x2paddle.op_mapper.onnx_directly_map
import
default_op_mapping
from
x2paddle.op_mapper.onnx_directly_map
import
default_ioa_constraint
from
x2paddle.op_mapper.onnx_custom_layer
import
*
from
x2paddle.core.util
import
string
import
numpy
as
np
import
onnx.numpy_helper
as
numpy_helper
import
logging
as
_logging
...
...
@@ -202,6 +202,48 @@ class ONNXOpMapper(OpMapper):
val_padded
=
self
.
Pad
(
node
,
op_independent
=
False
)
return
[
0
]
*
ndims
,
val_padded
def
_interpolate
(
self
,
node
):
val_x
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
val_scales
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
out_shape_
=
val_y
.
out_shapes
[
0
]
if
out_shape_
is
not
None
:
assert
len
(
out_shape_
)
==
4
,
'only 4-D Tensor as X and Y supported'
out_shape_
=
out_shape_
[
2
:]
scales
=
_const_weight_or_none
(
val_scales
)
if
scales
is
not
None
:
assert
len
(
scales
)
==
4
,
'only 4-D Tensor as X and Y supported'
assert
scales
[
0
]
==
1
and
scales
[
1
]
==
1
,
'only scale on (NC)HW supported'
assert
scales
[
2
]
==
scales
[
3
],
'only aspect-ratio-invariant scale supported'
scale
=
scales
[
2
]
if
scales
else
None
if
scale
is
None
:
assert
out_shape_
,
'neither scales nor output shape is available'
out_shape
=
out_shape_
else
:
out_shape
=
None
if
out_shape_
is
None
:
in_shape
=
val_x
.
out_shapes
[
0
]
assert
in_shape
is
not
None
,
'out_shape required but not inferrable'
assert
len
(
in_shape
)
==
4
,
'only 4-D Tensor as X and Y supported'
out_shape_
=
[
in_shape
[
2
]
*
scale
,
in_shape
[
3
]
*
scale
]
mode
=
node
.
get_attr
(
'mode'
,
'nearest'
)
fluid_op
=
'resize_{}'
.
format
(
mode
)
attr
=
{
'scale'
:
scale
,
'out_shape'
:
out_shape
,
'name'
:
string
(
node
.
layer_name
)
}
node
.
fluid_code
.
add_layer
(
fluid_op
,
inputs
=
val_x
,
output
=
node
,
param_attr
=
attr
)
def
Pad
(
self
,
node
,
op_independent
=
True
):
val_x
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
pads
=
node
.
get_attr
(
'pads'
)
...
...
@@ -258,6 +300,17 @@ class ONNXOpMapper(OpMapper):
output
=
node
,
param_attr
=
attr
)
def
Shrink
(
self
,
node
):
val_x
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
bias
=
node
.
get_attr
(
'bias'
)
lambd
=
node
.
get_attr
(
'lambd'
)
assert
bias
==
0.0
,
'not support bias!=0'
attr
=
{
'threshold'
:
lambd
,
'name'
:
node
.
layer_name
}
node
.
fluid_code
.
add_layer
(
'hard_shrink'
,
inputs
=
val_x
,
output
=
node
,
param_attr
=
attr
)
def
Constant
(
self
,
node
):
val_output
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
...
...
@@ -278,8 +331,8 @@ class ONNXOpMapper(OpMapper):
'using value as 1-D tensor may lead to fails'
,
val_output
.
layer_name
,
val_output
.
layer_name
)
value
=
value
.
tolist
()
if
len
(
value
)
==
1
:
# scalar
value
=
value
.
tolist
()
shape
=
[
1
]
value
=
value
[
0
]
if
dtype
.
name
==
'int64'
:
...
...
@@ -289,12 +342,25 @@ class ONNXOpMapper(OpMapper):
inputs
=
None
,
output
=
node
,
param_attr
=
attr
)
else
:
value
=
np
.
reshape
(
value
,
shape
)
self
.
weights
[
node
.
layer_name
]
=
value
attr
=
{
'dtype'
:
string
(
dtype
),
'shape'
:
shape
,
'name'
:
string
(
node
.
layer_name
),
'attr'
:
string
(
node
.
layer_name
),
'default_initializer'
:
'Constant(0.0)'
}
node
.
fluid_code
.
add_layer
(
"create_parameter"
,
inputs
=
None
,
output
=
node
,
param_attr
=
attr
)
def
Resize
(
self
,
node
):
# I/O
val_x
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
val_scales
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
val_y
,
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
out_shape_
=
val_y
.
out_shapes
[
0
]
if
out_shape_
is
not
None
:
...
...
@@ -322,8 +388,6 @@ class ONNXOpMapper(OpMapper):
mode
=
node
.
get_attr
(
'mode'
,
'nearest'
)
fluid_op
=
'resize_{}'
.
format
(
mode
)
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
attr
=
{
'scale'
:
scale
,
'out_shape'
:
out_shape
,
...
...
@@ -334,6 +398,33 @@ class ONNXOpMapper(OpMapper):
output
=
node
,
param_attr
=
attr
)
def
Upsample
(
self
,
node
):
self
.
_interpolate
(
node
)
def
Slice
(
self
,
node
):
val_x
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
axes
=
node
.
get_attr
(
'axes'
)
starts
=
node
.
get_attr
(
'starts'
)
ends
=
node
.
get_attr
(
'ends'
)
shape
=
val_x
.
out_shapes
[
0
]
if
shape
is
not
None
:
for
idx
,
value
in
enumerate
(
starts
):
if
value
>
2
**
63
-
1
//
2
:
value
=
value
-
ONNX_INT_MAX
starts
[
idx
]
=
shape
[
axes
[
idx
]]
+
value
for
idx
,
value
in
enumerate
(
ends
):
if
value
>
2
**
63
-
1
//
2
:
value
=
value
-
ONNX_INT_MAX
ends
[
idx
]
=
shape
[
axes
[
idx
]]
+
value
attr
=
{
"axes"
:
axes
,
"starts"
:
starts
,
"ends"
:
ends
}
node
.
fluid_code
.
add_layer
(
'slice'
,
inputs
=
val_x
,
output
=
node
,
param_attr
=
attr
)
def
ConstantOfShape
(
self
,
node
):
val_shape
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
...
...
@@ -384,8 +475,8 @@ class ONNXOpMapper(OpMapper):
# catch dynamic graph shape
if
isinstance
(
val_shape
,
ONNXGraphNode
):
shape
=
self
.
decoder
.
get_dynamic_shape_from_caffe2
(
val_shape
.
layer_name
,
self
.
input_shapes
)
shape
,
_
,
_
=
self
.
decoder
.
onnx_graph
.
get_dynamic_shape
(
val_shape
.
layer_name
)
if
shape
is
None
:
shape
=
val_reshaped
.
out_shapes
[
0
]
...
...
@@ -440,9 +531,10 @@ class ONNXOpMapper(OpMapper):
pads
=
node
.
get_attr
(
'pads'
,
[
0
]
*
(
poolnd
*
2
))
fluid_op
=
'pool{}d'
.
format
(
poolnd
)
assert
2
<=
poolnd
<=
3
,
'only pool2d and pool3d is supported'
paddings
,
val_x
=
self
.
_pad_if_asymmetric
(
node
,
pads
,
val_x
)
input_shape
=
val_x
.
out_shapes
[
0
]
paddings
,
val_x
=
self
.
_pad_if_asymmetric
(
node
,
pads
,
val_x
)
if
auto_pad
==
"SAME_UPPER"
or
auto_pad
==
"SAME_LOWER"
:
pad_h
=
get_same_padding
(
input_shape
[
2
],
kernel_shape
[
0
],
strides
[
0
])
...
...
@@ -597,14 +689,6 @@ class ONNXOpMapper(OpMapper):
output
=
node
,
param_attr
=
attr
)
def
Softmax
(
self
,
node
):
val_x
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
attr
=
{
"name"
:
string
(
node
.
layer_name
)}
node
.
fluid_code
.
add_layer
(
"softmax"
,
inputs
=
val_x
,
output
=
node
,
param_attr
=
attr
)
def
Transpose
(
self
,
node
):
val_x
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
perm
=
node
.
get_attr
(
'perm'
)
...
...
@@ -614,15 +698,79 @@ class ONNXOpMapper(OpMapper):
output
=
node
,
param_attr
=
attr
)
def
Mul
(
self
,
node
):
val_x
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
val_x_shape
=
val_x
.
out_shapes
[
0
]
val_y_shape
=
val_y
.
out_shapes
[
0
]
slice_idx
=
0
for
dim
in
val_y_shape
:
if
dim
==
1
:
slice_idx
+=
1
else
:
break
attr
=
{
"name"
:
string
(
node
.
layer_name
)}
if
slice_idx
<
len
(
val_y_shape
)
and
slice_idx
>
0
:
val_y_reshaped
=
val_y_shape
[
slice_idx
:]
var_y_reshaped
=
val_y
.
layer_name
+
'_reshaped'
attr_reshaped
=
{
'shape'
:
val_y_reshaped
,
'name'
:
string
(
var_y_reshaped
)
}
node
.
fluid_code
.
add_layer
(
'reshape'
,
inputs
=
val_y
,
output
=
var_y_reshaped
,
param_attr
=
attr_reshaped
)
inputs
=
{
'x'
:
val_x
,
'y'
:
var_y_reshaped
}
node
.
fluid_code
.
add_layer
(
"elementwise_mul"
,
inputs
=
inputs
,
output
=
node
,
param_attr
=
attr
)
else
:
inputs
=
{
'x'
:
val_x
,
'y'
:
val_y
}
node
.
fluid_code
.
add_layer
(
"elementwise_mul"
,
inputs
=
inputs
,
output
=
node
,
param_attr
=
attr
)
def
Div
(
self
,
node
):
val_x
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
1
],
copy
=
True
)
inputs
=
{
'x'
:
val_x
,
'y'
:
val_y
}
val_x_shape
=
val_x
.
out_shapes
[
0
]
val_y_shape
=
val_y
.
out_shapes
[
0
]
slice_idx
=
0
for
dim
in
val_y_shape
:
if
dim
==
1
:
slice_idx
+=
1
else
:
break
attr
=
{
"name"
:
string
(
node
.
layer_name
)}
node
.
fluid_code
.
add_layer
(
"elementwise_div"
,
inputs
=
inputs
,
output
=
node
,
param_attr
=
attr
)
if
slice_idx
<
len
(
val_y_shape
)
and
slice_idx
>
0
:
val_y_reshaped
=
val_y_shape
[
slice_idx
:]
var_y_reshaped
=
val_y
.
layer_name
+
'_reshaped'
attr_reshaped
=
{
'shape'
:
val_y_reshaped
,
'name'
:
string
(
var_y_reshaped
)
}
node
.
fluid_code
.
add_layer
(
'reshape'
,
inputs
=
val_y
,
output
=
var_y_reshaped
,
param_attr
=
attr_reshaped
)
inputs
=
{
'x'
:
val_x
,
'y'
:
var_y_reshaped
}
node
.
fluid_code
.
add_layer
(
"elementwise_div"
,
inputs
=
inputs
,
output
=
node
,
param_attr
=
attr
)
else
:
inputs
=
{
'x'
:
val_x
,
'y'
:
val_y
}
node
.
fluid_code
.
add_layer
(
"elementwise_div"
,
inputs
=
inputs
,
output
=
node
,
param_attr
=
attr
)
def
Relu
(
self
,
node
):
val_x
=
self
.
graph
.
get_node
(
node
.
layer
.
input
[
0
],
copy
=
True
)
...
...
@@ -679,9 +827,10 @@ class ONNXOpMapper(OpMapper):
pads
=
node
.
get_attr
(
'pads'
,
[
0
]
*
(
poolnd
*
2
))
# optional
fluid_op
=
'pool{}d'
.
format
(
poolnd
)
assert
2
<=
poolnd
<=
3
,
'only pool2d and pool3d is supported'
paddings
,
val_x
=
self
.
_pad_if_asymmetric
(
node
,
pads
,
val_x
)
input_shape
=
val_x
.
out_shapes
[
0
]
paddings
,
val_x
=
self
.
_pad_if_asymmetric
(
node
,
pads
,
val_x
)
if
auto_pad
==
"SAME_UPPER"
or
auto_pad
==
"SAME_LOWER"
:
pad_h
=
get_same_padding
(
input_shape
[
2
],
kernel_shape
[
0
],
strides
[
0
])
...
...
@@ -731,7 +880,6 @@ class ONNXOpMapper(OpMapper):
val_y
=
self
.
graph
.
get_node
(
node
.
layer
.
output
[
0
],
copy
=
True
)
self
.
omit_nodes
.
append
(
val_w
.
layer_name
)
input_shape
=
val_x
.
out_shapes
[
0
]
has_bias
=
len
(
node
.
layer
.
input
)
==
3
if
has_bias
:
...
...
@@ -752,6 +900,7 @@ class ONNXOpMapper(OpMapper):
dilations
=
node
.
get_attr
(
'dilations'
,
[
1
]
*
convnd
)
# optional
pads
=
node
.
get_attr
(
'pads'
,
[
0
]
*
(
convnd
*
2
))
# optional
input_shape
=
val_x
.
out_shapes
[
0
]
paddings
,
val_x
=
self
.
_pad_if_asymmetric
(
node
,
pads
,
val_x
)
if
auto_pad
==
"SAME_UPPER"
or
auto_pad
==
"SAME_LOWER"
:
...
...
@@ -796,14 +945,14 @@ class ONNXOpMapper(OpMapper):
assert
kernel_shape
,
'kernel_shape not inferred'
convnd
=
len
(
kernel_shape
)
assert
2
<=
convnd
<=
3
,
'only conv2d_transpose and conv3d_transpose supported'
num_out_channels
=
val_w
.
out_shapes
[
0
][
1
]
# IO...
num_out_channels
=
val_w
.
out_shapes
[
0
][
1
]
fluid_op
=
'conv{}d_transpose'
.
format
(
convnd
)
num_groups
=
node
.
get_attr
(
'group'
,
1
)
# optional
strides
=
node
.
get_attr
(
'strides'
,
[
1
]
*
convnd
)
# optional
dilations
=
node
.
get_attr
(
'dilations'
,
[
1
]
*
convnd
)
# optional
output_size
=
node
.
get_attr
(
'output_shape'
,
[])
# optional
pads
=
node
.
get_attr
(
'pads'
,
[
0
]
*
(
convnd
*
2
))
# optional
num_groups
=
node
.
get_attr
(
'group'
,
1
)
strides
=
node
.
get_attr
(
'strides'
,
[
1
]
*
convnd
)
dilations
=
node
.
get_attr
(
'dilations'
,
[
1
]
*
convnd
)
output_size
=
node
.
get_attr
(
'output_shape'
,
[])
pads
=
node
.
get_attr
(
'pads'
,
[
0
]
*
(
convnd
*
2
))
paddings
,
var_x
=
self
.
_pad_if_asymmetric
(
node
,
pads
,
val_x
)
...
...
@@ -831,3 +980,39 @@ class ONNXOpMapper(OpMapper):
inputs
=
val_x
,
output
=
node
,
param_attr
=
attr
)
# def NonMaxSuppression(self, node):
# boxes = self.graph.get_node(node.layer.input[0], copy=True)
# scores = self.graph.get_node(node.layer.input[1], copy=True)
# max_output_boxes_per_class = self.graph.get_node(node.layer.input[2], copy=True)
# iou_threshold = self.graph.get_node(node.layer.input[3], copy=True)
# score_threshold = self.graph.get_node(node.layer.input[4], copy=True)
# self.omit_nodes.append(max_output_boxes_per_class)
# self.omit_nodes.append(iou_threshold)
# self.omit_nodes.append(score_threshold)
# iou_threshold_val = iou_threshold.weight
# center_point_box = node.get_attr('center_point_box', 0)
# score_threshold_val = score_threshold.weight
# attr = {
# 'num_filters': num_out_channels,
# 'output_size': output_size or None,
# 'filter_size': kernel_shape,
# 'padding': paddings,
# 'stride': strides,
# 'dilation': dilations,
# 'groups': num_groups,
# 'param_attr': string(val_w.layer_name),
# 'bias_attr': string(val_b.layer_name),
# 'name': string(node.layer_name),
# }
# node.fluid_code.add_layer('multiclass_nms',
# inputs= boxes.layer_name ',' + scores.layer_name,
# output=node,
# param_attr=attr)
# pass
x2paddle/optimizer/onnx_optimizer.py
浏览文件 @
43cf70b5
...
...
@@ -14,7 +14,6 @@
# TODO useless node remove
from
x2paddle.op_mapper.onnx_op_mapper
import
ONNXOpMapper
from
x2paddle.core.util
import
*
class
ONNXOptimizer
(
object
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录