提交 b4a4ae1b 编写于 作者: Y Yu Yang

Add comments

上级 9e5de167
...@@ -85,6 +85,7 @@ namespace pybind { ...@@ -85,6 +85,7 @@ namespace pybind {
using namespace paddle::framework; // NOLINT using namespace paddle::framework; // NOLINT
// convert between std::vector and protobuf repeated.
template <typename T> template <typename T>
inline std::vector<T> RepeatedToVector( inline std::vector<T> RepeatedToVector(
const google::protobuf::RepeatedField<T> &repeated_field) { const google::protobuf::RepeatedField<T> &repeated_field) {
...@@ -104,6 +105,7 @@ inline void VectorToRepeated(const std::vector<T> &vec, ...@@ -104,6 +105,7 @@ inline void VectorToRepeated(const std::vector<T> &vec,
} }
} }
// Specialize vector<bool>.
template <typename RepeatedField> template <typename RepeatedField>
inline void VectorToRepeated(const std::vector<bool> &vec, inline void VectorToRepeated(const std::vector<bool> &vec,
RepeatedField *repeated_field) { RepeatedField *repeated_field) {
...@@ -118,13 +120,16 @@ class OpDescBind; ...@@ -118,13 +120,16 @@ class OpDescBind;
class BlockDescBind; class BlockDescBind;
class VarDescBind; class VarDescBind;
// Each Protobuf Message, we provide a XXXBind class. In that class, we optimize
// read/write speed. Only when we want the protobuf message, the local changes
// will be synchronized (by `Sync` method).
class VarDescBind { class VarDescBind {
public: public:
explicit VarDescBind(const std::string &name) { desc_.set_name(name); } explicit VarDescBind(const std::string &name) { desc_.set_name(name); }
VarDesc *Proto() { return &desc_; } VarDesc *Proto() { return &desc_; }
py::bytes Name() { return desc_.name(); } py::bytes Name() const { return desc_.name(); }
void SetShape(const std::vector<int64_t> &dims) { void SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims()); VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims());
...@@ -134,11 +139,13 @@ public: ...@@ -134,11 +139,13 @@ public:
desc_.mutable_lod_tensor()->set_data_type(data_type); desc_.mutable_lod_tensor()->set_data_type(data_type);
} }
std::vector<int64_t> Shape() { std::vector<int64_t> Shape() const {
return RepeatedToVector(desc_.lod_tensor().dims()); return RepeatedToVector(desc_.lod_tensor().dims());
} }
framework::DataType DataType() { return desc_.lod_tensor().data_type(); } framework::DataType DataType() const {
return desc_.lod_tensor().data_type();
}
private: private:
VarDesc desc_; VarDesc desc_;
...@@ -283,16 +290,16 @@ public: ...@@ -283,16 +290,16 @@ public:
void SetBlockAttr(const std::string &name, BlockDescBind &block); void SetBlockAttr(const std::string &name, BlockDescBind &block);
int GetBlockAttr(const std::string &name) const { Attribute GetAttr(const std::string &name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return boost::get<BlockDesc *>(it->second)->idx(); return it->second;
} }
Attribute GetAttr(const std::string &name) const { int GetBlockAttr(const std::string &name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return it->second; return boost::get<BlockDesc *>(it->second)->idx();
} }
private: private:
...@@ -312,7 +319,7 @@ public: ...@@ -312,7 +319,7 @@ public:
BlockDescBind(const BlockDescBind &o) = delete; BlockDescBind(const BlockDescBind &o) = delete;
BlockDescBind &operator=(const BlockDescBind &o) = delete; BlockDescBind &operator=(const BlockDescBind &o) = delete;
int32_t id() const { return desc_->idx(); } int32_t ID() const { return desc_->idx(); }
int32_t Parent() const { return desc_->parent_idx(); } int32_t Parent() const { return desc_->parent_idx(); }
...@@ -410,7 +417,7 @@ public: ...@@ -410,7 +417,7 @@ public:
BlockDescBind *AppendBlock(const BlockDescBind &parent) { BlockDescBind *AppendBlock(const BlockDescBind &parent) {
auto *b = prog_->add_blocks(); auto *b = prog_->add_blocks();
b->set_parent_idx(parent.id()); b->set_parent_idx(parent.ID());
b->set_idx(prog_->blocks_size() - 1); b->set_idx(prog_->blocks_size() - 1);
blocks_.emplace_back(new BlockDescBind(this, b)); blocks_.emplace_back(new BlockDescBind(this, b));
return blocks_.back().get(); return blocks_.back().get();
...@@ -454,6 +461,7 @@ void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { ...@@ -454,6 +461,7 @@ void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) {
this->attrs_[name] = desc; this->attrs_[name] = desc;
} }
// Bind Methods
void BindProgramDesc(py::module &m) { void BindProgramDesc(py::module &m) {
py::class_<ProgramDescBind>(m, "ProgramDesc", "") py::class_<ProgramDescBind>(m, "ProgramDesc", "")
.def_static("instance", .def_static("instance",
...@@ -481,7 +489,7 @@ void BindProgramDesc(py::module &m) { ...@@ -481,7 +489,7 @@ void BindProgramDesc(py::module &m) {
void BindBlockDesc(py::module &m) { void BindBlockDesc(py::module &m) {
py::class_<BlockDescBind>(m, "BlockDesc", "") py::class_<BlockDescBind>(m, "BlockDesc", "")
.def_property_readonly("id", &BlockDescBind::id) .def_property_readonly("id", &BlockDescBind::ID)
.def_property_readonly("parent", &BlockDescBind::Parent) .def_property_readonly("parent", &BlockDescBind::Parent)
.def("append_op", .def("append_op",
&BlockDescBind::AppendOp, &BlockDescBind::AppendOp,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册