diff --git a/paddle/framework/prune.cc b/paddle/framework/prune.cc index 59947470e7d3582b4e853e5b3d0f00dc01235853..6117cbb79fcf4d5ef955640b156a68a80374d966 100644 --- a/paddle/framework/prune.cc +++ b/paddle/framework/prune.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include #include @@ -103,33 +104,31 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, } } - // remove the vars in ProgramDesc that are not referenced in - // the pruned ops - std::unordered_map var_map; + // remove the VarDescs in BlockDesc that are not referenced in + // the pruned OpDescs + std::unordered_map var_map; auto* var_field = output->mutable_blocks(block_id)->mutable_vars(); - for (auto* var : *var_field) { - var_map[var->name()] = *var; + for (const auto& var : *var_field) { + var_map[var.name()] = var; } - // for (size_t i = 0; i < var_field->size(); ++i) { - // auto* var = (*var_field)[i]; - // var_map[var->name()] = *var; - // } - - // var_field->Clear(); - // for (size_t i = 0; i < op_field->size(); ++i) { - // auto* op = (*op_field)[i]; - - // auto* input_field = op->mutable_inputs(); - // for (size_t j = 0; j < input_field->size(); ++j) { - // auto* input_names = (*input_field)[j]->arguments(); - // for () - // *var_field->Add() = var_map[] - // } - // auto* ouput_field = op->mutable_outputs(); - // for (size_t k = 0; k < output_field->size(); ++k) { - // } - // } + var_field->Clear(); + for (const auto& op : *op_field) { + // add VarDescs of all input arguments for each OpDesc + auto& input_field = op.inputs(); + for (auto& input : input_field) { + for (auto& arg : input.arguments()) { + *var_field->Add() = var_map[arg]; + } + } + // add VarDescs of all output arguments for each OpDesc + auto& output_field = op.outputs(); + for (auto& output : output_field) { + for (auto& arg : output.arguments()) { + *var_field->Add() = var_map[arg]; + } + } + } } // TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies