Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
96102e2f
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
1 年多 前同步成功
通知
283
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
96102e2f
编写于
1月 09, 2019
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
save unique name generator and reload it when creating module
上级
7e3c5dda
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
31 addition
and
19 deletion
+31
-19
paddle_hub/module.py
paddle_hub/module.py
+25
-19
paddle_hub/module_creator.py
paddle_hub/module_creator.py
+6
-0
未找到文件。
paddle_hub/module.py
浏览文件 @
96102e2f
...
...
@@ -27,6 +27,7 @@ import pickle
from
collections
import
defaultdict
from
paddle_hub.downloader
import
download_and_uncompress
from
paddle_hub
import
module_desc_pb2
from
paddle_hub.config
import
RunConfig
,
ParamTrainConfig
__all__
=
[
"Module"
,
"ModuleConfig"
,
"ModuleUtils"
]
DICT_NAME
=
"dict.txt"
...
...
@@ -86,6 +87,13 @@ class Module(object):
self
.
config
=
ModuleConfig
(
self
.
module_dir
)
self
.
config
.
load
()
self
.
_process_parameter
()
#TODO(wuzewu): recover the default unique name generator someother where
self
.
_process_uqn
()
def
_process_uqn
(
self
):
filepath
=
os
.
path
.
join
(
self
.
module_dir
,
"uqn.pkl"
)
with
open
(
filepath
,
"rb"
)
as
file
:
fluid
.
unique_name
.
switch
(
pickle
.
load
(
file
))
def
_process_parameter
(
self
):
global_block
=
self
.
inference_program
.
global_block
()
...
...
@@ -116,27 +124,25 @@ class Module(object):
return
feed_dict
def
__call__
(
self
,
inputs
=
None
,
sign_name
=
"default"
):
def
__call__
(
self
,
sign_name
=
"default"
,
run_config
=
None
):
""" Call default signature and return results
"""
# word_ids_lod_tensor = self._preprocess_input(inputs)
feed_dict
=
self
.
_construct_feed_dict
(
inputs
)
print
(
"feed_dict"
,
feed_dict
)
ret_numpy
=
self
.
config
.
return_numpy
()
print
(
"ret_numpy"
,
ret_numpy
)
results
=
self
.
exe
.
run
(
self
.
inference_program
,
#feed={self.feed_target_names[0]: word_ids_lod_tensor},
feed
=
feed_dict
,
fetch_list
=
self
.
fetch_targets
,
return_numpy
=
ret_numpy
)
print
(
"module fetch_target_names"
,
self
.
feed_target_names
)
print
(
"module fetch_targets"
,
self
.
fetch_targets
)
np_result
=
np
.
array
(
results
[
0
])
return
np_result
def
_set_param_trainable
(
program
,
trainable
=
False
):
for
param
in
program
.
global_block
().
iter_parameters
():
param
.
trainable
=
trainable
if
not
run_config
:
run_config
=
RunConfig
()
program
=
self
.
get_inference_program
().
clone
()
if
run_config
.
param_train_config
==
ParamTrainConfig
.
PARAM_TRAIN_ALL
:
_set_param_trainable
(
program
=
program
,
trainable
=
True
)
elif
run_config
.
param_train_config
==
ParamTrainConfig
.
PARAM_TRAIN_ALL
:
_set_param_trainable
(
program
=
program
,
trainable
=
False
)
return
self
.
feed_target_names
,
self
.
fetch_targets
,
program
def
get_vars
(
self
):
"""
...
...
paddle_hub/module_creator.py
浏览文件 @
96102e2f
...
...
@@ -46,6 +46,12 @@ def create_module(sign_arr, program, path=None, assets=None):
module
.
contain_assets
=
True
os
.
makedirs
(
os
.
path
.
join
(
path
,
"assets"
))
# save the unique name object
generator
=
fluid
.
unique_name
.
generator
pklname
=
os
.
path
.
join
(
path
,
"uqn.pkl"
)
with
open
(
pklname
,
"wb"
)
as
file
:
pickle
.
dump
(
generator
,
file
)
# save fluid Parameter
param_arr
=
[]
for
param
in
program
.
global_block
().
iter_parameters
():
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录