提交 e05e27a7 编写于 作者: Y Yu Yang

Fix bug

上级 dc643a33
...@@ -42,15 +42,23 @@ inline void VectorToRepeated(const std::vector<T> &vec, ...@@ -42,15 +42,23 @@ inline void VectorToRepeated(const std::vector<T> &vec,
class ProgramDescBind; class ProgramDescBind;
class OpDescBind; class OpDescBind;
class BlockDescBind; class BlockDescBind;
class VarDescBind;
class OpDescBind { class VarDescBind {
public: public:
explicit OpDescBind(BlockDescBind *block) : block_(block) {} explicit VarDescBind(const std::string &name) { var_desc_.set_name(name); }
VarDesc *Proto() { return &var_desc_; }
private:
VarDesc var_desc_;
};
operator OpDesc *() { return &op_desc_; } class OpDescBind {
public:
OpDesc *Proto() { return &op_desc_; }
private: private:
BlockDescBind *block_;
OpDesc op_desc_; OpDesc op_desc_;
}; };
...@@ -59,14 +67,28 @@ public: ...@@ -59,14 +67,28 @@ public:
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {} : prog_(prog), desc_(desc), need_update_(false) {}
BlockDescBind(const BlockDescBind &o) = delete;
BlockDescBind &operator=(const BlockDescBind &o) = delete;
int32_t id() const { return desc_->idx(); } int32_t id() const { return desc_->idx(); }
int32_t Parent() const { return desc_->parent_idx(); } int32_t Parent() const { return desc_->parent_idx(); }
VarDescBind *NewVar(const std::string &name) {
need_update_ = true;
auto it = vars_.find(name);
PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name);
auto var = new VarDescBind(name);
vars_[name].reset(var);
return var;
}
BlockDescBind *ParentBlock() const;
OpDescBind *AppendOp() { OpDescBind *AppendOp() {
need_update_ = true; need_update_ = true;
ops_.emplace_back(this); ops_.emplace_back(new OpDescBind());
return &ops_.back(); return ops_.back().get();
} }
void Sync() { void Sync() {
...@@ -75,8 +97,9 @@ public: ...@@ -75,8 +97,9 @@ public:
op_field.Clear(); op_field.Clear();
op_field.Reserve(static_cast<int>(ops_.size())); op_field.Reserve(static_cast<int>(ops_.size()));
for (auto &op_desc : ops_) { for (auto &op_desc : ops_) {
op_field.AddAllocated(op_desc); op_field.AddAllocated(op_desc->Proto());
} }
need_update_ = false;
} }
} }
...@@ -85,7 +108,8 @@ private: ...@@ -85,7 +108,8 @@ private:
BlockDesc *desc_; // not_own BlockDesc *desc_; // not_own
bool need_update_; bool need_update_;
std::deque<OpDescBind> ops_; std::deque<std::unique_ptr<OpDescBind>> ops_;
std::unordered_map<std::string, std::unique_ptr<VarDescBind>> vars_;
}; };
using ProgDescMap = using ProgDescMap =
...@@ -106,18 +130,20 @@ public: ...@@ -106,18 +130,20 @@ public:
} }
return *ptr; return *ptr;
} }
ProgramDescBind(const ProgramDescBind &o) = delete;
ProgramDescBind &operator=(const ProgramDescBind &o) = delete;
BlockDescBind *AppendBlock(const BlockDescBind &parent) { BlockDescBind *AppendBlock(const BlockDescBind &parent) {
auto *b = prog_->add_blocks(); auto *b = prog_->add_blocks();
b->set_parent_idx(parent.id()); b->set_parent_idx(parent.id());
b->set_idx(prog_->blocks_size() - 1); b->set_idx(prog_->blocks_size() - 1);
blocks_.emplace_back(this, b); blocks_.emplace_back(new BlockDescBind(this, b));
return &blocks_.back(); return blocks_.back().get();
} }
BlockDescBind *Root() { return &blocks_.front(); } BlockDescBind *Root() { return blocks_.front().get(); }
BlockDescBind *Block(size_t idx) { return &blocks_[idx]; } BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); }
std::string DebugString() { return Proto()->DebugString(); } std::string DebugString() { return Proto()->DebugString(); }
...@@ -125,25 +151,31 @@ public: ...@@ -125,25 +151,31 @@ public:
ProgramDesc *Proto() { ProgramDesc *Proto() {
for (auto &block : blocks_) { for (auto &block : blocks_) {
block.Sync(); block->Sync();
} }
return prog_; return prog_;
} }
private: private:
explicit ProgramDescBind(ProgramDesc *prog) : prog_(prog) { explicit ProgramDescBind(ProgramDesc *prog) : prog_(prog) {
blocks_.reserve(100);
for (auto &block : *prog->mutable_blocks()) { for (auto &block : *prog->mutable_blocks()) {
blocks_.emplace_back(this, &block); blocks_.emplace_back(new BlockDescBind(this, &block));
} }
} }
// Not owned // Not owned
ProgramDesc *prog_; ProgramDesc *prog_;
std::vector<BlockDescBind> blocks_; std::vector<std::unique_ptr<BlockDescBind>> blocks_;
}; };
BlockDescBind *BlockDescBind::ParentBlock() const {
if (this->desc_->parent_idx() == -1) {
return nullptr;
}
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx()));
}
void BindProgramDesc(py::module &m) { void BindProgramDesc(py::module &m) {
py::class_<ProgramDescBind>(m, "ProgramDesc", "") py::class_<ProgramDescBind>(m, "ProgramDesc", "")
.def_static("instance", .def_static("instance",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册