You need to sign in or sign up before continuing.
提交 9a4dd1bc 编写于 作者: Z zhongpu 提交者: liym27

support float64 for GradClipByGlobalNorm in dygraph, test=develop (#21401)

* support float64 for GradClipByGlobalNorm in dygraph, test=develop

* fix comment for GradClipByGlobalNorm, test=develop
上级 8777e8c1
...@@ -192,15 +192,14 @@ class GradClipByGlobalNorm(GradClipBase): ...@@ -192,15 +192,14 @@ class GradClipByGlobalNorm(GradClipBase):
""" """
Clips values of multiple tensors by the ratio of the sum of their norms. Clips values of multiple tensors by the ratio of the sum of their norms.
Given a list of tensors t_list, and a clipping ratio clip_norm, this Given a list of tensors t_list, and a clipping ratio max_global_norm, this
operation returns a list of clipped tensors list_clipped and the global operation returns a list of clipped tensors list_clipped.
norm (global_norm) of all tensors in t_list.
To perform the clipping, the values :math:`t\_list[i]` are set to: To perform the clipping, the values :math:`t\_list[i]` are set to:
.. math:: .. math::
t\_list[i] = t\_list[i] * \\frac{clip\_norm}{\max(global\_norm, clip\_norm)} t\_list[i] = t\_list[i] * \\frac{max\_global\_norm}{\max(global\_norm, max\_global\_norm)}
where: where:
...@@ -208,12 +207,12 @@ class GradClipByGlobalNorm(GradClipBase): ...@@ -208,12 +207,12 @@ class GradClipByGlobalNorm(GradClipBase):
global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2} global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2}
If :math:`clip\_norm > global\_norm` then the entries in t_list remain as they are, If :math:`max\_global\_norm > global\_norm` then the entries in t_list remain as they are,
otherwise they're all shrunk by the global ratio. otherwise they're all shrunk by the global ratio.
Args: Args:
clip_norm (float): The maximum norm value max_global_norm (float): The maximum norm value.
group_name (str, optional): The group name for this clip. dtype (str, optional): The type of max_global_norm. Default: "float32".
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -225,7 +224,7 @@ class GradClipByGlobalNorm(GradClipBase): ...@@ -225,7 +224,7 @@ class GradClipByGlobalNorm(GradClipBase):
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.nn import FC from paddle.fluid.dygraph.nn import FC
from paddle.fluid.clip import GradClipByValue, GradClipByNorm, GradClipByGlobalNorm from paddle.fluid.dygraph_grad_clip import GradClipByValue, GradClipByNorm, GradClipByGlobalNorm
from paddle.fluid.optimizer import SGDOptimizer from paddle.fluid.optimizer import SGDOptimizer
...@@ -248,9 +247,9 @@ class GradClipByGlobalNorm(GradClipBase): ...@@ -248,9 +247,9 @@ class GradClipByGlobalNorm(GradClipBase):
""" """
@imperative_base.no_grad @imperative_base.no_grad
def __init__(self, max_global_norm): def __init__(self, max_global_norm, dtype='float32'):
self.max_global_norm = layers.fill_constant( self.max_global_norm = layers.fill_constant(
shape=[1], dtype='float32', value=max_global_norm) shape=[1], dtype=dtype, value=max_global_norm)
def __str__(self): def __str__(self):
return "ClipByGlobalNorm, max_global_norm=%f" % (self.max_global_norm) return "ClipByGlobalNorm, max_global_norm=%f" % (self.max_global_norm)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册