未验证 提交 a4eadd15 编写于 作者: Y Yuang Liu 提交者: GitHub

[hybrid] Fix mp multi gradient clip prob (#35713)

上级 4b683887
...@@ -142,32 +142,103 @@ class GradientClipHelper(object): ...@@ -142,32 +142,103 @@ class GradientClipHelper(object):
return return
# TODO (JZ-LIANG) revise this for uniform mixed parallelism # TODO (JZ-LIANG) revise this for uniform mixed parallelism
def sync_global_norm(self, block, ring_ids): def sync_global_norm(self, block, ring_ids, mp_rank):
""" """
prune gradient_clip related ops for params that not belong to cur shard prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div keep: sum, sqrt, elementwise_max, elementwise_div
""" """
# FIXME(wangxi): mp should prune duplicated param_grads is_clip_grad_by_global_norm = False
for idx, op in list(enumerate(block.ops)):
if not self._is_gradient_clip_op(op):
continue
if op.type == 'sum':
is_clip_grad_by_global_norm = True
break
if not is_clip_grad_by_global_norm:
# TODO(Yuang Liu): need some extra handles when clip_grad_norm for mp
return
removed_op_idx = set()
removed_tmp_var = set()
for idx, op in list(enumerate(block.ops)):
if not self._is_gradient_clip_op(op):
continue
if op.type == 'sum':
break
for input_name in op.input_arg_names:
input_var = block.var(input_name)
# NOTE: when mp_degree > 1, some vars will be split into each mp rank.
# However, there still some vars such as Scale, Bias are not split.
# Those not be split vars should only be counted once during grad clip
# by global norm. Those vars either doesn't have is_distributed attr
# or the is_distributed attr has been set as False.
# Therefore, we prune those duplicated vars for grad clip.
if mp_rank >= 1 and (not (hasattr(input_var, 'is_distributed')
and input_var.is_distributed)):
removed_op_idx.add(idx)
for output_name in op.output_arg_names:
removed_tmp_var.add(output_name)
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if not self._is_gradient_clip_op(op): if not self._is_gradient_clip_op(op):
continue continue
if idx in removed_op_idx:
block._remove_op(idx, sync=False)
if op.type == "sum": for var_name in removed_tmp_var:
sum_res = op.desc.output_arg_names()[0] block._remove_var(var_name, sync=False)
for ring_id in ring_ids:
if ring_id == -1: continue
idx = idx + 1 for idx, op in list(enumerate(block.ops)):
block._insert_op_without_sync( if not self._is_gradient_clip_op(op):
idx, continue
type='c_allreduce_sum', if op.type == 'sum':
inputs={'X': sum_res}, # If mp_rank == 0, no extra handles, just allreduce
outputs={'Out': sum_res}, # If mp_rank >= 1, some extra handles is needed
attrs={ sum_rst_var = block.var(op.output_arg_names[0])
'ring_id': ring_id, if mp_rank >= 1:
'op_namescope': "/gradient_clip_model_parallelism", reserved_vars = []
'use_calc_stream': True, for input_name in op.input_arg_names:
OP_ROLE_KEY: OpRole.Optimize, if input_name not in removed_tmp_var:
}) reserved_vars.append(input_name)
return
if len(reserved_vars) > 0:
op.desc.set_input("X", reserved_vars)
else:
# If all input of sum op should be removed, then remove the sum op.
# And set the output's value of sum to 0.
namescope = op.attr("op_namescope")
block._remove_op(idx, sync=False)
fill_constant_op = block._insert_op_without_sync(
idx,
type='fill_constant',
inputs={},
outputs={'Out': sum_rst_var},
attrs={
'shape': sum_rst_var.shape,
'dtype': sum_rst_var.dtype,
'value': 0.0,
OP_ROLE_KEY: OpRole.Optimize
})
fill_constant_op._set_attr('op_namescope', namescope)
self._insert_allreduce(block, ring_ids, idx, sum_rst_var)
break
@staticmethod
def _insert_allreduce(block, ring_ids, idx, var):
for ring_id in ring_ids:
if ring_id == -1:
continue
idx = idx + 1
block._insert_op_without_sync(
idx,
type='c_allreduce_sum',
inputs={'X': var},
outputs={'Out': var},
attrs={
'ring_id': ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})
...@@ -435,7 +435,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -435,7 +435,6 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block = self._main_program.global_block() main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
# FIXME(wangxi): mp should prune duplicated param_grads when calc
# amp inf_var & clip global_norm_var # amp inf_var & clip global_norm_var
rings = [self.mp_ring_id, self.pp_ring_id] rings = [self.mp_ring_id, self.pp_ring_id]
...@@ -446,7 +445,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -446,7 +445,7 @@ class ShardingOptimizer(MetaOptimizerBase):
gradientclip_helper = GradientClipHelper(None) gradientclip_helper = GradientClipHelper(None)
gradientclip_helper.sync_global_norm( gradientclip_helper.sync_global_norm(
main_block, [self.mp_ring_id, self.pp_ring_id]) main_block, [self.mp_ring_id, self.pp_ring_id], self.mp_rank)
def _insert_loss_grad_scale_op(self): def _insert_loss_grad_scale_op(self):
main_block = self._main_program.global_block() main_block = self._main_program.global_block()
......
...@@ -4381,7 +4381,7 @@ class PipelineOptimizer(object): ...@@ -4381,7 +4381,7 @@ class PipelineOptimizer(object):
persistable=source_var.persistable) persistable=source_var.persistable)
else: else:
dest_var = block._clone_variable(source_var, False) dest_var = block._clone_variable(source_var, False)
dest_var.stop_gradient = source_var.stop_gradient self._clone_var_attr(dest_var, source_var)
# When use with sharding, allreduce_sum and allreduce_max # When use with sharding, allreduce_sum and allreduce_max
# used for global gradient clip and amp will be added by sharding. # used for global gradient clip and amp will be added by sharding.
op_idx += 1 op_idx += 1
...@@ -4547,9 +4547,14 @@ class PipelineOptimizer(object): ...@@ -4547,9 +4547,14 @@ class PipelineOptimizer(object):
persistable=ref_var.persistable, persistable=ref_var.persistable,
is_data=ref_var.is_data, is_data=ref_var.is_data,
need_check_feed=ref_var.desc.need_check_feed()) need_check_feed=ref_var.desc.need_check_feed())
new_var.stop_gradient = ref_var.stop_gradient self._clone_var_attr(new_var, ref_var)
return new_var return new_var
def _clone_var_attr(self, dest, src):
dest.stop_gradient = src.stop_gradient
if hasattr(src, 'is_distributed'):
dest.is_distributed = src.is_distributed
def _strip_grad_suffix(self, name): def _strip_grad_suffix(self, name):
""" """
Strip the grad suffix from the given variable name Strip the grad suffix from the given variable name
...@@ -5209,6 +5214,8 @@ class PipelineOptimizer(object): ...@@ -5209,6 +5214,8 @@ class PipelineOptimizer(object):
persistable=True, persistable=True,
stop_gradient=False) stop_gradient=False)
real_param = main_block.var(param) real_param = main_block.var(param)
if hasattr(real_param, 'is_distributed'):
merged_grad_var.is_distributed = real_param.is_distributed
tmp_size = self._get_var_size(real_grad) tmp_size = self._get_var_size(real_grad)
# two strategies for splitting the grad # two strategies for splitting the grad
# 1. the current segment's size reach the user defined grad_size_in_MB # 1. the current segment's size reach the user defined grad_size_in_MB
......
...@@ -658,6 +658,33 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer): ...@@ -658,6 +658,33 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
'c_gen_nccl_id', 'c_comm_init' 'c_gen_nccl_id', 'c_comm_init'
]) ])
self.assertEqual(main_prog_op_types, [
'partial_recv', 'partial_allgather', 'cast', 'cast', 'mul', 'cast',
'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast',
'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast',
'elementwise_add', 'cast', 'tanh', 'cast', 'cast', 'mul', 'cast',
'elementwise_add', 'softmax', 'cast', 'cross_entropy2', 'mean',
'elementwise_mul', 'fill_constant', 'elementwise_mul_grad',
'mean_grad', 'cross_entropy_grad2', 'cast', 'softmax_grad',
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'c_sync_calc_stream',
'partial_send', 'fill_constant', 'cast', 'sum', 'fill_constant',
'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant',
'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant',
'cast', 'sum', 'fill_constant', 'cast', 'sum', 'fill_constant',
'cast', 'sum', 'c_sync_comm_stream', 'check_finite_and_unscale',
'cast', 'c_allreduce_max', 'c_allreduce_max', 'cast',
'update_loss_scaling', 'fill_constant', 'c_allreduce_sum',
'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max',
'elementwise_div', 'elementwise_mul', 'elementwise_mul',
'elementwise_mul', 'elementwise_mul', 'elementwise_mul',
'elementwise_mul', 'elementwise_mul', 'elementwise_mul', 'momentum',
'momentum', 'momentum', 'momentum', 'momentum', 'momentum',
'momentum', 'momentum'
])
# pp + mp, partial send recv # pp + mp, partial send recv
self.assertIn('partial_recv', main_prog_op_types) self.assertIn('partial_recv', main_prog_op_types)
self.assertIn('partial_allgather', main_prog_op_types) self.assertIn('partial_allgather', main_prog_op_types)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册