diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index cbb7b1cbff91ba191b15ddbc69a3b26be319962d..cae36713509b32627e4eecf9a6153d19a2f2b9b5 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -319,17 +319,30 @@ All parameter, weight, gradient are variables in Paddle. .def_static("instance", [] { return &GetProgramDesc(); }, py::return_value_policy::reference) + .def_static("__create_program_desc__", + [] { + // Only used for unit-test + auto *prog_desc = new ProgramDesc; + auto *block = prog_desc->mutable_blocks()->Add(); + block->set_idx(0); + block->set_parent_idx(-1); + return prog_desc; + }) .def("append_block", [](ProgramDesc &self, BlockDesc &parent) { - auto desc = self.mutable_blocks()->Add(); + auto desc = self.add_blocks(); desc->set_idx(self.mutable_blocks()->size() - 1); desc->set_parent_idx(parent.idx()); return desc; - }) + }, + py::return_value_policy::reference) .def("root_block", - [](ProgramDesc &self) { return self.mutable_blocks()[0]; }); + [](ProgramDesc &self) { return self.mutable_blocks()->Mutable(0); }, + py::return_value_policy::reference) + .def("__str__", [](ProgramDesc &self) { return self.DebugString(); }); + py::class_(m, "BlockDesc", "") - .def("idx", [](BlockDesc &self) { return self.idx(); }) + .def("id", [](BlockDesc &self) { return self.idx(); }) .def("parent", [](BlockDesc &self) { return self.parent_idx(); }) .def("append_op", [](BlockDesc &self) { return self.mutable_ops()->Add(); }); diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index 945610ff4522311ddf36c1ef1a48550946667ed7..8e94843662639aa377616217cb15b4369f1ae4a5 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -9,8 +9,17 @@ class TestProgramDesc(unittest.TestCase): del program_desc program_desc = core.ProgramDesc.instance() self.assertIsNotNone(program_desc) + self.assertIsNotNone(program_desc.root_block()) del program_desc + def test_append_block(self): + prog_desc = core.ProgramDesc.__create_program_desc__() + self.assertIsNotNone(prog_desc) + block1 = prog_desc.append_block(prog_desc.root_block()) + block2 = prog_desc.append_block(block1) + self.assertEqual(block1.id(), block2.parent()) + self.assertEqual(prog_desc.root_block().id(), block1.parent()) + if __name__ == '__main__': unittest.main()