提交 e05e27a7 编写于 作者: Y Yu Yang

Fix bug

上级 dc643a33
......@@ -42,15 +42,23 @@ inline void VectorToRepeated(const std::vector<T> &vec,
class ProgramDescBind;
class OpDescBind;
class BlockDescBind;
class VarDescBind;
class OpDescBind {
class VarDescBind {
public:
explicit OpDescBind(BlockDescBind *block) : block_(block) {}
explicit VarDescBind(const std::string &name) { var_desc_.set_name(name); }
VarDesc *Proto() { return &var_desc_; }
private:
VarDesc var_desc_;
};
operator OpDesc *() { return &op_desc_; }
class OpDescBind {
public:
OpDesc *Proto() { return &op_desc_; }
private:
BlockDescBind *block_;
OpDesc op_desc_;
};
......@@ -59,14 +67,28 @@ public:
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {}
BlockDescBind(const BlockDescBind &o) = delete;
BlockDescBind &operator=(const BlockDescBind &o) = delete;
int32_t id() const { return desc_->idx(); }
int32_t Parent() const { return desc_->parent_idx(); }
VarDescBind *NewVar(const std::string &name) {
need_update_ = true;
auto it = vars_.find(name);
PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name);
auto var = new VarDescBind(name);
vars_[name].reset(var);
return var;
}
BlockDescBind *ParentBlock() const;
OpDescBind *AppendOp() {
need_update_ = true;
ops_.emplace_back(this);
return &ops_.back();
ops_.emplace_back(new OpDescBind());
return ops_.back().get();
}
void Sync() {
......@@ -75,8 +97,9 @@ public:
op_field.Clear();
op_field.Reserve(static_cast<int>(ops_.size()));
for (auto &op_desc : ops_) {
op_field.AddAllocated(op_desc);
op_field.AddAllocated(op_desc->Proto());
}
need_update_ = false;
}
}
......@@ -85,7 +108,8 @@ private:
BlockDesc *desc_; // not_own
bool need_update_;
std::deque<OpDescBind> ops_;
std::deque<std::unique_ptr<OpDescBind>> ops_;
std::unordered_map<std::string, std::unique_ptr<VarDescBind>> vars_;
};
using ProgDescMap =
......@@ -106,18 +130,20 @@ public:
}
return *ptr;
}
ProgramDescBind(const ProgramDescBind &o) = delete;
ProgramDescBind &operator=(const ProgramDescBind &o) = delete;
BlockDescBind *AppendBlock(const BlockDescBind &parent) {
auto *b = prog_->add_blocks();
b->set_parent_idx(parent.id());
b->set_idx(prog_->blocks_size() - 1);
blocks_.emplace_back(this, b);
return &blocks_.back();
blocks_.emplace_back(new BlockDescBind(this, b));
return blocks_.back().get();
}
BlockDescBind *Root() { return &blocks_.front(); }
BlockDescBind *Root() { return blocks_.front().get(); }
BlockDescBind *Block(size_t idx) { return &blocks_[idx]; }
BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); }
std::string DebugString() { return Proto()->DebugString(); }
......@@ -125,25 +151,31 @@ public:
ProgramDesc *Proto() {
for (auto &block : blocks_) {
block.Sync();
block->Sync();
}
return prog_;
}
private:
explicit ProgramDescBind(ProgramDesc *prog) : prog_(prog) {
blocks_.reserve(100);
for (auto &block : *prog->mutable_blocks()) {
blocks_.emplace_back(this, &block);
blocks_.emplace_back(new BlockDescBind(this, &block));
}
}
// Not owned
ProgramDesc *prog_;
std::vector<BlockDescBind> blocks_;
std::vector<std::unique_ptr<BlockDescBind>> blocks_;
};
BlockDescBind *BlockDescBind::ParentBlock() const {
if (this->desc_->parent_idx() == -1) {
return nullptr;
}
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx()));
}
void BindProgramDesc(py::module &m) {
py::class_<ProgramDescBind>(m, "ProgramDesc", "")
.def_static("instance",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册