diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 3388b5cfdc0640237a3ef586535ccf1812845c4e..0a6020d6492f0f6db2d4d9b4769a4dcd2117391e 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -85,6 +85,7 @@ namespace pybind { using namespace paddle::framework; // NOLINT +// convert between std::vector and protobuf repeated. template inline std::vector RepeatedToVector( const google::protobuf::RepeatedField &repeated_field) { @@ -104,6 +105,7 @@ inline void VectorToRepeated(const std::vector &vec, } } +// Specialize vector. template inline void VectorToRepeated(const std::vector &vec, RepeatedField *repeated_field) { @@ -118,13 +120,16 @@ class OpDescBind; class BlockDescBind; class VarDescBind; +// Each Protobuf Message, we provide a XXXBind class. In that class, we optimize +// read/write speed. Only when we want the protobuf message, the local changes +// will be synchronized (by `Sync` method). class VarDescBind { public: explicit VarDescBind(const std::string &name) { desc_.set_name(name); } VarDesc *Proto() { return &desc_; } - py::bytes Name() { return desc_.name(); } + py::bytes Name() const { return desc_.name(); } void SetShape(const std::vector &dims) { VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims()); @@ -134,11 +139,13 @@ public: desc_.mutable_lod_tensor()->set_data_type(data_type); } - std::vector Shape() { + std::vector Shape() const { return RepeatedToVector(desc_.lod_tensor().dims()); } - framework::DataType DataType() { return desc_.lod_tensor().data_type(); } + framework::DataType DataType() const { + return desc_.lod_tensor().data_type(); + } private: VarDesc desc_; @@ -283,16 +290,16 @@ public: void SetBlockAttr(const std::string &name, BlockDescBind &block); - int GetBlockAttr(const std::string &name) const { + Attribute GetAttr(const std::string &name) const { auto it = attrs_.find(name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); - return boost::get(it->second)->idx(); + return it->second; } - Attribute GetAttr(const std::string &name) const { + int GetBlockAttr(const std::string &name) const { auto it = attrs_.find(name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); - return it->second; + return boost::get(it->second)->idx(); } private: @@ -312,7 +319,7 @@ public: BlockDescBind(const BlockDescBind &o) = delete; BlockDescBind &operator=(const BlockDescBind &o) = delete; - int32_t id() const { return desc_->idx(); } + int32_t ID() const { return desc_->idx(); } int32_t Parent() const { return desc_->parent_idx(); } @@ -410,7 +417,7 @@ public: BlockDescBind *AppendBlock(const BlockDescBind &parent) { auto *b = prog_->add_blocks(); - b->set_parent_idx(parent.id()); + b->set_parent_idx(parent.ID()); b->set_idx(prog_->blocks_size() - 1); blocks_.emplace_back(new BlockDescBind(this, b)); return blocks_.back().get(); @@ -454,6 +461,7 @@ void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { this->attrs_[name] = desc; } +// Bind Methods void BindProgramDesc(py::module &m) { py::class_(m, "ProgramDesc", "") .def_static("instance", @@ -481,7 +489,7 @@ void BindProgramDesc(py::module &m) { void BindBlockDesc(py::module &m) { py::class_(m, "BlockDesc", "") - .def_property_readonly("id", &BlockDescBind::id) + .def_property_readonly("id", &BlockDescBind::ID) .def_property_readonly("parent", &BlockDescBind::Parent) .def("append_op", &BlockDescBind::AppendOp,