未验证 提交 ffa32e44 编写于 作者: A Aurelius84 提交者: GitHub

[D2SCinn]Support deliver skip_gc_vars into Graph (#49411)

* [D2SCinn]Support deliver skip_gc_vars into Graph

* fix unittest

* fix copy
上级 a30e3602
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from copy import deepcopy
import numpy as np import numpy as np
import paddle import paddle
...@@ -699,19 +701,32 @@ class PartialProgramLayer: ...@@ -699,19 +701,32 @@ class PartialProgramLayer:
def _get_forward_backward_program_form( def _get_forward_backward_program_form(
self, whole_program, forward_end_op_index self, whole_program, forward_end_op_index
): ):
forward_builded_program = add_build_strategy_for( # NOTE(dev): We apply build_strategy for backward firstly to
whole_program, 0, forward_end_op_index, self._build_strategy # avoid skipping more gc variables.
)
backward_start_op_index = forward_end_op_index + 2 * len( backward_start_op_index = forward_end_op_index + 2 * len(
self._outputs.var_ids self._outputs.var_ids
) )
backward_end_op_index = whole_program.desc.block(0).op_size() backward_end_op_index = whole_program.desc.block(0).op_size()
backward_skip_vars = self._parse_skip_gc_vars(whole_program)
backward_builded_program = add_build_strategy_for( backward_builded_program = add_build_strategy_for(
whole_program, whole_program,
backward_start_op_index, backward_start_op_index,
backward_end_op_index, backward_end_op_index,
self._build_strategy, self._build_strategy,
backward_skip_vars,
)
forward_skip_vars = self._parse_skip_gc_vars(
whole_program, backward_builded_program
)
forward_builded_program = add_build_strategy_for(
whole_program,
0,
forward_end_op_index,
self._build_strategy,
forward_skip_vars,
) )
self._apply_inplace_pass( self._apply_inplace_pass(
forward_builded_program, backward_builded_program forward_builded_program, backward_builded_program
) )
...@@ -726,26 +741,10 @@ class PartialProgramLayer: ...@@ -726,26 +741,10 @@ class PartialProgramLayer:
empty_startup_program = paddle.static.Program() empty_startup_program = paddle.static.Program()
use_cuda = True if core.is_compiled_with_cuda() else False use_cuda = True if core.is_compiled_with_cuda() else False
# skip data var # skip data var
forward_mem_opt_skip_vars = [] forward_mem_opt_skip_vars = self._parse_skip_gc_vars(
backward_mem_opt_skip_vars = [] forward_program, backward_program
for var_name, var in forward_program.global_block().vars.items(): )
if var.is_data: backward_mem_opt_skip_vars = self._parse_skip_gc_vars(forward_program)
forward_mem_opt_skip_vars.append(var_name)
for var_name, var in backward_program.global_block().vars.items():
if var.is_data:
backward_mem_opt_skip_vars.append(var_name)
for var in self._inputs:
if isinstance(var, paddle.fluid.framework.Variable):
forward_mem_opt_skip_vars.append(var.desc.name())
backward_mem_opt_skip_vars.append(var.desc.name())
for var in self._outputs:
if isinstance(var, paddle.fluid.framework.Variable):
forward_mem_opt_skip_vars.append(var.desc.name())
backward_mem_opt_skip_vars.append(var.desc.name())
for var_name in core.parse_safe_eager_deletion_skip_vars(
backward_program.desc
):
forward_mem_opt_skip_vars.append(var_name)
attrs = { attrs = {
"use_cuda": use_cuda, "use_cuda": use_cuda,
"mem_opt_skip_vars": forward_mem_opt_skip_vars, "mem_opt_skip_vars": forward_mem_opt_skip_vars,
...@@ -771,6 +770,38 @@ class PartialProgramLayer: ...@@ -771,6 +770,38 @@ class PartialProgramLayer:
attr_types, attr_types,
) )
@LazyInitialized
def _inout_var_names(self):
"""
Returns Variable Names from self._inputs and self.outputs
"""
var_names = []
for var in self._inputs:
if isinstance(var, paddle.fluid.framework.Variable):
var_names.append(var.desc.name())
for var in self._outputs:
if isinstance(var, paddle.fluid.framework.Variable):
var_names.append(var.desc.name())
return var_names
def _parse_skip_gc_vars(self, program, backward_program=None):
"""
Parse variables that need to skip GC after execute it.
If specify backward_program, it will keep the variables used in backward.
"""
# skip data var, DO NOT ignore this deepcopy
skip_vars = deepcopy(self._inout_var_names)
for var_name, var in program.global_block().vars.items():
if var.is_data:
skip_vars.append(var_name)
if backward_program:
for var_name in core.parse_safe_eager_deletion_skip_vars(
backward_program.desc
):
skip_vars.append(var_name)
return skip_vars
def _prepare(self, inputs): def _prepare(self, inputs):
""" """
Prepare inputs, outputs, attrs. Prepare inputs, outputs, attrs.
...@@ -1055,13 +1086,16 @@ def partial_program_from(concrete_program): ...@@ -1055,13 +1086,16 @@ def partial_program_from(concrete_program):
@switch_to_static_graph @switch_to_static_graph
def add_build_strategy_for( def add_build_strategy_for(
program, start_op_index, end_op_index, build_strategy=None program, start_op_index, end_op_index, build_strategy=None, skip_vars=None
): ):
if start_op_index < end_op_index: if start_op_index < end_op_index:
compiled_program = paddle.static.CompiledProgram( compiled_program = paddle.static.CompiledProgram(
core.Graph(program.desc, start_op_index, end_op_index), core.Graph(program.desc, start_op_index, end_op_index),
build_strategy=build_strategy, build_strategy=build_strategy,
) )
if skip_vars:
# TODO(Aurelius84): Need to unify name with C++, such as kSkipVarNames.
compiled_program._graph.set("skip_gc_vars", set(skip_vars))
compiled_program._compile( compiled_program._compile(
core.Scope(), framework._current_expected_place() core.Scope(), framework._current_expected_place()
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册