未验证 提交 598d32d6 编写于 作者: W WangXi 提交者: GitHub

fix GradientClipByGlobalNorm in hybrid parallel (#35691)

上级 04fdb10a
......@@ -522,7 +522,9 @@ class ClipGradByGlobalNorm(ClipGradBase):
# fp64
global_norm_var_other_dtype = layers.sums(sum_square_list)
global_norm_var.append(global_norm_var_other_dtype)
global_norm_var = layers.sums(global_norm_var)
global_norm_var = layers.sums(global_norm_var) if len(
global_norm_var) > 1 else global_norm_var[0]
global_norm_var = layers.sqrt(x=global_norm_var)
max_global_norm = layers.fill_constant(
shape=[1],
......
......@@ -266,10 +266,9 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream',
'squared_l2_norm', 'squared_l2_norm', 'squared_l2_norm', 'sum',
'c_allreduce_sum', 'sum', 'c_allreduce_sum', 'sqrt',
'fill_constant', 'elementwise_max', 'elementwise_div',
'elementwise_mul', 'elementwise_mul', 'elementwise_mul', 'momentum',
'momentum', 'momentum'
'c_allreduce_sum', 'sqrt', 'fill_constant', 'elementwise_max',
'elementwise_div', 'elementwise_mul', 'elementwise_mul',
'elementwise_mul', 'momentum', 'momentum', 'momentum'
])
def test_sharding_clone_for_test(self):
......
......@@ -216,7 +216,7 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
def test_none_grad_fp32(self):
ops = self._test_none_grad_helper("float32")
self.assertListEqual(ops, [
'squared_l2_norm', 'squared_l2_norm', 'sum', 'sum', 'sqrt',
'squared_l2_norm', 'squared_l2_norm', 'sum', 'sqrt',
'fill_constant', 'elementwise_max', 'elementwise_div',
'elementwise_mul', 'elementwise_mul'
])
......@@ -225,9 +225,8 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
ops = self._test_none_grad_helper("float16")
self.assertListEqual(ops, [
'square', 'reduce_sum', 'square', 'reduce_sum', 'sum', 'cast',
'sum', 'sqrt', 'fill_constant', 'elementwise_max',
'elementwise_div', 'cast', 'elementwise_mul', 'cast',
'elementwise_mul'
'sqrt', 'fill_constant', 'elementwise_max', 'elementwise_div',
'cast', 'elementwise_mul', 'cast', 'elementwise_mul'
])
def _test_none_grad_helper(self, dtype):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册