diff --git a/paddle/framework/prune.cc b/paddle/framework/prune.cc index bff8e0bceaca9749101b2c45edddba526d565624..6a3882f199e370526827fbfa0118e79f5972cce3 100644 --- a/paddle/framework/prune.cc +++ b/paddle/framework/prune.cc @@ -49,11 +49,28 @@ bool IsTarget(const proto::OpDesc& op_desc) { return false; } -void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, - int block_id) { - // TODO(tonyyang-svail): - // - will change to use multiple blocks for RNN op and Cond Op +int GetSubBlockIndex(const proto::OpDesc& op_desc) { + for (auto& attr : op_desc.attrs()) { + if (attr.type() == proto::AttrType::BLOCK) { + PADDLE_ENFORCE(attr.has_block_idx()); + return attr.block_idx(); + } + } + return -1; +} +bool HasSubBlock(const proto::OpDesc& op_desc) { + return GetSubBlockIndex(op_desc) > 0; +} + +// block_id is the idx of the current block in the input desc +// parent_block_id is the idx of the parent of the current block +// in the output desc, -1 means the current block is global block +// dependent_vars is passed recursively from the parent block to +// the child block to help pruning +void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, + int block_id, int parent_block_id, + std::set& dependent_vars) { auto& block = input.blocks(block_id); auto& ops = block.ops(); @@ -72,11 +89,9 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, expect_fetch = (op_desc.type() == kFetchOpType); } - std::set dependent_vars; std::vector should_run; for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { auto& op_desc = *op_iter; - if (IsTarget(op_desc) || HasDependentVar(op_desc, dependent_vars)) { // insert its input to the dependency graph for (auto& var : op_desc.inputs()) { @@ -84,7 +99,6 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, dependent_vars.insert(argu); } } - should_run.push_back(true); } else { should_run.push_back(false); @@ -95,19 +109,48 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, // we reverse the should_run vector std::reverse(should_run.begin(), should_run.end()); - *output = input; - auto* op_field = output->mutable_blocks(block_id)->mutable_ops(); + //*output = input; + // copy the current block from input to output + auto* block_field = output->mutable_blocks(); + *block_field->Add() = input.blocks(block_id); + + int output_block_id = output->blocks_size() - 1; + auto* output_block = output->mutable_blocks(output_block_id); + output_block->set_idx = output_block_id; + output_block->set_parent_idx = parent_block_id; + + auto* op_field = output_block->mutable_ops(); op_field->Clear(); for (size_t i = 0; i < should_run.size(); ++i) { if (should_run[i]) { - *op_field->Add() = input.blocks(block_id).ops(i); + auto* op = op_field->Add(); + *op = input.blocks(block_id).ops(i); + if (HasSubBlock(*op)) { + // create sub_block_dependent_vars here to help prune the sub block + std::set sub_block_dependent_vars; + for (auto& var : op.inputs()) { + for (auto& argu : var.arguments()) { + sub_block_dependent_vars.insert(argu); + } + } + for (auto& var : op.outputs()) { + for (auto& argu : var.arguments()) { + sub_block_dependent_vars.insert(argu); + } + } + + // GetSubBlockIndex(*op) is the idx of the sub_block in the input desc + // output_block_id is the idx of the current block in the output desc + prune_impl(input, output, GetSubBlockIndex(*op), output_block_id, + sub_block_dependent_vars); + } } } // remove the VarDescs in BlockDesc that are not referenced in // the pruned OpDescs std::unordered_map var_map; - auto* var_field = output->mutable_blocks(block_id)->mutable_vars(); + auto* var_field = output->mutable_blocks(output_block_id)->mutable_vars(); for (const auto& var : *var_field) { var_map[var.name()] = var; } @@ -118,14 +161,14 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, auto& input_field = op.inputs(); for (auto& input_var : input_field) { for (auto& arg : input_var.arguments()) { - *var_field->Add() = var_map[arg]; + *var_field->Add() = var_map.at(arg); } } // add VarDescs of all output arguments for each OpDesc auto& output_field = op.outputs(); for (auto& output_var : output_field) { for (auto& arg : output_var.arguments()) { - *var_field->Add() = var_map[arg]; + *var_field->Add() = var_map.at(arg); } } } @@ -133,7 +176,7 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, // TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies void Prune(const proto::ProgramDesc& input, proto::ProgramDesc* output) { - prune_impl(input, output, 0); + prune_impl(input, output, 0, -1, {}); } void inference_optimize_impl(const proto::ProgramDesc& input,