test_parameters.py 1.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
import unittest
import sys

try:
    import py_paddle

    del py_paddle
except ImportError:
    print >> sys.stderr, "It seems swig of Paddle is not installed, this " \
                         "unittest will not be run."
    sys.exit(0)

import paddle.v2.parameters as parameters
from paddle.proto.ParameterConfig_pb2 import ParameterConfig
import random
import cStringIO
import numpy


def __rand_param_config__(name):
    conf = ParameterConfig()
    conf.name = name
    size = 1
    for i in xrange(2):
        dim = random.randint(1, 1000)
        conf.dims.append(dim)
        size *= dim
    conf.size = size
    assert conf.IsInitialized()
    return conf


class TestParameters(unittest.TestCase):
    def test_serialization(self):
        params = parameters.Parameters()
        params.__append_config__(__rand_param_config__("param_0"))
        params.__append_config__(__rand_param_config__("param_1"))

        for name in params.names():
            param = params.get(name)
            param[:] = numpy.random.uniform(
                -1.0, 1.0, size=params.get_shape(name))
            params.set(name, param)

        tmp_file = cStringIO.StringIO()
        params.to_tar(tmp_file)
        tmp_file.seek(0)
        params_dup = parameters.Parameters.from_tar(tmp_file)

        self.assertEqual(params_dup.names(), params.names())

        for name in params.names():
            self.assertEqual(params.get_shape(name), params_dup.get_shape(name))
            p0 = params.get(name)
            p1 = params_dup.get(name)
            self.assertTrue(numpy.isclose(p0, p1).all())


if __name__ == '__main__':
    unittest.main()