diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index b86185bf5bf9589ac207141ab6c6eb1e4b8e4696..b4ed9c4335ab2a293b62fb523d9627046db42f19 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -13,10 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/pybind/protobuf.h" +#include namespace paddle { namespace pybind { +using namespace paddle::framework; // NOLINT + template inline std::vector RepeatedToVector( const google::protobuf::RepeatedField &repeated_field) { @@ -36,45 +39,154 @@ inline void VectorToRepeated(const std::vector &vec, } } +class ProgramDescBind; +class OpDescBind; +class BlockDescBind; + +class OpDescBind { +public: + explicit OpDescBind(BlockDescBind *block) : block_(block) {} + + operator OpDesc *() { return &op_desc_; } + +private: + BlockDescBind *block_; + OpDesc op_desc_; +}; + +class BlockDescBind { +public: + BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) + : prog_(prog), desc_(desc), need_update_(false) {} + + ~BlockDescBind() { + std::cerr << "dtor " << this << "," << desc_ << std::endl; + } + + int32_t id() const { + std::cerr << "desc ptr " << desc_ << std::endl; + return desc_->idx(); + } + + int32_t Parent() const { return desc_->parent_idx(); } + + OpDescBind *AppendOp() { + need_update_ = true; + ops_.emplace_back(this); + return &ops_.back(); + } + + void Sync() { + if (need_update_) { + auto &op_field = *this->desc_->mutable_ops(); + op_field.Clear(); + op_field.Reserve(static_cast(ops_.size())); + for (auto &op_desc : ops_) { + op_field.AddAllocated(op_desc); + } + } + } + +private: + ProgramDescBind *prog_; // not_own + BlockDesc *desc_; // not_own + bool need_update_; + + std::deque ops_; +}; + +using ProgDescMap = + std::unordered_map>; +static ProgDescMap *g_bind_map = nullptr; + +class ProgramDescBind { +public: + static ProgramDescBind &Instance(ProgramDesc *prog) { + if (g_bind_map == nullptr) { + g_bind_map = new ProgDescMap(); + } + auto &map = *g_bind_map; + auto &ptr = map[prog]; + + if (ptr == nullptr) { + ptr.reset(new ProgramDescBind(prog)); + } + return *ptr; + } + + BlockDescBind *AppendBlock(BlockDescBind *parent) { + auto *b = prog_->add_blocks(); + std::cerr << "block ptr " << b << std::endl; + std::cerr << "pass ptr " << parent << std::endl; + b->set_parent_idx(parent->id()); + b->set_idx(prog_->blocks_size() - 1); + blocks_.emplace_back(this, b); + return &blocks_.back(); + } + + BlockDescBind *Root() { return &blocks_.front(); } + + BlockDescBind *Block(size_t idx) { return &blocks_[idx]; } + + std::string DebugString() { return Proto()->DebugString(); } + + size_t Size() const { return blocks_.size(); } + + ProgramDesc *Proto() { + for (auto &block : blocks_) { + block.Sync(); + } + return prog_; + } + +private: + explicit ProgramDescBind(ProgramDesc *prog) : prog_(prog) { + for (auto &block : *prog->mutable_blocks()) { + blocks_.emplace_back(this, &block); + } + } + + // Not owned + ProgramDesc *prog_; + + std::vector blocks_; +}; + void BindProgramDesc(py::module &m) { - using namespace paddle::framework; // NOLINT - py::class_(m, "ProgramDesc", "") + py::class_(m, "ProgramDesc", "") .def_static("instance", - [] { return &GetProgramDesc(); }, + []() -> ProgramDescBind * { + return &ProgramDescBind::Instance(&GetProgramDesc()); + }, py::return_value_policy::reference) .def_static("__create_program_desc__", - [] { + []() -> ProgramDescBind * { // Only used for unit-test auto *prog_desc = new ProgramDesc; auto *block = prog_desc->mutable_blocks()->Add(); block->set_idx(0); block->set_parent_idx(-1); - return prog_desc; - }) + return &ProgramDescBind::Instance(prog_desc); + }, + py::return_value_policy::reference) .def("append_block", - [](ProgramDesc &self, BlockDesc &parent) { - auto desc = self.add_blocks(); - desc->set_idx(self.mutable_blocks()->size() - 1); - desc->set_parent_idx(parent.idx()); - return desc; - }, + &ProgramDescBind::AppendBlock, py::return_value_policy::reference) .def("root_block", - [](ProgramDesc &self) { return self.mutable_blocks()->Mutable(0); }, + &ProgramDescBind::Root, py::return_value_policy::reference) - .def("block", - [](ProgramDesc &self, int id) { return self.blocks(id); }, - py::return_value_policy::reference) - .def("__str__", [](ProgramDesc &self) { return self.DebugString(); }); + .def("block", &ProgramDescBind::Block, py::return_value_policy::reference) + .def("__str__", &ProgramDescBind::DebugString) + .def("num_blocks", &ProgramDescBind::Size); } void BindBlockDesc(py::module &m) { using namespace paddle::framework; // NOLINT - py::class_(m, "BlockDesc", "") - .def("id", [](BlockDesc &self) { return self.idx(); }) - .def("parent", [](BlockDesc &self) { return self.parent_idx(); }) + py::class_(m, "BlockDesc", "") + .def_property_readonly("id", &BlockDescBind::id) + .def_property_readonly("parent", &BlockDescBind::Parent) .def("append_op", - [](BlockDesc &self) { return self.add_ops(); }, + &BlockDescBind::AppendOp, py::return_value_policy::reference) .def("new_var", [](BlockDesc &self) { return self.add_vars(); }, @@ -82,73 +194,76 @@ void BindBlockDesc(py::module &m) { } void BindVarDsec(py::module &m) { - using namespace paddle::framework; // NOLINT - py::class_(m, "VarDesc", "") - .def(py::init<>()) - .def("set_name", - [](VarDesc &self, const std::string &name) { self.set_name(name); }) - .def("set_shape", - [](VarDesc &self, const std::vector &dims) { - VectorToRepeated(dims, self.mutable_lod_tensor()->mutable_dims()); - }) - .def("set_data_type", - [](VarDesc &self, int type_id) { - LoDTensorDesc *lod_tensor_desc = self.mutable_lod_tensor(); - lod_tensor_desc->set_data_type(static_cast(type_id)); - }) - .def("shape", [](VarDesc &self) { - const LoDTensorDesc &lod_tensor_desc = self.lod_tensor(); - return RepeatedToVector(lod_tensor_desc.dims()); - }); + py::class_(m, "VarDesc", ""); + // using namespace paddle::framework; // NOLINT + // py::class_(m, "VarDesc", "") + // .def(py::init<>()) + // .def("set_name", + // [](VarDesc &self, const std::string &name) { self.set_name(name); + // }) + // .def("set_shape", + // [](VarDesc &self, const std::vector &dims) { + // VectorToRepeated(dims, + // self.mutable_lod_tensor()->mutable_dims()); + // }) + // .def("set_data_type", + // [](VarDesc &self, int type_id) { + // LoDTensorDesc *lod_tensor_desc = self.mutable_lod_tensor(); + // lod_tensor_desc->set_data_type(static_cast(type_id)); + // }) + // .def("shape", [](VarDesc &self) { + // const LoDTensorDesc &lod_tensor_desc = self.lod_tensor(); + // return RepeatedToVector(lod_tensor_desc.dims()); + // }); } void BindOpDesc(py::module &m) { - using namespace paddle::framework; // NOLINT - auto op_desc_set_var = [](OpDesc::Var *var, - const std::string ¶meter, - const std::vector &arguments) { - var->set_parameter(parameter); - VectorToRepeated(arguments, var->mutable_arguments()); - }; - - auto op_desc_set_attr = [](OpDesc &desc, const std::string &name) { - auto attr = desc.add_attrs(); - attr->set_name(name); - return attr; - }; - - py::class_(m, "OpDesc", "") - .def("type", [](OpDesc &op) { return op.type(); }) - .def("set_input", - [op_desc_set_var](OpDesc &self, - const std::string ¶meter, - const std::vector &arguments) { - auto ipt = self.add_inputs(); - op_desc_set_var(ipt, parameter, arguments); - }) - .def("input_names", - [](OpDesc &self) { - std::vector ret_val; - ret_val.reserve(static_cast(self.inputs().size())); - std::transform( - self.inputs().begin(), - self.inputs().end(), - std::back_inserter(ret_val), - [](const OpDesc::Var &var) { return var.parameter(); }); - return ret_val; - }) - .def("__str__", [](OpDesc &self) { return self.DebugString(); }) - .def("set_output", - [op_desc_set_var](OpDesc &self, - const std::string ¶meter, - const std::vector &arguments) { - auto opt = self.add_outputs(); - op_desc_set_var(opt, parameter, arguments); - }) - .def("set_attr", - [op_desc_set_attr](OpDesc &self, const std::string &name, int i) { - op_desc_set_attr(self, name)->set_i(i); - }); + // auto op_desc_set_var = [](OpDesc::Var *var, + // const std::string ¶meter, + // const std::vector &arguments) { + // var->set_parameter(parameter); + // VectorToRepeated(arguments, var->mutable_arguments()); + // }; + // + // auto op_desc_set_attr = [](OpDesc &desc, const std::string &name) { + // auto attr = desc.add_attrs(); + // attr->set_name(name); + // return attr; + // }; + py::class_(m, "OpDesc", ""); + + // .def("type", [](OpDesc &op) { return op.type(); }) + // .def("set_input", + // [op_desc_set_var](OpDesc &self, + // const std::string ¶meter, + // const std::vector &arguments) { + // auto ipt = self.add_inputs(); + // op_desc_set_var(ipt, parameter, arguments); + // }) + // .def("input_names", + // [](OpDesc &self) { + // std::vector ret_val; + // ret_val.reserve(static_cast(self.inputs().size())); + // std::transform( + // self.inputs().begin(), + // self.inputs().end(), + // std::back_inserter(ret_val), + // [](const OpDesc::Var &var) { return var.parameter(); }); + // return ret_val; + // }) + // .def("__str__", [](OpDesc &self) { return self.DebugString(); }) + // .def("set_output", + // [op_desc_set_var](OpDesc &self, + // const std::string ¶meter, + // const std::vector &arguments) { + // auto opt = self.add_outputs(); + // op_desc_set_var(opt, parameter, arguments); + // }) + // .def("set_attr", + // [op_desc_set_attr](OpDesc &self, const std::string &name, int i) + // { + // op_desc_set_attr(self, name)->set_i(i); + // }); } } // namespace pybind } // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index d0192814ef2396aac6767e945e87a6d6e3953a8e..b5ff2d4c36be7f7cfd0e3a794d7f50f27618ea21 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -9,21 +9,28 @@ class TestProgramDesc(unittest.TestCase): del program_desc program_desc = core.ProgramDesc.instance() self.assertIsNotNone(program_desc) - self.assertIsNotNone(program_desc.root_block()) + self.assertIsNotNone(program_desc.block(0)) del program_desc def test_append_block(self): prog_desc = core.ProgramDesc.__create_program_desc__() self.assertIsNotNone(prog_desc) - block_root = prog_desc.root_block() - self.assertEqual(block_root.id(), 0) + block_root = prog_desc.block(0) + self.assertIsNotNone(block_root) + print 'here' + self.assertEqual(block_root.id, 0) block1 = prog_desc.append_block(block_root) block2 = prog_desc.append_block(block1) - self.assertEqual(block1.id(), block2.parent()) - self.assertEqual(block_root.id(), block1.parent()) + self.assertIsNotNone(block1) + print 'here' + self.assertEqual(block1.id, block2.parent) + print 'here' + self.assertEqual(block_root.id, block1.parent) + print 'here' block3 = prog_desc.append_block(block_root) - self.assertEqual(block3.parent(), block_root.id()) - self.assertEqual(prog_desc.block(1).id(), 1) + self.assertEqual(block3.parent, block_root.id) + self.assertEqual(prog_desc.block(1).id, 1) + self.assertEqual(4, prog_desc.num_blocks()) class TestVarDesc(unittest.TestCase):