未验证 提交 730dcaf4 编写于 作者: H Haohongxiang 提交者: GitHub

fix bugs in HybridParallelClipGrad of hybrid_parallel_optimizer (#36237)

* fix bugs in HybridParallelClipGrad of hybrid_parallel_optimizer

* update

* update
上级 e9288340
...@@ -50,7 +50,8 @@ class HybridParallelClipGrad: ...@@ -50,7 +50,8 @@ class HybridParallelClipGrad:
@imperative_base.no_grad @imperative_base.no_grad
def _dygraph_clip(self, params_grads): def _dygraph_clip(self, params_grads):
params_and_grads = [] params_and_grads = []
sum_square_list = [] sum_square_list_dist = []
sum_square_list_not_dist = []
for p, g in params_grads: for p, g in params_grads:
if g is None: if g is None:
continue continue
...@@ -62,18 +63,33 @@ class HybridParallelClipGrad: ...@@ -62,18 +63,33 @@ class HybridParallelClipGrad:
merge_grad = layers.get_tensor_from_selected_rows(merge_grad) merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
square = layers.square(merge_grad) square = layers.square(merge_grad)
sum_square = layers.reduce_sum(square) sum_square = layers.reduce_sum(square)
sum_square_list.append(sum_square)
if p.is_distributed:
sum_square_list_dist.append(sum_square)
else:
sum_square_list_not_dist.append(sum_square)
# all parameters have been filterd out # all parameters have been filterd out
if len(sum_square_list) == 0: if len(sum_square_list_dist) + len(sum_square_list_not_dist) == 0:
return params_grads return params_grads
global_norm_var = layers.concat(sum_square_list) global_norm_var_dist = layers.concat(sum_square_list_dist) if len(
global_norm_var = layers.reduce_sum(global_norm_var) sum_square_list_dist) != 0 else layers.concat(
# add all reduce to get global norm in world size [paddle.to_tensor([0.])])
paddle.distributed.all_reduce(global_norm_var, global_norm_var_dist = layers.reduce_sum(global_norm_var_dist)
self._hcg.get_check_parallel_group()) global_norm_var_not_dist = layers.concat(
global_norm_var = layers.sqrt(global_norm_var) sum_square_list_not_dist) if len(
sum_square_list_not_dist) != 0 else layers.concat(
[paddle.to_tensor([0.])])
global_norm_var_not_dist = layers.reduce_sum(global_norm_var_not_dist)
# add all reduce to get global norm of distributed params_and_grads in world size
# all reduce is not needed while getting global norm of non-distributed params_and_grads
paddle.distributed.all_reduce(
global_norm_var_dist, group=self._hcg.get_check_parallel_group())
global_norm_var = layers.sqrt(global_norm_var_dist +
global_norm_var_not_dist)
max_global_norm = layers.fill_constant( max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm) shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
...@@ -96,7 +112,7 @@ class HybridParallelClipGrad: ...@@ -96,7 +112,7 @@ class HybridParallelClipGrad:
return getattr(self._clip, item) return getattr(self._clip, item)
def __call__(self, params_grads): def __call__(self, params_grads):
return self._clip(params_grads) return self._dygraph_clip(params_grads)
class HybridParallelOptimizer: class HybridParallelOptimizer:
...@@ -112,7 +128,7 @@ class HybridParallelOptimizer: ...@@ -112,7 +128,7 @@ class HybridParallelOptimizer:
self._need_dp = (self._hcg.get_data_parallel_world_size() > 1) self._need_dp = (self._hcg.get_data_parallel_world_size() > 1)
# NOTE(shenliang03): Because of the pure DataParallel mode, the gradient synchronization # NOTE(shenliang03): Because of the pure DataParallel mode, the gradient synchronization
# is achieved through reducer, so there is no need to call fuse_allreduce in oprimizer. # is achieved through reducer, so there is no need to call fuse_allreduce in optimizer.
self._dp_enable = not self._use_dp_mode and self._need_dp self._dp_enable = not self._use_dp_mode and self._need_dp
self._sharding_enable = ( self._sharding_enable = (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册