未验证 提交 ffa32e44 编写于 作者: A Aurelius84 提交者: GitHub

[D2SCinn]Support deliver skip_gc_vars into Graph (#49411)

* [D2SCinn]Support deliver skip_gc_vars into Graph

* fix unittest

* fix copy
上级 a30e3602
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import numpy as np
import paddle
......@@ -699,19 +701,32 @@ class PartialProgramLayer:
def _get_forward_backward_program_form(
self, whole_program, forward_end_op_index
):
forward_builded_program = add_build_strategy_for(
whole_program, 0, forward_end_op_index, self._build_strategy
)
# NOTE(dev): We apply build_strategy for backward firstly to
# avoid skipping more gc variables.
backward_start_op_index = forward_end_op_index + 2 * len(
self._outputs.var_ids
)
backward_end_op_index = whole_program.desc.block(0).op_size()
backward_skip_vars = self._parse_skip_gc_vars(whole_program)
backward_builded_program = add_build_strategy_for(
whole_program,
backward_start_op_index,
backward_end_op_index,
self._build_strategy,
backward_skip_vars,
)
forward_skip_vars = self._parse_skip_gc_vars(
whole_program, backward_builded_program
)
forward_builded_program = add_build_strategy_for(
whole_program,
0,
forward_end_op_index,
self._build_strategy,
forward_skip_vars,
)
self._apply_inplace_pass(
forward_builded_program, backward_builded_program
)
......@@ -726,26 +741,10 @@ class PartialProgramLayer:
empty_startup_program = paddle.static.Program()
use_cuda = True if core.is_compiled_with_cuda() else False
# skip data var
forward_mem_opt_skip_vars = []
backward_mem_opt_skip_vars = []
for var_name, var in forward_program.global_block().vars.items():
if var.is_data:
forward_mem_opt_skip_vars.append(var_name)
for var_name, var in backward_program.global_block().vars.items():
if var.is_data:
backward_mem_opt_skip_vars.append(var_name)
for var in self._inputs:
if isinstance(var, paddle.fluid.framework.Variable):
forward_mem_opt_skip_vars.append(var.desc.name())
backward_mem_opt_skip_vars.append(var.desc.name())
for var in self._outputs:
if isinstance(var, paddle.fluid.framework.Variable):
forward_mem_opt_skip_vars.append(var.desc.name())
backward_mem_opt_skip_vars.append(var.desc.name())
for var_name in core.parse_safe_eager_deletion_skip_vars(
backward_program.desc
):
forward_mem_opt_skip_vars.append(var_name)
forward_mem_opt_skip_vars = self._parse_skip_gc_vars(
forward_program, backward_program
)
backward_mem_opt_skip_vars = self._parse_skip_gc_vars(forward_program)
attrs = {
"use_cuda": use_cuda,
"mem_opt_skip_vars": forward_mem_opt_skip_vars,
......@@ -771,6 +770,38 @@ class PartialProgramLayer:
attr_types,
)
@LazyInitialized
def _inout_var_names(self):
"""
Returns Variable Names from self._inputs and self.outputs
"""
var_names = []
for var in self._inputs:
if isinstance(var, paddle.fluid.framework.Variable):
var_names.append(var.desc.name())
for var in self._outputs:
if isinstance(var, paddle.fluid.framework.Variable):
var_names.append(var.desc.name())
return var_names
def _parse_skip_gc_vars(self, program, backward_program=None):
"""
Parse variables that need to skip GC after execute it.
If specify backward_program, it will keep the variables used in backward.
"""
# skip data var, DO NOT ignore this deepcopy
skip_vars = deepcopy(self._inout_var_names)
for var_name, var in program.global_block().vars.items():
if var.is_data:
skip_vars.append(var_name)
if backward_program:
for var_name in core.parse_safe_eager_deletion_skip_vars(
backward_program.desc
):
skip_vars.append(var_name)
return skip_vars
def _prepare(self, inputs):
"""
Prepare inputs, outputs, attrs.
......@@ -1055,13 +1086,16 @@ def partial_program_from(concrete_program):
@switch_to_static_graph
def add_build_strategy_for(
program, start_op_index, end_op_index, build_strategy=None
program, start_op_index, end_op_index, build_strategy=None, skip_vars=None
):
if start_op_index < end_op_index:
compiled_program = paddle.static.CompiledProgram(
core.Graph(program.desc, start_op_index, end_op_index),
build_strategy=build_strategy,
)
if skip_vars:
# TODO(Aurelius84): Need to unify name with C++, such as kSkipVarNames.
compiled_program._graph.set("skip_gc_vars", set(skip_vars))
compiled_program._compile(
core.Scope(), framework._current_expected_place()
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册