提交 63222501 编写于 作者: Q Qingsheng Li 提交者: Yu Yang

[Do not merge] Fix global gradient clip by Yu Yang (#13516)

* Yuyang fix global gradient clip

* Share LoDs

* Revert unnecessary changes

* Fix bug in sequence_slice_op
上级 2d00e658
......@@ -271,7 +271,8 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
"All parameters' 'clip_norm' of a same group should be the same"
)
local_norm_var = layers.reduce_sum(input=layers.pow(x=grad, factor=2.0))
square = grad * grad
local_norm_var = layers.cast(layers.reduce_sum(input=square), 'float64')
context[self.group_name].append(local_norm_var)
self.context = context
......@@ -281,6 +282,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
if group_scale_name not in self.context:
group_norm_var = layers.sums(input=self.context[self.group_name])
group_norm_var = layers.sqrt(x=group_norm_var)
group_norm_var = layers.cast(group_norm_var, 'float32')
clip_var = self.context[self.group_name + "_clip"]
group_scale_var = layers.elementwise_div(
x=clip_var,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册