Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
acbb1cac
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
acbb1cac
编写于
1月 17, 2022
作者:
W
wjj19950828
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Simplified code
上级
53f8175d
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
47 addition
and
41 deletion
+47
-41
x2paddle/op_mapper/onnx2paddle/opset9/opset.py
x2paddle/op_mapper/onnx2paddle/opset9/opset.py
+47
-41
未找到文件。
x2paddle/op_mapper/onnx2paddle/opset9/opset.py
浏览文件 @
acbb1cac
...
@@ -45,7 +45,8 @@ def _const_weight_or_none(node, necessary=False):
...
@@ -45,7 +45,8 @@ def _const_weight_or_none(node, necessary=False):
def
_rename_or_remove_weight
(
weights
,
def
_rename_or_remove_weight
(
weights
,
origin_name
,
origin_name
,
target_name
=
None
,
target_name
=
None
,
is_remove
=
True
):
is_remove
=
True
,
rename_mapper
=
None
):
'''
'''
Rename parameters by Paddle's naming rule of parameters.
Rename parameters by Paddle's naming rule of parameters.
...
@@ -56,9 +57,13 @@ def _rename_or_remove_weight(weights,
...
@@ -56,9 +57,13 @@ def _rename_or_remove_weight(weights,
{target_name:weights[origin_name]} to weights, and target_name must follow paddle's
{target_name:weights[origin_name]} to weights, and target_name must follow paddle's
naming rule of parameters. Default: None.
naming rule of parameters. Default: None.
is_remove: if is_remove is True, remove origin key-value pair. Default: True.
is_remove: if is_remove is True, remove origin key-value pair. Default: True.
rename_mapper: Solved the same data is used for multiple OPs, key is old_name, value is new_name.
Returns:
Returns:
None
None
'''
'''
if
origin_name
in
rename_mapper
:
origin_name
=
rename_mapper
[
origin_name
]
is_remove
=
False
if
origin_name
not
in
weights
:
if
origin_name
not
in
weights
:
raise
KeyError
(
'{} not a key in {}'
.
format
(
origin_name
,
weights
.
keys
()))
raise
KeyError
(
'{} not a key in {}'
.
format
(
origin_name
,
weights
.
keys
()))
if
is_remove
:
if
is_remove
:
...
@@ -69,6 +74,7 @@ def _rename_or_remove_weight(weights,
...
@@ -69,6 +74,7 @@ def _rename_or_remove_weight(weights,
if
target_name
is
not
None
:
if
target_name
is
not
None
:
# rename weight
# rename weight
weights
[
target_name
]
=
data
weights
[
target_name
]
=
data
rename_mapper
[
origin_name
]
=
target_name
def
_is_static_shape
(
shape
):
def
_is_static_shape
(
shape
):
...
@@ -1682,38 +1688,26 @@ class OpSet9():
...
@@ -1682,38 +1688,26 @@ class OpSet9():
c
=
val_x
.
out_shapes
[
0
][
1
]
c
=
val_x
.
out_shapes
[
0
][
1
]
# solved the same data is used as an argument to multiple OPs.
# solved the same data is used as an argument to multiple OPs.
if
val_scale
.
name
in
self
.
rename_mapper
:
_rename_or_remove_weight
(
new_name
=
self
.
rename_mapper
[
val_scale
.
name
]
self
.
weights
,
_rename_or_remove_weight
(
self
.
weights
,
new_name
,
val_scale
.
name
,
op_name
+
'.weight'
,
False
)
op_name
+
'.weight'
,
else
:
rename_mapper
=
self
.
rename_mapper
)
_rename_or_remove_weight
(
self
.
weights
,
val_scale
.
name
,
_rename_or_remove_weight
(
op_name
+
'.weight'
)
self
.
weights
,
self
.
rename_mapper
[
val_scale
.
name
]
=
op_name
+
'.weight'
val_b
.
name
,
if
val_b
.
name
in
self
.
rename_mapper
:
op_name
+
'.bias'
,
new_name
=
self
.
rename_mapper
[
val_b
.
name
]
rename_mapper
=
self
.
rename_mapper
)
_rename_or_remove_weight
(
self
.
weights
,
new_name
,
op_name
+
'.bias'
,
_rename_or_remove_weight
(
False
)
self
.
weights
,
else
:
val_var
.
name
,
_rename_or_remove_weight
(
self
.
weights
,
val_b
.
name
,
op_name
+
'._variance'
,
op_name
+
'.bias'
)
rename_mapper
=
self
.
rename_mapper
)
self
.
rename_mapper
[
val_b
.
name
]
=
op_name
+
'.bias'
_rename_or_remove_weight
(
if
val_var
.
name
in
self
.
rename_mapper
:
self
.
weights
,
new_name
=
self
.
rename_mapper
[
val_var
.
name
]
val_mean
.
name
,
_rename_or_remove_weight
(
self
.
weights
,
new_name
,
op_name
+
'._mean'
,
op_name
+
'._variance'
,
False
)
rename_mapper
=
self
.
rename_mapper
)
else
:
_rename_or_remove_weight
(
self
.
weights
,
val_var
.
name
,
op_name
+
'._variance'
)
self
.
rename_mapper
[
val_var
.
name
]
=
op_name
+
'._variance'
if
val_mean
.
name
in
self
.
rename_mapper
:
new_name
=
self
.
rename_mapper
[
val_mean
.
name
]
_rename_or_remove_weight
(
self
.
weights
,
new_name
,
op_name
+
'._mean'
,
False
)
else
:
_rename_or_remove_weight
(
self
.
weights
,
val_mean
.
name
,
op_name
+
'._mean'
)
self
.
rename_mapper
[
val_mean
.
name
]
=
op_name
+
'._mean'
# Attribute: spatial is used in BatchNormalization-1,6,7
# Attribute: spatial is used in BatchNormalization-1,6,7
spatial
=
bool
(
node
.
get_attr
(
'spatial'
))
spatial
=
bool
(
node
.
get_attr
(
'spatial'
))
...
@@ -2255,14 +2249,22 @@ class OpSet9():
...
@@ -2255,14 +2249,22 @@ class OpSet9():
remove_weight
=
True
if
val_w
.
name
in
self
.
done_weight_list
else
False
remove_weight
=
True
if
val_w
.
name
in
self
.
done_weight_list
else
False
if
remove_weight
:
if
remove_weight
:
self
.
done_weight_list
.
append
(
val_w
.
name
)
self
.
done_weight_list
.
append
(
val_w
.
name
)
_rename_or_remove_weight
(
self
.
weights
,
val_w
.
name
,
op_name
+
'.weight'
,
_rename_or_remove_weight
(
remove_weight
)
self
.
weights
,
val_w
.
name
,
op_name
+
'.weight'
,
remove_weight
,
rename_mapper
=
self
.
rename_mapper
)
if
has_bias
:
if
has_bias
:
remove_bias
=
True
if
val_b
.
name
in
self
.
done_weight_list
else
False
remove_bias
=
True
if
val_b
.
name
in
self
.
done_weight_list
else
False
if
remove_bias
:
if
remove_bias
:
self
.
done_weight_list
.
append
(
val_b_name
)
self
.
done_weight_list
.
append
(
val_b
.
name
)
_rename_or_remove_weight
(
self
.
weights
,
val_b
.
name
,
_rename_or_remove_weight
(
op_name
+
'.bias'
,
remove_bias
)
self
.
weights
,
val_b
.
name
,
op_name
+
'.bias'
,
remove_bias
,
rename_mapper
=
self
.
rename_mapper
)
else
:
else
:
layer_attrs
[
"bias_attr"
]
=
False
layer_attrs
[
"bias_attr"
]
=
False
if
reduce
(
lambda
x
,
y
:
x
*
y
,
if
reduce
(
lambda
x
,
y
:
x
*
y
,
...
@@ -2382,10 +2384,14 @@ class OpSet9():
...
@@ -2382,10 +2384,14 @@ class OpSet9():
_rename_or_remove_weight
(
_rename_or_remove_weight
(
self
.
weights
,
self
.
weights
,
val_w
.
name
,
val_w
.
name
,
op_name
+
'.weight'
,
)
op_name
+
'.weight'
,
rename_mapper
=
self
.
rename_mapper
)
if
val_b
is
not
None
:
if
val_b
is
not
None
:
_rename_or_remove_weight
(
self
.
weights
,
val_b
.
name
,
_rename_or_remove_weight
(
op_name
+
'.bias'
)
self
.
weights
,
val_b
.
name
,
op_name
+
'.bias'
,
rename_mapper
=
self
.
rename_mapper
)
else
:
else
:
layer_attrs
[
"bias_attr"
]
=
False
layer_attrs
[
"bias_attr"
]
=
False
self
.
paddle_graph
.
add_layer
(
self
.
paddle_graph
.
add_layer
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录