未验证 提交 7049af57 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Add NoNeedBufferVarsInferer to reduce memory usage (#34177)

* Add NoNeedBufferVarsInferer

* fix code style
上级 44bdbe93
...@@ -77,6 +77,7 @@ void ParseSafeEagerDeletionSkipVars( ...@@ -77,6 +77,7 @@ void ParseSafeEagerDeletionSkipVars(
const std::vector<std::string> &output_var_names, const std::vector<std::string> &output_var_names,
std::vector<std::string> *skip_eager_delete_vars) { std::vector<std::string> *skip_eager_delete_vars) {
auto all_ops = program.Block(0).AllOps(); auto all_ops = program.Block(0).AllOps();
auto &op_info_map = OpInfoMap::Instance();
// NOTE: skip `shape` and `fill_constant` op created by // NOTE: skip `shape` and `fill_constant` op created by
// fluid.backward.gradients, one forward output will generate one `shape` // fluid.backward.gradients, one forward output will generate one `shape`
// and `fill_constant`. // and `fill_constant`.
...@@ -86,16 +87,33 @@ void ParseSafeEagerDeletionSkipVars( ...@@ -86,16 +87,33 @@ void ParseSafeEagerDeletionSkipVars(
// step 2: parse the necessary variable of backward op // step 2: parse the necessary variable of backward op
std::unordered_set<std::string> op_outputs; std::unordered_set<std::string> op_outputs;
std::unordered_set<std::string> op_inputs; std::unordered_set<std::string> op_inputs;
std::unordered_set<std::string> no_need_buffer_ins;
for (auto i = backward_op_start_index; i < all_ops.size(); ++i) { for (auto i = backward_op_start_index; i < all_ops.size(); ++i) {
framework::OpDesc *op = all_ops[i]; framework::OpDesc *op = all_ops[i];
for (const std::string &in_arg_name : op->InputArgumentNames()) { // NOTE: skip NoNeedBufferVars of grad_op and GC its memory in advance.
op_inputs.emplace(in_arg_name); auto &op_info = op_info_map.Get(op->Type());
auto &inferer = op_info.NoNeedBufferVarsInferer();
no_need_buffer_ins.clear();
if (inferer != nullptr) {
no_need_buffer_ins =
inferer(op->Inputs(), op->Outputs(), op->GetAttrMap());
}
for (auto &in_names : op->Inputs()) {
if (no_need_buffer_ins.count(in_names.first) == 0) {
for (auto &in_name : in_names.second) {
op_inputs.emplace(in_name);
}
} else {
VLOG(2) << op->Type() << " has no_need_buffer_in: " << in_names.first
<< " , skip it.";
}
} }
for (const std::string &out_arg_name : op->OutputArgumentNames()) { for (const std::string &out_arg_name : op->OutputArgumentNames()) {
op_outputs.emplace(out_arg_name); op_outputs.emplace(out_arg_name);
} }
} }
// For the grad op input variables, if it is not output of grad_op, it may // For the grad op input variables, if it is not output of grad_op, it may
// be output of forward op and we should set the variables as skip_var to // be output of forward op and we should set the variables as skip_var to
// prevent it being deleted when grad op is called multiple times. // prevent it being deleted when grad op is called multiple times.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册