Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
a27f4307
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a27f4307
编写于
2月 21, 2019
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
record variable info
上级
b125aa1b
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
39 addition
and
14 deletion
+39
-14
paddle_hub/module.py
paddle_hub/module.py
+39
-13
paddle_hub/utils.py
paddle_hub/utils.py
+0
-1
未找到文件。
paddle_hub/module.py
浏览文件 @
a27f4307
...
...
@@ -30,7 +30,7 @@ from paddle_hub.downloader import download_and_uncompress
from
paddle_hub
import
module_desc_pb2
from
paddle_hub.logger
import
logger
from
paddle_hub.signature
import
Signature
from
paddle_hub.utils
import
to_list
,
mkdir
from
paddle_hub.utils
import
to_list
,
mkdir
,
from_pyobj_to_flexible_data
,
from_flexible_data_to_pyobj
from
paddle_hub.paddle_helper
import
from_param_to_flexible_data
,
get_variable_info
,
from_flexible_data_to_param
from
paddle_hub.version
import
__version__
...
...
@@ -91,6 +91,19 @@ class Module(object):
stop_gradient
=
var
.
stop_gradient
,
is_data
=
var
.
is_data
)
def
_process_variable_info
(
self
):
var_infos
=
self
.
config
.
desc
.
extra_info
.
map
.
data
[
'var_infos'
]
for
var_info
in
var_infos
.
map
.
data
:
idx
=
from_flexible_data_to_pyobj
(
var_infos
.
map
.
data
[
var_info
].
map
.
data
[
'block_id'
])
stop_gradient
=
from_flexible_data_to_pyobj
(
var_infos
.
map
.
data
[
var_info
].
map
.
data
[
'stop_gradient'
])
block
=
self
.
inference_program
.
blocks
[
idx
]
var_name
=
HUB_VAR_PREFIX
+
var_info
if
var_name
in
block
.
vars
:
var
=
block
.
vars
[
var_name
]
var
.
stop_gradient
=
stop_gradient
def
__call__
(
self
,
sign_name
=
"default"
,
trainable
=
False
):
""" Call default signature and return results
"""
...
...
@@ -139,21 +152,23 @@ class Module(object):
logger
.
info
(
"**feed_target_names**
\n
{}"
.
format
(
self
.
feed_target_names
))
logger
.
info
(
"**fetch_targets**
\n
{}"
.
format
(
self
.
fetch_targets
))
self
.
_process_parameter
()
self
.
_process_variable_info
()
program
=
self
.
get_inference_program
().
clone
()
_process_op_attr
(
program
=
program
,
is_test
=
False
)
_set_param_trainable
(
program
=
program
,
trainable
=
trainable
)
_process_op_attr
(
program
=
self
.
inference_program
,
is_test
=
False
)
_set_param_trainable
(
program
=
self
.
inference_program
,
trainable
=
trainable
)
for
key
,
value
in
feed_dict
.
items
():
var
=
program
.
global_block
().
var
(
HUB_VAR_PREFIX
+
value
)
var
=
self
.
inference_program
.
global_block
().
var
(
HUB_VAR_PREFIX
+
value
)
feed_dict
[
key
]
=
var
for
key
,
value
in
fetch_dict
.
items
():
var
=
program
.
global_block
().
var
(
HUB_VAR_PREFIX
+
value
)
var
=
self
.
inference_program
.
global_block
().
var
(
HUB_VAR_PREFIX
+
value
)
fetch_dict
[
key
]
=
var
return
feed_dict
,
fetch_dict
,
program
return
feed_dict
,
fetch_dict
,
self
.
inference_
program
def
get_inference_program
(
self
):
return
self
.
inference_program
...
...
@@ -256,7 +271,7 @@ class ModuleConfig(object):
return
os
.
path
.
join
(
meta_path
,
PARAM_FILENAME
)
def
create_module
(
sign_arr
,
module_dir
=
None
,
word_dict
=
None
,
plac
e
=
None
):
def
create_module
(
sign_arr
,
module_dir
=
None
,
word_dict
=
None
,
ex
e
=
None
):
""" Create a module from main program
"""
assert
sign_arr
,
"signature array should not be None"
...
...
@@ -291,7 +306,6 @@ def create_module(sign_arr, module_dir=None, word_dict=None, place=None):
module_desc
.
auth_info
.
paddle_version
=
paddle
.
__version__
logger
.
info
(
"hub version is %s"
%
__version__
)
logger
.
info
(
"paddle version is %s"
%
paddle
.
__version__
)
program
=
program
.
clone
()
# save asset
if
word_dict
is
None
:
...
...
@@ -312,9 +326,20 @@ def create_module(sign_arr, module_dir=None, word_dict=None, place=None):
param_attr
=
param_attrs
.
map
.
data
[
param
.
name
]
from_param_to_flexible_data
(
param
,
param_attr
)
# save Variable Info
var_infos
=
extra_info
.
map
.
data
[
'var_infos'
]
var_infos
.
type
=
module_desc_pb2
.
MAP
for
block
in
program
.
blocks
:
for
var
in
block
.
vars
.
values
():
var_info
=
var_infos
.
map
.
data
[
var
.
name
]
var_info
.
type
=
module_desc_pb2
.
MAP
from_pyobj_to_flexible_data
(
var
.
stop_gradient
,
var_info
.
map
.
data
[
'stop_gradient'
])
from_pyobj_to_flexible_data
(
block
.
idx
,
var_info
.
map
.
data
[
'block_id'
])
# save signarture info
sign_map
=
module_desc
.
sign2var
program
=
sign_arr
[
0
].
get_inputs
()[
0
].
block
.
program
for
sign
in
sign_arr
:
if
sign
.
get_name
()
in
sign_map
:
raise
"Error! sign_arr contains repeat signatrue %s"
%
sign
...
...
@@ -335,7 +360,8 @@ def create_module(sign_arr, module_dir=None, word_dict=None, place=None):
fetch_var
.
alias
=
fetch_names
[
index
]
# save inference program
if
not
place
:
program
=
program
.
clone
()
if
not
exe
:
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
=
place
)
save_model_dir
=
os
.
path
.
join
(
module_dir
,
"model"
)
...
...
paddle_hub/utils.py
浏览文件 @
a27f4307
...
...
@@ -66,7 +66,6 @@ def get_pykey(key, keyed_type):
#TODO(wuzewu): solving the problem of circular references
def
from_pyobj_to_flexible_data
(
pyobj
,
flexible_data
,
obj_filter
=
None
):
if
obj_filter
and
obj_filter
(
pyobj
):
logger
.
info
(
"filter python object"
)
return
if
isinstance
(
pyobj
,
bool
):
flexible_data
.
type
=
module_desc_pb2
.
BOOLEAN
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录