未验证 提交 d1e8b1e2 编写于 作者: Y Yiqun Liu 提交者: GitHub

Cherry pick for fix of operator precision. (#52705)

* Fix scale kernel for low precision, cherry pick #50998.

* Fix the FP16 precision problem of add_n. (#50129)

* Change squared_l2_norm to reuse ReduceKernel, and register fp16 and bf16 kernel, which is cherry pick #48315.

* Cherry-pick the fix of MPTypeTrait in KP, which is implemented in #50993.

* Cherry-pick the multi-precision support of AdamW for bf16, #48041.

* Fix compiling error.

* Cherry-pick the fix of CubTensorReduceImpl for bfloat16 in #50993.

* Fix unittest.

---------
Co-authored-by: Nliuruyan <44316842+liuruyan@users.noreply.github.com>
上级 d12588d2
......@@ -986,14 +986,16 @@ template <typename Tx,
template <typename>
class ReduceOp,
typename TransformOp>
static typename std::enable_if<!std::is_same<Tx, phi::dtype::float16>::value,
void>::type
CubTensorReduceImpl(const Tx* x_data,
Ty* y_data,
const TransformOp& transform,
int reduce_num,
const KPDevice& dev_ctx,
KPStream stream) {
static
typename std::enable_if<!std::is_same<Tx, phi::dtype::float16>::value &&
!std::is_same<Tx, phi::dtype::bfloat16>::value,
void>::type
CubTensorReduceImpl(const Tx* x_data,
Ty* y_data,
const TransformOp& transform,
int reduce_num,
const KPDevice& dev_ctx,
KPStream stream) {
auto reducer = ReduceOp<Ty>();
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data,
transform);
......@@ -1037,6 +1039,23 @@ CubTensorReduceImpl(const Tx* x_data,
PADDLE_THROW(phi::errors::InvalidArgument(
"Tx should not be float16 when using cub::DeviceReduce::Reduce()."));
}
template <typename Tx,
typename Ty,
template <typename>
class ReduceOp,
typename TransformOp>
static typename std::enable_if<std::is_same<Tx, phi::dtype::bfloat16>::value,
void>::type
CubTensorReduceImpl(const Tx* x_data,
Ty* y_data,
const TransformOp& transform,
int reduce_num,
const KPDevice& dev_ctx,
KPStream stream) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Tx should not be bfloat16 when using cub::DeviceReduce::Reduce()."));
}
#endif // PADDLE_WITH_XPU_KP
template <typename Tx,
......@@ -1081,7 +1100,8 @@ void ReduceKernel(const KPDevice& dev_ctx,
config.SetOutputData(y_data, dev_ctx, &tmp);
constexpr bool kIsTxFP16 = std::is_same<Tx, phi::dtype::float16>::value;
bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16;
constexpr bool kIsTxBF16 = std::is_same<Tx, phi::dtype::bfloat16>::value;
bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16 && !kIsTxBF16;
#ifndef PADDLE_WITH_XPU_KP
if (use_cub_reduce) {
if (is_mean) {
......
......@@ -14,10 +14,10 @@
#include "paddle/phi/kernels/add_n_kernel.h"
#include "paddle/phi/kernels/impl/add_n_kernel_impl.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/impl/add_n_kernel_impl.h"
namespace phi {
......@@ -38,16 +38,18 @@ __global__ void Sum2CUDAKernel(const T *in_0,
template <class T>
__global__ void SumArrayCUDAKernel(
T **in, T *out, int64_t N, size_t in_size, bool read_dst) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
int id = blockIdx.x * blockDim.x + threadIdx.x;
while (id < N) {
T total(read_dst ? out[id] : static_cast<T>(0));
MPType total(read_dst ? static_cast<MPType>(out[id])
: static_cast<MPType>(0));
for (int i = 0; i < in_size; ++i) {
const T *tmp = in[i];
if (tmp) {
total += tmp[id];
total += static_cast<MPType>(tmp[id]);
}
}
out[id] = total;
out[id] = static_cast<T>(total);
id += blockDim.x * gridDim.x;
}
}
......@@ -116,11 +118,12 @@ void AddNKernel(const Context &dev_ctx,
int64_t length_0 = in_0.numel();
int64_t length_1 = in_1.numel();
if (length_0 && length_1 && in_0.IsInitialized() && in_1.IsInitialized()) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device();
auto in_0_e = EigenVector<T>::Flatten(in_0);
auto in_1_e = EigenVector<T>::Flatten(in_1);
result.device(place) = in_0_e + in_1_e;
auto in_0_e = EigenVector<T>::Flatten(in_0).template cast<MPType>();
auto in_1_e = EigenVector<T>::Flatten(in_1).template cast<MPType>();
result.device(place) = (in_0_e + in_1_e).template cast<T>();
} else if (length_0 && in_0.IsInitialized()) {
auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device();
......
......@@ -15,28 +15,30 @@ limitations under the License. */
#include "paddle/phi/kernels/scale_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
namespace phi {
template <typename InT>
template <typename DataT, typename ParamT>
struct ScaleFunctor {
InT bias;
InT scale;
ParamT bias;
ParamT scale;
bool bias_after_scale;
ScaleFunctor(InT scale_data, InT bias_data, bool is_bias_after_sacle)
ScaleFunctor(ParamT scale_data, ParamT bias_data, bool is_bias_after_sacle)
: bias(bias_data),
scale(scale_data),
bias_after_scale(is_bias_after_sacle) {}
__device__ __forceinline__ InT operator()(const InT x) const {
__device__ __forceinline__ DataT operator()(const DataT x) const {
if (bias_after_scale) {
return scale * x + bias;
return static_cast<DataT>(scale * static_cast<ParamT>(x) + bias);
} else {
return scale * (x + bias);
return static_cast<DataT>(scale * (static_cast<ParamT>(x) + bias));
}
}
};
......@@ -48,16 +50,21 @@ void ScaleKernel(const Context& dev_ctx,
float bias,
bool bias_after_scale,
DenseTensor* out) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
std::vector<const DenseTensor*> inputs;
std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x);
outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out);
if (x.numel() <= 0 || (!x.IsInitialized())) {
return;
}
phi::funcs::ElementwiseKernel<T>(
dev_ctx,
inputs,
&outputs,
ScaleFunctor<T>(scale.to<T>(), static_cast<T>(bias), bias_after_scale));
ScaleFunctor<T, MT>(
scale.to<MT>(), static_cast<MT>(bias), bias_after_scale));
}
} // namespace phi
......
......@@ -15,12 +15,47 @@
#include "paddle/phi/kernels/squared_l2_norm_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/squared_l2_norm_grad_kernel_impl.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
namespace phi {
/**
* x*y*2.0
*/
template <typename T>
struct DoubleMulFunctor {
__device__ __forceinline__ T operator()(const T a, const T b) const {
return b * a * static_cast<T>(2.0f);
}
};
template <typename T, typename Context>
void SquaredL2NormGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
DenseTensor* dx) {
dev_ctx.template Alloc<T>(dx);
PADDLE_ENFORCE_EQ(
dout.numel(),
1,
phi::errors::InvalidArgument(
"Input(GRAD@Out) of SquaredL2NormGradOP should be a scalar."));
std::vector<const DenseTensor*> ins{&x, &dout};
std::vector<DenseTensor*> outs{dx};
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, -1, phi::DoubleMulFunctor<T>());
}
} // namespace phi
PD_REGISTER_KERNEL(squared_l2_norm_grad,
GPU,
ALL_LAYOUT,
phi::SquaredL2NormGradKernel,
float,
double) {}
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -15,9 +15,34 @@
#include "paddle/phi/kernels/squared_l2_norm_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/squared_l2_norm_kernel_impl.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
PD_REGISTER_KERNEL(
squared_l2_norm, GPU, ALL_LAYOUT, phi::SquaredL2NormKernel, float, double) {
namespace phi {
template <typename T, typename Context>
void SquaredL2NormKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
std::vector<int> origin_reduce_dims;
for (size_t i = 0; i < x.dims().size(); i++) {
origin_reduce_dims.push_back(i);
}
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::SquareFunctor<T, T>>(
dev_ctx, x, out, kps::SquareFunctor<T, T>(), origin_reduce_dims, false);
}
} // namespace phi
PD_REGISTER_KERNEL(squared_l2_norm,
GPU,
ALL_LAYOUT,
phi::SquaredL2NormKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -52,6 +52,12 @@ class MPTypeTrait<phi::dtype::float16> {
using Type = float;
};
template <>
class MPTypeTrait<phi::dtype::bfloat16> {
public:
using Type = float;
};
/**
* @brief Will be used in BlockYReduce, get the index of reduce_num in shared
* memory.
......
......@@ -32,8 +32,11 @@ from .framework import default_main_program
from paddle import _C_ops, _legacy_C_ops
__all__ = [
'set_gradient_clip', 'ErrorClipByValue', 'ClipGradByValue',
'ClipGradByNorm', 'ClipGradByGlobalNorm'
'set_gradient_clip',
'ErrorClipByValue',
'ClipGradByValue',
'ClipGradByNorm',
'ClipGradByGlobalNorm',
]
_clip_by_global_norm_using_mp_type_flag = False
......@@ -52,9 +55,10 @@ def _clip_by_global_norm_using_mp_type(*args):
def _cast_to_mp_type_if_enabled(x):
if (x.dtype == core.VarDesc.VarType.FP16
or x.dtype == core.VarDesc.VarType.BF16
) and _clip_by_global_norm_using_mp_type():
if (
x.dtype == core.VarDesc.VarType.FP16
or x.dtype == core.VarDesc.VarType.BF16
) and _clip_by_global_norm_using_mp_type():
return x.astype(core.VarDesc.VarType.FP32)
else:
return x
......@@ -66,8 +70,7 @@ def _squared_l2_norm(x):
"""
x = _cast_to_mp_type_if_enabled(x)
if core.is_compiled_with_xpu(
) or x.dtype == core.VarDesc.VarType.FP16 or x.dtype == core.VarDesc.VarType.BF16:
if core.is_compiled_with_xpu():
square = layers.square(x)
sum_square = layers.reduce_sum(square)
return sum_square
......@@ -78,7 +81,9 @@ def _squared_l2_norm(x):
return _legacy_C_ops.squared_l2_norm(x)
op_type = 'squared_l2_norm'
check_variable_and_dtype(x, 'x', ['float32', 'float64'], op_type)
check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'float16', 'uint16'], op_type
)
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(x.dtype)
......@@ -89,7 +94,6 @@ def _squared_l2_norm(x):
class BaseErrorClipAttr(object):
def __str__(self):
raise NotImplementedError()
......@@ -164,8 +168,9 @@ def error_clip_callback(block, context):
for grad_n in [n for n in op_desc.output_arg_names() if n in grad_to_var]:
fwd_var = block._var_recursive(grad_to_var[grad_n])
error_clip = getattr(fwd_var, "error_clip", None)
if not (error_clip is None
or isinstance(error_clip, BaseErrorClipAttr)):
if not (
error_clip is None or isinstance(error_clip, BaseErrorClipAttr)
):
raise TypeError(
"Variable's error_clip should be an instance of BaseErrorClipAttr or None."
)
......@@ -174,7 +179,6 @@ def error_clip_callback(block, context):
class ClipGradBase(object):
def __init__(self):
super(ClipGradBase, self).__init__()
......@@ -197,7 +201,8 @@ class ClipGradBase(object):
warnings.warn(
"'set_gradient_clip' will be ineffective, because you have "
"set 'need_clip' in 'ParamAttr'. So, 'set_gradient_clip' "
"is redundant and you can remove it.")
"is redundant and you can remove it."
)
break
return self._static_clip(params_grads)
......@@ -211,34 +216,34 @@ class ClipGradBase(object):
class ClipGradByValue(ClipGradBase):
"""
Limit the value of multi-dimensional Tensor :math:`X` to the range [min, max].
- Any values less than min are set to ``min``.
- Any values greater than max are set to ``max``.
The multi-dimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
The multi-dimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
(for example: :ref:`api_paddle_optimizer_SGD`).
Note:
``need_clip`` of ``ClipGradByValue`` HAS BEEN DEPRECATED since 2.0.
``need_clip`` of ``ClipGradByValue`` HAS BEEN DEPRECATED since 2.0.
Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.
Args:
max (float): The maximum value to clip by.
min (float, optional): The minimum value to clip by. if not set by user, it will be set to ``-max``
min (float, optional): The minimum value to clip by. if not set by user, it will be set to ``-max``
automatically. In this case, ``max`` must be greater than 0.
Examples:
.. code-block:: python
import paddle
x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
linear = paddle.nn.Linear(in_features=10, out_features=10,
weight_attr=paddle.ParamAttr(need_clip=True),
linear = paddle.nn.Linear(in_features=10, out_features=10,
weight_attr=paddle.ParamAttr(need_clip=True),
bias_attr=paddle.ParamAttr(need_clip=False))
out = linear(x)
loss = paddle.mean(out)
......@@ -252,7 +257,7 @@ class ClipGradByValue(ClipGradBase):
def __init__(self, max, min=None):
super(ClipGradByValue, self).__init__()
if min is None:
assert (max > 0.0)
assert max > 0.0
min = -max
self.max = float(max)
self.min = float(min)
......@@ -417,17 +422,17 @@ def _allow_pure_fp16_global_norm_clip(*args):
class ClipGradByGlobalNorm(ClipGradBase):
r"""
Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in
Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in
:math:`t\_list` , and limit it to ``clip_norm`` .
- If the global norm is greater than ``clip_norm`` , all elements of :math:`t\_list` will be compressed by a ratio.
- If the global norm is less than or equal to ``clip_norm`` , nothing will be done.
The list of Tensor :math:`t\_list` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
(for example: :ref:`api_paddle_optimizer_SGD`).
The clipping formula is:
......@@ -443,7 +448,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2}
Note:
``need_clip`` of ``ClipGradyGlobalNorm`` HAS BEEN DEPRECATED since 2.0.
``need_clip`` of ``ClipGradyGlobalNorm`` HAS BEEN DEPRECATED since 2.0.
Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.
Args:
......@@ -452,12 +457,12 @@ class ClipGradByGlobalNorm(ClipGradBase):
Examples:
.. code-block:: python
import paddle
x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
linear = paddle.nn.Linear(in_features=10, out_features=10,
weight_attr=paddle.ParamAttr(need_clip=True),
linear = paddle.nn.Linear(in_features=10, out_features=10,
weight_attr=paddle.ParamAttr(need_clip=True),
bias_attr=paddle.ParamAttr(need_clip=False))
out = linear(x)
loss = paddle.mean(out)
......@@ -468,10 +473,9 @@ class ClipGradByGlobalNorm(ClipGradBase):
sdg.step()
"""
def __init__(self,
clip_norm,
group_name="default_group",
auto_skip_clip=False):
def __init__(
self, clip_norm, group_name="default_group", auto_skip_clip=False
):
super(ClipGradByGlobalNorm, self).__init__()
self.clip_norm = float(clip_norm)
self.group_name = group_name
......@@ -503,7 +507,10 @@ class ClipGradByGlobalNorm(ClipGradBase):
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
sum_square = _squared_l2_norm(merge_grad)
if sum_square.dtype == core.VarDesc.VarType.FP16 or sum_square.dtype == core.VarDesc.VarType.BF16:
if (
sum_square.dtype == core.VarDesc.VarType.FP16
or sum_square.dtype == core.VarDesc.VarType.BF16
):
sum_square_list_fp16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.FP32:
sum_square_list_fp32.append(sum_square)
......@@ -511,8 +518,12 @@ class ClipGradByGlobalNorm(ClipGradBase):
sum_square_list.append(sum_square)
# all parameters have been filterd out
if len(sum_square_list) + len(sum_square_list_fp16) + len(
sum_square_list_fp32) == 0:
if (
len(sum_square_list)
+ len(sum_square_list_fp16)
+ len(sum_square_list_fp32)
== 0
):
return params_grads
sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32"
......@@ -531,22 +542,23 @@ class ClipGradByGlobalNorm(ClipGradBase):
global_norm_var.append(global_norm_var_fp64)
global_norm_var = paddle.add_n(global_norm_var)
global_norm_var = layers.sqrt(global_norm_var)
max_global_norm = layers.fill_constant(shape=[1],
dtype=global_norm_var.dtype,
value=self.clip_norm)
max_global_norm = layers.fill_constant(
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm
)
need_clip = False
if not self.auto_skip_clip: # always apply clip
need_clip = True
clip_var = layers.elementwise_div(x=max_global_norm,
y=layers.elementwise_max(
x=global_norm_var,
y=max_global_norm))
clip_var = layers.elementwise_div(
x=max_global_norm,
y=layers.elementwise_max(x=global_norm_var, y=max_global_norm),
)
elif global_norm_var > max_global_norm:
# only when global_norm_var > max_global_norm, grad need clip
need_clip = True
clip_var = layers.elementwise_div(x=max_global_norm,
y=global_norm_var)
clip_var = layers.elementwise_div(
x=max_global_norm, y=global_norm_var
)
for p, g in params_grads:
if g is None:
......@@ -556,8 +568,11 @@ class ClipGradByGlobalNorm(ClipGradBase):
continue
# TODO(wangxi): use inplace elementwise_mul
if need_clip:
clip_input = (clip_var.astype(g.dtype)
if clip_var.dtype != g.dtype else clip_var)
clip_input = (
clip_var.astype(g.dtype)
if clip_var.dtype != g.dtype
else clip_var
)
new_grad = layers.elementwise_mul(g, clip_input)
params_and_grads.append((p, new_grad))
else:
......@@ -581,7 +596,8 @@ class ClipGradByGlobalNorm(ClipGradBase):
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(
merge_grad)
merge_grad
)
sum_square = _squared_l2_norm(merge_grad)
if sum_square.dtype == core.VarDesc.VarType.FP16:
sum_square_list_fp16.append(sum_square)
......@@ -591,8 +607,12 @@ class ClipGradByGlobalNorm(ClipGradBase):
sum_square_list.append(sum_square)
# all parameters have been filterd out
if len(sum_square_list) + len(sum_square_list_fp16) + len(
sum_square_list_fp32) == 0:
if (
len(sum_square_list)
+ len(sum_square_list_fp16)
+ len(sum_square_list_fp32)
== 0
):
return params_grads
with p.block.program._optimized_guard([p, g]):
......@@ -601,10 +621,14 @@ class ClipGradByGlobalNorm(ClipGradBase):
global_norm_var = []
if len(sum_square_list_fp16) > 0:
global_norm_var_fp16 = layers.sums(sum_square_list_fp16)
if sum_square_list_fp32 or sum_square_list or not _allow_pure_fp16_global_norm_clip(
if (
sum_square_list_fp32
or sum_square_list
or not _allow_pure_fp16_global_norm_clip()
):
global_norm_var.append(
global_norm_var_fp16.astype(sum_dtype))
global_norm_var_fp16.astype(sum_dtype)
)
else:
global_norm_var.append(global_norm_var_fp16)
if len(sum_square_list_fp32) > 0:
......@@ -613,23 +637,28 @@ class ClipGradByGlobalNorm(ClipGradBase):
global_norm_var.append(global_norm_var_fp32)
else:
global_norm_var.append(
global_norm_var_fp32.astype(sum_dtype))
global_norm_var_fp32.astype(sum_dtype)
)
if len(sum_square_list) > 0:
# fp64
global_norm_var_other_dtype = layers.sums(sum_square_list)
global_norm_var.append(global_norm_var_other_dtype)
global_norm_var = layers.sums(global_norm_var) if len(
global_norm_var) > 1 else global_norm_var[0]
global_norm_var = (
layers.sums(global_norm_var)
if len(global_norm_var) > 1
else global_norm_var[0]
)
global_norm_var = layers.sqrt(x=global_norm_var)
max_global_norm = layers.fill_constant(
shape=[1],
dtype=global_norm_var.dtype,
value=self.clip_norm)
scale_var = layers.elementwise_div(x=max_global_norm,
y=layers.elementwise_max(
x=max_global_norm,
y=global_norm_var))
shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm
)
scale_var = layers.elementwise_div(
x=max_global_norm,
y=layers.elementwise_max(
x=max_global_norm, y=global_norm_var
),
)
param_new_grad_name_dict = dict()
for p, g in params_grads:
if g is None:
......@@ -641,29 +670,32 @@ class ClipGradByGlobalNorm(ClipGradBase):
with p.block.program._optimized_guard([p, g]):
new_g = _cast_to_mp_type_if_enabled(g)
# inplace
scale_input = (scale_var.astype('float16') if
new_g.dtype == core.VarDesc.VarType.FP16 and
scale_var.dtype != core.VarDesc.VarType.FP16
else scale_var)
scale_input = (
scale_var.astype('float16')
if new_g.dtype == core.VarDesc.VarType.FP16
and scale_var.dtype != core.VarDesc.VarType.FP16
else scale_var
)
# NOTE(Yuang Liu): For pure dp with gradient merge, the p and g
# will be in different blocks with the gradient clip related ops.
# We need to handle the correct block, otherwise will encounter
# a 'NotFoundError' during compile time.
block = default_main_program().current_block()
block.append_op(type='elementwise_mul',
inputs={
'X': new_g,
'Y': scale_input
},
outputs={'Out': new_g})
block.append_op(
type='elementwise_mul',
inputs={'X': new_g, 'Y': scale_input},
outputs={'Out': new_g},
)
if new_g is not g:
block.append_op(type='cast',
inputs={'X': new_g},
outputs={'Out': g},
attrs={
'in_dtype': new_g.dtype,
'out_dtype': g.dtype
})
block.append_op(
type='cast',
inputs={'X': new_g},
outputs={'Out': g},
attrs={
'in_dtype': new_g.dtype,
'out_dtype': g.dtype,
},
)
param_new_grad_name_dict[p.name] = g.name
params_and_grads.append((p, g))
......@@ -676,7 +708,8 @@ class ClipGradByGlobalNorm(ClipGradBase):
context[self.group_name] = []
context[self.group_name + "_clip_value"] = self.clip_norm
context[self.group_name + "_clip"] = layers.fill_constant(
shape=[1], dtype=grad.dtype, value=self.clip_norm)
shape=[1], dtype=grad.dtype, value=self.clip_norm
)
else:
if not self.clip_norm == context[self.group_name + "_clip_value"]:
raise ValueError(
......@@ -699,20 +732,19 @@ class ClipGradByGlobalNorm(ClipGradBase):
group_norm_var = layers.sums(input=self.context[self.group_name])
group_norm_var = layers.sqrt(x=group_norm_var)
clip_var = self.context[self.group_name + "_clip"]
group_scale_var = layers.elementwise_div(x=clip_var,
y=layers.elementwise_max(
x=clip_var,
y=group_norm_var))
assert group_scale_var.shape == (1, )
group_scale_var = layers.elementwise_div(
x=clip_var,
y=layers.elementwise_max(x=clip_var, y=group_norm_var),
)
assert group_scale_var.shape == (1,)
self.context[group_scale_name] = group_scale_var
# inplace
param.block.append_op(type='elementwise_mul',
inputs={
'X': grad,
'Y': self.context[group_scale_name]
},
outputs={'Out': grad})
param.block.append_op(
type='elementwise_mul',
inputs={'X': grad, 'Y': self.context[group_scale_name]},
outputs={'Out': grad},
)
return param, grad
......@@ -721,23 +753,23 @@ class ClipGradByGlobalNorm(ClipGradBase):
def set_gradient_clip(clip, param_list=None, program=None):
"""
:api_attr: Static Graph
Warning:
This API must be used after building network, and before ``minimize`` ,
and it may be removed in future releases, so it is not recommended.
This API must be used after building network, and before ``minimize`` ,
and it may be removed in future releases, so it is not recommended.
It is recommended to set ``grad_clip`` when initializing the ``optimizer`` ,
this is a better method to clip gradient. There are three clipping strategies:
:ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` .
To specify parameters that require gradient clip.
Args:
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default value: None, and there is no
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default value: None, and there is no
gradient clipping.
param_list (list(Variable), optional): Parameters that require gradient clip.
It can be a list of parameter or a list of parameter's name.
......@@ -791,7 +823,7 @@ def set_gradient_clip(clip, param_list=None, program=None):
param_list=[param_var1, param_var2])
sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss)
# network 4: use 'set_gradient_clip' and 'optimize(grad_clip=clip)' together
with fluid.program_guard(fluid.Program(), fluid.Program()):
loss = network()
......@@ -802,27 +834,31 @@ def set_gradient_clip(clip, param_list=None, program=None):
# Set the gradient clipping strategy: clip2
sgd = fluid.optimizer.SGD(learning_rate=1e-3, grad_clip=clip2)
sgd.minimize(loss)
# 'set_gradient_clip' will not take effect when setting has a conflict,
# 'set_gradient_clip' will not take effect when setting has a conflict,
# and the gradient clipping strategy will be 'clip2'
"""
warnings.warn("Caution! 'set_gradient_clip' is not recommended "
"and may be deprecated in future! "
"We recommend a new strategy: set 'grad_clip' "
"when initializing the 'optimizer'. "
"This method can reduce the mistakes, please "
"refer to documention of 'optimizer'.")
warnings.warn(
"Caution! 'set_gradient_clip' is not recommended "
"and may be deprecated in future! "
"We recommend a new strategy: set 'grad_clip' "
"when initializing the 'optimizer'. "
"This method can reduce the mistakes, please "
"refer to documention of 'optimizer'."
)
if not isinstance(clip, ClipGradBase):
raise TypeError(
"'clip' should be an instance of ClipGradBase's derived class")
"'clip' should be an instance of ClipGradBase's derived class"
)
if program is None:
program = framework.default_main_program()
for op in program.block(0).ops:
if 'op_namescope' in op.all_attrs() and "optimizer" in op.attr(
"op_namescope"):
"op_namescope"
):
warnings.warn(
"'minimize' has been invoked before, this will make 'set_gradient_clip' "
"be ineffective! Please invoke 'set_gradient_clip' before 'minimize'."
......@@ -847,14 +883,16 @@ def append_gradient_clip_ops(param_grads):
for p, g in param_grads:
if g is None:
continue
with p.block.program._optimized_guard(
[p, g]), framework.name_scope('gradient_clip'):
with p.block.program._optimized_guard([p, g]), framework.name_scope(
'gradient_clip'
):
clip_attr = getattr(p, 'gradient_clip_attr', None)
if clip_attr is None:
return param_grads
if not isinstance(clip_attr, ClipGradBase):
raise TypeError(
"clip attribute should be an instance of GradientClipBase")
"clip attribute should be an instance of GradientClipBase"
)
clip_attr._process_context(context=context, param=p, grad=g)
......@@ -863,8 +901,9 @@ def append_gradient_clip_ops(param_grads):
for p, g in param_grads:
if g is None:
continue
with p.block.program._optimized_guard(
[p, g]), framework.name_scope('gradient_clip'):
with p.block.program._optimized_guard([p, g]), framework.name_scope(
'gradient_clip'
):
param, new_grad = clip_attr._create_operators(param=p, grad=g)
param_new_grad_name_dict[param.name] = new_grad.name
res.append([param, new_grad])
......@@ -888,12 +927,16 @@ def _correct_clip_op_role_var(params_grads, param_new_grad_name_dict):
continue
block_id_list.append(block_id)
for op in param.block.program.global_block().ops:
if op.has_attr("op_namescope") and "gradient_clip" in op.attr(
"op_namescope") and op.attr('op_role_var'):
if (
op.has_attr("op_namescope")
and "gradient_clip" in op.attr("op_namescope")
and op.attr('op_role_var')
):
param_name = op.attr('op_role_var')[0]
if param_name in param_new_grad_name_dict:
correct_p_g = [
param_name, param_new_grad_name_dict[param_name]
param_name,
param_new_grad_name_dict[param_name],
]
op._set_attr('op_role_var', correct_p_g)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
class TestAddnOp(unittest.TestCase):
def setUp(self):
np.random.seed(20)
l = 32
self.x_np = np.random.random([l, 16, 256])
def check_main(self, x_np, dtype, axis=None):
paddle.disable_static()
x = []
for i in range(x_np.shape[0]):
val = paddle.to_tensor(x_np[i].astype(dtype))
val.stop_gradient = False
x.append(val)
y = paddle.add_n(x)
x_g = paddle.grad(y, x)
y_np = y.numpy().astype('float32')
x_g_np = []
for val in x_g:
x_g_np.append(val.numpy().astype('float32'))
paddle.enable_static()
return y_np, x_g_np
def test_add_n_fp16(self):
if not paddle.is_compiled_with_cuda():
return
y_np_16, x_g_np_16 = self.check_main(self.x_np, 'float16')
y_np_32, x_g_np_32 = self.check_main(self.x_np, 'float32')
np.testing.assert_allclose(y_np_16, y_np_32, rtol=1e-03)
for i in range(len(x_g_np_32)):
np.testing.assert_allclose(x_g_np_16[i], x_g_np_32[i], rtol=1e-03)
def test_add_n_api(self):
if not paddle.is_compiled_with_cuda():
return
y_np_32, x_g_np_32 = self.check_main(self.x_np, 'float32')
y_np_gt = np.sum(self.x_np, axis=0).astype('float32')
np.testing.assert_allclose(y_np_32, y_np_gt, rtol=1e-06)
if __name__ == "__main__":
unittest.main()
......@@ -26,21 +26,17 @@ from paddle.fluid.clip import _allow_pure_fp16_global_norm_clip
paddle.enable_static()
def bow_net(data,
label,
dict_dim,
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2):
def bow_net(
data, label, dict_dim, emb_dim=128, hid_dim=128, hid_dim2=96, class_dim=2
):
"""
BOW net
This model is from https://github.com/PaddlePaddle/models:
fluid/PaddleNLP/text_classification/nets.py
"""
emb = fluid.layers.embedding(input=data,
is_sparse=True,
size=[dict_dim, emb_dim])
emb = fluid.layers.embedding(
input=data, is_sparse=True, size=[dict_dim, emb_dim]
)
bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
bow_tanh = fluid.layers.tanh(bow)
fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh")
......@@ -53,7 +49,6 @@ def bow_net(data,
class TestGradientClip(unittest.TestCase):
def setUp(self):
self.word_dict_len = 5147
self.BATCH_SIZE = 2
......@@ -77,8 +72,9 @@ class TestGradientClip(unittest.TestCase):
def check_gradient_clip(self, place, dtype='float32'):
prog = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program=prog,
startup_program=startup_program):
with fluid.program_guard(
main_program=prog, startup_program=startup_program
):
image = fluid.data(name="a", shape=[-1, 784], dtype='float32')
label = fluid.data(name="b", shape=[-1, 1], dtype='int64')
if dtype != 'float32':
......@@ -99,8 +95,9 @@ class TestGradientClip(unittest.TestCase):
p_g = sorted(p_g, key=lambda x: x[0].name)
p_g_clip = sorted(p_g_clip, key=lambda x: x[0].name)
with fluid.program_guard(main_program=prog_clip,
startup_program=startup_program):
with fluid.program_guard(
main_program=prog_clip, startup_program=startup_program
):
p_g_clip = self.clip_gradient(p_g_clip)
grad_list = [elem[1] for elem in p_g]
......@@ -113,20 +110,20 @@ class TestGradientClip(unittest.TestCase):
data = next(train_reader())
out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list)
out_clip = exe.run(prog_clip,
feed=feeder.feed(data),
fetch_list=grad_clip_list)
out_clip = exe.run(
prog_clip, feed=feeder.feed(data), fetch_list=grad_clip_list
)
self.check_clip_result(out, out_clip)
def check_sparse_gradient_clip(self, place):
prog = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program=prog,
startup_program=startup_program):
data = fluid.data(name="words",
shape=[-1, 1],
dtype="int64",
lod_level=1)
with fluid.program_guard(
main_program=prog, startup_program=startup_program
):
data = fluid.data(
name="words", shape=[-1, 1], dtype="int64", lod_level=1
)
label = fluid.data(name="label", shape=[-1, 1], dtype="int64")
cost = bow_net(data, label, self.word_dict_len)
......@@ -138,7 +135,7 @@ class TestGradientClip(unittest.TestCase):
data = next(self.train_data())
val = exe.run(prog, feed=feeder.feed(data), fetch_list=[cost])[0]
self.assertEqual((1, ), val.shape)
self.assertEqual((1,), val.shape)
self.assertFalse(np.isnan(val))
def backward_and_optimize(self, cost):
......@@ -146,7 +143,6 @@ class TestGradientClip(unittest.TestCase):
class TestGradientClipByGlobalNorm(TestGradientClip):
def init(self):
self.clip_norm = 0.2
......@@ -166,13 +162,13 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
v,
rtol=1e-05,
atol=1e-08,
err_msg=
'gradient clip by global norm has wrong results!, \nu={}\nv={}\ndiff={}'
.format(u, v, u - v))
err_msg='gradient clip by global norm has wrong results!, \nu={}\nv={}\ndiff={}'.format(
u, v, u - v
),
)
# test whether the output is right when use 'set_gradient_clip'
def test_old_gradient_clip(self):
def func(params_grads):
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
fluid.clip.set_gradient_clip(clip)
......@@ -183,7 +179,6 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
# test whether the output is right when use grad_clip
def test_new_gradient_clip(self):
def func(params_grads):
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
return clip(params_grads)
......@@ -193,7 +188,6 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
# test whether the output is right when use grad_clip under float64
def test_new_gradient_clip_fp64(self):
def func(params_grads):
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
return clip(params_grads)
......@@ -203,12 +197,12 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
# invoke 'set_gradient_clip' in a wrong order
def test_wrong_API_order(self):
def backward_func(cost):
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0)
fluid.clip.set_gradient_clip(clip)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01,
grad_clip=clip)
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=0.01, grad_clip=clip
)
# if 'set_gradient_clip' and 'optimize(grad_clip)' together, 'set_gradient_clip' will be ineffective
sgd_optimizer.minimize(cost)
# 'set_gradient_clip' must before 'minimize', otherwise, 'set_gradient_clip' will be ineffective
......@@ -222,43 +216,72 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
def test_tpyeError(self):
# the type of optimizer(grad_clip=) must be an instance of GradientClipBase's derived class
with self.assertRaises(TypeError):
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1,
grad_clip="test")
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=0.1, grad_clip="test"
)
# if grad is None or not need clip
def test_none_grad_fp32(self):
ops = self._test_none_grad_helper("float32")
self.assertListEqual(ops, [
'squared_l2_norm', 'squared_l2_norm', 'sum', 'sqrt',
'fill_constant', 'elementwise_max', 'elementwise_div',
'elementwise_mul', 'elementwise_mul'
])
self.assertListEqual(
ops,
[
'squared_l2_norm',
'squared_l2_norm',
'sum',
'sqrt',
'fill_constant',
'elementwise_max',
'elementwise_div',
'elementwise_mul',
'elementwise_mul',
],
)
def test_none_grad_fp16(self):
ops = self._test_none_grad_helper("float16")
self.assertListEqual(ops, [
'square', 'reduce_sum', 'square', 'reduce_sum', 'sum', 'cast',
'sqrt', 'fill_constant', 'elementwise_max', 'elementwise_div',
'cast', 'elementwise_mul', 'cast', 'elementwise_mul'
])
self.assertListEqual(
ops,
[
'squared_l2_norm',
'squared_l2_norm',
'sum',
'cast',
'sqrt',
'fill_constant',
'elementwise_max',
'elementwise_div',
'cast',
'elementwise_mul',
'cast',
'elementwise_mul',
],
)
def _test_none_grad_helper(self, dtype):
prog = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program=prog,
startup_program=startup_program):
with fluid.program_guard(
main_program=prog, startup_program=startup_program
):
clip = fluid.clip.GradientClipByGlobalNorm(self.clip_norm)
x = fluid.default_main_program().global_block().create_parameter(
name="x", shape=[2, 3], dtype=dtype)
y = fluid.default_main_program().global_block().create_parameter(
name="y", shape=[2, 3], dtype=dtype)
x = (
fluid.default_main_program()
.global_block()
.create_parameter(name="x", shape=[2, 3], dtype=dtype)
)
y = (
fluid.default_main_program()
.global_block()
.create_parameter(name="y", shape=[2, 3], dtype=dtype)
)
# (x, None) should not be returned
params_grads = [(x, None), (x, y), (y, x)]
params_grads = clip(params_grads)
self.assertTrue(
len(params_grads) == 2,
"ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!"
"ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!",
)
ops = [op.type for op in x.block.ops]
......@@ -266,7 +289,6 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
class TestGradientClipByNorm(TestGradientClip):
def init(self):
self.clip_norm = 0.2
......@@ -280,11 +302,11 @@ class TestGradientClipByNorm(TestGradientClip):
v,
rtol=1e-05,
atol=1e-08,
err_msg='gradient clip by norm has wrong results!')
err_msg='gradient clip by norm has wrong results!',
)
# test whether the output is right when use grad_clip
def test_gradient_clip(self):
def func(params_grads):
clip = fluid.clip.GradientClipByNorm(clip_norm=self.clip_norm)
return clip(params_grads)
......@@ -295,25 +317,35 @@ class TestGradientClipByNorm(TestGradientClip):
# if grad is None or not need clip
def test_none_grad(self):
clip = fluid.clip.GradientClipByNorm(self.clip_norm)
x = fluid.default_main_program().global_block().create_parameter(
name="x", shape=[2, 3], dtype="float32", need_clip=False)
y = fluid.default_main_program().global_block().create_parameter(
name="y", shape=[2, 3], dtype="float32", need_clip=False)
x = (
fluid.default_main_program()
.global_block()
.create_parameter(
name="x", shape=[2, 3], dtype="float32", need_clip=False
)
)
y = (
fluid.default_main_program()
.global_block()
.create_parameter(
name="y", shape=[2, 3], dtype="float32", need_clip=False
)
)
# (x, None) should not be returned
params_grads = [(x, None), (x, y)]
params_grads = clip(params_grads)
self.assertTrue(
len(clip(params_grads)) == 1,
"ClipGradByNorm: when grad is None, it shouldn't be returned by gradient clip!"
"ClipGradByNorm: when grad is None, it shouldn't be returned by gradient clip!",
)
self.assertTrue(
params_grads[0][1].name == 'y',
"ClipGradByNorm: grad should not be clipped when filtered out!")
"ClipGradByNorm: grad should not be clipped when filtered out!",
)
class TestGradientClipByValue(TestGradientClip):
def init(self):
self.max = 0.2
self.min = 0.1
......@@ -328,11 +360,11 @@ class TestGradientClipByValue(TestGradientClip):
v,
rtol=1e-06,
atol=1e-08,
err_msg='gradient clip by value has wrong results!')
err_msg='gradient clip by value has wrong results!',
)
# test whether the output is right when use grad_clip
def test_gradient_clip(self):
def func(params_grads):
clip = fluid.clip.GradientClipByValue(max=self.max, min=self.min)
return clip(params_grads)
......@@ -343,37 +375,49 @@ class TestGradientClipByValue(TestGradientClip):
# if grad is None or not need clip
def test_none_grad(self):
clip = fluid.clip.GradientClipByValue(self.max, self.min)
x = fluid.default_main_program().global_block().create_parameter(
name="x", shape=[2, 3], dtype="float32", need_clip=False)
y = fluid.default_main_program().global_block().create_parameter(
name="y", shape=[2, 3], dtype="float32", need_clip=False)
x = (
fluid.default_main_program()
.global_block()
.create_parameter(
name="x", shape=[2, 3], dtype="float32", need_clip=False
)
)
y = (
fluid.default_main_program()
.global_block()
.create_parameter(
name="y", shape=[2, 3], dtype="float32", need_clip=False
)
)
# (x, None) should not be returned
params_grads = [(x, None), (x, y)]
params_grads = clip(params_grads)
self.assertTrue(
len(clip(params_grads)) == 1,
"ClipGradByValue: when grad is None, it shouldn't be returned by gradient clip!"
"ClipGradByValue: when grad is None, it shouldn't be returned by gradient clip!",
)
self.assertTrue(
params_grads[0][1].name == 'y',
"ClipGradByValue: grad should not be clipped when filtered out!")
"ClipGradByValue: grad should not be clipped when filtered out!",
)
class TestDygraphGradientClip(unittest.TestCase):
def test_gradient_clip(self):
with fluid.dygraph.guard():
linear = fluid.dygraph.Linear(5, 5)
inputs = fluid.layers.uniform_random([16, 5], min=-10,
max=10).astype('float32')
inputs = fluid.layers.uniform_random(
[16, 5], min=-10, max=10
).astype('float32')
out = linear(fluid.dygraph.to_variable(inputs))
loss = fluid.layers.reduce_mean(out)
loss.backward()
sgd_optimizer = fluid.optimizer.SGD(
learning_rate=0.0,
parameter_list=linear.parameters(),
grad_clip=fluid.clip.GradientClipByGlobalNorm(0.1))
grad_clip=fluid.clip.GradientClipByGlobalNorm(0.1),
)
self.check_clip_result(loss, sgd_optimizer)
def check_clip_result(self, loss, optimizer):
......@@ -381,20 +425,23 @@ class TestDygraphGradientClip(unittest.TestCase):
class TestDygraphGradientClipByGlobalNorm(TestDygraphGradientClip):
def setUp(self):
self.clip_norm = 0.8
self.clip1 = fluid.clip.GradientClipByGlobalNorm(
clip_norm=self.clip_norm)
clip_norm=self.clip_norm
)
self.clip2 = fluid.clip.GradientClipByGlobalNorm(
clip_norm=self.clip_norm)
clip_norm=self.clip_norm
)
def check_clip_result(self, loss, optimizer):
# if grad is None
x = fluid.dygraph.to_variable(np.array([2, 3]).astype("float32"),
name="x")
y = fluid.dygraph.to_variable(np.array([3, 4]).astype("float32"),
name="y")
x = fluid.dygraph.to_variable(
np.array([2, 3]).astype("float32"), name="x"
)
y = fluid.dygraph.to_variable(
np.array([3, 4]).astype("float32"), name="y"
)
assert len(self.clip1([(x, x), (x, y), (x, None)])) == 2
# get params and grads from network
opt, params_grads = optimizer.minimize(loss)
......@@ -419,11 +466,11 @@ class TestDygraphGradientClipByGlobalNorm(TestDygraphGradientClip):
self.assertTrue(
np.isclose(a=a, b=b, rtol=1e-6, atol=1e-8),
"gradient clip by global norm has wrong results, expetcd:%f, but received:%f"
% (a, b))
% (a, b),
)
class TestDygraphGradientClipByNorm(TestDygraphGradientClip):
def setUp(self):
self.clip_norm = 0.8
self.clip = fluid.clip.GradientClipByNorm(clip_norm=self.clip_norm)
......@@ -448,11 +495,11 @@ class TestDygraphGradientClipByNorm(TestDygraphGradientClip):
self.assertTrue(
np.isclose(a=a, b=b, rtol=1e-6, atol=1e-8),
"gradient clip by norm has wrong results, expetcd:%f, but received:%f"
% (a, b))
% (a, b),
)
class TestDygraphGradientClipByValue(TestDygraphGradientClip):
def setUp(self):
self.max = 0.2
self.min = 0.1
......@@ -475,11 +522,11 @@ class TestDygraphGradientClipByValue(TestDygraphGradientClip):
v,
rtol=1e-06,
atol=1e-08,
err_msg='gradient clip by value has wrong results!')
err_msg='gradient clip by value has wrong results!',
)
class SimpleNet(paddle.nn.Layer):
def __init__(self):
super(SimpleNet, self).__init__()
self.linear = paddle.nn.Linear(5, 5)
......@@ -492,19 +539,21 @@ class SimpleNet(paddle.nn.Layer):
class TestDygraphGradientClipFP16(unittest.TestCase):
def test_gradient_clip(self):
if fluid.core.is_compiled_with_cuda():
with fluid.dygraph.guard():
paddle.seed(10)
model = SimpleNet()
sgd_optimizer = paddle.optimizer.SGD(
learning_rate=0.0, parameters=model.parameters())
learning_rate=0.0, parameters=model.parameters()
)
model, sgd_optimizer = paddle.amp.decorate(
models=model, optimizers=sgd_optimizer, level='O2')
models=model, optimizers=sgd_optimizer, level='O2'
)
scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
inputs = fluid.layers.uniform_random([1, 5], min=-10,
max=10).astype('float32')
inputs = fluid.layers.uniform_random(
[1, 5], min=-10, max=10
).astype('float32')
with paddle.amp.auto_cast(level='O2'):
out = model(fluid.dygraph.to_variable(inputs))
loss = fluid.layers.reduce_mean(out)
......@@ -543,15 +592,16 @@ class TestDygraphGradientClipFP16(unittest.TestCase):
self.assertTrue(
np.isclose(a=a, b=b, rtol=1e-3, atol=1e-8),
"gradient clip by global norm has wrong results, expetcd:%f, but received:%f"
% (a, b))
% (a, b),
)
class TestDygraphGradientClipFP64(unittest.TestCase):
def test_gradient_clip(self):
with fluid.dygraph.guard():
inputs = fluid.layers.uniform_random([16, 5], min=-10,
max=10).astype('float64')
inputs = fluid.layers.uniform_random(
[16, 5], min=-10, max=10
).astype('float64')
linear = fluid.dygraph.Linear(5, 5, dtype="float64")
out = linear(fluid.dygraph.to_variable(inputs))
loss = fluid.layers.reduce_mean(out)
......@@ -589,11 +639,11 @@ class TestDygraphGradientClipFP64(unittest.TestCase):
self.assertTrue(
np.isclose(a=a, b=b, rtol=1e-6, atol=1e-8),
"gradient clip by global norm has wrong results, expetcd:%f, but received:%f"
% (a, b))
% (a, b),
)
class TestPureFP16ClipGradByGlobalNorm(unittest.TestCase):
def check_main(self, expected_has_cast_op):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
......@@ -604,12 +654,12 @@ class TestPureFP16ClipGradByGlobalNorm(unittest.TestCase):
param_and_grads = []
main_block = main_prog.global_block()
for name, shape in zip(names, shapes):
p = main_block.create_parameter(name=name,
shape=shape,
dtype='float16')
g = main_block.create_parameter(name=p.name + '@GRAD',
shape=p.shape,
dtype=p.dtype)
p = main_block.create_parameter(
name=name, shape=shape, dtype='float16'
)
g = main_block.create_parameter(
name=p.name + '@GRAD', shape=p.shape, dtype=p.dtype
)
param_and_grads.append((p, g))
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
......
......@@ -148,19 +148,21 @@ class AdamW(Optimizer):
_beta1_pow_acc_str = "beta1_pow_acc"
_beta2_pow_acc_str = "beta2_pow_acc"
def __init__(self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
parameters=None,
weight_decay=0.01,
lr_ratio=None,
apply_decay_param_fun=None,
grad_clip=None,
lazy_mode=False,
multi_precision=False,
name=None):
def __init__(
self,
learning_rate=0.001,
beta1=0.9,
beta2=0.999,
epsilon=1e-8,
parameters=None,
weight_decay=0.01,
lr_ratio=None,
apply_decay_param_fun=None,
grad_clip=None,
lazy_mode=False,
multi_precision=False,
name=None,
):
assert learning_rate is not None
assert beta1 is not None
assert beta2 is not None
......@@ -171,14 +173,16 @@ class AdamW(Optimizer):
raise ValueError("Invaild value of beta2, expect beta2 in [0,1).")
if not 0 <= epsilon:
raise ValueError("Invaild value of epsilon, expect epsilon >= 0.")
if not isinstance(weight_decay, float) and \
not isinstance(weight_decay, framework.Variable):
if not isinstance(weight_decay, float) and not isinstance(
weight_decay, framework.Variable
):
raise TypeError("weight_decay should be float or Tensor.")
if lr_ratio is not None:
assert isinstance(lr_ratio, Callable)
if not core.is_compiled_with_cuda():
raise NotImplementedError(
"'lr_ratio' is unimplemented in CPU, XPU and NPU")
"'lr_ratio' is unimplemented in CPU, XPU and NPU"
)
if parameters is not None:
# paddle.Tensor is also iterable, so here we don't check whether
......@@ -187,13 +191,16 @@ class AdamW(Optimizer):
if isinstance(parameters, (paddle.Tensor, core.eager.Tensor)):
raise TypeError(
"`parameters` argument given to the optimizer should be "
"an iterable of paddle Tensors, but got argument type is `{}`."
.format(type(parameters)))
"an iterable of paddle Tensors, but got argument type is `{}`.".format(
type(parameters)
)
)
if isinstance(parameters, dict):
raise TypeError(
"`parameters` argument should not get dict type, "
"if parameter groups is needed, please set `parameters`"
" as list of dict")
" as list of dict"
)
self._parameter_list = list(parameters)
else:
self._parameter_list = None
......@@ -207,8 +214,9 @@ class AdamW(Optimizer):
if not isinstance(learning_rate, (float, LRScheduler)):
raise TypeError(
"learning rate should be float or LRScheduler, got %s here" %
type(learning_rate))
"learning rate should be float or LRScheduler, got %s here"
% type(learning_rate)
)
if grad_clip is not None:
if not isinstance(grad_clip, GradientClipBase):
raise TypeError(
......@@ -220,8 +228,9 @@ class AdamW(Optimizer):
if self._parameter_list:
if isinstance(self._parameter_list[0], dict):
for param_group in self._parameter_list:
assert 'params' in param_group, \
'params should be set in parameters if parameter groups are optimized in different options'
assert (
'params' in param_group
), 'params should be set in parameters if parameter groups are optimized in different options'
self._dtype = self._parameter_list[0]['params'][0].dtype
else:
self._dtype = self._parameter_list[0].dtype
......@@ -260,7 +269,7 @@ class AdamW(Optimizer):
'beta2': beta2,
'epsilon': epsilon,
'lazy_mode': lazy_mode,
'grad_clip': grad_clip
'grad_clip': grad_clip,
}
self._param_groups = []
......@@ -297,7 +306,8 @@ class AdamW(Optimizer):
elif isinstance(params, set):
raise TypeError(
"optimizer parameters should be in ordered collections,"
"but received set, please use list instead.")
"but received set, please use list instead."
)
else:
param_group['params'] = list(params)
......@@ -311,11 +321,13 @@ class AdamW(Optimizer):
if not param_set.isdisjoint(set(param_group['params'])):
raise ValueError(
"some parameters appear in more than one parameter group")
"some parameters appear in more than one parameter group"
)
for param in param_group['params']:
param.optimize_attr['learning_rate'] = param_group.get(
'learning_rate', 1.)
'learning_rate', 1.0
)
self._param_groups.append(param_group)
......@@ -327,19 +339,23 @@ class AdamW(Optimizer):
var_name = param.name + "_fp32_master"
var_name = unique_name.generate(var_name)
var = layers.create_global_var(name=var_name,
shape=param.shape,
value=0,
dtype='float32',
persistable=True)
var = layers.create_global_var(
name=var_name,
shape=param.shape,
value=0,
dtype='float32',
persistable=True,
)
block = self.helper.startup_program.global_block()
block.append_op(type="cast",
inputs={"X": [param]},
outputs={"Out": [var]},
attrs={
"in_dtype": param.dtype,
"out_dtype": core.VarDesc.VarType.FP32
})
block.append_op(
type="cast",
inputs={"X": [param]},
outputs={"Out": [var]},
attrs={
"in_dtype": param.dtype,
"out_dtype": core.VarDesc.VarType.FP32,
},
)
self._master_weights[param.name] = var
return var
......@@ -353,20 +369,31 @@ class AdamW(Optimizer):
"""
if self._name is not None:
name = self._name + "_" + name
find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16
target_param = self._master_weights[
param.name] if find_master else param
find_master = self._multi_precision and (
param.dtype == core.VarDesc.VarType.FP16
or param.dtype == core.VarDesc.VarType.BF16
)
target_param = (
self._master_weights[param.name] if find_master else param
)
target_name = target_param.name
if (name not in self._accumulators
or target_name not in self._accumulators[name]):
if (
name not in self._accumulators
or target_name not in self._accumulators[name]
):
raise Exception(
"Accumulator {} does not exist for parameter {}".format(
name, target_name))
name, target_name
)
)
return self._accumulators[name][target_name]
def _add_moments_pows(self, p):
acc_dtype = p.dtype
if acc_dtype == core.VarDesc.VarType.FP16:
if (
acc_dtype == core.VarDesc.VarType.FP16
or acc_dtype == core.VarDesc.VarType.BF16
):
acc_dtype = core.VarDesc.VarType.FP32
self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
......@@ -374,18 +401,24 @@ class AdamW(Optimizer):
name=self._beta1_pow_acc_str,
param=p,
dtype=acc_dtype,
fill_value=0.9 if isinstance(self._beta1, Variable) \
else self._beta1,
fill_value=0.9
if isinstance(self._beta1, Variable)
else self._beta1,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
type=core.VarDesc.VarType.LOD_TENSOR,
device='cpu',
)
self._add_accumulator(
name=self._beta2_pow_acc_str,
param=p,
dtype=acc_dtype,
fill_value=0.999 if isinstance(self._beta2, Variable) \
else self._beta2,
fill_value=0.999
if isinstance(self._beta2, Variable)
else self._beta2,
shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu')
type=core.VarDesc.VarType.LOD_TENSOR,
device='cpu',
)
def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block)
......@@ -394,13 +427,19 @@ class AdamW(Optimizer):
# Create accumulator tensors for first and second moments
for p in parameters:
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16:
if self._multi_precision and (
p.dtype == core.VarDesc.VarType.FP16
or p.dtype == core.VarDesc.VarType.BF16
):
master_p = self._create_master_weight(p)
self._add_moments_pows(master_p)
continue
if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision:
if (
p.dtype == core.VarDesc.VarType.FP16
or p.dtype == core.VarDesc.VarType.BF16
) and not self._multi_precision:
warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Accumulating with FP16/BFP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Adam optimizer."
)
self._add_moments_pows(p)
......@@ -413,53 +452,112 @@ class AdamW(Optimizer):
# Whether we should do weight decay for the parameter.
with_decay = True
if self._apply_decay_param_fun is not None \
and not self._apply_decay_param_fun(param.name):
if (
self._apply_decay_param_fun is not None
and not self._apply_decay_param_fun(param.name)
):
with_decay = False
moment1 = self._get_accumulator(self._moment1_acc_str,
param_and_grad[0])
moment2 = self._get_accumulator(self._moment2_acc_str,
param_and_grad[0])
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str,
param_and_grad[0])
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str,
param_and_grad[0])
find_master = self._multi_precision and param_and_grad[
0].dtype == core.VarDesc.VarType.FP16
master_weight = (self._master_weights[param_and_grad[0].name]
if find_master else None)
moment1 = self._get_accumulator(
self._moment1_acc_str, param_and_grad[0]
)
moment2 = self._get_accumulator(
self._moment2_acc_str, param_and_grad[0]
)
beta1_pow_acc = self._get_accumulator(
self._beta1_pow_acc_str, param_and_grad[0]
)
beta2_pow_acc = self._get_accumulator(
self._beta2_pow_acc_str, param_and_grad[0]
)
find_master = self._multi_precision and (
param_and_grad[0].dtype == core.VarDesc.VarType.FP16
or param_and_grad[0].dtype == core.VarDesc.VarType.BF16
)
master_weight = (
self._master_weights[param_and_grad[0].name]
if find_master
else None
)
lr = self._create_param_lr(param_and_grad)
# create the adamw optimize op
if framework._non_static_mode():
lr_ratio_ = 1. if self._lr_ratio is None else self._lr_ratio(
param_and_grad[0])
_beta1 = self._beta1 if not isinstance(
self._beta1, Variable) else self._beta1.numpy().item(0)
_beta2 = self._beta2 if not isinstance(
self._beta2, Variable) else self._beta2.numpy().item(0)
lr_ratio_ = (
1.0
if self._lr_ratio is None
else self._lr_ratio(param_and_grad[0])
)
_beta1 = (
self._beta1
if not isinstance(self._beta1, Variable)
else self._beta1.numpy().item(0)
)
_beta2 = (
self._beta2
if not isinstance(self._beta2, Variable)
else self._beta2.numpy().item(0)
)
if framework.in_dygraph_mode():
found_inf = self._get_auxiliary_var('found_inf')
_, _, _, _, _, _ = _C_ops.adamw_(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
beta1_pow_acc, beta2_pow_acc, master_weight, found_inf,
_beta1, _beta2, self._epsilon, lr_ratio_,
self._weight_decay, with_decay, self._lazy_mode, 1000,
find_master, False)
param_and_grad[0],
param_and_grad[1],
lr,
moment1,
moment2,
beta1_pow_acc,
beta2_pow_acc,
master_weight,
found_inf,
_beta1,
_beta2,
self._epsilon,
lr_ratio_,
self._weight_decay,
with_decay,
self._lazy_mode,
1000,
find_master,
False,
)
else:
_, _, _, _, _, _ = _legacy_C_ops.adamw(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2,
beta1_pow_acc, beta2_pow_acc, master_weight,
param_and_grad[0], moment1, moment2, beta1_pow_acc,
beta2_pow_acc, master_weight, 'epsilon', self._epsilon,
'lazy_mode', self._lazy_mode,
'min_row_size_to_use_multithread', 1000, 'beta1', _beta1,
'beta2', _beta2, "with_decay", with_decay, 'coeff',
self._weight_decay, 'multi_precision', find_master,
'lr_ratio', lr_ratio_)
param_and_grad[0],
param_and_grad[1],
lr,
moment1,
moment2,
beta1_pow_acc,
beta2_pow_acc,
master_weight,
param_and_grad[0],
moment1,
moment2,
beta1_pow_acc,
beta2_pow_acc,
master_weight,
'epsilon',
self._epsilon,
'lazy_mode',
self._lazy_mode,
'min_row_size_to_use_multithread',
1000,
'beta1',
_beta1,
'beta2',
_beta2,
"with_decay",
with_decay,
'coeff',
self._weight_decay,
'multi_precision',
find_master,
'lr_ratio',
lr_ratio_,
)
return None
inputs = {
......@@ -486,18 +584,14 @@ class AdamW(Optimizer):
"Beta2PowOut": [beta2_pow_acc],
}
attrs = {
"lazy_mode":
self._lazy_mode,
"min_row_size_to_use_multithread":
1000,
"multi_precision":
find_master,
"with_decay":
with_decay,
"coeff":
self._weight_decay,
"lr_ratio":
1. if self._lr_ratio is None else self._lr_ratio(param_and_grad[0])
"lazy_mode": self._lazy_mode,
"min_row_size_to_use_multithread": 1000,
"multi_precision": find_master,
"with_decay": with_decay,
"coeff": self._weight_decay,
"lr_ratio": 1.0
if self._lr_ratio is None
else self._lr_ratio(param_and_grad[0]),
}
if isinstance(self._beta1, Variable):
......@@ -517,11 +611,13 @@ class AdamW(Optimizer):
inputs["MasterParam"] = master_weight
outputs["MasterParamOut"] = master_weight
adamw_op = block.append_op(type=self.type,
inputs=inputs,
outputs=outputs,
attrs=attrs,
stop_gradient=True)
adamw_op = block.append_op(
type=self.type,
inputs=inputs,
outputs=outputs,
attrs=attrs,
stop_gradient=True,
)
return adamw_op
......@@ -541,7 +637,7 @@ class AdamW(Optimizer):
.. code-block:: python
import paddle
a = paddle.rand([2,13], dtype="float32")
linear = paddle.nn.Linear(13, 5)
# This can be any optimizer supported by dygraph.
......@@ -560,24 +656,28 @@ class AdamW(Optimizer):
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
if framework.in_dygraph_mode():
if hasattr(grad_var, "is_selected_rows"
) and grad_var.is_selected_rows(
) and self.regularization is not None:
if (
hasattr(grad_var, "is_selected_rows")
and grad_var.is_selected_rows()
and self.regularization is not None
):
raise RuntimeError(
"AdamW don't support weight_decay with sparse parameters, please set it to None."
)
else:
if hasattr(
grad_var, "_is_sparse") and grad_var._is_sparse(
) and self.regularization is not None:
if (
hasattr(grad_var, "_is_sparse")
and grad_var._is_sparse()
and self.regularization is not None
):
raise RuntimeError(
"AdamW don't support weight_decay with sparse parameters, please set it to None."
)
params_grads.append((param, grad_var))
optimize_ops = self._apply_optimize(loss=None,
startup_program=None,
params_grads=params_grads)
optimize_ops = self._apply_optimize(
loss=None, startup_program=None, params_grads=params_grads
)
else:
# optimize parameters in groups
for param_group in self._param_groups:
......@@ -588,35 +688,41 @@ class AdamW(Optimizer):
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
if framework.in_dygraph_mode():
if hasattr(grad_var, "is_selected_rows"
) and grad_var.is_selected_rows(
) and self.regularization is not None:
if (
hasattr(grad_var, "is_selected_rows")
and grad_var.is_selected_rows()
and self.regularization is not None
):
raise RuntimeError(
"AdamW don't support weight_decay with sparse parameters, please set it to None."
)
else:
if hasattr(grad_var,
"_is_sparse") and grad_var._is_sparse(
) and self.regularization is not None:
if (
hasattr(grad_var, "_is_sparse")
and grad_var._is_sparse()
and self.regularization is not None
):
raise RuntimeError(
"AdamW don't support weight_decay with sparse parameters, please set it to None."
)
params_grads['params'].append((param, grad_var))
params_grads.update(
{k: v
for k, v in param_group.items() if k != 'params'})
self._apply_optimize(loss=None,
startup_program=None,
params_grads=params_grads)
{k: v for k, v in param_group.items() if k != 'params'}
)
self._apply_optimize(
loss=None, startup_program=None, params_grads=params_grads
)
def _update_param_group(self, parameters):
self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
self._lazy_mode = parameters.get('lazy_mode',
self._default_dict['lazy_mode'])
self._weight_decay = parameters.get('weight_decay',
self._default_dict['weight_decay'])
self._lazy_mode = parameters.get(
'lazy_mode', self._default_dict['lazy_mode']
)
self._weight_decay = parameters.get(
'weight_decay', self._default_dict['weight_decay']
)
parameters = parameters.get('params')
return parameters
......@@ -440,15 +440,21 @@ class Optimizer(object):
return self._opti_name_list
def _create_global_learning_rate(self):
# lr var can't be float16, for pure fp16 training, should extra handle the dtype for lr
# lr var can't be float16 or bfloat16, for pure fp16 or fp16 training, should extra handle the dtype for lr
_lr_dtype = (
paddle.get_default_dtype() if self._dtype is None else self._dtype
)
_lr_dtype = (
paddle.float32
if (
paddle.get_default_dtype() != "float16"
and _lr_dtype == paddle.float16
(
paddle.get_default_dtype() != "float16"
and _lr_dtype == paddle.float16
)
or (
paddle.get_default_dtype() != "bfloat16"
and _lr_dtype == paddle.bfloat16
)
)
else _lr_dtype
)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册