From efe53811c569182df71b14c46aa5a7238038cba1 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sat, 4 Mar 2017 18:32:39 +0800 Subject: [PATCH] complete serialize * Test gzip --- demo/mnist/.gitignore | 1 + demo/mnist/api_train_v2.py | 14 ++++----- python/paddle/v2/parameters.py | 54 ++++++++++++++++------------------ 3 files changed, 32 insertions(+), 37 deletions(-) diff --git a/demo/mnist/.gitignore b/demo/mnist/.gitignore index ed074b09e7f..7e61d5e3a0c 100644 --- a/demo/mnist/.gitignore +++ b/demo/mnist/.gitignore @@ -7,3 +7,4 @@ train.log .ipynb_checkpoints params.pkl params.tar +params.tar.gz diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 7a1f6613188..a11260d91b3 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -1,5 +1,5 @@ import paddle.v2 as paddle -import cPickle +import gzip def softmax_regression(img): @@ -73,8 +73,8 @@ def main(): cost = paddle.layer.classification_cost(input=predict, label=label) try: - with open('params.pkl', 'r') as f: - parameters = cPickle.load(f) + with gzip.open('params.tar.gz', 'r') as f: + parameters = paddle.parameters.Parameters.from_tar(f) except IOError: parameters = paddle.parameters.create(cost) @@ -99,12 +99,8 @@ def main(): event.pass_id, event.batch_id, event.cost, event.metrics, result.metrics) - with open('params.pkl', 'w') as f: - cPickle.dump( - parameters, f, protocol=cPickle.HIGHEST_PROTOCOL) - - with open('params.tar', 'w') as f: - parameters.serialize_to_tar(f) + with gzip.open('params.tar.gz', 'w') as f: + parameters.to_tar(f) elif isinstance(event, paddle.event.EndPass): result = trainer.test(reader=paddle.reader.batched( diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index 6a7b8835008..58be5234072 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -224,19 +224,6 @@ class Parameters(object): self.__gradient_machines__.append(gradient_machine) - def __getstate__(self): - params = {} - for name in self.names(): - params[name] = self.get(name) - - param_conf = {} - for name in self.__param_conf__: - conf = self.__param_conf__[name] - assert isinstance(conf, ParameterConfig) - param_conf[name] = conf.SerializeToString() - - return {'conf': param_conf, 'params': params} - def serialize(self, name, f): """ @@ -260,10 +247,10 @@ class Parameters(object): :return: """ f.read(16) # header - arr = np.fromfile(f, dtype=np.float32) + arr = np.frombuffer(f.read(), dtype=np.float32) self.set(name, arr.reshape(self.get_shape(name))) - def serialize_to_tar(self, f): + def to_tar(self, f): tar = tarfile.TarFile(fileobj=f, mode='w') for nm in self.names(): buf = cStringIO.StringIO() @@ -273,19 +260,30 @@ class Parameters(object): tarinfo.size = len(buf.getvalue()) tar.addfile(tarinfo, buf) - def __setstate__(self, obj): - Parameters.__init__(self) - - def __impl__(conf, params): - for name in conf: - p = ParameterConfig() - p.ParseFromString(conf[name]) - self.__append_config__(p) - for name in params: - shape = self.get_shape(name) - self.set(name, params[name].reshape(shape)) - - __impl__(**obj) + conf = self.__param_conf__[nm] + confStr = conf.SerializeToString() + tarinfo = tarfile.TarInfo(name="%s.protobuf" % nm) + tarinfo.size = len(confStr) + buf = cStringIO.StringIO(confStr) + buf.seek(0) + tar.addfile(tarinfo, fileobj=buf) + + @staticmethod + def from_tar(f): + params = Parameters() + tar = tarfile.TarFile(fileobj=f, mode='r') + for finfo in tar: + assert isinstance(finfo, tarfile.TarInfo) + if finfo.name.endswith('.protobuf'): + f = tar.extractfile(finfo) + conf = ParameterConfig() + conf.ParseFromString(f.read()) + params.__append_config__(conf) + + for param_name in params.names(): + f = tar.extractfile(param_name) + params.deserialize(param_name, f) + return params def __get_parameter_in_gradient_machine__(gradient_machine, name): -- GitLab