提交 c9865824 编写于 作者: D dangqingqing

Support to init partial network parameters from the tar file.

上级 1a0fdb9e
......@@ -51,7 +51,7 @@ class Parameters(object):
def __init__(self):
self.__param_conf__ = dict()
self.__gradient_machines__ = []
self.__tmp_params__ = []
self.__tmp_params__ = dict()
def __append_config__(self, param_conf):
"""
......@@ -128,13 +128,10 @@ class Parameters(object):
if len(self.__gradient_machines__) == 0:
# create new parameter in python numpy.
if len(self.__tmp_params__) != 0:
ret_list = [
mat for name, mat in self.__tmp_params__ if name == key
]
if len(ret_list) == 1:
return ret_list[0]
return np.ndarray(shape=shape, dtype=np.float32)
if key in self.__tmp_params__:
return self.__tmp_params__[key]
else:
return np.ndarray(shape=shape, dtype=np.float32)
else:
for each_gradient_machine in self.__gradient_machines__:
param = __get_parameter_in_gradient_machine__(
......@@ -187,7 +184,7 @@ class Parameters(object):
(shape, value.shape))
if len(self.__gradient_machines__) == 0:
self.__tmp_params__.append((key, value))
self.__tmp_params__[key] = value
else:
for each_gradient_machine in self.__gradient_machines__:
__copy_parameter_to_gradient_machine__(each_gradient_machine,
......@@ -231,7 +228,7 @@ class Parameters(object):
raise ValueError("gradient_machine should be api.GradientMachine")
if len(self.__tmp_params__) != 0:
for name, val in self.__tmp_params__:
for name, val in self.__tmp_params__.iteritems():
try:
__copy_parameter_to_gradient_machine__(gradient_machine,
name, val)
......@@ -302,6 +299,12 @@ class Parameters(object):
params.deserialize(param_name, f)
return params
def init_from_tar(self, f):
tar_param = self.from_tar(f)
for pname in tar_param.names():
if pname in self.names():
self.set(pname, tar_param.get(pname))
def __get_parameter_in_gradient_machine__(gradient_machine, name):
"""
......
......@@ -20,14 +20,17 @@ import cStringIO
import numpy
def __rand_param_config__(name):
def __rand_param_config__(name, psize=None):
conf = ParameterConfig()
conf.name = name
size = 1
for i in xrange(2):
dim = random.randint(1, 1000)
conf.dims.append(dim)
size *= dim
if psize is None:
for i in xrange(2):
dim = random.randint(1, 1000)
conf.dims.append(dim)
size *= dim
else:
size = psize
conf.size = size
assert conf.IsInitialized()
return conf
......@@ -77,6 +80,50 @@ class TestParameters(unittest.TestCase):
expected = numpy.array([[1, 1], [1, 2], [1, 1]], numpy.float32)
assert numpy.logical_and.reduce(numpy.reshape(val == expected, 6))
def test_init_from_tar(self):
def get_param(names, size):
p = parameters.Parameters()
for k, v in zip(names, size):
p.__append_config__(__rand_param_config__(k, v))
for name in p.names():
param = p.get(name)
param[:] = numpy.random.uniform(
-1.0, 1.0, size=p.get_shape(name))
p.set(name, param)
return p
def get_parames():
name1 = ['param_0', 'param_1']
size1 = [128, 256]
p1 = get_param(name1, size1)
file1 = cStringIO.StringIO()
p1.to_tar(file1)
file1.seek(0)
name2 = ['param_0', 'param_1', 'param_2']
size2 = [128, 256, 288]
p2 = get_param(name2, size2)
file2 = cStringIO.StringIO()
p2.to_tar(file2)
file2.seek(0)
return p1, file1, p2, file2
p1, file1, p2, file2 = get_parames()
p2.init_from_tar(file1)
for name in p1.names():
self.assertEqual(p1.get_shape(name), p2.get_shape(name))
v1 = p1.get(name)
v2 = p2.get(name)
self.assertTrue(numpy.isclose(v1, v2).all())
p1, file1, p2, file2 = get_parames()
p1.init_from_tar(file2)
for name in p1.names():
self.assertEqual(p1.get_shape(name), p2.get_shape(name))
v1 = p1.get(name)
v2 = p2.get(name)
self.assertTrue(numpy.isclose(v1, v2).all())
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册