未验证 提交 fe0dc40d 编写于 作者: 骑马小猫 提交者: GitHub

[FluidAPI]remove clip api (#48946)

上级 822ea0f9
......@@ -20,11 +20,11 @@ __all__ = []
import paddle
from paddle.common_ops_import import LayerHelper
from paddle.fluid.clip import GradientClipByNorm, append_gradient_clip_ops
from paddle.fluid.dygraph import base as imperative_base
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.optimizer import Momentum, Optimizer
from paddle.framework import core
from paddle.nn.clip import ClipGradByNorm, append_gradient_clip_ops
from paddle.static import create_global_var
......@@ -76,9 +76,9 @@ class DGCMomentumOptimizer(Optimizer):
self._dgc_clip_norm = None
if grad_clip is not None:
if not isinstance(grad_clip, GradientClipByNorm):
if not isinstance(grad_clip, ClipGradByNorm):
raise TypeError(
"The type of grad_clip should be 'GradientClipByNorm', because DGCMomentumOptimizer only support GradientClipByNorm"
"The type of grad_clip should be 'ClipGradByNorm', because DGCMomentumOptimizer only support ClipGradByNorm"
)
assert isinstance(num_trainers, int), (
"The type of num_trainers should be 'int', but received %s"
......
......@@ -15,9 +15,8 @@
import paddle
from paddle import framework
from paddle.autograd import no_grad
from paddle.fluid import layers
from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.framework import core
from paddle.nn import ClipGradByGlobalNorm, clip
from ...base.topology import ParallelMode
from ...utils.hybrid_parallel_util import (
......@@ -62,8 +61,8 @@ class HybridParallelClipGrad:
continue
merge_grad = g
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
merge_grad = clip.merge_selected_rows(g)
merge_grad = clip.get_tensor_from_selected_rows(merge_grad)
square = paddle.square(merge_grad)
sum_square = paddle.sum(square)
......
......@@ -30,7 +30,7 @@ import paddle
import paddle.distributed as dist
from paddle.distributed import ParallelMode, fleet
from paddle.fluid import core
from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.nn import ClipGradByGlobalNorm
from paddle.optimizer import Optimizer
HybridParallelClipGrad = (
......
......@@ -25,8 +25,8 @@ import paddle.fluid.framework as framework
from paddle import nn
from paddle.autograd import PyLayer
from paddle.distributed import collective
from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.fluid.framework import EagerParamBase
from paddle.nn import ClipGradByGlobalNorm
from .group_sharded_storage import GradStorage
from .group_sharded_utils import GroupShardedClipGrad, Type, device_guard
......
......@@ -23,6 +23,7 @@ from paddle import _legacy_C_ops
from paddle.fluid import core, layers
from paddle.fluid.dygraph import to_variable
from paddle.fluid.framework import dygraph_only
from paddle.nn import clip
class Taskflow:
......@@ -65,8 +66,8 @@ class GroupShardedClipGrad:
merge_grad = g
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.get_tensor_from_selected_rows(
layers.merge_selected_rows(g)
merge_grad = clip.get_tensor_from_selected_rows(
clip.merge_selected_rows(g)
)
square = paddle.square(merge_grad)
sum_square = paddle.sum(square)
......
......@@ -159,7 +159,7 @@ def auc(stat_pos, stat_neg, scope=None, util=None):
.. code-block:: python
# in model.py
similarity_norm = fluid.layers.sigmoid(fluid.layers.clip(output, min=-15.0, max=15.0))
similarity_norm = fluid.layers.sigmoid(paddle.clip(output, min=-15.0, max=15.0))
binary_predict = fluid.layers.concat(
input=[paddle.subtract(fluid.layers.ceil(similarity_norm), similarity_norm), similarity_norm], axis=1)
self.auc, batch_auc, [batch_stat_pos, batch_stat_neg, stat_pos, stat_neg] =
......
......@@ -90,7 +90,6 @@ from .transpiler import (
DistributeTranspilerConfig,
)
from .lod_tensor import create_lod_tensor, create_random_int_lodtensor
from . import clip
from . import profiler
from . import unique_name
from . import parallel_executor
......@@ -164,7 +163,6 @@ __all__ = (
'ParamAttr',
'WeightNormParamAttr',
'DataFeeder',
'clip',
'profiler',
'unique_name',
'Scope',
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import warnings
import functools
import paddle
from . import layers
from . import framework
from . import core
from . import name_scope
from .dygraph import base as imperative_base
from .data_feeder import check_variable_and_dtype
from .framework import in_dygraph_mode
from .layer_helper import LayerHelper
from .framework import default_main_program
from paddle import _C_ops, _legacy_C_ops
__all__ = [
'set_gradient_clip',
'ErrorClipByValue',
'ClipGradByValue',
'ClipGradByNorm',
'ClipGradByGlobalNorm',
]
_clip_by_global_norm_using_mp_type_flag = False
def _clip_by_global_norm_using_mp_type(*args):
global _clip_by_global_norm_using_mp_type_flag
assert len(args) <= 1
if len(args) == 1:
assert isinstance(args[0], bool)
old_value = _clip_by_global_norm_using_mp_type_flag
_clip_by_global_norm_using_mp_type_flag = args[0]
return old_value
else:
return _clip_by_global_norm_using_mp_type_flag
def _cast_to_mp_type_if_enabled(x):
if (
x.dtype == core.VarDesc.VarType.FP16
or x.dtype == core.VarDesc.VarType.BF16
) and _clip_by_global_norm_using_mp_type():
return x.astype(core.VarDesc.VarType.FP32)
else:
return x
def _squared_l2_norm(x):
r"""
This OP returns the squared L2 norm of a tensor.
"""
x = _cast_to_mp_type_if_enabled(x)
if (
core.is_compiled_with_xpu()
or x.dtype == core.VarDesc.VarType.FP16
or x.dtype == core.VarDesc.VarType.BF16
):
square = paddle.square(x)
sum_square = paddle.sum(square)
return sum_square
if in_dygraph_mode():
return _C_ops.squared_l2_norm(x)
else:
op_type = 'squared_l2_norm'
check_variable_and_dtype(x, 'x', ['float32', 'float64'], op_type)
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(x.dtype)
inputs = {"X": x}
outputs = {'Out': out}
helper.append_op(type=op_type, inputs=inputs, outputs=outputs)
return out
class BaseErrorClipAttr:
def __str__(self):
raise NotImplementedError()
def _append_clip_op(self, block, grad_name):
raise NotImplementedError()
class ErrorClipByValue(BaseErrorClipAttr):
r"""
Clips tensor values to the range [min, max].
Given a tensor ``t`` (see Examples below), this operation clips its value \
to ``min`` and ``max`` inplace.
- Any values less than min are set to min.
- Any values greater than max are set to max.
Args:
max (float): The maximum value to clip by.
min (float, optional): The minimum value to clip by. if not set by user, \
will be set to ``-max`` by framework.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
paddle.enable_static()
BATCH_SIZE = 128
CLIP_MAX = 2e-6
CLIP_MIN = -1e-6
prog = fluid.framework.Program()
with fluid.program_guard(main_program=prog):
image = fluid.layers.data(
name='x', shape=[784], dtype='float32')
hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
predict = fluid.layers.fc(
input=hidden2, size=10, act='softmax')
label = fluid.layers.data(name='y', shape=[1], dtype='int64')
cost = paddle.nn.functional.cross_entropy(input=predict, label=label, reduction='none', use_softmax=False)
avg_cost = paddle.mean(cost)
prog_clip = prog.clone()
prog_clip.block(0).var(hidden1.name)._set_error_clip(
fluid.clip.ErrorClipByValue(
max=CLIP_MAX, min=CLIP_MIN
)
)
"""
def __init__(self, max, min=None):
max = float(max)
if min is None:
min = -max
else:
min = float(min)
self.max = max
self.min = min
def __str__(self):
return "ByValue, min=%f, max=%f" % (self.min, self.max)
def _append_clip_op(self, block, grad_name):
clip_op_desc = block.desc.append_op()
clip_op_desc.set_type("clip")
clip_op_desc.set_input("X", [grad_name])
clip_op_desc.set_output("Out", [grad_name])
clip_op_desc._set_attr("min", self.min)
clip_op_desc._set_attr("max", self.max)
def error_clip_callback(block, context):
# the context is a grad_to_var map
grad_to_var = context
op_desc = block.desc.op(block.desc.op_size() - 1)
for grad_n in [n for n in op_desc.output_arg_names() if n in grad_to_var]:
fwd_var = block._var_recursive(grad_to_var[grad_n])
error_clip = getattr(fwd_var, "error_clip", None)
if not (
error_clip is None or isinstance(error_clip, BaseErrorClipAttr)
):
raise TypeError(
"Variable's error_clip should be an instance of BaseErrorClipAttr or None."
)
if error_clip is not None:
error_clip._append_clip_op(block, grad_n)
class ClipGradBase:
def __init__(self):
super().__init__()
def __str__(self):
raise NotImplementedError()
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
raise NotImplementedError
def _static_clip(self, params_grads):
raise NotImplementedError
def __call__(self, params_grads):
if in_dygraph_mode():
return self._dygraph_clip(params_grads)
else:
for p, g in params_grads:
if getattr(p, 'gradient_clip_attr', None) is not None:
warnings.warn(
"'set_gradient_clip' will be ineffective, because you have "
"set 'need_clip' in 'ParamAttr'. So, 'set_gradient_clip' "
"is redundant and you can remove it."
)
break
return self._static_clip(params_grads)
def _process_context(self, context, param, grad):
raise NotImplementedError()
def _create_operators(self, param, grad):
raise NotImplementedError()
class ClipGradByValue(ClipGradBase):
"""
Limit the value of multi-dimensional Tensor :math:`X` to the range [min, max].
- Any values less than min are set to ``min``.
- Any values greater than max are set to ``max``.
The multi-dimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
(for example: :ref:`api_paddle_optimizer_SGD`).
Note:
``need_clip`` of ``ClipGradByValue`` HAS BEEN DEPRECATED since 2.0.
Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.
Args:
max (float): The maximum value to clip by.
min (float, optional): The minimum value to clip by. if not set by user, it will be set to ``-max``
automatically. In this case, ``max`` must be greater than 0.
Examples:
.. code-block:: python
import paddle
x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
linear = paddle.nn.Linear(in_features=10, out_features=10,
weight_attr=paddle.ParamAttr(need_clip=True),
bias_attr=paddle.ParamAttr(need_clip=False))
out = linear(x)
loss = paddle.mean(out)
loss.backward()
clip = paddle.nn.ClipGradByValue(min=-1, max=1)
sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), grad_clip=clip)
sdg.step()
"""
def __init__(self, max, min=None):
super().__init__()
if min is None:
assert max > 0.0
min = -max
self.max = float(max)
self.min = float(min)
def __str__(self):
return "Clip Gradient By Value, min = %f, max=%f" % (self.min, self.max)
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
new_grad = paddle.clip(x=g, min=self.min, max=self.max)
params_and_grads.append((p, new_grad))
return params_and_grads
def _static_clip(self, params_grads):
params_and_grads = []
param_new_grad_name_dict = dict()
with framework.name_scope('gradient_clip'):
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
with p.block.program._optimized_guard([p, g]):
new_grad = layers.clip(x=g, min=self.min, max=self.max)
params_and_grads.append((p, new_grad))
param_new_grad_name_dict[p.name] = new_grad.name
_correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
return params_and_grads
def _process_context(self, context, param, grad):
pass
def _create_operators(self, param, grad):
new_grad = layers.clip(x=grad, min=self.min, max=self.max)
return param, new_grad
class ClipGradByNorm(ClipGradBase):
r"""
Limit the l2 norm of multi-dimensional Tensor :math:`X` to ``clip_norm`` .
- If the l2 norm of :math:`X` is greater than ``clip_norm`` , :math:`X` will be compressed by a ratio.
- If the l2 norm of :math:`X` is less than or equal to ``clip_norm`` , nothing will be done.
The multidimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
(for example: :ref:`api_paddle_optimizer_SGD`).
The clipping formula is:
.. math::
Out =
\left\{
\begin{array}{ccl}
X & & if (norm(X) \leq clip\_norm) \\
\frac{clip\_norm*X}{norm(X)} & & if (norm(X) > clip\_norm) \\
\end{array}
\right.
where :math:`norm(X)` represents the L2 norm of :math:`X`.
.. math::
norm(X) = ( \sum_{i=1}^{n}|x\_i|^2)^{ \frac{1}{2}}
Note:
``need_clip`` of ``ClipGradByNorm`` HAS BEEN DEPRECATED since 2.0.
Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.
Args:
clip_norm(float): The maximum norm value.
Examples:
.. code-block:: python
import paddle
x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
linear = paddle.nn.Linear(in_features=10, out_features=10,
weight_attr=paddle.ParamAttr(need_clip=True),
bias_attr=paddle.ParamAttr(need_clip=False))
out = linear(x)
loss = paddle.mean(out)
loss.backward()
clip = paddle.nn.ClipGradByNorm(clip_norm=1.0)
sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), grad_clip=clip)
sdg.step()
"""
def __init__(self, clip_norm):
super().__init__()
self.clip_norm = float(clip_norm)
def __str__(self):
return "Gradient Clip By Norm, clip_norm=%f" % self.clip_norm
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
new_grad = layers.clip_by_norm(x=g, max_norm=self.clip_norm)
params_and_grads.append((p, new_grad))
return params_and_grads
def _static_clip(self, params_grads):
params_and_grads = []
with framework.name_scope('gradient_clip'):
param_new_grad_name_dict = dict()
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
with p.block.program._optimized_guard([p, g]):
new_grad = layers.clip_by_norm(x=g, max_norm=self.clip_norm)
param_new_grad_name_dict[p.name] = new_grad.name
params_and_grads.append((p, new_grad))
_correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
return params_and_grads
def _process_context(self, context, param, grad):
pass
def _create_operators(self, param, grad):
new_grad = layers.clip_by_norm(x=grad, max_norm=self.clip_norm)
return param, new_grad
_allow_pure_fp16_global_norm_clip_flag = False
def _allow_pure_fp16_global_norm_clip(*args):
global _allow_pure_fp16_global_norm_clip_flag
if len(args) == 0:
return _allow_pure_fp16_global_norm_clip_flag
else:
assert len(args) == 1 and isinstance(args[0], bool)
old_value = _allow_pure_fp16_global_norm_clip_flag
_allow_pure_fp16_global_norm_clip_flag = args[0]
return old_value
class ClipGradByGlobalNorm(ClipGradBase):
r"""
Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in
:math:`t\_list` , and limit it to ``clip_norm`` .
- If the global norm is greater than ``clip_norm`` , all elements of :math:`t\_list` will be compressed by a ratio.
- If the global norm is less than or equal to ``clip_norm`` , nothing will be done.
The list of Tensor :math:`t\_list` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
(for example: :ref:`api_paddle_optimizer_SGD`).
The clipping formula is:
.. math::
t\_list[i] = t\_list[i] * \frac{clip\_norm}{\max(global\_norm, clip\_norm)}
where:
.. math::
global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2}
Note:
``need_clip`` of ``ClipGradyGlobalNorm`` HAS BEEN DEPRECATED since 2.0.
Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.
Args:
clip_norm (float): The maximum norm value.
group_name (str, optional): The group name for this clip. Default value is ``default_group``.
Examples:
.. code-block:: python
import paddle
x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
linear = paddle.nn.Linear(in_features=10, out_features=10,
weight_attr=paddle.ParamAttr(need_clip=True),
bias_attr=paddle.ParamAttr(need_clip=False))
out = linear(x)
loss = paddle.mean(out)
loss.backward()
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), grad_clip=clip)
sdg.step()
"""
def __init__(
self, clip_norm, group_name="default_group", auto_skip_clip=False
):
super().__init__()
self.clip_norm = float(clip_norm)
self.group_name = group_name
assert isinstance(auto_skip_clip, bool)
self.auto_skip_clip = auto_skip_clip
def __str__(self):
return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm)
@imperative_base.no_grad
def _dygraph_clip(self, params_grads):
params_and_grads = []
sum_square_list = []
sum_square_list_fp16 = []
sum_square_list_fp32 = []
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
continue
merge_grad = g
if in_dygraph_mode() and g.is_selected_rows():
merge_grad = layers.merge_selected_rows(g)
merge_grad = merge_grad._get_tensor_from_selected_rows()
elif g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
sum_square = _squared_l2_norm(merge_grad)
if (
sum_square.dtype == core.VarDesc.VarType.FP16
or sum_square.dtype == core.VarDesc.VarType.BF16
):
sum_square_list_fp16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.FP32:
sum_square_list_fp32.append(sum_square)
else:
sum_square_list.append(sum_square)
# all parameters have been filterd out
if (
len(sum_square_list)
+ len(sum_square_list_fp16)
+ len(sum_square_list_fp32)
== 0
):
return params_grads
sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32"
global_norm_var = []
if len(sum_square_list_fp16) > 0:
global_norm_var_fp16 = paddle.add_n(sum_square_list_fp16)
global_norm_var.append(global_norm_var_fp16.astype(sum_dtype))
if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32)
if sum_dtype == 'float32':
global_norm_var.append(global_norm_var_fp32)
else:
global_norm_var.append(global_norm_var_fp32.astype(sum_dtype))
if len(sum_square_list) > 0:
global_norm_var_fp64 = paddle.add_n(sum_square_list)
global_norm_var.append(global_norm_var_fp64)
global_norm_var = paddle.add_n(global_norm_var)
global_norm_var = paddle.sqrt(global_norm_var)
max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm
)
need_clip = False
if not self.auto_skip_clip: # always apply clip
need_clip = True
clip_var = paddle.divide(
x=max_global_norm,
y=paddle.maximum(x=global_norm_var, y=max_global_norm),
)
elif global_norm_var > max_global_norm:
# only when global_norm_var > max_global_norm, grad need clip
need_clip = True
clip_var = paddle.divide(x=max_global_norm, y=global_norm_var)
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
# TODO(wangxi): use inplace elementwise_mul
if need_clip:
clip_input = (
clip_var.astype(g.dtype)
if clip_var.dtype != g.dtype
else clip_var
)
new_grad = paddle.multiply(g, clip_input)
params_and_grads.append((p, new_grad))
else:
params_and_grads.append((p, g))
return params_and_grads
def _static_clip(self, params_grads):
params_and_grads = []
sum_square_list = []
sum_square_list_fp16 = []
sum_square_list_fp32 = []
with framework.name_scope('gradient_clip'):
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
continue
merge_grad = g
with p.block.program._optimized_guard([p, g]):
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(
merge_grad
)
sum_square = _squared_l2_norm(merge_grad)
if sum_square.dtype == core.VarDesc.VarType.FP16:
sum_square_list_fp16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.FP32:
sum_square_list_fp32.append(sum_square)
else:
sum_square_list.append(sum_square)
# all parameters have been filterd out
if (
len(sum_square_list)
+ len(sum_square_list_fp16)
+ len(sum_square_list_fp32)
== 0
):
return params_grads
with p.block.program._optimized_guard([p, g]):
sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32"
global_norm_var = []
if len(sum_square_list_fp16) > 0:
global_norm_var_fp16 = layers.sums(sum_square_list_fp16)
if (
sum_square_list_fp32
or sum_square_list
or not _allow_pure_fp16_global_norm_clip()
):
global_norm_var.append(
global_norm_var_fp16.astype(sum_dtype)
)
else:
global_norm_var.append(global_norm_var_fp16)
if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = layers.sums(sum_square_list_fp32)
if sum_dtype == 'float32':
global_norm_var.append(global_norm_var_fp32)
else:
global_norm_var.append(
global_norm_var_fp32.astype(sum_dtype)
)
if len(sum_square_list) > 0:
# fp64
global_norm_var_other_dtype = layers.sums(sum_square_list)
global_norm_var.append(global_norm_var_other_dtype)
global_norm_var = (
layers.sums(global_norm_var)
if len(global_norm_var) > 1
else global_norm_var[0]
)
global_norm_var = paddle.sqrt(x=global_norm_var)
max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm
)
scale_var = paddle.divide(
x=max_global_norm,
y=paddle.maximum(x=max_global_norm, y=global_norm_var),
)
param_new_grad_name_dict = dict()
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
with p.block.program._optimized_guard([p, g]):
new_g = _cast_to_mp_type_if_enabled(g)
# inplace
scale_input = (
scale_var.astype('float16')
if new_g.dtype == core.VarDesc.VarType.FP16
and scale_var.dtype != core.VarDesc.VarType.FP16
else scale_var
)
# NOTE(Yuang Liu): For pure dp with gradient merge, the p and g
# will be in different blocks with the gradient clip related ops.
# We need to handle the correct block, otherwise will encounter
# a 'NotFoundError' during compile time.
block = default_main_program().current_block()
block.append_op(
type='elementwise_mul',
inputs={'X': new_g, 'Y': scale_input},
outputs={'Out': new_g},
)
if new_g is not g:
block.append_op(
type='cast',
inputs={'X': new_g},
outputs={'Out': g},
attrs={
'in_dtype': new_g.dtype,
'out_dtype': g.dtype,
},
)
param_new_grad_name_dict[p.name] = g.name
params_and_grads.append((p, g))
_correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
return params_and_grads
def _process_context(self, context, param, grad):
if self.group_name not in context:
context[self.group_name] = []
context[self.group_name + "_clip_value"] = self.clip_norm
context[self.group_name + "_clip"] = layers.fill_constant(
shape=[1], dtype=grad.dtype, value=self.clip_norm
)
else:
if not self.clip_norm == context[self.group_name + "_clip_value"]:
raise ValueError(
"All parameters' 'clip_norm' of a same group should be the same"
)
merge_grad = grad
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(grad)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
local_norm_var = _squared_l2_norm(merge_grad)
context[self.group_name].append(local_norm_var)
self.context = context
def _create_operators(self, param, grad):
group_scale_name = self.group_name + "_scale"
if group_scale_name not in self.context:
group_norm_var = layers.sums(input=self.context[self.group_name])
group_norm_var = paddle.sqrt(x=group_norm_var)
clip_var = self.context[self.group_name + "_clip"]
group_scale_var = paddle.divide(
x=clip_var,
y=paddle.maximum(x=clip_var, y=group_norm_var),
)
assert group_scale_var.shape == (1,)
self.context[group_scale_name] = group_scale_var
# inplace
param.block.append_op(
type='elementwise_mul',
inputs={'X': grad, 'Y': self.context[group_scale_name]},
outputs={'Out': grad},
)
return param, grad
@framework.dygraph_not_support
def set_gradient_clip(clip, param_list=None, program=None):
"""
:api_attr: Static Graph
Warning:
This API must be used after building network, and before ``minimize`` ,
and it may be removed in future releases, so it is not recommended.
It is recommended to set ``grad_clip`` when initializing the ``optimizer`` ,
this is a better method to clip gradient. There are three clipping strategies:
:ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` .
To specify parameters that require gradient clip.
Args:
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default value: None, and there is no
gradient clipping.
param_list (list(Variable), optional): Parameters that require gradient clip.
It can be a list of parameter or a list of parameter's name.
Default None, meaning that all parameters in the program will be included.
program (Program, optional): The program where parameters are located.
Default None, meaning that using :ref:`api_fluid_default_main_program` .
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
def network():
image = fluid.data(name='image', shape=[
None, 28], dtype='float32')
param_attr1 = fluid.ParamAttr("fc1_param")
fc1 = fluid.layers.fc(image, size=10, param_attr=param_attr1)
param_attr2 = fluid.ParamAttr("fc2_param")
fc2 = fluid.layers.fc(fc1, size=10, param_attr=param_attr2)
loss = fluid.layers.reduce_mean(fc2)
return loss
# network 1: clip all parameter gradient
with fluid.program_guard(fluid.Program(), fluid.Program()):
loss = network()
fluid.clip.set_gradient_clip(
fluid.clip.GradientClipByGlobalNorm(clip_norm=2.0))
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss)
# network 2: clip parameter gradient by name
with fluid.program_guard(fluid.Program(), fluid.Program()):
loss = network()
fluid.clip.set_gradient_clip(
fluid.clip.GradientClipByValue(min=-1.0, max=1.0),
param_list=["fc1_param", "fc2_param"])
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss)
# network 3: clip parameter gradient by value
with fluid.program_guard(fluid.Program(), fluid.Program()):
loss = network()
param_var1 = fluid.default_main_program().global_block().var("fc1_param")
param_var2 = fluid.default_main_program().global_block().var("fc2_param")
fluid.clip.set_gradient_clip(
fluid.clip.GradientClipByValue(min=-1.0, max=1.0),
param_list=[param_var1, param_var2])
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss)
# network 4: use 'set_gradient_clip' and 'optimize(grad_clip=clip)' together
with fluid.program_guard(fluid.Program(), fluid.Program()):
loss = network()
clip1 = fluid.clip.GradientClipByValue(min=-1.0, max=1.0)
clip2 = fluid.clip.GradientClipByNorm(clip_norm=1.0)
# Set the gradient clipping strategy: clip1
fluid.clip.set_gradient_clip(clip1)
# Set the gradient clipping strategy: clip2
sgd = fluid.optimizer.SGD(learning_rate=1e-3, grad_clip=clip2)
sgd.minimize(loss)
# 'set_gradient_clip' will not take effect when setting has a conflict,
# and the gradient clipping strategy will be 'clip2'
"""
warnings.warn(
"Caution! 'set_gradient_clip' is not recommended "
"and may be deprecated in future! "
"We recommend a new strategy: set 'grad_clip' "
"when initializing the 'optimizer'. "
"This method can reduce the mistakes, please "
"refer to documention of 'optimizer'."
)
if not isinstance(clip, ClipGradBase):
raise TypeError(
"'clip' should be an instance of ClipGradBase's derived class"
)
if program is None:
program = framework.default_main_program()
for op in program.block(0).ops:
if 'op_namescope' in op.all_attrs() and "optimizer" in op.attr(
"op_namescope"
):
warnings.warn(
"'minimize' has been invoked before, this will make 'set_gradient_clip' "
"be ineffective! Please invoke 'set_gradient_clip' before 'minimize'."
)
break
if param_list is None:
param_list = program.block(0).all_parameters()
if all(isinstance(elem, str) for elem in param_list):
param_list = [program.block(0).var(elem) for elem in param_list]
if not all(isinstance(elem, framework.Parameter) for elem in param_list):
raise TypeError(
"'param_list' should be a list of Parameter or basestring(parameter's name)."
)
for param in param_list:
param.gradient_clip_attr = copy.deepcopy(clip)
def append_gradient_clip_ops(param_grads):
context = dict()
for p, g in param_grads:
if g is None:
continue
with p.block.program._optimized_guard([p, g]), framework.name_scope(
'gradient_clip'
):
clip_attr = getattr(p, 'gradient_clip_attr', None)
if clip_attr is None:
return param_grads
if not isinstance(clip_attr, ClipGradBase):
raise TypeError(
"clip attribute should be an instance of GradientClipBase"
)
clip_attr._process_context(context=context, param=p, grad=g)
res = []
param_new_grad_name_dict = dict()
for p, g in param_grads:
if g is None:
continue
with p.block.program._optimized_guard([p, g]), framework.name_scope(
'gradient_clip'
):
param, new_grad = clip_attr._create_operators(param=p, grad=g)
param_new_grad_name_dict[param.name] = new_grad.name
res.append([param, new_grad])
_correct_clip_op_role_var(res, param_new_grad_name_dict)
return res
# change wrong mapping relation between param & grad in clip op
# Note: This function is sensitive to the time cost of the network with gradient clipping
# and should not be changed easily. If you must change, please test the time cost.
def _correct_clip_op_role_var(params_grads, param_new_grad_name_dict):
block_id_list = []
if len(param_new_grad_name_dict) == 0:
return
for param, grad in params_grads:
if grad is None:
continue
block_id = param.block.idx
if block_id in block_id_list:
continue
block_id_list.append(block_id)
for op in param.block.program.global_block().ops:
if (
op.has_attr("op_namescope")
and "gradient_clip" in op.attr("op_namescope")
and op.attr('op_role_var')
):
param_name = op.attr('op_role_var')[0]
if param_name in param_new_grad_name_dict:
correct_p_g = [
param_name,
param_new_grad_name_dict[param_name],
]
op._set_attr('op_role_var', correct_p_g)
GradientClipBase = ClipGradBase
GradientClipByValue = ClipGradByValue
GradientClipByNorm = ClipGradByNorm
GradientClipByGlobalNorm = ClipGradByGlobalNorm
......@@ -185,7 +185,7 @@ class FleetUtil:
# below is part of model
emb = my_slot_net(slots, label) # emb can be fc layer of size 1
similarity_norm = fluid.layers.sigmoid(fluid.layers.clip(\
similarity_norm = fluid.layers.sigmoid(paddle.clip(\
emb, min=-15.0, max=15.0), name="similarity_norm")\
binary_predict = fluid.layers.concat(input=[\
paddle.subtract(\
......@@ -1374,7 +1374,7 @@ class FleetUtil:
label = fluid.layers.data(name="click", shape=[-1, 1],\
dtype="int64", lod_level=0, append_batch_size=False)
emb = my_slot_net(slots, label) # emb can be fc layer of size 1
similarity_norm = fluid.layers.sigmoid(fluid.layers.clip(\
similarity_norm = fluid.layers.sigmoid(paddle.clip(\
emb, min=-15.0, max=15.0), name="similarity_norm")\
binary_predict = fluid.layers.concat(input=[\
paddle.subtract(\
......@@ -1574,7 +1574,7 @@ class FleetUtil:
label = fluid.layers.data(name="click", shape=[-1, 1],\
dtype="int64", lod_level=0, append_batch_size=False)
emb = my_slot_net(slots, label) # emb can be fc layer of size 1
similarity_norm = fluid.layers.sigmoid(fluid.layers.clip(\
similarity_norm = fluid.layers.sigmoid(paddle.clip(\
emb, min=-15.0, max=15.0), name="similarity_norm")\
binary_predict = fluid.layers.concat(input=[\
paddle.subtract(\
......
......@@ -63,10 +63,6 @@ __all__ = [
'fc',
'embedding',
'autoincreased_step_counter',
'clip',
'clip_by_norm',
'merge_selected_rows',
'get_tensor_from_selected_rows',
]
OP_NAMEMAPPING = {
......@@ -997,199 +993,3 @@ def _logical_op(op_name, x, y, out=None, name=None, binary_op=True):
)
return out
@templatedoc()
def clip(x, min, max, name=None):
"""
:old_api: paddle.fluid.layers.clip
${comment}
Args:
x(${x_type}): ${x_comment}
min(float): ${min_comment}
max(float): ${max_comment}
name(str, optional): The default value is None.
Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`
Returns:
${out_comment}
Return Type:
${out_type}
Examples:
.. code-block:: python
import paddle.fluid as fluid
input = fluid.data(
name='data', shape=[1], dtype='float32')
reward = fluid.layers.clip(x=input, min=-1.0, max=1.0)
"""
helper = LayerHelper("clip", **locals())
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'clip')
if name is None:
name = unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])
)
out = helper.create_variable(
type=x.type, name=name, dtype=x.dtype, persistable=False
)
helper.append_op(
type="clip",
inputs={"X": x},
attrs={"min": min, "max": max},
outputs={"Out": out},
)
return out
@templatedoc()
def clip_by_norm(x, max_norm, name=None):
"""
${comment}
Args:
x(${x_type}): ${x_comment}
max_norm(${max_norm_type}): ${max_norm_comment}
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Tensor:
out(${out_type}): ${out_comment}
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
input = paddle.to_tensor([[2.0, 2.0], [2.0, 2.0]], dtype='float32')
reward = fluid.layers.clip_by_norm(x=input, max_norm=1.0)
# [[0.5, 0.5], [0.5, 0.5]]
"""
if in_dygraph_mode():
return _C_ops.clip_by_norm(x, max_norm)
else:
helper = LayerHelper("clip_by_norm", **locals())
check_variable_and_dtype(x, 'X', ['float32', 'float16'], 'clip_by_norm')
check_type(max_norm, 'max_norm', (float), 'clip_by_norm')
if name is None:
name = unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])
)
out = helper.create_variable(
type=x.type, name=name, dtype=x.dtype, persistable=False
)
helper.append_op(
type="clip_by_norm",
inputs={"X": x},
attrs={"max_norm": max_norm},
outputs={"Out": out},
)
return out
@templatedoc()
def merge_selected_rows(x, name=None):
"""
${comment}
Args:
x(${x_type}): ${x_comment}
name(basestring|None): Name of the output.
Returns:
out(${out_type}): ${out_comment}
Examples:
.. code-block:: python
import paddle.fluid as fluid
b = fluid.default_main_program().global_block()
var = b.create_var(
name="X", dtype="float32", persistable=True,
type=fluid.core.VarDesc.VarType.SELECTED_ROWS)
y = fluid.layers.merge_selected_rows(var)
"""
if in_dygraph_mode():
return _C_ops.merge_selected_rows(x)
else:
helper = LayerHelper("merge_selected_rows", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="merge_selected_rows",
inputs={"X": x},
attrs={},
outputs={"Out": out},
)
return out
@templatedoc()
def get_tensor_from_selected_rows(x, name=None):
"""
This operator gets tensor data from input with SelectedRows type, and outputs a LoDTensor.
.. code-block:: text
input x is SelectedRows:
x.rows = [0, 5, 5, 4, 19]
x.height = 20
x.value = [[1, 1] [2, 2] [2, 2] [3, 3] [6, 6]]
Output is LoDTensor:
out.shape = [5, 2]
out.data = [[1, 1],
[2, 2],
[2, 2],
[3, 3],
[6, 6]]
Args:
x(SelectedRows): Input with SelectedRows type. The data type is float32, float64, int32 or int64.
name(str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Variable: LoDTensor transformed from SelectedRows. The data type is same with input.
Examples:
.. code-block:: python
import paddle.fluid as fluid
b = fluid.default_main_program().global_block()
input = b.create_var(name="X", dtype="float32", persistable=True, type=fluid.core.VarDesc.VarType.SELECTED_ROWS)
out = fluid.layers.get_tensor_from_selected_rows(input)
"""
check_type(x, 'x', Variable, 'get_tensor_from_selected_rows')
if x.type != core.VarDesc.VarType.SELECTED_ROWS:
raise TypeError(
"The type of 'x' in get_tensor_from_selected_rows must be SELECTED_ROWS."
)
helper = LayerHelper('get_tensor_from_selected_rows', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='get_tensor_from_selected_rows',
inputs={'X': x},
outputs={'Out': out},
attrs={},
)
return out
......@@ -38,13 +38,6 @@ from .backward import (
_append_grad_suffix_,
_get_no_grad_set_name,
)
from .clip import (
GradientClipBase,
GradientClipByNorm,
error_clip_callback,
append_gradient_clip_ops,
ClipGradByGlobalNorm,
)
from .framework import program_guard
from .initializer import Constant
from .layer_helper import LayerHelper
......@@ -160,7 +153,7 @@ class Optimizer:
)
if grad_clip is not None:
if not isinstance(grad_clip, GradientClipBase):
if not isinstance(grad_clip, paddle.nn.clip.GradientClipBase):
raise TypeError(
"'grad_clip' should be an instance of GradientClipBase's derived class"
)
......@@ -1030,7 +1023,7 @@ class Optimizer:
params_grads.append((param, grad_var))
else:
if callbacks is None:
callbacks = [error_clip_callback]
callbacks = [paddle.nn.clip.error_clip_callback]
else:
assert isinstance(callbacks, list)
program = loss.block.program
......@@ -1260,7 +1253,7 @@ class Optimizer:
# NOTE(zhiqiu): currently, only support ClipGradByGlobalNorm and without regularization.
if self._flatten_param_grads and self.regularization is None:
if self._grad_clip is None or isinstance(
self._grad_clip, ClipGradByGlobalNorm
self._grad_clip, paddle.nn.ClipGradByGlobalNorm
):
params_grads = self.flatten_param_grads(params_grads)
......@@ -1268,7 +1261,7 @@ class Optimizer:
if self._grad_clip is not None:
params_grads = self._grad_clip(params_grads)
else:
params_grads = append_gradient_clip_ops(params_grads)
params_grads = paddle.nn.clip.append_gradient_clip_ops(params_grads)
# Add regularization if any
params_grads = self.append_regularization_ops(
......
......@@ -38,13 +38,13 @@ with fluid.program_guard(main_program=prog):
prog_clip = prog.clone()
prog_clip.block(0).var(hidden1.name)._set_error_clip(
fluid.clip.ErrorClipByValue(max=CLIP_MAX, min=CLIP_MIN)
paddle.nn.clip.ErrorClipByValue(max=CLIP_MAX, min=CLIP_MIN)
)
avg_cost_clip = prog_clip.block(0).var(avg_cost.name)
fluid.backward.append_backward(loss=avg_cost)
fluid.backward.append_backward(
loss=avg_cost_clip, callbacks=[fluid.clip.error_clip_callback]
loss=avg_cost_clip, callbacks=[paddle.nn.clip.error_clip_callback]
)
hidden1_grad = prog.block(0).var(hidden1.name + "@GRAD")
......
......@@ -122,7 +122,7 @@ class TestDistMnist2x2(TestDistRunnerBase):
opt = paddle.optimizer.AdamW(
learning_rate=lr_val,
grad_clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0),
grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0),
)
acc_steps = 2 # accumulated steps for pipeline
......
......@@ -122,7 +122,7 @@ class TestDistMnist2x2(TestDistRunnerBase):
opt = fluid.optimizer.Momentum(
learning_rate=lr_val,
momentum=0.9,
grad_clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0),
grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0),
)
acc_steps = 2 # accumulated steps for pipeline
......
......@@ -15,10 +15,10 @@
import unittest
import paddle
import paddle.fluid.clip as clip
import paddle.fluid.framework as framework
import paddle.fluid.optimizer as optimizer
import paddle.fluid.regularizer as regularizer
import paddle.nn.clip as clip
paddle.enable_static()
......@@ -76,7 +76,7 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
rampup_begin_step=0,
num_trainers=2,
regularization=regularization,
grad_clip=clip.GradientClipByNorm(1.0),
grad_clip=clip.ClipGradByNorm(1.0),
)
if use_recompute:
......@@ -144,14 +144,14 @@ class TestDGCMomentumOptimizer(unittest.TestCase):
print("dgc regular_coeff=" + str(coeff))
def test_tpyeError(self):
# the type of DGCMomentumOptimizer(grad_clip=) must be 'GradientClipByNorm'
# the type of DGCMomentumOptimizer(grad_clip=) must be 'ClipGradByNorm'
with self.assertRaises(TypeError):
dgc_momentum_optimizer = self.MockDGCMomentum(
learning_rate=0.01,
momentum=0.2,
rampup_begin_step=0,
num_trainers=2,
grad_clip=clip.GradientClipByGlobalNorm(1.0),
grad_clip=clip.ClipGradByGlobalNorm(1.0),
)
def test_momentum_without_dgc(self):
......
......@@ -354,7 +354,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
}
strategy.fuse_all_reduce_ops = True
strategy.fuse_grad_size_in_MB = 32
clip = paddle.fluid.clip.GradientClipByGlobalNorm(1.0)
clip = paddle.nn.ClipGradByGlobalNorm(1.0)
self.optimizer(
avg_cost, strategy, train_prog, startup_prog, grad_clip=clip
......@@ -552,7 +552,7 @@ class TestFleetHybridOptimizer(TestFleetMetaOptimizer):
strategy.fuse_all_reduce_ops = True
strategy.fuse_grad_size_in_MB = 32
strategy.fuse_grad_merge = True
clip = paddle.fluid.clip.GradientClipByGlobalNorm(1.0)
clip = paddle.nn.ClipGradByGlobalNorm(1.0)
self.optimizer(
avg_cost, strategy, train_prog, startup_prog, grad_clip=clip
......@@ -940,7 +940,7 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer):
}
strategy.fuse_all_reduce_ops = True
strategy.fuse_grad_size_in_MB = 32
clip = paddle.fluid.clip.GradientClipByGlobalNorm(1.0)
clip = paddle.nn.ClipGradByGlobalNorm(1.0)
self.optimizer(
avg_cost, strategy, train_prog, startup_prog, grad_clip=clip
......@@ -1044,7 +1044,7 @@ class TestFleetHybridOptimizerBoundary(TestFleetMetaOptimizer):
}
strategy.fuse_all_reduce_ops = True
strategy.fuse_grad_size_in_MB = 32
clip = paddle.fluid.clip.GradientClipByGlobalNorm(1.0)
clip = paddle.nn.ClipGradByGlobalNorm(1.0)
self.optimizer(
avg_cost, strategy, train_prog, startup_prog, grad_clip=clip
......
......@@ -640,7 +640,7 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
)
avg_cost, strategy = self.net(train_prog, startup_prog)
self.set_strategy(strategy, 'sharding')
clip = paddle.fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
self.optimizer(
avg_cost, strategy, train_prog, startup_prog, grad_clip=clip
)
......@@ -1309,7 +1309,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
"micro_batch_size": 2,
"accumulate_steps": 4,
}
clip = paddle.fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
self.optimizer(
avg_cost, strategy, train_prog, startup_prog, grad_clip=clip
)
......@@ -1547,7 +1547,7 @@ class TestFleetShardingHybridOptimizer(TestFleetMetaOptimizer):
"micro_batch_size": 2,
"accumulate_steps": 4,
}
clip = paddle.fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
self.optimizer(
avg_cost,
strategy,
......
......@@ -22,8 +22,8 @@ import paddle
import paddle.distributed.fleet as fleet
import paddle.fluid.core as core
from paddle.distributed.fleet.meta_optimizers.common import CollectiveHelper
from paddle.fluid.clip import ClipGradBase, _clip_by_global_norm_using_mp_type
from paddle.incubate import DistributedFusedLamb
from paddle.nn.clip import ClipGradBase, _clip_by_global_norm_using_mp_type
from paddle.vision.models import resnet18 as resnet
......
......@@ -19,6 +19,7 @@ import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.jit.dy2static import Call
from paddle.nn import clip
SEED = 2020
np.random.seed(SEED)
......@@ -89,11 +90,11 @@ def len_with_selected_rows(place):
type=fluid.core.VarDesc.VarType.SELECTED_ROWS,
)
# y is Variable(SelectedRows)
y = fluid.layers.merge_selected_rows(var)
y = clip.merge_selected_rows(var)
y_len = Call(len)(y)
# z is inner tensor with shape [4, 2]
z = fluid.layers.get_tensor_from_selected_rows(y)
z = clip.get_tensor_from_selected_rows(y)
z_len = Call(len)(z)
# set data for selected_rows
......
......@@ -22,8 +22,8 @@ from seq2seq_dygraph_model import AttentionModel, BaseModel
from seq2seq_utils import Seq2SeqModelHyperParams, get_data_iter
import paddle.fluid as fluid
from paddle.fluid.clip import GradientClipByGlobalNorm
from paddle.jit import ProgramTranslator
from paddle.nn import ClipGradByGlobalNorm
place = (
fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace()
......@@ -71,7 +71,7 @@ def train(args, attn_model=False):
dropout=args.dropout,
)
gloabl_norm_clip = GradientClipByGlobalNorm(args.max_grad_norm)
gloabl_norm_clip = ClipGradByGlobalNorm(args.max_grad_norm)
optimizer = fluid.optimizer.SGD(
args.learning_rate,
parameter_list=model.parameters(),
......
......@@ -127,7 +127,7 @@ class ElementwiseActivationMkldnnFusePassTest_Add_Clip(
):
def set_params(self):
self.operand = paddle.add
self.act = fluid.layers.clip
self.act = paddle.clip
self.act_alpha = 0.0
self.act_beta = 10.0
......@@ -219,7 +219,7 @@ class ElementwiseActivationMkldnnFusePassTest_Sub_Clip(
):
def set_params(self):
self.operand = paddle.subtract
self.act = fluid.layers.clip
self.act = paddle.clip
self.act_alpha = 0.0
self.act_beta = 10.0
......@@ -319,7 +319,7 @@ class ElementwiseActivationMkldnnFusePassTest_Mul_Clip(
):
def set_params(self):
self.operand = paddle.multiply
self.act = fluid.layers.clip
self.act = paddle.clip
self.act_alpha = 0.0
self.act_beta = 10.0
......
......@@ -106,7 +106,7 @@ class TensorRTSubgraphPassHardSwishPluginTest(
class TensorRTSubgraphPassClipTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.clip(x, 0, 1)
return paddle.clip(x, 0, 1)
class TensorRTSubgraphPassTanhTest(TensorRTSubgraphPassActivationTest):
......
......@@ -117,13 +117,13 @@ class TestClipOpError(unittest.TestCase):
input_data = np.random.random((2, 4)).astype("float32")
def test_Variable():
fluid.layers.clip(x=input_data, min=-1.0, max=1.0)
paddle.clip(x=input_data, min=-1.0, max=1.0)
self.assertRaises(TypeError, test_Variable)
def test_dtype():
x2 = fluid.layers.data(name='x2', shape=[1], dtype='int32')
fluid.layers.clip(x=x2, min=-1.0, max=1.0)
paddle.clip(x=x2, min=-1.0, max=1.0)
self.assertRaises(TypeError, test_dtype)
paddle.disable_static()
......
......@@ -686,7 +686,7 @@ class TestAdamOpV2(unittest.TestCase):
value = np.arange(26).reshape(2, 13).astype("float32")
a = fluid.dygraph.to_variable(value)
linear = paddle.nn.Linear(13, 5)
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
adam = paddle.optimizer.Adam(
0.1, parameters=linear.parameters(), grad_clip=clip
)
......
......@@ -20,12 +20,13 @@ from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.nn import clip
class TestClipByNormOp(OpTest):
def setUp(self):
self.max_relative_error = 0.006
self.python_api = fluid.layers.clip_by_norm
self.python_api = clip.clip_by_norm
self.init_dtype()
self.initTestCase()
input = np.random.random(self.shape).astype(self.dtype)
......
......@@ -128,15 +128,9 @@ class TestClipOpError(unittest.TestCase):
input_data = np.random.random((2, 4)).astype("float32")
def test_Variable():
fluid.layers.clip(x=input_data, min=-1.0, max=1.0)
paddle.clip(x=input_data, min=-1.0, max=1.0)
self.assertRaises(TypeError, test_Variable)
def test_dtype():
x2 = fluid.layers.data(name='x2', shape=[1], dtype='int32')
fluid.layers.clip(x=x2, min=-1.0, max=1.0)
self.assertRaises(TypeError, test_dtype)
paddle.disable_static()
......
......@@ -584,7 +584,7 @@ class TestL2Decay(TranspilerTest):
def filter(param):
return param.name == "fc_w"
clip = fluid.clip.GradientClipByValue(0.1, need_clip=filter)
clip = paddle.nn.ClipGradByValue(0.1, need_clip=filter)
sgd_optimizer.minimize(avg_cost, grad_clip=clip)
def transpiler_test_impl(self):
......
......@@ -504,8 +504,8 @@ class PaddingRNNTestBase(unittest.TestCase):
self.feed_order,
) = res_vars
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(
paddle.nn.clip.set_gradient_clip(
clip=paddle.nn.ClipGradByGlobalNorm(
clip_norm=config.max_grad_norm
)
)
......
......@@ -64,7 +64,7 @@ class TestFleetExecutor(unittest.TestCase):
)
opt = paddle.optimizer.AdamW(
learning_rate=lr_val,
grad_clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0),
grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0),
)
opt.minimize(loss)
# TODO: section_program will be removed in the future
......
......@@ -64,7 +64,7 @@ class TestFleetExecutor(unittest.TestCase):
)
opt = paddle.optimizer.AdamW(
learning_rate=lr_val,
grad_clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0),
grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0),
)
opt.minimize(loss)
# TODO: section_program will be removed in the future
......
......@@ -47,7 +47,7 @@ class TestFleetExecutor(unittest.TestCase):
)
opt = paddle.optimizer.AdamW(
learning_rate=lr_val,
grad_clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0),
grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0),
)
opt.minimize(loss)
# TODO: section_program will be removed in the future
......
......@@ -20,6 +20,7 @@ import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from paddle.fluid.op import Operator
from paddle.nn import clip
class TestGetTensorFromSelectedRowsError(unittest.TestCase):
......@@ -31,12 +32,12 @@ class TestGetTensorFromSelectedRowsError(unittest.TestCase):
x_data = np.random.random((2, 4)).astype("float32")
def test_Variable():
fluid.layers.get_tensor_from_selected_rows(x=x_data)
clip.get_tensor_from_selected_rows(x=x_data)
self.assertRaises(TypeError, test_Variable)
def test_SELECTED_ROWS():
fluid.layers.get_tensor_from_selected_rows(x=x_var)
clip.get_tensor_from_selected_rows(x=x_var)
self.assertRaises(TypeError, test_SELECTED_ROWS)
......
......@@ -17,12 +17,8 @@ import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.clip import (
GradientClipByGlobalNorm,
GradientClipByNorm,
GradientClipByValue,
)
from paddle.fluid.dygraph.base import to_variable
from paddle.nn import ClipGradByGlobalNorm, ClipGradByNorm, ClipGradByValue
class TestGradClipByGlobalNorm(unittest.TestCase):
......@@ -67,7 +63,7 @@ class TestGradClipByGlobalNorm(unittest.TestCase):
def get_dygrap_global_norm_result(self):
with fluid.dygraph.guard():
gloabl_norm_clip = GradientClipByGlobalNorm(self.max_global_norm)
gloabl_norm_clip = ClipGradByGlobalNorm(self.max_global_norm)
p_g_var = []
for p, g in self.para_and_grad:
new_p = to_variable(p)
......@@ -142,7 +138,7 @@ class TestGradClipByNorm(unittest.TestCase):
def get_dygrap_norm_result(self):
with fluid.dygraph.guard():
norm_clip = GradientClipByNorm(self.max_norm)
norm_clip = ClipGradByNorm(self.max_norm)
p_g_var = []
for p, g in self.para_and_grad:
new_p = to_variable(p)
......@@ -212,9 +208,7 @@ class TestGradClipByValue(unittest.TestCase):
def get_dygrap_clip_result(self):
with fluid.dygraph.guard():
value_clip = GradientClipByValue(
max=self.max_value, min=self.min_value
)
value_clip = ClipGradByValue(max=self.max_value, min=self.min_value)
p_g_var = []
for p, g in self.para_and_grad:
new_p = to_variable(p)
......
......@@ -20,7 +20,7 @@ from fake_reader import fake_imdb_reader
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.clip import _allow_pure_fp16_global_norm_clip
from paddle.nn.clip import _allow_pure_fp16_global_norm_clip
paddle.enable_static()
......@@ -173,9 +173,9 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
# test whether the output is right when use 'set_gradient_clip'
def test_old_gradient_clip(self):
def func(params_grads):
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
fluid.clip.set_gradient_clip(clip)
return fluid.clip.append_gradient_clip_ops(params_grads)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=self.clip_norm)
paddle.nn.clip.set_gradient_clip(clip)
return paddle.nn.clip.append_gradient_clip_ops(params_grads)
self.clip_gradient = func
self.check_gradient_clip(fluid.CPUPlace())
......@@ -183,7 +183,7 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
# test whether the output is right when use grad_clip
def test_new_gradient_clip(self):
def func(params_grads):
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=self.clip_norm)
return clip(params_grads)
self.clip_gradient = func
......@@ -192,7 +192,7 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
# test whether the output is right when use grad_clip under float64
def test_new_gradient_clip_fp64(self):
def func(params_grads):
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=self.clip_norm)
return clip(params_grads)
self.clip_gradient = func
......@@ -201,15 +201,15 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
# invoke 'set_gradient_clip' in a wrong order
def test_wrong_API_order(self):
def backward_func(cost):
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0)
fluid.clip.set_gradient_clip(clip)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=5.0)
paddle.nn.clip.set_gradient_clip(clip)
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=0.01, grad_clip=clip
)
# if 'set_gradient_clip' and 'optimize(grad_clip)' together, 'set_gradient_clip' will be ineffective
sgd_optimizer.minimize(cost)
# 'set_gradient_clip' must before 'minimize', otherwise, 'set_gradient_clip' will be ineffective
fluid.clip.set_gradient_clip(clip)
paddle.nn.clip.set_gradient_clip(clip)
self.backward_and_optimize = backward_func
for place in self.get_places():
......@@ -269,7 +269,7 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
with fluid.program_guard(
main_program=prog, startup_program=startup_program
):
clip = fluid.clip.GradientClipByGlobalNorm(self.clip_norm)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
x = (
fluid.default_main_program()
.global_block()
......@@ -313,7 +313,7 @@ class TestGradientClipByNorm(TestGradientClip):
# test whether the output is right when use grad_clip
def test_gradient_clip(self):
def func(params_grads):
clip = fluid.clip.GradientClipByNorm(clip_norm=self.clip_norm)
clip = paddle.nn.ClipGradByNorm(clip_norm=self.clip_norm)
return clip(params_grads)
self.clip_gradient = func
......@@ -321,7 +321,7 @@ class TestGradientClipByNorm(TestGradientClip):
# if grad is None or not need clip
def test_none_grad(self):
clip = fluid.clip.GradientClipByNorm(self.clip_norm)
clip = paddle.nn.ClipGradByNorm(self.clip_norm)
x = (
fluid.default_main_program()
.global_block()
......@@ -371,7 +371,7 @@ class TestGradientClipByValue(TestGradientClip):
# test whether the output is right when use grad_clip
def test_gradient_clip(self):
def func(params_grads):
clip = fluid.clip.GradientClipByValue(max=self.max, min=self.min)
clip = paddle.nn.ClipGradByValue(max=self.max, min=self.min)
return clip(params_grads)
self.clip_gradient = func
......@@ -379,7 +379,7 @@ class TestGradientClipByValue(TestGradientClip):
# if grad is None or not need clip
def test_none_grad(self):
clip = fluid.clip.GradientClipByValue(self.max, self.min)
clip = paddle.nn.ClipGradByValue(self.max, self.min)
x = (
fluid.default_main_program()
.global_block()
......@@ -419,7 +419,7 @@ class TestDygraphGradientClip(unittest.TestCase):
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=0.0,
parameter_list=linear.parameters(),
grad_clip=fluid.clip.GradientClipByGlobalNorm(0.1),
grad_clip=paddle.nn.ClipGradByGlobalNorm(0.1),
)
self.check_clip_result(loss, sgd_optimizer)
......@@ -430,12 +430,8 @@ class TestDygraphGradientClip(unittest.TestCase):
class TestDygraphGradientClipByGlobalNorm(TestDygraphGradientClip):
def setUp(self):
self.clip_norm = 0.8
self.clip1 = fluid.clip.GradientClipByGlobalNorm(
clip_norm=self.clip_norm
)
self.clip2 = fluid.clip.GradientClipByGlobalNorm(
clip_norm=self.clip_norm
)
self.clip1 = paddle.nn.ClipGradByGlobalNorm(clip_norm=self.clip_norm)
self.clip2 = paddle.nn.ClipGradByGlobalNorm(clip_norm=self.clip_norm)
def check_clip_result(self, loss, optimizer):
# if grad is None
......@@ -476,7 +472,7 @@ class TestDygraphGradientClipByGlobalNorm(TestDygraphGradientClip):
class TestDygraphGradientClipByNorm(TestDygraphGradientClip):
def setUp(self):
self.clip_norm = 0.8
self.clip = fluid.clip.GradientClipByNorm(clip_norm=self.clip_norm)
self.clip = paddle.nn.ClipGradByNorm(clip_norm=self.clip_norm)
def check_clip_result(self, loss, optimizer):
# if grad is None
......@@ -506,7 +502,7 @@ class TestDygraphGradientClipByValue(TestDygraphGradientClip):
def setUp(self):
self.max = 0.2
self.min = 0.1
self.clip = fluid.clip.GradientClipByValue(max=self.max, min=self.min)
self.clip = paddle.nn.ClipGradByValue(max=self.max, min=self.min)
def check_clip_result(self, loss, optimizer):
# if grad is None
......@@ -572,7 +568,7 @@ class TestDygraphGradientClipFP16(unittest.TestCase):
params_grads.append((param, param._grad_ivar()))
_, grads = zip(*params_grads)
# clip grads
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=0.8)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=0.8)
params_grads = clip(params_grads)
_, grads_clip = zip(*params_grads)
# param update
......@@ -616,7 +612,7 @@ class TestDygraphGradientClipFP64(unittest.TestCase):
params_grads.append((param, param._grad_ivar()))
_, grads = zip(*params_grads)
# clip grads
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=0.1)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=0.1)
params_grads = clip(params_grads)
_, grads_clip = zip(*params_grads)
......
......@@ -361,7 +361,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
place = fluid.CPUPlace()
with fluid.dygraph.guard(place):
model = MyLayer(size, vocab_size, size)
grad_clip = fluid.clip.GradientClipByGlobalNorm(0.001)
grad_clip = paddle.nn.ClipGradByGlobalNorm(0.001)
optimizer = fluid.optimizer.AdamOptimizer(
0.001, parameter_list=model.parameters(), grad_clip=grad_clip
)
......@@ -380,7 +380,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
with fluid.dygraph.guard(place):
model = MyLayer2(size, vocab_size, size)
grad_clip = fluid.clip.GradientClipByGlobalNorm(0.001)
grad_clip = paddle.nn.ClipGradByGlobalNorm(0.001)
optimizer = fluid.optimizer.AdamOptimizer(
0.001, parameter_list=model.parameters(), grad_clip=grad_clip
)
......
......@@ -52,7 +52,7 @@ class TestSimpleNet(unittest.TestCase):
fluid.set_flags(
{'FLAGS_sort_sum_gradient': sort_sum_gradient}
)
# grad_clip = fluid.clip.GradientClipByGlobalNorm(5.0)
# grad_clip = paddle.nn.ClipGradByGlobalNorm(5.0)
input_word = np.array([[1, 2], [2, 1]]).astype('int64')
input = paddle.to_tensor(input_word)
......@@ -91,7 +91,7 @@ class TestSimpleNet(unittest.TestCase):
fluid.set_flags(
{'FLAGS_sort_sum_gradient': sort_sum_gradient}
)
grad_clip = fluid.clip.GradientClipByGlobalNorm(5.0)
grad_clip = paddle.nn.ClipGradByGlobalNorm(5.0)
input_word = np.array([[1, 2], [2, 1]]).astype('int64')
input = to_variable(input_word)
......
......@@ -131,13 +131,13 @@ class TestClipOpError(unittest.TestCase):
input_data = np.random.random((2, 4)).astype("float32")
def test_Variable():
fluid.layers.clip(x=input_data, min=-1.0, max=1.0)
paddle.clip(x=input_data, min=-1.0, max=1.0)
self.assertRaises(TypeError, test_Variable)
def test_dtype():
x2 = fluid.layers.data(name='x2', shape=[1], dtype='int32')
fluid.layers.clip(x=x2, min=-1.0, max=1.0)
paddle.clip(x=x2, min=-1.0, max=1.0)
self.assertRaises(TypeError, test_dtype)
paddle.disable_static()
......
......@@ -1535,7 +1535,7 @@ class Model:
assert isinstance(
self._optimizer._grad_clip,
(paddle.nn.ClipGradByGlobalNorm, paddle.nn.ClipGradByNorm),
), "Only GradientClipByNorm and GradientClipByGlobalNorm are supported in amp training with level=O2 currently."
), "Only ClipGradByNorm and ClipGradByGlobalNorm are supported in amp training with level=O2 currently."
self._adapter._amp_custom_lists = {}
self._adapter._amp_configs = {}
......
......@@ -15,13 +15,14 @@
import paddle
import paddle.distributed as dist
from paddle.fluid import core, layers
from paddle.fluid.clip import ClipGradBase, _squared_l2_norm
from paddle.fluid.dygraph import base as imperative_base
from paddle.nn import clip
from paddle.nn.clip import ClipGradBase, _squared_l2_norm
class ClipGradForMOEByGlobalNorm(ClipGradBase):
r"""
The Algrithm is the same as paddle.fluid.clip.ClipGradByGlobalNorm
The Algrithm is the same as paddle.nn.ClipGradByGlobalNorm
Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in
:math:`t\_list` , and limit it to ``clip_norm`` .
......@@ -113,8 +114,8 @@ class ClipGradForMOEByGlobalNorm(ClipGradBase):
continue
merge_grad = g
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
merge_grad = clip.merge_selected_rows(g)
merge_grad = clip.get_tensor_from_selected_rows(merge_grad)
sum_square = _squared_l2_norm(merge_grad)
if sum_square.dtype == core.VarDesc.VarType.FP16:
sum_square_list_fp16.append(sum_square)
......
......@@ -16,11 +16,11 @@ import os
import paddle
from paddle.fluid import core, framework, unique_name
from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.fluid.executor import global_scope
from paddle.fluid.framework import Variable, name_scope
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.optimizer import Optimizer
from paddle.nn import ClipGradByGlobalNorm
def init_communicator(block, rank, ranks, ring_id):
......
......@@ -12,9 +12,1074 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: define the functions to clip gradient of parameter
from ..fluid.clip import ClipGradByGlobalNorm # noqa: F401
from ..fluid.clip import ClipGradByNorm # noqa: F401
from ..fluid.clip import ClipGradByValue # noqa: F401
import copy
import warnings
import paddle
import paddle.autograd as imperative_base
from paddle import _C_ops, _legacy_C_ops
from paddle.common_ops_import import Variable, check_type, default_main_program
from paddle.fluid import core, framework, layers, unique_name
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.framework import LayerHelper, _non_static_mode, in_dygraph_mode
from paddle.tensor.layer_function_generator import templatedoc
__all__ = []
@templatedoc()
def clip_by_norm(x, max_norm, name=None):
"""
${comment}
Args:
x(${x_type}): ${x_comment}
max_norm(${max_norm_type}): ${max_norm_comment}
name(str, optional): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
Returns:
Tensor:
out(${out_type}): ${out_comment}
Examples:
.. code-block:: python
import paddle
from paddle.nn import clip
input = paddle.to_tensor([[2.0, 2.0], [2.0, 2.0]], dtype='float32')
reward = clip.clip_by_norm(x=input, max_norm=1.0)
# [[0.5, 0.5], [0.5, 0.5]]
"""
if in_dygraph_mode():
return _C_ops.clip_by_norm(x, max_norm)
if _non_static_mode():
return _legacy_C_ops.clip_by_norm(x, 'max_norm', max_norm)
helper = LayerHelper("clip_by_norm", **locals())
check_variable_and_dtype(x, 'X', ['float32', 'float16'], 'clip_by_norm')
check_type(max_norm, 'max_norm', (float), 'clip_by_norm')
if name is None:
name = unique_name.generate_with_ignorable_key(
".".join([helper.name, 'tmp'])
)
out = helper.create_variable(
type=x.type, name=name, dtype=x.dtype, persistable=False
)
helper.append_op(
type="clip_by_norm",
inputs={"X": x},
attrs={"max_norm": max_norm},
outputs={"Out": out},
)
return out
@templatedoc()
def merge_selected_rows(x, name=None):
"""
${comment}
Args:
x(${x_type}): ${x_comment}
name(basestring|None): Name of the output.
Returns:
out(${out_type}): ${out_comment}
Examples:
.. code-block:: python
import paddle.fluid as fluid
b = fluid.default_main_program().global_block()
var = b.create_var(
name="X", dtype="float32", persistable=True,
type=fluid.core.VarDesc.VarType.SELECTED_ROWS)
y = nn.merge_selected_rows(var)
"""
if in_dygraph_mode():
return _C_ops.merge_selected_rows(x)
if _non_static_mode():
return _legacy_C_ops.merge_selected_rows(x)
helper = LayerHelper("merge_selected_rows", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="merge_selected_rows",
inputs={"X": x},
attrs={},
outputs={"Out": out},
)
return out
@templatedoc()
def get_tensor_from_selected_rows(x, name=None):
"""
Get tensor data from input with SelectedRows type, and outputs a Tensor.
.. code-block:: text
input x is SelectedRows:
x.rows = [0, 5, 5, 4, 19]
x.height = 20
x.value = [[1, 1] [2, 2] [2, 2] [3, 3] [6, 6]]
Output is LoDTensor:
out.shape = [5, 2]
out.data = [[1, 1],
[2, 2],
[2, 2],
[3, 3],
[6, 6]]
Args:
x(SelectedRows): Input with SelectedRows type. The data type is float32, float64, int32 or int64.
name(str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Variable: LoDTensor transformed from SelectedRows. The data type is same with input.
Examples:
.. code-block:: python
from paddle import nnp.py
b = fluid.default_main_program().global_block()
input = b.create_var(name="X", dtype="float32", persistable=True, type=fluid.core.VarDesc.VarType.SELECTED_ROWS)
out = nn.get_tensor_from_selected_rows(input)
"""
check_type(x, 'x', Variable, 'get_tensor_from_selected_rows')
if x.type != core.VarDesc.VarType.SELECTED_ROWS:
raise TypeError(
"The type of 'x' in get_tensor_from_selected_rows must be SELECTED_ROWS."
)
helper = LayerHelper('get_tensor_from_selected_rows', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='get_tensor_from_selected_rows',
inputs={'X': x},
outputs={'Out': out},
attrs={},
)
return out
_clip_by_global_norm_using_mp_type_flag = False
def _clip_by_global_norm_using_mp_type(*args):
global _clip_by_global_norm_using_mp_type_flag
assert len(args) <= 1
if len(args) == 1:
assert isinstance(args[0], bool)
old_value = _clip_by_global_norm_using_mp_type_flag
_clip_by_global_norm_using_mp_type_flag = args[0]
return old_value
else:
return _clip_by_global_norm_using_mp_type_flag
def _cast_to_mp_type_if_enabled(x):
if (
x.dtype == core.VarDesc.VarType.FP16
or x.dtype == core.VarDesc.VarType.BF16
) and _clip_by_global_norm_using_mp_type():
return x.astype(core.VarDesc.VarType.FP32)
else:
return x
def _squared_l2_norm(x):
r"""
Return the squared L2 norm of a tensor.
"""
x = _cast_to_mp_type_if_enabled(x)
if (
core.is_compiled_with_xpu()
or x.dtype == core.VarDesc.VarType.FP16
or x.dtype == core.VarDesc.VarType.BF16
):
square = paddle.square(x)
sum_square = paddle.sum(square)
return sum_square
if in_dygraph_mode():
return _C_ops.squared_l2_norm(x)
op_type = 'squared_l2_norm'
check_variable_and_dtype(x, 'x', ['float32', 'float64'], op_type)
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(x.dtype)
inputs = {"X": x}
outputs = {'Out': out}
helper.append_op(type=op_type, inputs=inputs, outputs=outputs)
return out
class BaseErrorClipAttr:
def __str__(self):
raise NotImplementedError()
def _append_clip_op(self, block, grad_name):
raise NotImplementedError()
class ErrorClipByValue(BaseErrorClipAttr):
r"""
Clip tensor values to the range [min, max].
Given a tensor ``t`` (see Examples below), this operation clips its value \
to ``min`` and ``max`` inplace.
- Any values less than min are set to min.
- Any values greater than max are set to max.
Args:
max (float): The maximum value to clip by.
min (float, optional): The minimum value to clip by. if not set by user, \
will be set to ``-max`` by framework.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle
paddle.enable_static()
BATCH_SIZE = 128
CLIP_MAX = 2e-6
CLIP_MIN = -1e-6
prog = fluid.framework.Program()
with fluid.program_guard(main_program=prog):
image = fluid.layers.data(
name='x', shape=[784], dtype='float32')
hidden1 = fluid.layers.fc(input=image, size=128, act='relu')
hidden2 = fluid.layers.fc(input=hidden1, size=64, act='relu')
predict = fluid.layers.fc(
input=hidden2, size=10, act='softmax')
label = fluid.layers.data(name='y', shape=[1], dtype='int64')
cost = paddle.nn.functional.cross_entropy(input=predict, label=label)
avg_cost = paddle.mean(cost)
prog_clip = prog.clone()
prog_clip.block(0).var(hidden1.name)._set_error_clip(
paddle.nn.clip.ErrorClipByValue(
max=CLIP_MAX, min=CLIP_MIN)
)
"""
def __init__(self, max, min=None):
max = float(max)
if min is None:
min = -max
else:
min = float(min)
self.max = max
self.min = min
def __str__(self):
return "ByValue, min=%f, max=%f" % (self.min, self.max)
def _append_clip_op(self, block, grad_name):
clip_op_desc = block.desc.append_op()
clip_op_desc.set_type("clip")
clip_op_desc.set_input("X", [grad_name])
clip_op_desc.set_output("Out", [grad_name])
clip_op_desc._set_attr("min", self.min)
clip_op_desc._set_attr("max", self.max)
def error_clip_callback(block, context):
# the context is a grad_to_var map
grad_to_var = context
op_desc = block.desc.op(block.desc.op_size() - 1)
for grad_n in [n for n in op_desc.output_arg_names() if n in grad_to_var]:
fwd_var = block._var_recursive(grad_to_var[grad_n])
error_clip = getattr(fwd_var, "error_clip", None)
if not (
error_clip is None or isinstance(error_clip, BaseErrorClipAttr)
):
raise TypeError(
"Variable's error_clip should be an instance of BaseErrorClipAttr or None."
)
if error_clip is not None:
error_clip._append_clip_op(block, grad_n)
class ClipGradBase:
def __init__(self):
super().__init__()
def __str__(self):
raise NotImplementedError()
@imperative_base.no_grad()
def _dygraph_clip(self, params_grads):
raise NotImplementedError
def _static_clip(self, params_grads):
raise NotImplementedError
def __call__(self, params_grads):
if _non_static_mode():
return self._dygraph_clip(params_grads)
else:
for p, g in params_grads:
if getattr(p, 'gradient_clip_attr', None) is not None:
warnings.warn(
"'set_gradient_clip' will be ineffective, because you have "
"set 'need_clip' in 'ParamAttr'. So, 'set_gradient_clip' "
"is redundant and you can remove it."
)
break
return self._static_clip(params_grads)
def _process_context(self, context, param, grad):
raise NotImplementedError()
def _create_operators(self, param, grad):
raise NotImplementedError()
class ClipGradByValue(ClipGradBase):
"""
Limit the value of multi-dimensional Tensor :math:`X` to the range [min, max].
- Any values less than min are set to ``min``.
- Any values greater than max are set to ``max``.
The multi-dimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
(for example: :ref:`api_paddle_optimizer_SGD`).
Note:
``need_clip`` of ``ClipGradByValue`` HAS BEEN DEPRECATED since 2.0.
Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.
Args:
max (float): The maximum value to clip by.
min (float, optional): The minimum value to clip by. if not set by user, it will be set to ``-max``
automatically. In this case, ``max`` must be greater than 0.
Examples:
.. code-block:: python
import paddle
x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
linear = paddle.nn.Linear(in_features=10, out_features=10,
weight_attr=paddle.ParamAttr(need_clip=True),
bias_attr=paddle.ParamAttr(need_clip=False))
out = linear(x)
loss = paddle.mean(out)
loss.backward()
clip = paddle.nn.ClipGradByValue(min=-1, max=1)
sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), grad_clip=clip)
sdg.step()
"""
def __init__(self, max, min=None):
super().__init__()
if min is None:
assert max > 0.0
min = -max
self.max = float(max)
self.min = float(min)
def __str__(self):
return "Clip Gradient By Value, min = %f, max=%f" % (self.min, self.max)
@imperative_base.no_grad()
def _dygraph_clip(self, params_grads):
params_and_grads = []
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
new_grad = paddle.clip(x=g, min=self.min, max=self.max)
params_and_grads.append((p, new_grad))
return params_and_grads
def _static_clip(self, params_grads):
params_and_grads = []
param_new_grad_name_dict = dict()
with framework.name_scope('gradient_clip'):
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
with p.block.program._optimized_guard([p, g]):
new_grad = paddle.clip(x=g, min=self.min, max=self.max)
params_and_grads.append((p, new_grad))
param_new_grad_name_dict[p.name] = new_grad.name
_correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
return params_and_grads
def _process_context(self, context, param, grad):
pass
def _create_operators(self, param, grad):
new_grad = paddle.clip(x=grad, min=self.min, max=self.max)
return param, new_grad
class ClipGradByNorm(ClipGradBase):
r"""
Limit the l2 norm of multi-dimensional Tensor :math:`X` to ``clip_norm`` .
- If the l2 norm of :math:`X` is greater than ``clip_norm`` , :math:`X` will be compressed by a ratio.
- If the l2 norm of :math:`X` is less than or equal to ``clip_norm`` , nothing will be done.
The multidimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
(for example: :ref:`api_paddle_optimizer_SGD`).
The clipping formula is:
.. math::
Out =
\left\{
\begin{array}{ccl}
X & & if (norm(X) \leq clip\_norm) \\
\frac{clip\_norm*X}{norm(X)} & & if (norm(X) > clip\_norm) \\
\end{array}
\right.
where :math:`norm(X)` represents the L2 norm of :math:`X`.
.. math::
norm(X) = ( \sum_{i=1}^{n}|x\_i|^2)^{ \frac{1}{2}}
Note:
``need_clip`` of ``ClipGradByNorm`` HAS BEEN DEPRECATED since 2.0.
Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.
Args:
clip_norm(float): The maximum norm value.
Examples:
.. code-block:: python
import paddle
x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
linear = paddle.nn.Linear(in_features=10, out_features=10,
weight_attr=paddle.ParamAttr(need_clip=True),
bias_attr=paddle.ParamAttr(need_clip=False))
out = linear(x)
loss = paddle.mean(out)
loss.backward()
clip = paddle.nn.ClipGradByNorm(clip_norm=1.0)
sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), grad_clip=clip)
sdg.step()
"""
def __init__(self, clip_norm):
super().__init__()
self.clip_norm = float(clip_norm)
def __str__(self):
return "Gradient Clip By Norm, clip_norm=%f" % self.clip_norm
@imperative_base.no_grad()
def _dygraph_clip(self, params_grads):
params_and_grads = []
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
new_grad = clip_by_norm(x=g, max_norm=self.clip_norm)
params_and_grads.append((p, new_grad))
return params_and_grads
def _static_clip(self, params_grads):
params_and_grads = []
with framework.name_scope('gradient_clip'):
param_new_grad_name_dict = dict()
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
with p.block.program._optimized_guard([p, g]):
new_grad = clip_by_norm(x=g, max_norm=self.clip_norm)
param_new_grad_name_dict[p.name] = new_grad.name
params_and_grads.append((p, new_grad))
_correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
return params_and_grads
def _process_context(self, context, param, grad):
pass
def _create_operators(self, param, grad):
new_grad = clip_by_norm(x=grad, max_norm=self.clip_norm)
return param, new_grad
_allow_pure_fp16_global_norm_clip_flag = False
def _allow_pure_fp16_global_norm_clip(*args):
global _allow_pure_fp16_global_norm_clip_flag
if len(args) == 0:
return _allow_pure_fp16_global_norm_clip_flag
else:
assert len(args) == 1 and isinstance(args[0], bool)
old_value = _allow_pure_fp16_global_norm_clip_flag
_allow_pure_fp16_global_norm_clip_flag = args[0]
return old_value
class ClipGradByGlobalNorm(ClipGradBase):
r"""
Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in
:math:`t\_list` , and limit it to ``clip_norm`` .
- If the global norm is greater than ``clip_norm`` , all elements of :math:`t\_list` will be compressed by a ratio.
- If the global norm is less than or equal to ``clip_norm`` , nothing will be done.
The list of Tensor :math:`t\_list` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
(for example: :ref:`api_paddle_optimizer_SGD`).
The clipping formula is:
.. math::
t\_list[i] = t\_list[i] * \frac{clip\_norm}{\max(global\_norm, clip\_norm)}
where:
.. math::
global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2}
Note:
``need_clip`` of ``ClipGradyGlobalNorm`` HAS BEEN DEPRECATED since 2.0.
Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.
Args:
clip_norm (float): The maximum norm value.
group_name (str, optional): The group name for this clip. Default value is ``default_group``.
auto_skip_clip (bool, optional): skip clipping gradient. Default value is ``False``.
Examples:
.. code-block:: python
import paddle
x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
linear = paddle.nn.Linear(in_features=10, out_features=10,
weight_attr=paddle.ParamAttr(need_clip=True),
bias_attr=paddle.ParamAttr(need_clip=False))
out = linear(x)
loss = paddle.mean(out)
loss.backward()
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
sdg = paddle.optimizer.SGD(learning_rate=0.1, parameters=linear.parameters(), grad_clip=clip)
sdg.step()
"""
def __init__(
self, clip_norm, group_name="default_group", auto_skip_clip=False
):
super().__init__()
self.clip_norm = float(clip_norm)
self.group_name = group_name
assert isinstance(auto_skip_clip, bool)
self.auto_skip_clip = auto_skip_clip
def __str__(self):
return "Gradient Clip By GlobalNorm, global_norm=%f" % (self.clip_norm)
@imperative_base.no_grad()
def _dygraph_clip(self, params_grads):
params_and_grads = []
sum_square_list = []
sum_square_list_fp16 = []
sum_square_list_fp32 = []
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
continue
merge_grad = g
if in_dygraph_mode() and g.is_selected_rows():
merge_grad = merge_selected_rows(g)
merge_grad = merge_grad._get_tensor_from_selected_rows()
elif g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = merge_selected_rows(g)
merge_grad = get_tensor_from_selected_rows(merge_grad)
sum_square = _squared_l2_norm(merge_grad)
if (
sum_square.dtype == core.VarDesc.VarType.FP16
or sum_square.dtype == core.VarDesc.VarType.BF16
):
sum_square_list_fp16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.FP32:
sum_square_list_fp32.append(sum_square)
else:
sum_square_list.append(sum_square)
# all parameters have been filterd out
if (
len(sum_square_list)
+ len(sum_square_list_fp16)
+ len(sum_square_list_fp32)
== 0
):
return params_grads
sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32"
global_norm_var = []
if len(sum_square_list_fp16) > 0:
global_norm_var_fp16 = paddle.add_n(sum_square_list_fp16)
global_norm_var.append(global_norm_var_fp16.astype(sum_dtype))
if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = paddle.add_n(sum_square_list_fp32)
if sum_dtype == 'float32':
global_norm_var.append(global_norm_var_fp32)
else:
global_norm_var.append(global_norm_var_fp32.astype(sum_dtype))
if len(sum_square_list) > 0:
global_norm_var_fp64 = paddle.add_n(sum_square_list)
global_norm_var.append(global_norm_var_fp64)
global_norm_var = paddle.add_n(global_norm_var)
global_norm_var = paddle.sqrt(global_norm_var)
max_global_norm = paddle.full(
shape=[1], dtype=global_norm_var.dtype, fill_value=self.clip_norm
)
need_clip = False
if not self.auto_skip_clip: # always apply clip
need_clip = True
clip_var = paddle.divide(
x=max_global_norm,
y=paddle.maximum(x=global_norm_var, y=max_global_norm),
)
elif global_norm_var > max_global_norm:
# only when global_norm_var > max_global_norm, grad need clip
need_clip = True
clip_var = paddle.divide(x=max_global_norm, y=global_norm_var)
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
# TODO(wangxi): use inplace elementwise_mul
if need_clip:
clip_input = (
clip_var.astype(g.dtype)
if clip_var.dtype != g.dtype
else clip_var
)
new_grad = paddle.multiply(g, clip_input)
params_and_grads.append((p, new_grad))
else:
params_and_grads.append((p, g))
return params_and_grads
def _static_clip(self, params_grads):
params_and_grads = []
sum_square_list = []
sum_square_list_fp16 = []
sum_square_list_fp32 = []
with framework.name_scope('gradient_clip'):
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
continue
merge_grad = g
with p.block.program._optimized_guard([p, g]):
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = merge_selected_rows(g)
merge_grad = get_tensor_from_selected_rows(merge_grad)
sum_square = _squared_l2_norm(merge_grad)
if sum_square.dtype == core.VarDesc.VarType.FP16:
sum_square_list_fp16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.FP32:
sum_square_list_fp32.append(sum_square)
else:
sum_square_list.append(sum_square)
# all parameters have been filterd out
if (
len(sum_square_list)
+ len(sum_square_list_fp16)
+ len(sum_square_list_fp32)
== 0
):
return params_grads
with p.block.program._optimized_guard([p, g]):
sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32"
global_norm_var = []
if len(sum_square_list_fp16) > 0:
global_norm_var_fp16 = layers.sums(sum_square_list_fp16)
if (
sum_square_list_fp32
or sum_square_list
or not _allow_pure_fp16_global_norm_clip()
):
global_norm_var.append(
global_norm_var_fp16.astype(sum_dtype)
)
else:
global_norm_var.append(global_norm_var_fp16)
if len(sum_square_list_fp32) > 0:
global_norm_var_fp32 = layers.sums(sum_square_list_fp32)
if sum_dtype == 'float32':
global_norm_var.append(global_norm_var_fp32)
else:
global_norm_var.append(
global_norm_var_fp32.astype(sum_dtype)
)
if len(sum_square_list) > 0:
# fp64
global_norm_var_other_dtype = layers.sums(sum_square_list)
global_norm_var.append(global_norm_var_other_dtype)
global_norm_var = (
layers.sums(global_norm_var)
if len(global_norm_var) > 1
else global_norm_var[0]
)
global_norm_var = paddle.sqrt(x=global_norm_var)
max_global_norm = paddle.full(
shape=[1],
dtype=global_norm_var.dtype,
fill_value=self.clip_norm,
)
scale_var = paddle.divide(
x=max_global_norm,
y=paddle.maximum(x=max_global_norm, y=global_norm_var),
)
param_new_grad_name_dict = dict()
for p, g in params_grads:
if g is None:
continue
if getattr(p, 'need_clip', True) is False:
params_and_grads.append((p, g))
continue
with p.block.program._optimized_guard([p, g]):
new_g = _cast_to_mp_type_if_enabled(g)
# inplace
scale_input = (
scale_var.astype('float16')
if new_g.dtype == core.VarDesc.VarType.FP16
and scale_var.dtype != core.VarDesc.VarType.FP16
else scale_var
)
# NOTE(Yuang Liu): For pure dp with gradient merge, the p and g
# will be in different blocks with the gradient clip related ops.
# We need to handle the correct block, otherwise will encounter
# a 'NotFoundError' during compile time.
block = default_main_program().current_block()
block.append_op(
type='elementwise_mul',
inputs={'X': new_g, 'Y': scale_input},
outputs={'Out': new_g},
)
if new_g is not g:
block.append_op(
type='cast',
inputs={'X': new_g},
outputs={'Out': g},
attrs={
'in_dtype': new_g.dtype,
'out_dtype': g.dtype,
},
)
param_new_grad_name_dict[p.name] = g.name
params_and_grads.append((p, g))
_correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict)
return params_and_grads
def _process_context(self, context, param, grad):
if self.group_name not in context:
context[self.group_name] = []
context[self.group_name + "_clip_value"] = self.clip_norm
context[self.group_name + "_clip"] = paddle.full(
shape=[1], dtype=grad.dtype, fill_value=self.clip_norm
)
else:
if not self.clip_norm == context[self.group_name + "_clip_value"]:
raise ValueError(
"All parameters' 'clip_norm' of a same group should be the same"
)
merge_grad = grad
if grad.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = merge_selected_rows(grad)
merge_grad = get_tensor_from_selected_rows(merge_grad)
local_norm_var = _squared_l2_norm(merge_grad)
context[self.group_name].append(local_norm_var)
self.context = context
def _create_operators(self, param, grad):
group_scale_name = self.group_name + "_scale"
if group_scale_name not in self.context:
group_norm_var = layers.sums(input=self.context[self.group_name])
group_norm_var = paddle.sqrt(x=group_norm_var)
clip_var = self.context[self.group_name + "_clip"]
group_scale_var = paddle.divide(
x=clip_var,
y=paddle.maximum(x=clip_var, y=group_norm_var),
)
assert group_scale_var.shape == (1,)
self.context[group_scale_name] = group_scale_var
# inplace
param.block.append_op(
type='elementwise_mul',
inputs={'X': grad, 'Y': self.context[group_scale_name]},
outputs={'Out': grad},
)
return param, grad
@framework.dygraph_not_support
def set_gradient_clip(clip, param_list=None, program=None):
"""
Warning:
This API must be used after building network, and before ``minimize`` ,
and it may be removed in future releases, so it is not recommended.
It is recommended to set ``grad_clip`` when initializing the ``optimizer`` ,
this is a better method to clip gradient. There are three clipping strategies:
:ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` .
To specify parameters that require gradient clip.
Args:
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default value: None, and there is no
gradient clipping.
param_list (list(Variable), optional): Parameters that require gradient clip.
It can be a list of parameter or a list of parameter's name.
Default None, meaning that all parameters in the program will be included.
program (Program, optional): The program where parameters are located.
Default None, meaning that using :ref:`api_fluid_default_main_program` .
Returns:
None
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
paddle.enable_static()
def network():
image = fluid.data(name='image', shape=[
None, 28], dtype='float32')
param_attr1 = fluid.ParamAttr("fc1_param")
fc1 = fluid.layers.fc(image, size=10, param_attr=param_attr1)
param_attr2 = fluid.ParamAttr("fc2_param")
fc2 = fluid.layers.fc(fc1, size=10, param_attr=param_attr2)
loss = paddle.mean(fc2)
return loss
# network 1: clip all parameter gradient
with fluid.program_guard(fluid.Program(), fluid.Program()):
loss = network()
paddle.nn.clip.set_gradient_clip(
paddle.nn.ClipGradByGlobalNorm(clip_norm=2.0))
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss)
# network 2: clip parameter gradient by name
with fluid.program_guard(fluid.Program(), fluid.Program()):
loss = network()
paddle.nn.clip.set_gradient_clip(
paddle.nn.ClipGradByValue(min=-1.0, max=1.0),
param_list=["fc1_param", "fc2_param"])
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss)
# network 3: clip parameter gradient by value
with fluid.program_guard(fluid.Program(), fluid.Program()):
loss = network()
param_var1 = fluid.default_main_program().global_block().var("fc1_param")
param_var2 = fluid.default_main_program().global_block().var("fc2_param")
paddle.nn.clip.set_gradient_clip(
paddle.nn.ClipGradByValue(min=-1.0, max=1.0),
param_list=[param_var1, param_var2])
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss)
# network 4: use 'set_gradient_clip' and 'optimize(grad_clip=clip)' together
with fluid.program_guard(fluid.Program(), fluid.Program()):
loss = network()
clip1 = paddle.nn.ClipGradByValue(min=-1.0, max=1.0)
clip2 = paddle.nn.ClipGradByNorm(clip_norm=1.0)
# Set the gradient clipping strategy: clip1
paddle.nn.clip.set_gradient_clip(clip1)
# Set the gradient clipping strategy: clip2
sgd = fluid.optimizer.SGD(learning_rate=1e-3, grad_clip=clip2)
sgd.minimize(loss)
# 'set_gradient_clip' will not take effect when setting has a conflict,
# and the gradient clipping strategy will be 'clip2'
"""
warnings.warn(
"Caution! 'set_gradient_clip' is not recommended "
"and may be deprecated in future! "
"We recommend a new strategy: set 'grad_clip' "
"when initializing the 'optimizer'. "
"This method can reduce the mistakes, please "
"refer to documention of 'optimizer'."
)
if not isinstance(clip, ClipGradBase):
raise TypeError(
"'clip' should be an instance of ClipGradBase's derived class"
)
if program is None:
program = framework.default_main_program()
for op in program.block(0).ops:
if 'op_namescope' in op.all_attrs() and "optimizer" in op.attr(
"op_namescope"
):
warnings.warn(
"'minimize' has been invoked before, this will make 'set_gradient_clip' "
"be ineffective! Please invoke 'set_gradient_clip' before 'minimize'."
)
break
if param_list is None:
param_list = program.block(0).all_parameters()
if all(isinstance(elem, str) for elem in param_list):
param_list = [program.block(0).var(elem) for elem in param_list]
if not all(isinstance(elem, framework.Parameter) for elem in param_list):
raise TypeError(
"'param_list' should be a list of Parameter or basestring(parameter's name)."
)
for param in param_list:
param.gradient_clip_attr = copy.deepcopy(clip)
def append_gradient_clip_ops(param_grads):
context = dict()
for p, g in param_grads:
if g is None:
continue
with p.block.program._optimized_guard([p, g]), framework.name_scope(
'gradient_clip'
):
clip_attr = getattr(p, 'gradient_clip_attr', None)
if clip_attr is None:
return param_grads
if not isinstance(clip_attr, ClipGradBase):
raise TypeError(
"clip attribute should be an instance of GradientClipBase"
)
clip_attr._process_context(context=context, param=p, grad=g)
res = []
param_new_grad_name_dict = dict()
for p, g in param_grads:
if g is None:
continue
with p.block.program._optimized_guard([p, g]), framework.name_scope(
'gradient_clip'
):
param, new_grad = clip_attr._create_operators(param=p, grad=g)
param_new_grad_name_dict[param.name] = new_grad.name
res.append([param, new_grad])
_correct_clip_op_role_var(res, param_new_grad_name_dict)
return res
# change wrong mapping relation between param & grad in clip op
# Note: This function is sensitive to the time cost of the network with gradient clipping
# and should not be changed easily. If you must change, please test the time cost.
def _correct_clip_op_role_var(params_grads, param_new_grad_name_dict):
block_id_list = []
if len(param_new_grad_name_dict) == 0:
return
for param, grad in params_grads:
if grad is None:
continue
block_id = param.block.idx
if block_id in block_id_list:
continue
block_id_list.append(block_id)
for op in param.block.program.global_block().ops:
if (
op.has_attr("op_namescope")
and "gradient_clip" in op.attr("op_namescope")
and op.attr('op_role_var')
):
param_name = op.attr('op_role_var')[0]
if param_name in param_new_grad_name_dict:
correct_p_g = [
param_name,
param_new_grad_name_dict[param_name],
]
op._set_attr('op_role_var', correct_p_g)
GradientClipBase = ClipGradBase
GradientClipByValue = ClipGradByValue
GradientClipByNorm = ClipGradByNorm
GradientClipByGlobalNorm = ClipGradByGlobalNorm
......@@ -20,10 +20,10 @@ import paddle
from .. import _C_ops
from ..fluid import core, framework, unique_name
from ..fluid.clip import GradientClipBase
from ..fluid.dygraph import base as imperative_base
from ..fluid.framework import Parameter, Variable
from ..fluid.layer_helper import LayerHelper
from ..nn.clip import GradientClipBase
from .lr import LRScheduler
from .optimizer import Optimizer
......
......@@ -18,6 +18,7 @@ from collections import defaultdict
import numpy as np
import paddle
import paddle.autograd as imperative_base
from paddle import _C_ops
from paddle.fluid import core
from paddle.fluid.framework import (
......@@ -32,12 +33,6 @@ from paddle.fluid.framework import (
from ..fluid import framework, unique_name
from ..fluid.backward import _get_no_grad_set_name, append_backward
from ..fluid.clip import (
GradientClipBase,
append_gradient_clip_ops,
error_clip_callback,
)
from ..fluid.dygraph import base as imperative_base
from ..fluid.framework import Parameter, program_guard
from ..fluid.initializer import Constant
from ..fluid.layer_helper import LayerHelper
......@@ -168,7 +163,7 @@ class Optimizer:
"""
@imperative_base.no_grad
@imperative_base.no_grad()
def __init__(
self,
learning_rate,
......@@ -225,7 +220,7 @@ class Optimizer:
% type(learning_rate)
)
if grad_clip is not None:
if not isinstance(grad_clip, GradientClipBase):
if not isinstance(grad_clip, paddle.nn.clip.GradientClipBase):
raise TypeError(
"'grad_clip' should be an instance of GradientClipBase's derived class"
)
......@@ -1042,7 +1037,7 @@ class Optimizer:
params_grads.append((parameter_list[index], grad))
else:
if callbacks is None:
callbacks = [error_clip_callback]
callbacks = [paddle.nn.clip.error_clip_callback]
else:
assert isinstance(callbacks, list)
program = loss.block.program
......@@ -1103,7 +1098,7 @@ class Optimizer:
params_grads = self._grad_clip(params_grads)
else:
params_grads = append_gradient_clip_ops(params_grads)
params_grads = paddle.nn.clip.append_gradient_clip_ops(params_grads)
# Add regularization if any
params_grads = self.append_regularization_ops(
......@@ -1317,7 +1312,7 @@ class Optimizer:
else:
core.clear_gradients(param_list, set_to_zero)
@imperative_base.no_grad
@imperative_base.no_grad()
def minimize(
self, loss, startup_program=None, parameters=None, no_grad_set=None
):
......@@ -1380,7 +1375,7 @@ class Optimizer:
return optimize_ops, params_grads
@imperative_base.no_grad
@imperative_base.no_grad()
@framework.dygraph_only
def step(self):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册