提交 808c3ef3 编写于 作者: T tensor-tang

fix sqrt

上级 f42a12da
...@@ -280,7 +280,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr): ...@@ -280,7 +280,7 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
group_scale_name = self.group_name + "_scale" group_scale_name = self.group_name + "_scale"
if group_scale_name not in self.context: if group_scale_name not in self.context:
group_norm_var = layers.sums(input=self.context[self.group_name]) group_norm_var = layers.sums(input=self.context[self.group_name])
layers.sqrt(x=group_norm_var, out=group_norm_var) group_norm_var = layers.sqrt(x=group_norm_var)
clip_var = self.context[self.group_name + "_clip"] clip_var = self.context[self.group_name + "_clip"]
group_scale_var = layers.elementwise_div( group_scale_var = layers.elementwise_div(
x=clip_var, x=clip_var,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册