提交 c3d27b15 编写于 作者: K Kexin Zhao

modify prune.cc for multiple blocks

上级 dc8390d8
...@@ -49,11 +49,28 @@ bool IsTarget(const proto::OpDesc& op_desc) { ...@@ -49,11 +49,28 @@ bool IsTarget(const proto::OpDesc& op_desc) {
return false; return false;
} }
void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, int GetSubBlockIndex(const proto::OpDesc& op_desc) {
int block_id) { for (auto& attr : op_desc.attrs()) {
// TODO(tonyyang-svail): if (attr.type() == proto::AttrType::BLOCK) {
// - will change to use multiple blocks for RNN op and Cond Op PADDLE_ENFORCE(attr.has_block_idx());
return attr.block_idx();
}
}
return -1;
}
bool HasSubBlock(const proto::OpDesc& op_desc) {
return GetSubBlockIndex(op_desc) > 0;
}
// block_id is the idx of the current block in the input desc
// parent_block_id is the idx of the parent of the current block
// in the output desc, -1 means the current block is global block
// dependent_vars is passed recursively from the parent block to
// the child block to help pruning
void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
int block_id, int parent_block_id,
std::set<std::string>& dependent_vars) {
auto& block = input.blocks(block_id); auto& block = input.blocks(block_id);
auto& ops = block.ops(); auto& ops = block.ops();
...@@ -72,11 +89,9 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, ...@@ -72,11 +89,9 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
expect_fetch = (op_desc.type() == kFetchOpType); expect_fetch = (op_desc.type() == kFetchOpType);
} }
std::set<std::string> dependent_vars;
std::vector<bool> should_run; std::vector<bool> should_run;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter; auto& op_desc = *op_iter;
if (IsTarget(op_desc) || HasDependentVar(op_desc, dependent_vars)) { if (IsTarget(op_desc) || HasDependentVar(op_desc, dependent_vars)) {
// insert its input to the dependency graph // insert its input to the dependency graph
for (auto& var : op_desc.inputs()) { for (auto& var : op_desc.inputs()) {
...@@ -84,7 +99,6 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, ...@@ -84,7 +99,6 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
dependent_vars.insert(argu); dependent_vars.insert(argu);
} }
} }
should_run.push_back(true); should_run.push_back(true);
} else { } else {
should_run.push_back(false); should_run.push_back(false);
...@@ -95,19 +109,48 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, ...@@ -95,19 +109,48 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
// we reverse the should_run vector // we reverse the should_run vector
std::reverse(should_run.begin(), should_run.end()); std::reverse(should_run.begin(), should_run.end());
*output = input; //*output = input;
auto* op_field = output->mutable_blocks(block_id)->mutable_ops(); // copy the current block from input to output
auto* block_field = output->mutable_blocks();
*block_field->Add() = input.blocks(block_id);
int output_block_id = output->blocks_size() - 1;
auto* output_block = output->mutable_blocks(output_block_id);
output_block->set_idx = output_block_id;
output_block->set_parent_idx = parent_block_id;
auto* op_field = output_block->mutable_ops();
op_field->Clear(); op_field->Clear();
for (size_t i = 0; i < should_run.size(); ++i) { for (size_t i = 0; i < should_run.size(); ++i) {
if (should_run[i]) { if (should_run[i]) {
*op_field->Add() = input.blocks(block_id).ops(i); auto* op = op_field->Add();
*op = input.blocks(block_id).ops(i);
if (HasSubBlock(*op)) {
// create sub_block_dependent_vars here to help prune the sub block
std::set<std::string> sub_block_dependent_vars;
for (auto& var : op.inputs()) {
for (auto& argu : var.arguments()) {
sub_block_dependent_vars.insert(argu);
}
}
for (auto& var : op.outputs()) {
for (auto& argu : var.arguments()) {
sub_block_dependent_vars.insert(argu);
}
}
// GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
// output_block_id is the idx of the current block in the output desc
prune_impl(input, output, GetSubBlockIndex(*op), output_block_id,
sub_block_dependent_vars);
}
} }
} }
// remove the VarDescs in BlockDesc that are not referenced in // remove the VarDescs in BlockDesc that are not referenced in
// the pruned OpDescs // the pruned OpDescs
std::unordered_map<std::string, proto::VarDesc> var_map; std::unordered_map<std::string, proto::VarDesc> var_map;
auto* var_field = output->mutable_blocks(block_id)->mutable_vars(); auto* var_field = output->mutable_blocks(output_block_id)->mutable_vars();
for (const auto& var : *var_field) { for (const auto& var : *var_field) {
var_map[var.name()] = var; var_map[var.name()] = var;
} }
...@@ -118,14 +161,14 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, ...@@ -118,14 +161,14 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
auto& input_field = op.inputs(); auto& input_field = op.inputs();
for (auto& input_var : input_field) { for (auto& input_var : input_field) {
for (auto& arg : input_var.arguments()) { for (auto& arg : input_var.arguments()) {
*var_field->Add() = var_map[arg]; *var_field->Add() = var_map.at(arg);
} }
} }
// add VarDescs of all output arguments for each OpDesc // add VarDescs of all output arguments for each OpDesc
auto& output_field = op.outputs(); auto& output_field = op.outputs();
for (auto& output_var : output_field) { for (auto& output_var : output_field) {
for (auto& arg : output_var.arguments()) { for (auto& arg : output_var.arguments()) {
*var_field->Add() = var_map[arg]; *var_field->Add() = var_map.at(arg);
} }
} }
} }
...@@ -133,7 +176,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, ...@@ -133,7 +176,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies // TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) { void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) {
prune_impl(input, output, 0); prune_impl(input, output, 0, -1, {});
} }
void inference_optimize_impl(const proto::ProgramDesc& input, void inference_optimize_impl(const proto::ProgramDesc& input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册