提交 efe53811 编写于 作者: Y Yu Yang

complete serialize

* Test gzip
上级 fb74ae36
...@@ -7,3 +7,4 @@ train.log ...@@ -7,3 +7,4 @@ train.log
.ipynb_checkpoints .ipynb_checkpoints
params.pkl params.pkl
params.tar params.tar
params.tar.gz
import paddle.v2 as paddle import paddle.v2 as paddle
import cPickle import gzip
def softmax_regression(img): def softmax_regression(img):
...@@ -73,8 +73,8 @@ def main(): ...@@ -73,8 +73,8 @@ def main():
cost = paddle.layer.classification_cost(input=predict, label=label) cost = paddle.layer.classification_cost(input=predict, label=label)
try: try:
with open('params.pkl', 'r') as f: with gzip.open('params.tar.gz', 'r') as f:
parameters = cPickle.load(f) parameters = paddle.parameters.Parameters.from_tar(f)
except IOError: except IOError:
parameters = paddle.parameters.create(cost) parameters = paddle.parameters.create(cost)
...@@ -99,12 +99,8 @@ def main(): ...@@ -99,12 +99,8 @@ def main():
event.pass_id, event.batch_id, event.cost, event.metrics, event.pass_id, event.batch_id, event.cost, event.metrics,
result.metrics) result.metrics)
with open('params.pkl', 'w') as f: with gzip.open('params.tar.gz', 'w') as f:
cPickle.dump( parameters.to_tar(f)
parameters, f, protocol=cPickle.HIGHEST_PROTOCOL)
with open('params.tar', 'w') as f:
parameters.serialize_to_tar(f)
elif isinstance(event, paddle.event.EndPass): elif isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=paddle.reader.batched( result = trainer.test(reader=paddle.reader.batched(
......
...@@ -224,19 +224,6 @@ class Parameters(object): ...@@ -224,19 +224,6 @@ class Parameters(object):
self.__gradient_machines__.append(gradient_machine) self.__gradient_machines__.append(gradient_machine)
def __getstate__(self):
params = {}
for name in self.names():
params[name] = self.get(name)
param_conf = {}
for name in self.__param_conf__:
conf = self.__param_conf__[name]
assert isinstance(conf, ParameterConfig)
param_conf[name] = conf.SerializeToString()
return {'conf': param_conf, 'params': params}
def serialize(self, name, f): def serialize(self, name, f):
""" """
...@@ -260,10 +247,10 @@ class Parameters(object): ...@@ -260,10 +247,10 @@ class Parameters(object):
:return: :return:
""" """
f.read(16) # header f.read(16) # header
arr = np.fromfile(f, dtype=np.float32) arr = np.frombuffer(f.read(), dtype=np.float32)
self.set(name, arr.reshape(self.get_shape(name))) self.set(name, arr.reshape(self.get_shape(name)))
def serialize_to_tar(self, f): def to_tar(self, f):
tar = tarfile.TarFile(fileobj=f, mode='w') tar = tarfile.TarFile(fileobj=f, mode='w')
for nm in self.names(): for nm in self.names():
buf = cStringIO.StringIO() buf = cStringIO.StringIO()
...@@ -273,19 +260,30 @@ class Parameters(object): ...@@ -273,19 +260,30 @@ class Parameters(object):
tarinfo.size = len(buf.getvalue()) tarinfo.size = len(buf.getvalue())
tar.addfile(tarinfo, buf) tar.addfile(tarinfo, buf)
def __setstate__(self, obj): conf = self.__param_conf__[nm]
Parameters.__init__(self) confStr = conf.SerializeToString()
tarinfo = tarfile.TarInfo(name="%s.protobuf" % nm)
def __impl__(conf, params): tarinfo.size = len(confStr)
for name in conf: buf = cStringIO.StringIO(confStr)
p = ParameterConfig() buf.seek(0)
p.ParseFromString(conf[name]) tar.addfile(tarinfo, fileobj=buf)
self.__append_config__(p)
for name in params: @staticmethod
shape = self.get_shape(name) def from_tar(f):
self.set(name, params[name].reshape(shape)) params = Parameters()
tar = tarfile.TarFile(fileobj=f, mode='r')
__impl__(**obj) for finfo in tar:
assert isinstance(finfo, tarfile.TarInfo)
if finfo.name.endswith('.protobuf'):
f = tar.extractfile(finfo)
conf = ParameterConfig()
conf.ParseFromString(f.read())
params.__append_config__(conf)
for param_name in params.names():
f = tar.extractfile(param_name)
params.deserialize(param_name, f)
return params
def __get_parameter_in_gradient_machine__(gradient_machine, name): def __get_parameter_in_gradient_machine__(gradient_machine, name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册