未验证 提交 9f85f218 编写于 作者: H Hongyu Liu 提交者: GitHub

Add new gard clip [old gradient clip not support in dy graph] (#17523)

* add gradient clip in minimize; test=develop

* fix bug; test=develop

* fix format; test=develop

* move new grad clip to dygraph/grad_clip.py; test=develop

* fix lr decay and grad clip test; test=develop

* seperate dygraph grad clip; test=develop

* fix grad clip test; develop

* fix api spec bug; test=develop

* add blank line, test=develop,test=document_preview

to fix format problem
上级 4337009b
此差异已折叠。
...@@ -54,6 +54,7 @@ from .transpiler import DistributeTranspiler, \ ...@@ -54,6 +54,7 @@ from .transpiler import DistributeTranspiler, \
memory_optimize, release_memory, DistributeTranspilerConfig memory_optimize, release_memory, DistributeTranspilerConfig
from .lod_tensor import create_lod_tensor, create_random_int_lodtensor from .lod_tensor import create_lod_tensor, create_random_int_lodtensor
from . import clip from . import clip
from . import dygraph_grad_clip
from . import profiler from . import profiler
from . import unique_name from . import unique_name
from . import recordio_writer from . import recordio_writer
...@@ -93,6 +94,7 @@ __all__ = framework.__all__ + executor.__all__ + \ ...@@ -93,6 +94,7 @@ __all__ = framework.__all__ + executor.__all__ + \
'WeightNormParamAttr', 'WeightNormParamAttr',
'DataFeeder', 'DataFeeder',
'clip', 'clip',
'dygraph_grad_clip',
'profiler', 'profiler',
'unique_name', 'unique_name',
'recordio_writer', 'recordio_writer',
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import copy
import six
import functools
from . import layers
from . import framework
from . import core
__all__ = [
'GradClipByValue',
'GradClipByNorm',
'GradClipByGlobalNorm',
]
class GradClipBase(object):
def __str__(self):
raise NotImplementedError()
def _clip(self, para_and_grad):
raise NotImplementedError
def __call__(self, para_and_grad):
return self._clip(para_and_grad)
class GradClipByValue(GradClipBase):
"""
Clips gradient values to the range [min_value, max_value].
Given a gradient g, this operation clips its value to min_value and max_value.
- Any values less than min_value are set to min_value.
- Any values greater than max_value are set to max_value.
Args:
max_value (float): The maximum value to clip by.
min (float, optional): The minimum value to clip by. if not set by user, \
will be set to -max_value(max_value MUST be postive) by framework.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.nn import FC
from paddle.fluid.clip import GradClipByValue, GradClipByNorm, GradClipByGlobalNorm
from paddle.fluid.optimizer import SGDOptimizer
with fluid.dygraph.guard():
value_clip = GradClipByValue( -1.0, 1.0 )
sgd = SGDOptimizer(learning_rate=1.0)
init_value = np.random.uniform( -1, 1, (10, 10)).astype('float32')
fc = FC( "fc", 10)
out = fc( to_variable(init_value) )
loss = fluid.layers.reduce_mean( out )
loss.backward()
sgd.minimize(loss, grad_clip = value_clip)
"""
def __init__(self, min_value, max_value=None):
if min_value is None:
assert (max_value > 0.0)
min_value = -max_value
else:
min_value = float(min_value)
self.max_value = max_value
self.min_value = min_value
def __str__(self):
return "ClipByValue, min = %f, max=%f" % (self.min_value,
self.max_value)
def _clip(self, para_and_grad):
out = []
for p, g in para_and_grad:
if g is None:
out.append((p, g))
continue
new_grad = layers.clip(x=g, min=self.min_value, max=self.max_value)
out.append((p, new_grad))
return out
class GradClipByNorm(GradClipBase):
"""
Clips tensor values to a maximum L2-norm.
This operator limits the L2 norm of the input :math:`X` within :math:`max\_norm`.
If the L2 norm of :math:`X` is less than or equal to :math:`max\_norm`, :math:`Out`
will be the same as :math:`X`. If the L2 norm of :math:`X` is greater than
:math:`max\_norm`, :math:`X` will be linearly scaled to make the L2 norm of
:math:`Out` equal to :math:`max\_norm`, as shown in the following formula:
.. math::
Out = \\frac{max\_norm * X}{norm(X)},
where :math:`norm(X)` represents the L2 norm of :math:`X`.
Args:
clip_norm (float): The maximum norm value
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.nn import FC
from paddle.fluid.clip import GradClipByValue, GradClipByNorm, GradClipByGlobalNorm
from paddle.fluid.optimizer import SGDOptimizer
with fluid.dygraph.guard():
norm_clip = GradClipByNorm( 5.0 )
sgd = SGDOptimizer(learning_rate=1.0)
init_value = np.random.uniform( -1, 1, (10, 10)).astype('float32')
fc = FC( "fc", 10)
out = fc( to_variable(init_value) )
loss = fluid.layers.reduce_mean( out )
loss.backward()
sgd.minimize(loss, grad_clip = norm_clip)
"""
def __init__(self, clip_norm):
self.clip_norm = clip_norm
def __str__(self):
return "ClipByNorm, clip_norm=%f" % self.clip_norm
def _clip(self, para_and_grad):
out = []
for p, g in para_and_grad:
if g is None:
out.append((p, g))
continue
new_g = layers.clip_by_norm(x=g, max_norm=self.clip_norm)
out.append((p, new_g))
return out
class GradClipByGlobalNorm(GradClipBase):
"""
Clips values of multiple tensors by the ratio of the sum of their norms.
Given a list of tensors t_list, and a clipping ratio clip_norm, this
operation returns a list of clipped tensors list_clipped and the global
norm (global_norm) of all tensors in t_list.
To perform the clipping, the values :math:`t\_list[i]` are set to:
.. math::
t\_list[i] = t\_list[i] * \\frac{clip\_norm}{\max(global\_norm, clip\_norm)}
where:
.. math::
global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2}
If :math:`clip\_norm > global\_norm` then the entries in t_list remain as they are,
otherwise they're all shrunk by the global ratio.
Args:
clip_norm (float): The maximum norm value
group_name (str, optional): The group name for this clip.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.nn import FC
from paddle.fluid.clip import GradClipByValue, GradClipByNorm, GradClipByGlobalNorm
from paddle.fluid.optimizer import SGDOptimizer
with fluid.dygraph.guard():
gloabl_norm_clip = GradClipByGlobalNorm( 5.0 )
sgd = SGDOptimizer(learning_rate=1.0)
init_value = np.random.uniform( -1, 1, (10, 10)).astype('float32')
fc = FC( "fc", 10)
out = fc( to_variable(init_value) )
loss = fluid.layers.reduce_mean( out )
loss.backward()
sgd.minimize(loss, grad_clip = gloabl_norm_clip)
"""
def __init__(self, max_global_norm):
self.max_global_norm = layers.fill_constant(
shape=[1], dtype='float32', value=max_global_norm)
def __str__(self):
return "ClipByGlobalNorm, max_global_norm=%f" % (self.max_global_norm)
def _clip(self, para_and_grad):
out = []
norm_arr = []
for p, g in para_and_grad:
if g is None:
continue
power = layers.square(g)
sum_t = layers.reduce_sum(power)
norm_arr.append(sum_t)
norm_global = layers.concat(norm_arr)
norm_global = layers.reduce_sum(norm_global)
norm_global = layers.sqrt(norm_global)
clip_scale = layers.elementwise_div(
x=self.max_global_norm,
y=layers.elementwise_max(
x=norm_global, y=self.max_global_norm))
for p, g in para_and_grad:
if g is None:
out.append((p, g))
continue
new_grad = g * clip_scale
out.append((p, new_grad))
return out
...@@ -463,6 +463,8 @@ class Optimizer(object): ...@@ -463,6 +463,8 @@ class Optimizer(object):
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
with program_guard(framework.default_main_program(), with program_guard(framework.default_main_program(),
framework.default_startup_program()): framework.default_startup_program()):
params_grads = append_regularization_ops(params_grads,
self.regularization)
optimize_ops = self._create_optimization_pass(params_grads) optimize_ops = self._create_optimization_pass(params_grads)
else: else:
program = loss.block.program program = loss.block.program
...@@ -474,7 +476,8 @@ class Optimizer(object): ...@@ -474,7 +476,8 @@ class Optimizer(object):
loss, loss,
startup_program=None, startup_program=None,
parameter_list=None, parameter_list=None,
no_grad_set=None): no_grad_set=None,
grad_clip=None):
""" """
Add operations to minimize `loss` by updating `parameter_list`. Add operations to minimize `loss` by updating `parameter_list`.
...@@ -487,6 +490,7 @@ class Optimizer(object): ...@@ -487,6 +490,7 @@ class Optimizer(object):
in `parameter_list`. in `parameter_list`.
parameter_list (list): list of Variables to update. parameter_list (list): list of Variables to update.
no_grad_set (set|None): set of Variables should be ignored. no_grad_set (set|None): set of Variables should be ignored.
grad_clip (GradClipBase|None) : Gradient clip strategy
Returns: Returns:
tuple: (optimize_ops, params_grads) which are, list of operators appended; tuple: (optimize_ops, params_grads) which are, list of operators appended;
...@@ -497,6 +501,11 @@ class Optimizer(object): ...@@ -497,6 +501,11 @@ class Optimizer(object):
startup_program=startup_program, startup_program=startup_program,
parameter_list=parameter_list, parameter_list=parameter_list,
no_grad_set=no_grad_set) no_grad_set=no_grad_set)
if grad_clip is not None and framework.in_dygraph_mode():
# TODO(hongyu): FIX later, this is only for dygraph, should be work for static mode
params_grads = grad_clip(params_grads)
optimize_ops = self.apply_optimize( optimize_ops = self.apply_optimize(
loss, startup_program=startup_program, params_grads=params_grads) loss, startup_program=startup_program, params_grads=params_grads)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import contextlib
import unittest
import numpy as np
import six
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph_grad_clip import GradClipByValue, GradClipByNorm, GradClipByGlobalNorm
class TestGradClipByGlobalNorm(unittest.TestCase):
def init_value(self):
self.max_global_norm = 5.0
self.init_scale = 1.0
self.shape = (20, 20)
def generate_p_g(self):
self.para_and_grad = []
for i in range(10):
self.para_and_grad.append(
(np.random.uniform(-self.init_scale, self.init_scale,
self.shape).astype('float32'),
np.random.uniform(-self.init_scale, self.init_scale,
self.shape).astype('float32')))
def get_numpy_global_norm_result(self):
gloabl_norm = 0.0
for p, g in self.para_and_grad:
gloabl_norm += np.sum(np.square(g))
gloabl_norm_np = np.sqrt(gloabl_norm)
new_np_p_g = []
scale = 1.0
if gloabl_norm_np > self.max_global_norm:
scale = self.max_global_norm / gloabl_norm_np
for p, g in self.para_and_grad:
new_np_p_g.append((p, g * scale))
return new_np_p_g
def get_dygrap_global_norm_result(self):
with fluid.dygraph.guard():
gloabl_norm_clip = GradClipByGlobalNorm(self.max_global_norm)
p_g_var = []
for p, g in self.para_and_grad:
new_p = to_variable(p)
new_g = to_variable(g)
p_g_var.append((new_p, new_g))
new_p_g_var = gloabl_norm_clip(p_g_var)
p_g_dy_out = []
for p, g in new_p_g_var:
p_g_dy_out.append((p.numpy(), g.numpy()))
return p_g_dy_out
def test_clip_by_global_norm(self):
self.init_value()
self.generate_p_g()
np_p_g = self.get_numpy_global_norm_result()
dy_out_p_g = self.get_dygrap_global_norm_result()
for (p_np, g_np), (p_dy, g_dy) in zip(np_p_g, dy_out_p_g):
self.assertTrue(np.allclose(g_np, g_dy, rtol=1e-6, atol=1e-8))
def test_clip_by_global_norm_2(self):
self.init_value()
self.init_scale = 0.2
self.max_global_norm = 10
self.generate_p_g()
np_p_g = self.get_numpy_global_norm_result()
dy_out_p_g = self.get_dygrap_global_norm_result()
for (p_np, g_np), (p_dy, g_dy) in zip(np_p_g, dy_out_p_g):
self.assertTrue(np.allclose(g_np, g_dy, rtol=1e-6, atol=1e-8))
class TestGradClipByNorm(unittest.TestCase):
def init_value(self):
self.max_norm = 5.0
self.init_scale = 1.0
self.shape = (10, 10)
def generate_p_g(self):
self.para_and_grad = []
for i in range(10):
self.para_and_grad.append(
(np.random.uniform(-self.init_scale, self.init_scale,
self.shape).astype('float32'),
np.random.uniform(-self.init_scale, self.init_scale,
self.shape).astype('float32')))
def get_numpy_norm_result(self):
new_p_g = []
for p, g in self.para_and_grad:
norm = np.sqrt(np.sum(np.square(g)))
if norm > self.max_norm:
new_p_g.append((p, g * self.max_norm / norm))
else:
new_p_g.append((p, g))
return new_p_g
def get_dygrap_norm_result(self):
with fluid.dygraph.guard():
norm_clip = GradClipByNorm(self.max_norm)
p_g_var = []
for p, g in self.para_and_grad:
new_p = to_variable(p)
new_g = to_variable(g)
p_g_var.append((new_p, new_g))
new_p_g_var = norm_clip(p_g_var)
p_g_dy_out = []
for p, g in new_p_g_var:
p_g_dy_out.append((p.numpy(), g.numpy()))
return p_g_dy_out
def test_clip_by_norm(self):
self.init_value()
self.generate_p_g()
np_p_g = self.get_numpy_norm_result()
dy_out_p_g = self.get_dygrap_norm_result()
for (p_np, g_np), (p_dy, g_dy) in zip(np_p_g, dy_out_p_g):
self.assertTrue(np.allclose(g_np, g_dy, rtol=1e-6, atol=1e-8))
def test_clip_by_norm_2(self):
self.init_value()
self.init_scale = 0.2
self.max_norm = 10.0
self.generate_p_g()
np_p_g = self.get_numpy_norm_result()
dy_out_p_g = self.get_dygrap_norm_result()
for (p_np, g_np), (p_dy, g_dy) in zip(np_p_g, dy_out_p_g):
self.assertTrue(np.allclose(g_np, g_dy, rtol=1e-6, atol=1e-8))
class TestGradClipByValue(unittest.TestCase):
def init_value(self):
self.max_value = 0.8
self.min_value = -0.1
self.init_scale = 1.0
self.shape = (10, 10)
def generate_p_g(self):
self.para_and_grad = []
for i in range(10):
self.para_and_grad.append(
(np.random.uniform(-self.init_scale, self.init_scale,
self.shape).astype('float32'),
np.random.uniform(-self.init_scale, self.init_scale,
self.shape).astype('float32')))
def get_numpy_clip_result(self):
new_p_g = []
for p, g in self.para_and_grad:
new_p_g.append((p, np.clip(g, self.min_value, self.max_value)))
return new_p_g
def get_dygrap_clip_result(self):
with fluid.dygraph.guard():
value_clip = GradClipByValue(self.min_value, self.max_value)
p_g_var = []
for p, g in self.para_and_grad:
new_p = to_variable(p)
new_g = to_variable(g)
p_g_var.append((new_p, new_g))
new_p_g_var = value_clip(p_g_var)
p_g_dy_out = []
for p, g in new_p_g_var:
p_g_dy_out.append((p.numpy(), g.numpy()))
return p_g_dy_out
def test_clip_by_value(self):
self.init_value()
self.generate_p_g()
np_p_g = self.get_numpy_clip_result()
dy_out_p_g = self.get_dygrap_clip_result()
for (p_np, g_np), (p_dy, g_dy) in zip(np_p_g, dy_out_p_g):
self.assertTrue(np.allclose(g_np, g_dy, rtol=1e-6, atol=1e-8))
def test_clip_by_norm_2(self):
self.init_value()
self.init_scale = 0.2
self.generate_p_g()
np_p_g = self.get_numpy_clip_result()
dy_out_p_g = self.get_dygrap_clip_result()
for (p_np, g_np), (p_dy, g_dy) in zip(np_p_g, dy_out_p_g):
self.assertTrue(np.allclose(g_np, g_dy, rtol=1e-6, atol=1e-8))
def test_clip_by_norm_3(self):
self.init_value()
self.init_scale = 0.5
self.max_value = 0.6
self.min_value = None
self.generate_p_g()
np_p_g = self.get_numpy_clip_result()
dy_out_p_g = self.get_dygrap_clip_result()
for (p_np, g_np), (p_dy, g_dy) in zip(np_p_g, dy_out_p_g):
self.assertTrue(np.allclose(g_np, g_dy, rtol=1e-6, atol=1e-8))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册