diff --git a/python/paddle/v2/framework/framework.py b/python/paddle/v2/framework/framework.py index acc61e66da1a55ec0123b195ebb52ce0858cdf5e..d649e69d58961dcc43e3bf7325b0b06c832245dc 100644 --- a/python/paddle/v2/framework/framework.py +++ b/python/paddle/v2/framework/framework.py @@ -308,6 +308,9 @@ class Block(object): def create_var(self, *args, **kwargs): return Variable(self, *args, **kwargs) + def has_var(self, name): + return name in self.vars + def create_parameter(self, *args, **kwargs): global_block = self.program.global_block() return Parameter(global_block, *args, **kwargs) @@ -324,6 +327,43 @@ class Block(object): self.ops.appendleft(op) return op + def sync_with_cpp(self): + # sync variables from cpp + for var in self.desc.all_vars(): + if not self.has_var(var.name()): + self.create_var(name=var.name(), desc=var, type=var.type()) + + # sync operators from cpp + ops_in_cpp = self.desc.all_ops() + first_op_in_python = self.ops[0].desc + last_op_in_python = self.ops[len(self.ops) - 1].desc + start_index = None + end_index = None + for index in range(len(ops_in_cpp)): + if first_op_in_python == ops_in_cpp[index]: + start_index = index + if last_op_in_python == ops_in_cpp[index]: + end_index = index + assert start_index is not None + assert end_index is not None + assert start_index <= end_index + + # sync ops append to the head of cpp_ops + for index in range((start_index - 1 - 1), -1, -1): + op_desc = ops_in_cpp[index] + op = Operator(self, op_desc) + self.ops.appendleft(op) + + # sync ops append to the end of cpp_ops + for index in range((end_index + 1), len(ops_in_cpp)): + op_desc = ops_in_cpp[index] + op = Operator(self, op_desc) + self.ops.append(op) + + assert len(self.ops) == len(ops_in_cpp) + for index in range(len(self.ops)): + assert self.ops[index].desc == ops_in_cpp[index] + class Program(object): @classmethod @@ -354,6 +394,12 @@ class Program(object): def current_block(self): return self.blocks[self.current_block_idx] + def append_backward(self, target, no_grad_set): + assert isinstance(target, Variable) + param_to_grad_info = self.desc.append_backward(target.desc, no_grad_set) + self.sync_with_cpp() + return param_to_grad_info + def create_block(self): new_block_idx = len(self.blocks) self.desc.append_block(self.current_block().desc) @@ -364,6 +410,12 @@ class Program(object): def rollback(self): self.current_block_idx = self.current_block().parent_idx + def sync_with_cpp(self): + for block_idx in range(len(self.blocks), self.desc.num_blocks()): + self.blocks.append(Block(self, block_idx)) + for block in self.blocks: + block.sync_with_cpp() + class Parameter(Variable): def __init__(self, block, shape, dtype, **kwargs): diff --git a/python/paddle/v2/framework/tests/test_program.py b/python/paddle/v2/framework/tests/test_program.py index 7c521cd634ca570ab282b83a3536c64808332cea..d06f86c09fe4edf8364e7d124cb7b8b1ae6bcc64 100644 --- a/python/paddle/v2/framework/tests/test_program.py +++ b/python/paddle/v2/framework/tests/test_program.py @@ -1,6 +1,7 @@ import unittest import paddle.v2.framework.core as core +from paddle.v2.framework.framework import Program from paddle.v2.framework.framework import g_program @@ -33,7 +34,7 @@ class TestProgram(unittest.TestCase): self.assertEqual(1, b.idx) self.assertEqual(0, b.parent_idx) - def test_append_backward(self): + def test_desc_append_backward(self): prog = core.ProgramDesc.__create_program_desc__() self.assertIsNotNone(prog) block = prog.block(0) @@ -71,6 +72,24 @@ class TestProgram(unittest.TestCase): actual_ops.append(op.type()) self.assertEqual(actual_ops, expect_ops) + def test_append_backward(self): + prog = Program.instance() + block = prog.global_block() + + mul_x = block.create_parameter( + dtype="float32", shape=[5, 10], lod_level=0, name="mul.x") + mul_y = block.create_var( + dtype="float32", shape=[10, 8], lod_level=0, name="mul.y") + mul_out = block.create_var( + dtype="float32", shape=[5, 8], lod_level=0, name="mul.out") + mul_op = block.append_op( + type="mul", + inputs={"X": [mul_x], + "Y": mul_y}, + outputs={"Out": [mul_out]}, + attrs={"x_num_col_dims": 1}) + param_to_grad = prog.append_backward(mul_out, set()) + if __name__ == '__main__': unittest.main()