提交 fb74ae36 编写于 作者: Y Yu Yang

Refine serialize

上级 0d5b4acb
...@@ -6,3 +6,4 @@ train.log ...@@ -6,3 +6,4 @@ train.log
*pyc *pyc
.ipynb_checkpoints .ipynb_checkpoints
params.pkl params.pkl
params.tar
...@@ -103,6 +103,9 @@ def main(): ...@@ -103,6 +103,9 @@ def main():
cPickle.dump( cPickle.dump(
parameters, f, protocol=cPickle.HIGHEST_PROTOCOL) parameters, f, protocol=cPickle.HIGHEST_PROTOCOL)
with open('params.tar', 'w') as f:
parameters.serialize_to_tar(f)
elif isinstance(event, paddle.event.EndPass): elif isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=paddle.reader.batched( result = trainer.test(reader=paddle.reader.batched(
paddle.dataset.mnist.test(), batch_size=128)) paddle.dataset.mnist.test(), batch_size=128))
......
import numpy as np import numpy as np
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
from paddle.proto.ParameterConfig_pb2 import ParameterConfig from paddle.proto.ParameterConfig_pb2 import ParameterConfig
import struct
import tarfile
import cStringIO
from topology import Topology from topology import Topology
__all__ = ['Parameters', 'create'] __all__ = ['Parameters', 'create']
...@@ -235,6 +237,42 @@ class Parameters(object): ...@@ -235,6 +237,42 @@ class Parameters(object):
return {'conf': param_conf, 'params': params} return {'conf': param_conf, 'params': params}
def serialize(self, name, f):
"""
:param name:
:param f:
:type f: file
:return:
"""
param = self.get(name)
size = reduce(lambda a, b: a * b, param.shape)
f.write(struct.pack("IIQ", 0, 4, size))
param = param.astype(np.float32)
f.write(param.tobytes())
def deserialize(self, name, f):
"""
:param name:
:param f:
:type f: file
:return:
"""
f.read(16) # header
arr = np.fromfile(f, dtype=np.float32)
self.set(name, arr.reshape(self.get_shape(name)))
def serialize_to_tar(self, f):
tar = tarfile.TarFile(fileobj=f, mode='w')
for nm in self.names():
buf = cStringIO.StringIO()
self.serialize(nm, buf)
tarinfo = tarfile.TarInfo(name=nm)
buf.seek(0)
tarinfo.size = len(buf.getvalue())
tar.addfile(tarinfo, buf)
def __setstate__(self, obj): def __setstate__(self, obj):
Parameters.__init__(self) Parameters.__init__(self)
......
...@@ -101,7 +101,7 @@ class SGD(): ...@@ -101,7 +101,7 @@ class SGD():
for each_param in self.__gradient_machine__.getNonStaticParameters( for each_param in self.__gradient_machine__.getNonStaticParameters(
): ):
updater.update(each_param) updater.update(each_param)
cost_sum = out_args.sumCosts() cost_sum = out_args.sum()
cost = cost_sum / len(data_batch) cost = cost_sum / len(data_batch)
updater.finishBatch(cost) updater.finishBatch(cost)
batch_evaluator.finish() batch_evaluator.finish()
...@@ -137,7 +137,7 @@ class SGD(): ...@@ -137,7 +137,7 @@ class SGD():
num_samples += len(data_batch) num_samples += len(data_batch)
self.__gradient_machine__.forward( self.__gradient_machine__.forward(
feeder(data_batch), out_args, api.PASS_TEST) feeder(data_batch), out_args, api.PASS_TEST)
total_cost += out_args.sumCosts() total_cost += out_args.sum()
self.__gradient_machine__.eval(evaluator) self.__gradient_machine__.eval(evaluator)
evaluator.finish() evaluator.finish()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册