提交 7a9098a6 编写于 作者: Y Yu Yang 提交者: Yang Yang(Tony)

Add block.fwd_block_id (#8489)

* Add block.fwd_block_id

* fix bug in memory optimization transpiler

* Change DFS to BFS

* Add comments
上级 8c0434c3
...@@ -16,6 +16,8 @@ limitations under the License. */ ...@@ -16,6 +16,8 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include <queue>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -64,12 +66,36 @@ VarDesc *BlockDesc::RenameVar(const std::string &old_name, ...@@ -64,12 +66,36 @@ VarDesc *BlockDesc::RenameVar(const std::string &old_name,
VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const { VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
if (name == kEmptyVarName) return nullptr; if (name == kEmptyVarName) return nullptr;
auto it = vars_.find(name); std::queue<const BlockDesc *> frontier;
if (it == vars_.end()) { std::unordered_set<const BlockDesc *> visited;
return Parent() == kNoneBlockIndex ? nullptr
: ParentBlock()->FindVarRecursive(name); frontier.push(this);
while (!frontier.empty()) { // BFS
auto cur = frontier.front();
frontier.pop();
if (visited.count(cur) != 0) {
continue;
}
auto var = cur->FindVar(name);
if (var != nullptr) {
return var;
}
auto fwd = cur->ForwardBlock();
auto parent = cur->ParentBlock();
if (fwd != nullptr) {
frontier.push(fwd);
}
if (parent != nullptr) {
frontier.push(parent);
}
visited.insert(cur);
} }
return it->second.get();
return nullptr;
} }
VarDesc &BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) { VarDesc &BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) {
...@@ -155,10 +181,7 @@ void BlockDesc::Flush() { ...@@ -155,10 +181,7 @@ void BlockDesc::Flush() {
} }
BlockDesc *BlockDesc::ParentBlock() const { BlockDesc *BlockDesc::ParentBlock() const {
if (this->desc_->parent_idx() == kNoneBlockIndex) { return prog_->MutableBlock(static_cast<size_t>(desc_->parent_idx()));
return nullptr;
}
return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx()));
} }
proto::BlockDesc *BlockDesc::Proto() { proto::BlockDesc *BlockDesc::Proto() {
...@@ -205,5 +228,16 @@ void BlockDesc::ClearPBVars() { ...@@ -205,5 +228,16 @@ void BlockDesc::ClearPBVars() {
} }
} }
void BlockDesc::SetForwardBlockID(int32_t forward_block_id) {
PADDLE_ENFORCE(!desc_->has_forward_block_idx(),
"Parent block ID has been set to %d. Cannot set to %d",
desc_->forward_block_idx(), forward_block_id);
desc_->set_forward_block_idx(forward_block_id);
}
BlockDesc *BlockDesc::ForwardBlock() const {
return prog_->MutableBlock(static_cast<size_t>(desc_->forward_block_idx()));
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -49,6 +49,8 @@ class BlockDesc { ...@@ -49,6 +49,8 @@ class BlockDesc {
int32_t Parent() const { return desc_->parent_idx(); } int32_t Parent() const { return desc_->parent_idx(); }
int32_t ForwardBlockID() const { return desc_->forward_block_idx(); }
VarDesc *Var(const std::string &name_bytes); VarDesc *Var(const std::string &name_bytes);
VarDesc *FindVar(const std::string &name_bytes) const; VarDesc *FindVar(const std::string &name_bytes) const;
...@@ -75,6 +77,10 @@ class BlockDesc { ...@@ -75,6 +77,10 @@ class BlockDesc {
BlockDesc *ParentBlock() const; BlockDesc *ParentBlock() const;
BlockDesc *ForwardBlock() const;
void SetForwardBlockID(int32_t forward_block_id);
OpDesc *AppendOp(); OpDesc *AppendOp();
void AppendAllocatedOp(std::unique_ptr<OpDesc> &&op_desc); void AppendAllocatedOp(std::unique_ptr<OpDesc> &&op_desc);
...@@ -93,7 +99,7 @@ class BlockDesc { ...@@ -93,7 +99,7 @@ class BlockDesc {
proto::BlockDesc *Proto(); proto::BlockDesc *Proto();
ProgramDesc *Program() { return this->prog_; } ProgramDesc *Program() const { return this->prog_; }
private: private:
void ClearPBOps(); void ClearPBOps();
......
...@@ -158,6 +158,7 @@ message BlockDesc { ...@@ -158,6 +158,7 @@ message BlockDesc {
required int32 parent_idx = 2; required int32 parent_idx = 2;
repeated VarDesc vars = 3; repeated VarDesc vars = 3;
repeated OpDesc ops = 4; repeated OpDesc ops = 4;
optional int32 forward_block_idx = 5 [ default = -1 ];
} }
// Please refer to // Please refer to
......
...@@ -38,7 +38,13 @@ class ProgramDesc { ...@@ -38,7 +38,13 @@ class ProgramDesc {
BlockDesc *AppendBlock(const BlockDesc &parent); BlockDesc *AppendBlock(const BlockDesc &parent);
BlockDesc *MutableBlock(size_t idx) { return blocks_[idx].get(); } BlockDesc *MutableBlock(size_t idx) {
if (idx == static_cast<size_t>(kNoneBlockIndex)) {
return nullptr;
} else {
return blocks_[idx].get();
}
}
const BlockDesc &Block(size_t idx) const { return *blocks_[idx]; } const BlockDesc &Block(size_t idx) const { return *blocks_[idx]; }
......
...@@ -231,7 +231,8 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -231,7 +231,8 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
while_grad->SetInput(kStepScopes, Output(kStepScopes)); while_grad->SetInput(kStepScopes, Output(kStepScopes));
auto *grad_block = this->grad_block_[0]; auto *grad_block = this->grad_block_[0];
auto *fwd_block = grad_block->ParentBlock(); auto *fwd_block = grad_block->ForwardBlock();
auto *parent_block = grad_block->ParentBlock();
// Not all of IGs will be generated by inner gradient operators of while op. // Not all of IGs will be generated by inner gradient operators of while op.
// Ignore IGs that is not generated by the inside block. // Ignore IGs that is not generated by the inside block.
...@@ -260,33 +261,37 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -260,33 +261,37 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
for (auto &o : Output(kOutputs)) { for (auto &o : Output(kOutputs)) {
block_ins.insert(o); block_ins.insert(o);
} }
std::unordered_set<std::string> extra_inputs; std::unordered_set<std::string> output_grads;
for (const auto *op : grad_block->AllOps()) { for (const auto *op : grad_block->AllOps()) {
for (auto &input_name : op->InputArgumentNames()) { for (auto &input_name : op->InputArgumentNames()) {
// If the input of Op has been recorded or is generated by the forward // If the input of Op has been recorded or is generated by the forward
// block, do not make it as input again. // block, do not make it as input again.
// The input is located in I/O or other op's outputs or the variable is
// located in grad_block's parents
if (block_ins.find(input_name) != block_ins.end() || if (block_ins.find(input_name) != block_ins.end() ||
fwd_block->FindVar(input_name) != nullptr) { (fwd_block->FindVarRecursive(input_name) != nullptr ||
parent_block->FindVarRecursive(input_name) != nullptr)) {
continue; continue;
} }
extra_inputs.insert(input_name); output_grads.insert(input_name);
} }
for (auto &output_name : op->OutputArgumentNames()) { for (auto &output_name : op->OutputArgumentNames()) {
block_ins.insert(output_name); block_ins.insert(output_name);
} }
} }
std::vector<std::string> extra_inputs_list; std::vector<std::string> output_grads_list;
extra_inputs_list.resize(extra_inputs.size()); output_grads_list.resize(output_grads.size());
std::copy(extra_inputs.begin(), extra_inputs.end(), std::copy(output_grads.begin(), output_grads.end(),
extra_inputs_list.begin()); output_grads_list.begin());
while_grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list); while_grad->SetInput(framework::GradVarName(kOutputs), output_grads_list);
while_grad->SetAttrMap(this->Attrs()); while_grad->SetAttrMap(this->Attrs());
while_grad->SetBlockAttr(kStepBlock, *grad_block); while_grad->SetBlockAttr(kStepBlock, *grad_block);
// record the original output gradient names, since the gradient name of // record the original output gradient names, since the gradient name of
// while operator could be renamed. // while operator could be renamed.
while_grad->SetAttr("original_output_grad", extra_inputs_list); while_grad->SetAttr("original_output_grad", output_grads_list);
return std::unique_ptr<framework::OpDesc>(while_grad); return std::unique_ptr<framework::OpDesc>(while_grad);
} }
......
...@@ -155,6 +155,8 @@ void BindBlockDesc(py::module &m) { ...@@ -155,6 +155,8 @@ void BindBlockDesc(py::module &m) {
py::class_<BlockDesc>(m, "BlockDesc", "") py::class_<BlockDesc>(m, "BlockDesc", "")
.def_property_readonly("id", &BlockDesc::ID) .def_property_readonly("id", &BlockDesc::ID)
.def_property_readonly("parent", &BlockDesc::Parent) .def_property_readonly("parent", &BlockDesc::Parent)
.def("get_forward_block_idx", &BlockDesc::ForwardBlockID)
.def("set_forward_block_idx", &BlockDesc::SetForwardBlockID)
.def("append_op", &BlockDesc::AppendOp, .def("append_op", &BlockDesc::AppendOp,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("prepend_op", &BlockDesc::PrependOp, .def("prepend_op", &BlockDesc::PrependOp,
......
...@@ -298,7 +298,8 @@ def _append_backward_ops_(block, ...@@ -298,7 +298,8 @@ def _append_backward_ops_(block,
# If the op has its own sub-block, deal with the sub-block first # If the op has its own sub-block, deal with the sub-block first
if op.has_attr("sub_block"): if op.has_attr("sub_block"):
sub_block = program.block(op.block_attr("sub_block")) sub_block = program.block(op.block_attr("sub_block"))
grad_sub_block = program.create_block(parent_idx=sub_block.idx) grad_sub_block = program.create_block()
grad_sub_block.set_forward_block_idx(sub_block.idx)
cb = _callback_lookup_(op) cb = _callback_lookup_(op)
if cb is not None: if cb is not None:
if callbacks is None: if callbacks is None:
...@@ -310,6 +311,8 @@ def _append_backward_ops_(block, ...@@ -310,6 +311,8 @@ def _append_backward_ops_(block,
else: else:
_append_backward_ops_(sub_block, sub_block.ops, grad_sub_block, _append_backward_ops_(sub_block, sub_block.ops, grad_sub_block,
no_grad_dict, grad_to_var, callbacks) no_grad_dict, grad_to_var, callbacks)
program.rollback()
grad_sub_block_list.append(grad_sub_block.desc) grad_sub_block_list.append(grad_sub_block.desc)
# Getting op's corresponding grad_op # Getting op's corresponding grad_op
......
...@@ -696,6 +696,13 @@ class Block(object): ...@@ -696,6 +696,13 @@ class Block(object):
def parent_idx(self): def parent_idx(self):
return self.desc.parent return self.desc.parent
@property
def forward_block_idx(self):
return self.desc.get_forward_block_idx()
def set_forward_block_idx(self, idx):
self.desc.set_forward_block_idx(idx)
@property @property
def idx(self): def idx(self):
return self.desc.id return self.desc.id
...@@ -709,15 +716,32 @@ class Block(object): ...@@ -709,15 +716,32 @@ class Block(object):
return v return v
def var_recursive(self, name): def var_recursive(self, name):
if self.has_var(name): frontier = list()
return self.var(name) visited = set()
else:
if self.idx == 0: frontier.append(self)
raise ValueError("var %s is not in block(%d) nor its parents." %
name, self.idx) prog = self.program
else:
parent_block = self.program.block(self.parent_idx) while len(frontier) != 0: # BFS
return parent_block.var_recursive(name) cur = frontier[0]
frontier = frontier[1:]
if id(cur) in visited:
continue
if cur.has_var(name):
return cur.var(name)
if cur.parent_idx != -1:
frontier.append(prog.block(cur.parent_idx))
if cur.forward_block_idx != -1:
frontier.append(prog.block(cur.forward_block_idx))
visited.add(id(cur))
raise ValueError("Var {0} is not found recursively".format(name))
def all_parameters(self): def all_parameters(self):
return list(self.iter_parameters()) return list(self.iter_parameters())
......
...@@ -223,15 +223,15 @@ def get_cfgs(input_program): ...@@ -223,15 +223,15 @@ def get_cfgs(input_program):
# Find while/while_grad block pair # Find while/while_grad block pair
for grad_id in while_grad_sub_block_ids: for grad_id in while_grad_sub_block_ids:
parent_id = pdesc.block(grad_id).parent forward_id = pdesc.block(grad_id).get_forward_block_idx()
if parent_id in while_sub_block_ids: if forward_id in while_sub_block_ids:
while_block_id_pair.append((parent_id, grad_id)) while_block_id_pair.append((forward_id, grad_id))
while_sub_block_ids.remove(parent_id) while_sub_block_ids.remove(forward_id)
# Get while/while_grad block ops # Get while/while_grad block ops
for parent_id, grad_id in while_block_id_pair: for forward_id, grad_id in while_block_id_pair:
while_block_ops = [] while_block_ops = []
while_block = pdesc.block(parent_id) while_block = pdesc.block(forward_id)
while_block_op_size = while_block.op_size() while_block_op_size = while_block.op_size()
for i in range(while_block_op_size): for i in range(while_block_op_size):
while_block_ops.append(while_block.op(i)) while_block_ops.append(while_block.op(i))
...@@ -242,21 +242,21 @@ def get_cfgs(input_program): ...@@ -242,21 +242,21 @@ def get_cfgs(input_program):
while_block_ops.append(while_grad_block.op(i)) while_block_ops.append(while_grad_block.op(i))
while_op_output = set() while_op_output = set()
while_op_output.update(while_op_dict[parent_id].output_arg_names()) while_op_output.update(while_op_dict[forward_id].output_arg_names())
while_op_output.update(while_op_dict[grad_id].output_arg_names()) while_op_output.update(while_op_dict[grad_id].output_arg_names())
ops_list.append((while_block_ops, while_block_op_size, while_op_output)) ops_list.append((while_block_ops, while_block_op_size, while_op_output))
# Process rest while block ops # Process rest while block ops
for parent_id in while_sub_block_ids: for forward_id in while_sub_block_ids:
while_block_ops = [] while_block_ops = []
while_block = pdesc.block(parent_id) while_block = pdesc.block(forward_id)
while_block_op_size = while_block.op_size() while_block_op_size = while_block.op_size()
for i in range(while_block_op_size): for i in range(while_block_op_size):
while_block_ops.append(while_block.op(i)) while_block_ops.append(while_block.op(i))
while_op_output = set() while_op_output = set()
while_op_output.update(while_op_dict[parent_id].output_arg_names()) while_op_output.update(while_op_dict[forward_id].output_arg_names())
ops_list.append((while_block_ops, while_block_op_size, while_op_output)) ops_list.append((while_block_ops, while_block_op_size, while_op_output))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册