提交 88b9202c 编写于 作者: Q Qiao Longfei 提交者: GitHub

Python cpp sync (#4816)

* add sync_with_cpp to Python Program and Block

* sync vars and ops in block from cpp

* optimize code and add some comment

* add more check for sync
上级 6729f32c
...@@ -308,6 +308,9 @@ class Block(object): ...@@ -308,6 +308,9 @@ class Block(object):
def create_var(self, *args, **kwargs): def create_var(self, *args, **kwargs):
return Variable(self, *args, **kwargs) return Variable(self, *args, **kwargs)
def has_var(self, name):
return name in self.vars
def create_parameter(self, *args, **kwargs): def create_parameter(self, *args, **kwargs):
global_block = self.program.global_block() global_block = self.program.global_block()
return Parameter(global_block, *args, **kwargs) return Parameter(global_block, *args, **kwargs)
...@@ -324,6 +327,43 @@ class Block(object): ...@@ -324,6 +327,43 @@ class Block(object):
self.ops.appendleft(op) self.ops.appendleft(op)
return op return op
def sync_with_cpp(self):
# sync variables from cpp
for var in self.desc.all_vars():
if not self.has_var(var.name()):
self.create_var(name=var.name(), desc=var, type=var.type())
# sync operators from cpp
ops_in_cpp = self.desc.all_ops()
first_op_in_python = self.ops[0].desc
last_op_in_python = self.ops[len(self.ops) - 1].desc
start_index = None
end_index = None
for index in range(len(ops_in_cpp)):
if first_op_in_python == ops_in_cpp[index]:
start_index = index
if last_op_in_python == ops_in_cpp[index]:
end_index = index
assert start_index is not None
assert end_index is not None
assert start_index <= end_index
# sync ops append to the head of cpp_ops
for index in range((start_index - 1 - 1), -1, -1):
op_desc = ops_in_cpp[index]
op = Operator(self, op_desc)
self.ops.appendleft(op)
# sync ops append to the end of cpp_ops
for index in range((end_index + 1), len(ops_in_cpp)):
op_desc = ops_in_cpp[index]
op = Operator(self, op_desc)
self.ops.append(op)
assert len(self.ops) == len(ops_in_cpp)
for index in range(len(self.ops)):
assert self.ops[index].desc == ops_in_cpp[index]
class Program(object): class Program(object):
@classmethod @classmethod
...@@ -354,6 +394,12 @@ class Program(object): ...@@ -354,6 +394,12 @@ class Program(object):
def current_block(self): def current_block(self):
return self.blocks[self.current_block_idx] return self.blocks[self.current_block_idx]
def append_backward(self, target, no_grad_set):
assert isinstance(target, Variable)
param_to_grad_info = self.desc.append_backward(target.desc, no_grad_set)
self.sync_with_cpp()
return param_to_grad_info
def create_block(self): def create_block(self):
new_block_idx = len(self.blocks) new_block_idx = len(self.blocks)
self.desc.append_block(self.current_block().desc) self.desc.append_block(self.current_block().desc)
...@@ -364,6 +410,12 @@ class Program(object): ...@@ -364,6 +410,12 @@ class Program(object):
def rollback(self): def rollback(self):
self.current_block_idx = self.current_block().parent_idx self.current_block_idx = self.current_block().parent_idx
def sync_with_cpp(self):
for block_idx in range(len(self.blocks), self.desc.num_blocks()):
self.blocks.append(Block(self, block_idx))
for block in self.blocks:
block.sync_with_cpp()
class Parameter(Variable): class Parameter(Variable):
def __init__(self, block, shape, dtype, **kwargs): def __init__(self, block, shape, dtype, **kwargs):
......
import unittest import unittest
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
from paddle.v2.framework.framework import Program
from paddle.v2.framework.framework import g_program from paddle.v2.framework.framework import g_program
...@@ -33,7 +34,7 @@ class TestProgram(unittest.TestCase): ...@@ -33,7 +34,7 @@ class TestProgram(unittest.TestCase):
self.assertEqual(1, b.idx) self.assertEqual(1, b.idx)
self.assertEqual(0, b.parent_idx) self.assertEqual(0, b.parent_idx)
def test_append_backward(self): def test_desc_append_backward(self):
prog = core.ProgramDesc.__create_program_desc__() prog = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog) self.assertIsNotNone(prog)
block = prog.block(0) block = prog.block(0)
...@@ -71,6 +72,24 @@ class TestProgram(unittest.TestCase): ...@@ -71,6 +72,24 @@ class TestProgram(unittest.TestCase):
actual_ops.append(op.type()) actual_ops.append(op.type())
self.assertEqual(actual_ops, expect_ops) self.assertEqual(actual_ops, expect_ops)
def test_append_backward(self):
prog = Program.instance()
block = prog.global_block()
mul_x = block.create_parameter(
dtype="float32", shape=[5, 10], lod_level=0, name="mul.x")
mul_y = block.create_var(
dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
mul_out = block.create_var(
dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
mul_op = block.append_op(
type="mul",
inputs={"X": [mul_x],
"Y": mul_y},
outputs={"Out": [mul_out]},
attrs={"x_num_col_dims": 1})
param_to_grad = prog.append_backward(mul_out, set())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册