未验证 提交 5c8fdb59 编写于 作者: Z Zhou Wei 提交者: GitHub

Fix GradientClipByGlobalNorm dtype bug (#27437)

* fix dtype of gradientclipbyglobalnorm

* fix dtype bug of GradientClipbyGlobalnorm
上级 4a9d21de
...@@ -590,7 +590,7 @@ class GradientClipByGlobalNorm(GradientClipBase): ...@@ -590,7 +590,7 @@ class GradientClipByGlobalNorm(GradientClipBase):
global_norm_var = layers.reduce_sum(global_norm_var) global_norm_var = layers.reduce_sum(global_norm_var)
global_norm_var = layers.sqrt(global_norm_var) global_norm_var = layers.sqrt(global_norm_var)
max_global_norm = layers.fill_constant( max_global_norm = layers.fill_constant(
shape=[1], dtype='float32', value=self.clip_norm) shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm)
clip_var = layers.elementwise_div( clip_var = layers.elementwise_div(
x=max_global_norm, x=max_global_norm,
y=layers.elementwise_max( y=layers.elementwise_max(
...@@ -635,7 +635,9 @@ class GradientClipByGlobalNorm(GradientClipBase): ...@@ -635,7 +635,9 @@ class GradientClipByGlobalNorm(GradientClipBase):
global_norm_var = layers.sums(sum_square_list) global_norm_var = layers.sums(sum_square_list)
global_norm_var = layers.sqrt(x=global_norm_var) global_norm_var = layers.sqrt(x=global_norm_var)
max_global_norm = layers.fill_constant( max_global_norm = layers.fill_constant(
shape=[1], dtype="float32", value=self.clip_norm) shape=[1],
dtype=global_norm_var.dtype,
value=self.clip_norm)
scale_var = layers.elementwise_div( scale_var = layers.elementwise_div(
x=max_global_norm, x=max_global_norm,
y=layers.elementwise_max( y=layers.elementwise_max(
...@@ -663,7 +665,7 @@ class GradientClipByGlobalNorm(GradientClipBase): ...@@ -663,7 +665,7 @@ class GradientClipByGlobalNorm(GradientClipBase):
context[self.group_name] = [] context[self.group_name] = []
context[self.group_name + "_clip_value"] = self.clip_norm context[self.group_name + "_clip_value"] = self.clip_norm
context[self.group_name + "_clip"] = layers.fill_constant( context[self.group_name + "_clip"] = layers.fill_constant(
shape=[1], dtype="float32", value=self.clip_norm) shape=[1], dtype=grad.dtype, value=self.clip_norm)
else: else:
if not self.clip_norm == context[self.group_name + "_clip_value"]: if not self.clip_norm == context[self.group_name + "_clip_value"]:
raise ValueError( raise ValueError(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册