提交 618884dd 编写于 作者: Y Yu Yang

Complete unittest for ProgramDesc

上级 65bec3be
...@@ -319,17 +319,30 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -319,17 +319,30 @@ All parameter, weight, gradient are variables in Paddle.
.def_static("instance", .def_static("instance",
[] { return &GetProgramDesc(); }, [] { return &GetProgramDesc(); },
py::return_value_policy::reference) py::return_value_policy::reference)
.def_static("__create_program_desc__",
[] {
// Only used for unit-test
auto *prog_desc = new ProgramDesc;
auto *block = prog_desc->mutable_blocks()->Add();
block->set_idx(0);
block->set_parent_idx(-1);
return prog_desc;
})
.def("append_block", .def("append_block",
[](ProgramDesc &self, BlockDesc &parent) { [](ProgramDesc &self, BlockDesc &parent) {
auto desc = self.mutable_blocks()->Add(); auto desc = self.add_blocks();
desc->set_idx(self.mutable_blocks()->size() - 1); desc->set_idx(self.mutable_blocks()->size() - 1);
desc->set_parent_idx(parent.idx()); desc->set_parent_idx(parent.idx());
return desc; return desc;
}) },
py::return_value_policy::reference)
.def("root_block", .def("root_block",
[](ProgramDesc &self) { return self.mutable_blocks()[0]; }); [](ProgramDesc &self) { return self.mutable_blocks()->Mutable(0); },
py::return_value_policy::reference)
.def("__str__", [](ProgramDesc &self) { return self.DebugString(); });
py::class_<BlockDesc>(m, "BlockDesc", "") py::class_<BlockDesc>(m, "BlockDesc", "")
.def("idx", [](BlockDesc &self) { return self.idx(); }) .def("id", [](BlockDesc &self) { return self.idx(); })
.def("parent", [](BlockDesc &self) { return self.parent_idx(); }) .def("parent", [](BlockDesc &self) { return self.parent_idx(); })
.def("append_op", .def("append_op",
[](BlockDesc &self) { return self.mutable_ops()->Add(); }); [](BlockDesc &self) { return self.mutable_ops()->Add(); });
......
...@@ -9,8 +9,17 @@ class TestProgramDesc(unittest.TestCase): ...@@ -9,8 +9,17 @@ class TestProgramDesc(unittest.TestCase):
del program_desc del program_desc
program_desc = core.ProgramDesc.instance() program_desc = core.ProgramDesc.instance()
self.assertIsNotNone(program_desc) self.assertIsNotNone(program_desc)
self.assertIsNotNone(program_desc.root_block())
del program_desc del program_desc
def test_append_block(self):
prog_desc = core.ProgramDesc.__create_program_desc__()
self.assertIsNotNone(prog_desc)
block1 = prog_desc.append_block(prog_desc.root_block())
block2 = prog_desc.append_block(block1)
self.assertEqual(block1.id(), block2.parent())
self.assertEqual(prog_desc.root_block().id(), block1.parent())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册