未验证 提交 b56dbe08 编写于 作者: Y Yuang Liu 提交者: GitHub

fix the allreduce fused bug, test=develop (#34446)

上级 76f94f88
......@@ -188,7 +188,7 @@ message DistributedStrategy {
optional bool find_unused_parameters = 28 [ default = false ];
optional bool tensor_parallel = 29 [ default = false ];
optional bool without_graph_optimization = 30 [ default = false ];
optional int32 fuse_grad_size_in_num = 31 [ default = 1 ];
optional int32 fuse_grad_size_in_num = 31 [ default = 8 ];
optional bool calc_comm_same_stream = 32 [ default = false ];
optional bool asp = 33 [ default = false ];
......
......@@ -131,7 +131,7 @@ class RawProgramOptimizer(MetaOptimizerBase):
def _transpile_main_program(self, loss):
self._insert_loss_grad_ops(loss)
if self.fuse_all_reduce_ops:
if self.fuse_all_reduce_ops and self.fuse_grad_size_in_num > 1:
self._allreduce_fusion_program()
else:
self._insert_allreduce_ops()
......@@ -216,11 +216,10 @@ class RawProgramOptimizer(MetaOptimizerBase):
def _allreduce_fusion_program(self):
block = self.main_program.global_block()
ring_id = self.global_ring_id
record_idx, allreduce_input_vars, allreduce_output_vars = [], [], []
ops = list(enumerate(block.ops))
param_grads = []
for idx, op in reversed(ops):
# we travers the ops reversely
# find all grad params
for op in reversed(block.ops):
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.attr(OP_ROLE_VAR_KEY)
......@@ -229,214 +228,88 @@ class RawProgramOptimizer(MetaOptimizerBase):
assert len(op_role_var) % 2 == 0, "vars need to be one param var followed by one grad var, " \
"but got odd number of vars"
for i in range(0, len(op_role_var), 2):
# handle vars in each op, each time handle a param and a grad
param_name = op_role_var[i]
param = block.var(param_name)
grad_name = op_role_var[i + 1]
grad = block.var(grad_name)
if param.is_distributed:
continue
if ".cast_fp16@GRAD" in grad_name:
# when amp=True get the fp16 param
param_name = param_name + ".cast_fp16"
if not block.has_var(param_name):
raise ValueError("op cast name error {}".format(
op.type))
else:
param = block.var(param_name)
if len(allreduce_output_vars) == 0 or \
len(allreduce_output_vars[-1]) == \
self.fuse_grad_size_in_num:
# start of the fusion or last group meets the config size
allreduce_output_vars.append([grad])
allreduce_input_vars.append([param])
# add the start and end idx to the record idx
record_idx.append([idx, idx])
param_grads.append(grad)
segments = []
last_dtype = None
# split the grad based on dtype and fused size
for var in param_grads:
if len(segments) == 0 \
or len(segments[-1]) == self.fuse_grad_size_in_num \
or var.dtype != last_dtype:
segments.append([var])
last_dtype = var.dtype
else:
# Current group's size is below the config size
# append grad and param to the last group (current group)
# update the start idx to current op's idx
# Since we travers the ops reversely, the idx is descending
# we update the first entry of each entry for record_idx
allreduce_output_vars[-1].append(grad)
allreduce_input_vars[-1].append(param)
record_idx[-1][0] = idx
assert len(allreduce_output_vars) == len(
record_idx
), "It has different lens between the allreduce_output_vars and record_idx."
if not allreduce_output_vars or not allreduce_input_vars:
# nothing needs to be allreduced
return
segments[-1].append(var)
self.vars = collections.OrderedDict()
index, pos, offset = 0, 0, 0
start, end = record_idx[index]
for idx, op in reversed(ops):
if idx == start:
pos = 0
done_output_vars, done_input_vars = self._split_fuction(
allreduce_output_vars[index], # grad
allreduce_input_vars[index] # param
)
for id_, done_output_var in enumerate(done_output_vars):
fused_vars = []
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
for segment in segments:
# insert coalesce tensor
tmp_var = block.create_var(
name=unique_name.generate('FusedOutput_{}'.format(
done_output_var[0].name)),
dtype=done_output_var[0].dtype,
persistable=False,
segment[0].name)),
dtype=segment[0].dtype,
persistable=True,
stop_gradient=True)
self.vars['FusedOutput_{}'.format(done_output_var[0]
.name)] = tmp_var
block._insert_op(
idx + id_,
fused_vars.append(tmp_var)
block._insert_op_without_sync(
idx,
type="coalesce_tensor",
inputs={"Input": done_input_vars[id_]},
outputs={
"Output": done_output_var,
"FusedOutput": tmp_var
},
inputs={"Input": segment},
outputs={"Output": segment,
"FusedOutput": tmp_var},
attrs={
"copy_data": False,
"copy_data": True,
"use_align": True,
"dtype": done_output_var[0].dtype,
"dtype": segment[0].dtype,
OP_ROLE_KEY: OpRole.Backward
})
pos += 1
for id_ in range(len(done_output_vars)):
x = self.vars['FusedOutput_{}'.format(done_output_vars[id_][
0].name)]
out = x
# NOTE: there still some optimize space if use EVENT instead of sync
if not self.calc_comm_same_stream:
# need sync if the calc and comm stream are not the same
block._insert_op(
end + id_ + pos + 1,
type='c_sync_calc_stream',
inputs={'X': x},
outputs={'Out': out},
attrs={OP_ROLE_KEY: OpRole.Backward})
break
block._insert_op(
end + id_ + pos + 1
if self.calc_comm_same_stream else end + id_ + pos + 2,
# insert the allreduce_sum op
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
for fused_var in fused_vars:
block._insert_op_without_sync(
idx,
type='c_allreduce_sum',
inputs={'X': x},
outputs={'Out': out},
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={
'ring_id': ring_id,
'use_calc_stream': self.calc_comm_same_stream,
OP_ROLE_KEY: OpRole.Backward
})
index += 1
if len(record_idx) == index:
if not self.calc_comm_same_stream:
block._insert_op_without_sync(
idx,
type='c_sync_calc_stream',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={OP_ROLE_KEY: OpRole.Backward})
break
start, end = record_idx[index]
if not self.calc_comm_same_stream:
# need sync if the calc and comm stream are not the same
if len(fused_vars) == 0:
block._sync_with_cpp()
return
# insert the sync comm op
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
block._insert_op(
block._insert_op_without_sync(
idx,
type='c_sync_comm_stream',
inputs={'X': block.create_var()},
outputs={'Out': block.create_var()},
attrs={
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward
})
inputs={'X': fused_vars[0]},
outputs={'Out': fused_vars[0]},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward})
break
# Integrate grads of the same type to form a combination.
# If combination is selected, will return grads of the same type in a groups.
# For example:[(fp16, fp16), (fp32), (fp16)] -> [(fp16, fp16, fp16), (fp32)]
def _split_fuction(self,
allreduce_output_vars,
allreduce_input_vars,
combination=True):
input_vars, final_input_vars, output_vars, final_output_vars = [], [], [], []
if len(allreduce_output_vars) == 1:
# only have one var to handle
final_output_vars.append(allreduce_output_vars)
final_input_vars.append(allreduce_input_vars)
return final_output_vars, final_input_vars
for idx in range(len(allreduce_input_vars) - 1):
# the last var needs to be handled differently
if allreduce_input_vars[idx].dtype == allreduce_input_vars[idx +
1].dtype:
# if current var and next var are in same type
# append current var to input_vars
input_vars.append(allreduce_input_vars[idx])
if idx == len(allreduce_input_vars) - 2:
# if current var is the second last var
# append the last var to input_vars
# and update the final_input_vars
input_vars.append(allreduce_input_vars[idx + 1])
final_input_vars.append(input_vars)
else:
# the current var and next var are in different types
# append current var to input_vars
# update the final_input_vars
# reset input_vars to receive a new type
input_vars.append(allreduce_input_vars[idx])
final_input_vars.append(input_vars)
input_vars = []
if idx == len(allreduce_input_vars) - 2:
# if current var is the second last var
# append the last var to a reset input_vars since they are in different types
# and update the final_input_vars
input_vars.append(allreduce_input_vars[idx + 1])
final_input_vars.append(input_vars)
for idx in range(len(allreduce_output_vars) - 1):
# the procedure for the output vars is the same with that for the input vars
if allreduce_output_vars[idx].dtype == allreduce_output_vars[
idx + 1].dtype:
output_vars.append(allreduce_output_vars[idx])
if idx == len(allreduce_output_vars) - 2:
output_vars.append(allreduce_output_vars[idx + 1])
final_output_vars.append(output_vars)
else:
output_vars.append(allreduce_output_vars[idx])
final_output_vars.append(output_vars)
output_vars = []
if idx == len(allreduce_output_vars) - 2:
output_vars.append(allreduce_output_vars[idx + 1])
final_output_vars.append(output_vars)
# at this time, all vars in each group in final_input_vars and final_output_vars are in the same type
if combination:
input_fp16_vars, input_fp32_vars, output_fp16_vars, output_fp32_vars = [], [], [], []
for final_input_var in final_input_vars:
if final_input_var[0].dtype == core.VarDesc.VarType.FP16:
# extend the group
input_fp16_vars.extend(final_input_var)
else:
input_fp32_vars.extend(final_input_var)
for final_output_var in final_output_vars:
if final_output_var[0].dtype == core.VarDesc.VarType.FP16:
output_fp16_vars.extend(final_output_var)
else:
output_fp32_vars.extend(final_output_var)
final_output_vars, final_input_vars = [], []
if output_fp16_vars:
final_output_vars.append(output_fp16_vars)
if output_fp32_vars:
final_output_vars.append(output_fp32_vars)
if input_fp16_vars:
final_input_vars.append(input_fp16_vars)
if input_fp32_vars:
final_input_vars.append(input_fp32_vars)
return final_output_vars, final_input_vars
block._sync_with_cpp()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册