未验证 提交 28b8adb1 编写于 作者: R Ruibiao Chen 提交者: GitHub

Improve GC for pipeline parallel (#56022)

* Improve GC for pipeline parallel

* Delete print
上级 fbb3a34f
......@@ -45,9 +45,14 @@ establish the dependency between input and output tensors.
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERER(NopNoNeedBufferVarsInferer, "X", "Out");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(nop, ops::NopOp, ops::NopOpMaker);
REGISTER_OP_WITHOUT_GRADIENT(nop,
ops::NopOp,
ops::NopOpMaker,
ops::NopNoNeedBufferVarsInferer);
......@@ -190,15 +190,14 @@ class OpInOutInfo:
return
for slot_name in op.input_names:
if slot_name in self._no_need_buffer_slots:
continue
for in_name in op.input(slot_name):
self._other_arg_names_set.add(in_name)
if slot_name not in self._no_need_buffer_slots:
for in_name in op.input(slot_name):
self._other_arg_names_set.add(in_name)
for slot_name in op.output_names:
for out_name in op.output(slot_name):
self._other_arg_names_set.add(out_name)
if slot_name not in self._no_need_buffer_slots:
for out_name in op.output(slot_name):
self._other_arg_names_set.add(out_name)
self._is_build = True
......@@ -209,16 +208,9 @@ class OpInOutInfo:
)
def var_can_be_deleted(var_name, program):
var = program.global_block()._find_var_recursive(var_name)
if var is None or var.persistable:
return False
return var.type in [
core.VarDesc.VarType.LOD_TENSOR,
core.VarDesc.VarType.SELECTED_ROWS,
core.VarDesc.VarType.LOD_TENSOR_ARRAY,
]
def var_can_be_deleted(var_name, block):
var = block._find_var_recursive(var_name)
return var is not None and not var.persistable
def get_skip_gc_vars(program_list: List[Program]):
......@@ -231,31 +223,31 @@ def get_skip_gc_vars(program_list: List[Program]):
"""
# step1: Get all vars of every sub_program of program_list that are non-persistable and not in op's no_need_buffer.
vars_list = [set() for _ in range(len(program_list))]
for ip, program in enumerate(program_list):
for op in program.global_block().ops:
op_info = OpInOutInfo()
for in_name in op.input_arg_names:
if not var_can_be_deleted(in_name, program):
required_vars = [set() for _ in range(len(program_list))]
for idx, program in enumerate(program_list):
for block in program.blocks:
for op in block.ops:
# NOTE(Ruibiao): Some vars maybe be the arguements of conditional_block op but no-need-buffer in the actual subblock, should not add them to the required_vars.
if op.type == "conditional_block":
continue
if not op_info.is_build:
op_info.build_info(op)
if op_info.is_needed(in_name):
vars_list[ip].add(in_name)
for out_name in op.output_arg_names:
if var_can_be_deleted(out_name, program):
vars_list[ip].add(out_name)
op_info = OpInOutInfo()
op_info.build_info(op)
for arg_name in op.input_arg_names + op.output_arg_names:
if var_can_be_deleted(
arg_name, block
) and op_info.is_needed(arg_name):
required_vars[idx].add(arg_name)
# step2: get the `skip_gc_vars` that vars of current sub_program might be used in the later sub_program
union_set = set()
# step2: Get the `skip_gc_vars` that vars of current sub_program might be used in the later sub_program
suffixed_required_vars = set()
skip_gc_vars = [set()] * len(program_list)
for idx, vars_set in reversed(list(enumerate(vars_list))):
if idx < len(vars_list) - 1:
union_set = union_set.union(vars_list[idx + 1])
skip_gc_vars[idx] = vars_set & union_set
for idx, vars_set in reversed(list(enumerate(required_vars))):
if idx < len(required_vars) - 1:
suffixed_required_vars = suffixed_required_vars.union(
required_vars[idx + 1]
)
skip_gc_vars[idx] = vars_set & suffixed_required_vars
return skip_gc_vars
......@@ -279,7 +271,7 @@ def _create_param(dst_block, src_var):
stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer,
**copied_kwargs
**copied_kwargs,
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册