diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index 8e94843662639aa377616217cb15b4369f1ae4a5..71bdca8765353416cc7450251a9cb9a0d74fdb26 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -15,10 +15,25 @@ class TestProgramDesc(unittest.TestCase): def test_append_block(self): prog_desc = core.ProgramDesc.__create_program_desc__() self.assertIsNotNone(prog_desc) - block1 = prog_desc.append_block(prog_desc.root_block()) + block_root = prog_desc.root_block() + self.assertEqual(block_root.id(), 0) + block1 = prog_desc.append_block(block_root) block2 = prog_desc.append_block(block1) self.assertEqual(block1.id(), block2.parent()) - self.assertEqual(prog_desc.root_block().id(), block1.parent()) + self.assertEqual(block_root.id(), block1.parent()) + block3 = prog_desc.append_block(block_root) + self.assertEqual(block3.parent(), block_root.id()) + + +class TestVarDesc(unittest.TestCase): + def test_shape(self): + program_desc = core.ProgramDesc.instance() + block = program_desc.root_block() + var = block.new_var() + src_shape = [3, 2, 10, 8] + var.set_shape(src_shape) + res_shape = var.shape() + self.assertEqual(src_shape, res_shape) if __name__ == '__main__':