未验证 提交 b56dbe08 编写于 作者: Y Yuang Liu 提交者: GitHub

fix the allreduce fused bug, test=develop (#34446)

上级 76f94f88
...@@ -188,7 +188,7 @@ message DistributedStrategy { ...@@ -188,7 +188,7 @@ message DistributedStrategy {
optional bool find_unused_parameters = 28 [ default = false ]; optional bool find_unused_parameters = 28 [ default = false ];
optional bool tensor_parallel = 29 [ default = false ]; optional bool tensor_parallel = 29 [ default = false ];
optional bool without_graph_optimization = 30 [ default = false ]; optional bool without_graph_optimization = 30 [ default = false ];
optional int32 fuse_grad_size_in_num = 31 [ default = 1 ]; optional int32 fuse_grad_size_in_num = 31 [ default = 8 ];
optional bool calc_comm_same_stream = 32 [ default = false ]; optional bool calc_comm_same_stream = 32 [ default = false ];
optional bool asp = 33 [ default = false ]; optional bool asp = 33 [ default = false ];
......
...@@ -131,7 +131,7 @@ class RawProgramOptimizer(MetaOptimizerBase): ...@@ -131,7 +131,7 @@ class RawProgramOptimizer(MetaOptimizerBase):
def _transpile_main_program(self, loss): def _transpile_main_program(self, loss):
self._insert_loss_grad_ops(loss) self._insert_loss_grad_ops(loss)
if self.fuse_all_reduce_ops: if self.fuse_all_reduce_ops and self.fuse_grad_size_in_num > 1:
self._allreduce_fusion_program() self._allreduce_fusion_program()
else: else:
self._insert_allreduce_ops() self._insert_allreduce_ops()
...@@ -216,11 +216,10 @@ class RawProgramOptimizer(MetaOptimizerBase): ...@@ -216,11 +216,10 @@ class RawProgramOptimizer(MetaOptimizerBase):
def _allreduce_fusion_program(self): def _allreduce_fusion_program(self):
block = self.main_program.global_block() block = self.main_program.global_block()
ring_id = self.global_ring_id ring_id = self.global_ring_id
record_idx, allreduce_input_vars, allreduce_output_vars = [], [], [] param_grads = []
ops = list(enumerate(block.ops))
for idx, op in reversed(ops): # find all grad params
# we travers the ops reversely for op in reversed(block.ops):
if is_backward_op(op) and \ if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names: OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.attr(OP_ROLE_VAR_KEY) op_role_var = op.attr(OP_ROLE_VAR_KEY)
...@@ -229,214 +228,88 @@ class RawProgramOptimizer(MetaOptimizerBase): ...@@ -229,214 +228,88 @@ class RawProgramOptimizer(MetaOptimizerBase):
assert len(op_role_var) % 2 == 0, "vars need to be one param var followed by one grad var, " \ assert len(op_role_var) % 2 == 0, "vars need to be one param var followed by one grad var, " \
"but got odd number of vars" "but got odd number of vars"
for i in range(0, len(op_role_var), 2): for i in range(0, len(op_role_var), 2):
# handle vars in each op, each time handle a param and a grad
param_name = op_role_var[i] param_name = op_role_var[i]
param = block.var(param_name) param = block.var(param_name)
grad_name = op_role_var[i + 1] grad_name = op_role_var[i + 1]
grad = block.var(grad_name) grad = block.var(grad_name)
if param.is_distributed: if param.is_distributed:
continue continue
if ".cast_fp16@GRAD" in grad_name: param_grads.append(grad)
# when amp=True get the fp16 param
param_name = param_name + ".cast_fp16" segments = []
if not block.has_var(param_name): last_dtype = None
raise ValueError("op cast name error {}".format( # split the grad based on dtype and fused size
op.type)) for var in param_grads:
else: if len(segments) == 0 \
param = block.var(param_name) or len(segments[-1]) == self.fuse_grad_size_in_num \
or var.dtype != last_dtype:
if len(allreduce_output_vars) == 0 or \ segments.append([var])
len(allreduce_output_vars[-1]) == \ last_dtype = var.dtype
self.fuse_grad_size_in_num: else:
# start of the fusion or last group meets the config size segments[-1].append(var)
allreduce_output_vars.append([grad])
allreduce_input_vars.append([param])
# add the start and end idx to the record idx
record_idx.append([idx, idx])
else:
# Current group's size is below the config size
# append grad and param to the last group (current group)
# update the start idx to current op's idx
# Since we travers the ops reversely, the idx is descending
# we update the first entry of each entry for record_idx
allreduce_output_vars[-1].append(grad)
allreduce_input_vars[-1].append(param)
record_idx[-1][0] = idx
assert len(allreduce_output_vars) == len(
record_idx
), "It has different lens between the allreduce_output_vars and record_idx."
if not allreduce_output_vars or not allreduce_input_vars:
# nothing needs to be allreduced
return
self.vars = collections.OrderedDict() fused_vars = []
index, pos, offset = 0, 0, 0 for idx, op in enumerate(block.ops):
start, end = record_idx[index] if is_optimizer_op(op):
for idx, op in reversed(ops): for segment in segments:
if idx == start: # insert coalesce tensor
pos = 0
done_output_vars, done_input_vars = self._split_fuction(
allreduce_output_vars[index], # grad
allreduce_input_vars[index] # param
)
for id_, done_output_var in enumerate(done_output_vars):
tmp_var = block.create_var( tmp_var = block.create_var(
name=unique_name.generate('FusedOutput_{}'.format( name=unique_name.generate('FusedOutput_{}'.format(
done_output_var[0].name)), segment[0].name)),
dtype=done_output_var[0].dtype, dtype=segment[0].dtype,
persistable=False, persistable=True,
stop_gradient=True) stop_gradient=True)
self.vars['FusedOutput_{}'.format(done_output_var[0] fused_vars.append(tmp_var)
.name)] = tmp_var block._insert_op_without_sync(
idx,
block._insert_op(
idx + id_,
type="coalesce_tensor", type="coalesce_tensor",
inputs={"Input": done_input_vars[id_]}, inputs={"Input": segment},
outputs={ outputs={"Output": segment,
"Output": done_output_var, "FusedOutput": tmp_var},
"FusedOutput": tmp_var
},
attrs={ attrs={
"copy_data": False, "copy_data": True,
"use_align": True, "use_align": True,
"dtype": done_output_var[0].dtype, "dtype": segment[0].dtype,
OP_ROLE_KEY: OpRole.Backward OP_ROLE_KEY: OpRole.Backward
}) })
pos += 1 break
for id_ in range(len(done_output_vars)):
x = self.vars['FusedOutput_{}'.format(done_output_vars[id_][
0].name)]
out = x
# NOTE: there still some optimize space if use EVENT instead of sync
if not self.calc_comm_same_stream:
# need sync if the calc and comm stream are not the same
block._insert_op(
end + id_ + pos + 1,
type='c_sync_calc_stream',
inputs={'X': x},
outputs={'Out': out},
attrs={OP_ROLE_KEY: OpRole.Backward})
block._insert_op( # insert the allreduce_sum op
end + id_ + pos + 1 for idx, op in enumerate(block.ops):
if self.calc_comm_same_stream else end + id_ + pos + 2, if is_optimizer_op(op):
for fused_var in fused_vars:
block._insert_op_without_sync(
idx,
type='c_allreduce_sum', type='c_allreduce_sum',
inputs={'X': x}, inputs={'X': fused_var},
outputs={'Out': out}, outputs={'Out': fused_var},
attrs={ attrs={
'ring_id': ring_id, 'ring_id': ring_id,
'use_calc_stream': self.calc_comm_same_stream, 'use_calc_stream': self.calc_comm_same_stream,
OP_ROLE_KEY: OpRole.Backward OP_ROLE_KEY: OpRole.Backward
}) })
if not self.calc_comm_same_stream:
block._insert_op_without_sync(
idx,
type='c_sync_calc_stream',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={OP_ROLE_KEY: OpRole.Backward})
break
index += 1 if len(fused_vars) == 0:
if len(record_idx) == index: block._sync_with_cpp()
break return
start, end = record_idx[index]
if not self.calc_comm_same_stream:
# need sync if the calc and comm stream are not the same
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
block._insert_op(
idx,
type='c_sync_comm_stream',
inputs={'X': block.create_var()},
outputs={'Out': block.create_var()},
attrs={
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward
})
break
# Integrate grads of the same type to form a combination.
# If combination is selected, will return grads of the same type in a groups.
# For example:[(fp16, fp16), (fp32), (fp16)] -> [(fp16, fp16, fp16), (fp32)]
def _split_fuction(self,
allreduce_output_vars,
allreduce_input_vars,
combination=True):
input_vars, final_input_vars, output_vars, final_output_vars = [], [], [], []
if len(allreduce_output_vars) == 1:
# only have one var to handle
final_output_vars.append(allreduce_output_vars)
final_input_vars.append(allreduce_input_vars)
return final_output_vars, final_input_vars
for idx in range(len(allreduce_input_vars) - 1):
# the last var needs to be handled differently
if allreduce_input_vars[idx].dtype == allreduce_input_vars[idx +
1].dtype:
# if current var and next var are in same type
# append current var to input_vars
input_vars.append(allreduce_input_vars[idx])
if idx == len(allreduce_input_vars) - 2:
# if current var is the second last var
# append the last var to input_vars
# and update the final_input_vars
input_vars.append(allreduce_input_vars[idx + 1])
final_input_vars.append(input_vars)
else:
# the current var and next var are in different types
# append current var to input_vars
# update the final_input_vars
# reset input_vars to receive a new type
input_vars.append(allreduce_input_vars[idx])
final_input_vars.append(input_vars)
input_vars = []
if idx == len(allreduce_input_vars) - 2:
# if current var is the second last var
# append the last var to a reset input_vars since they are in different types
# and update the final_input_vars
input_vars.append(allreduce_input_vars[idx + 1])
final_input_vars.append(input_vars)
for idx in range(len(allreduce_output_vars) - 1):
# the procedure for the output vars is the same with that for the input vars
if allreduce_output_vars[idx].dtype == allreduce_output_vars[
idx + 1].dtype:
output_vars.append(allreduce_output_vars[idx])
if idx == len(allreduce_output_vars) - 2:
output_vars.append(allreduce_output_vars[idx + 1])
final_output_vars.append(output_vars)
else:
output_vars.append(allreduce_output_vars[idx])
final_output_vars.append(output_vars)
output_vars = []
if idx == len(allreduce_output_vars) - 2:
output_vars.append(allreduce_output_vars[idx + 1])
final_output_vars.append(output_vars)
# at this time, all vars in each group in final_input_vars and final_output_vars are in the same type
if combination:
input_fp16_vars, input_fp32_vars, output_fp16_vars, output_fp32_vars = [], [], [], []
for final_input_var in final_input_vars:
if final_input_var[0].dtype == core.VarDesc.VarType.FP16:
# extend the group
input_fp16_vars.extend(final_input_var)
else:
input_fp32_vars.extend(final_input_var)
for final_output_var in final_output_vars:
if final_output_var[0].dtype == core.VarDesc.VarType.FP16:
output_fp16_vars.extend(final_output_var)
else:
output_fp32_vars.extend(final_output_var)
final_output_vars, final_input_vars = [], []
if output_fp16_vars:
final_output_vars.append(output_fp16_vars)
if output_fp32_vars:
final_output_vars.append(output_fp32_vars)
if input_fp16_vars:
final_input_vars.append(input_fp16_vars)
if input_fp32_vars:
final_input_vars.append(input_fp32_vars)
return final_output_vars, final_input_vars # insert the sync comm op
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
block._insert_op_without_sync(
idx,
type='c_sync_comm_stream',
inputs={'X': fused_vars[0]},
outputs={'Out': fused_vars[0]},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward})
break
block._sync_with_cpp()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册