Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PARL
提交
6596320f
P
PARL
项目概览
PaddlePaddle
/
PARL
通知
67
Star
3
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
18
列表
看板
标记
里程碑
合并请求
3
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PARL
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
18
Issue
18
列表
看板
标记
里程碑
合并请求
3
合并请求
3
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6596320f
编写于
10月 16, 2019
作者:
H
Hongsheng Zeng
提交者:
Bo Zhou
10月 16, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug in _get_parameter_names function (#159)
上级
2ddf4c11
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
63 addition
and
19 deletion
+63
-19
parl/core/fluid/model.py
parl/core/fluid/model.py
+19
-19
parl/core/fluid/tests/model_base_test_.py
parl/core/fluid/tests/model_base_test_.py
+44
-0
未找到文件。
parl/core/fluid/model.py
浏览文件 @
6596320f
...
...
@@ -273,33 +273,33 @@ class Model(ModelBase):
set_value
(
param_name
,
weight
,
is_gpu_available
)
def
_get_parameter_names
(
self
,
obj
):
""" Recursively get parameter names in a
model and its child attributes
.
""" Recursively get parameter names in a
n object
.
Args:
obj (
``parl.Model``): an instance of ``Model``
obj (
Object): any object
Returns:
parameter_names (list): all parameter names in this
model
.
parameter_names (list): all parameter names in this
object
.
"""
parameter_names
=
[]
for
attr
in
sorted
(
obj
.
__dict__
.
keys
()
):
val
=
getattr
(
obj
,
attr
)
if
isinstance
(
val
,
Model
):
if
isinstance
(
obj
,
Model
):
for
attr
in
sorted
(
obj
.
__dict__
.
keys
()):
val
=
getattr
(
obj
,
attr
)
parameter_names
.
extend
(
self
.
_get_parameter_names
(
val
))
elif
isinstance
(
val
,
LayerFunc
):
for
attr
in
val
.
attr_holder
.
sorted
():
if
attr
:
parameter_names
.
append
(
attr
.
name
)
elif
isinstance
(
val
,
tuple
)
or
isinstance
(
val
,
list
):
for
x
in
val
:
parameter_names
.
extend
(
self
.
_get_parameter_names
(
x
))
elif
isinstance
(
val
,
dict
):
for
x
in
list
(
val
.
values
()):
parameter_names
.
extend
(
self
.
_get_parameter_names
(
x
))
else
:
# for any other type, won't be handled. E.g. set
pass
elif
isinstance
(
obj
,
LayerFunc
):
for
attr
in
obj
.
attr_holder
.
sorted
():
if
attr
:
parameter_names
.
append
(
attr
.
name
)
elif
isinstance
(
obj
,
tuple
)
or
isinstance
(
obj
,
list
):
for
x
in
obj
:
parameter_names
.
extend
(
self
.
_get_parameter_names
(
x
))
elif
isinstance
(
obj
,
dict
):
for
x
in
list
(
obj
.
values
()):
parameter_names
.
extend
(
self
.
_get_parameter_names
(
x
))
else
:
# for any other type, won't be handled. E.g. set
pass
return
parameter_names
def
_get_parameter_pairs
(
self
,
src
,
target
):
...
...
parl/core/fluid/tests/model_base_test_.py
浏览文件 @
6596320f
...
...
@@ -84,6 +84,42 @@ class TestModel4(parl.Model):
return
out
class
TestModel6
(
parl
.
Model
):
def
__init__
(
self
):
self
.
fc1
=
layers
.
fc
(
size
=
256
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
'fc1.w'
),
bias_attr
=
ParamAttr
(
name
=
'fc1.b'
))
self
.
fc_tuple
=
(
layers
.
fc
(
size
=
128
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
'fc2.w'
),
bias_attr
=
ParamAttr
(
name
=
'fc2.b'
)),
(
layers
.
fc
(
size
=
1
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
'fc3.w'
),
bias_attr
=
ParamAttr
(
name
=
'fc3.b'
)),
10
),
10
)
self
.
fc_dict
=
{
'k1'
:
layers
.
fc
(
size
=
128
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
'fc4.w'
),
bias_attr
=
ParamAttr
(
name
=
'fc4.b'
)),
'k2'
:
{
'k22'
:
layers
.
fc
(
size
=
1
,
act
=
None
,
param_attr
=
ParamAttr
(
name
=
'fc5.w'
),
bias_attr
=
ParamAttr
(
name
=
'fc5.b'
))
},
'k3'
:
1
,
}
class
ModelBaseTest
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
model
=
TestModel
()
...
...
@@ -139,6 +175,14 @@ class ModelBaseTest(unittest.TestCase):
set
(
self
.
model
.
parameters
()),
set
([
'fc1.w'
,
'fc1.b'
,
'fc2.w'
,
'fc2.b'
,
'fc3.w'
,
'fc3.b'
]))
model2
=
TestModel6
()
self
.
assertSetEqual
(
set
(
model2
.
parameters
()),
set
([
'fc1.w'
,
'fc1.b'
,
'fc2.w'
,
'fc2.b'
,
'fc3.w'
,
'fc3.b'
,
'fc4.w'
,
'fc4.b'
,
'fc5.w'
,
'fc5.b'
]))
def
test_sync_weights_in_one_program
(
self
):
pred_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
pred_program
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录