diff --git a/python/paddle/nn/utils/__init__.py b/python/paddle/nn/utils/__init__.py index 82b17c8c05d245e2d95f4f5a2d176aaf0285c8f2..d1645deb905a98f45ae9e643c0eb004330ac030d 100644 --- a/python/paddle/nn/utils/__init__.py +++ b/python/paddle/nn/utils/__init__.py @@ -20,6 +20,7 @@ from .transform_parameters import ( _stride_column, ) # noqa: F401 from .clip_grad_norm_ import clip_grad_norm_ # noqa: F401 +from .clip_grad_value_ import clip_grad_value_ # noqa: F401 __all__ = [ # noqa 'weight_norm', @@ -28,4 +29,5 @@ __all__ = [ # noqa 'parameters_to_vector', 'vector_to_parameters', 'clip_grad_norm_', + 'clip_grad_value_', ] diff --git a/python/paddle/nn/utils/clip_grad_norm_.py b/python/paddle/nn/utils/clip_grad_norm_.py index 3a3ecb38b4428259ccb2cbd8faa5a1bf9ebf1ffa..22fa7341e3f2e0c65752499292dfc93770ba5da6 100644 --- a/python/paddle/nn/utils/clip_grad_norm_.py +++ b/python/paddle/nn/utils/clip_grad_norm_.py @@ -14,9 +14,10 @@ import paddle -__all__ = ['clip_grad_norm_'] +__all__ = [] +@paddle.autograd.no_grad() def clip_grad_norm_( parameters, max_norm, @@ -98,10 +99,9 @@ def clip_grad_norm_( clip_coef = max_norm / (total_norm + 1e-6) # Note: when the coef is clamped to 1, it is redundant to multiply the clamped coef, but this # avoids the `if clip_coef < 1:` condition. - clip_coef_clamped = paddle.clip(clip_coef, max=1.0) - with paddle.no_grad(): - for _, p in enumerate(parameters): - g = p.grad - if g is not None: - p.grad = paddle.multiply(x=g, y=clip_coef_clamped) + clip_coef_clamped = clip_coef.clip_(max=1.0) + + for _, p in enumerate(parameters): + if p.grad is not None: + p.grad = paddle.multiply(x=p.grad, y=clip_coef_clamped) return total_norm diff --git a/python/paddle/nn/utils/clip_grad_value_.py b/python/paddle/nn/utils/clip_grad_value_.py new file mode 100644 index 0000000000000000000000000000000000000000..9c0fa10ed08ebe31ded4751766deb9dd8c2d5e91 --- /dev/null +++ b/python/paddle/nn/utils/clip_grad_value_.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +__all__ = [] + + +@paddle.autograd.no_grad() +def clip_grad_value_( + parameters, + clip_value, +): + r"""Clips gradient of an iterable of parameters at specified value. + The gradient will be modified in place. + This API can only run in dynamic graph mode, not static graph mode. + Args: + parameters (Iterable[paddle.Tensor] or paddle.Tensor): Tensors or a single Tensor + that will be normalized gradients + clip_value (float or int): maximum allowed value of the gradients. + The gradients are clipped in the range + :math:`\left[\text{-clip\_value}, \text{clip\_value}\right]` + Example: + .. code-block:: python + + import paddle + x = paddle.uniform([10, 10], min=-10.0, max=10.0, dtype='float32') + clip_value = float(5.0) + linear = paddle.nn.Linear(in_features=10, out_features=10) + out = linear(x) + loss = paddle.mean(out) + loss.backward() + paddle.nn.utils.clip_grad_value_(linear.parameters(), clip_value) + sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters()) + sdg.step() + """ + if not paddle.in_dynamic_mode(): + raise RuntimeError('this API can only run in dynamic mode.') + + if isinstance(parameters, paddle.Tensor): + parameters = [parameters] + + clip_value = float(clip_value) + + for _, p in enumerate(parameters): + if p.grad is not None: + p.grad.clip_(min=-clip_value, max=clip_value) diff --git a/test/legacy_test/test_clip_grad_value_.py b/test/legacy_test/test_clip_grad_value_.py new file mode 100644 index 0000000000000000000000000000000000000000..5b975cbb64979d95f7a2088bac4aed23ad402349 --- /dev/null +++ b/test/legacy_test/test_clip_grad_value_.py @@ -0,0 +1,136 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +from paddle.nn.utils.clip_grad_value_ import clip_grad_value_ + + +class TestClipGradValue(unittest.TestCase): + def test_basic(self): + run_test_equal_np( + self, + shape=[16, 16], + dtype=np.float32, + clip_value=1, + ) + run_test_equal_np( + self, + shape=(100,), + dtype=np.float32, + clip_value=0.1, + ) + run_test_equal_np( + self, shape=[4, 8, 16], dtype=np.float32, clip_value=0 + ) + run_test_equal_ClipGradByValue( + self, + shape=[16, 16], + dtype=np.float32, + clip_value=1, + ) + run_test_equal_ClipGradByValue( + self, + shape=(100,), + dtype=np.float32, + clip_value=0.1, + ) + run_test_equal_ClipGradByValue( + self, shape=[4, 8, 16], dtype=np.float32, clip_value=0 + ) + + def test_errors(self): + def TestValueError(): + input_pd = paddle.to_tensor( + np.random.random([1, 2]).astype(np.float32) + ) + input_pd.grad = paddle.to_tensor( + np.random.random([1, 2]).astype(np.float32) + ) + clip_grad_value_(input_pd, clip_value=-1) + + self.assertRaises(ValueError, TestValueError) + + def TestRuntimeErrorStaticMode(): + paddle.enable_static() + input_pd = paddle.to_tensor( + np.random.random([1, 2]).astype(np.float32) + ) + input_pd.grad = paddle.to_tensor( + np.random.random([1, 2]).astype(np.float32) + ) + clip_grad_value_(input_pd, clip_value=1) + paddle.disable_static() + + self.assertRaises(RuntimeError, TestRuntimeErrorStaticMode) + + +def run_test_equal_np( + self, + shape, + dtype, + clip_value, +): + input = np.random.random(shape).astype(dtype) + grad = np.random.random(shape).astype(dtype) + input_pd = paddle.to_tensor(input) + input_pd.grad = paddle.to_tensor(grad) + + output = np.clip(grad, a_min=-clip_value, a_max=clip_value) + clip_grad_value_( + input_pd, + clip_value=clip_value, + ) + + np.testing.assert_allclose( + input_pd.grad.numpy(), + output, + rtol=1e-05, + atol=1e-05, + equal_nan=False, + ) + + +def run_test_equal_ClipGradByValue( + self, + shape, + dtype, + clip_value, +): + input = np.random.random(shape).astype(dtype) + grad = np.random.random(shape).astype(dtype) + input_pd = paddle.to_tensor(input) + input_pd.grad = paddle.to_tensor(grad) + + clip = paddle.nn.ClipGradByValue(max=clip_value, min=-clip_value) + output = clip([(input_pd, input_pd.grad)])[0][1] + clip_grad_value_( + input_pd, + clip_value=clip_value, + ) + + np.testing.assert_allclose( + input_pd.grad, + output, + rtol=1e-05, + atol=1e-05, + equal_nan=False, + ) + + +if __name__ == '__main__': + unittest.main()