From 28b8adb1481d35d3546ac425014b2a68ff8995d6 Mon Sep 17 00:00:00 2001 From: Ruibiao Chen Date: Tue, 8 Aug 2023 16:57:15 +0800 Subject: [PATCH] Improve GC for pipeline parallel (#56022) * Improve GC for pipeline parallel * Delete print --- paddle/fluid/operators/nop_op.cc | 7 +- .../paddle/distributed/passes/pass_utils.py | 70 ++++++++----------- 2 files changed, 37 insertions(+), 40 deletions(-) diff --git a/paddle/fluid/operators/nop_op.cc b/paddle/fluid/operators/nop_op.cc index 2c148663656..66c475b31c2 100644 --- a/paddle/fluid/operators/nop_op.cc +++ b/paddle/fluid/operators/nop_op.cc @@ -45,9 +45,14 @@ establish the dependency between input and output tensors. } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERER(NopNoNeedBufferVarsInferer, "X", "Out"); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_WITHOUT_GRADIENT(nop, ops::NopOp, ops::NopOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(nop, + ops::NopOp, + ops::NopOpMaker, + ops::NopNoNeedBufferVarsInferer); diff --git a/python/paddle/distributed/passes/pass_utils.py b/python/paddle/distributed/passes/pass_utils.py index ab0ab0feb41..afb93500816 100644 --- a/python/paddle/distributed/passes/pass_utils.py +++ b/python/paddle/distributed/passes/pass_utils.py @@ -190,15 +190,14 @@ class OpInOutInfo: return for slot_name in op.input_names: - if slot_name in self._no_need_buffer_slots: - continue - - for in_name in op.input(slot_name): - self._other_arg_names_set.add(in_name) + if slot_name not in self._no_need_buffer_slots: + for in_name in op.input(slot_name): + self._other_arg_names_set.add(in_name) for slot_name in op.output_names: - for out_name in op.output(slot_name): - self._other_arg_names_set.add(out_name) + if slot_name not in self._no_need_buffer_slots: + for out_name in op.output(slot_name): + self._other_arg_names_set.add(out_name) self._is_build = True @@ -209,16 +208,9 @@ class OpInOutInfo: ) -def var_can_be_deleted(var_name, program): - var = program.global_block()._find_var_recursive(var_name) - if var is None or var.persistable: - return False - - return var.type in [ - core.VarDesc.VarType.LOD_TENSOR, - core.VarDesc.VarType.SELECTED_ROWS, - core.VarDesc.VarType.LOD_TENSOR_ARRAY, - ] +def var_can_be_deleted(var_name, block): + var = block._find_var_recursive(var_name) + return var is not None and not var.persistable def get_skip_gc_vars(program_list: List[Program]): @@ -231,31 +223,31 @@ def get_skip_gc_vars(program_list: List[Program]): """ # step1: Get all vars of every sub_program of program_list that are non-persistable and not in op's no_need_buffer. - vars_list = [set() for _ in range(len(program_list))] - for ip, program in enumerate(program_list): - for op in program.global_block().ops: - op_info = OpInOutInfo() - for in_name in op.input_arg_names: - if not var_can_be_deleted(in_name, program): + required_vars = [set() for _ in range(len(program_list))] + for idx, program in enumerate(program_list): + for block in program.blocks: + for op in block.ops: + # NOTE(Ruibiao): Some vars maybe be the arguements of conditional_block op but no-need-buffer in the actual subblock, should not add them to the required_vars. + if op.type == "conditional_block": continue - if not op_info.is_build: - op_info.build_info(op) - - if op_info.is_needed(in_name): - vars_list[ip].add(in_name) - - for out_name in op.output_arg_names: - if var_can_be_deleted(out_name, program): - vars_list[ip].add(out_name) + op_info = OpInOutInfo() + op_info.build_info(op) + for arg_name in op.input_arg_names + op.output_arg_names: + if var_can_be_deleted( + arg_name, block + ) and op_info.is_needed(arg_name): + required_vars[idx].add(arg_name) - # step2: get the `skip_gc_vars` that vars of current sub_program might be used in the later sub_program - union_set = set() + # step2: Get the `skip_gc_vars` that vars of current sub_program might be used in the later sub_program + suffixed_required_vars = set() skip_gc_vars = [set()] * len(program_list) - for idx, vars_set in reversed(list(enumerate(vars_list))): - if idx < len(vars_list) - 1: - union_set = union_set.union(vars_list[idx + 1]) - skip_gc_vars[idx] = vars_set & union_set + for idx, vars_set in reversed(list(enumerate(required_vars))): + if idx < len(required_vars) - 1: + suffixed_required_vars = suffixed_required_vars.union( + required_vars[idx + 1] + ) + skip_gc_vars[idx] = vars_set & suffixed_required_vars return skip_gc_vars @@ -279,7 +271,7 @@ def _create_param(dst_block, src_var): stop_gradient=src_var.stop_gradient, is_data=src_var.is_data, belong_to_optimizer=src_var.belong_to_optimizer, - **copied_kwargs + **copied_kwargs, ) -- GitLab