提交 efe53811 编写于 作者: Y Yu Yang

complete serialize

* Test gzip
上级 fb74ae36
......@@ -7,3 +7,4 @@ train.log
.ipynb_checkpoints
params.pkl
params.tar
params.tar.gz
import paddle.v2 as paddle
import cPickle
import gzip
def softmax_regression(img):
......@@ -73,8 +73,8 @@ def main():
cost = paddle.layer.classification_cost(input=predict, label=label)
try:
with open('params.pkl', 'r') as f:
parameters = cPickle.load(f)
with gzip.open('params.tar.gz', 'r') as f:
parameters = paddle.parameters.Parameters.from_tar(f)
except IOError:
parameters = paddle.parameters.create(cost)
......@@ -99,12 +99,8 @@ def main():
event.pass_id, event.batch_id, event.cost, event.metrics,
result.metrics)
with open('params.pkl', 'w') as f:
cPickle.dump(
parameters, f, protocol=cPickle.HIGHEST_PROTOCOL)
with open('params.tar', 'w') as f:
parameters.serialize_to_tar(f)
with gzip.open('params.tar.gz', 'w') as f:
parameters.to_tar(f)
elif isinstance(event, paddle.event.EndPass):
result = trainer.test(reader=paddle.reader.batched(
......
......@@ -224,19 +224,6 @@ class Parameters(object):
self.__gradient_machines__.append(gradient_machine)
def __getstate__(self):
params = {}
for name in self.names():
params[name] = self.get(name)
param_conf = {}
for name in self.__param_conf__:
conf = self.__param_conf__[name]
assert isinstance(conf, ParameterConfig)
param_conf[name] = conf.SerializeToString()
return {'conf': param_conf, 'params': params}
def serialize(self, name, f):
"""
......@@ -260,10 +247,10 @@ class Parameters(object):
:return:
"""
f.read(16) # header
arr = np.fromfile(f, dtype=np.float32)
arr = np.frombuffer(f.read(), dtype=np.float32)
self.set(name, arr.reshape(self.get_shape(name)))
def serialize_to_tar(self, f):
def to_tar(self, f):
tar = tarfile.TarFile(fileobj=f, mode='w')
for nm in self.names():
buf = cStringIO.StringIO()
......@@ -273,19 +260,30 @@ class Parameters(object):
tarinfo.size = len(buf.getvalue())
tar.addfile(tarinfo, buf)
def __setstate__(self, obj):
Parameters.__init__(self)
def __impl__(conf, params):
for name in conf:
p = ParameterConfig()
p.ParseFromString(conf[name])
self.__append_config__(p)
for name in params:
shape = self.get_shape(name)
self.set(name, params[name].reshape(shape))
__impl__(**obj)
conf = self.__param_conf__[nm]
confStr = conf.SerializeToString()
tarinfo = tarfile.TarInfo(name="%s.protobuf" % nm)
tarinfo.size = len(confStr)
buf = cStringIO.StringIO(confStr)
buf.seek(0)
tar.addfile(tarinfo, fileobj=buf)
@staticmethod
def from_tar(f):
params = Parameters()
tar = tarfile.TarFile(fileobj=f, mode='r')
for finfo in tar:
assert isinstance(finfo, tarfile.TarInfo)
if finfo.name.endswith('.protobuf'):
f = tar.extractfile(finfo)
conf = ParameterConfig()
conf.ParseFromString(f.read())
params.__append_config__(conf)
for param_name in params.names():
f = tar.extractfile(param_name)
params.deserialize(param_name, f)
return params
def __get_parameter_in_gradient_machine__(gradient_machine, name):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册