未验证 提交 23e5b25c 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel Performance] Optimizing data parallel Fuse-Allreduce-Overlapping (#48092)

* add depend

* add origin amp files

* fp16 distinguish None & False

* engine log

* dp add deps for graph exe

* add dep for grad clip

* dep ops in comm stream

* unitest
上级 0707c0af
...@@ -1410,6 +1410,9 @@ def naive_set_dist_op_attr_for_program_by_mesh_and_mapping( ...@@ -1410,6 +1410,9 @@ def naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
def naive_set_dist_op_attr_for_program_by_mesh( def naive_set_dist_op_attr_for_program_by_mesh(
new_op, process_mesh, ctx, is_recompute=False new_op, process_mesh, ctx, is_recompute=False
): ):
# hack to skip coalesce var for dist attr
if not is_recompute:
return
assert process_mesh is not None assert process_mesh is not None
new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr = OperatorDistributedAttribute()
...@@ -2129,13 +2132,13 @@ def insert_dependencies_for_two_ops( ...@@ -2129,13 +2132,13 @@ def insert_dependencies_for_two_ops(
block, block,
idx, idx,
prior_op, prior_op,
posterior, posterior_op,
dist_context, dist_context,
is_recompute=False, is_recompute=False,
sync=False, sync=False,
): ):
""" """
dependency: prior_op should be run before posterior dependency: prior_op should be run before posterior_op
""" """
assert ( assert (
...@@ -2144,15 +2147,15 @@ def insert_dependencies_for_two_ops( ...@@ -2144,15 +2147,15 @@ def insert_dependencies_for_two_ops(
str(prior_op) str(prior_op)
) )
assert ( assert (
len(posterior.input_arg_names) >= 1 len(posterior_op.input_arg_names) >= 1
), "second op of dependency should at least have one input. [{}]".format( ), "second op of dependency should at least have one input. [{}]".format(
str(posterior) str(posterior_op)
) )
prior_op_mesh = dist_context.get_op_dist_attr_for_program( prior_op_mesh = dist_context.get_op_dist_attr_for_program(
prior_op prior_op
).process_mesh ).process_mesh
posterior_mesh = dist_context.get_op_dist_attr_for_program( posterior_mesh = dist_context.get_op_dist_attr_for_program(
posterior posterior_op
).process_mesh ).process_mesh
assert ( assert (
prior_op_mesh == posterior_mesh prior_op_mesh == posterior_mesh
...@@ -2171,25 +2174,72 @@ def insert_dependencies_for_two_ops( ...@@ -2171,25 +2174,72 @@ def insert_dependencies_for_two_ops(
[block.var(name) for name in prior_op.output_arg_names] [block.var(name) for name in prior_op.output_arg_names]
) )
second_var = _select_best_depend_var( second_var = _select_best_depend_var(
[block.var(name) for name in posterior.input_arg_names] [block.var(name) for name in posterior_op.input_arg_names]
) )
return insert_dependencies_for_two_vars(
block,
idx,
first_var,
second_var,
dist_context,
OpRole.Backward,
prior_op_mesh,
is_recompute,
sync,
)
def insert_dependencies_for_two_vars(
block,
idx,
prior_var,
post_var,
dist_context,
oprole,
process_mesh=None,
is_recompute=False,
sync=False,
):
"""
dependency: op that generates prior_var should be run before op that generates post_var
"""
assert block.has_var(prior_var.name)
assert block.has_var(post_var.name)
if process_mesh is None:
process_mesh = dist_context.get_tensor_dist_attr_for_program(
post_var
).process_mesh
assert process_mesh is not None
depend_op = block._insert_op_without_sync( depend_op = block._insert_op_without_sync(
idx, idx,
type='nop', type='nop',
inputs={ inputs={
"X": first_var, "X": prior_var,
}, },
outputs={"Out": second_var}, outputs={"Out": post_var},
) )
# depend_op.desc.set_type("depend") # depend_op.desc.set_type("depend")
depend_op._set_attr(OP_ROLE_KEY, OpRole.Backward) depend_op._set_attr(OP_ROLE_KEY, oprole)
# depend_op.desc.set_input("Dep", [first_var.name]) # depend_op.desc.set_input("Dep", [first_var.name])
# self.desc.set_output(out_proto.name, out_arg_names) # self.desc.set_output(out_proto.name, out_arg_names)
naive_set_dist_op_attr_for_program_by_mesh( naive_set_dist_op_attr_for_program_by_mesh(
depend_op, prior_op_mesh, dist_context, is_recompute depend_op, process_mesh, dist_context, is_recompute
) )
if sync: if sync:
block._sync_with_cpp() block._sync_with_cpp()
return depend_op
def use_standalone_executor():
return os.environ.get('FLAGS_CONVERT_GRAPH_TO_PROGRAM', None) in [
1,
'1',
True,
'True',
'true',
]
...@@ -27,8 +27,11 @@ from paddle.distributed.auto_parallel.utils import ( ...@@ -27,8 +27,11 @@ from paddle.distributed.auto_parallel.utils import (
find_higher_order_backward_op, find_higher_order_backward_op,
is_loss_grad_op, is_loss_grad_op,
is_optimize_op, is_optimize_op,
is_forward_op,
ring_id_to_process_group, ring_id_to_process_group,
get_var_numel, get_var_numel,
use_standalone_executor,
insert_dependencies_for_two_vars,
) )
# add new optimizers supporting rescale_grad here # add new optimizers supporting rescale_grad here
...@@ -87,16 +90,20 @@ class DataParallelOptimizationPass(PassBase): ...@@ -87,16 +90,20 @@ class DataParallelOptimizationPass(PassBase):
self.dist_context = self.get_attr("dist_context") self.dist_context = self.get_attr("dist_context")
self.global_rank = int(self.get_attr("global_rank")) self.global_rank = int(self.get_attr("global_rank"))
self.use_sharding = self.get_attr("use_sharding") self.use_sharding = self.get_attr("use_sharding")
self.coalesce_prefix = 'coalesce_grad'
if use_standalone_executor():
self.gradient_sync_stream = "gradient_sync_stream"
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
self._analyze_program() self._analyze_program()
# TODO refactor here to first fuse then overlap
if self.is_data_parallel_applied(): if self.is_data_parallel_applied():
self._prune_grad_scaling() self._prune_grad_scaling()
self._calc_comm_overlap() self._calc_comm_overlap()
grad_group = self._fuse_allreduce() grad_group = self._fuse_allreduce()
self._add_dependencies(grad_group)
# self.summary(grad_group) self.summary(grad_group)
def _prune_grad_scaling(self): def _prune_grad_scaling(self):
...@@ -284,7 +291,6 @@ class DataParallelOptimizationPass(PassBase): ...@@ -284,7 +291,6 @@ class DataParallelOptimizationPass(PassBase):
# InterpreterCore has a different logic for overlapping # InterpreterCore has a different logic for overlapping
# which is different from use_calc_stream # which is different from use_calc_stream
block = default_main_program().global_block() block = default_main_program().global_block()
ops = block.ops
# comm wait calc to finish # comm wait calc to finish
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
...@@ -294,7 +300,6 @@ class DataParallelOptimizationPass(PassBase): ...@@ -294,7 +300,6 @@ class DataParallelOptimizationPass(PassBase):
op._set_attr('use_calc_stream', False) op._set_attr('use_calc_stream', False)
ring_id = op.attr("ring_id") ring_id = op.attr("ring_id")
block._insert_op_without_sync( block._insert_op_without_sync(
idx, idx,
type='c_wait_compute', type='c_wait_compute',
...@@ -307,8 +312,10 @@ class DataParallelOptimizationPass(PassBase): ...@@ -307,8 +312,10 @@ class DataParallelOptimizationPass(PassBase):
def _calc_wait_comms(self): def _calc_wait_comms(self):
if use_standalone_executor():
return
block = default_main_program().global_block() block = default_main_program().global_block()
ops = block.ops
# NOTE the naive overlap implement in static hybird parallel only sync comm stream # NOTE the naive overlap implement in static hybird parallel only sync comm stream
# at the end of Backward phase, based on a strong constraint that # at the end of Backward phase, based on a strong constraint that
...@@ -325,7 +332,7 @@ class DataParallelOptimizationPass(PassBase): ...@@ -325,7 +332,7 @@ class DataParallelOptimizationPass(PassBase):
ring_id_to_un_sync_grad_map[group.id] = [] ring_id_to_un_sync_grad_map[group.id] = []
# analyze the where need to sync # analyze the where need to sync
for i, op in enumerate(ops): for i, op in enumerate(block.ops):
if is_data_parallel_reduce_op(op): if is_data_parallel_reduce_op(op):
ring_id = op.attr("ring_id") ring_id = op.attr("ring_id")
grad_name = op.output_arg_names[0] grad_name = op.output_arg_names[0]
...@@ -365,6 +372,7 @@ class DataParallelOptimizationPass(PassBase): ...@@ -365,6 +372,7 @@ class DataParallelOptimizationPass(PassBase):
outputs={'Out': []}, outputs={'Out': []},
attrs={'op_role': OpRole.Backward, 'ring_id': ring_id}, attrs={'op_role': OpRole.Backward, 'ring_id': ring_id},
) )
block._sync_with_cpp()
def _could_be_fuse(self): def _could_be_fuse(self):
# TODO support gradient fuse higher order gradient. # TODO support gradient fuse higher order gradient.
...@@ -404,8 +412,6 @@ class DataParallelOptimizationPass(PassBase): ...@@ -404,8 +412,6 @@ class DataParallelOptimizationPass(PassBase):
def collect_group(cur_group, grad_var, ring_id, i): def collect_group(cur_group, grad_var, ring_id, i):
if len(cur_group.gradients) == 0: if len(cur_group.gradients) == 0:
cur_group = None cur_group = None
elif len(cur_group.gradients) == 1:
grouped_grad_names.remove(cur_group.gradients[0].name)
else: else:
cur_group.finalize() cur_group.finalize()
grad_groups.append(cur_group) grad_groups.append(cur_group)
...@@ -451,9 +457,16 @@ class DataParallelOptimizationPass(PassBase): ...@@ -451,9 +457,16 @@ class DataParallelOptimizationPass(PassBase):
for i, group in enumerate(grad_groups[::-1]): for i, group in enumerate(grad_groups[::-1]):
# skip unfused big tensor
if len(group.gradients) <= 1:
group.coalesce_var = group.gradients[0]
continue
# create coalecse tensor # create coalecse tensor
group.coalesce_var = block.create_var( group.coalesce_var = block.create_var(
name=unique_name.generate('coalecse_grad_{}'.format(i)), name=unique_name.generate(
self.coalesce_prefix + '_{}'.format(i)
),
dtype=group.dtype, dtype=group.dtype,
persistable=False, persistable=False,
stop_gradient=True, stop_gradient=True,
...@@ -497,7 +510,7 @@ class DataParallelOptimizationPass(PassBase): ...@@ -497,7 +510,7 @@ class DataParallelOptimizationPass(PassBase):
), "Unexception: try to remove op {}".format( ), "Unexception: try to remove op {}".format(
str(block.ops[idx]) str(block.ops[idx])
) )
block._remove_op(idx) block._remove_op(idx, False)
# insert coalecse op # insert coalecse op
concated_shapes = [] concated_shapes = []
...@@ -529,6 +542,141 @@ class DataParallelOptimizationPass(PassBase): ...@@ -529,6 +542,141 @@ class DataParallelOptimizationPass(PassBase):
block._sync_with_cpp() block._sync_with_cpp()
# TODO update dist attr # TODO update dist attr
def _add_dependencies(self, grad_groups):
# NOTE Currently, auto_parallel need to adopt for two executors: Sequential executor (old exe) and Graph based
# multiple stream executor(standalone exe). This function just for standalone exe. Refactor here
# in future when only one executor stay.
if not use_standalone_executor() or len(grad_groups) == 0:
return
block = default_main_program().global_block()
# Build maps
vars_to_coalesce_map = {}
coalesce_to_vars_map = {}
for group in grad_groups:
grad_names = []
coalesce_name = group.coalesce_var.name
for grad in group.gradients:
vars_to_coalesce_map[grad.name] = coalesce_name
grad_names.append(grad.name)
coalesce_to_vars_map[coalesce_name] = grad_names
# analyze dependencies
# Record ONLY the last grad that generated before allreduce
# NOTE need to be update when we allow multiple calc stream for backward calc
not_sync_coalesces = []
prior_allreduce_deps = {}
for idx, op in reversed(list(enumerate(block.ops))):
if is_forward_op(op):
break
if is_optimize_op(op):
continue
if is_data_parallel_reduce_op(op):
coalesce_var_name = op.output_arg_names[0]
# NOTE only add extra deps for fused tensor, other tensor rely on
# data flow analysis of executor.
if self.coalesce_prefix in coalesce_var_name:
prior_allreduce_deps[coalesce_var_name] = [
idx,
None,
coalesce_var_name,
]
not_sync_coalesces.append(coalesce_var_name)
continue
for out_name in op.output_arg_names:
var_name = vars_to_coalesce_map.get(out_name, None)
if var_name in not_sync_coalesces:
prior_allreduce_deps[var_name][1] = out_name
not_sync_coalesces.remove(var_name)
assert (
len(not_sync_coalesces) == 0
), "Unexception: {} has NOT been add prior Dep before allreduce.".format(
not_sync_coalesces
)
# Record ONLY the first grad that used after allreduce
# NOTE need to be update when we allow multiple calc stream for backward calc
not_sync_coalesces = []
post_allreduce_deps = {}
for idx, op in enumerate(block.ops):
if is_forward_op(op):
continue
if is_data_parallel_reduce_op(op):
coalesce_var_name = op.input_arg_names[0]
if self.coalesce_prefix in coalesce_var_name:
post_allreduce_deps[coalesce_var_name] = [
None,
coalesce_var_name,
None,
]
not_sync_coalesces.append(coalesce_var_name)
continue
for out_name in op.input_arg_names:
var_name = vars_to_coalesce_map.get(out_name, None)
if var_name in not_sync_coalesces:
post_allreduce_deps[var_name][0] = idx
post_allreduce_deps[var_name][2] = out_name
not_sync_coalesces.remove(var_name)
assert (
len(not_sync_coalesces) == 0
), "Unexception: {} has NOT been add post Dep after allreduce.".format(
not_sync_coalesces
)
# Update program IR insert dependencise op
dep_var_pairs = []
for deps in [prior_allreduce_deps, post_allreduce_deps]:
for pair in deps.values():
dep_var_pairs.append(pair)
dep_var_pairs.sort(key=lambda x: x[0], reverse=True)
for idx, prior_name, post_name in dep_var_pairs:
prior_var = block.var(prior_name)
post_var = block.var(post_name)
depend_op = insert_dependencies_for_two_vars(
block,
idx,
prior_var,
post_var,
self.dist_context,
OpRole.Backward,
process_mesh=[
-1
], # hack to avoid initialize the dist attr for coalesc var
is_recompute=False,
sync=False,
)
depend_op.dist_attr.execution_stream = self.gradient_sync_stream
block._sync_with_cpp()
# remove naive synchronization & assign allreduce stream
def remove_cond(op):
if op.type != "c_wait_compute":
return False
if len(op.input_arg_names) != 0:
return False
if len(op.output_arg_names) != 0:
return False
return True
for idx, op in reversed(list(enumerate(block.ops))):
if is_data_parallel_reduce_op(op):
op._set_attr('use_calc_stream', True)
op.dist_attr.execution_stream = self.gradient_sync_stream
if remove_cond(op):
block._remove_op(idx, sync=False)
block._sync_with_cpp()
def summary(self, grad_groups=[]): def summary(self, grad_groups=[]):
# TODO: add logger module # TODO: add logger module
import logging import logging
......
...@@ -26,6 +26,8 @@ from ..auto_parallel.utils import ( ...@@ -26,6 +26,8 @@ from ..auto_parallel.utils import (
OP_ROLE_KEY, OP_ROLE_KEY,
OpRole, OpRole,
_get_comm_group, _get_comm_group,
insert_dependencies_for_two_vars,
use_standalone_executor,
) )
from ..auto_parallel.dist_attribute import ( from ..auto_parallel.dist_attribute import (
TensorDistributedAttribute, TensorDistributedAttribute,
...@@ -334,6 +336,7 @@ class ClipGradByGloblNormPass(PassBase): ...@@ -334,6 +336,7 @@ class ClipGradByGloblNormPass(PassBase):
if op.type == 'sqrt': if op.type == 'sqrt':
input_name = op.input("X")[0] input_name = op.input("X")[0]
input_var = block.vars[input_name] input_var = block.vars[input_name]
insert_leaf_fill_constant_node = False
if paddle.distributed.get_world_size() > 1: if paddle.distributed.get_world_size() > 1:
offset = 0 offset = 0
if input_name in removed_tmp_var: if input_name in removed_tmp_var:
...@@ -356,6 +359,7 @@ class ClipGradByGloblNormPass(PassBase): ...@@ -356,6 +359,7 @@ class ClipGradByGloblNormPass(PassBase):
) )
offset += 1 offset += 1
self.clip_helper._init_dist_attr(fill_constant_op) self.clip_helper._init_dist_attr(fill_constant_op)
insert_leaf_fill_constant_node = True
allreduce_op = block._insert_op( allreduce_op = block._insert_op(
idx + offset, idx + offset,
...@@ -373,6 +377,45 @@ class ClipGradByGloblNormPass(PassBase): ...@@ -373,6 +377,45 @@ class ClipGradByGloblNormPass(PassBase):
) )
self.clip_helper._init_dist_attr(allreduce_op) self.clip_helper._init_dist_attr(allreduce_op)
if (
use_standalone_executor
and insert_leaf_fill_constant_node
):
# NOTE add naive deps for global norm sync in graph exe
j = idx - 1
prior_op = None
while j > 0:
op_type = block.ops[j].type
if op_type in [
'update_loss_scaling',
'check_finite_and_unscale',
] or op_type.endswith("_grad"):
prior_op = block.ops[j]
break
j -= 1
print("here: ", block.ops[j])
assert (
prior_op is not None
), "Unexception: ClipByGlobalNorm could not find priory depend op"
prior_var = block.vars[prior_op.output_arg_names[0]]
assert (
prior_var is not None
), "Unexception: ClipByGlobalNorm could not find priory depend var"
insert_dependencies_for_two_vars(
block,
idx,
prior_var,
input_var,
self.clip_helper.dist_context,
OpRole.Optimize,
process_mesh=[
-1
], # hack to avoid initialize the dist attr for coalesc var
is_recompute=False,
sync=False,
)
for varname in removed_tmp_var: for varname in removed_tmp_var:
block._remove_var(varname, sync=False) block._remove_var(varname, sync=False)
......
...@@ -203,6 +203,7 @@ class RecomputeState(ProgramStats): ...@@ -203,6 +203,7 @@ class RecomputeState(ProgramStats):
if cur_op.attr("fix_seed") is False if cur_op.attr("fix_seed") is False
else int(cur_op.attr("seed")) else int(cur_op.attr("seed"))
) )
# TODO add dependency for seed op to ensure it be issued just before recompute.
seed_op = self._block._insert_op_without_sync( seed_op = self._block._insert_op_without_sync(
index=cur_op.idx, index=cur_op.idx,
type="seed", type="seed",
...@@ -490,6 +491,7 @@ class RecomputePass(PassBase): ...@@ -490,6 +491,7 @@ class RecomputePass(PassBase):
prior_op, prior_op,
posterior_op, posterior_op,
self._dist_context, self._dist_context,
is_recompute=True,
sync=False, sync=False,
) )
main_program._sync_with_cpp() main_program._sync_with_cpp()
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import random import random
import sys import sys
import unittest import unittest
...@@ -24,6 +25,9 @@ import paddle.distributed.fleet as fleet ...@@ -24,6 +25,9 @@ import paddle.distributed.fleet as fleet
from paddle.distributed.auto_parallel.dist_context import ( from paddle.distributed.auto_parallel.dist_context import (
get_default_distributed_context, get_default_distributed_context,
) )
from paddle.distributed.auto_parallel.operators.common import (
is_data_parallel_reduce_op,
)
from paddle.distributed.passes import PassContext, new_pass from paddle.distributed.passes import PassContext, new_pass
sys.path.append("..") sys.path.append("..")
...@@ -116,5 +120,63 @@ class TestDataParallelPassWithScale2(TestDataParallelPassWithScale1): ...@@ -116,5 +120,63 @@ class TestDataParallelPassWithScale2(TestDataParallelPassWithScale1):
return dist_main_prog, dist_startup_prog, data_holder, [loss], gen_data return dist_main_prog, dist_startup_prog, data_holder, [loss], gen_data
class TestDataParallelPassWithStandaloneEXE(TestDataParallelPassWithScale1):
def init(self):
if paddle.is_compiled_with_cuda():
os.environ['FLAGS_CONVERT_GRAPH_TO_PROGRAM'] = "1"
paddle.set_flags({'FLAGS_cudnn_deterministic': 1})
self.rtol = 1e-5
self.atol = 1e-8
# NOTE a hack to compare pass apply or not, since there is no
# setting of this pass in dist_strategy
self._apply_pass = False
rank = paddle.distributed.get_rank()
paddle.seed(rank + 2021)
random.seed(rank + 2021)
np.random.seed(rank + 2021)
# test scaling with optimizer rescale_grad
def get_model(self, place, batch_size, sequence_len, vocab_size):
(
dist_main_prog,
dist_startup_prog,
data_holder,
[loss],
gen_data,
) = self.get_gpt_model(
'dp',
place,
batch_size,
sequence_len,
vocab_size,
optimizer='LarsMomentum',
)
if self._apply_pass:
config = {}
config["dist_context"] = get_default_distributed_context()
config["global_rank"] = paddle.distributed.get_rank()
dp_pass = new_pass(
"auto_parallel_data_parallel_optimization", config
)
dp_pass.apply([dist_main_prog], [dist_startup_prog], PassContext())
ops = dist_main_prog.global_block().ops
allreduce_op_idx = -1
for idx in range(len(ops)):
if is_data_parallel_reduce_op(ops[idx]):
allreduce_op_idx = idx
break
assert allreduce_op_idx > 0
allreduce_op = ops[allreduce_op_idx]
assert allreduce_op.attr('use_calc_stream') is True
assert allreduce_op.dist_attr.execution_stream is not None
assert ops[allreduce_op_idx - 1].type == "nop"
assert ops[allreduce_op_idx + 1].type == "nop"
return dist_main_prog, dist_startup_prog, data_holder, [loss], gen_data
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册