提交 9a4dd1bc 编写于 作者: Z zhongpu 提交者: liym27

support float64 for GradClipByGlobalNorm in dygraph, test=develop (#21401)

* support float64 for GradClipByGlobalNorm in dygraph, test=develop

* fix comment for GradClipByGlobalNorm, test=develop
上级 8777e8c1
......@@ -192,15 +192,14 @@ class GradClipByGlobalNorm(GradClipBase):
"""
Clips values of multiple tensors by the ratio of the sum of their norms.
Given a list of tensors t_list, and a clipping ratio clip_norm, this
operation returns a list of clipped tensors list_clipped and the global
norm (global_norm) of all tensors in t_list.
Given a list of tensors t_list, and a clipping ratio max_global_norm, this
operation returns a list of clipped tensors list_clipped.
To perform the clipping, the values :math:`t\_list[i]` are set to:
.. math::
t\_list[i] = t\_list[i] * \\frac{clip\_norm}{\max(global\_norm, clip\_norm)}
t\_list[i] = t\_list[i] * \\frac{max\_global\_norm}{\max(global\_norm, max\_global\_norm)}
where:
......@@ -208,12 +207,12 @@ class GradClipByGlobalNorm(GradClipBase):
global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2}
If :math:`clip\_norm > global\_norm` then the entries in t_list remain as they are,
If :math:`max\_global\_norm > global\_norm` then the entries in t_list remain as they are,
otherwise they're all shrunk by the global ratio.
Args:
clip_norm (float): The maximum norm value
group_name (str, optional): The group name for this clip.
max_global_norm (float): The maximum norm value.
dtype (str, optional): The type of max_global_norm. Default: "float32".
Examples:
.. code-block:: python
......@@ -225,7 +224,7 @@ class GradClipByGlobalNorm(GradClipBase):
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.nn import FC
from paddle.fluid.clip import GradClipByValue, GradClipByNorm, GradClipByGlobalNorm
from paddle.fluid.dygraph_grad_clip import GradClipByValue, GradClipByNorm, GradClipByGlobalNorm
from paddle.fluid.optimizer import SGDOptimizer
......@@ -248,9 +247,9 @@ class GradClipByGlobalNorm(GradClipBase):
"""
@imperative_base.no_grad
def __init__(self, max_global_norm):
def __init__(self, max_global_norm, dtype='float32'):
self.max_global_norm = layers.fill_constant(
shape=[1], dtype='float32', value=max_global_norm)
shape=[1], dtype=dtype, value=max_global_norm)
def __str__(self):
return "ClipByGlobalNorm, max_global_norm=%f" % (self.max_global_norm)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册