未验证 提交 598d32d6 编写于 作者: W WangXi 提交者: GitHub

fix GradientClipByGlobalNorm in hybrid parallel (#35691)

上级 04fdb10a
...@@ -522,7 +522,9 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -522,7 +522,9 @@ class ClipGradByGlobalNorm(ClipGradBase):
# fp64 # fp64
global_norm_var_other_dtype = layers.sums(sum_square_list) global_norm_var_other_dtype = layers.sums(sum_square_list)
global_norm_var.append(global_norm_var_other_dtype) global_norm_var.append(global_norm_var_other_dtype)
global_norm_var = layers.sums(global_norm_var)
global_norm_var = layers.sums(global_norm_var) if len(
global_norm_var) > 1 else global_norm_var[0]
global_norm_var = layers.sqrt(x=global_norm_var) global_norm_var = layers.sqrt(x=global_norm_var)
max_global_norm = layers.fill_constant( max_global_norm = layers.fill_constant(
shape=[1], shape=[1],
......
...@@ -266,10 +266,9 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): ...@@ -266,10 +266,9 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream', 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream',
'squared_l2_norm', 'squared_l2_norm', 'squared_l2_norm', 'sum', 'squared_l2_norm', 'squared_l2_norm', 'squared_l2_norm', 'sum',
'c_allreduce_sum', 'sum', 'c_allreduce_sum', 'sqrt', 'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max',
'fill_constant', 'elementwise_max', 'elementwise_div', 'elementwise_div', 'elementwise_mul', 'elementwise_mul',
'elementwise_mul', 'elementwise_mul', 'elementwise_mul', 'momentum', 'elementwise_mul', 'momentum', 'momentum', 'momentum'
'momentum', 'momentum'
]) ])
def test_sharding_clone_for_test(self): def test_sharding_clone_for_test(self):
......
...@@ -216,7 +216,7 @@ class TestGradientClipByGlobalNorm(TestGradientClip): ...@@ -216,7 +216,7 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
def test_none_grad_fp32(self): def test_none_grad_fp32(self):
ops = self._test_none_grad_helper("float32") ops = self._test_none_grad_helper("float32")
self.assertListEqual(ops, [ self.assertListEqual(ops, [
'squared_l2_norm', 'squared_l2_norm', 'sum', 'sum', 'sqrt', 'squared_l2_norm', 'squared_l2_norm', 'sum', 'sqrt',
'fill_constant', 'elementwise_max', 'elementwise_div', 'fill_constant', 'elementwise_max', 'elementwise_div',
'elementwise_mul', 'elementwise_mul' 'elementwise_mul', 'elementwise_mul'
]) ])
...@@ -225,9 +225,8 @@ class TestGradientClipByGlobalNorm(TestGradientClip): ...@@ -225,9 +225,8 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
ops = self._test_none_grad_helper("float16") ops = self._test_none_grad_helper("float16")
self.assertListEqual(ops, [ self.assertListEqual(ops, [
'square', 'reduce_sum', 'square', 'reduce_sum', 'sum', 'cast', 'square', 'reduce_sum', 'square', 'reduce_sum', 'sum', 'cast',
'sum', 'sqrt', 'fill_constant', 'elementwise_max', 'sqrt', 'fill_constant', 'elementwise_max', 'elementwise_div',
'elementwise_div', 'cast', 'elementwise_mul', 'cast', 'cast', 'elementwise_mul', 'cast', 'elementwise_mul'
'elementwise_mul'
]) ])
def _test_none_grad_helper(self, dtype): def _test_none_grad_helper(self, dtype):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册