diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 47b3c43ebfccd0d26ac5a44d10911018895ba353..bfbe177e8f2a95c8835d32077d43819ef1ad00c5 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -15,9 +15,10 @@ limitations under the License. */ #include "paddle/pybind/protobuf.h" namespace paddle { -namespace framework { +namespace pybind { void BindProgramDesc(py::module &m) { + using namespace paddle::framework; // NOLINT py::class_(m, "ProgramDesc", "") .def_static("instance", [] { return &GetProgramDesc(); }, @@ -42,10 +43,14 @@ void BindProgramDesc(py::module &m) { .def("root_block", [](ProgramDesc &self) { return self.mutable_blocks()->Mutable(0); }, py::return_value_policy::reference) + .def("block", + [](ProgramDesc &self, int id) { return self.blocks(id); }, + py::return_value_policy::reference) .def("__str__", [](ProgramDesc &self) { return self.DebugString(); }); } void BindBlockDesc(py::module &m) { + using namespace paddle::framework; // NOLINT py::class_(m, "BlockDesc", "") .def("id", [](BlockDesc &self) { return self.idx(); }) .def("parent", [](BlockDesc &self) { return self.parent_idx(); }) @@ -58,6 +63,7 @@ void BindBlockDesc(py::module &m) { } void BindVarDsec(py::module &m) { + using namespace paddle::framework; // NOLINT py::class_(m, "VarDesc", "") .def(py::init<>()) .def("set_name", @@ -86,6 +92,7 @@ void BindVarDsec(py::module &m) { } void BindOpDesc(py::module &m) { + using namespace paddle::framework; // NOLINT auto op_desc_set_var = [](OpDesc::Var *var, const std::string ¶meter, const std::vector &arguments) { @@ -132,5 +139,5 @@ void BindOpDesc(py::module &m) { op_desc_set_attr(self, name)->set_i(i); }); } -} // namespace framework +} // namespace pybind } // namespace paddle diff --git a/paddle/pybind/protobuf.h b/paddle/pybind/protobuf.h index a32acfb03826e8cf0161f3a001e05d6b5a76ac2b..de9a008e25d1e7b076eabea4f83fc71dfca881da 100644 --- a/paddle/pybind/protobuf.h +++ b/paddle/pybind/protobuf.h @@ -25,7 +25,7 @@ limitations under the License. */ namespace py = pybind11; namespace paddle { -namespace framework { +namespace pybind { template inline std::vector RepeatedToVector( @@ -50,5 +50,5 @@ void BindProgramDesc(py::module& m); void BindBlockDesc(py::module& m); void BindVarDsec(py::module& m); void BindOpDesc(py::module& m); -} // namespace framework +} // namespace pybind } // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index 71bdca8765353416cc7450251a9cb9a0d74fdb26..d0192814ef2396aac6767e945e87a6d6e3953a8e 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -23,6 +23,7 @@ class TestProgramDesc(unittest.TestCase): self.assertEqual(block_root.id(), block1.parent()) block3 = prog_desc.append_block(block_root) self.assertEqual(block3.parent(), block_root.id()) + self.assertEqual(prog_desc.block(1).id(), 1) class TestVarDesc(unittest.TestCase):