diff --git a/demo/mnist/.gitignore b/demo/mnist/.gitignore index 9c552159be6f73e06233956ba0ed1e077ac1d02a..ed074b09e7fba156f87412306e44931e1b464a1d 100644 --- a/demo/mnist/.gitignore +++ b/demo/mnist/.gitignore @@ -6,3 +6,4 @@ train.log *pyc .ipynb_checkpoints params.pkl +params.tar diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index a72ebfa9805935933fba7367f7891637f567cff6..7a1f6613188f4f066aaaae28b2f2ade8f5c37306 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -103,6 +103,9 @@ def main(): cPickle.dump( parameters, f, protocol=cPickle.HIGHEST_PROTOCOL) + with open('params.tar', 'w') as f: + parameters.serialize_to_tar(f) + elif isinstance(event, paddle.event.EndPass): result = trainer.test(reader=paddle.reader.batched( paddle.dataset.mnist.test(), batch_size=128)) diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index d8c3a73b0ea9e5df11dd0933a1c87cf13e0b5ded..6a7b883500892789b88becb42e1f4a58b2dfe233 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -1,7 +1,9 @@ import numpy as np import py_paddle.swig_paddle as api from paddle.proto.ParameterConfig_pb2 import ParameterConfig - +import struct +import tarfile +import cStringIO from topology import Topology __all__ = ['Parameters', 'create'] @@ -235,6 +237,42 @@ class Parameters(object): return {'conf': param_conf, 'params': params} + def serialize(self, name, f): + """ + + :param name: + :param f: + :type f: file + :return: + """ + param = self.get(name) + size = reduce(lambda a, b: a * b, param.shape) + f.write(struct.pack("IIQ", 0, 4, size)) + param = param.astype(np.float32) + f.write(param.tobytes()) + + def deserialize(self, name, f): + """ + + :param name: + :param f: + :type f: file + :return: + """ + f.read(16) # header + arr = np.fromfile(f, dtype=np.float32) + self.set(name, arr.reshape(self.get_shape(name))) + + def serialize_to_tar(self, f): + tar = tarfile.TarFile(fileobj=f, mode='w') + for nm in self.names(): + buf = cStringIO.StringIO() + self.serialize(nm, buf) + tarinfo = tarfile.TarInfo(name=nm) + buf.seek(0) + tarinfo.size = len(buf.getvalue()) + tar.addfile(tarinfo, buf) + def __setstate__(self, obj): Parameters.__init__(self) diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index e878ea6e3b39bc21d0d50163d404cbd43f954df7..7da97d79a85b04cecae8f2d247775f7bfdba815d 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -101,7 +101,7 @@ class SGD(): for each_param in self.__gradient_machine__.getNonStaticParameters( ): updater.update(each_param) - cost_sum = out_args.sumCosts() + cost_sum = out_args.sum() cost = cost_sum / len(data_batch) updater.finishBatch(cost) batch_evaluator.finish() @@ -137,7 +137,7 @@ class SGD(): num_samples += len(data_batch) self.__gradient_machine__.forward( feeder(data_batch), out_args, api.PASS_TEST) - total_cost += out_args.sumCosts() + total_cost += out_args.sum() self.__gradient_machine__.eval(evaluator) evaluator.finish()