From 3c39df197e2fbb0e8666bd8bb20e2a60e5a47d9b Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 9 Oct 2017 10:30:20 -0700 Subject: [PATCH] Init Python API Following the design * https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/python_api.md Just written `Program`, `Block` and unittest of program. --- python/paddle/v2/framework/graph.py | 45 +++++++++++++++++++ .../paddle/v2/framework/tests/test_program.py | 36 +++++++++++++++ 2 files changed, 81 insertions(+) create mode 100644 python/paddle/v2/framework/graph.py create mode 100644 python/paddle/v2/framework/tests/test_program.py diff --git a/python/paddle/v2/framework/graph.py b/python/paddle/v2/framework/graph.py new file mode 100644 index 000000000..5211b0f16 --- /dev/null +++ b/python/paddle/v2/framework/graph.py @@ -0,0 +1,45 @@ +import paddle.v2.framework.core as core + + +class Block(object): + def __init__(self, program, idx): + self.proto = program.proto.block(idx) + self.vars = dict() # var_name --> var + self.ops = list() # operator list + self.program = program + + @property + def parent_idx(self): + return self.proto.parent + + @property + def idx(self): + return self.proto.id + + +class Program(object): + def __init__(self): + self.proto = core.ProgramDesc.instance() + assert self.proto.num_blocks() == 1 + self.blocks = [Block(self, 0)] + self.current_block_idx = 0 + + def global_block(self): + return self.blocks[0] + + def current_block(self): + return self.blocks[self.current_block_idx] + + def create_block(self): + new_block_idx = len(self.blocks) + self.proto.append_block(self.current_block().proto) + self.current_block_idx = new_block_idx + self.blocks.append(Block(self, self.current_block_idx)) + return self.current_block() + + def rollback(self): + self.current_block_idx = self.current_block().parent_idx + + +# program is a global instance. +g_program = Program() diff --git a/python/paddle/v2/framework/tests/test_program.py b/python/paddle/v2/framework/tests/test_program.py new file mode 100644 index 000000000..b82d1760d --- /dev/null +++ b/python/paddle/v2/framework/tests/test_program.py @@ -0,0 +1,36 @@ +import unittest +from paddle.v2.framework.graph import g_program + + +class TestProgram(unittest.TestCase): + def test_program(self): + b = g_program.current_block() + self.assertEqual(-1, b.parent_idx) + self.assertEqual(0, b.idx) + + b = g_program.create_block() + self.assertEqual(1, b.idx) + self.assertEqual(0, b.parent_idx) + + b = g_program.create_block() + self.assertEqual(2, b.idx) + self.assertEqual(1, b.parent_idx) + + g_program.rollback() + + b = g_program.current_block() + self.assertEqual(1, b.idx) + self.assertEqual(0, b.parent_idx) + + b = g_program.create_block() + self.assertEqual(3, b.idx) + self.assertEqual(1, b.parent_idx) + + g_program.rollback() + b = g_program.current_block() + self.assertEqual(1, b.idx) + self.assertEqual(0, b.parent_idx) + + +if __name__ == '__main__': + unittest.main() -- GitLab