From 7049af576224763999f4e7c951f2fc1237599c26 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 16 Jul 2021 11:28:02 +0800 Subject: [PATCH] [Dy2Stat]Add NoNeedBufferVarsInferer to reduce memory usage (#34177) * Add NoNeedBufferVarsInferer * fix code style --- paddle/fluid/framework/executor_cache.cc | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index 8ec75e992b3..17eac8a7228 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -77,6 +77,7 @@ void ParseSafeEagerDeletionSkipVars( const std::vector &output_var_names, std::vector *skip_eager_delete_vars) { auto all_ops = program.Block(0).AllOps(); + auto &op_info_map = OpInfoMap::Instance(); // NOTE: skip `shape` and `fill_constant` op created by // fluid.backward.gradients, one forward output will generate one `shape` // and `fill_constant`. @@ -86,16 +87,33 @@ void ParseSafeEagerDeletionSkipVars( // step 2: parse the necessary variable of backward op std::unordered_set op_outputs; std::unordered_set op_inputs; + std::unordered_set no_need_buffer_ins; + for (auto i = backward_op_start_index; i < all_ops.size(); ++i) { framework::OpDesc *op = all_ops[i]; - for (const std::string &in_arg_name : op->InputArgumentNames()) { - op_inputs.emplace(in_arg_name); + // NOTE: skip NoNeedBufferVars of grad_op and GC its memory in advance. + auto &op_info = op_info_map.Get(op->Type()); + auto &inferer = op_info.NoNeedBufferVarsInferer(); + no_need_buffer_ins.clear(); + if (inferer != nullptr) { + no_need_buffer_ins = + inferer(op->Inputs(), op->Outputs(), op->GetAttrMap()); + } + for (auto &in_names : op->Inputs()) { + if (no_need_buffer_ins.count(in_names.first) == 0) { + for (auto &in_name : in_names.second) { + op_inputs.emplace(in_name); + } + } else { + VLOG(2) << op->Type() << " has no_need_buffer_in: " << in_names.first + << " , skip it."; + } } + for (const std::string &out_arg_name : op->OutputArgumentNames()) { op_outputs.emplace(out_arg_name); } } - // For the grad op input variables, if it is not output of grad_op, it may // be output of forward op and we should set the variables as skip_var to // prevent it being deleted when grad op is called multiple times. -- GitLab