Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
48031dd4
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
48031dd4
编写于
3月 06, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'feature/serialize_deserialize_in_parameters' into feature/recommendation_v2_api
上级
d4327b68
c36a3f46
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
144 addition
and
32 deletion
+144
-32
demo/mnist/.gitignore
demo/mnist/.gitignore
+2
-0
demo/mnist/api_train_v2.py
demo/mnist/api_train_v2.py
+15
-7
python/paddle/v2/parameters.py
python/paddle/v2/parameters.py
+66
-24
python/paddle/v2/tests/run_tests.sh
python/paddle/v2/tests/run_tests.sh
+1
-1
python/paddle/v2/tests/test_parameters.py
python/paddle/v2/tests/test_parameters.py
+60
-0
未找到文件。
demo/mnist/.gitignore
浏览文件 @
48031dd4
...
@@ -6,3 +6,5 @@ train.log
...
@@ -6,3 +6,5 @@ train.log
*pyc
*pyc
.ipynb_checkpoints
.ipynb_checkpoints
params.pkl
params.pkl
params.tar
params.tar.gz
demo/mnist/api_train_v2.py
浏览文件 @
48031dd4
import
paddle.v2
as
paddle
import
paddle.v2
as
paddle
import
cPickle
import
gzip
def
softmax_regression
(
img
):
def
softmax_regression
(
img
):
...
@@ -73,8 +73,8 @@ def main():
...
@@ -73,8 +73,8 @@ def main():
cost
=
paddle
.
layer
.
classification_cost
(
input
=
predict
,
label
=
label
)
cost
=
paddle
.
layer
.
classification_cost
(
input
=
predict
,
label
=
label
)
try
:
try
:
with
open
(
'params.pkl
'
,
'r'
)
as
f
:
with
gzip
.
open
(
'params.tar.gz
'
,
'r'
)
as
f
:
parameters
=
cPickle
.
load
(
f
)
parameters
=
paddle
.
parameters
.
Parameters
.
from_tar
(
f
)
except
IOError
:
except
IOError
:
parameters
=
paddle
.
parameters
.
create
(
cost
)
parameters
=
paddle
.
parameters
.
create
(
cost
)
...
@@ -91,10 +91,18 @@ def main():
...
@@ -91,10 +91,18 @@ def main():
def
event_handler
(
event
):
def
event_handler
(
event
):
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
isinstance
(
event
,
paddle
.
event
.
EndIteration
):
if
event
.
batch_id
%
100
==
0
:
if
event
.
batch_id
%
1000
==
0
:
print
"Pass %d, Batch %d, Cost %f, %s"
%
(
result
=
trainer
.
test
(
reader
=
paddle
.
reader
.
batched
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
,
event
.
metrics
)
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
256
))
if
isinstance
(
event
,
paddle
.
event
.
EndPass
):
print
"Pass %d, Batch %d, Cost %f, %s, Testing metrics %s"
%
(
event
.
pass_id
,
event
.
batch_id
,
event
.
cost
,
event
.
metrics
,
result
.
metrics
)
with
gzip
.
open
(
'params.tar.gz'
,
'w'
)
as
f
:
parameters
.
to_tar
(
f
)
elif
isinstance
(
event
,
paddle
.
event
.
EndPass
):
result
=
trainer
.
test
(
reader
=
paddle
.
reader
.
batched
(
result
=
trainer
.
test
(
reader
=
paddle
.
reader
.
batched
(
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
128
))
paddle
.
dataset
.
mnist
.
test
(),
batch_size
=
128
))
print
"Test with Pass %d, Cost %f, %s
\n
"
%
(
print
"Test with Pass %d, Cost %f, %s
\n
"
%
(
...
...
python/paddle/v2/parameters.py
浏览文件 @
48031dd4
import
numpy
as
np
import
numpy
as
np
import
py_paddle.swig_paddle
as
api
import
py_paddle.swig_paddle
as
api
from
paddle.proto.ParameterConfig_pb2
import
ParameterConfig
from
paddle.proto.ParameterConfig_pb2
import
ParameterConfig
import
struct
import
tarfile
import
cStringIO
from
topology
import
Topology
from
topology
import
Topology
__all__
=
[
'Parameters'
,
'create'
]
__all__
=
[
'Parameters'
,
'create'
]
...
@@ -122,6 +124,12 @@ class Parameters(object):
...
@@ -122,6 +124,12 @@ class Parameters(object):
if
len
(
self
.
__gradient_machines__
)
==
0
:
if
len
(
self
.
__gradient_machines__
)
==
0
:
# create new parameter in python numpy.
# create new parameter in python numpy.
if
len
(
self
.
__tmp_params__
)
!=
0
:
ret_list
=
[
mat
for
name
,
mat
in
self
.
__tmp_params__
if
name
==
key
]
if
len
(
ret_list
)
==
1
:
return
ret_list
[
0
]
return
np
.
ndarray
(
shape
=
shape
,
dtype
=
np
.
float32
)
return
np
.
ndarray
(
shape
=
shape
,
dtype
=
np
.
float32
)
else
:
else
:
for
each_gradient_machine
in
self
.
__gradient_machines__
:
for
each_gradient_machine
in
self
.
__gradient_machines__
:
...
@@ -228,32 +236,66 @@ class Parameters(object):
...
@@ -228,32 +236,66 @@ class Parameters(object):
self
.
__gradient_machines__
.
append
(
gradient_machine
)
self
.
__gradient_machines__
.
append
(
gradient_machine
)
def
__getstate__
(
self
):
def
serialize
(
self
,
name
,
f
):
params
=
{}
"""
for
name
in
self
.
names
():
params
[
name
]
=
self
.
get
(
name
)
param_conf
=
{}
for
name
in
self
.
__param_conf__
:
conf
=
self
.
__param_conf__
[
name
]
assert
isinstance
(
conf
,
ParameterConfig
)
param_conf
[
name
]
=
conf
.
SerializeToString
()
return
{
'conf'
:
param_conf
,
'params'
:
params
}
def
__setstate__
(
self
,
obj
):
:param name:
Parameters
.
__init__
(
self
)
:param f:
:type f: file
:return:
"""
param
=
self
.
get
(
name
)
size
=
reduce
(
lambda
a
,
b
:
a
*
b
,
param
.
shape
)
f
.
write
(
struct
.
pack
(
"IIQ"
,
0
,
4
,
size
))
param
=
param
.
astype
(
np
.
float32
)
f
.
write
(
param
.
tobytes
())
def
__impl__
(
conf
,
params
):
def
deserialize
(
self
,
name
,
f
):
for
name
in
conf
:
"""
p
=
ParameterConfig
()
p
.
ParseFromString
(
conf
[
name
])
self
.
__append_config__
(
p
)
for
name
in
params
:
shape
=
self
.
get_shape
(
name
)
self
.
set
(
name
,
params
[
name
].
reshape
(
shape
))
__impl__
(
**
obj
)
:param name:
:param f:
:type f: file
:return:
"""
f
.
read
(
16
)
# header
arr
=
np
.
frombuffer
(
f
.
read
(),
dtype
=
np
.
float32
)
self
.
set
(
name
,
arr
.
reshape
(
self
.
get_shape
(
name
)))
def
to_tar
(
self
,
f
):
tar
=
tarfile
.
TarFile
(
fileobj
=
f
,
mode
=
'w'
)
for
nm
in
self
.
names
():
buf
=
cStringIO
.
StringIO
()
self
.
serialize
(
nm
,
buf
)
tarinfo
=
tarfile
.
TarInfo
(
name
=
nm
)
buf
.
seek
(
0
)
tarinfo
.
size
=
len
(
buf
.
getvalue
())
tar
.
addfile
(
tarinfo
,
buf
)
conf
=
self
.
__param_conf__
[
nm
]
confStr
=
conf
.
SerializeToString
()
tarinfo
=
tarfile
.
TarInfo
(
name
=
"%s.protobuf"
%
nm
)
tarinfo
.
size
=
len
(
confStr
)
buf
=
cStringIO
.
StringIO
(
confStr
)
buf
.
seek
(
0
)
tar
.
addfile
(
tarinfo
,
fileobj
=
buf
)
@
staticmethod
def
from_tar
(
f
):
params
=
Parameters
()
tar
=
tarfile
.
TarFile
(
fileobj
=
f
,
mode
=
'r'
)
for
finfo
in
tar
:
assert
isinstance
(
finfo
,
tarfile
.
TarInfo
)
if
finfo
.
name
.
endswith
(
'.protobuf'
):
f
=
tar
.
extractfile
(
finfo
)
conf
=
ParameterConfig
()
conf
.
ParseFromString
(
f
.
read
())
params
.
__append_config__
(
conf
)
for
param_name
in
params
.
names
():
f
=
tar
.
extractfile
(
param_name
)
params
.
deserialize
(
param_name
,
f
)
return
params
def
__get_parameter_in_gradient_machine__
(
gradient_machine
,
name
):
def
__get_parameter_in_gradient_machine__
(
gradient_machine
,
name
):
...
...
python/paddle/v2/tests/run_tests.sh
浏览文件 @
48031dd4
...
@@ -22,7 +22,7 @@ cd $SCRIPTPATH
...
@@ -22,7 +22,7 @@ cd $SCRIPTPATH
$1
-m
pip
install
../../../../paddle/dist/
*
.whl
$1
-m
pip
install
../../../../paddle/dist/
*
.whl
test_list
=
"test_data_feeder.py"
test_list
=
"test_data_feeder.py
test_parameters.py
"
export
PYTHONPATH
=
$PWD
/../../../../python/
export
PYTHONPATH
=
$PWD
/../../../../python/
...
...
python/paddle/v2/tests/test_parameters.py
0 → 100644
浏览文件 @
48031dd4
import
unittest
import
sys
try
:
import
py_paddle
del
py_paddle
except
ImportError
:
print
>>
sys
.
stderr
,
"It seems swig of Paddle is not installed, this "
\
"unittest will not be run."
sys
.
exit
(
0
)
import
paddle.v2.parameters
as
parameters
from
paddle.proto.ParameterConfig_pb2
import
ParameterConfig
import
random
import
cStringIO
import
numpy
def
__rand_param_config__
(
name
):
conf
=
ParameterConfig
()
conf
.
name
=
name
size
=
1
for
i
in
xrange
(
2
):
dim
=
random
.
randint
(
1
,
1000
)
conf
.
dims
.
append
(
dim
)
size
*=
dim
conf
.
size
=
size
assert
conf
.
IsInitialized
()
return
conf
class
TestParameters
(
unittest
.
TestCase
):
def
test_serialization
(
self
):
params
=
parameters
.
Parameters
()
params
.
__append_config__
(
__rand_param_config__
(
"param_0"
))
params
.
__append_config__
(
__rand_param_config__
(
"param_1"
))
for
name
in
params
.
names
():
param
=
params
.
get
(
name
)
param
[:]
=
numpy
.
random
.
uniform
(
-
1.0
,
1.0
,
size
=
params
.
get_shape
(
name
))
params
.
set
(
name
,
param
)
tmp_file
=
cStringIO
.
StringIO
()
params
.
to_tar
(
tmp_file
)
tmp_file
.
seek
(
0
)
params_dup
=
parameters
.
Parameters
.
from_tar
(
tmp_file
)
self
.
assertEqual
(
params_dup
.
names
(),
params
.
names
())
for
name
in
params
.
names
():
self
.
assertEqual
(
params
.
get_shape
(
name
),
params_dup
.
get_shape
(
name
))
p0
=
params
.
get
(
name
)
p1
=
params_dup
.
get
(
name
)
self
.
assertTrue
(
numpy
.
isclose
(
p0
,
p1
).
all
())
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录