未验证 提交 501b58bd 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #13664 from reyoung/feature/use_double_merge_grads

fix(clip): use double to accumulate grad^2
......@@ -271,7 +271,8 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
"All parameters' 'clip_norm' of a same group should be the same"
)
local_norm_var = layers.reduce_sum(input=layers.pow(x=grad, factor=2.0))
square = grad * grad
local_norm_var = layers.cast(layers.reduce_sum(input=square), 'float64')
context[self.group_name].append(local_norm_var)
self.context = context
......@@ -281,6 +282,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
if group_scale_name not in self.context:
group_norm_var = layers.sums(input=self.context[self.group_name])
group_norm_var = layers.sqrt(x=group_norm_var)
group_norm_var = layers.cast(group_norm_var, 'float32')
clip_var = self.context[self.group_name + "_clip"]
group_scale_var = layers.elementwise_div(
x=clip_var,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册