提交 aaee28bf 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #2664 from qingqing01/from_tar

 Init partial network parameters from another saved model.
...@@ -144,7 +144,7 @@ class DenseScanner(IScanner): ...@@ -144,7 +144,7 @@ class DenseScanner(IScanner):
if len(self.__shape__) > 1: if len(self.__shape__) > 1:
# The last-two dimenstions are the frame height and width. # The last-two dimenstions are the frame height and width.
# For example, the layout is CHW for 3-D feature of image. # For example, the layout is CHW for 3-D feature of image.
# The H and W are the fram height and width. # The H and W are the frame height and width.
h, w = self.__shape__[-2:] h, w = self.__shape__[-2:]
argument.setSlotFrameHeight(self.pos, h) argument.setSlotFrameHeight(self.pos, h)
argument.setSlotFrameWidth(self.pos, w) argument.setSlotFrameWidth(self.pos, w)
......
...@@ -51,7 +51,7 @@ class Parameters(object): ...@@ -51,7 +51,7 @@ class Parameters(object):
def __init__(self): def __init__(self):
self.__param_conf__ = dict() self.__param_conf__ = dict()
self.__gradient_machines__ = [] self.__gradient_machines__ = []
self.__tmp_params__ = [] self.__tmp_params__ = dict()
def __append_config__(self, param_conf): def __append_config__(self, param_conf):
""" """
...@@ -128,12 +128,9 @@ class Parameters(object): ...@@ -128,12 +128,9 @@ class Parameters(object):
if len(self.__gradient_machines__) == 0: if len(self.__gradient_machines__) == 0:
# create new parameter in python numpy. # create new parameter in python numpy.
if len(self.__tmp_params__) != 0: if key in self.__tmp_params__:
ret_list = [ return self.__tmp_params__[key]
mat for name, mat in self.__tmp_params__ if name == key else:
]
if len(ret_list) == 1:
return ret_list[0]
return np.ndarray(shape=shape, dtype=np.float32) return np.ndarray(shape=shape, dtype=np.float32)
else: else:
for each_gradient_machine in self.__gradient_machines__: for each_gradient_machine in self.__gradient_machines__:
...@@ -187,7 +184,7 @@ class Parameters(object): ...@@ -187,7 +184,7 @@ class Parameters(object):
(shape, value.shape)) (shape, value.shape))
if len(self.__gradient_machines__) == 0: if len(self.__gradient_machines__) == 0:
self.__tmp_params__.append((key, value)) self.__tmp_params__[key] = value
else: else:
for each_gradient_machine in self.__gradient_machines__: for each_gradient_machine in self.__gradient_machines__:
__copy_parameter_to_gradient_machine__(each_gradient_machine, __copy_parameter_to_gradient_machine__(each_gradient_machine,
...@@ -231,7 +228,7 @@ class Parameters(object): ...@@ -231,7 +228,7 @@ class Parameters(object):
raise ValueError("gradient_machine should be api.GradientMachine") raise ValueError("gradient_machine should be api.GradientMachine")
if len(self.__tmp_params__) != 0: if len(self.__tmp_params__) != 0:
for name, val in self.__tmp_params__: for name, val in self.__tmp_params__.iteritems():
try: try:
__copy_parameter_to_gradient_machine__(gradient_machine, __copy_parameter_to_gradient_machine__(gradient_machine,
name, val) name, val)
...@@ -287,6 +284,18 @@ class Parameters(object): ...@@ -287,6 +284,18 @@ class Parameters(object):
@staticmethod @staticmethod
def from_tar(f): def from_tar(f):
"""
Create a `Parameters` object from the given file. And
the `Parameters` only contains the parameters in this
file. It is adapted the parameters are same in the
defined network and the given file. For example, it
can be used in the inference.
:param f: the initialized model file.
:type f: tar file
:return: A Parameters object.
:rtype: Parameters.
"""
params = Parameters() params = Parameters()
tar = tarfile.TarFile(fileobj=f, mode='r') tar = tarfile.TarFile(fileobj=f, mode='r')
for finfo in tar: for finfo in tar:
...@@ -302,6 +311,21 @@ class Parameters(object): ...@@ -302,6 +311,21 @@ class Parameters(object):
params.deserialize(param_name, f) params.deserialize(param_name, f)
return params return params
def init_from_tar(self, f):
"""
Different from `from_tar`, this interface can be used to
init partial network parameters from another saved model.
:param f: the initialized model file.
:type f: tar file
:return: Nothing.
"""
tar_param = Parameters.from_tar(f)
for pname in tar_param.names():
if pname in self.names():
self.set(pname, tar_param.get(pname))
def __get_parameter_in_gradient_machine__(gradient_machine, name): def __get_parameter_in_gradient_machine__(gradient_machine, name):
""" """
......
...@@ -20,14 +20,17 @@ import cStringIO ...@@ -20,14 +20,17 @@ import cStringIO
import numpy import numpy
def __rand_param_config__(name): def __rand_param_config__(name, psize=None):
conf = ParameterConfig() conf = ParameterConfig()
conf.name = name conf.name = name
size = 1 size = 1
if psize is None:
for i in xrange(2): for i in xrange(2):
dim = random.randint(1, 1000) dim = random.randint(1, 1000)
conf.dims.append(dim) conf.dims.append(dim)
size *= dim size *= dim
else:
size = psize
conf.size = size conf.size = size
assert conf.IsInitialized() assert conf.IsInitialized()
return conf return conf
...@@ -77,6 +80,50 @@ class TestParameters(unittest.TestCase): ...@@ -77,6 +80,50 @@ class TestParameters(unittest.TestCase):
expected = numpy.array([[1, 1], [1, 2], [1, 1]], numpy.float32) expected = numpy.array([[1, 1], [1, 2], [1, 1]], numpy.float32)
assert numpy.logical_and.reduce(numpy.reshape(val == expected, 6)) assert numpy.logical_and.reduce(numpy.reshape(val == expected, 6))
def test_init_from_tar(self):
def get_param(names, size):
p = parameters.Parameters()
for k, v in zip(names, size):
p.__append_config__(__rand_param_config__(k, v))
for name in p.names():
param = p.get(name)
param[:] = numpy.random.uniform(
-1.0, 1.0, size=p.get_shape(name))
p.set(name, param)
return p
def get_parames():
name1 = ['param_0', 'param_1']
size1 = [128, 256]
p1 = get_param(name1, size1)
file1 = cStringIO.StringIO()
p1.to_tar(file1)
file1.seek(0)
name2 = ['param_0', 'param_1', 'param_2']
size2 = [128, 256, 288]
p2 = get_param(name2, size2)
file2 = cStringIO.StringIO()
p2.to_tar(file2)
file2.seek(0)
return p1, file1, p2, file2
p1, file1, p2, file2 = get_parames()
p2.init_from_tar(file1)
for name in p1.names():
self.assertEqual(p1.get_shape(name), p2.get_shape(name))
v1 = p1.get(name)
v2 = p2.get(name)
self.assertTrue(numpy.isclose(v1, v2).all())
p1, file1, p2, file2 = get_parames()
p1.init_from_tar(file2)
for name in p1.names():
self.assertEqual(p1.get_shape(name), p2.get_shape(name))
v1 = p1.get(name)
v2 = p2.get(name)
self.assertTrue(numpy.isclose(v1, v2).all())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册