Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleHub
提交
c9ac67d6
P
PaddleHub
项目概览
PaddlePaddle
/
PaddleHub
大约 1 年 前同步成功
通知
282
Star
12117
Fork
2091
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
200
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleHub
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
200
Issue
200
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c9ac67d6
编写于
1月 21, 2019
作者:
W
wuzewu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
save param attr with pb format
上级
86470a5d
变更
3
展开全部
隐藏空白更改
内联
并排
Showing
3 changed file
with
848 addition
and
31 deletion
+848
-31
paddle_hub/module.py
paddle_hub/module.py
+81
-18
paddle_hub/module_desc.proto
paddle_hub/module_desc.proto
+43
-0
paddle_hub/module_desc_pb2.py
paddle_hub/module_desc_pb2.py
+724
-13
未找到文件。
paddle_hub/module.py
浏览文件 @
c9ac67d6
...
...
@@ -80,11 +80,48 @@ class Module(object):
def
_process_parameter
(
self
):
global_block
=
self
.
inference_program
.
global_block
()
param_path
=
ModuleConfig
.
meta_param_path
(
self
.
module_dir
)
with
open
(
param_path
,
"rb"
)
as
file
:
param_arr
=
pickle
.
load
(
file
)
for
param
in
param_arr
:
param
[
'name'
]
=
HUB_VAR_PREFIX
+
param
[
'name'
]
param_attrs
=
self
.
config
.
desc
.
param_attrs
for
key
,
param_attr
in
param_attrs
.
items
():
param
=
{}
param
[
'name'
]
=
HUB_VAR_PREFIX
+
key
param
[
'trainable'
]
=
param_attr
.
trainable
param
[
'do_model_average'
]
=
param_attr
.
do_model_average
param
[
'optimize_attr'
]
=
{}
param
[
'optimize_attr'
][
'learning_rate'
]
=
param_attr
.
optimize_attr
.
m
[
'learning_rate'
].
f
# TODO(wuzewu): recover the param attr with a more reliable way
if
param_attr
.
regularizer
.
type
==
"L2DecayRegularizer"
:
regularizer
=
fluid
.
regularizer
.
L2DecayRegularizer
(
regularization_coeff
=
param_attr
.
regularizer
.
regularization_coeff
)
elif
param_attr
.
regularizer
.
type
==
"L1DecayRegularizer"
:
regularizer
=
fluid
.
regularizer
.
L1DecayRegularizer
(
regularization_coeff
=
param_attr
.
regularizer
.
regularization_coeff
)
else
:
regularizer
=
None
param
[
'regularizer'
]
=
regularizer
if
param_attr
.
gradient_clip_attr
.
type
==
"ErrorClipByValue"
:
clip
=
fluid
.
clip
.
ErrorClipByValue
(
max
=
param_attr
.
gradient_clip_attr
.
max
,
min
=
param_attr
.
gradient_clip_attr
.
min
)
elif
param_attr
.
gradient_clip_attr
.
type
==
"GradientClipByValue"
:
clip
=
fluid
.
clip
.
GradientClipByValue
(
max
=
param_attr
.
gradient_clip_attr
.
max
,
min
=
param_attr
.
gradient_clip_attr
.
min
)
elif
param_attr
.
gradient_clip_attr
.
type
==
"GradientClipByNorm"
:
clip
=
fluid
.
clip
.
GradientClipByNorm
(
clip_norm
=
param_attr
.
gradient_clip_attr
.
clip_norm
)
elif
param_attr
.
gradient_clip_attr
.
type
==
"GradientClipByGlobalNorm"
:
clip
=
fluid
.
clip
.
GradientClipByNorm
(
clip_norm
=
param_attr
.
gradient_clip_attr
.
clip_norm
,
group_name
=
param_attr
.
gradient_clip_attr
.
group_name
)
else
:
clip
=
None
param
[
'gradient_clip_attr'
]
=
clip
if
(
param
[
'name'
]
not
in
global_block
.
vars
):
continue
var
=
global_block
.
var
(
param
[
'name'
])
...
...
@@ -309,20 +346,46 @@ def create_module(sign_arr, module_dir=None, word_dict=None):
fo
.
write
(
"{}
\t
{}
\n
"
.
format
(
w
,
w_id
))
# save fluid Parameter
param_a
rr
=
[]
param_a
ttrs
=
module_desc
.
param_attrs
for
param
in
program
.
global_block
().
iter_parameters
():
param_info
=
{
'name'
:
param
.
name
,
'regularizer'
:
param
.
regularizer
,
'gradient_clip_attr'
:
param
.
gradient_clip_attr
,
'trainable'
:
param
.
trainable
,
'optimize_attr'
:
param
.
optimize_attr
,
'do_model_average'
:
param
.
do_model_average
}
param_arr
.
append
(
param_info
)
with
open
(
ModuleConfig
.
meta_param_path
(
module_dir
),
"wb"
)
as
fo
:
pickle
.
dump
(
param_arr
,
fo
)
param_attr
=
param_attrs
[
param
.
name
]
param_attr
.
trainable
=
param
.
trainable
if
param
.
do_model_average
:
param_attr
.
do_model_average
=
param
.
do_model_average
# TODO(wuzewu): add a func to transfer python dict to fexiable data
param_attr
.
optimize_attr
.
type
=
module_desc_pb2
.
MAP
param_attr
.
optimize_attr
.
m
[
'learning_rate'
].
type
=
module_desc_pb2
.
FLOAT
param_attr
.
optimize_attr
.
m
[
'learning_rate'
].
f
=
param
.
optimize_attr
[
'learning_rate'
]
if
param
.
regularizer
:
if
isinstance
(
param
.
regularizer
,
fluid
.
regularizer
.
L2DecayRegularizer
):
param_attr
.
regularizer
.
type
=
"L2DecayRegularizer"
if
isinstance
(
param
.
regularizer
,
fluid
.
regularizer
.
L1DecayRegularizer
):
param_attr
.
regularizer
.
type
=
"L1DecayRegularizer"
param_attr
.
regularizer
.
regularization_coeff
=
param
.
regularizer
.
regularization_coeff
if
param
.
gradient_clip_attr
:
if
isinstance
(
param
.
gradient_clip_attr
,
fluid
.
clip
.
ErrorClipByValue
):
param_attr
.
gradient_clip_attr
.
max
=
param
.
gradient_clip_attr
.
max
param_attr
.
gradient_clip_attr
.
min
=
param
.
gradient_clip_attr
.
min
param_attr
.
gradient_clip_attr
.
type
=
"ErrorClipByValue"
if
isinstance
(
param
.
gradient_clip_attr
,
fluid
.
clip
.
GradientClipByValue
):
param_attr
.
gradient_clip_attr
.
max
=
param
.
gradient_clip_attr
.
max
param_attr
.
gradient_clip_attr
.
min
=
param
.
gradient_clip_attr
.
min
param_attr
.
gradient_clip_attr
.
type
=
"GradientClipByValue"
if
isinstance
(
param
.
gradient_clip_attr
,
fluid
.
clip
.
GradientClipByNorm
):
param_attr
.
gradient_clip_attr
.
clip_norm
=
param
.
gradient_clip_attr
.
clip_norm
param_attr
.
gradient_clip_attr
.
type
=
"GradientClipByNorm"
if
isinstance
(
param
.
gradient_clip_attr
,
fluid
.
clip
.
GradientClipByGlobalNorm
):
param_attr
.
gradient_clip_attr
.
clip_norm
=
param
.
gradient_clip_attr
.
clip_norm
param_attr
.
gradient_clip_attr
.
group_name
=
param
.
gradient_clip_attr
.
group_name
param_attr
.
gradient_clip_attr
.
type
=
"GradientClipByGlobalNorm"
# save signarture info
sign_map
=
module_desc
.
sign2var
...
...
paddle_hub/module_desc.proto
浏览文件 @
c9ac67d6
...
...
@@ -18,6 +18,26 @@ option optimize_for = LITE_RUNTIME;
package
paddle_hub
;
enum
DataType
{
INT
=
0
;
FLOAT
=
1
;
STRING
=
2
;
BOOLEAN
=
3
;
LIST
=
4
;
MAP
=
5
;
}
message
FlexibleData
{
DataType
type
=
1
;
string
name
=
2
;
int32
i
=
3
;
float
f
=
4
;
bool
b
=
5
;
string
s
=
6
;
map
<
string
,
FlexibleData
>
m
=
7
;
map
<
int32
,
FlexibleData
>
l
=
8
;
}
// Feed Variable Description
message
FeedDesc
{
string
var_name
=
1
;
...
...
@@ -41,6 +61,27 @@ message AuthInfo {
string
hub_version
=
2
;
}
message
ParamAttr
{
message
Regularizer
{
string
type
=
1
;
float
regularization_coeff
=
2
;
}
message
GradientClipAttr
{
string
type
=
1
;
float
min
=
2
;
float
max
=
3
;
float
clip_norm
=
4
;
string
group_name
=
5
;
}
Regularizer
regularizer
=
1
;
GradientClipAttr
gradient_clip_attr
=
2
;
FlexibleData
optimize_attr
=
3
;
bool
trainable
=
4
;
bool
do_model_average
=
5
;
}
// A Hub Module is stored in a directory with a file 'paddlehub.pb'
// containing a serialized protocol message of this type. The further contents
// of the directory depend on the storage format described by the message.
...
...
@@ -56,5 +97,7 @@ message ModuleDesc {
bool
contain_assets
=
4
;
AuthInfo
auth_info
=
5
;
map
<
string
,
ParamAttr
>
param_attrs
=
6
;
};
paddle_hub/module_desc_pb2.py
浏览文件 @
c9ac67d6
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录