Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
86470a5d
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
1 年多 前同步成功
通知
283
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
86470a5d
编写于
1月 21, 2019
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
solve the problem of duplicate name by rename program var when create module
上级
8ce73b13
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
33 addition
and
34 deletion
+33
-34
paddle_hub/module.py
paddle_hub/module.py
+33
-34
未找到文件。
paddle_hub/module.py
浏览文件 @
86470a5d
...
...
@@ -43,7 +43,8 @@ MODEL_DIRNAME = "model"
DICT_FILENAME
=
"vocab.txt"
PARAM_FILENAME
=
"param.pkl"
MODULE_DESC_PBNAME
=
"module_desc.pb"
GENERATOR_FILENAME
=
"unique_name_generator.pkl"
# paddle hub var prefix
HUB_VAR_PREFIX
=
"@HUB@"
def
mkdir
(
path
):
...
...
@@ -83,6 +84,7 @@ class Module(object):
with
open
(
param_path
,
"rb"
)
as
file
:
param_arr
=
pickle
.
load
(
file
)
for
param
in
param_arr
:
param
[
'name'
]
=
HUB_VAR_PREFIX
+
param
[
'name'
]
if
(
param
[
'name'
]
not
in
global_block
.
vars
):
continue
var
=
global_block
.
var
(
param
[
'name'
])
...
...
@@ -146,9 +148,6 @@ class Module(object):
print
(
"**feed_target_names**
\n
{}"
.
format
(
self
.
feed_target_names
))
print
(
"**fetch_targets**
\n
{}"
.
format
(
self
.
fetch_targets
))
self
.
_process_parameter
()
name_generator_path
=
ModuleConfig
.
name_generator_path
(
self
.
module_dir
)
with
open
(
name_generator_path
,
"rb"
)
as
data
:
generator
=
pickle
.
load
(
data
)
program
=
self
.
get_inference_program
().
clone
()
...
...
@@ -156,14 +155,14 @@ class Module(object):
_set_param_trainable
(
program
=
program
,
trainable
=
trainable
)
for
key
,
value
in
feed_dict
.
items
():
var
=
program
.
global_block
().
var
(
value
)
var
=
program
.
global_block
().
var
(
HUB_VAR_PREFIX
+
value
)
feed_dict
[
key
]
=
var
for
key
,
value
in
fetch_dict
.
items
():
var
=
program
.
global_block
().
var
(
value
)
var
=
program
.
global_block
().
var
(
HUB_VAR_PREFIX
+
value
)
fetch_dict
[
key
]
=
var
return
feed_dict
,
fetch_dict
,
program
,
generator
return
feed_dict
,
fetch_dict
,
program
def
get_inference_program
(
self
):
return
self
.
inference_program
...
...
@@ -253,12 +252,6 @@ class ModuleConfig(object):
def
module_desc_path
(
module_dir
):
return
os
.
path
.
join
(
module_dir
,
MODULE_DESC_PBNAME
)
@
staticmethod
def
name_generator_path
(
module_dir
):
meta_path
=
os
.
path
.
join
(
module_dir
,
META_DIRNAME
)
mkdir
(
meta_path
)
return
os
.
path
.
join
(
meta_path
,
GENERATOR_FILENAME
)
@
staticmethod
def
assets_dict_path
(
module_dir
):
assets_path
=
os
.
path
.
join
(
module_dir
,
ASSETS_DIRNAME
)
...
...
@@ -271,12 +264,6 @@ class ModuleConfig(object):
mkdir
(
meta_path
)
return
os
.
path
.
join
(
meta_path
,
PARAM_FILENAME
)
@
staticmethod
def
meta_name_generator_path
(
module_dir
):
meta_path
=
os
.
path
.
join
(
module_dir
,
META_DIRNAME
)
mkdir
(
meta_path
)
return
os
.
path
.
join
(
meta_path
,
GENERATOR_FILENAME
)
def
create_module
(
sign_arr
,
module_dir
=
None
,
word_dict
=
None
):
""" Create a module from main program
...
...
@@ -321,19 +308,6 @@ def create_module(sign_arr, module_dir=None, word_dict=None):
w_id
=
word_dict
[
w
]
fo
.
write
(
"{}
\t
{}
\n
"
.
format
(
w
,
w_id
))
# save the unique name generator object
var_name_arr
=
[
'_'
.
join
(
var
.
split
(
'@'
)[
0
].
split
(
'.'
)[
0
].
split
(
'_'
)[
0
:
-
1
])
for
block
in
program
.
blocks
for
var
in
block
.
vars
]
with
fluid
.
unique_name
.
guard
():
for
var_name
in
var_name_arr
:
fluid
.
unique_name
.
generate
(
var_name
)
generator
=
fluid
.
unique_name
.
generator
with
open
(
ModuleConfig
.
name_generator_path
(
module_dir
),
"wb"
)
as
fo
:
pickle
.
dump
(
generator
,
fo
)
# save fluid Parameter
param_arr
=
[]
for
param
in
program
.
global_block
().
iter_parameters
():
...
...
@@ -386,6 +360,30 @@ def create_module(sign_arr, module_dir=None, word_dict=None):
main_program
=
program
,
executor
=
exe
)
with
open
(
os
.
path
.
join
(
save_model_dir
,
"__model__"
),
"rb"
)
as
file
:
program_desc_str
=
file
.
read
()
rename_program
=
fluid
.
framework
.
Program
.
parse_from_string
(
program_desc_str
)
varlist
=
{
var
:
block
for
block
in
rename_program
.
blocks
for
var
in
block
.
vars
if
HUB_VAR_PREFIX
not
in
var
}
for
var
,
block
in
varlist
.
items
():
old_name
=
var
new_name
=
HUB_VAR_PREFIX
+
old_name
block
.
_rename_var
(
old_name
,
new_name
)
mkdir
(
save_model_dir
)
with
open
(
os
.
path
.
join
(
save_model_dir
,
"__model__"
),
"wb"
)
as
f
:
f
.
write
(
rename_program
.
desc
.
serialize_to_string
())
for
file
in
os
.
listdir
(
save_model_dir
):
if
(
file
==
"__model__"
or
HUB_VAR_PREFIX
in
file
):
continue
os
.
rename
(
os
.
path
.
join
(
save_model_dir
,
file
),
os
.
path
.
join
(
save_model_dir
,
HUB_VAR_PREFIX
+
file
))
# Serialize module_desc pb
module_pb
=
module_desc
.
SerializeToString
()
with
open
(
ModuleConfig
.
module_desc_path
(
module_dir
),
"wb"
)
as
f
:
...
...
@@ -410,7 +408,8 @@ class ModuleUtils(object):
for
index
in
need_to_remove_op_index
[::
-
1
]:
block
.
_remove_op
(
index
)
block
.
_remove_var
(
"feed"
)
block
.
_remove_var
(
"fetch"
)
# TODO(wuzewu): get feed and fetch var by other way
block
.
_remove_var
(
HUB_VAR_PREFIX
+
"feed"
)
block
.
_remove_var
(
HUB_VAR_PREFIX
+
"fetch"
)
program
.
desc
.
flush
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录