未验证 提交 d89ff5b6 编写于 作者: Y Yu Yang 提交者: GitHub

Restore the param infos in Program.clone() (#5873)

* Restore the param infos in Program.clone()

The Program.clone only clone the variables and ops
in the program into a new program. However, the
information of Parameter is not clone.

So we need restore the information of Parameters.

Fix #5871

* Follow comments

* Fix CI

* Fix CI

* Fix CI
上级 c9a96575
...@@ -395,7 +395,11 @@ class Block(object): ...@@ -395,7 +395,11 @@ class Block(object):
return v return v
def all_parameters(self): def all_parameters(self):
return {v for k, v in self.vars.iteritems() if isinstance(v, Parameter)} return list(self.iter_parameters())
def iter_parameters(self):
return (item[1] for item in self.vars.iteritems()
if isinstance(item[1], Parameter))
def create_var(self, *args, **kwargs): def create_var(self, *args, **kwargs):
var = Variable(self, *args, **kwargs) var = Variable(self, *args, **kwargs)
...@@ -469,6 +473,37 @@ class Block(object): ...@@ -469,6 +473,37 @@ class Block(object):
for index in range(len(self.ops)): for index in range(len(self.ops)):
assert self.ops[index].desc == ops_in_cpp[index] assert self.ops[index].desc == ops_in_cpp[index]
def copy_param_info_from(self, other):
"""
Copy the information of parameters from other block
Args:
other(Block): other block
Returns:
None
"""
if not isinstance(other, Block):
raise TypeError("copy_param_info_from should be invoked with Block")
for p in other.iter_parameters():
assert isinstance(p, Parameter)
v = self.vars.get(p.name, None)
if v is None:
raise ValueError("copy_param_info_from should be invoked with "
"same topology")
assert isinstance(v, Variable)
new_p = Parameter(
block=self,
shape=v.shape,
dtype=v.dtype,
type=v.type,
lod_level=v.lod_level,
stop_gradient=p.stop_gradient,
trainable=p.trainable,
optimize_attr=p.optimize_attr,
regularizer=p.regularizer,
name=v.name)
self.vars[new_p.name] = new_p
class Program(object): class Program(object):
def __init__(self): def __init__(self):
...@@ -489,6 +524,7 @@ class Program(object): ...@@ -489,6 +524,7 @@ class Program(object):
p.desc = core.ProgramDesc(self.desc) p.desc = core.ProgramDesc(self.desc)
p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())] p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())]
p.sync_with_cpp() p.sync_with_cpp()
p.copy_param_info_from(self)
return p return p
def prune(self, targets): def prune(self, targets):
...@@ -572,6 +608,24 @@ class Program(object): ...@@ -572,6 +608,24 @@ class Program(object):
for block in self.blocks: for block in self.blocks:
block.sync_with_cpp() block.sync_with_cpp()
def copy_param_info_from(self, other):
"""
Copy the information of parameters from other program.
Args:
other(Program): Other program
Returns:
None
"""
if not isinstance(other, Program):
raise TypeError("copy_param_info_from should be invoked with "
"Program")
if len(self.blocks) != len(other.blocks):
raise ValueError("copy_param_info_from should be invoked with two "
"program, with represent the same topology")
self.global_block().copy_param_info_from(other.global_block())
def list_vars(self): def list_vars(self):
for each_block in self.blocks: for each_block in self.blocks:
for each_var in each_block.vars.itervalues(): for each_var in each_block.vars.itervalues():
......
from __future__ import print_function
import unittest import unittest
from paddle.v2.fluid.framework import Program from paddle.v2.fluid.framework import Program
from paddle.v2.fluid.framework import g_main_program from paddle.v2.fluid.framework import g_main_program
import paddle.v2.fluid.layers as layers
class TestProgram(unittest.TestCase): class TestProgram(unittest.TestCase):
...@@ -48,8 +50,8 @@ class TestProgram(unittest.TestCase): ...@@ -48,8 +50,8 @@ class TestProgram(unittest.TestCase):
# FIXME(yuyang18): We manual compare the output string, since the order # FIXME(yuyang18): We manual compare the output string, since the order
# of variable could be changed. # of variable could be changed.
print prog print(prog)
print prog.clone() print(prog.clone())
def test_parse_program_from_string(self): def test_parse_program_from_string(self):
prog = Program() prog = Program()
...@@ -67,8 +69,8 @@ class TestProgram(unittest.TestCase): ...@@ -67,8 +69,8 @@ class TestProgram(unittest.TestCase):
binary_str = prog.desc.serialize_to_string() binary_str = prog.desc.serialize_to_string()
prog_restored = Program.parse_from_string(binary_str) prog_restored = Program.parse_from_string(binary_str)
print prog print(prog)
print prog_restored print(prog_restored)
def test_append_backward(self): def test_append_backward(self):
prog = Program() prog = Program()
...@@ -123,6 +125,20 @@ class TestProgram(unittest.TestCase): ...@@ -123,6 +125,20 @@ class TestProgram(unittest.TestCase):
actual_ops.append(op.type) actual_ops.append(op.type)
self.assertEqual(actual_ops, expect_ops) self.assertEqual(actual_ops, expect_ops)
def test_program_clone_with_parameter(self):
main_program = Program()
startup_program = Program()
kwargs = {
'main_program': main_program,
'startup_program': startup_program
}
d = layers.data(name='x', shape=[784], dtype='float32', **kwargs)
hidden = layers.fc(input=d, size=100, **kwargs)
layers.fc(input=hidden, size=100, **kwargs)
new_program = main_program.clone()
self.assertNotEqual(0, len(new_program.blocks[0].all_parameters()))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册