未验证 提交 7fda333a 编写于 作者: Z Zhou Wei 提交者: GitHub

add new method of gradient_clip, better to use,test=develop (#23224)

上级 b7b0b359
......@@ -75,7 +75,6 @@ from .transpiler import DistributeTranspiler, \
memory_optimize, release_memory, DistributeTranspilerConfig
from .lod_tensor import create_lod_tensor, create_random_int_lodtensor
from . import clip
from . import dygraph_grad_clip
from . import profiler
from . import unique_name
from . import parallel_executor
......@@ -122,7 +121,6 @@ __all__ = framework.__all__ + executor.__all__ + \
'WeightNormParamAttr',
'DataFeeder',
'clip',
'dygraph_grad_clip',
'profiler',
'unique_name',
'Scope',
......
......@@ -16,19 +16,18 @@ from __future__ import print_function
import copy
import six
import warnings
import functools
from . import layers
from . import framework
from . import core
from . import name_scope
from .dygraph import base as imperative_base
__all__ = [
'set_gradient_clip',
'ErrorClipByValue',
'GradientClipByValue',
'GradientClipByNorm',
'GradientClipByGlobalNorm',
'set_gradient_clip', 'ErrorClipByValue', 'GradientClipByValue',
'GradientClipByNorm', 'GradientClipByGlobalNorm'
]
......@@ -116,29 +115,51 @@ def error_clip_callback(block, context):
error_clip._append_clip_op(block, grad_n)
class BaseGradientClipAttr(object):
class GradientClipBase(object):
def __init__(self, need_clip=None):
if need_clip is not None and not callable(need_clip):
raise TypeError(
"The type of need_clip must be funciton, and it can filter out "
"parameter that does't need gradient clip. This function must return "
"True or False, and True means that clipping is required. Please refer to "
"API documention of GradientClipByGlobalNorm / GradientClipByNorm "
"/GradientClipByValue.")
self._need_clip_func = need_clip
def __str__(self):
raise NotImplementedError()
def _process_context(self, context, param, grad):
raise NotImplementedError()
def _create_operators(self, param, grad):
raise NotImplementedError()
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
raise NotImplementedError
def _static_clip(self, params_grads):
raise NotImplementedError
class NullGradientClipAttr(BaseGradientClipAttr):
def __str__(self):
return "Null"
def __call__(self, params_grads):
assert len(
params_grads
) > 0, "The number of trainable parameters should be greater than 0."
if framework.in_dygraph_mode():
return self._dygraph_clip(params_grads)
else:
for p, g in params_grads:
if getattr(p, 'gradient_clip_attr', None) is not None:
warnings.warn(
"'set_gradient_clip' will be ineffective, because you have "
"pass 'grad_clip' into 'minimize'. So, 'set_gradient_clip' "
"is redundant and you can remove it.")
break
return self._static_clip(params_grads)
def _process_context(self, context, param, grad):
pass
raise NotImplementedError()
def _create_operators(self, param, grad):
return param, grad
raise NotImplementedError()
class GradientClipByValue(BaseGradientClipAttr):
class GradientClipByValue(GradientClipBase):
"""
Clips gradient values to the range [min, max].
......@@ -168,17 +189,46 @@ class GradientClipByValue(BaseGradientClipAttr):
input=x, size=1, param_attr=w_param_attrs)
"""
def __init__(self, max, min=None):
max = float(max)
def __init__(self, max, min=None, need_clip=None):
super(GradientClipByValue, self).__init__(need_clip)
if min is None:
assert (max > 0.0)
min = -max
else:
min = float(min)
self.max = max
self.min = min
self.max = float(max)
self.min = float(min)
def __str__(self):
return "ByValue, min=%f, max=%f" % (self.min, self.max)
return "Gradient Clip By Value, min = %f, max=%f" % (self.min, self.max)
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
for p, g in params_grads:
if g is None:
continue
if self._need_clip_func is not None and not self._need_clip_func(p):
params_and_grads.append((p, g))
continue
new_grad = layers.clip(x=g, min=self.min, max=self.max)
params_and_grads.append((p, new_grad))
return params_and_grads
def _static_clip(self, params_grads):
params_and_grads = []
with framework.name_scope('gradient_clip'):
for p, g in params_grads:
if g is None:
continue
if self._need_clip_func is not None and not self._need_clip_func(
p):
params_and_grads.append((p, g))
continue
with p.block.program._optimized_guard([p, g]):
new_grad = layers.clip(x=g, min=self.min, max=self.max)
params_and_grads.append((p, new_grad))
_correct_clip_op_role_var(params_and_grads)
return params_and_grads
def _process_context(self, context, param, grad):
pass
......@@ -188,7 +238,7 @@ class GradientClipByValue(BaseGradientClipAttr):
return param, new_grad
class GradientClipByNorm(BaseGradientClipAttr):
class GradientClipByNorm(GradientClipBase):
"""
Convert the input multidimensional Tensor :math:`X` to a multidimensional Tensor whose L2 norm does not exceed the given two-norm maximum ( :math:`clip\_norm` ).
......@@ -268,11 +318,42 @@ class GradientClipByNorm(BaseGradientClipAttr):
"""
def __init__(self, clip_norm):
self.clip_norm = clip_norm
def __init__(self, clip_norm, need_clip=None):
super(GradientClipByNorm, self).__init__(need_clip)
self.clip_norm = float(clip_norm)
def __str__(self):
return "ByNorm, clip_norm=%f" % self.clip_norm
return "Gradient Clip By Norm, clip_norm=%f" % self.clip_norm
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
for p, g in params_grads:
if g is None:
continue
if self._need_clip_func is not None and not self._need_clip_func(p):
params_and_grads.append((p, g))
continue
new_grad = layers.clip_by_norm(x=g, max_norm=self.clip_norm)
params_and_grads.append((p, new_grad))
return params_and_grads
def _static_clip(self, params_grads):
params_and_grads = []
with framework.name_scope('gradient_clip'):
for p, g in params_grads:
if g is None:
continue
if self._need_clip_func is not None and not self._need_clip_func(
p):
params_and_grads.append((p, g))
continue
with p.block.program._optimized_guard([p, g]):
new_grad = layers.clip_by_norm(x=g, max_norm=self.clip_norm)
params_and_grads.append((p, new_grad))
_correct_clip_op_role_var(params_and_grads)
return params_and_grads
def _process_context(self, context, param, grad):
pass
......@@ -282,7 +363,7 @@ class GradientClipByNorm(BaseGradientClipAttr):
return param, new_grad
class GradientClipByGlobalNorm(BaseGradientClipAttr):
class GradientClipByGlobalNorm(GradientClipBase):
"""
Clips values of multiple tensors by the ratio of the sum of their norms.
......@@ -371,16 +452,104 @@ class GradientClipByGlobalNorm(BaseGradientClipAttr):
"""
def __init__(self, clip_norm, group_name="default_group"):
if not isinstance(group_name, six.string_types):
raise TypeError("'group_name' must be a %s." % (six.string_types))
self.clip_norm = clip_norm
def __init__(self, clip_norm, group_name="default_group", need_clip=None):
super(GradientClipByGlobalNorm, self).__init__(need_clip)
self.clip_norm = float(clip_norm)
self.group_name = group_name
def __str__(self):
return "ByGlobalNorm, group_name=%s, clip_norm=%f" % (self.group_name,
self.clip_norm)
return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm)
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
sum_square_list = []
for p, g in params_grads:
if g is None:
continue
if self._need_clip_func is not None and not self._need_clip_func(p):
continue
merge_grad = g
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
square = layers.square(merge_grad)
sum_square = layers.reduce_sum(square)
sum_square_list.append(sum_square)
# all parameters have been filterd out
if len(sum_square_list) == 0:
return params_grads
global_norm_var = layers.concat(sum_square_list)
global_norm_var = layers.reduce_sum(global_norm_var)
global_norm_var = layers.sqrt(global_norm_var)
max_global_norm = layers.fill_constant(
shape=[1], dtype='float32', value=self.clip_norm)
clip_var = layers.elementwise_div(
x=max_global_norm,
y=layers.elementwise_max(
x=global_norm_var, y=max_global_norm))
for p, g in params_grads:
if g is None:
continue
if self._need_clip_func is not None and not self._need_clip_func(p):
params_and_grads.append((p, g))
continue
new_grad = layers.elementwise_mul(x=g, y=clip_var)
params_and_grads.append((p, new_grad))
return params_and_grads
def _static_clip(self, params_grads):
params_and_grads = []
sum_square_list = []
with framework.name_scope('gradient_clip'):
for p, g in params_grads:
if g is None:
continue
if self._need_clip_func is not None and not self._need_clip_func(
p):
continue
merge_grad = g
with p.block.program._optimized_guard([p, g]):
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(
merge_grad)
square = layers.square(merge_grad)
sum_square = layers.reduce_sum(input=square)
sum_square_list.append(sum_square)
# all parameters have been filterd out
if len(sum_square_list) == 0:
return params_grads
with p.block.program._optimized_guard([p, g]):
global_norm_var = layers.sums(sum_square_list)
global_norm_var = layers.sqrt(x=global_norm_var)
max_global_norm = layers.fill_constant(
shape=[1], dtype="float32", value=self.clip_norm)
scale_var = layers.elementwise_div(
x=max_global_norm,
y=layers.elementwise_max(
x=max_global_norm, y=global_norm_var))
for p, g in params_grads:
if g is None:
continue
if self._need_clip_func is not None and not self._need_clip_func(
p):
params_and_grads.append((p, g))
continue
with p.block.program._optimized_guard([p, g]):
new_grad = layers.elementwise_mul(x=g, y=scale_var)
params_and_grads.append((p, new_grad))
_correct_clip_op_role_var(params_and_grads)
return params_and_grads
def _process_context(self, context, param, grad):
if self.group_name not in context:
......@@ -486,12 +655,28 @@ def set_gradient_clip(clip, param_list=None, program=None):
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss)
"""
if not isinstance(clip, BaseGradientClipAttr):
warnings.warn("Caution! 'set_gradient_clip' is not recommended "
"and may be deprecated in future! "
"We recommend a new strategy: clip gradient by "
"'optimizer.minimize(loss, grad_clip=clip)'. "
"This method can reduce the mistakes, please "
"see documention of 'optimzier.minimize'.")
if not isinstance(clip, GradientClipBase):
raise TypeError(
"'clip' should be an instance of BaseGradientClipAttr's derived class"
)
"'clip' should be an instance of GradientClipBase's derived class")
if program is None:
program = framework.default_main_program()
for op in program.block(0).ops:
if 'op_namescope' in op.all_attrs() and "optimizer" in op.attr(
"op_namescope"):
warnings.warn(
"'minimize' has been invoked before, this will make 'set_gradient_clip' "
"be ineffective! Please invoke 'set_gradient_clip' before 'minimize'."
)
break
if param_list is None:
param_list = program.block(0).all_parameters()
if all(isinstance(elem, six.string_types) for elem in param_list):
......@@ -511,46 +696,45 @@ def append_gradient_clip_ops(param_grads):
if g is None:
continue
with p.block.program._optimized_guard(
[p, g]), framework.name_scope('append_clip_@CLIP'):
clip_attr = getattr(p, 'gradient_clip_attr', NullGradientClipAttr())
[p, g]), framework.name_scope('gradient_clip_@CLIP'):
clip_attr = getattr(p, 'gradient_clip_attr', None)
if clip_attr is None:
clip_attr = NullGradientClipAttr()
if not isinstance(clip_attr, BaseGradientClipAttr):
return param_grads
if not isinstance(clip_attr, GradientClipBase):
raise TypeError(
"clip attribute should be an instance of BaseGradientClipAttr"
)
"clip attribute should be an instance of GradientClipBase")
clip_attr._process_context(context=context, param=p, grad=g)
res = []
param_new_grad_dict = dict()
for p, g in param_grads:
if g is None:
continue
with p.block.program._optimized_guard(
[p, g]), framework.name_scope('append_graident_clip_@CLIP'):
[p, g]), framework.name_scope('graident_clip_@CLIP'):
param, new_grad = clip_attr._create_operators(param=p, grad=g)
param_new_grad_dict[param.name] = new_grad.name
res.append([param, new_grad])
# change wrong mapping relation between param & grad in clip op
clip_flag = '@CLIP'
block_id_list = []
for p, g in param_grads:
if g is None:
continue
block_id = p.block.idx
if block_id in block_id_list:
_correct_clip_op_role_var(res)
return res
# change wrong mapping relation between param & grad in clip op
def _correct_clip_op_role_var(params_grads):
for param, grad in params_grads:
if grad is None:
continue
block_id_list.append(block_id)
for op in p.block.program.global_block().ops:
if 'op_namescope' in op.all_attrs() and clip_flag in op.attr(
for op in param.block.program.global_block().ops:
if 'op_namescope' in op.all_attrs() and "gradient_clip" in op.attr(
"op_namescope"):
if op.attr('op_role_var'):
param_name = op.attr('op_role_var')[0]
correct_p_g = [param_name, param_new_grad_dict[param_name]]
index = 0
for i in range(len(params_grads)):
if params_grads[i][0].name == param_name:
index = i
correct_p_g = [param_name, params_grads[index][1].name]
op._set_attr('op_role_var', correct_p_g)
return res
ClipByValue = GradientClipByValue
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import copy
import six
import functools
from . import layers
from . import framework
from . import core
from .dygraph import base as imperative_base
__all__ = [
'GradClipByValue',
'GradClipByNorm',
'GradClipByGlobalNorm',
]
class GradClipBase(object):
def __str__(self):
raise NotImplementedError()
def _clip(self, para_and_grad):
raise NotImplementedError
@imperative_base.no_grad
def __call__(self, para_and_grad):
return self._clip(para_and_grad)
class GradClipByValue(GradClipBase):
"""
Clips gradient values to the range [min_value, max_value].
Given a gradient g, this operation clips its value to min_value and max_value.
- Any values less than min_value are set to min_value.
- Any values greater than max_value are set to max_value.
Args:
max_value (float): The maximum value to clip by.
min (float, optional): The minimum value to clip by. if not set by user, \
will be set to -max_value(max_value MUST be positive) by framework.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.clip import GradClipByValue, GradClipByNorm, GradClipByGlobalNorm
from paddle.fluid.optimizer import SGDOptimizer
with fluid.dygraph.guard():
value_clip = GradClipByValue( -1.0, 1.0 )
sgd = SGDOptimizer(learning_rate=1.0)
init_value = np.random.uniform( -1, 1, (10, 10)).astype('float32')
linear = Linear( 10, 10)
out = linear( to_variable(init_value) )
loss = fluid.layers.reduce_mean( out )
loss.backward()
sgd.minimize(loss, grad_clip = value_clip)
"""
@imperative_base.no_grad
def __init__(self, min_value, max_value=None):
if min_value is None:
assert (max_value > 0.0)
min_value = -max_value
else:
min_value = float(min_value)
self.max_value = max_value
self.min_value = min_value
def __str__(self):
return "ClipByValue, min = %f, max=%f" % (self.min_value,
self.max_value)
def _clip(self, para_and_grad):
out = []
for p, g in para_and_grad:
if g is None:
out.append((p, g))
continue
new_grad = layers.clip(x=g, min=self.min_value, max=self.max_value)
out.append((p, new_grad))
return out
class GradClipByNorm(GradClipBase):
"""
Clips tensor values to a maximum L2-norm.
This operator limits the L2 norm of the input :math:`X` within :math:`max\_norm`.
If the L2 norm of :math:`X` is less than or equal to :math:`max\_norm`, :math:`Out`
will be the same as :math:`X`. If the L2 norm of :math:`X` is greater than
:math:`max\_norm`, :math:`X` will be linearly scaled to make the L2 norm of
:math:`Out` equal to :math:`max\_norm`, as shown in the following formula:
.. math::
Out = \\frac{max\_norm * X}{norm(X)},
where :math:`norm(X)` represents the L2 norm of :math:`X`.
Args:
clip_norm (float): The maximum norm value
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.clip import GradClipByValue, GradClipByNorm, GradClipByGlobalNorm
from paddle.fluid.optimizer import SGDOptimizer
with fluid.dygraph.guard():
norm_clip = GradClipByNorm( 5.0 )
sgd = SGDOptimizer(learning_rate=1.0)
init_value = np.random.uniform( -1, 1, (10, 10)).astype('float32')
linear = Linear( 10, 10)
out = linear( to_variable(init_value) )
loss = fluid.layers.reduce_mean( out )
loss.backward()
sgd.minimize(loss, grad_clip = norm_clip)
"""
@imperative_base.no_grad
def __init__(self, clip_norm):
self.clip_norm = clip_norm
def __str__(self):
return "ClipByNorm, clip_norm=%f" % self.clip_norm
def _clip(self, para_and_grad):
out = []
for p, g in para_and_grad:
if g is None:
out.append((p, g))
continue
new_g = layers.clip_by_norm(x=g, max_norm=self.clip_norm)
out.append((p, new_g))
return out
class GradClipByGlobalNorm(GradClipBase):
"""
Clips values of multiple tensors by the ratio of the sum of their norms.
Given a list of tensors t_list, and a clipping ratio max_global_norm, this
operation returns a list of clipped tensors list_clipped.
To perform the clipping, the values :math:`t\_list[i]` are set to:
.. math::
t\_list[i] = t\_list[i] * \\frac{max\_global\_norm}{\max(global\_norm, max\_global\_norm)}
where:
.. math::
global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2}
If :math:`max\_global\_norm > global\_norm` then the entries in t_list remain as they are,
otherwise they're all shrunk by the global ratio.
Args:
max_global_norm (float): The maximum norm value.
dtype (str, optional): The type of max_global_norm. Default: "float32".
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.dygraph_grad_clip import GradClipByValue, GradClipByNorm, GradClipByGlobalNorm
from paddle.fluid.optimizer import SGDOptimizer
with fluid.dygraph.guard():
gloabl_norm_clip = GradClipByGlobalNorm( 5.0 )
sgd = SGDOptimizer(learning_rate=1.0)
init_value = np.random.uniform( -1, 1, (10, 10)).astype('float32')
linear = Linear( 10, 10)
out = linear( to_variable(init_value) )
loss = fluid.layers.reduce_mean( out )
loss.backward()
sgd.minimize(loss, grad_clip = gloabl_norm_clip)
"""
@imperative_base.no_grad
def __init__(self, max_global_norm, dtype='float32'):
self.max_global_norm = layers.fill_constant(
shape=[1], dtype=dtype, value=max_global_norm)
def __str__(self):
return "ClipByGlobalNorm, max_global_norm=%f" % (self.max_global_norm)
def _clip(self, para_and_grad):
out = []
norm_arr = []
for p, g in para_and_grad:
if g is None:
continue
merge_grad = g
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
power = layers.square(merge_grad)
sum_t = layers.reduce_sum(power)
norm_arr.append(sum_t)
norm_global = layers.concat(norm_arr)
norm_global = layers.reduce_sum(norm_global)
norm_global = layers.sqrt(norm_global)
clip_scale = self.max_global_norm / (layers.elementwise_max(
x=norm_global, y=self.max_global_norm))
for p, g in para_and_grad:
if g is None:
out.append((p, g))
continue
new_grad = g * clip_scale
out.append((p, new_grad))
return out
......@@ -2409,7 +2409,6 @@ class Block(object):
trainable = v.trainable
optimize_attr = v.optimize_attr
regularizer = v.regularizer
gradient_clip_attr = v.gradient_clip_attr
error_clip = v.error_clip
elif type(v) == Variable:
var_type = "Variable"
......@@ -2432,7 +2431,6 @@ class Block(object):
trainable=trainable,
optimize_attr=optimize_attr,
regularizer=regularizer,
gradient_clip_attr=gradient_clip_attr,
error_clip=error_clip)
else:
var = Parameter(
......@@ -2445,7 +2443,6 @@ class Block(object):
trainable=trainable,
optimize_attr=optimize_attr,
regularizer=regularizer,
gradient_clip_attr=gradient_clip_attr,
error_clip=error_clip)
elif var_type == "Variable":
var = Variable(
......@@ -2723,7 +2720,6 @@ class Block(object):
trainable=p.trainable,
optimize_attr=p.optimize_attr,
regularizer=p.regularizer,
gradient_clip_attr=p.gradient_clip_attr,
error_clip=p.error_clip,
name=v.name)
else:
......@@ -2737,7 +2733,6 @@ class Block(object):
trainable=p.trainable,
optimize_attr=p.optimize_attr,
regularizer=p.regularizer,
gradient_clip_attr=p.gradient_clip_attr,
error_clip=p.error_clip,
name=v.name)
self.vars[new_p.name] = new_p
......@@ -4646,8 +4641,6 @@ class Parameter(Variable):
Default: {'learning_rate': 1.0}
regularizer(WeightDecayRegularizer): The Regularizer which will
be applied on the parameter. Default: None
gradient_clip_attr(BaseGradientClipAttr): The gradient clip strategy
which will be applied on the parameter. Default: None
do_model_average(bool): True if the model average strategy will
be applied on this parameter.
"""
......@@ -4687,8 +4680,6 @@ class Parameter(Variable):
self.regularizer = kwargs.get('regularizer', None)
self.gradient_clip_attr = kwargs.get('gradient_clip_attr', None)
self.do_model_average = kwargs.get('do_model_average', None)
self.is_distributed = False
......@@ -4723,7 +4714,7 @@ class Parameter(Variable):
if with_details:
res_str = Variable.to_string(self, throw_on_error, True)
additional_attr = ("trainable", "optimize_attr", "regularizer",
"gradient_clip_attr", "do_model_average")
"do_model_average")
for attr_name in additional_attr:
res_str += "%s: %s\n" % (attr_name,
cpt.to_text(getattr(self, attr_name)))
......@@ -4752,8 +4743,6 @@ class ParamBase(core.VarBase):
Default: {'learning_rate': 1.0}
regularizer(WeightDecayRegularizer): The Regularizer which will
be applied on the ParamBase. Default: None
gradient_clip_attr(BaseGradientClipAttr): The gradient clip strategy
which will be applied on the ParamBase. Default: None
do_model_average(bool): True if the model average strategy will
be applied on this ParamBase.
"""
......@@ -4792,8 +4781,6 @@ class ParamBase(core.VarBase):
self.regularizer = kwargs.get('regularizer', None)
self.gradient_clip_attr = kwargs.get('gradient_clip_attr', None)
self.do_model_average = kwargs.get('do_model_average', None)
self.is_distributed = False
......
......@@ -24,7 +24,7 @@ from . import framework
from . import layers
from . import unique_name
from .backward import append_backward, _some_in_set_, _append_grad_suffix_, _get_no_grad_set_name
from .clip import append_gradient_clip_ops, error_clip_callback
from .clip import GradientClipBase, error_clip_callback, append_gradient_clip_ops
from .framework import program_guard
from .initializer import Constant
from .layer_helper import LayerHelper
......@@ -109,6 +109,8 @@ class Optimizer(object):
self._opti_name_list = []
self._accumulators_holder = {}
self._param_device_map = dict()
# if pass grad_clip into minimize, it will not be None
self._grad_clip = None
@framework.dygraph_only
def state_dict(self):
......@@ -690,12 +692,17 @@ class Optimizer(object):
# ...
optimizer.apply_gradients(params_grads)
"""
params_grads = sorted(params_grads, key=lambda x: x[0].name)
params_grads, table_param_and_grad, table_optimize_op = \
self._process_distribute_lookuptable(params_grads)
params_grads = append_gradient_clip_ops(params_grads)
# 'minimize(grad_clip)' or 'set_gradient_clip'
if self._grad_clip is not None:
params_grads = self._grad_clip(params_grads)
else:
params_grads = append_gradient_clip_ops(params_grads)
# Add regularization if any
params_grads = append_regularization_ops(params_grads,
......@@ -712,19 +719,19 @@ class Optimizer(object):
"""
Second part of `minimize`, appending optimization operators for
given `params_grads` pairs.
Args:
loss (Variable): loss variable to run optimizations.
startup_program (Program): startup_program for initializing parameters
in `parameter_list`.
params_grads (list): list of (param, grad) pair to do optimization.
Returns:
list: A list of operators appended to the current program.
"""
if framework.in_dygraph_mode():
with program_guard(framework.default_main_program(),
framework.default_startup_program()):
if self._grad_clip is not None:
params_grads = self._grad_clip(params_grads)
params_grads = append_regularization_ops(params_grads,
self.regularization)
optimize_ops = self._create_optimization_pass(params_grads)
......@@ -809,16 +816,19 @@ class Optimizer(object):
Please refer to the example of current Optimizer.
"""
assert isinstance(loss, Variable), "The loss should be an Variable."
if grad_clip is not None:
if not isinstance(grad_clip, GradientClipBase):
raise TypeError(
"'grad_clip' should be an instance of GradientClipBase's derived class"
)
self._grad_clip = grad_clip
params_grads = self.backward(
loss,
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
if grad_clip is not None and framework.in_dygraph_mode():
# TODO(hongyu): FIX later, this is only for dygraph, should be work for static mode
params_grads = grad_clip(params_grads)
optimize_ops = self.apply_optimize(
loss, startup_program=startup_program, params_grads=params_grads)
......@@ -1148,6 +1158,7 @@ class DGCMomentumOptimizer(Optimizer):
self.regular_type, self.regular_coeff = self._get_regularization_param(
self.regularization)
self._grad_clip = None
def _get_regularization_param(self, regularization):
regular_type = 0
......@@ -1404,24 +1415,28 @@ class DGCMomentumOptimizer(Optimizer):
dgc_op._set_attr(op_maker.kOpRoleVarAttrName(),
[param_var.name, grad_var.name])
@imperative_base.no_grad
def apply_gradients(self, params_grads):
params_grads = sorted(params_grads, key=lambda x: x[0].name)
params_grads, table_param_and_grad, table_optimize_op = \
self._process_distribute_lookuptable(params_grads)
not_dgc_params_grads = []
dgc_params_grads = []
# DGC clip and regularization in optimizer.backward
for param, grad in params_grads:
if not self._is_use_dgc(param, grad):
not_dgc_params_grads.append((param, grad))
else:
dgc_params_grads.append((param, grad))
# DGC clip and regularization in local
not_dgc_params_grads = append_gradient_clip_ops(not_dgc_params_grads)
# 'minimize(grad_clip)' or 'set_gradient_clip'
if self._grad_clip is not None:
not_dgc_params_grads = self._grad_clip(not_dgc_params_grads)
else:
not_dgc_params_grads = append_gradient_clip_ops(
not_dgc_params_grads)
# Add regularization if any
not_dgc_params_grads = append_regularization_ops(not_dgc_params_grads,
self.regularization)
......@@ -3942,16 +3957,13 @@ class RecomputeOptimizer(Optimizer):
def apply_optimize(self, loss, startup_program, params_grads):
"""
call the apply_optimize function of self._optimizer
Args:
loss (Variable): loss variable to run optimizations.
startup_program (Program): startup_program for initializing parameters
in `parameter_list`.
params_grads (list): list of (param, grad) pair to do optimization.
Examples:
.. code-block:: python
import paddle.fluid as fluid
def mlp(input_x, input_y, hid_dim=128, label_dim=2):
......@@ -3979,7 +3991,6 @@ class RecomputeOptimizer(Optimizer):
cost, startup_program=None, params_grads=params_grads)
print("Finished apply_optimize")
"""
return self._optimizer.apply_optimize(
......@@ -3991,24 +4002,24 @@ class RecomputeOptimizer(Optimizer):
parameter_list=None,
no_grad_set=None,
grad_clip=None):
assert (isinstance(loss, Variable)), "The loss should be an Variable."
assert isinstance(loss, Variable), "The loss should be an Variable."
assert (self._checkpoints is not None
), "You should call _set_checkpoints first"
if framework.in_dygraph_mode():
raise NotImplementedError(
"DyGraph current does not support recompute")
if grad_clip is not None:
if not isinstance(grad_clip, GradientClipBase):
raise TypeError(
"'grad_clip' should be an instance of GradientClipBase's derived class"
)
self._optimizer._grad_clip = grad_clip
params_grads = self.backward(
loss,
startup_program=startup_program,
parameter_list=parameter_list,
no_grad_set=no_grad_set)
if grad_clip:
# TODO(guru4elephant): should add grad_clip for static graph
pass
optimize_ops = self.apply_optimize(
loss, startup_program=startup_program, params_grads=params_grads)
......
......@@ -15,6 +15,7 @@
from __future__ import print_function
import six
import warnings
from .initializer import Initializer, Xavier, Constant
from .regularizer import WeightDecayRegularizer
......@@ -68,7 +69,6 @@ class ParamAttr(object):
learning_rate=1.0,
regularizer=None,
trainable=True,
gradient_clip=None,
do_model_average=True):
self.name = name
if isinstance(self.name, six.string_types) and self.name == "":
......@@ -78,7 +78,6 @@ class ParamAttr(object):
self.learning_rate = learning_rate
self.regularizer = regularizer
self.trainable = trainable
self.gradient_clip = gradient_clip
self.do_model_average = do_model_average
def _set_default_initializer(self, initializer):
......@@ -176,7 +175,6 @@ class ParamAttr(object):
},
'regularizer': self.regularizer,
'trainable': self.trainable,
'gradient_clip_attr': self.gradient_clip,
'do_model_average': self.do_model_average
}
if with_initializer:
......@@ -248,7 +246,6 @@ class WeightNormParamAttr(ParamAttr):
learning_rate=1.0,
regularizer=None,
trainable=True,
gradient_clip=None,
do_model_average=False):
super(WeightNormParamAttr, self).__init__(
name=name,
......@@ -256,6 +253,5 @@ class WeightNormParamAttr(ParamAttr):
learning_rate=learning_rate,
regularizer=regularizer,
trainable=trainable,
gradient_clip=gradient_clip,
do_model_average=do_model_average)
self.dim = dim
......@@ -476,15 +476,18 @@ class TestL2Decay(TranspilerTest):
size=1000,
act=None,
param_attr=fluid.ParamAttr(
name='fc_w',
regularizer=fluid.regularizer.L2Decay(),
gradient_clip=fluid.clip.GradientClipByValue(0.1)),
name='fc_w', regularizer=fluid.regularizer.L2Decay()),
bias_attr=fluid.ParamAttr(name='fc_b'))
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
avg_cost = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
sgd_optimizer.minimize(avg_cost)
def filter(param):
return param.name == "fc_w"
clip = fluid.clip.GradientClipByValue(0.1, need_clip=filter)
sgd_optimizer.minimize(avg_cost, grad_clip=clip)
def transpiler_test_impl(self):
pserver, startup = self.get_pserver(self.pserver1_ep)
......
......@@ -25,7 +25,7 @@ from paddle.fluid import core
from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph_grad_clip import GradClipByValue, GradClipByNorm, GradClipByGlobalNorm
from paddle.fluid.clip import GradientClipByValue, GradientClipByNorm, GradientClipByGlobalNorm
class TestGradClipByGlobalNorm(unittest.TestCase):
......@@ -65,7 +65,7 @@ class TestGradClipByGlobalNorm(unittest.TestCase):
def get_dygrap_global_norm_result(self):
with fluid.dygraph.guard():
gloabl_norm_clip = GradClipByGlobalNorm(self.max_global_norm)
gloabl_norm_clip = GradientClipByGlobalNorm(self.max_global_norm)
p_g_var = []
for p, g in self.para_and_grad:
new_p = to_variable(p)
......@@ -135,7 +135,7 @@ class TestGradClipByNorm(unittest.TestCase):
def get_dygrap_norm_result(self):
with fluid.dygraph.guard():
norm_clip = GradClipByNorm(self.max_norm)
norm_clip = GradientClipByNorm(self.max_norm)
p_g_var = []
for p, g in self.para_and_grad:
new_p = to_variable(p)
......@@ -200,8 +200,8 @@ class TestGradClipByValue(unittest.TestCase):
def get_dygrap_clip_result(self):
with fluid.dygraph.guard():
value_clip = GradClipByValue(self.min_value, self.max_value)
value_clip = GradientClipByValue(
max=self.max_value, min=self.min_value)
p_g_var = []
for p, g in self.para_and_grad:
new_p = to_variable(p)
......@@ -225,7 +225,7 @@ class TestGradClipByValue(unittest.TestCase):
for (p_np, g_np), (p_dy, g_dy) in zip(np_p_g, dy_out_p_g):
self.assertTrue(np.allclose(g_np, g_dy, rtol=1e-6, atol=1e-8))
def test_clip_by_norm_2(self):
def test_clip_by_value_2(self):
self.init_value()
self.init_scale = 0.2
......@@ -236,7 +236,7 @@ class TestGradClipByValue(unittest.TestCase):
for (p_np, g_np), (p_dy, g_dy) in zip(np_p_g, dy_out_p_g):
self.assertTrue(np.allclose(g_np, g_dy, rtol=1e-6, atol=1e-8))
def test_clip_by_norm_3(self):
def test_clip_by_value_3(self):
self.init_value()
self.init_scale = 0.5
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
......@@ -54,26 +54,32 @@ class TestGradientClip(unittest.TestCase):
self.BATCH_SIZE = 2
reader = fake_imdb_reader(self.word_dict_len, self.BATCH_SIZE * 100)
self.train_data = paddle.batch(reader, batch_size=self.BATCH_SIZE)
self.init()
def init(self):
pass
def get_places(self):
places = [core.CPUPlace()]
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
places.append(fluid.CUDAPlace(0))
return places
def check_operators(self, place):
CLIP = 1
def clip_gradient(self, params_grads):
pass
prog = fluid.framework.Program()
startup_program = fluid.framework.Program()
def check_clip_result(self, out, out_clip):
pass
def check_gradient_clip(self, place):
prog = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(
main_program=prog, startup_program=startup_program):
image = fluid.layers.data(name='x', shape=[784], dtype='float32')
label = fluid.layers.data(name='y', shape=[1], dtype='int64')
hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
predict = fluid.layers.fc(input=hidden2, size=10, act='softmax')
image = fluid.data(name='x', shape=[-1, 784], dtype='float32')
label = fluid.data(name='y', shape=[-1, 1], dtype='int64')
hidden = fluid.layers.fc(input=image, size=32, act='relu')
predict = fluid.layers.fc(input=hidden, size=10, act='softmax')
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(cost)
......@@ -84,45 +90,26 @@ class TestGradientClip(unittest.TestCase):
p_g = fluid.backward.append_backward(loss=avg_cost)
p_g_clip = fluid.backward.append_backward(loss=avg_cost_clip)
p_g = sorted(p_g, key=lambda x: x[0].name)
p_g_clip = sorted(p_g_clip, key=lambda x: x[0].name)
with fluid.program_guard(
main_program=prog_clip, startup_program=startup_program):
fluid.clip.set_gradient_clip(
fluid.clip.GradientClipByGlobalNorm(clip_norm=CLIP))
p_g_clip = fluid.clip.append_gradient_clip_ops(p_g_clip)
p_g_clip = self.clip_gradient(p_g_clip)
grad_list = [elem[1] for elem in p_g]
grad_clip_list = [elem[1] for elem in p_g_clip]
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=8192),
batch_size=128)
train_reader = paddle.batch(paddle.dataset.mnist.train(), batch_size=3)
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[image, label], place=place)
exe.run(startup_program)
count = 0
for data in train_reader():
count += 1
if count > 5:
break
out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list)
out_clip = exe.run(prog_clip,
feed=feeder.feed(data),
fetch_list=grad_clip_list)
global_norm = 0
for v in out:
global_norm += np.sum(np.power(v, 2))
global_norm = np.sqrt(global_norm)
global_norm_clip = 0
for v in out_clip:
global_norm_clip += np.sum(np.power(v, 2))
global_norm_clip = np.sqrt(global_norm_clip)
assert np.isclose(
a=global_norm_clip, b=np.minimum(global_norm, CLIP), rtol=5e-3)
data = next(train_reader())
out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list)
out_clip = exe.run(prog_clip,
feed=feeder.feed(data),
fetch_list=grad_clip_list)
self.check_clip_result(out, out_clip)
def check_sparse_gradient_clip(self, place):
prog = fluid.framework.Program()
......@@ -134,11 +121,7 @@ class TestGradientClip(unittest.TestCase):
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
cost = bow_net(data, label, self.word_dict_len)
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0))
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
sgd_optimizer.minimize(cost)
self.backward_and_optimize(cost)
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[data, label], place=place)
......@@ -150,13 +133,345 @@ class TestGradientClip(unittest.TestCase):
print(val)
self.assertFalse(np.isnan(val))
def test_operators(self):
self.check_operators(core.CPUPlace())
def backward_and_optimize(cost):
pass
class TestGradientClipByGlobalNorm(TestGradientClip):
def init(self):
self.clip_norm = 0.2
def clip_gradient(self, params_grads):
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
print(clip)
return clip(params_grads)
def check_clip_result(self, out, out_clip):
global_norm = 0
for v in out:
global_norm += np.sum(np.power(v, 2))
global_norm = np.sqrt(global_norm)
scale = self.clip_norm / np.maximum(self.clip_norm, global_norm)
res = []
for i in range(len(out)):
out[i] = scale * out[i]
for u, v in zip(out, out_clip):
self.assertTrue(
np.allclose(
a=u, b=v, rtol=1e-5, atol=1e-8),
"gradient clip by global norm has wrong results!")
# test whether the ouput is right when use 'set_gradient_clip'
def test_old_gradient_clip(self):
def func(params_grads):
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
fluid.clip.set_gradient_clip(clip)
return fluid.clip.append_gradient_clip_ops(params_grads)
self.clip_gradient = func
self.check_gradient_clip(fluid.CPUPlace())
# test whether the ouput is right when use 'minimize(grad_clip)'
def test_new_gradient_clip(self):
def func(params_grads):
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
print(clip)
return clip(params_grads)
def test_sparse_gradient_clip(self):
self.clip_gradient = func
self.check_gradient_clip(fluid.CPUPlace())
# invoke 'set_gradient_clip' in a wrong order
def test_wrong_API_order(self):
def backward_func(cost):
# no clip gradient
def fileter_func(param):
return param.name == "fc.w_0"
clip = fluid.clip.GradientClipByGlobalNorm(
clip_norm=5.0, need_clip=fileter_func)
fluid.clip.set_gradient_clip(clip)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
# if 'set_gradient_clip' and 'minimize(grad_clip)' together, 'set_gradient_clip' will be ineffective
sgd_optimizer.minimize(cost, grad_clip=clip)
# 'set_gradient_clip' must before 'minimize', otherwise, 'set_gradient_clip' will be ineffective
fluid.clip.set_gradient_clip(clip)
self.backward_and_optimize = backward_func
for place in self.get_places():
self.check_sparse_gradient_clip(place)
# if grad is None or not need clip
def test_none_grad(self):
def fileter_func(param):
return param.name == "x"
clip = fluid.clip.GradientClipByGlobalNorm(
self.clip_norm, need_clip=fileter_func)
x = fluid.default_main_program().global_block().create_parameter(
name="x", shape=[2, 3], dtype="float32")
y = fluid.default_main_program().global_block().create_parameter(
name="y", shape=[2, 3], dtype="float32")
# (x, None) should not be returned
params_grads = [(x, None), (x, y), (y, x)]
params_grads = clip(params_grads)
self.assertTrue(
len(clip(params_grads)) == 2,
"ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!"
)
self.assertTrue(
params_grads[0][1].name != 'y',
"ClipByGlobalNorm: param_grad (x, y) should be clipped!")
# raise typeError
def test_tpyeError(self):
# the type of need_clip must be an funciton
with self.assertRaises(TypeError):
clip = fluid.clip.GradientClipByGlobalNorm(
clip_norm=self.clip_norm, need_clip="test")
# the type of minimize(grad_clip=) must be an instance of GradientClipBase's derived class
with self.assertRaises(TypeError):
x = fluid.default_main_program().global_block().create_parameter(
name="x", shape=[2, 3], dtype="float32")
loss = fluid.layers.reduce_mean(x)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
sgd_optimizer.minimize(loss, grad_clip="test")
# the type of RecomputeOptimizer.minimize(grad_clip=) must be an instance of GradientClipBase's derived class
with self.assertRaises(TypeError):
x = fluid.default_main_program().global_block().create_parameter(
name="x", shape=[2, 3], dtype="float32")
loss = fluid.layers.reduce_mean(x)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
recompute_optimizer = fluid.optimizer.RecomputeOptimizer(
sgd_optimizer)
recompute_optimizer._set_checkpoints([x])
recompute_optimizer.minimize(loss, grad_clip="test")
class TestGradientClipByNorm(TestGradientClip):
def init(self):
self.clip_norm = 0.2
def clip_gradient(self, params_grads):
clip = fluid.clip.GradientClipByNorm(clip_norm=self.clip_norm)
print(clip)
return clip(params_grads)
def check_clip_result(self, out, out_clip):
for u, v in zip(out, out_clip):
norm = np.sqrt(np.sum(np.power(u, 2)))
scale = self.clip_norm / np.maximum(self.clip_norm, norm)
u = u * scale
self.assertTrue(
np.allclose(
a=u, b=v, rtol=1e-5, atol=1e-8),
"gradient clip by norm has wrong results!")
# test whether the ouput is right when use 'minimize(grad_clip)'
def test_gradient_clip(self):
self.check_gradient_clip(fluid.CPUPlace())
# if grad is None or not need clip
def test_none_grad(self):
def fileter_func(param):
return param.name == "z"
clip = fluid.clip.GradientClipByNorm(
self.clip_norm, need_clip=fileter_func)
x = fluid.default_main_program().global_block().create_parameter(
name="x", shape=[2, 3], dtype="float32")
y = fluid.default_main_program().global_block().create_parameter(
name="y", shape=[2, 3], dtype="float32")
# (x, None) should not be returned
params_grads = [(x, None), (x, y)]
params_grads = clip(params_grads)
self.assertTrue(
len(clip(params_grads)) == 1,
"ClipByNorm: when grad is None, it shouldn't be returned by gradient clip!"
)
self.assertTrue(
params_grads[0][1].name == 'y',
"ClipByNorm: grad should not be clipped when filtered out!")
class TestGradientClipByValue(TestGradientClip):
def init(self):
self.max = 0.2
self.min = 0.1
def clip_gradient(self, params_grads):
clip = fluid.clip.GradientClipByValue(max=self.max, min=self.min)
print(clip)
return clip(params_grads)
def check_clip_result(self, out, out_clip):
for i, v in enumerate(out):
out[i] = np.clip(v, self.min, self.max)
for u, v in zip(out, out_clip):
u = np.clip(u, self.min, self.max)
self.assertTrue(
np.allclose(
a=u, b=v, rtol=1e-6, atol=1e-8),
"gradient clip by value has wrong results!")
# test whether the ouput is right when use 'minimize(grad_clip)'
def test_gradient_clip(self):
self.check_gradient_clip(fluid.CPUPlace())
# if grad is None or not need clip
def test_none_grad(self):
def fileter_func(param):
return param.name == "z"
clip = fluid.clip.GradientClipByValue(
self.max, self.min, need_clip=fileter_func)
x = fluid.default_main_program().global_block().create_parameter(
name="x", shape=[2, 3], dtype="float32")
y = fluid.default_main_program().global_block().create_parameter(
name="y", shape=[2, 3], dtype="float32")
# (x, None) should not be returned
params_grads = [(x, None), (x, y)]
params_grads = clip(params_grads)
self.assertTrue(
len(clip(params_grads)) == 1,
"ClipByValue: when grad is None, it shouldn't be returned by gradient clip!"
)
self.assertTrue(
params_grads[0][1].name == 'y',
"ClipByValue: grad should not be clipped when filtered out!")
class TestDygraphGradientClip(unittest.TestCase):
def test_gradient_clip(self):
with fluid.dygraph.guard():
linear = fluid.dygraph.Linear(5, 5)
inputs = fluid.layers.uniform_random(
[16, 5], min=-10, max=10).astype('float32')
out = linear(fluid.dygraph.to_variable(inputs))
loss = fluid.layers.reduce_mean(out)
loss.backward()
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=0.0, parameter_list=linear.parameters())
self.check_clip_result(loss, sgd_optimizer)
def check_clip_result(self, loss, optimizer):
pass
class TestDygraphGradientClipByGlobalNorm(TestDygraphGradientClip):
def setUp(self):
# only clip gradient of x (ParamBase)
def fileter_func(param):
return param.name == "x"
self.clip_norm = 0.8
self.clip1 = fluid.clip.GradientClipByGlobalNorm(
clip_norm=self.clip_norm, need_clip=fileter_func)
self.clip2 = fluid.clip.GradientClipByGlobalNorm(
clip_norm=self.clip_norm)
def check_clip_result(self, loss, optimizer):
# if grad is None
x = fluid.dygraph.to_variable(
np.array([2, 3]).astype("float32"), name="x")
y = fluid.dygraph.to_variable(
np.array([3, 4]).astype("float32"), name="y")
assert len(self.clip1([(x, x), (x, y), (x, None)])) == 2
# get params and grads from network
opt, params_grads = optimizer.minimize(loss, grad_clip=self.clip2)
_, grads = zip(*params_grads)
params_grads = self.clip2(params_grads)
_, grads_clip = zip(*params_grads)
global_norm = 0
for u in grads:
u = u.numpy()
global_norm += np.sum(np.power(u, 2))
global_norm = np.sqrt(global_norm)
global_norm_clip = 0
for v in grads_clip:
v = v.numpy()
global_norm_clip += np.sum(np.power(v, 2))
global_norm_clip = np.sqrt(global_norm_clip)
a = np.minimum(global_norm, self.clip_norm)
b = global_norm_clip
self.assertTrue(
np.isclose(
a=a, b=b, rtol=1e-6, atol=1e-8),
"gradient clip by global norm has wrong results, expetcd:%f, but recieved:%f"
% (a, b))
class TestDygraphGradientClipByNorm(TestDygraphGradientClip):
def setUp(self):
# only clip gradient of linear_0.w_0 (ParamBase)
def fileter_func(param):
return param.name == "linear_0.w_0"
self.clip_norm = 0.8
self.clip = fluid.clip.GradientClipByNorm(
clip_norm=self.clip_norm, need_clip=fileter_func)
def check_clip_result(self, loss, optimizer):
# if grad is None
x = fluid.dygraph.to_variable(np.array([2, 3]).astype("float32"))
assert len(self.clip([(x, None)])) == 0
# get params and grads from network
self.clip([(fluid.dygraph.to_variable(np.array([2, 3])), None)])
params_grads = optimizer.backward(loss)
_, grads = zip(*params_grads)
params_grads = self.clip(params_grads)
_, grads_clip = zip(*params_grads)
for u, v in zip(grads, grads_clip):
u = u.numpy()
v = v.numpy()
a = np.sqrt(np.sum(np.power(u, 2)))
a = np.minimum(a, self.clip_norm)
b = np.sqrt(np.sum(np.power(v, 2)))
self.assertTrue(
np.isclose(
a=a, b=b, rtol=1e-6, atol=1e-8),
"gradient clip by norm has wrong results, expetcd:%f, but recieved:%f"
% (a, b))
class TestDygraphGradientClipByValue(TestDygraphGradientClip):
def setUp(self):
# only clip gradient of linear_0.w_0 (ParamBase)
def fileter_func(param):
return param.name == "linear_0.w_0"
self.max = 0.2
self.min = 0.1
self.clip = fluid.clip.GradientClipByValue(
max=self.max, min=self.min, need_clip=fileter_func)
def check_clip_result(self, loss, optimizer):
# if grad is None
x = fluid.dygraph.to_variable(np.array([2, 3]).astype("float32"))
assert len(self.clip([(x, None)])) == 0
# get params and grads from network
params_grads = optimizer.backward(loss)
_, grads = zip(*params_grads)
params_grads = self.clip(params_grads)
_, grads_clip = zip(*params_grads)
for u, v in zip(grads, grads_clip):
u = np.clip(u.numpy(), self.min, self.max)
v = v.numpy()
self.assertTrue(
np.allclose(
a=u, b=v, rtol=1e-6, atol=1e-8),
"gradient clip by value has wrong results!")
if __name__ == '__main__':
unittest.main()
......@@ -331,7 +331,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
model = MyLayer(size, vocab_size, size)
optimizer = fluid.optimizer.AdamOptimizer(
0.001, parameter_list=model.parameters())
grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(0.001)
grad_clip = fluid.clip.GradientClipByGlobalNorm(0.001)
indices = fluid.dygraph.to_variable(indices)
embed = fluid.dygraph.to_variable(embed)
......@@ -350,7 +350,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
model = MyLayer2(size, vocab_size, size)
optimizer = fluid.optimizer.AdamOptimizer(
0.001, parameter_list=model.parameters())
grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(0.001)
grad_clip = fluid.clip.GradientClipByGlobalNorm(0.001)
indices = fluid.dygraph.to_variable(indices)
emebd = fluid.dygraph.to_variable(embed)
......
......@@ -49,7 +49,7 @@ class TestSimpleNet(unittest.TestCase):
with fluid.dygraph.guard(place):
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = sort_sum_gradient
# grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(5.0)
# grad_clip = fluid.clip.GradientClipByGlobalNorm(5.0)
input_word = np.array([[1, 2], [2, 1]]).astype('int64')
input = to_variable(input_word)
......@@ -83,8 +83,7 @@ class TestSimpleNet(unittest.TestCase):
with fluid.dygraph.guard(place):
backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = sort_sum_gradient
grad_clip = fluid.dygraph_grad_clip.GradClipByGlobalNorm(
5.0)
grad_clip = fluid.clip.GradientClipByGlobalNorm(5.0)
input_word = np.array([[1, 2], [2, 1]]).astype('int64')
input = to_variable(input_word)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册