Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
c9865824
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c9865824
编写于
6月 29, 2017
作者:
D
dangqingqing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support to init partial network parameters from the tar file.
上级
1a0fdb9e
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
65 addition
and
15 deletion
+65
-15
python/paddle/v2/parameters.py
python/paddle/v2/parameters.py
+13
-10
python/paddle/v2/tests/test_parameters.py
python/paddle/v2/tests/test_parameters.py
+52
-5
未找到文件。
python/paddle/v2/parameters.py
浏览文件 @
c9865824
...
@@ -51,7 +51,7 @@ class Parameters(object):
...
@@ -51,7 +51,7 @@ class Parameters(object):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
__param_conf__
=
dict
()
self
.
__param_conf__
=
dict
()
self
.
__gradient_machines__
=
[]
self
.
__gradient_machines__
=
[]
self
.
__tmp_params__
=
[]
self
.
__tmp_params__
=
dict
()
def
__append_config__
(
self
,
param_conf
):
def
__append_config__
(
self
,
param_conf
):
"""
"""
...
@@ -128,12 +128,9 @@ class Parameters(object):
...
@@ -128,12 +128,9 @@ class Parameters(object):
if
len
(
self
.
__gradient_machines__
)
==
0
:
if
len
(
self
.
__gradient_machines__
)
==
0
:
# create new parameter in python numpy.
# create new parameter in python numpy.
if
len
(
self
.
__tmp_params__
)
!=
0
:
if
key
in
self
.
__tmp_params__
:
ret_list
=
[
return
self
.
__tmp_params__
[
key
]
mat
for
name
,
mat
in
self
.
__tmp_params__
if
name
==
key
else
:
]
if
len
(
ret_list
)
==
1
:
return
ret_list
[
0
]
return
np
.
ndarray
(
shape
=
shape
,
dtype
=
np
.
float32
)
return
np
.
ndarray
(
shape
=
shape
,
dtype
=
np
.
float32
)
else
:
else
:
for
each_gradient_machine
in
self
.
__gradient_machines__
:
for
each_gradient_machine
in
self
.
__gradient_machines__
:
...
@@ -187,7 +184,7 @@ class Parameters(object):
...
@@ -187,7 +184,7 @@ class Parameters(object):
(
shape
,
value
.
shape
))
(
shape
,
value
.
shape
))
if
len
(
self
.
__gradient_machines__
)
==
0
:
if
len
(
self
.
__gradient_machines__
)
==
0
:
self
.
__tmp_params__
.
append
((
key
,
value
))
self
.
__tmp_params__
[
key
]
=
value
else
:
else
:
for
each_gradient_machine
in
self
.
__gradient_machines__
:
for
each_gradient_machine
in
self
.
__gradient_machines__
:
__copy_parameter_to_gradient_machine__
(
each_gradient_machine
,
__copy_parameter_to_gradient_machine__
(
each_gradient_machine
,
...
@@ -231,7 +228,7 @@ class Parameters(object):
...
@@ -231,7 +228,7 @@ class Parameters(object):
raise
ValueError
(
"gradient_machine should be api.GradientMachine"
)
raise
ValueError
(
"gradient_machine should be api.GradientMachine"
)
if
len
(
self
.
__tmp_params__
)
!=
0
:
if
len
(
self
.
__tmp_params__
)
!=
0
:
for
name
,
val
in
self
.
__tmp_params__
:
for
name
,
val
in
self
.
__tmp_params__
.
iteritems
()
:
try
:
try
:
__copy_parameter_to_gradient_machine__
(
gradient_machine
,
__copy_parameter_to_gradient_machine__
(
gradient_machine
,
name
,
val
)
name
,
val
)
...
@@ -302,6 +299,12 @@ class Parameters(object):
...
@@ -302,6 +299,12 @@ class Parameters(object):
params
.
deserialize
(
param_name
,
f
)
params
.
deserialize
(
param_name
,
f
)
return
params
return
params
def
init_from_tar
(
self
,
f
):
tar_param
=
self
.
from_tar
(
f
)
for
pname
in
tar_param
.
names
():
if
pname
in
self
.
names
():
self
.
set
(
pname
,
tar_param
.
get
(
pname
))
def
__get_parameter_in_gradient_machine__
(
gradient_machine
,
name
):
def
__get_parameter_in_gradient_machine__
(
gradient_machine
,
name
):
"""
"""
...
...
python/paddle/v2/tests/test_parameters.py
浏览文件 @
c9865824
...
@@ -20,14 +20,17 @@ import cStringIO
...
@@ -20,14 +20,17 @@ import cStringIO
import
numpy
import
numpy
def
__rand_param_config__
(
name
):
def
__rand_param_config__
(
name
,
psize
=
None
):
conf
=
ParameterConfig
()
conf
=
ParameterConfig
()
conf
.
name
=
name
conf
.
name
=
name
size
=
1
size
=
1
if
psize
is
None
:
for
i
in
xrange
(
2
):
for
i
in
xrange
(
2
):
dim
=
random
.
randint
(
1
,
1000
)
dim
=
random
.
randint
(
1
,
1000
)
conf
.
dims
.
append
(
dim
)
conf
.
dims
.
append
(
dim
)
size
*=
dim
size
*=
dim
else
:
size
=
psize
conf
.
size
=
size
conf
.
size
=
size
assert
conf
.
IsInitialized
()
assert
conf
.
IsInitialized
()
return
conf
return
conf
...
@@ -77,6 +80,50 @@ class TestParameters(unittest.TestCase):
...
@@ -77,6 +80,50 @@ class TestParameters(unittest.TestCase):
expected
=
numpy
.
array
([[
1
,
1
],
[
1
,
2
],
[
1
,
1
]],
numpy
.
float32
)
expected
=
numpy
.
array
([[
1
,
1
],
[
1
,
2
],
[
1
,
1
]],
numpy
.
float32
)
assert
numpy
.
logical_and
.
reduce
(
numpy
.
reshape
(
val
==
expected
,
6
))
assert
numpy
.
logical_and
.
reduce
(
numpy
.
reshape
(
val
==
expected
,
6
))
def
test_init_from_tar
(
self
):
def
get_param
(
names
,
size
):
p
=
parameters
.
Parameters
()
for
k
,
v
in
zip
(
names
,
size
):
p
.
__append_config__
(
__rand_param_config__
(
k
,
v
))
for
name
in
p
.
names
():
param
=
p
.
get
(
name
)
param
[:]
=
numpy
.
random
.
uniform
(
-
1.0
,
1.0
,
size
=
p
.
get_shape
(
name
))
p
.
set
(
name
,
param
)
return
p
def
get_parames
():
name1
=
[
'param_0'
,
'param_1'
]
size1
=
[
128
,
256
]
p1
=
get_param
(
name1
,
size1
)
file1
=
cStringIO
.
StringIO
()
p1
.
to_tar
(
file1
)
file1
.
seek
(
0
)
name2
=
[
'param_0'
,
'param_1'
,
'param_2'
]
size2
=
[
128
,
256
,
288
]
p2
=
get_param
(
name2
,
size2
)
file2
=
cStringIO
.
StringIO
()
p2
.
to_tar
(
file2
)
file2
.
seek
(
0
)
return
p1
,
file1
,
p2
,
file2
p1
,
file1
,
p2
,
file2
=
get_parames
()
p2
.
init_from_tar
(
file1
)
for
name
in
p1
.
names
():
self
.
assertEqual
(
p1
.
get_shape
(
name
),
p2
.
get_shape
(
name
))
v1
=
p1
.
get
(
name
)
v2
=
p2
.
get
(
name
)
self
.
assertTrue
(
numpy
.
isclose
(
v1
,
v2
).
all
())
p1
,
file1
,
p2
,
file2
=
get_parames
()
p1
.
init_from_tar
(
file2
)
for
name
in
p1
.
names
():
self
.
assertEqual
(
p1
.
get_shape
(
name
),
p2
.
get_shape
(
name
))
v1
=
p1
.
get
(
name
)
v2
=
p2
.
get
(
name
)
self
.
assertTrue
(
numpy
.
isclose
(
v1
,
v2
).
all
())
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录