未验证 提交 0855d982 编写于 作者: Z zxcd 提交者: GitHub

add clip_grad_norm_ API (#49935)

* add clip_grad_norm_ api.

* fix docs and some details according to the comments.

* fix code style.

* fix no_grad problem, and fix doc.

* fix code style.

* fix doc and remove type information
上级 9ce8cfcf
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.nn.utils.clip_grad_norm_ import clip_grad_norm_
class TestClipGradNorm(unittest.TestCase):
def test_basic(self):
run_test_equal(
self,
shape=[16, 16],
dtype=np.float32,
max_norm=5,
norm_type=2,
)
run_test_equal(
self,
shape=(100,),
dtype=np.float32,
max_norm=1e20,
norm_type=2,
)
run_test_equal(
self,
shape=[4, 8, 16],
dtype=np.float32,
max_norm=1.0,
norm_type=float("inf"),
)
def test_errors(self):
def TestValueError():
input_pd = paddle.to_tensor(
np.random.random([1, 2]).astype(np.float32)
)
input_pd.grad = paddle.to_tensor(
np.random.random([1, 2]).astype(np.float32)
)
clip_grad_norm_(input_pd, max_norm=2, norm_type=float("-inf"))
self.assertRaises(ValueError, TestValueError)
def TestRuntimeError():
input_pd = paddle.to_tensor(
np.random.random([1, 2]).astype(np.float32)
)
input_pd.grad = paddle.full([1, 2], float("inf"))
clip_grad_norm_(
input_pd, max_norm=2, norm_type=2, error_if_nonfinite=True
)
self.assertRaises(RuntimeError, TestRuntimeError)
def TestRuntimeErrorStaticMode():
paddle.enable_static()
input_pd = paddle.to_tensor(
np.random.random([1, 2]).astype(np.float32)
)
input_pd.grad = paddle.to_tensor(
np.random.random([1, 2]).astype(np.float32)
)
clip_grad_norm_(input_pd, max_norm=2, norm_type=float("inf"))
paddle.disable_static()
self.assertRaises(RuntimeError, TestRuntimeErrorStaticMode)
def run_test_equal(
self,
shape,
dtype,
max_norm,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
):
input = np.random.random(shape).astype(dtype)
grad = np.random.random(shape).astype(dtype)
input_pd = paddle.to_tensor(input)
input_pd.grad = paddle.to_tensor(grad)
if norm_type == 2:
grad = grad.reshape(1, grad.size)
output = np.linalg.norm(grad, 'fro')
elif norm_type == np.inf:
output = np.amax(np.abs(grad))
else:
output = np.linalg.norm(grad, norm_type)
clip_grad_norm_result = clip_grad_norm_(
input_pd,
max_norm=max_norm,
norm_type=norm_type,
error_if_nonfinite=error_if_nonfinite,
)
np.testing.assert_allclose(
clip_grad_norm_result.numpy(),
output,
rtol=1e-05,
atol=1e-05,
equal_nan=False,
)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -19,6 +19,7 @@ from .transform_parameters import (
vector_to_parameters,
_stride_column,
) # noqa: F401
from .clip_grad_norm_ import clip_grad_norm_ # noqa: F401
__all__ = [ # noqa
'weight_norm',
......@@ -26,4 +27,5 @@ __all__ = [ # noqa
'spectral_norm',
'parameters_to_vector',
'vector_to_parameters',
'clip_grad_norm_',
]
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
__all__ = ['clip_grad_norm_']
def clip_grad_norm_(
parameters,
max_norm,
norm_type=2.0,
error_if_nonfinite=False,
):
r"""Clips gradient norm of the iteratable parameters.
Norms are calculated together on all gradients, just as they are
connected into one vector. The gradient will be modified in place.
This API can only run in dynamic graph mode, not static graph mode.
Args:
parameters (Iterable[paddle.Tensor] or paddle.Tensor): Tensors or a single Tensor
that will be normalized gradients
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be `inf` for
infinity norm.
error_if_nonfinite (bool): if True, throw an error if the total
norm of the gradients from :attr:`parameters` is `nan`,
`inf`, or `-inf`.
Returns:
Total norm of the parameter gradients (treated as a single vector).
Example:
.. code-block:: python
import paddle
x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
max_norm = float(5.0)
linear = paddle.nn.Linear(in_features=10, out_features=10)
out = linear(x)
loss = paddle.mean(out)
loss.backward()
paddle.nn.utils.clip_grad_norm_(linear.parameters(), max_norm)
sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters())
sdg.step()
"""
if not paddle.in_dynamic_mode():
raise RuntimeError('this API can only run in dynamic mode.')
if isinstance(parameters, paddle.Tensor):
parameters = [parameters]
support_norm_type = [float("inf"), 0, 1, 2]
if norm_type not in support_norm_type:
raise ValueError(f'norm_type only support {support_norm_type}')
grads = [p.grad for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
if len(grads) == 0:
return paddle.to_tensor(0.0)
if norm_type == float("inf"):
norms = [g.detach().abs().max() for g in grads]
total_norm = (
norms[0] if len(norms) == 1 else paddle.max(paddle.stack(norms))
)
else:
total_norm = paddle.linalg.norm(
paddle.stack(
[paddle.linalg.norm(g.detach(), norm_type) for g in grads]
),
norm_type,
)
if error_if_nonfinite and paddle.logical_or(
total_norm.isnan(), total_norm.isinf()
):
raise RuntimeError(
f'The total norm of {norm_type} order of the gradients from '
'`parameters` is non-finite, so it cannot be clipped. In any case, '
'disable this error and scale the gradient by non-finite norm, '
'set `error_if_nonfinite=False`'
)
clip_coef = max_norm / (total_norm + 1e-6)
# Note: when the coef is clamped to 1, it is redundant to multiply the clamped coef, but this
# avoids the `if clip_coef < 1:` condition.
clip_coef_clamped = paddle.clip(clip_coef, max=1.0)
with paddle.no_grad():
for _, p in enumerate(parameters):
g = p.grad
if g is not None:
p.grad = paddle.multiply(x=g, y=clip_coef_clamped)
return total_norm
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册