diff --git a/python/paddle/fluid/tests/unittests/test_clip_grad_norm_.py b/python/paddle/fluid/tests/unittests/test_clip_grad_norm_.py new file mode 100644 index 0000000000000000000000000000000000000000..308c59d094ec51ba6838f04cf6c1f8e4bcafb745 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_clip_grad_norm_.py @@ -0,0 +1,121 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.nn.utils.clip_grad_norm_ import clip_grad_norm_ + + +class TestClipGradNorm(unittest.TestCase): + def test_basic(self): + run_test_equal( + self, + shape=[16, 16], + dtype=np.float32, + max_norm=5, + norm_type=2, + ) + run_test_equal( + self, + shape=(100,), + dtype=np.float32, + max_norm=1e20, + norm_type=2, + ) + run_test_equal( + self, + shape=[4, 8, 16], + dtype=np.float32, + max_norm=1.0, + norm_type=float("inf"), + ) + + def test_errors(self): + def TestValueError(): + input_pd = paddle.to_tensor( + np.random.random([1, 2]).astype(np.float32) + ) + input_pd.grad = paddle.to_tensor( + np.random.random([1, 2]).astype(np.float32) + ) + clip_grad_norm_(input_pd, max_norm=2, norm_type=float("-inf")) + + self.assertRaises(ValueError, TestValueError) + + def TestRuntimeError(): + input_pd = paddle.to_tensor( + np.random.random([1, 2]).astype(np.float32) + ) + input_pd.grad = paddle.full([1, 2], float("inf")) + clip_grad_norm_( + input_pd, max_norm=2, norm_type=2, error_if_nonfinite=True + ) + + self.assertRaises(RuntimeError, TestRuntimeError) + + def TestRuntimeErrorStaticMode(): + paddle.enable_static() + input_pd = paddle.to_tensor( + np.random.random([1, 2]).astype(np.float32) + ) + input_pd.grad = paddle.to_tensor( + np.random.random([1, 2]).astype(np.float32) + ) + clip_grad_norm_(input_pd, max_norm=2, norm_type=float("inf")) + paddle.disable_static() + + self.assertRaises(RuntimeError, TestRuntimeErrorStaticMode) + + +def run_test_equal( + self, + shape, + dtype, + max_norm, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, +): + input = np.random.random(shape).astype(dtype) + grad = np.random.random(shape).astype(dtype) + input_pd = paddle.to_tensor(input) + input_pd.grad = paddle.to_tensor(grad) + + if norm_type == 2: + grad = grad.reshape(1, grad.size) + output = np.linalg.norm(grad, 'fro') + elif norm_type == np.inf: + output = np.amax(np.abs(grad)) + else: + output = np.linalg.norm(grad, norm_type) + clip_grad_norm_result = clip_grad_norm_( + input_pd, + max_norm=max_norm, + norm_type=norm_type, + error_if_nonfinite=error_if_nonfinite, + ) + + np.testing.assert_allclose( + clip_grad_norm_result.numpy(), + output, + rtol=1e-05, + atol=1e-05, + equal_nan=False, + ) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/utils/__init__.py b/python/paddle/nn/utils/__init__.py index 23e1e233cc0dcd9d064ac1f6fa0211c2d6961648..82b17c8c05d245e2d95f4f5a2d176aaf0285c8f2 100644 --- a/python/paddle/nn/utils/__init__.py +++ b/python/paddle/nn/utils/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ from .transform_parameters import ( vector_to_parameters, _stride_column, ) # noqa: F401 +from .clip_grad_norm_ import clip_grad_norm_ # noqa: F401 __all__ = [ # noqa 'weight_norm', @@ -26,4 +27,5 @@ __all__ = [ # noqa 'spectral_norm', 'parameters_to_vector', 'vector_to_parameters', + 'clip_grad_norm_', ] diff --git a/python/paddle/nn/utils/clip_grad_norm_.py b/python/paddle/nn/utils/clip_grad_norm_.py new file mode 100644 index 0000000000000000000000000000000000000000..3a3ecb38b4428259ccb2cbd8faa5a1bf9ebf1ffa --- /dev/null +++ b/python/paddle/nn/utils/clip_grad_norm_.py @@ -0,0 +1,107 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +__all__ = ['clip_grad_norm_'] + + +def clip_grad_norm_( + parameters, + max_norm, + norm_type=2.0, + error_if_nonfinite=False, +): + r"""Clips gradient norm of the iteratable parameters. + + Norms are calculated together on all gradients, just as they are + connected into one vector. The gradient will be modified in place. + + This API can only run in dynamic graph mode, not static graph mode. + + Args: + parameters (Iterable[paddle.Tensor] or paddle.Tensor): Tensors or a single Tensor + that will be normalized gradients + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be `inf` for + infinity norm. + error_if_nonfinite (bool): if True, throw an error if the total + norm of the gradients from :attr:`parameters` is `nan`, + `inf`, or `-inf`. + + Returns: + Total norm of the parameter gradients (treated as a single vector). + Example: + .. code-block:: python + import paddle + + x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32') + max_norm = float(5.0) + linear = paddle.nn.Linear(in_features=10, out_features=10) + out = linear(x) + loss = paddle.mean(out) + loss.backward() + + paddle.nn.utils.clip_grad_norm_(linear.parameters(), max_norm) + + sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters()) + sdg.step() + """ + if not paddle.in_dynamic_mode(): + raise RuntimeError('this API can only run in dynamic mode.') + + if isinstance(parameters, paddle.Tensor): + parameters = [parameters] + + support_norm_type = [float("inf"), 0, 1, 2] + if norm_type not in support_norm_type: + raise ValueError(f'norm_type only support {support_norm_type}') + + grads = [p.grad for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return paddle.to_tensor(0.0) + if norm_type == float("inf"): + norms = [g.detach().abs().max() for g in grads] + total_norm = ( + norms[0] if len(norms) == 1 else paddle.max(paddle.stack(norms)) + ) + else: + total_norm = paddle.linalg.norm( + paddle.stack( + [paddle.linalg.norm(g.detach(), norm_type) for g in grads] + ), + norm_type, + ) + + if error_if_nonfinite and paddle.logical_or( + total_norm.isnan(), total_norm.isinf() + ): + raise RuntimeError( + f'The total norm of {norm_type} order of the gradients from ' + '`parameters` is non-finite, so it cannot be clipped. In any case, ' + 'disable this error and scale the gradient by non-finite norm, ' + 'set `error_if_nonfinite=False`' + ) + clip_coef = max_norm / (total_norm + 1e-6) + # Note: when the coef is clamped to 1, it is redundant to multiply the clamped coef, but this + # avoids the `if clip_coef < 1:` condition. + clip_coef_clamped = paddle.clip(clip_coef, max=1.0) + with paddle.no_grad(): + for _, p in enumerate(parameters): + g = p.grad + if g is not None: + p.grad = paddle.multiply(x=g, y=clip_coef_clamped) + return total_norm