未验证 提交 88402cdb 编写于 作者: Z zqw_1997 提交者: GitHub

add clip_grad_value_ api (#54603)

* add clip_grad_value_ api

* add test for ClipGradByValue

* typo fix

* refine and modify clip_grad_norm_

* no_grad

* clip_

* remove g=p.grad

* bug: AssertionError: When Variable is used as the condition of if/while , Variable can only contain one element.
上级 6b6d4090
......@@ -20,6 +20,7 @@ from .transform_parameters import (
_stride_column,
) # noqa: F401
from .clip_grad_norm_ import clip_grad_norm_ # noqa: F401
from .clip_grad_value_ import clip_grad_value_ # noqa: F401
__all__ = [ # noqa
'weight_norm',
......@@ -28,4 +29,5 @@ __all__ = [ # noqa
'parameters_to_vector',
'vector_to_parameters',
'clip_grad_norm_',
'clip_grad_value_',
]
......@@ -14,9 +14,10 @@
import paddle
__all__ = ['clip_grad_norm_']
__all__ = []
@paddle.autograd.no_grad()
def clip_grad_norm_(
parameters,
max_norm,
......@@ -98,10 +99,9 @@ def clip_grad_norm_(
clip_coef = max_norm / (total_norm + 1e-6)
# Note: when the coef is clamped to 1, it is redundant to multiply the clamped coef, but this
# avoids the `if clip_coef < 1:` condition.
clip_coef_clamped = paddle.clip(clip_coef, max=1.0)
with paddle.no_grad():
for _, p in enumerate(parameters):
g = p.grad
if g is not None:
p.grad = paddle.multiply(x=g, y=clip_coef_clamped)
clip_coef_clamped = clip_coef.clip_(max=1.0)
for _, p in enumerate(parameters):
if p.grad is not None:
p.grad = paddle.multiply(x=p.grad, y=clip_coef_clamped)
return total_norm
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
__all__ = []
@paddle.autograd.no_grad()
def clip_grad_value_(
parameters,
clip_value,
):
r"""Clips gradient of an iterable of parameters at specified value.
The gradient will be modified in place.
This API can only run in dynamic graph mode, not static graph mode.
Args:
parameters (Iterable[paddle.Tensor] or paddle.Tensor): Tensors or a single Tensor
that will be normalized gradients
clip_value (float or int): maximum allowed value of the gradients.
The gradients are clipped in the range
:math:`\left[\text{-clip\_value}, \text{clip\_value}\right]`
Example:
.. code-block:: python
import paddle
x = paddle.uniform([10, 10], min=-10.0, max=10.0, dtype='float32')
clip_value = float(5.0)
linear = paddle.nn.Linear(in_features=10, out_features=10)
out = linear(x)
loss = paddle.mean(out)
loss.backward()
paddle.nn.utils.clip_grad_value_(linear.parameters(), clip_value)
sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters())
sdg.step()
"""
if not paddle.in_dynamic_mode():
raise RuntimeError('this API can only run in dynamic mode.')
if isinstance(parameters, paddle.Tensor):
parameters = [parameters]
clip_value = float(clip_value)
for _, p in enumerate(parameters):
if p.grad is not None:
p.grad.clip_(min=-clip_value, max=clip_value)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.nn.utils.clip_grad_value_ import clip_grad_value_
class TestClipGradValue(unittest.TestCase):
def test_basic(self):
run_test_equal_np(
self,
shape=[16, 16],
dtype=np.float32,
clip_value=1,
)
run_test_equal_np(
self,
shape=(100,),
dtype=np.float32,
clip_value=0.1,
)
run_test_equal_np(
self, shape=[4, 8, 16], dtype=np.float32, clip_value=0
)
run_test_equal_ClipGradByValue(
self,
shape=[16, 16],
dtype=np.float32,
clip_value=1,
)
run_test_equal_ClipGradByValue(
self,
shape=(100,),
dtype=np.float32,
clip_value=0.1,
)
run_test_equal_ClipGradByValue(
self, shape=[4, 8, 16], dtype=np.float32, clip_value=0
)
def test_errors(self):
def TestValueError():
input_pd = paddle.to_tensor(
np.random.random([1, 2]).astype(np.float32)
)
input_pd.grad = paddle.to_tensor(
np.random.random([1, 2]).astype(np.float32)
)
clip_grad_value_(input_pd, clip_value=-1)
self.assertRaises(ValueError, TestValueError)
def TestRuntimeErrorStaticMode():
paddle.enable_static()
input_pd = paddle.to_tensor(
np.random.random([1, 2]).astype(np.float32)
)
input_pd.grad = paddle.to_tensor(
np.random.random([1, 2]).astype(np.float32)
)
clip_grad_value_(input_pd, clip_value=1)
paddle.disable_static()
self.assertRaises(RuntimeError, TestRuntimeErrorStaticMode)
def run_test_equal_np(
self,
shape,
dtype,
clip_value,
):
input = np.random.random(shape).astype(dtype)
grad = np.random.random(shape).astype(dtype)
input_pd = paddle.to_tensor(input)
input_pd.grad = paddle.to_tensor(grad)
output = np.clip(grad, a_min=-clip_value, a_max=clip_value)
clip_grad_value_(
input_pd,
clip_value=clip_value,
)
np.testing.assert_allclose(
input_pd.grad.numpy(),
output,
rtol=1e-05,
atol=1e-05,
equal_nan=False,
)
def run_test_equal_ClipGradByValue(
self,
shape,
dtype,
clip_value,
):
input = np.random.random(shape).astype(dtype)
grad = np.random.random(shape).astype(dtype)
input_pd = paddle.to_tensor(input)
input_pd.grad = paddle.to_tensor(grad)
clip = paddle.nn.ClipGradByValue(max=clip_value, min=-clip_value)
output = clip([(input_pd, input_pd.grad)])[0][1]
clip_grad_value_(
input_pd,
clip_value=clip_value,
)
np.testing.assert_allclose(
input_pd.grad,
output,
rtol=1e-05,
atol=1e-05,
equal_nan=False,
)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册