Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
d6636846
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
d6636846
编写于
1月 13, 2021
作者:
C
Channingss
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize code
上级
ae5a1811
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
21 addition
and
24 deletion
+21
-24
x2paddle/core/program.py
x2paddle/core/program.py
+2
-2
x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py
x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py
+19
-22
未找到文件。
x2paddle/core/program.py
浏览文件 @
d6636846
...
...
@@ -521,7 +521,7 @@ class PaddleGraph(object):
gen_codes
(
comment_list
,
indent
=
1
))
use_structured_name
=
False
if
self
.
source_type
in
[
"tf"
,
"onnx"
]
else
True
use_structured_name
=
False
if
self
.
source_type
in
[
"tf"
]
else
True
self
.
run_func
.
extend
(
gen_codes
([
"paddle.disable_static()"
,
"params = paddle.load('{}/model.pdparams')"
.
format
(
osp
.
abspath
(
code_dir
)),
...
...
@@ -673,7 +673,7 @@ class PaddleGraph(object):
paddle
.
disable_static
()
restore
=
paddle
.
load
(
osp
.
join
(
save_dir
,
"model.pdparams"
))
model
=
getattr
(
x2paddle_code
,
self
.
name
)()
if
self
.
source_type
in
[
"tf"
,
"onnx"
]:
if
self
.
source_type
in
[
"tf"
]:
model
.
set_dict
(
restore
,
use_structured_name
=
False
)
else
:
model
.
set_dict
(
restore
)
...
...
x2paddle/op_mapper/dygraph/onnx2paddle/opset9/opset.py
浏览文件 @
d6636846
...
...
@@ -1898,7 +1898,7 @@ class OpSet9():
reform_permutation
=
[(
0
,
1
),
(
2
,
4
),
(
1
,
2
)]
input_weight_np
,
hidden_weight_np
,
input_bias_np
,
hidden_bias_np
=
transform_weight_with_bias
(
weights
=
transform_weight_with_bias
(
[
input_weight_np
,
hidden_weight_np
,
input_bias_np
,
hidden_bias_np
],
hidden_size
,
reform_permutation
)
...
...
@@ -1907,30 +1907,27 @@ class OpSet9():
yh_out
=
node
.
output
(
1
)
yc_out
=
node
.
output
(
2
)
direction
=
node
.
get_attr
(
'direction'
,
'forward'
)
if
direction
==
'backward'
:
raise
Exception
(
"LSTM support 'forward' or 'bidirectional', except '{}'."
.
format
(
direction
))
elif
direction
==
'forward'
:
self
.
weights
[
input_weight
.
name
]
=
input_weight_np
.
squeeze
(
0
)
self
.
weights
[
hidden_weight
.
name
]
=
hidden_weight_np
.
squeeze
(
0
)
self
.
weights
[
input_bias_name
]
=
input_bias_np
.
squeeze
(
0
)
self
.
weights
[
hidden_bias_name
]
=
hidden_bias_np
.
squeeze
(
0
)
else
:
def
generate_paddle_param_names
(
op_name
,
suffix
=
''
):
param_names
=
[]
for
direct
in
range
(
2
):
suffix
=
'_reverse'
if
direct
==
1
else
''
param_names
.
extend
([
'{}.weight_ih_l0{}'
,
'{}.weight_hh_l0{}'
])
if
have_bias
!=
False
:
param_names
.
append
(
'{}.bias_ih_l0{}'
)
if
have_bias
!=
False
:
param_names
.
append
(
'{}.bias_hh_l0{}'
)
param_names
=
[
x
.
format
(
op_name
,
suffix
)
for
x
in
param_names
]
return
param_names
self
.
weights
[
param_names
[
0
]]
=
input_weight_np
[
0
]
self
.
weights
[
param_names
[
4
]]
=
input_weight_np
[
1
]
self
.
weights
[
param_names
[
1
]]
=
hidden_weight_np
[
0
]
self
.
weights
[
param_names
[
5
]]
=
hidden_weight_np
[
1
]
self
.
weights
[
param_names
[
2
]]
=
input_bias_np
[
0
]
self
.
weights
[
param_names
[
6
]]
=
input_bias_np
[
1
]
self
.
weights
[
param_names
[
3
]]
=
hidden_bias_np
[
0
]
self
.
weights
[
param_names
[
7
]]
=
hidden_bias_np
[
1
]
def
assign_params
(
op_name
,
weights
,
weight_idx
=
0
,
suffix
=
''
):
param_names
=
generate_paddle_param_names
(
op_name
,
suffix
)
print
(
param_names
)
for
param_name
,
weight
in
zip
(
param_names
,
weights
):
self
.
weights
[
param_name
]
=
weight
[
weight_idx
]
if
direction
==
'backward'
:
raise
Exception
(
"LSTM support 'forward' or 'bidirectional', except '{}'."
.
format
(
direction
))
else
:
assign_params
(
op_name
,
weights
)
if
direction
==
'bidirectional'
:
assign_params
(
op_name
,
weights
,
1
,
'_reverse'
)
self
.
paddle_graph
.
add_layer
(
'paddle.nn.LSTM'
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录