未验证 提交 4d2994cb 编写于 作者: Y Yuang Liu 提交者: GitHub

Optimize fused allreduce in raw program (#34509)

上级 6a9fac14
......@@ -217,9 +217,13 @@ class RawProgramOptimizer(MetaOptimizerBase):
block = self.main_program.global_block()
ring_id = self.global_ring_id
param_grads = []
first_backward_idx = -1
# find all grad params
for op in reversed(block.ops):
for idx, op in enumerate(block.ops):
if first_backward_idx == -1 and \
is_backward_op(op):
first_backward_idx = idx
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.attr(OP_ROLE_VAR_KEY)
......@@ -234,70 +238,100 @@ class RawProgramOptimizer(MetaOptimizerBase):
grad = block.var(grad_name)
if param.is_distributed:
continue
param_grads.append(grad)
param_grads.append((param, grad))
outputs_name_to_idx = self.__get_ouputs_name_to_idx(first_backward_idx,
block)
segments = []
# structure of grad_param_segments is
# [([grad0, grad1], [param0, param1]), ([grad2, grad3], [param2, param3])]
# each entry of the list is a tuple stores the grads segment list and
# the corresponding params segment list
grad_param_segments = []
last_dtype = None
# split the grad based on dtype and fused size
for var in param_grads:
if len(segments) == 0 \
or len(segments[-1]) == self.fuse_grad_size_in_num \
or var.dtype != last_dtype:
segments.append([var])
last_dtype = var.dtype
for param, grad in param_grads:
if len(grad_param_segments) == 0 \
or len(grad_param_segments[-1][0]) == self.fuse_grad_size_in_num \
or grad.dtype != last_dtype:
grad_param_segments.append(([grad], [param]))
last_dtype = grad.dtype
else:
segments[-1].append(var)
grad_param_segments[-1][0].append(grad)
grad_param_segments[-1][1].append(param)
fused_vars = []
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
for segment in segments:
# insert coalesce tensor
tmp_var = block.create_var(
name=unique_name.generate('FusedOutput_{}'.format(
segment[0].name)),
dtype=segment[0].dtype,
persistable=True,
stop_gradient=True)
fused_vars.append(tmp_var)
block._insert_op_without_sync(
idx,
type="coalesce_tensor",
inputs={"Input": segment},
outputs={"Output": segment,
"FusedOutput": tmp_var},
attrs={
"copy_data": True,
"use_align": True,
"dtype": segment[0].dtype,
OP_ROLE_KEY: OpRole.Backward
})
break
if len(grad_param_segments) == 0:
return
# insert the allreduce_sum op
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
for fused_var in fused_vars:
block._insert_op_without_sync(
idx,
type='c_allreduce_sum',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={
'ring_id': ring_id,
'use_calc_stream': self.calc_comm_same_stream,
OP_ROLE_KEY: OpRole.Backward
})
if not self.calc_comm_same_stream:
block._insert_op_without_sync(
idx,
type='c_sync_calc_stream',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={OP_ROLE_KEY: OpRole.Backward})
break
fused_vars = [None] * len(grad_param_segments)
for i in range(len(grad_param_segments) - 1, -1, -1):
# travers the grad_param_segments in backward
# not to use reversed since needs the absolute index value
grad_segment, param_segment = grad_param_segments[i]
# insert coalesce tensor
fused_var = block.create_var(
name=unique_name.generate('FusedOutput_{}'.format(grad_segment[
0].name)),
dtype=grad_segment[0].dtype,
persistable=False,
stop_gradient=True)
fused_vars[i] = fused_var
after_idx = outputs_name_to_idx[grad_segment[-1]][1]
block._insert_op_without_sync(
after_idx + 1,
type='c_allreduce_sum',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={
'ring_id': ring_id,
'use_calc_stream': self.calc_comm_same_stream,
OP_ROLE_KEY: OpRole.Backward
})
if not self.calc_comm_same_stream:
block._insert_op_without_sync(
after_idx + 1,
type='c_sync_calc_stream',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={OP_ROLE_KEY: OpRole.Backward})
if len(fused_vars) == 0:
# update the outputs_name_to_idx after insertion of sync/allreduce ops
outputs_name_to_idx = self.__get_ouputs_name_to_idx(first_backward_idx,
block)
# the before_idx is not guaranteed sorted, therefore we have to find the
# topology to insert the coalesce ops
pos_for_coalesce = {}
for i in range(len(grad_param_segments) - 1, -1, -1):
# We separate the insertion of coalesce op and the insertion of sync/allreduce op,
# since that the coalesce op's insertion may invalidate the outputs_name_to_idx
grad_segment, param_segment = grad_param_segments[i]
before_idx = len(block.ops)
for grad in outputs_name_to_idx:
before_idx = min(before_idx, outputs_name_to_idx[grad][0])
pos_for_coalesce[i] = before_idx
# insert the coalesce op based on the sorted before_idx
pos_for_coalesce = sorted(
pos_for_coalesce.items(),
key=lambda kv: (kv[1], kv[0]),
reverse=True)
for i, before_idx in pos_for_coalesce:
grad_segment, param_segment = grad_param_segments[i]
fused_var = fused_vars[i]
block._insert_op_without_sync(
before_idx,
type="coalesce_tensor",
inputs={"Input": param_segment},
outputs={"Output": grad_segment,
"FusedOutput": fused_var},
attrs={
"copy_data": False,
"use_align": True,
"dtype": grad_segment[0].dtype,
OP_ROLE_KEY: OpRole.Backward
})
if self.calc_comm_same_stream:
block._sync_with_cpp()
return
......@@ -307,9 +341,31 @@ class RawProgramOptimizer(MetaOptimizerBase):
block._insert_op_without_sync(
idx,
type='c_sync_comm_stream',
inputs={'X': fused_vars[0]},
outputs={'Out': fused_vars[0]},
inputs={'X': grad_segment[0]},
outputs={'Out': grad_segment[0]},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward})
break
block._sync_with_cpp()
def __get_ouputs_name_to_idx(self, first_backward_idx, block):
# Each item of outputs_name_to_idx is a pair of idx.
# The first entry of this pair is the idx of the first op generates the grad,
# which is used to indicate the position to insert coalesce op.
# The second entry of this pair is the idx of the last op generates the grad,
# which is used to indicate the position to insert sync and allreduce op.
outputs_name_to_idx = {}
for idx in range(first_backward_idx, len(block.ops)):
op = block.ops[idx]
if is_optimizer_op(op):
break
for name in op.output_arg_names:
var = block.var(name)
if not outputs_name_to_idx.get(var):
# if the grad only be generated by one op
# the first idx and the last ids are identical
outputs_name_to_idx[var] = (idx, idx)
else:
outputs_name_to_idx[var] = (outputs_name_to_idx[var][0],
idx)
return outputs_name_to_idx
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册