diff --git a/python/paddle/jit/dy2static/partial_program.py b/python/paddle/jit/dy2static/partial_program.py index 40f276d269f96a2371a870a4b884a1c4aea4ca82..7fcdc51ff756bce491d2cc58c6573ddebd5ae81c 100644 --- a/python/paddle/jit/dy2static/partial_program.py +++ b/python/paddle/jit/dy2static/partial_program.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy + import numpy as np import paddle @@ -699,19 +701,32 @@ class PartialProgramLayer: def _get_forward_backward_program_form( self, whole_program, forward_end_op_index ): - forward_builded_program = add_build_strategy_for( - whole_program, 0, forward_end_op_index, self._build_strategy - ) + # NOTE(dev): We apply build_strategy for backward firstly to + # avoid skipping more gc variables. backward_start_op_index = forward_end_op_index + 2 * len( self._outputs.var_ids ) backward_end_op_index = whole_program.desc.block(0).op_size() + backward_skip_vars = self._parse_skip_gc_vars(whole_program) backward_builded_program = add_build_strategy_for( whole_program, backward_start_op_index, backward_end_op_index, self._build_strategy, + backward_skip_vars, + ) + + forward_skip_vars = self._parse_skip_gc_vars( + whole_program, backward_builded_program + ) + forward_builded_program = add_build_strategy_for( + whole_program, + 0, + forward_end_op_index, + self._build_strategy, + forward_skip_vars, ) + self._apply_inplace_pass( forward_builded_program, backward_builded_program ) @@ -726,26 +741,10 @@ class PartialProgramLayer: empty_startup_program = paddle.static.Program() use_cuda = True if core.is_compiled_with_cuda() else False # skip data var - forward_mem_opt_skip_vars = [] - backward_mem_opt_skip_vars = [] - for var_name, var in forward_program.global_block().vars.items(): - if var.is_data: - forward_mem_opt_skip_vars.append(var_name) - for var_name, var in backward_program.global_block().vars.items(): - if var.is_data: - backward_mem_opt_skip_vars.append(var_name) - for var in self._inputs: - if isinstance(var, paddle.fluid.framework.Variable): - forward_mem_opt_skip_vars.append(var.desc.name()) - backward_mem_opt_skip_vars.append(var.desc.name()) - for var in self._outputs: - if isinstance(var, paddle.fluid.framework.Variable): - forward_mem_opt_skip_vars.append(var.desc.name()) - backward_mem_opt_skip_vars.append(var.desc.name()) - for var_name in core.parse_safe_eager_deletion_skip_vars( - backward_program.desc - ): - forward_mem_opt_skip_vars.append(var_name) + forward_mem_opt_skip_vars = self._parse_skip_gc_vars( + forward_program, backward_program + ) + backward_mem_opt_skip_vars = self._parse_skip_gc_vars(forward_program) attrs = { "use_cuda": use_cuda, "mem_opt_skip_vars": forward_mem_opt_skip_vars, @@ -771,6 +770,38 @@ class PartialProgramLayer: attr_types, ) + @LazyInitialized + def _inout_var_names(self): + """ + Returns Variable Names from self._inputs and self.outputs + """ + var_names = [] + for var in self._inputs: + if isinstance(var, paddle.fluid.framework.Variable): + var_names.append(var.desc.name()) + for var in self._outputs: + if isinstance(var, paddle.fluid.framework.Variable): + var_names.append(var.desc.name()) + return var_names + + def _parse_skip_gc_vars(self, program, backward_program=None): + """ + Parse variables that need to skip GC after execute it. + If specify backward_program, it will keep the variables used in backward. + """ + # skip data var, DO NOT ignore this deepcopy + skip_vars = deepcopy(self._inout_var_names) + for var_name, var in program.global_block().vars.items(): + if var.is_data: + skip_vars.append(var_name) + + if backward_program: + for var_name in core.parse_safe_eager_deletion_skip_vars( + backward_program.desc + ): + skip_vars.append(var_name) + return skip_vars + def _prepare(self, inputs): """ Prepare inputs, outputs, attrs. @@ -1055,13 +1086,16 @@ def partial_program_from(concrete_program): @switch_to_static_graph def add_build_strategy_for( - program, start_op_index, end_op_index, build_strategy=None + program, start_op_index, end_op_index, build_strategy=None, skip_vars=None ): if start_op_index < end_op_index: compiled_program = paddle.static.CompiledProgram( core.Graph(program.desc, start_op_index, end_op_index), build_strategy=build_strategy, ) + if skip_vars: + # TODO(Aurelius84): Need to unify name with C++, such as kSkipVarNames. + compiled_program._graph.set("skip_gc_vars", set(skip_vars)) compiled_program._compile( core.Scope(), framework._current_expected_place() )