From 618884dd69af0f2e7ea7c0527ec2ba8131ec5a07 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Thu, 21 Sep 2017 21:39:42 -0700 Subject: [PATCH] Complete unittest for ProgramDesc --- paddle/pybind/pybind.cc | 21 +++++++++++++++---- .../v2/framework/tests/test_protobuf_descs.py | 9 ++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index cbb7b1cbff9..cae36713509 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -319,17 +319,30 @@ All parameter, weight, gradient are variables in Paddle. .def_static("instance", [] { return &GetProgramDesc(); }, py::return_value_policy::reference) + .def_static("__create_program_desc__", + [] { + // Only used for unit-test + auto *prog_desc = new ProgramDesc; + auto *block = prog_desc->mutable_blocks()->Add(); + block->set_idx(0); + block->set_parent_idx(-1); + return prog_desc; + }) .def("append_block", [](ProgramDesc &self, BlockDesc &parent) { - auto desc = self.mutable_blocks()->Add(); + auto desc = self.add_blocks(); desc->set_idx(self.mutable_blocks()->size() - 1); desc->set_parent_idx(parent.idx()); return desc; - }) + }, + py::return_value_policy::reference) .def("root_block", - [](ProgramDesc &self) { return self.mutable_blocks()[0]; }); + [](ProgramDesc &self) { return self.mutable_blocks()->Mutable(0); }, + py::return_value_policy::reference) + .def("__str__", [](ProgramDesc &self) { return self.DebugString(); }); + py::class_(m, "BlockDesc", "") - .def("idx", [](BlockDesc &self) { return self.idx(); }) + .def("id", [](BlockDesc &self) { return self.idx(); }) .def("parent", [](BlockDesc &self) { return self.parent_idx(); }) .def("append_op", [](BlockDesc &self) { return self.mutable_ops()->Add(); }); diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index 945610ff452..8e948436626 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -9,8 +9,17 @@ class TestProgramDesc(unittest.TestCase): del program_desc program_desc = core.ProgramDesc.instance() self.assertIsNotNone(program_desc) + self.assertIsNotNone(program_desc.root_block()) del program_desc + def test_append_block(self): + prog_desc = core.ProgramDesc.__create_program_desc__() + self.assertIsNotNone(prog_desc) + block1 = prog_desc.append_block(prog_desc.root_block()) + block2 = prog_desc.append_block(block1) + self.assertEqual(block1.id(), block2.parent()) + self.assertEqual(prog_desc.root_block().id(), block1.parent()) + if __name__ == '__main__': unittest.main() -- GitLab