From c36a3f46070e8ef5102b6fb34362c50193d5f529 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 6 Mar 2017 14:51:15 +0800 Subject: [PATCH] Add unittest for serialize/deserialize. --- python/paddle/v2/parameters.py | 6 +++ python/paddle/v2/tests/run_tests.sh | 2 +- python/paddle/v2/tests/test_parameters.py | 60 +++++++++++++++++++++++ 3 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 python/paddle/v2/tests/test_parameters.py diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index 1fed0b8a6a6..05dc5c68dd9 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -124,6 +124,12 @@ class Parameters(object): if len(self.__gradient_machines__) == 0: # create new parameter in python numpy. + if len(self.__tmp_params__) != 0: + ret_list = [ + mat for name, mat in self.__tmp_params__ if name == key + ] + if len(ret_list) == 1: + return ret_list[0] return np.ndarray(shape=shape, dtype=np.float32) else: for each_gradient_machine in self.__gradient_machines__: diff --git a/python/paddle/v2/tests/run_tests.sh b/python/paddle/v2/tests/run_tests.sh index b96f54fe9cc..dda1b1bd222 100755 --- a/python/paddle/v2/tests/run_tests.sh +++ b/python/paddle/v2/tests/run_tests.sh @@ -22,7 +22,7 @@ cd $SCRIPTPATH $1 -m pip install ../../../../paddle/dist/*.whl -test_list="test_data_feeder.py" +test_list="test_data_feeder.py test_parameters.py" export PYTHONPATH=$PWD/../../../../python/ diff --git a/python/paddle/v2/tests/test_parameters.py b/python/paddle/v2/tests/test_parameters.py new file mode 100644 index 00000000000..ebb182caab6 --- /dev/null +++ b/python/paddle/v2/tests/test_parameters.py @@ -0,0 +1,60 @@ +import unittest +import sys + +try: + import py_paddle + + del py_paddle +except ImportError: + print >> sys.stderr, "It seems swig of Paddle is not installed, this " \ + "unittest will not be run." + sys.exit(0) + +import paddle.v2.parameters as parameters +from paddle.proto.ParameterConfig_pb2 import ParameterConfig +import random +import cStringIO +import numpy + + +def __rand_param_config__(name): + conf = ParameterConfig() + conf.name = name + size = 1 + for i in xrange(2): + dim = random.randint(1, 1000) + conf.dims.append(dim) + size *= dim + conf.size = size + assert conf.IsInitialized() + return conf + + +class TestParameters(unittest.TestCase): + def test_serialization(self): + params = parameters.Parameters() + params.__append_config__(__rand_param_config__("param_0")) + params.__append_config__(__rand_param_config__("param_1")) + + for name in params.names(): + param = params.get(name) + param[:] = numpy.random.uniform( + -1.0, 1.0, size=params.get_shape(name)) + params.set(name, param) + + tmp_file = cStringIO.StringIO() + params.to_tar(tmp_file) + tmp_file.seek(0) + params_dup = parameters.Parameters.from_tar(tmp_file) + + self.assertEqual(params_dup.names(), params.names()) + + for name in params.names(): + self.assertEqual(params.get_shape(name), params_dup.get_shape(name)) + p0 = params.get(name) + p1 = params_dup.get(name) + self.assertTrue(numpy.isclose(p0, p1).all()) + + +if __name__ == '__main__': + unittest.main() -- GitLab