From ee547f6ac984b8880394acceb6fbec856f6a2dde Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Fri, 22 Sep 2017 10:37:34 -0700 Subject: [PATCH] Add unittests --- .../v2/framework/tests/test_protobuf_descs.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index 8e94843662..71bdca8765 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -15,10 +15,25 @@ class TestProgramDesc(unittest.TestCase): def test_append_block(self): prog_desc = core.ProgramDesc.__create_program_desc__() self.assertIsNotNone(prog_desc) - block1 = prog_desc.append_block(prog_desc.root_block()) + block_root = prog_desc.root_block() + self.assertEqual(block_root.id(), 0) + block1 = prog_desc.append_block(block_root) block2 = prog_desc.append_block(block1) self.assertEqual(block1.id(), block2.parent()) - self.assertEqual(prog_desc.root_block().id(), block1.parent()) + self.assertEqual(block_root.id(), block1.parent()) + block3 = prog_desc.append_block(block_root) + self.assertEqual(block3.parent(), block_root.id()) + + +class TestVarDesc(unittest.TestCase): + def test_shape(self): + program_desc = core.ProgramDesc.instance() + block = program_desc.root_block() + var = block.new_var() + src_shape = [3, 2, 10, 8] + var.set_shape(src_shape) + res_shape = var.shape() + self.assertEqual(src_shape, res_shape) if __name__ == '__main__': -- GitLab