From fb74ae36d4b9ba21ccf98bd45b04c361029e7406 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sat, 4 Mar 2017 17:11:49 +0800 Subject: [PATCH] Refine serialize --- demo/mnist/.gitignore | 1 + demo/mnist/api_train_v2.py | 3 +++ python/paddle/v2/parameters.py | 40 +++++++++++++++++++++++++++++++++- python/paddle/v2/trainer.py | 4 ++-- 4 files changed, 45 insertions(+), 3 deletions(-) diff --git a/demo/mnist/.gitignore b/demo/mnist/.gitignore index 9c552159be6..ed074b09e7f 100644 --- a/demo/mnist/.gitignore +++ b/demo/mnist/.gitignore @@ -6,3 +6,4 @@ train.log *pyc .ipynb_checkpoints params.pkl +params.tar diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index a72ebfa9805..7a1f6613188 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -103,6 +103,9 @@ def main(): cPickle.dump( parameters, f, protocol=cPickle.HIGHEST_PROTOCOL) + with open('params.tar', 'w') as f: + parameters.serialize_to_tar(f) + elif isinstance(event, paddle.event.EndPass): result = trainer.test(reader=paddle.reader.batched( paddle.dataset.mnist.test(), batch_size=128)) diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index d8c3a73b0ea..6a7b8835008 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -1,7 +1,9 @@ import numpy as np import py_paddle.swig_paddle as api from paddle.proto.ParameterConfig_pb2 import ParameterConfig - +import struct +import tarfile +import cStringIO from topology import Topology __all__ = ['Parameters', 'create'] @@ -235,6 +237,42 @@ class Parameters(object): return {'conf': param_conf, 'params': params} + def serialize(self, name, f): + """ + + :param name: + :param f: + :type f: file + :return: + """ + param = self.get(name) + size = reduce(lambda a, b: a * b, param.shape) + f.write(struct.pack("IIQ", 0, 4, size)) + param = param.astype(np.float32) + f.write(param.tobytes()) + + def deserialize(self, name, f): + """ + + :param name: + :param f: + :type f: file + :return: + """ + f.read(16) # header + arr = np.fromfile(f, dtype=np.float32) + self.set(name, arr.reshape(self.get_shape(name))) + + def serialize_to_tar(self, f): + tar = tarfile.TarFile(fileobj=f, mode='w') + for nm in self.names(): + buf = cStringIO.StringIO() + self.serialize(nm, buf) + tarinfo = tarfile.TarInfo(name=nm) + buf.seek(0) + tarinfo.size = len(buf.getvalue()) + tar.addfile(tarinfo, buf) + def __setstate__(self, obj): Parameters.__init__(self) diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index e878ea6e3b3..7da97d79a85 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -101,7 +101,7 @@ class SGD(): for each_param in self.__gradient_machine__.getNonStaticParameters( ): updater.update(each_param) - cost_sum = out_args.sumCosts() + cost_sum = out_args.sum() cost = cost_sum / len(data_batch) updater.finishBatch(cost) batch_evaluator.finish() @@ -137,7 +137,7 @@ class SGD(): num_samples += len(data_batch) self.__gradient_machine__.forward( feeder(data_batch), out_args, api.PASS_TEST) - total_cost += out_args.sumCosts() + total_cost += out_args.sum() self.__gradient_machine__.eval(evaluator) evaluator.finish() -- GitLab