Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
bfedfda8
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bfedfda8
编写于
2月 13, 2019
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
optimize the method of preserving model
上级
a5da11b6
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
42 addition
and
38 deletion
+42
-38
paddle_hub/module.py
paddle_hub/module.py
+41
-38
paddle_hub/utils.py
paddle_hub/utils.py
+1
-0
未找到文件。
paddle_hub/module.py
浏览文件 @
bfedfda8
...
...
@@ -166,7 +166,7 @@ class Module(object):
model_dir
=
os
.
path
.
join
(
self
.
module_dir
,
MODEL_DIRNAME
)
self
.
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
self
.
inference_program
,
self
.
feed_target_names
,
self
.
fetch_targets
=
fluid
.
io
.
load_inference_model
(
dirname
=
os
.
path
.
join
(
model_dir
,
sign_name
)
,
executor
=
self
.
exe
)
model_dir
,
executor
=
self
.
exe
)
feed_dict
,
fetch_dict
=
_process_input_output_key
(
self
.
config
.
desc
,
sign_name
)
...
...
@@ -293,7 +293,7 @@ class ModuleConfig(object):
return
os
.
path
.
join
(
meta_path
,
PARAM_FILENAME
)
def
create_module
(
sign_arr
,
module_dir
=
None
,
word_dict
=
None
):
def
create_module
(
sign_arr
,
module_dir
=
None
,
word_dict
=
None
,
place
=
None
):
""" Create a module from main program
"""
assert
sign_arr
,
"signature array should not be None"
...
...
@@ -301,15 +301,19 @@ def create_module(sign_arr, module_dir=None, word_dict=None):
# check all variable
sign_arr
=
to_list
(
sign_arr
)
program
=
sign_arr
[
0
].
get_inputs
()[
0
].
block
.
program
feeded_var_names
=
set
()
target_vars
=
set
()
for
sign
in
sign_arr
:
assert
isinstance
(
sign
,
Signature
),
"sign_arr should be list of Signature"
for
input
in
sign
.
get_inputs
():
feeded_var_names
.
add
(
input
.
name
)
_tmp_program
=
input
.
block
.
program
assert
program
==
_tmp_program
,
"all the variable should come from the same program"
for
output
in
sign
.
get_outputs
():
target_vars
.
add
(
output
)
_tmp_program
=
output
.
block
.
program
assert
program
==
_tmp_program
,
"all the variable should come from the same program"
...
...
@@ -401,42 +405,41 @@ def create_module(sign_arr, module_dir=None, word_dict=None):
fetch_var
.
alias
=
fetch_names
[
index
]
# save inference program
exe
=
fluid
.
Executor
(
place
=
fluid
.
CPUPlace
())
model_dir
=
os
.
path
.
join
(
module_dir
,
"model"
)
mkdir
(
model_dir
)
# TODO(wuzewu): save paddle model with a more effective way
for
sign
in
sign_arr
:
save_model_dir
=
os
.
path
.
join
(
model_dir
,
sign
.
get_name
())
fluid
.
io
.
save_inference_model
(
save_model_dir
,
feeded_var_names
=
[
var
.
name
for
var
in
sign
.
get_inputs
()],
target_vars
=
sign
.
get_outputs
(),
main_program
=
program
,
executor
=
exe
)
with
open
(
os
.
path
.
join
(
save_model_dir
,
"__model__"
),
"rb"
)
as
file
:
program_desc_str
=
file
.
read
()
rename_program
=
fluid
.
framework
.
Program
.
parse_from_string
(
program_desc_str
)
varlist
=
{
var
:
block
for
block
in
rename_program
.
blocks
for
var
in
block
.
vars
if
HUB_VAR_PREFIX
not
in
var
}
for
var
,
block
in
varlist
.
items
():
old_name
=
var
new_name
=
HUB_VAR_PREFIX
+
old_name
block
.
_rename_var
(
old_name
,
new_name
)
mkdir
(
save_model_dir
)
with
open
(
os
.
path
.
join
(
save_model_dir
,
"__model__"
),
"wb"
)
as
f
:
f
.
write
(
rename_program
.
desc
.
serialize_to_string
())
for
file
in
os
.
listdir
(
save_model_dir
):
if
(
file
==
"__model__"
or
HUB_VAR_PREFIX
in
file
):
continue
os
.
rename
(
os
.
path
.
join
(
save_model_dir
,
file
),
os
.
path
.
join
(
save_model_dir
,
HUB_VAR_PREFIX
+
file
))
if
not
place
:
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
=
place
)
save_model_dir
=
os
.
path
.
join
(
module_dir
,
"model"
)
mkdir
(
save_model_dir
)
fluid
.
io
.
save_inference_model
(
save_model_dir
,
feeded_var_names
=
list
(
feeded_var_names
),
target_vars
=
list
(
target_vars
),
main_program
=
program
,
executor
=
exe
)
with
open
(
os
.
path
.
join
(
save_model_dir
,
"__model__"
),
"rb"
)
as
file
:
program_desc_str
=
file
.
read
()
rename_program
=
fluid
.
framework
.
Program
.
parse_from_string
(
program_desc_str
)
varlist
=
{
var
:
block
for
block
in
rename_program
.
blocks
for
var
in
block
.
vars
if
HUB_VAR_PREFIX
not
in
var
}
for
var
,
block
in
varlist
.
items
():
old_name
=
var
new_name
=
HUB_VAR_PREFIX
+
old_name
block
.
_rename_var
(
old_name
,
new_name
)
mkdir
(
save_model_dir
)
with
open
(
os
.
path
.
join
(
save_model_dir
,
"__model__"
),
"wb"
)
as
f
:
f
.
write
(
rename_program
.
desc
.
serialize_to_string
())
for
file
in
os
.
listdir
(
save_model_dir
):
if
(
file
==
"__model__"
or
HUB_VAR_PREFIX
in
file
):
continue
os
.
rename
(
os
.
path
.
join
(
save_model_dir
,
file
),
os
.
path
.
join
(
save_model_dir
,
HUB_VAR_PREFIX
+
file
))
# Serialize module_desc pb
module_pb
=
module_desc
.
SerializeToString
()
...
...
paddle_hub/utils.py
浏览文件 @
bfedfda8
...
...
@@ -19,6 +19,7 @@ from __future__ import division
from
__future__
import
print_function
import
paddle
import
paddle.fluid
as
fluid
import
os
def
to_list
(
input
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录