From 9a4dd1bc25c9a023e2dd5f4b6f5a415b22ed488a Mon Sep 17 00:00:00 2001 From: zhongpu <2013000149@qq.com> Date: Mon, 9 Dec 2019 22:08:10 +0800 Subject: [PATCH] support float64 for GradClipByGlobalNorm in dygraph, test=develop (#21401) * support float64 for GradClipByGlobalNorm in dygraph, test=develop * fix comment for GradClipByGlobalNorm, test=develop --- python/paddle/fluid/dygraph_grad_clip.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/python/paddle/fluid/dygraph_grad_clip.py b/python/paddle/fluid/dygraph_grad_clip.py index ad052694648..103322815b5 100644 --- a/python/paddle/fluid/dygraph_grad_clip.py +++ b/python/paddle/fluid/dygraph_grad_clip.py @@ -192,15 +192,14 @@ class GradClipByGlobalNorm(GradClipBase): """ Clips values of multiple tensors by the ratio of the sum of their norms. - Given a list of tensors t_list, and a clipping ratio clip_norm, this - operation returns a list of clipped tensors list_clipped and the global - norm (global_norm) of all tensors in t_list. + Given a list of tensors t_list, and a clipping ratio max_global_norm, this + operation returns a list of clipped tensors list_clipped. To perform the clipping, the values :math:`t\_list[i]` are set to: .. math:: - t\_list[i] = t\_list[i] * \\frac{clip\_norm}{\max(global\_norm, clip\_norm)} + t\_list[i] = t\_list[i] * \\frac{max\_global\_norm}{\max(global\_norm, max\_global\_norm)} where: @@ -208,12 +207,12 @@ class GradClipByGlobalNorm(GradClipBase): global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2} - If :math:`clip\_norm > global\_norm` then the entries in t_list remain as they are, + If :math:`max\_global\_norm > global\_norm` then the entries in t_list remain as they are, otherwise they're all shrunk by the global ratio. Args: - clip_norm (float): The maximum norm value - group_name (str, optional): The group name for this clip. + max_global_norm (float): The maximum norm value. + dtype (str, optional): The type of max_global_norm. Default: "float32". Examples: .. code-block:: python @@ -225,7 +224,7 @@ class GradClipByGlobalNorm(GradClipBase): from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.nn import FC - from paddle.fluid.clip import GradClipByValue, GradClipByNorm, GradClipByGlobalNorm + from paddle.fluid.dygraph_grad_clip import GradClipByValue, GradClipByNorm, GradClipByGlobalNorm from paddle.fluid.optimizer import SGDOptimizer @@ -248,9 +247,9 @@ class GradClipByGlobalNorm(GradClipBase): """ @imperative_base.no_grad - def __init__(self, max_global_norm): + def __init__(self, max_global_norm, dtype='float32'): self.max_global_norm = layers.fill_constant( - shape=[1], dtype='float32', value=max_global_norm) + shape=[1], dtype=dtype, value=max_global_norm) def __str__(self): return "ClipByGlobalNorm, max_global_norm=%f" % (self.max_global_norm) -- GitLab