未验证 提交 28b8adb1 编写于 作者: R Ruibiao Chen 提交者: GitHub

Improve GC for pipeline parallel (#56022)

* Improve GC for pipeline parallel

* Delete print
上级 fbb3a34f
...@@ -45,9 +45,14 @@ establish the dependency between input and output tensors. ...@@ -45,9 +45,14 @@ establish the dependency between input and output tensors.
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(NopNoNeedBufferVarsInferer, "X", "Out");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(nop, ops::NopOp, ops::NopOpMaker); REGISTER_OP_WITHOUT_GRADIENT(nop,
ops::NopOp,
ops::NopOpMaker,
ops::NopNoNeedBufferVarsInferer);
...@@ -190,13 +190,12 @@ class OpInOutInfo: ...@@ -190,13 +190,12 @@ class OpInOutInfo:
return return
for slot_name in op.input_names: for slot_name in op.input_names:
if slot_name in self._no_need_buffer_slots: if slot_name not in self._no_need_buffer_slots:
continue
for in_name in op.input(slot_name): for in_name in op.input(slot_name):
self._other_arg_names_set.add(in_name) self._other_arg_names_set.add(in_name)
for slot_name in op.output_names: for slot_name in op.output_names:
if slot_name not in self._no_need_buffer_slots:
for out_name in op.output(slot_name): for out_name in op.output(slot_name):
self._other_arg_names_set.add(out_name) self._other_arg_names_set.add(out_name)
...@@ -209,16 +208,9 @@ class OpInOutInfo: ...@@ -209,16 +208,9 @@ class OpInOutInfo:
) )
def var_can_be_deleted(var_name, program): def var_can_be_deleted(var_name, block):
var = program.global_block()._find_var_recursive(var_name) var = block._find_var_recursive(var_name)
if var is None or var.persistable: return var is not None and not var.persistable
return False
return var.type in [
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY,
]
def get_skip_gc_vars(program_list: List[Program]): def get_skip_gc_vars(program_list: List[Program]):
...@@ -231,31 +223,31 @@ def get_skip_gc_vars(program_list: List[Program]): ...@@ -231,31 +223,31 @@ def get_skip_gc_vars(program_list: List[Program]):
""" """
# step1: Get all vars of every sub_program of program_list that are non-persistable and not in op's no_need_buffer. # step1: Get all vars of every sub_program of program_list that are non-persistable and not in op's no_need_buffer.
vars_list = [set() for _ in range(len(program_list))] required_vars = [set() for _ in range(len(program_list))]
for ip, program in enumerate(program_list): for idx, program in enumerate(program_list):
for op in program.global_block().ops: for block in program.blocks:
op_info = OpInOutInfo() for op in block.ops:
for in_name in op.input_arg_names: # NOTE(Ruibiao): Some vars maybe be the arguements of conditional_block op but no-need-buffer in the actual subblock, should not add them to the required_vars.
if not var_can_be_deleted(in_name, program): if op.type == "conditional_block":
continue continue
if not op_info.is_build: op_info = OpInOutInfo()
op_info.build_info(op) op_info.build_info(op)
for arg_name in op.input_arg_names + op.output_arg_names:
if op_info.is_needed(in_name): if var_can_be_deleted(
vars_list[ip].add(in_name) arg_name, block
) and op_info.is_needed(arg_name):
for out_name in op.output_arg_names: required_vars[idx].add(arg_name)
if var_can_be_deleted(out_name, program):
vars_list[ip].add(out_name) # step2: Get the `skip_gc_vars` that vars of current sub_program might be used in the later sub_program
suffixed_required_vars = set()
# step2: get the `skip_gc_vars` that vars of current sub_program might be used in the later sub_program
union_set = set()
skip_gc_vars = [set()] * len(program_list) skip_gc_vars = [set()] * len(program_list)
for idx, vars_set in reversed(list(enumerate(vars_list))): for idx, vars_set in reversed(list(enumerate(required_vars))):
if idx < len(vars_list) - 1: if idx < len(required_vars) - 1:
union_set = union_set.union(vars_list[idx + 1]) suffixed_required_vars = suffixed_required_vars.union(
skip_gc_vars[idx] = vars_set & union_set required_vars[idx + 1]
)
skip_gc_vars[idx] = vars_set & suffixed_required_vars
return skip_gc_vars return skip_gc_vars
...@@ -279,7 +271,7 @@ def _create_param(dst_block, src_var): ...@@ -279,7 +271,7 @@ def _create_param(dst_block, src_var):
stop_gradient=src_var.stop_gradient, stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data, is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer, belong_to_optimizer=src_var.belong_to_optimizer,
**copied_kwargs **copied_kwargs,
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册