未验证 提交 9b5e0154 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] DP Calc-Comm Overlapping Support Weight Sharing (#45443)

* bugfix (#45332)

* customize wait_comm
上级 a4d2878a
...@@ -93,8 +93,8 @@ class DataParallelOptimizationPass(PassBase): ...@@ -93,8 +93,8 @@ class DataParallelOptimizationPass(PassBase):
def _calc_comm_overlap(self): def _calc_comm_overlap(self):
if not self._could_be_overlap(): if not self._could_be_overlap():
return return
self._calc_overlap_comms() self._comms_overlap_calc()
self._update_wait_comms() self._calc_wait_comms()
def _fuse_allreduce(self): def _fuse_allreduce(self):
pass pass
...@@ -227,7 +227,7 @@ class DataParallelOptimizationPass(PassBase): ...@@ -227,7 +227,7 @@ class DataParallelOptimizationPass(PassBase):
return True return True
def _calc_overlap_comms(self): def _comms_overlap_calc(self):
# TODO support InterpreterCore executor for overlap. # TODO support InterpreterCore executor for overlap.
# InterpreterCore has a different logic for overlapping # InterpreterCore has a different logic for overlapping
# which is different from use_calc_stream # which is different from use_calc_stream
...@@ -254,23 +254,58 @@ class DataParallelOptimizationPass(PassBase): ...@@ -254,23 +254,58 @@ class DataParallelOptimizationPass(PassBase):
block._sync_with_cpp() block._sync_with_cpp()
def _update_wait_comms(self): def _calc_wait_comms(self):
block = default_main_program().global_block() block = default_main_program().global_block()
ops = block.ops ops = block.ops
# update wait comm to finish # NOTE the naive overlap implement in static hybird parallel only sync comm stream
first_optimize_op_idx = -1 # at the end of Backward phase, based on a strong constraint that
for idx, op in enumerate(ops): # all communicating gradient would NOT be used after communication in Backward phase.
if is_optimize_op(op): # BUT this constraint will fail for scenario like Weight-Sharing and Higher-Order Differentiation,
first_optimize_op_idx = idx # where gradient will be involved in other calculation between data-parallel allreduce kernel submmited
break # into comm streams and the synchronization of comm stream at the end of Backward phase.
# synchronization of comm stream should add according to the usage of communicating gradients
assert first_optimize_op_idx > -1, "Unexception: not found optimizer op in program" # to support Overlapping for Weight-Sharing and Higher-Order Differentiation.
ring_id_to_un_sync_grad_map = {}
op_idx_to_sync_ring_id_map = {}
for group in self._group_to_grad_name_map.keys(): for group in self._group_to_grad_name_map.keys():
ring_id = group.id ring_id_to_un_sync_grad_map[group.id] = []
block._insert_op_without_sync(first_optimize_op_idx,
# analyze the where need to sync
for i, op in enumerate(ops):
if is_data_parallel_reduce_op(op):
ring_id = op.attr("ring_id")
grad_name = op.output_arg_names[0]
ring_id_to_un_sync_grad_map[ring_id].append(grad_name)
elif is_data_parallel_scale_op(op):
continue
# other ops that might use communicating grad
else:
for input_var_name in op.input_arg_names:
for ring_id, unsync_grad_names in ring_id_to_un_sync_grad_map.items(
):
if input_var_name in unsync_grad_names:
# need to sync before op_i
if i in op_idx_to_sync_ring_id_map:
op_idx_to_sync_ring_id_map[i].append(ring_id)
else:
op_idx_to_sync_ring_id_map[i] = [ring_id]
# all grads in this comm stream are synced
ring_id_to_un_sync_grad_map[ring_id] = []
# insert synchronization
indices = list(op_idx_to_sync_ring_id_map.keys())
# TODO the synchronization could be optimized
# we should record the event of a gradient is communicating and
# only wait for that event to be completed.
# BUT paddle static currently not support op api for event record only, so
# here we try to wait for all kernel in that comm stream to be finish which is not that optimized.
for i in sorted(indices, reverse=True):
for ring_id in op_idx_to_sync_ring_id_map[i]:
block._insert_op_without_sync(i,
type='c_wait_comm', type='c_wait_comm',
inputs={'X': []}, inputs={'X': []},
outputs={'Out': []}, outputs={'Out': []},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册