提交 14f83707 编写于 作者: Y Yu Yang

Add block.fwd_block_id

上级 78cc64a5
......@@ -46,11 +46,25 @@ VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
if (name == kEmptyVarName) return nullptr;
auto it = vars_.find(name);
if (it == vars_.end()) {
return Parent() == kNoneBlockIndex ? nullptr
: ParentBlock()->FindVarRecursive(name);
if (it != vars_.end()) {
return it->second.get();
}
return it->second.get();
BlockDesc *tmp = ParentBlock();
if (tmp != nullptr) {
auto ptr = tmp->FindVarRecursive(name);
if (ptr != nullptr) {
return ptr;
}
}
tmp = ForwardBlock();
if (tmp != nullptr) {
return tmp->FindVarRecursive(name);
}
return nullptr;
}
VarDesc &BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) {
......@@ -136,10 +150,7 @@ void BlockDesc::Flush() {
}
BlockDesc *BlockDesc::ParentBlock() const {
if (this->desc_->parent_idx() == kNoneBlockIndex) {
return nullptr;
}
return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx()));
return prog_->MutableBlock(static_cast<size_t>(desc_->parent_idx()));
}
proto::BlockDesc *BlockDesc::Proto() {
......@@ -186,5 +197,16 @@ void BlockDesc::ClearPBVars() {
}
}
void BlockDesc::SetForwardBlockID(int32_t forward_block_id) {
PADDLE_ENFORCE(!desc_->has_forward_block_idx(),
"Parent block ID has been set to %d. Cannot set to %d",
desc_->forward_block_idx(), forward_block_id);
desc_->set_forward_block_idx(forward_block_id);
}
BlockDesc *BlockDesc::ForwardBlock() const {
return prog_->MutableBlock(static_cast<size_t>(desc_->forward_block_idx()));
}
} // namespace framework
} // namespace paddle
......@@ -49,6 +49,8 @@ class BlockDesc {
int32_t Parent() const { return desc_->parent_idx(); }
int32_t ForwardBlockID() const { return desc_->forward_block_idx(); }
VarDesc *Var(const std::string &name_bytes);
VarDesc *FindVar(const std::string &name_bytes) const;
......@@ -73,6 +75,10 @@ class BlockDesc {
BlockDesc *ParentBlock() const;
BlockDesc *ForwardBlock() const;
void SetForwardBlockID(int32_t forward_block_id);
OpDesc *AppendOp();
void AppendAllocatedOp(std::unique_ptr<OpDesc> &&op_desc);
......@@ -91,7 +97,7 @@ class BlockDesc {
proto::BlockDesc *Proto();
ProgramDesc *Program() { return this->prog_; }
ProgramDesc *Program() const { return this->prog_; }
private:
void ClearPBOps();
......
......@@ -158,6 +158,7 @@ message BlockDesc {
required int32 parent_idx = 2;
repeated VarDesc vars = 3;
repeated OpDesc ops = 4;
optional int32 forward_block_idx = 5 [ default = -1 ];
}
// Please refer to
......
......@@ -38,7 +38,13 @@ class ProgramDesc {
BlockDesc *AppendBlock(const BlockDesc &parent);
BlockDesc *MutableBlock(size_t idx) { return blocks_[idx].get(); }
BlockDesc *MutableBlock(size_t idx) {
if (idx == static_cast<size_t>(kNoneBlockIndex)) {
return nullptr;
} else {
return blocks_[idx].get();
}
}
const BlockDesc &Block(size_t idx) const { return *blocks_[idx]; }
......
......@@ -231,7 +231,8 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
while_grad->SetInput(kStepScopes, Output(kStepScopes));
auto *grad_block = this->grad_block_[0];
auto *fwd_block = grad_block->ParentBlock();
auto *fwd_block = grad_block->ForwardBlock();
auto *parent_block = grad_block->ParentBlock();
// Not all of IGs will be generated by inner gradient operators of while op.
// Ignore IGs that is not generated by the inside block.
......@@ -265,8 +266,10 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
for (auto &input_name : op->InputArgumentNames()) {
// If the input of Op has been recorded or is generated by the forward
// block, do not make it as input again.
if (block_ins.find(input_name) != block_ins.end() ||
fwd_block->FindVar(input_name) != nullptr) {
fwd_block->FindVar(input_name) != nullptr ||
parent_block->FindVar(input_name) != nullptr) {
continue;
}
extra_inputs.insert(input_name);
......
......@@ -155,6 +155,8 @@ void BindBlockDesc(py::module &m) {
py::class_<BlockDesc>(m, "BlockDesc", "")
.def_property_readonly("id", &BlockDesc::ID)
.def_property_readonly("parent", &BlockDesc::Parent)
.def("get_forward_block_idx", &BlockDesc::ForwardBlockID)
.def("set_forward_block_idx", &BlockDesc::SetForwardBlockID)
.def("append_op", &BlockDesc::AppendOp,
py::return_value_policy::reference)
.def("prepend_op", &BlockDesc::PrependOp,
......
......@@ -298,7 +298,8 @@ def _append_backward_ops_(block,
# If the op has its own sub-block, deal with the sub-block first
if op.has_attr("sub_block"):
sub_block = program.block(op.block_attr("sub_block"))
grad_sub_block = program.create_block(parent_idx=sub_block.idx)
grad_sub_block = program.create_block()
grad_sub_block.set_forward_block_idx(sub_block.idx)
cb = _callback_lookup_(op)
if cb is not None:
if callbacks is None:
......@@ -310,6 +311,8 @@ def _append_backward_ops_(block,
else:
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
no_grad_dict, grad_to_var, callbacks)
program.rollback()
grad_sub_block_list.append(grad_sub_block.desc)
# Getting op's corresponding grad_op
......
......@@ -678,6 +678,13 @@ class Block(object):
def parent_idx(self):
return self.desc.parent
@property
def forward_block_idx(self):
return self.desc.get_forward_block_idx()
def set_forward_block_idx(self, idx):
self.desc.set_forward_block_idx(idx)
@property
def idx(self):
return self.desc.id
......@@ -695,11 +702,22 @@ class Block(object):
return self.var(name)
else:
if self.idx == 0:
raise ValueError("var %s is not in block(%d) nor its parents." %
name, self.idx)
raise ValueError(
"var {0} is not in block({1}) nor its parents.".format(
name, self.idx))
else:
parent_block = self.program.block(self.parent_idx)
return parent_block.var_recursive(name)
# DFS
try:
parent_block = self.program.block(self.parent_idx)
return parent_block.var_recursive(name)
except ValueError:
fwd_block = self.program.block(
self.forward_block_idx
) if self.forward_block_idx != -1 else None
if fwd_block is not None:
return fwd_block.var_recursive(name)
else:
raise
def all_parameters(self):
return list(self.iter_parameters())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册