未验证 提交 161998f7 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Recompute Support New Graph Executor (#47846)

* add depend

* fp16 pass distinguish None & False

* engine log
上级 ae256454
...@@ -1407,6 +1407,27 @@ def naive_set_dist_op_attr_for_program_by_mesh_and_mapping( ...@@ -1407,6 +1407,27 @@ def naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr) ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
def naive_set_dist_op_attr_for_program_by_mesh(
new_op, process_mesh, ctx, is_recompute=False
):
assert process_mesh is not None
new_op_dist_attr = OperatorDistributedAttribute()
for input_varname in new_op.desc.input_arg_names():
var = ctx.serial_main_program.global_block().var(input_varname)
mapping = ctx.get_tensor_dist_attr_for_program(var).dims_mapping
new_op_dist_attr.set_input_dims_mapping(input_varname, mapping)
for output_varname in new_op.desc.output_arg_names():
var = ctx.serial_main_program.global_block().var(output_varname)
mapping = ctx.get_tensor_dist_attr_for_program(var).dims_mapping
new_op_dist_attr.set_output_dims_mapping(output_varname, mapping)
new_op_dist_attr.process_mesh = process_mesh
new_op_dist_attr.is_recompute = is_recompute
ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr)
def update_op_dims_mapping_by_default_dist_impl(dist_op): def update_op_dims_mapping_by_default_dist_impl(dist_op):
changed = False changed = False
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
...@@ -2102,3 +2123,73 @@ def _copy_dist_attr_from_cpp_for_graph(dist_context): ...@@ -2102,3 +2123,73 @@ def _copy_dist_attr_from_cpp_for_graph(dist_context):
py_dist_attr = dist_context.get_op_dist_attr_for_graph(node) py_dist_attr = dist_context.get_op_dist_attr_for_graph(node)
cpp_dist_attr = node.op().dist_attr cpp_dist_attr = node.op().dist_attr
_copy_op_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr) _copy_op_dist_attr_from_cpp(cpp_dist_attr, py_dist_attr)
def insert_dependencies_for_two_ops(
block,
idx,
prior_op,
posterior,
dist_context,
is_recompute=False,
sync=False,
):
"""
dependency: prior_op should be run before posterior
"""
assert (
len(prior_op.output_arg_names) >= 1
), "first op of dependency should at least have one output. [{}]".format(
str(prior_op)
)
assert (
len(posterior.input_arg_names) >= 1
), "second op of dependency should at least have one input. [{}]".format(
str(posterior)
)
prior_op_mesh = dist_context.get_op_dist_attr_for_program(
prior_op
).process_mesh
posterior_mesh = dist_context.get_op_dist_attr_for_program(
posterior
).process_mesh
assert (
prior_op_mesh == posterior_mesh
), "two ops of dependency should have same mesh but got [{}] and [{}]".format(
str(prior_op_mesh), str(posterior_mesh)
)
def _select_best_depend_var(vars):
vars_with_numels = [(var, get_var_numel(var)) for var in vars]
vars_with_numels.sort(key=lambda x: x[1])
return vars_with_numels[-1][0]
first_var = _select_best_depend_var(
[block.var(name) for name in prior_op.output_arg_names]
)
second_var = _select_best_depend_var(
[block.var(name) for name in posterior.input_arg_names]
)
depend_op = block._insert_op_without_sync(
idx,
type='nop',
inputs={
"X": first_var,
},
outputs={"Out": second_var},
)
# depend_op.desc.set_type("depend")
depend_op._set_attr(OP_ROLE_KEY, OpRole.Backward)
# depend_op.desc.set_input("Dep", [first_var.name])
# self.desc.set_output(out_proto.name, out_arg_names)
naive_set_dist_op_attr_for_program_by_mesh(
depend_op, prior_op_mesh, dist_context, is_recompute
)
if sync:
block._sync_with_cpp()
...@@ -29,6 +29,7 @@ from paddle.distributed.auto_parallel.utils import ( ...@@ -29,6 +29,7 @@ from paddle.distributed.auto_parallel.utils import (
) )
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.utils import (
naive_set_dist_op_attr_for_program_by_mesh_and_mapping, naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
insert_dependencies_for_two_ops,
) )
...@@ -449,6 +450,7 @@ class RecomputePass(PassBase): ...@@ -449,6 +450,7 @@ class RecomputePass(PassBase):
while idx - 1 >= 0 and ops[idx - 1].type == "sum": while idx - 1 >= 0 and ops[idx - 1].type == "sum":
idx -= 1 idx -= 1
segment_descs = ckpt_ops_dict[fwd_op_id][1] segment_descs = ckpt_ops_dict[fwd_op_id][1]
rc_op = None
for _, op_desc in reversed(list(enumerate(segment_descs))): for _, op_desc in reversed(list(enumerate(segment_descs))):
rc_op = main_block._insert_op_without_sync( rc_op = main_block._insert_op_without_sync(
idx, type='nop' idx, type='nop'
...@@ -466,7 +468,15 @@ class RecomputePass(PassBase): ...@@ -466,7 +468,15 @@ class RecomputePass(PassBase):
) )
ckpt_ops_dict[fwd_op_id][0] = False ckpt_ops_dict[fwd_op_id][0] = False
if rc_op:
insert_dependencies_for_two_ops(
main_block,
idx,
main_block.ops[rc_op.idx - 1],
rc_op,
self._dist_context,
sync=False,
)
main_program._sync_with_cpp() main_program._sync_with_cpp()
def reset_op_dist_attr(self, op, var_name_dict): def reset_op_dist_attr(self, op, var_name_dict):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册