未验证 提交 a36cdd6b 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] add dist_attr in data_parallel optimization (#49744)

* fix dist_attr in data_parallel in optimization

* fix grad_clip pass when pp2

* fix dist_attr
上级 3c121040
...@@ -221,7 +221,14 @@ class DistributedOperator: ...@@ -221,7 +221,14 @@ class DistributedOperator:
) )
for arg_name in self.serial_op.desc.input_arg_names(): for arg_name in self.serial_op.desc.input_arg_names():
dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name) try:
dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name)
except IndexError:
raise IndexError(
"There is not input var '{}''s dist_attr in current op '{}'".format(
arg_name, self.serial_op.desc.type()
)
)
if self.dist_attr.is_annotated_input_dims_mapping(arg_name): if self.dist_attr.is_annotated_input_dims_mapping(arg_name):
annotated_str = "annotated" annotated_str = "annotated"
else: else:
...@@ -238,7 +245,14 @@ class DistributedOperator: ...@@ -238,7 +245,14 @@ class DistributedOperator:
) )
for arg_name in self.serial_op.desc.output_arg_names(): for arg_name in self.serial_op.desc.output_arg_names():
dims_mapping = self.dist_attr.get_output_dims_mapping(arg_name) try:
dims_mapping = self.dist_attr.get_output_dims_mapping(arg_name)
except IndexError:
raise IndexError(
"There is not output var '{}''s dist_attr in current op '{}'".format(
arg_name, self.serial_op.desc.type()
)
)
if self.dist_attr.is_annotated_output_dims_mapping(arg_name): if self.dist_attr.is_annotated_output_dims_mapping(arg_name):
annotated_str = "annotated" annotated_str = "annotated"
else: else:
......
...@@ -1426,9 +1426,6 @@ def naive_set_dist_op_attr_for_program_by_mesh_and_mapping( ...@@ -1426,9 +1426,6 @@ def naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
def naive_set_dist_op_attr_for_program_by_mesh( def naive_set_dist_op_attr_for_program_by_mesh(
new_op, process_mesh, ctx, is_recompute=False new_op, process_mesh, ctx, is_recompute=False
): ):
# hack to skip coalesce var for dist attr
if not is_recompute:
return
assert process_mesh is not None assert process_mesh is not None
new_op_dist_attr = OperatorDistAttr() new_op_dist_attr = OperatorDistAttr()
...@@ -2314,15 +2311,31 @@ def insert_dependencies_for_vars( ...@@ -2314,15 +2311,31 @@ def insert_dependencies_for_vars(
}, },
outputs={"Out": post_vars}, outputs={"Out": post_vars},
) )
# depend_op.desc.set_type("depend")
depend_op._set_attr(OP_ROLE_KEY, oprole) depend_op._set_attr(OP_ROLE_KEY, oprole)
# depend_op.desc.set_input("Dep", [first_var.name])
# self.desc.set_output(out_proto.name, out_arg_names)
naive_set_dist_op_attr_for_program_by_mesh( # TODO: condition can be removed when add correct dist_attr for coalesce vars and ops in sharding_pass
depend_op, process_mesh, dist_context, is_recompute if is_recompute or process_mesh != [-1]:
) depend_op_dist_attr = OperatorDistAttr()
depend_op_dist_attr.impl_idx = 0
depend_op_dist_attr.impl_type = "default"
depend_op_dist_attr.process_mesh = process_mesh
depend_op_dist_attr.is_recompute = is_recompute
for input_varname in depend_op.desc.input_arg_names():
var = block.var(input_varname)
mapping = dist_context.get_tensor_dist_attr_for_program(
var
).dims_mapping
depend_op_dist_attr.set_input_dims_mapping(input_varname, mapping)
for output_varname in depend_op.desc.output_arg_names():
var = block.var(output_varname)
mapping = dist_context.get_tensor_dist_attr_for_program(
var
).dims_mapping
depend_op_dist_attr.set_output_dims_mapping(output_varname, mapping)
dist_context.set_op_dist_attr_for_program(
depend_op, depend_op_dist_attr
)
if op_namescope is not None: if op_namescope is not None:
depend_op._set_attr('op_namescope', "/{}".format(op_namescope)) depend_op._set_attr('op_namescope', "/{}".format(op_namescope))
......
...@@ -15,10 +15,15 @@ ...@@ -15,10 +15,15 @@
from collections import OrderedDict from collections import OrderedDict
import paddle import paddle
from paddle.distributed.auto_parallel.dist_attribute import (
OperatorDistAttr,
TensorDistAttr,
)
from paddle.distributed.auto_parallel.operators.common import ( from paddle.distributed.auto_parallel.operators.common import (
is_data_parallel_reduce_op, is_data_parallel_reduce_op,
is_data_parallel_scale_op, is_data_parallel_scale_op,
) )
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.utils import ( from paddle.distributed.auto_parallel.utils import (
find_higher_order_backward_op, find_higher_order_backward_op,
get_var_numel, get_var_numel,
...@@ -463,6 +468,21 @@ class DataParallelOptimizationPass(PassBase): ...@@ -463,6 +468,21 @@ class DataParallelOptimizationPass(PassBase):
group.coalesce_var = group.gradients[0] group.coalesce_var = group.gradients[0]
continue continue
ref_process_mesh = set()
concated_shapes = []
concated_ranks = []
for grad_ in group.gradients:
grad_dist_attr = (
self.dist_context.get_tensor_dist_attr_for_program(grad_)
)
ref_process_mesh.update(
set(grad_dist_attr.process_mesh.process_ids)
)
shape = grad_.shape
concated_shapes.extend(shape)
concated_ranks.append(len(shape))
# create coalesce tensor # create coalesce tensor
group.coalesce_var = block.create_var( group.coalesce_var = block.create_var(
name=unique_name.generate( name=unique_name.generate(
...@@ -473,6 +493,13 @@ class DataParallelOptimizationPass(PassBase): ...@@ -473,6 +493,13 @@ class DataParallelOptimizationPass(PassBase):
stop_gradient=True, stop_gradient=True,
) )
tensor_dist_attr = TensorDistAttr()
tensor_dist_attr.process_mesh = ProcessMesh(list(ref_process_mesh))
tensor_dist_attr.dims_mapping = []
self.dist_context.set_tensor_dist_attr_for_program(
group.coalesce_var, tensor_dist_attr
)
# update allreduce & scale op # update allreduce & scale op
if group.scale_op_idx != -1: if group.scale_op_idx != -1:
scale_op = block.ops[group.scale_op_idx] scale_op = block.ops[group.scale_op_idx]
...@@ -492,11 +519,27 @@ class DataParallelOptimizationPass(PassBase): ...@@ -492,11 +519,27 @@ class DataParallelOptimizationPass(PassBase):
), "should found c_allreduce_sum op but found {}".format( ), "should found c_allreduce_sum op but found {}".format(
str(allreduce_op) str(allreduce_op)
) )
allreduce_op._rename_input( allreduce_op_dist_attr = (
allreduce_op.input_arg_names[0], group.coalesce_var.name self.dist_context.get_op_dist_attr_for_program(allreduce_op)
)
old_in_name = allreduce_op.input_arg_names[0]
new_in_name = group.coalesce_var.name
allreduce_op._rename_input(old_in_name, new_in_name)
input_dist_attr = allreduce_op_dist_attr.get_input_dist_attr(
old_in_name
) )
allreduce_op._rename_output( allreduce_op_dist_attr.set_input_dist_attr(
allreduce_op.output_arg_names[0], group.coalesce_var.name new_in_name, input_dist_attr
)
old_out_name = allreduce_op.output_arg_names[0]
new_out_name = group.coalesce_var.name
allreduce_op._rename_output(old_out_name, new_out_name)
out_dist_attr = allreduce_op_dist_attr.get_output_dist_attr(
old_out_name
)
allreduce_op_dist_attr.set_output_dist_attr(
new_out_name, out_dist_attr
) )
# remvoe un-used op # remvoe un-used op
...@@ -512,15 +555,8 @@ class DataParallelOptimizationPass(PassBase): ...@@ -512,15 +555,8 @@ class DataParallelOptimizationPass(PassBase):
block._remove_op(idx, False) block._remove_op(idx, False)
# insert coalesce op # insert coalesce op
concated_shapes = []
concated_ranks = []
for grad_ in group.gradients:
shape = grad_.shape
concated_shapes.extend(shape)
concated_ranks.append(len(shape))
grad_names = [grad.name for grad in group.gradients] grad_names = [grad.name for grad in group.gradients]
block._insert_op_without_sync( coalesce_op = block._insert_op_without_sync(
group.coalesce_op_idx, group.coalesce_op_idx,
type="coalesce_tensor", type="coalesce_tensor",
inputs={"Input": grad_names}, inputs={"Input": grad_names},
...@@ -538,8 +574,32 @@ class DataParallelOptimizationPass(PassBase): ...@@ -538,8 +574,32 @@ class DataParallelOptimizationPass(PassBase):
}, },
) )
op_dist_attr = OperatorDistAttr()
op_dist_attr.impl_idx = 0
op_dist_attr.impl_type = "default"
op_dist_attr.process_mesh = ProcessMesh(list(ref_process_mesh))
for in_name in coalesce_op.input_arg_names:
in_var = block.var(in_name)
in_var_dist_attr = (
self.dist_context.get_tensor_dist_attr_for_program(in_var)
)
op_dist_attr.set_input_dims_mapping(
in_name, in_var_dist_attr.dims_mapping
)
for out_name in coalesce_op.output_arg_names:
out_var = block.var(out_name)
out_var_dist_attr = (
self.dist_context.get_tensor_dist_attr_for_program(out_var)
)
op_dist_attr.set_output_dims_mapping(
out_name, out_var_dist_attr.dims_mapping
)
self.dist_context.set_op_dist_attr_for_program(
coalesce_op, op_dist_attr
)
block._sync_with_cpp() block._sync_with_cpp()
# TODO update dist attr
def _add_dependencies(self, grad_groups): def _add_dependencies(self, grad_groups):
# NOTE Currently, auto_parallel need to adopt for two executors: Sequential executor (old exe) and Graph based # NOTE Currently, auto_parallel need to adopt for two executors: Sequential executor (old exe) and Graph based
...@@ -551,22 +611,12 @@ class DataParallelOptimizationPass(PassBase): ...@@ -551,22 +611,12 @@ class DataParallelOptimizationPass(PassBase):
block = default_main_program().global_block() block = default_main_program().global_block()
# Build maps # Build maps
vars_to_coalesce_map = {}
coalesce_to_vars_map = {} coalesce_to_vars_map = {}
for group in grad_groups: for group in grad_groups:
grad_names = [] coalesce_to_vars_map[group.coalesce_var.name] = group
coalesce_name = group.coalesce_var.name
for grad in group.gradients:
vars_to_coalesce_map[grad.name] = coalesce_name
grad_names.append(grad.name)
coalesce_to_vars_map[coalesce_name] = grad_names
# analyze dependencies # analyze dependencies
# Record ONLY the last grad that generated before allreduce dep_map = {}
# NOTE need to be update when we allow multiple calc stream for backward calc
not_sync_coalesces = []
prior_allreduce_deps = {}
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if is_forward_op(op): if is_forward_op(op):
break break
...@@ -575,86 +625,41 @@ class DataParallelOptimizationPass(PassBase): ...@@ -575,86 +625,41 @@ class DataParallelOptimizationPass(PassBase):
if is_data_parallel_reduce_op(op): if is_data_parallel_reduce_op(op):
coalesce_var_name = op.output_arg_names[0] coalesce_var_name = op.output_arg_names[0]
# NOTE only add extra deps for fused tensor, other tensor rely on
# data flow analysis of executor.
if self.coalesce_prefix in coalesce_var_name:
prior_allreduce_deps[coalesce_var_name] = [
idx,
None,
coalesce_var_name,
]
not_sync_coalesces.append(coalesce_var_name)
continue
for out_name in op.output_arg_names:
var_name = vars_to_coalesce_map.get(out_name, None)
if var_name in not_sync_coalesces:
prior_allreduce_deps[var_name][1] = out_name
not_sync_coalesces.remove(var_name)
assert (
len(not_sync_coalesces) == 0
), "Unexpected: {} has NOT been add prior Dep before allreduce.".format(
not_sync_coalesces
)
# Record ONLY the first grad that used after allreduce
# NOTE need to be update when we allow multiple calc stream for backward calc
not_sync_coalesces = []
post_allreduce_deps = {}
for idx, op in enumerate(block.ops):
if is_forward_op(op):
continue
if is_data_parallel_reduce_op(op):
coalesce_var_name = op.input_arg_names[0]
if self.coalesce_prefix in coalesce_var_name: if self.coalesce_prefix in coalesce_var_name:
post_allreduce_deps[coalesce_var_name] = [ group = coalesce_to_vars_map[coalesce_var_name]
None, dep_map[idx] = [
coalesce_var_name, (
None, idx,
group.gradients[-1],
group.coalesce_var,
op.attr(OP_ROLE_KEY),
)
] ]
not_sync_coalesces.append(coalesce_var_name) dep_map[idx].append(
continue (
idx + 1,
for out_name in op.input_arg_names: group.coalesce_var,
var_name = vars_to_coalesce_map.get(out_name, None) group.gradients,
if var_name in not_sync_coalesces: op.attr(OP_ROLE_KEY),
post_allreduce_deps[var_name][0] = idx )
post_allreduce_deps[var_name][2] = out_name )
not_sync_coalesces.remove(var_name)
assert (
len(not_sync_coalesces) == 0
), "Unexpected: {} has NOT been add post Dep after allreduce.".format(
not_sync_coalesces
)
# Update program IR insert dependencise op # insert dependency op
dep_var_pairs = [] indice = sorted(list(dep_map.keys()), reverse=True)
for deps in [prior_allreduce_deps, post_allreduce_deps]: for i in indice:
for pair in deps.values(): for idx, prior_vars, post_vars, op_role in dep_map[i][::-1]:
dep_var_pairs.append(pair) depend_op = insert_dependencies_for_vars(
block,
dep_var_pairs.sort(key=lambda x: x[0], reverse=True) idx,
for idx, prior_name, post_name in dep_var_pairs: prior_vars,
prior_var = block.var(prior_name) post_vars,
post_var = block.var(post_name) self.dist_context,
depend_op = insert_dependencies_for_vars( op_role,
block, is_recompute=False,
idx, sync=False,
prior_var, op_namescope="data_parallel_overlap_dep",
post_var, )
self.dist_context, depend_op.dist_attr.execution_stream = self.gradient_sync_stream
OpRole.Backward,
process_mesh=[
-1
], # hack to avoid initialize the dist attr for coalesce var
is_recompute=False,
sync=False,
op_namescope="data_parallel_overlap_dep",
)
depend_op.dist_attr.execution_stream = self.gradient_sync_stream
block._sync_with_cpp() block._sync_with_cpp()
# remove naive synchronization & assign allreduce stream # remove naive synchronization & assign allreduce stream
......
...@@ -254,6 +254,8 @@ class ClipHelper: ...@@ -254,6 +254,8 @@ class ClipHelper:
"c_allreduce_sum", "c_allreduce_sum",
] and not is_data_parallel_reduce_op(op): ] and not is_data_parallel_reduce_op(op):
return False return False
if op.type in ["send_v2", "recv_v2"]:
return False
return True return True
......
...@@ -150,9 +150,6 @@ class AutoParalSupplementDepPass(PassBase): ...@@ -150,9 +150,6 @@ class AutoParalSupplementDepPass(PassBase):
post_var, post_var,
self._dist_context, self._dist_context,
OpRole.Optimize, OpRole.Optimize,
process_mesh=[
-1
], # hack to avoid initialize the dist attr for coalesc var
is_recompute=False, is_recompute=False,
sync=False, sync=False,
op_namescope=op_namescope, op_namescope=op_namescope,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册