From 5c8fdb59265e7e22a4bd52629e0038180d494ff5 Mon Sep 17 00:00:00 2001 From: Zhou Wei <52485244+zhouwei25@users.noreply.github.com> Date: Thu, 24 Sep 2020 10:21:44 +0800 Subject: [PATCH] Fix GradientClipByGlobalNorm dtype bug (#27437) * fix dtype of gradientclipbyglobalnorm * fix dtype bug of GradientClipbyGlobalnorm --- python/paddle/fluid/clip.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index 7b301ac19d..04e4906868 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -590,7 +590,7 @@ class GradientClipByGlobalNorm(GradientClipBase): global_norm_var = layers.reduce_sum(global_norm_var) global_norm_var = layers.sqrt(global_norm_var) max_global_norm = layers.fill_constant( - shape=[1], dtype='float32', value=self.clip_norm) + shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm) clip_var = layers.elementwise_div( x=max_global_norm, y=layers.elementwise_max( @@ -635,7 +635,9 @@ class GradientClipByGlobalNorm(GradientClipBase): global_norm_var = layers.sums(sum_square_list) global_norm_var = layers.sqrt(x=global_norm_var) max_global_norm = layers.fill_constant( - shape=[1], dtype="float32", value=self.clip_norm) + shape=[1], + dtype=global_norm_var.dtype, + value=self.clip_norm) scale_var = layers.elementwise_div( x=max_global_norm, y=layers.elementwise_max( @@ -663,7 +665,7 @@ class GradientClipByGlobalNorm(GradientClipBase): context[self.group_name] = [] context[self.group_name + "_clip_value"] = self.clip_norm context[self.group_name + "_clip"] = layers.fill_constant( - shape=[1], dtype="float32", value=self.clip_norm) + shape=[1], dtype=grad.dtype, value=self.clip_norm) else: if not self.clip_norm == context[self.group_name + "_clip_value"]: raise ValueError( -- GitLab