diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc index 159ed03b92bbc57ab79734de832845ef1f367de9..0a305e8a8ccdd58565f1d218c5f562c1507791b3 100644 --- a/paddle/framework/attribute.cc +++ b/paddle/framework/attribute.cc @@ -24,6 +24,9 @@ static ProgramDesc* g_program_desc = nullptr; ProgramDesc& GetProgramDesc() { if (g_program_desc == nullptr) { g_program_desc = new ProgramDesc(); + auto root_block = g_program_desc->mutable_blocks()->Add(); + root_block->set_idx(0); + root_block->set_parent_idx(-1); } return *g_program_desc; } diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 90b995decb9a6b5b684efc44abe8fc280ed73adc..835ea85aa117f07f2ffd8fd317318e976b9f011d 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -316,21 +316,75 @@ All parameter, weight, gradient are variables in Paddle. m.def("is_compile_gpu", IsCompileGPU); py::class_(m, "ProgramDesc", "") - .def_static("instance", [] { return &GetProgramDesc(); }) - .def("append_block", [](ProgramDesc &self) { - auto desc = self.mutable_blocks()->Add(); - desc->set_idx(self.mutable_blocks()->size() - 1); - return desc; - }); + .def_static("instance", + [] { return &GetProgramDesc(); }, + py::return_value_policy::reference) + .def("append_block", + [](ProgramDesc &self, BlockDesc &parent) { + auto desc = self.mutable_blocks()->Add(); + desc->set_idx(self.mutable_blocks()->size() - 1); + desc->set_parent_idx(parent.idx()); + return desc; + }) + .def("root_block", + [](ProgramDesc &self) { return self.mutable_blocks()[0]; }); py::class_(m, "BlockDesc", "") .def("idx", [](BlockDesc &self) { return self.idx(); }) - .def("set_parent", - [](BlockDesc &self, int32_t idx) { self.set_parent_idx(idx); }) - .def("parent", [](BlockDesc &self) { return self.parent_idx(); }); + .def("parent", [](BlockDesc &self) { return self.parent_idx(); }) + .def("append_op", + [](BlockDesc &self) { return self.mutable_ops()->Add(); }); py::class_(m, "VarDesc", ""); - py::class_(m, "OpDesc", ""); + auto op_desc_set_var = [](OpDesc::Var *var, + const std::string ¶meter, + const std::vector &arguments) { + var->set_parameter(parameter); + auto args = var->mutable_arguments(); + args->Reserve(static_cast(arguments.size())); + for (auto &arg : arguments) { + *args->Add() = arg; + } + }; + + auto op_desc_set_attr = [](OpDesc &desc, const std::string &name) { + auto attr = desc.mutable_attrs()->Add(); + attr->set_name(name); + return attr; + }; + + py::class_(m, "OpDesc", "") + .def("type", [](OpDesc &op) { return op.type(); }) + .def("set_input", + [op_desc_set_var](OpDesc &self, + const std::string ¶meter, + const std::vector &arguments) { + auto ipt = self.mutable_inputs()->Add(); + op_desc_set_var(ipt, parameter, arguments); + }) + .def("input_names", + [](OpDesc &self) { + std::vector ret_val; + ret_val.reserve(static_cast(self.inputs().size())); + std::transform( + self.inputs().begin(), + self.inputs().end(), + std::back_inserter(ret_val), + [](const OpDesc::Var &var) { return var.parameter(); }); + return ret_val; + }) + .def("__str__", [](OpDesc &self) { return self.DebugString(); }) + .def("set_output", + [op_desc_set_var](OpDesc &self, + const std::string ¶meter, + const std::vector &arguments) { + auto opt = self.mutable_outputs()->Add(); + op_desc_set_var(opt, parameter, arguments); + }) + .def("set_attr", + [op_desc_set_attr](OpDesc &self, const std::string &name, int i) { + op_desc_set_attr(self, name)->set_i(i); + }); return m.ptr(); } diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py new file mode 100644 index 0000000000000000000000000000000000000000..945610ff4522311ddf36c1ef1a48550946667ed7 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -0,0 +1,16 @@ +import unittest +import paddle.v2.framework.core as core + + +class TestProgramDesc(unittest.TestCase): + def test_instance(self): + program_desc = core.ProgramDesc.instance() + self.assertIsNotNone(program_desc) + del program_desc + program_desc = core.ProgramDesc.instance() + self.assertIsNotNone(program_desc) + del program_desc + + +if __name__ == '__main__': + unittest.main()