diff --git a/demo/mnist/.gitignore b/demo/mnist/.gitignore index 9c552159be6f73e06233956ba0ed1e077ac1d02a..7e61d5e3a0cabd46d4185454d46610ac2ee2e63f 100644 --- a/demo/mnist/.gitignore +++ b/demo/mnist/.gitignore @@ -6,3 +6,5 @@ train.log *pyc .ipynb_checkpoints params.pkl +params.tar +params.tar.gz diff --git a/demo/mnist/api_train_v2.py b/demo/mnist/api_train_v2.py index 9a49274e9a0d7be79047c479c2437c65193c590b..072b2a08da6db1f6ae7b84ee66dbc88aef487deb 100644 --- a/demo/mnist/api_train_v2.py +++ b/demo/mnist/api_train_v2.py @@ -1,5 +1,5 @@ import paddle.v2 as paddle -import cPickle +import gzip def softmax_regression(img): @@ -73,8 +73,8 @@ def main(): cost = paddle.layer.classification_cost(input=predict, label=label) try: - with open('params.pkl', 'r') as f: - parameters = cPickle.load(f) + with gzip.open('params.tar.gz', 'r') as f: + parameters = paddle.parameters.Parameters.from_tar(f) except IOError: parameters = paddle.parameters.create(cost) @@ -91,10 +91,18 @@ def main(): def event_handler(event): if isinstance(event, paddle.event.EndIteration): - if event.batch_id % 100 == 0: - print "Pass %d, Batch %d, Cost %f, %s" % ( - event.pass_id, event.batch_id, event.cost, event.metrics) - if isinstance(event, paddle.event.EndPass): + if event.batch_id % 1000 == 0: + result = trainer.test(reader=paddle.reader.batched( + paddle.dataset.mnist.test(), batch_size=256)) + + print "Pass %d, Batch %d, Cost %f, %s, Testing metrics %s" % ( + event.pass_id, event.batch_id, event.cost, event.metrics, + result.metrics) + + with gzip.open('params.tar.gz', 'w') as f: + parameters.to_tar(f) + + elif isinstance(event, paddle.event.EndPass): result = trainer.test(reader=paddle.reader.batched( paddle.dataset.mnist.test(), batch_size=128)) print "Test with Pass %d, Cost %f, %s\n" % ( diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index f01185c14e2c47876ab8d16365372ceb2f44d69a..05dc5c68dd97b00fb15b74564a32313430c45345 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -1,7 +1,9 @@ import numpy as np import py_paddle.swig_paddle as api from paddle.proto.ParameterConfig_pb2 import ParameterConfig - +import struct +import tarfile +import cStringIO from topology import Topology __all__ = ['Parameters', 'create'] @@ -122,6 +124,12 @@ class Parameters(object): if len(self.__gradient_machines__) == 0: # create new parameter in python numpy. + if len(self.__tmp_params__) != 0: + ret_list = [ + mat for name, mat in self.__tmp_params__ if name == key + ] + if len(ret_list) == 1: + return ret_list[0] return np.ndarray(shape=shape, dtype=np.float32) else: for each_gradient_machine in self.__gradient_machines__: @@ -228,32 +236,66 @@ class Parameters(object): self.__gradient_machines__.append(gradient_machine) - def __getstate__(self): - params = {} - for name in self.names(): - params[name] = self.get(name) - - param_conf = {} - for name in self.__param_conf__: - conf = self.__param_conf__[name] - assert isinstance(conf, ParameterConfig) - param_conf[name] = conf.SerializeToString() - - return {'conf': param_conf, 'params': params} + def serialize(self, name, f): + """ - def __setstate__(self, obj): - Parameters.__init__(self) + :param name: + :param f: + :type f: file + :return: + """ + param = self.get(name) + size = reduce(lambda a, b: a * b, param.shape) + f.write(struct.pack("IIQ", 0, 4, size)) + param = param.astype(np.float32) + f.write(param.tobytes()) - def __impl__(conf, params): - for name in conf: - p = ParameterConfig() - p.ParseFromString(conf[name]) - self.__append_config__(p) - for name in params: - shape = self.get_shape(name) - self.set(name, params[name].reshape(shape)) + def deserialize(self, name, f): + """ - __impl__(**obj) + :param name: + :param f: + :type f: file + :return: + """ + f.read(16) # header + arr = np.frombuffer(f.read(), dtype=np.float32) + self.set(name, arr.reshape(self.get_shape(name))) + + def to_tar(self, f): + tar = tarfile.TarFile(fileobj=f, mode='w') + for nm in self.names(): + buf = cStringIO.StringIO() + self.serialize(nm, buf) + tarinfo = tarfile.TarInfo(name=nm) + buf.seek(0) + tarinfo.size = len(buf.getvalue()) + tar.addfile(tarinfo, buf) + + conf = self.__param_conf__[nm] + confStr = conf.SerializeToString() + tarinfo = tarfile.TarInfo(name="%s.protobuf" % nm) + tarinfo.size = len(confStr) + buf = cStringIO.StringIO(confStr) + buf.seek(0) + tar.addfile(tarinfo, fileobj=buf) + + @staticmethod + def from_tar(f): + params = Parameters() + tar = tarfile.TarFile(fileobj=f, mode='r') + for finfo in tar: + assert isinstance(finfo, tarfile.TarInfo) + if finfo.name.endswith('.protobuf'): + f = tar.extractfile(finfo) + conf = ParameterConfig() + conf.ParseFromString(f.read()) + params.__append_config__(conf) + + for param_name in params.names(): + f = tar.extractfile(param_name) + params.deserialize(param_name, f) + return params def __get_parameter_in_gradient_machine__(gradient_machine, name): diff --git a/python/paddle/v2/tests/run_tests.sh b/python/paddle/v2/tests/run_tests.sh index b96f54fe9cc78a436bc67e6c542b6e842aba997b..dda1b1bd222a9f226db1a4bd730e9637ab882196 100755 --- a/python/paddle/v2/tests/run_tests.sh +++ b/python/paddle/v2/tests/run_tests.sh @@ -22,7 +22,7 @@ cd $SCRIPTPATH $1 -m pip install ../../../../paddle/dist/*.whl -test_list="test_data_feeder.py" +test_list="test_data_feeder.py test_parameters.py" export PYTHONPATH=$PWD/../../../../python/ diff --git a/python/paddle/v2/tests/test_parameters.py b/python/paddle/v2/tests/test_parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..ebb182caab6430862a8e4da2ae4ea6b1e72f726c --- /dev/null +++ b/python/paddle/v2/tests/test_parameters.py @@ -0,0 +1,60 @@ +import unittest +import sys + +try: + import py_paddle + + del py_paddle +except ImportError: + print >> sys.stderr, "It seems swig of Paddle is not installed, this " \ + "unittest will not be run." + sys.exit(0) + +import paddle.v2.parameters as parameters +from paddle.proto.ParameterConfig_pb2 import ParameterConfig +import random +import cStringIO +import numpy + + +def __rand_param_config__(name): + conf = ParameterConfig() + conf.name = name + size = 1 + for i in xrange(2): + dim = random.randint(1, 1000) + conf.dims.append(dim) + size *= dim + conf.size = size + assert conf.IsInitialized() + return conf + + +class TestParameters(unittest.TestCase): + def test_serialization(self): + params = parameters.Parameters() + params.__append_config__(__rand_param_config__("param_0")) + params.__append_config__(__rand_param_config__("param_1")) + + for name in params.names(): + param = params.get(name) + param[:] = numpy.random.uniform( + -1.0, 1.0, size=params.get_shape(name)) + params.set(name, param) + + tmp_file = cStringIO.StringIO() + params.to_tar(tmp_file) + tmp_file.seek(0) + params_dup = parameters.Parameters.from_tar(tmp_file) + + self.assertEqual(params_dup.names(), params.names()) + + for name in params.names(): + self.assertEqual(params.get_shape(name), params_dup.get_shape(name)) + p0 = params.get(name) + p1 = params_dup.get(name) + self.assertTrue(numpy.isclose(p0, p1).all()) + + +if __name__ == '__main__': + unittest.main()