未验证 提交 d1e8b1e2 编写于 作者: Y Yiqun Liu 提交者: GitHub

Cherry pick for fix of operator precision. (#52705)

* Fix scale kernel for low precision, cherry pick #50998.

* Fix the FP16 precision problem of add_n. (#50129)

* Change squared_l2_norm to reuse ReduceKernel, and register fp16 and bf16 kernel, which is cherry pick #48315.

* Cherry-pick the fix of MPTypeTrait in KP, which is implemented in #50993.

* Cherry-pick the multi-precision support of AdamW for bf16, #48041.

* Fix compiling error.

* Cherry-pick the fix of CubTensorReduceImpl for bfloat16 in #50993.

* Fix unittest.

---------
Co-authored-by: Nliuruyan <44316842+liuruyan@users.noreply.github.com>
上级 d12588d2
...@@ -986,14 +986,16 @@ template <typename Tx, ...@@ -986,14 +986,16 @@ template <typename Tx,
template <typename> template <typename>
class ReduceOp, class ReduceOp,
typename TransformOp> typename TransformOp>
static typename std::enable_if<!std::is_same<Tx, phi::dtype::float16>::value, static
void>::type typename std::enable_if<!std::is_same<Tx, phi::dtype::float16>::value &&
CubTensorReduceImpl(const Tx* x_data, !std::is_same<Tx, phi::dtype::bfloat16>::value,
Ty* y_data, void>::type
const TransformOp& transform, CubTensorReduceImpl(const Tx* x_data,
int reduce_num, Ty* y_data,
const KPDevice& dev_ctx, const TransformOp& transform,
KPStream stream) { int reduce_num,
const KPDevice& dev_ctx,
KPStream stream) {
auto reducer = ReduceOp<Ty>(); auto reducer = ReduceOp<Ty>();
cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data, cub::TransformInputIterator<Ty, TransformOp, const Tx*> trans_x(x_data,
transform); transform);
...@@ -1037,6 +1039,23 @@ CubTensorReduceImpl(const Tx* x_data, ...@@ -1037,6 +1039,23 @@ CubTensorReduceImpl(const Tx* x_data,
PADDLE_THROW(phi::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"Tx should not be float16 when using cub::DeviceReduce::Reduce().")); "Tx should not be float16 when using cub::DeviceReduce::Reduce()."));
} }
template <typename Tx,
typename Ty,
template <typename>
class ReduceOp,
typename TransformOp>
static typename std::enable_if<std::is_same<Tx, phi::dtype::bfloat16>::value,
void>::type
CubTensorReduceImpl(const Tx* x_data,
Ty* y_data,
const TransformOp& transform,
int reduce_num,
const KPDevice& dev_ctx,
KPStream stream) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Tx should not be bfloat16 when using cub::DeviceReduce::Reduce()."));
}
#endif // PADDLE_WITH_XPU_KP #endif // PADDLE_WITH_XPU_KP
template <typename Tx, template <typename Tx,
...@@ -1081,7 +1100,8 @@ void ReduceKernel(const KPDevice& dev_ctx, ...@@ -1081,7 +1100,8 @@ void ReduceKernel(const KPDevice& dev_ctx,
config.SetOutputData(y_data, dev_ctx, &tmp); config.SetOutputData(y_data, dev_ctx, &tmp);
constexpr bool kIsTxFP16 = std::is_same<Tx, phi::dtype::float16>::value; constexpr bool kIsTxFP16 = std::is_same<Tx, phi::dtype::float16>::value;
bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16; constexpr bool kIsTxBF16 = std::is_same<Tx, phi::dtype::bfloat16>::value;
bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16 && !kIsTxBF16;
#ifndef PADDLE_WITH_XPU_KP #ifndef PADDLE_WITH_XPU_KP
if (use_cub_reduce) { if (use_cub_reduce) {
if (is_mean) { if (is_mean) {
......
...@@ -14,10 +14,10 @@ ...@@ -14,10 +14,10 @@
#include "paddle/phi/kernels/add_n_kernel.h" #include "paddle/phi/kernels/add_n_kernel.h"
#include "paddle/phi/kernels/impl/add_n_kernel_impl.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/impl/add_n_kernel_impl.h"
namespace phi { namespace phi {
...@@ -38,16 +38,18 @@ __global__ void Sum2CUDAKernel(const T *in_0, ...@@ -38,16 +38,18 @@ __global__ void Sum2CUDAKernel(const T *in_0,
template <class T> template <class T>
__global__ void SumArrayCUDAKernel( __global__ void SumArrayCUDAKernel(
T **in, T *out, int64_t N, size_t in_size, bool read_dst) { T **in, T *out, int64_t N, size_t in_size, bool read_dst) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
int id = blockIdx.x * blockDim.x + threadIdx.x; int id = blockIdx.x * blockDim.x + threadIdx.x;
while (id < N) { while (id < N) {
T total(read_dst ? out[id] : static_cast<T>(0)); MPType total(read_dst ? static_cast<MPType>(out[id])
: static_cast<MPType>(0));
for (int i = 0; i < in_size; ++i) { for (int i = 0; i < in_size; ++i) {
const T *tmp = in[i]; const T *tmp = in[i];
if (tmp) { if (tmp) {
total += tmp[id]; total += static_cast<MPType>(tmp[id]);
} }
} }
out[id] = total; out[id] = static_cast<T>(total);
id += blockDim.x * gridDim.x; id += blockDim.x * gridDim.x;
} }
} }
...@@ -116,11 +118,12 @@ void AddNKernel(const Context &dev_ctx, ...@@ -116,11 +118,12 @@ void AddNKernel(const Context &dev_ctx,
int64_t length_0 = in_0.numel(); int64_t length_0 = in_0.numel();
int64_t length_1 = in_1.numel(); int64_t length_1 = in_1.numel();
if (length_0 && length_1 && in_0.IsInitialized() && in_1.IsInitialized()) { if (length_0 && length_1 && in_0.IsInitialized() && in_1.IsInitialized()) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
auto result = EigenVector<T>::Flatten(*out); auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device(); auto &place = *dev_ctx.eigen_device();
auto in_0_e = EigenVector<T>::Flatten(in_0); auto in_0_e = EigenVector<T>::Flatten(in_0).template cast<MPType>();
auto in_1_e = EigenVector<T>::Flatten(in_1); auto in_1_e = EigenVector<T>::Flatten(in_1).template cast<MPType>();
result.device(place) = in_0_e + in_1_e; result.device(place) = (in_0_e + in_1_e).template cast<T>();
} else if (length_0 && in_0.IsInitialized()) { } else if (length_0 && in_0.IsInitialized()) {
auto result = EigenVector<T>::Flatten(*out); auto result = EigenVector<T>::Flatten(*out);
auto &place = *dev_ctx.eigen_device(); auto &place = *dev_ctx.eigen_device();
......
...@@ -15,28 +15,30 @@ limitations under the License. */ ...@@ -15,28 +15,30 @@ limitations under the License. */
#include "paddle/phi/kernels/scale_kernel.h" #include "paddle/phi/kernels/scale_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h" #include "paddle/phi/kernels/funcs/elementwise_base.h"
namespace phi { namespace phi {
template <typename InT> template <typename DataT, typename ParamT>
struct ScaleFunctor { struct ScaleFunctor {
InT bias; ParamT bias;
InT scale; ParamT scale;
bool bias_after_scale; bool bias_after_scale;
ScaleFunctor(InT scale_data, InT bias_data, bool is_bias_after_sacle) ScaleFunctor(ParamT scale_data, ParamT bias_data, bool is_bias_after_sacle)
: bias(bias_data), : bias(bias_data),
scale(scale_data), scale(scale_data),
bias_after_scale(is_bias_after_sacle) {} bias_after_scale(is_bias_after_sacle) {}
__device__ __forceinline__ InT operator()(const InT x) const { __device__ __forceinline__ DataT operator()(const DataT x) const {
if (bias_after_scale) { if (bias_after_scale) {
return scale * x + bias; return static_cast<DataT>(scale * static_cast<ParamT>(x) + bias);
} else { } else {
return scale * (x + bias); return static_cast<DataT>(scale * (static_cast<ParamT>(x) + bias));
} }
} }
}; };
...@@ -48,16 +50,21 @@ void ScaleKernel(const Context& dev_ctx, ...@@ -48,16 +50,21 @@ void ScaleKernel(const Context& dev_ctx,
float bias, float bias,
bool bias_after_scale, bool bias_after_scale,
DenseTensor* out) { DenseTensor* out) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
std::vector<const DenseTensor*> inputs; std::vector<const DenseTensor*> inputs;
std::vector<DenseTensor*> outputs; std::vector<DenseTensor*> outputs;
inputs.emplace_back(&x); inputs.emplace_back(&x);
outputs.emplace_back(out); outputs.emplace_back(out);
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
if (x.numel() <= 0 || (!x.IsInitialized())) {
return;
}
phi::funcs::ElementwiseKernel<T>( phi::funcs::ElementwiseKernel<T>(
dev_ctx, dev_ctx,
inputs, inputs,
&outputs, &outputs,
ScaleFunctor<T>(scale.to<T>(), static_cast<T>(bias), bias_after_scale)); ScaleFunctor<T, MT>(
scale.to<MT>(), static_cast<MT>(bias), bias_after_scale));
} }
} // namespace phi } // namespace phi
......
...@@ -15,12 +15,47 @@ ...@@ -15,12 +15,47 @@
#include "paddle/phi/kernels/squared_l2_norm_grad_kernel.h" #include "paddle/phi/kernels/squared_l2_norm_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/squared_l2_norm_grad_kernel_impl.h" #include "paddle/phi/kernels/funcs/broadcast_function.h"
namespace phi {
/**
* x*y*2.0
*/
template <typename T>
struct DoubleMulFunctor {
__device__ __forceinline__ T operator()(const T a, const T b) const {
return b * a * static_cast<T>(2.0f);
}
};
template <typename T, typename Context>
void SquaredL2NormGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& dout,
DenseTensor* dx) {
dev_ctx.template Alloc<T>(dx);
PADDLE_ENFORCE_EQ(
dout.numel(),
1,
phi::errors::InvalidArgument(
"Input(GRAD@Out) of SquaredL2NormGradOP should be a scalar."));
std::vector<const DenseTensor*> ins{&x, &dout};
std::vector<DenseTensor*> outs{dx};
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, -1, phi::DoubleMulFunctor<T>());
}
} // namespace phi
PD_REGISTER_KERNEL(squared_l2_norm_grad, PD_REGISTER_KERNEL(squared_l2_norm_grad,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::SquaredL2NormGradKernel, phi::SquaredL2NormGradKernel,
float, float,
double) {} double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -15,9 +15,34 @@ ...@@ -15,9 +15,34 @@
#include "paddle/phi/kernels/squared_l2_norm_kernel.h" #include "paddle/phi/kernels/squared_l2_norm_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/squared_l2_norm_kernel_impl.h" #include "paddle/phi/kernels/funcs/reduce_function.h"
PD_REGISTER_KERNEL( namespace phi {
squared_l2_norm, GPU, ALL_LAYOUT, phi::SquaredL2NormKernel, float, double) {
template <typename T, typename Context>
void SquaredL2NormKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
std::vector<int> origin_reduce_dims;
for (size_t i = 0; i < x.dims().size(); i++) {
origin_reduce_dims.push_back(i);
}
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::SquareFunctor<T, T>>(
dev_ctx, x, out, kps::SquareFunctor<T, T>(), origin_reduce_dims, false);
} }
} // namespace phi
PD_REGISTER_KERNEL(squared_l2_norm,
GPU,
ALL_LAYOUT,
phi::SquaredL2NormKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -52,6 +52,12 @@ class MPTypeTrait<phi::dtype::float16> { ...@@ -52,6 +52,12 @@ class MPTypeTrait<phi::dtype::float16> {
using Type = float; using Type = float;
}; };
template <>
class MPTypeTrait<phi::dtype::bfloat16> {
public:
using Type = float;
};
/** /**
* @brief Will be used in BlockYReduce, get the index of reduce_num in shared * @brief Will be used in BlockYReduce, get the index of reduce_num in shared
* memory. * memory.
......
...@@ -32,8 +32,11 @@ from .framework import default_main_program ...@@ -32,8 +32,11 @@ from .framework import default_main_program
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
__all__ = [ __all__ = [
'set_gradient_clip', 'ErrorClipByValue', 'ClipGradByValue', 'set_gradient_clip',
'ClipGradByNorm', 'ClipGradByGlobalNorm' 'ErrorClipByValue',
'ClipGradByValue',
'ClipGradByNorm',
'ClipGradByGlobalNorm',
] ]
_clip_by_global_norm_using_mp_type_flag = False _clip_by_global_norm_using_mp_type_flag = False
...@@ -52,9 +55,10 @@ def _clip_by_global_norm_using_mp_type(*args): ...@@ -52,9 +55,10 @@ def _clip_by_global_norm_using_mp_type(*args):
def _cast_to_mp_type_if_enabled(x): def _cast_to_mp_type_if_enabled(x):
if (x.dtype == core.VarDesc.VarType.FP16 if (
or x.dtype == core.VarDesc.VarType.BF16 x.dtype == core.VarDesc.VarType.FP16
) and _clip_by_global_norm_using_mp_type(): or x.dtype == core.VarDesc.VarType.BF16
) and _clip_by_global_norm_using_mp_type():
return x.astype(core.VarDesc.VarType.FP32) return x.astype(core.VarDesc.VarType.FP32)
else: else:
return x return x
...@@ -66,8 +70,7 @@ def _squared_l2_norm(x): ...@@ -66,8 +70,7 @@ def _squared_l2_norm(x):
""" """
x = _cast_to_mp_type_if_enabled(x) x = _cast_to_mp_type_if_enabled(x)
if core.is_compiled_with_xpu( if core.is_compiled_with_xpu():
) or x.dtype == core.VarDesc.VarType.FP16 or x.dtype == core.VarDesc.VarType.BF16:
square = layers.square(x) square = layers.square(x)
sum_square = layers.reduce_sum(square) sum_square = layers.reduce_sum(square)
return sum_square return sum_square
...@@ -78,7 +81,9 @@ def _squared_l2_norm(x): ...@@ -78,7 +81,9 @@ def _squared_l2_norm(x):
return _legacy_C_ops.squared_l2_norm(x) return _legacy_C_ops.squared_l2_norm(x)
op_type = 'squared_l2_norm' op_type = 'squared_l2_norm'
check_variable_and_dtype(x, 'x', ['float32', 'float64'], op_type) check_variable_and_dtype(
x, 'x', ['float32', 'float64', 'float16', 'uint16'], op_type
)
helper = LayerHelper(op_type, **locals()) helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(x.dtype) out = helper.create_variable_for_type_inference(x.dtype)
...@@ -89,7 +94,6 @@ def _squared_l2_norm(x): ...@@ -89,7 +94,6 @@ def _squared_l2_norm(x):
class BaseErrorClipAttr(object): class BaseErrorClipAttr(object):
def __str__(self): def __str__(self):
raise NotImplementedError() raise NotImplementedError()
...@@ -164,8 +168,9 @@ def error_clip_callback(block, context): ...@@ -164,8 +168,9 @@ def error_clip_callback(block, context):
for grad_n in [n for n in op_desc.output_arg_names() if n in grad_to_var]: for grad_n in [n for n in op_desc.output_arg_names() if n in grad_to_var]:
fwd_var = block._var_recursive(grad_to_var[grad_n]) fwd_var = block._var_recursive(grad_to_var[grad_n])
error_clip = getattr(fwd_var, "error_clip", None) error_clip = getattr(fwd_var, "error_clip", None)
if not (error_clip is None if not (
or isinstance(error_clip, BaseErrorClipAttr)): error_clip is None or isinstance(error_clip, BaseErrorClipAttr)
):
raise TypeError( raise TypeError(
"Variable's error_clip should be an instance of BaseErrorClipAttr or None." "Variable's error_clip should be an instance of BaseErrorClipAttr or None."
) )
...@@ -174,7 +179,6 @@ def error_clip_callback(block, context): ...@@ -174,7 +179,6 @@ def error_clip_callback(block, context):
class ClipGradBase(object): class ClipGradBase(object):
def __init__(self): def __init__(self):
super(ClipGradBase, self).__init__() super(ClipGradBase, self).__init__()
...@@ -197,7 +201,8 @@ class ClipGradBase(object): ...@@ -197,7 +201,8 @@ class ClipGradBase(object):
warnings.warn( warnings.warn(
"'set_gradient_clip' will be ineffective, because you have " "'set_gradient_clip' will be ineffective, because you have "
"set 'need_clip' in 'ParamAttr'. So, 'set_gradient_clip' " "set 'need_clip' in 'ParamAttr'. So, 'set_gradient_clip' "
"is redundant and you can remove it.") "is redundant and you can remove it."
)
break break
return self._static_clip(params_grads) return self._static_clip(params_grads)
...@@ -211,34 +216,34 @@ class ClipGradBase(object): ...@@ -211,34 +216,34 @@ class ClipGradBase(object):
class ClipGradByValue(ClipGradBase): class ClipGradByValue(ClipGradBase):
""" """
Limit the value of multi-dimensional Tensor :math:`X` to the range [min, max]. Limit the value of multi-dimensional Tensor :math:`X` to the range [min, max].
- Any values less than min are set to ``min``. - Any values less than min are set to ``min``.
- Any values greater than max are set to ``max``. - Any values greater than max are set to ``max``.
The multi-dimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters set in ``optimizer``. The multi-dimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped. If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
(for example: :ref:`api_paddle_optimizer_SGD`). (for example: :ref:`api_paddle_optimizer_SGD`).
Note: Note:
``need_clip`` of ``ClipGradByValue`` HAS BEEN DEPRECATED since 2.0. ``need_clip`` of ``ClipGradByValue`` HAS BEEN DEPRECATED since 2.0.
Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope. Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.
Args: Args:
max (float): The maximum value to clip by. max (float): The maximum value to clip by.
min (float, optional): The minimum value to clip by. if not set by user, it will be set to ``-max`` min (float, optional): The minimum value to clip by. if not set by user, it will be set to ``-max``
automatically. In this case, ``max`` must be greater than 0. automatically. In this case, ``max`` must be greater than 0.
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32') x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
linear = paddle.nn.Linear(in_features=10, out_features=10, linear = paddle.nn.Linear(in_features=10, out_features=10,
weight_attr=paddle.ParamAttr(need_clip=True), weight_attr=paddle.ParamAttr(need_clip=True),
bias_attr=paddle.ParamAttr(need_clip=False)) bias_attr=paddle.ParamAttr(need_clip=False))
out = linear(x) out = linear(x)
loss = paddle.mean(out) loss = paddle.mean(out)
...@@ -252,7 +257,7 @@ class ClipGradByValue(ClipGradBase): ...@@ -252,7 +257,7 @@ class ClipGradByValue(ClipGradBase):
def __init__(self, max, min=None): def __init__(self, max, min=None):
super(ClipGradByValue, self).__init__() super(ClipGradByValue, self).__init__()
if min is None: if min is None:
assert (max > 0.0) assert max > 0.0
min = -max min = -max
self.max = float(max) self.max = float(max)
self.min = float(min) self.min = float(min)
...@@ -417,17 +422,17 @@ def _allow_pure_fp16_global_norm_clip(*args): ...@@ -417,17 +422,17 @@ def _allow_pure_fp16_global_norm_clip(*args):
class ClipGradByGlobalNorm(ClipGradBase): class ClipGradByGlobalNorm(ClipGradBase):
r""" r"""
Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in
:math:`t\_list` , and limit it to ``clip_norm`` . :math:`t\_list` , and limit it to ``clip_norm`` .
- If the global norm is greater than ``clip_norm`` , all elements of :math:`t\_list` will be compressed by a ratio. - If the global norm is greater than ``clip_norm`` , all elements of :math:`t\_list` will be compressed by a ratio.
- If the global norm is less than or equal to ``clip_norm`` , nothing will be done. - If the global norm is less than or equal to ``clip_norm`` , nothing will be done.
The list of Tensor :math:`t\_list` is not passed from this class, but the gradients of all parameters set in ``optimizer``. The list of Tensor :math:`t\_list` is not passed from this class, but the gradients of all parameters set in ``optimizer``.
If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped. If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped.
Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer``
(for example: :ref:`api_paddle_optimizer_SGD`). (for example: :ref:`api_paddle_optimizer_SGD`).
The clipping formula is: The clipping formula is:
...@@ -443,7 +448,7 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -443,7 +448,7 @@ class ClipGradByGlobalNorm(ClipGradBase):
global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2} global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2}
Note: Note:
``need_clip`` of ``ClipGradyGlobalNorm`` HAS BEEN DEPRECATED since 2.0. ``need_clip`` of ``ClipGradyGlobalNorm`` HAS BEEN DEPRECATED since 2.0.
Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope. Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope.
Args: Args:
...@@ -452,12 +457,12 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -452,12 +457,12 @@ class ClipGradByGlobalNorm(ClipGradBase):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle import paddle
x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32') x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32')
linear = paddle.nn.Linear(in_features=10, out_features=10, linear = paddle.nn.Linear(in_features=10, out_features=10,
weight_attr=paddle.ParamAttr(need_clip=True), weight_attr=paddle.ParamAttr(need_clip=True),
bias_attr=paddle.ParamAttr(need_clip=False)) bias_attr=paddle.ParamAttr(need_clip=False))
out = linear(x) out = linear(x)
loss = paddle.mean(out) loss = paddle.mean(out)
...@@ -468,10 +473,9 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -468,10 +473,9 @@ class ClipGradByGlobalNorm(ClipGradBase):
sdg.step() sdg.step()
""" """
def __init__(self, def __init__(
clip_norm, self, clip_norm, group_name="default_group", auto_skip_clip=False
group_name="default_group", ):
auto_skip_clip=False):
super(ClipGradByGlobalNorm, self).__init__() super(ClipGradByGlobalNorm, self).__init__()
self.clip_norm = float(clip_norm) self.clip_norm = float(clip_norm)
self.group_name = group_name self.group_name = group_name
...@@ -503,7 +507,10 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -503,7 +507,10 @@ class ClipGradByGlobalNorm(ClipGradBase):
merge_grad = layers.get_tensor_from_selected_rows(merge_grad) merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
sum_square = _squared_l2_norm(merge_grad) sum_square = _squared_l2_norm(merge_grad)
if sum_square.dtype == core.VarDesc.VarType.FP16 or sum_square.dtype == core.VarDesc.VarType.BF16: if (
sum_square.dtype == core.VarDesc.VarType.FP16
or sum_square.dtype == core.VarDesc.VarType.BF16
):
sum_square_list_fp16.append(sum_square) sum_square_list_fp16.append(sum_square)
elif sum_square.dtype == core.VarDesc.VarType.FP32: elif sum_square.dtype == core.VarDesc.VarType.FP32:
sum_square_list_fp32.append(sum_square) sum_square_list_fp32.append(sum_square)
...@@ -511,8 +518,12 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -511,8 +518,12 @@ class ClipGradByGlobalNorm(ClipGradBase):
sum_square_list.append(sum_square) sum_square_list.append(sum_square)
# all parameters have been filterd out # all parameters have been filterd out
if len(sum_square_list) + len(sum_square_list_fp16) + len( if (
sum_square_list_fp32) == 0: len(sum_square_list)
+ len(sum_square_list_fp16)
+ len(sum_square_list_fp32)
== 0
):
return params_grads return params_grads
sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32" sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32"
...@@ -531,22 +542,23 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -531,22 +542,23 @@ class ClipGradByGlobalNorm(ClipGradBase):
global_norm_var.append(global_norm_var_fp64) global_norm_var.append(global_norm_var_fp64)
global_norm_var = paddle.add_n(global_norm_var) global_norm_var = paddle.add_n(global_norm_var)
global_norm_var = layers.sqrt(global_norm_var) global_norm_var = layers.sqrt(global_norm_var)
max_global_norm = layers.fill_constant(shape=[1], max_global_norm = layers.fill_constant(
dtype=global_norm_var.dtype, shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm
value=self.clip_norm) )
need_clip = False need_clip = False
if not self.auto_skip_clip: # always apply clip if not self.auto_skip_clip: # always apply clip
need_clip = True need_clip = True
clip_var = layers.elementwise_div(x=max_global_norm, clip_var = layers.elementwise_div(
y=layers.elementwise_max( x=max_global_norm,
x=global_norm_var, y=layers.elementwise_max(x=global_norm_var, y=max_global_norm),
y=max_global_norm)) )
elif global_norm_var > max_global_norm: elif global_norm_var > max_global_norm:
# only when global_norm_var > max_global_norm, grad need clip # only when global_norm_var > max_global_norm, grad need clip
need_clip = True need_clip = True
clip_var = layers.elementwise_div(x=max_global_norm, clip_var = layers.elementwise_div(
y=global_norm_var) x=max_global_norm, y=global_norm_var
)
for p, g in params_grads: for p, g in params_grads:
if g is None: if g is None:
...@@ -556,8 +568,11 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -556,8 +568,11 @@ class ClipGradByGlobalNorm(ClipGradBase):
continue continue
# TODO(wangxi): use inplace elementwise_mul # TODO(wangxi): use inplace elementwise_mul
if need_clip: if need_clip:
clip_input = (clip_var.astype(g.dtype) clip_input = (
if clip_var.dtype != g.dtype else clip_var) clip_var.astype(g.dtype)
if clip_var.dtype != g.dtype
else clip_var
)
new_grad = layers.elementwise_mul(g, clip_input) new_grad = layers.elementwise_mul(g, clip_input)
params_and_grads.append((p, new_grad)) params_and_grads.append((p, new_grad))
else: else:
...@@ -581,7 +596,8 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -581,7 +596,8 @@ class ClipGradByGlobalNorm(ClipGradBase):
if g.type == core.VarDesc.VarType.SELECTED_ROWS: if g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g) merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows( merge_grad = layers.get_tensor_from_selected_rows(
merge_grad) merge_grad
)
sum_square = _squared_l2_norm(merge_grad) sum_square = _squared_l2_norm(merge_grad)
if sum_square.dtype == core.VarDesc.VarType.FP16: if sum_square.dtype == core.VarDesc.VarType.FP16:
sum_square_list_fp16.append(sum_square) sum_square_list_fp16.append(sum_square)
...@@ -591,8 +607,12 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -591,8 +607,12 @@ class ClipGradByGlobalNorm(ClipGradBase):
sum_square_list.append(sum_square) sum_square_list.append(sum_square)
# all parameters have been filterd out # all parameters have been filterd out
if len(sum_square_list) + len(sum_square_list_fp16) + len( if (
sum_square_list_fp32) == 0: len(sum_square_list)
+ len(sum_square_list_fp16)
+ len(sum_square_list_fp32)
== 0
):
return params_grads return params_grads
with p.block.program._optimized_guard([p, g]): with p.block.program._optimized_guard([p, g]):
...@@ -601,10 +621,14 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -601,10 +621,14 @@ class ClipGradByGlobalNorm(ClipGradBase):
global_norm_var = [] global_norm_var = []
if len(sum_square_list_fp16) > 0: if len(sum_square_list_fp16) > 0:
global_norm_var_fp16 = layers.sums(sum_square_list_fp16) global_norm_var_fp16 = layers.sums(sum_square_list_fp16)
if sum_square_list_fp32 or sum_square_list or not _allow_pure_fp16_global_norm_clip( if (
sum_square_list_fp32
or sum_square_list
or not _allow_pure_fp16_global_norm_clip()
): ):
global_norm_var.append( global_norm_var.append(
global_norm_var_fp16.astype(sum_dtype)) global_norm_var_fp16.astype(sum_dtype)
)
else: else:
global_norm_var.append(global_norm_var_fp16) global_norm_var.append(global_norm_var_fp16)
if len(sum_square_list_fp32) > 0: if len(sum_square_list_fp32) > 0:
...@@ -613,23 +637,28 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -613,23 +637,28 @@ class ClipGradByGlobalNorm(ClipGradBase):
global_norm_var.append(global_norm_var_fp32) global_norm_var.append(global_norm_var_fp32)
else: else:
global_norm_var.append( global_norm_var.append(
global_norm_var_fp32.astype(sum_dtype)) global_norm_var_fp32.astype(sum_dtype)
)
if len(sum_square_list) > 0: if len(sum_square_list) > 0:
# fp64 # fp64
global_norm_var_other_dtype = layers.sums(sum_square_list) global_norm_var_other_dtype = layers.sums(sum_square_list)
global_norm_var.append(global_norm_var_other_dtype) global_norm_var.append(global_norm_var_other_dtype)
global_norm_var = layers.sums(global_norm_var) if len( global_norm_var = (
global_norm_var) > 1 else global_norm_var[0] layers.sums(global_norm_var)
if len(global_norm_var) > 1
else global_norm_var[0]
)
global_norm_var = layers.sqrt(x=global_norm_var) global_norm_var = layers.sqrt(x=global_norm_var)
max_global_norm = layers.fill_constant( max_global_norm = layers.fill_constant(
shape=[1], shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm
dtype=global_norm_var.dtype, )
value=self.clip_norm) scale_var = layers.elementwise_div(
scale_var = layers.elementwise_div(x=max_global_norm, x=max_global_norm,
y=layers.elementwise_max( y=layers.elementwise_max(
x=max_global_norm, x=max_global_norm, y=global_norm_var
y=global_norm_var)) ),
)
param_new_grad_name_dict = dict() param_new_grad_name_dict = dict()
for p, g in params_grads: for p, g in params_grads:
if g is None: if g is None:
...@@ -641,29 +670,32 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -641,29 +670,32 @@ class ClipGradByGlobalNorm(ClipGradBase):
with p.block.program._optimized_guard([p, g]): with p.block.program._optimized_guard([p, g]):
new_g = _cast_to_mp_type_if_enabled(g) new_g = _cast_to_mp_type_if_enabled(g)
# inplace # inplace
scale_input = (scale_var.astype('float16') if scale_input = (
new_g.dtype == core.VarDesc.VarType.FP16 and scale_var.astype('float16')
scale_var.dtype != core.VarDesc.VarType.FP16 if new_g.dtype == core.VarDesc.VarType.FP16
else scale_var) and scale_var.dtype != core.VarDesc.VarType.FP16
else scale_var
)
# NOTE(Yuang Liu): For pure dp with gradient merge, the p and g # NOTE(Yuang Liu): For pure dp with gradient merge, the p and g
# will be in different blocks with the gradient clip related ops. # will be in different blocks with the gradient clip related ops.
# We need to handle the correct block, otherwise will encounter # We need to handle the correct block, otherwise will encounter
# a 'NotFoundError' during compile time. # a 'NotFoundError' during compile time.
block = default_main_program().current_block() block = default_main_program().current_block()
block.append_op(type='elementwise_mul', block.append_op(
inputs={ type='elementwise_mul',
'X': new_g, inputs={'X': new_g, 'Y': scale_input},
'Y': scale_input outputs={'Out': new_g},
}, )
outputs={'Out': new_g})
if new_g is not g: if new_g is not g:
block.append_op(type='cast', block.append_op(
inputs={'X': new_g}, type='cast',
outputs={'Out': g}, inputs={'X': new_g},
attrs={ outputs={'Out': g},
'in_dtype': new_g.dtype, attrs={
'out_dtype': g.dtype 'in_dtype': new_g.dtype,
}) 'out_dtype': g.dtype,
},
)
param_new_grad_name_dict[p.name] = g.name param_new_grad_name_dict[p.name] = g.name
params_and_grads.append((p, g)) params_and_grads.append((p, g))
...@@ -676,7 +708,8 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -676,7 +708,8 @@ class ClipGradByGlobalNorm(ClipGradBase):
context[self.group_name] = [] context[self.group_name] = []
context[self.group_name + "_clip_value"] = self.clip_norm context[self.group_name + "_clip_value"] = self.clip_norm
context[self.group_name + "_clip"] = layers.fill_constant( context[self.group_name + "_clip"] = layers.fill_constant(
shape=[1], dtype=grad.dtype, value=self.clip_norm) shape=[1], dtype=grad.dtype, value=self.clip_norm
)
else: else:
if not self.clip_norm == context[self.group_name + "_clip_value"]: if not self.clip_norm == context[self.group_name + "_clip_value"]:
raise ValueError( raise ValueError(
...@@ -699,20 +732,19 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -699,20 +732,19 @@ class ClipGradByGlobalNorm(ClipGradBase):
group_norm_var = layers.sums(input=self.context[self.group_name]) group_norm_var = layers.sums(input=self.context[self.group_name])
group_norm_var = layers.sqrt(x=group_norm_var) group_norm_var = layers.sqrt(x=group_norm_var)
clip_var = self.context[self.group_name + "_clip"] clip_var = self.context[self.group_name + "_clip"]
group_scale_var = layers.elementwise_div(x=clip_var, group_scale_var = layers.elementwise_div(
y=layers.elementwise_max( x=clip_var,
x=clip_var, y=layers.elementwise_max(x=clip_var, y=group_norm_var),
y=group_norm_var)) )
assert group_scale_var.shape == (1, ) assert group_scale_var.shape == (1,)
self.context[group_scale_name] = group_scale_var self.context[group_scale_name] = group_scale_var
# inplace # inplace
param.block.append_op(type='elementwise_mul', param.block.append_op(
inputs={ type='elementwise_mul',
'X': grad, inputs={'X': grad, 'Y': self.context[group_scale_name]},
'Y': self.context[group_scale_name] outputs={'Out': grad},
}, )
outputs={'Out': grad})
return param, grad return param, grad
...@@ -721,23 +753,23 @@ class ClipGradByGlobalNorm(ClipGradBase): ...@@ -721,23 +753,23 @@ class ClipGradByGlobalNorm(ClipGradBase):
def set_gradient_clip(clip, param_list=None, program=None): def set_gradient_clip(clip, param_list=None, program=None):
""" """
:api_attr: Static Graph :api_attr: Static Graph
Warning: Warning:
This API must be used after building network, and before ``minimize`` , This API must be used after building network, and before ``minimize`` ,
and it may be removed in future releases, so it is not recommended. and it may be removed in future releases, so it is not recommended.
It is recommended to set ``grad_clip`` when initializing the ``optimizer`` , It is recommended to set ``grad_clip`` when initializing the ``optimizer`` ,
this is a better method to clip gradient. There are three clipping strategies: this is a better method to clip gradient. There are three clipping strategies:
:ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` . :ref:`api_fluid_clip_GradientClipByValue` .
To specify parameters that require gradient clip. To specify parameters that require gradient clip.
Args: Args:
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
some derived class of ``GradientClipBase`` . There are three cliping strategies some derived class of ``GradientClipBase`` . There are three cliping strategies
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
:ref:`api_fluid_clip_GradientClipByValue` ). Default value: None, and there is no :ref:`api_fluid_clip_GradientClipByValue` ). Default value: None, and there is no
gradient clipping. gradient clipping.
param_list (list(Variable), optional): Parameters that require gradient clip. param_list (list(Variable), optional): Parameters that require gradient clip.
It can be a list of parameter or a list of parameter's name. It can be a list of parameter or a list of parameter's name.
...@@ -791,7 +823,7 @@ def set_gradient_clip(clip, param_list=None, program=None): ...@@ -791,7 +823,7 @@ def set_gradient_clip(clip, param_list=None, program=None):
param_list=[param_var1, param_var2]) param_list=[param_var1, param_var2])
sgd = fluid.optimizer.SGD(learning_rate=1e-3) sgd = fluid.optimizer.SGD(learning_rate=1e-3)
sgd.minimize(loss) sgd.minimize(loss)
# network 4: use 'set_gradient_clip' and 'optimize(grad_clip=clip)' together # network 4: use 'set_gradient_clip' and 'optimize(grad_clip=clip)' together
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
loss = network() loss = network()
...@@ -802,27 +834,31 @@ def set_gradient_clip(clip, param_list=None, program=None): ...@@ -802,27 +834,31 @@ def set_gradient_clip(clip, param_list=None, program=None):
# Set the gradient clipping strategy: clip2 # Set the gradient clipping strategy: clip2
sgd = fluid.optimizer.SGD(learning_rate=1e-3, grad_clip=clip2) sgd = fluid.optimizer.SGD(learning_rate=1e-3, grad_clip=clip2)
sgd.minimize(loss) sgd.minimize(loss)
# 'set_gradient_clip' will not take effect when setting has a conflict, # 'set_gradient_clip' will not take effect when setting has a conflict,
# and the gradient clipping strategy will be 'clip2' # and the gradient clipping strategy will be 'clip2'
""" """
warnings.warn("Caution! 'set_gradient_clip' is not recommended " warnings.warn(
"and may be deprecated in future! " "Caution! 'set_gradient_clip' is not recommended "
"We recommend a new strategy: set 'grad_clip' " "and may be deprecated in future! "
"when initializing the 'optimizer'. " "We recommend a new strategy: set 'grad_clip' "
"This method can reduce the mistakes, please " "when initializing the 'optimizer'. "
"refer to documention of 'optimizer'.") "This method can reduce the mistakes, please "
"refer to documention of 'optimizer'."
)
if not isinstance(clip, ClipGradBase): if not isinstance(clip, ClipGradBase):
raise TypeError( raise TypeError(
"'clip' should be an instance of ClipGradBase's derived class") "'clip' should be an instance of ClipGradBase's derived class"
)
if program is None: if program is None:
program = framework.default_main_program() program = framework.default_main_program()
for op in program.block(0).ops: for op in program.block(0).ops:
if 'op_namescope' in op.all_attrs() and "optimizer" in op.attr( if 'op_namescope' in op.all_attrs() and "optimizer" in op.attr(
"op_namescope"): "op_namescope"
):
warnings.warn( warnings.warn(
"'minimize' has been invoked before, this will make 'set_gradient_clip' " "'minimize' has been invoked before, this will make 'set_gradient_clip' "
"be ineffective! Please invoke 'set_gradient_clip' before 'minimize'." "be ineffective! Please invoke 'set_gradient_clip' before 'minimize'."
...@@ -847,14 +883,16 @@ def append_gradient_clip_ops(param_grads): ...@@ -847,14 +883,16 @@ def append_gradient_clip_ops(param_grads):
for p, g in param_grads: for p, g in param_grads:
if g is None: if g is None:
continue continue
with p.block.program._optimized_guard( with p.block.program._optimized_guard([p, g]), framework.name_scope(
[p, g]), framework.name_scope('gradient_clip'): 'gradient_clip'
):
clip_attr = getattr(p, 'gradient_clip_attr', None) clip_attr = getattr(p, 'gradient_clip_attr', None)
if clip_attr is None: if clip_attr is None:
return param_grads return param_grads
if not isinstance(clip_attr, ClipGradBase): if not isinstance(clip_attr, ClipGradBase):
raise TypeError( raise TypeError(
"clip attribute should be an instance of GradientClipBase") "clip attribute should be an instance of GradientClipBase"
)
clip_attr._process_context(context=context, param=p, grad=g) clip_attr._process_context(context=context, param=p, grad=g)
...@@ -863,8 +901,9 @@ def append_gradient_clip_ops(param_grads): ...@@ -863,8 +901,9 @@ def append_gradient_clip_ops(param_grads):
for p, g in param_grads: for p, g in param_grads:
if g is None: if g is None:
continue continue
with p.block.program._optimized_guard( with p.block.program._optimized_guard([p, g]), framework.name_scope(
[p, g]), framework.name_scope('gradient_clip'): 'gradient_clip'
):
param, new_grad = clip_attr._create_operators(param=p, grad=g) param, new_grad = clip_attr._create_operators(param=p, grad=g)
param_new_grad_name_dict[param.name] = new_grad.name param_new_grad_name_dict[param.name] = new_grad.name
res.append([param, new_grad]) res.append([param, new_grad])
...@@ -888,12 +927,16 @@ def _correct_clip_op_role_var(params_grads, param_new_grad_name_dict): ...@@ -888,12 +927,16 @@ def _correct_clip_op_role_var(params_grads, param_new_grad_name_dict):
continue continue
block_id_list.append(block_id) block_id_list.append(block_id)
for op in param.block.program.global_block().ops: for op in param.block.program.global_block().ops:
if op.has_attr("op_namescope") and "gradient_clip" in op.attr( if (
"op_namescope") and op.attr('op_role_var'): op.has_attr("op_namescope")
and "gradient_clip" in op.attr("op_namescope")
and op.attr('op_role_var')
):
param_name = op.attr('op_role_var')[0] param_name = op.attr('op_role_var')[0]
if param_name in param_new_grad_name_dict: if param_name in param_new_grad_name_dict:
correct_p_g = [ correct_p_g = [
param_name, param_new_grad_name_dict[param_name] param_name,
param_new_grad_name_dict[param_name],
] ]
op._set_attr('op_role_var', correct_p_g) op._set_attr('op_role_var', correct_p_g)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
class TestAddnOp(unittest.TestCase):
def setUp(self):
np.random.seed(20)
l = 32
self.x_np = np.random.random([l, 16, 256])
def check_main(self, x_np, dtype, axis=None):
paddle.disable_static()
x = []
for i in range(x_np.shape[0]):
val = paddle.to_tensor(x_np[i].astype(dtype))
val.stop_gradient = False
x.append(val)
y = paddle.add_n(x)
x_g = paddle.grad(y, x)
y_np = y.numpy().astype('float32')
x_g_np = []
for val in x_g:
x_g_np.append(val.numpy().astype('float32'))
paddle.enable_static()
return y_np, x_g_np
def test_add_n_fp16(self):
if not paddle.is_compiled_with_cuda():
return
y_np_16, x_g_np_16 = self.check_main(self.x_np, 'float16')
y_np_32, x_g_np_32 = self.check_main(self.x_np, 'float32')
np.testing.assert_allclose(y_np_16, y_np_32, rtol=1e-03)
for i in range(len(x_g_np_32)):
np.testing.assert_allclose(x_g_np_16[i], x_g_np_32[i], rtol=1e-03)
def test_add_n_api(self):
if not paddle.is_compiled_with_cuda():
return
y_np_32, x_g_np_32 = self.check_main(self.x_np, 'float32')
y_np_gt = np.sum(self.x_np, axis=0).astype('float32')
np.testing.assert_allclose(y_np_32, y_np_gt, rtol=1e-06)
if __name__ == "__main__":
unittest.main()
...@@ -26,21 +26,17 @@ from paddle.fluid.clip import _allow_pure_fp16_global_norm_clip ...@@ -26,21 +26,17 @@ from paddle.fluid.clip import _allow_pure_fp16_global_norm_clip
paddle.enable_static() paddle.enable_static()
def bow_net(data, def bow_net(
label, data, label, dict_dim, emb_dim=128, hid_dim=128, hid_dim2=96, class_dim=2
dict_dim, ):
emb_dim=128,
hid_dim=128,
hid_dim2=96,
class_dim=2):
""" """
BOW net BOW net
This model is from https://github.com/PaddlePaddle/models: This model is from https://github.com/PaddlePaddle/models:
fluid/PaddleNLP/text_classification/nets.py fluid/PaddleNLP/text_classification/nets.py
""" """
emb = fluid.layers.embedding(input=data, emb = fluid.layers.embedding(
is_sparse=True, input=data, is_sparse=True, size=[dict_dim, emb_dim]
size=[dict_dim, emb_dim]) )
bow = fluid.layers.sequence_pool(input=emb, pool_type='sum') bow = fluid.layers.sequence_pool(input=emb, pool_type='sum')
bow_tanh = fluid.layers.tanh(bow) bow_tanh = fluid.layers.tanh(bow)
fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh") fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh")
...@@ -53,7 +49,6 @@ def bow_net(data, ...@@ -53,7 +49,6 @@ def bow_net(data,
class TestGradientClip(unittest.TestCase): class TestGradientClip(unittest.TestCase):
def setUp(self): def setUp(self):
self.word_dict_len = 5147 self.word_dict_len = 5147
self.BATCH_SIZE = 2 self.BATCH_SIZE = 2
...@@ -77,8 +72,9 @@ class TestGradientClip(unittest.TestCase): ...@@ -77,8 +72,9 @@ class TestGradientClip(unittest.TestCase):
def check_gradient_clip(self, place, dtype='float32'): def check_gradient_clip(self, place, dtype='float32'):
prog = fluid.Program() prog = fluid.Program()
startup_program = fluid.Program() startup_program = fluid.Program()
with fluid.program_guard(main_program=prog, with fluid.program_guard(
startup_program=startup_program): main_program=prog, startup_program=startup_program
):
image = fluid.data(name="a", shape=[-1, 784], dtype='float32') image = fluid.data(name="a", shape=[-1, 784], dtype='float32')
label = fluid.data(name="b", shape=[-1, 1], dtype='int64') label = fluid.data(name="b", shape=[-1, 1], dtype='int64')
if dtype != 'float32': if dtype != 'float32':
...@@ -99,8 +95,9 @@ class TestGradientClip(unittest.TestCase): ...@@ -99,8 +95,9 @@ class TestGradientClip(unittest.TestCase):
p_g = sorted(p_g, key=lambda x: x[0].name) p_g = sorted(p_g, key=lambda x: x[0].name)
p_g_clip = sorted(p_g_clip, key=lambda x: x[0].name) p_g_clip = sorted(p_g_clip, key=lambda x: x[0].name)
with fluid.program_guard(main_program=prog_clip, with fluid.program_guard(
startup_program=startup_program): main_program=prog_clip, startup_program=startup_program
):
p_g_clip = self.clip_gradient(p_g_clip) p_g_clip = self.clip_gradient(p_g_clip)
grad_list = [elem[1] for elem in p_g] grad_list = [elem[1] for elem in p_g]
...@@ -113,20 +110,20 @@ class TestGradientClip(unittest.TestCase): ...@@ -113,20 +110,20 @@ class TestGradientClip(unittest.TestCase):
data = next(train_reader()) data = next(train_reader())
out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list) out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list)
out_clip = exe.run(prog_clip, out_clip = exe.run(
feed=feeder.feed(data), prog_clip, feed=feeder.feed(data), fetch_list=grad_clip_list
fetch_list=grad_clip_list) )
self.check_clip_result(out, out_clip) self.check_clip_result(out, out_clip)
def check_sparse_gradient_clip(self, place): def check_sparse_gradient_clip(self, place):
prog = fluid.Program() prog = fluid.Program()
startup_program = fluid.Program() startup_program = fluid.Program()
with fluid.program_guard(main_program=prog, with fluid.program_guard(
startup_program=startup_program): main_program=prog, startup_program=startup_program
data = fluid.data(name="words", ):
shape=[-1, 1], data = fluid.data(
dtype="int64", name="words", shape=[-1, 1], dtype="int64", lod_level=1
lod_level=1) )
label = fluid.data(name="label", shape=[-1, 1], dtype="int64") label = fluid.data(name="label", shape=[-1, 1], dtype="int64")
cost = bow_net(data, label, self.word_dict_len) cost = bow_net(data, label, self.word_dict_len)
...@@ -138,7 +135,7 @@ class TestGradientClip(unittest.TestCase): ...@@ -138,7 +135,7 @@ class TestGradientClip(unittest.TestCase):
data = next(self.train_data()) data = next(self.train_data())
val = exe.run(prog, feed=feeder.feed(data), fetch_list=[cost])[0] val = exe.run(prog, feed=feeder.feed(data), fetch_list=[cost])[0]
self.assertEqual((1, ), val.shape) self.assertEqual((1,), val.shape)
self.assertFalse(np.isnan(val)) self.assertFalse(np.isnan(val))
def backward_and_optimize(self, cost): def backward_and_optimize(self, cost):
...@@ -146,7 +143,6 @@ class TestGradientClip(unittest.TestCase): ...@@ -146,7 +143,6 @@ class TestGradientClip(unittest.TestCase):
class TestGradientClipByGlobalNorm(TestGradientClip): class TestGradientClipByGlobalNorm(TestGradientClip):
def init(self): def init(self):
self.clip_norm = 0.2 self.clip_norm = 0.2
...@@ -166,13 +162,13 @@ class TestGradientClipByGlobalNorm(TestGradientClip): ...@@ -166,13 +162,13 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
v, v,
rtol=1e-05, rtol=1e-05,
atol=1e-08, atol=1e-08,
err_msg= err_msg='gradient clip by global norm has wrong results!, \nu={}\nv={}\ndiff={}'.format(
'gradient clip by global norm has wrong results!, \nu={}\nv={}\ndiff={}' u, v, u - v
.format(u, v, u - v)) ),
)
# test whether the output is right when use 'set_gradient_clip' # test whether the output is right when use 'set_gradient_clip'
def test_old_gradient_clip(self): def test_old_gradient_clip(self):
def func(params_grads): def func(params_grads):
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm) clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
fluid.clip.set_gradient_clip(clip) fluid.clip.set_gradient_clip(clip)
...@@ -183,7 +179,6 @@ class TestGradientClipByGlobalNorm(TestGradientClip): ...@@ -183,7 +179,6 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
# test whether the output is right when use grad_clip # test whether the output is right when use grad_clip
def test_new_gradient_clip(self): def test_new_gradient_clip(self):
def func(params_grads): def func(params_grads):
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm) clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
return clip(params_grads) return clip(params_grads)
...@@ -193,7 +188,6 @@ class TestGradientClipByGlobalNorm(TestGradientClip): ...@@ -193,7 +188,6 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
# test whether the output is right when use grad_clip under float64 # test whether the output is right when use grad_clip under float64
def test_new_gradient_clip_fp64(self): def test_new_gradient_clip_fp64(self):
def func(params_grads): def func(params_grads):
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm) clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm)
return clip(params_grads) return clip(params_grads)
...@@ -203,12 +197,12 @@ class TestGradientClipByGlobalNorm(TestGradientClip): ...@@ -203,12 +197,12 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
# invoke 'set_gradient_clip' in a wrong order # invoke 'set_gradient_clip' in a wrong order
def test_wrong_API_order(self): def test_wrong_API_order(self):
def backward_func(cost): def backward_func(cost):
clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0) clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0)
fluid.clip.set_gradient_clip(clip) fluid.clip.set_gradient_clip(clip)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01, sgd_optimizer = fluid.optimizer.SGD(
grad_clip=clip) learning_rate=0.01, grad_clip=clip
)
# if 'set_gradient_clip' and 'optimize(grad_clip)' together, 'set_gradient_clip' will be ineffective # if 'set_gradient_clip' and 'optimize(grad_clip)' together, 'set_gradient_clip' will be ineffective
sgd_optimizer.minimize(cost) sgd_optimizer.minimize(cost)
# 'set_gradient_clip' must before 'minimize', otherwise, 'set_gradient_clip' will be ineffective # 'set_gradient_clip' must before 'minimize', otherwise, 'set_gradient_clip' will be ineffective
...@@ -222,43 +216,72 @@ class TestGradientClipByGlobalNorm(TestGradientClip): ...@@ -222,43 +216,72 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
def test_tpyeError(self): def test_tpyeError(self):
# the type of optimizer(grad_clip=) must be an instance of GradientClipBase's derived class # the type of optimizer(grad_clip=) must be an instance of GradientClipBase's derived class
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1, sgd_optimizer = fluid.optimizer.SGD(
grad_clip="test") learning_rate=0.1, grad_clip="test"
)
# if grad is None or not need clip # if grad is None or not need clip
def test_none_grad_fp32(self): def test_none_grad_fp32(self):
ops = self._test_none_grad_helper("float32") ops = self._test_none_grad_helper("float32")
self.assertListEqual(ops, [ self.assertListEqual(
'squared_l2_norm', 'squared_l2_norm', 'sum', 'sqrt', ops,
'fill_constant', 'elementwise_max', 'elementwise_div', [
'elementwise_mul', 'elementwise_mul' 'squared_l2_norm',
]) 'squared_l2_norm',
'sum',
'sqrt',
'fill_constant',
'elementwise_max',
'elementwise_div',
'elementwise_mul',
'elementwise_mul',
],
)
def test_none_grad_fp16(self): def test_none_grad_fp16(self):
ops = self._test_none_grad_helper("float16") ops = self._test_none_grad_helper("float16")
self.assertListEqual(ops, [ self.assertListEqual(
'square', 'reduce_sum', 'square', 'reduce_sum', 'sum', 'cast', ops,
'sqrt', 'fill_constant', 'elementwise_max', 'elementwise_div', [
'cast', 'elementwise_mul', 'cast', 'elementwise_mul' 'squared_l2_norm',
]) 'squared_l2_norm',
'sum',
'cast',
'sqrt',
'fill_constant',
'elementwise_max',
'elementwise_div',
'cast',
'elementwise_mul',
'cast',
'elementwise_mul',
],
)
def _test_none_grad_helper(self, dtype): def _test_none_grad_helper(self, dtype):
prog = fluid.Program() prog = fluid.Program()
startup_program = fluid.Program() startup_program = fluid.Program()
with fluid.program_guard(main_program=prog, with fluid.program_guard(
startup_program=startup_program): main_program=prog, startup_program=startup_program
):
clip = fluid.clip.GradientClipByGlobalNorm(self.clip_norm) clip = fluid.clip.GradientClipByGlobalNorm(self.clip_norm)
x = fluid.default_main_program().global_block().create_parameter( x = (
name="x", shape=[2, 3], dtype=dtype) fluid.default_main_program()
y = fluid.default_main_program().global_block().create_parameter( .global_block()
name="y", shape=[2, 3], dtype=dtype) .create_parameter(name="x", shape=[2, 3], dtype=dtype)
)
y = (
fluid.default_main_program()
.global_block()
.create_parameter(name="y", shape=[2, 3], dtype=dtype)
)
# (x, None) should not be returned # (x, None) should not be returned
params_grads = [(x, None), (x, y), (y, x)] params_grads = [(x, None), (x, y), (y, x)]
params_grads = clip(params_grads) params_grads = clip(params_grads)
self.assertTrue( self.assertTrue(
len(params_grads) == 2, len(params_grads) == 2,
"ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!" "ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!",
) )
ops = [op.type for op in x.block.ops] ops = [op.type for op in x.block.ops]
...@@ -266,7 +289,6 @@ class TestGradientClipByGlobalNorm(TestGradientClip): ...@@ -266,7 +289,6 @@ class TestGradientClipByGlobalNorm(TestGradientClip):
class TestGradientClipByNorm(TestGradientClip): class TestGradientClipByNorm(TestGradientClip):
def init(self): def init(self):
self.clip_norm = 0.2 self.clip_norm = 0.2
...@@ -280,11 +302,11 @@ class TestGradientClipByNorm(TestGradientClip): ...@@ -280,11 +302,11 @@ class TestGradientClipByNorm(TestGradientClip):
v, v,
rtol=1e-05, rtol=1e-05,
atol=1e-08, atol=1e-08,
err_msg='gradient clip by norm has wrong results!') err_msg='gradient clip by norm has wrong results!',
)
# test whether the output is right when use grad_clip # test whether the output is right when use grad_clip
def test_gradient_clip(self): def test_gradient_clip(self):
def func(params_grads): def func(params_grads):
clip = fluid.clip.GradientClipByNorm(clip_norm=self.clip_norm) clip = fluid.clip.GradientClipByNorm(clip_norm=self.clip_norm)
return clip(params_grads) return clip(params_grads)
...@@ -295,25 +317,35 @@ class TestGradientClipByNorm(TestGradientClip): ...@@ -295,25 +317,35 @@ class TestGradientClipByNorm(TestGradientClip):
# if grad is None or not need clip # if grad is None or not need clip
def test_none_grad(self): def test_none_grad(self):
clip = fluid.clip.GradientClipByNorm(self.clip_norm) clip = fluid.clip.GradientClipByNorm(self.clip_norm)
x = fluid.default_main_program().global_block().create_parameter( x = (
name="x", shape=[2, 3], dtype="float32", need_clip=False) fluid.default_main_program()
y = fluid.default_main_program().global_block().create_parameter( .global_block()
name="y", shape=[2, 3], dtype="float32", need_clip=False) .create_parameter(
name="x", shape=[2, 3], dtype="float32", need_clip=False
)
)
y = (
fluid.default_main_program()
.global_block()
.create_parameter(
name="y", shape=[2, 3], dtype="float32", need_clip=False
)
)
# (x, None) should not be returned # (x, None) should not be returned
params_grads = [(x, None), (x, y)] params_grads = [(x, None), (x, y)]
params_grads = clip(params_grads) params_grads = clip(params_grads)
self.assertTrue( self.assertTrue(
len(clip(params_grads)) == 1, len(clip(params_grads)) == 1,
"ClipGradByNorm: when grad is None, it shouldn't be returned by gradient clip!" "ClipGradByNorm: when grad is None, it shouldn't be returned by gradient clip!",
) )
self.assertTrue( self.assertTrue(
params_grads[0][1].name == 'y', params_grads[0][1].name == 'y',
"ClipGradByNorm: grad should not be clipped when filtered out!") "ClipGradByNorm: grad should not be clipped when filtered out!",
)
class TestGradientClipByValue(TestGradientClip): class TestGradientClipByValue(TestGradientClip):
def init(self): def init(self):
self.max = 0.2 self.max = 0.2
self.min = 0.1 self.min = 0.1
...@@ -328,11 +360,11 @@ class TestGradientClipByValue(TestGradientClip): ...@@ -328,11 +360,11 @@ class TestGradientClipByValue(TestGradientClip):
v, v,
rtol=1e-06, rtol=1e-06,
atol=1e-08, atol=1e-08,
err_msg='gradient clip by value has wrong results!') err_msg='gradient clip by value has wrong results!',
)
# test whether the output is right when use grad_clip # test whether the output is right when use grad_clip
def test_gradient_clip(self): def test_gradient_clip(self):
def func(params_grads): def func(params_grads):
clip = fluid.clip.GradientClipByValue(max=self.max, min=self.min) clip = fluid.clip.GradientClipByValue(max=self.max, min=self.min)
return clip(params_grads) return clip(params_grads)
...@@ -343,37 +375,49 @@ class TestGradientClipByValue(TestGradientClip): ...@@ -343,37 +375,49 @@ class TestGradientClipByValue(TestGradientClip):
# if grad is None or not need clip # if grad is None or not need clip
def test_none_grad(self): def test_none_grad(self):
clip = fluid.clip.GradientClipByValue(self.max, self.min) clip = fluid.clip.GradientClipByValue(self.max, self.min)
x = fluid.default_main_program().global_block().create_parameter( x = (
name="x", shape=[2, 3], dtype="float32", need_clip=False) fluid.default_main_program()
y = fluid.default_main_program().global_block().create_parameter( .global_block()
name="y", shape=[2, 3], dtype="float32", need_clip=False) .create_parameter(
name="x", shape=[2, 3], dtype="float32", need_clip=False
)
)
y = (
fluid.default_main_program()
.global_block()
.create_parameter(
name="y", shape=[2, 3], dtype="float32", need_clip=False
)
)
# (x, None) should not be returned # (x, None) should not be returned
params_grads = [(x, None), (x, y)] params_grads = [(x, None), (x, y)]
params_grads = clip(params_grads) params_grads = clip(params_grads)
self.assertTrue( self.assertTrue(
len(clip(params_grads)) == 1, len(clip(params_grads)) == 1,
"ClipGradByValue: when grad is None, it shouldn't be returned by gradient clip!" "ClipGradByValue: when grad is None, it shouldn't be returned by gradient clip!",
) )
self.assertTrue( self.assertTrue(
params_grads[0][1].name == 'y', params_grads[0][1].name == 'y',
"ClipGradByValue: grad should not be clipped when filtered out!") "ClipGradByValue: grad should not be clipped when filtered out!",
)
class TestDygraphGradientClip(unittest.TestCase): class TestDygraphGradientClip(unittest.TestCase):
def test_gradient_clip(self): def test_gradient_clip(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
linear = fluid.dygraph.Linear(5, 5) linear = fluid.dygraph.Linear(5, 5)
inputs = fluid.layers.uniform_random([16, 5], min=-10, inputs = fluid.layers.uniform_random(
max=10).astype('float32') [16, 5], min=-10, max=10
).astype('float32')
out = linear(fluid.dygraph.to_variable(inputs)) out = linear(fluid.dygraph.to_variable(inputs))
loss = fluid.layers.reduce_mean(out) loss = fluid.layers.reduce_mean(out)
loss.backward() loss.backward()
sgd_optimizer = fluid.optimizer.SGD( sgd_optimizer = fluid.optimizer.SGD(
learning_rate=0.0, learning_rate=0.0,
parameter_list=linear.parameters(), parameter_list=linear.parameters(),
grad_clip=fluid.clip.GradientClipByGlobalNorm(0.1)) grad_clip=fluid.clip.GradientClipByGlobalNorm(0.1),
)
self.check_clip_result(loss, sgd_optimizer) self.check_clip_result(loss, sgd_optimizer)
def check_clip_result(self, loss, optimizer): def check_clip_result(self, loss, optimizer):
...@@ -381,20 +425,23 @@ class TestDygraphGradientClip(unittest.TestCase): ...@@ -381,20 +425,23 @@ class TestDygraphGradientClip(unittest.TestCase):
class TestDygraphGradientClipByGlobalNorm(TestDygraphGradientClip): class TestDygraphGradientClipByGlobalNorm(TestDygraphGradientClip):
def setUp(self): def setUp(self):
self.clip_norm = 0.8 self.clip_norm = 0.8
self.clip1 = fluid.clip.GradientClipByGlobalNorm( self.clip1 = fluid.clip.GradientClipByGlobalNorm(
clip_norm=self.clip_norm) clip_norm=self.clip_norm
)
self.clip2 = fluid.clip.GradientClipByGlobalNorm( self.clip2 = fluid.clip.GradientClipByGlobalNorm(
clip_norm=self.clip_norm) clip_norm=self.clip_norm
)
def check_clip_result(self, loss, optimizer): def check_clip_result(self, loss, optimizer):
# if grad is None # if grad is None
x = fluid.dygraph.to_variable(np.array([2, 3]).astype("float32"), x = fluid.dygraph.to_variable(
name="x") np.array([2, 3]).astype("float32"), name="x"
y = fluid.dygraph.to_variable(np.array([3, 4]).astype("float32"), )
name="y") y = fluid.dygraph.to_variable(
np.array([3, 4]).astype("float32"), name="y"
)
assert len(self.clip1([(x, x), (x, y), (x, None)])) == 2 assert len(self.clip1([(x, x), (x, y), (x, None)])) == 2
# get params and grads from network # get params and grads from network
opt, params_grads = optimizer.minimize(loss) opt, params_grads = optimizer.minimize(loss)
...@@ -419,11 +466,11 @@ class TestDygraphGradientClipByGlobalNorm(TestDygraphGradientClip): ...@@ -419,11 +466,11 @@ class TestDygraphGradientClipByGlobalNorm(TestDygraphGradientClip):
self.assertTrue( self.assertTrue(
np.isclose(a=a, b=b, rtol=1e-6, atol=1e-8), np.isclose(a=a, b=b, rtol=1e-6, atol=1e-8),
"gradient clip by global norm has wrong results, expetcd:%f, but received:%f" "gradient clip by global norm has wrong results, expetcd:%f, but received:%f"
% (a, b)) % (a, b),
)
class TestDygraphGradientClipByNorm(TestDygraphGradientClip): class TestDygraphGradientClipByNorm(TestDygraphGradientClip):
def setUp(self): def setUp(self):
self.clip_norm = 0.8 self.clip_norm = 0.8
self.clip = fluid.clip.GradientClipByNorm(clip_norm=self.clip_norm) self.clip = fluid.clip.GradientClipByNorm(clip_norm=self.clip_norm)
...@@ -448,11 +495,11 @@ class TestDygraphGradientClipByNorm(TestDygraphGradientClip): ...@@ -448,11 +495,11 @@ class TestDygraphGradientClipByNorm(TestDygraphGradientClip):
self.assertTrue( self.assertTrue(
np.isclose(a=a, b=b, rtol=1e-6, atol=1e-8), np.isclose(a=a, b=b, rtol=1e-6, atol=1e-8),
"gradient clip by norm has wrong results, expetcd:%f, but received:%f" "gradient clip by norm has wrong results, expetcd:%f, but received:%f"
% (a, b)) % (a, b),
)
class TestDygraphGradientClipByValue(TestDygraphGradientClip): class TestDygraphGradientClipByValue(TestDygraphGradientClip):
def setUp(self): def setUp(self):
self.max = 0.2 self.max = 0.2
self.min = 0.1 self.min = 0.1
...@@ -475,11 +522,11 @@ class TestDygraphGradientClipByValue(TestDygraphGradientClip): ...@@ -475,11 +522,11 @@ class TestDygraphGradientClipByValue(TestDygraphGradientClip):
v, v,
rtol=1e-06, rtol=1e-06,
atol=1e-08, atol=1e-08,
err_msg='gradient clip by value has wrong results!') err_msg='gradient clip by value has wrong results!',
)
class SimpleNet(paddle.nn.Layer): class SimpleNet(paddle.nn.Layer):
def __init__(self): def __init__(self):
super(SimpleNet, self).__init__() super(SimpleNet, self).__init__()
self.linear = paddle.nn.Linear(5, 5) self.linear = paddle.nn.Linear(5, 5)
...@@ -492,19 +539,21 @@ class SimpleNet(paddle.nn.Layer): ...@@ -492,19 +539,21 @@ class SimpleNet(paddle.nn.Layer):
class TestDygraphGradientClipFP16(unittest.TestCase): class TestDygraphGradientClipFP16(unittest.TestCase):
def test_gradient_clip(self): def test_gradient_clip(self):
if fluid.core.is_compiled_with_cuda(): if fluid.core.is_compiled_with_cuda():
with fluid.dygraph.guard(): with fluid.dygraph.guard():
paddle.seed(10) paddle.seed(10)
model = SimpleNet() model = SimpleNet()
sgd_optimizer = paddle.optimizer.SGD( sgd_optimizer = paddle.optimizer.SGD(
learning_rate=0.0, parameters=model.parameters()) learning_rate=0.0, parameters=model.parameters()
)
model, sgd_optimizer = paddle.amp.decorate( model, sgd_optimizer = paddle.amp.decorate(
models=model, optimizers=sgd_optimizer, level='O2') models=model, optimizers=sgd_optimizer, level='O2'
)
scaler = paddle.amp.GradScaler(init_loss_scaling=1024) scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
inputs = fluid.layers.uniform_random([1, 5], min=-10, inputs = fluid.layers.uniform_random(
max=10).astype('float32') [1, 5], min=-10, max=10
).astype('float32')
with paddle.amp.auto_cast(level='O2'): with paddle.amp.auto_cast(level='O2'):
out = model(fluid.dygraph.to_variable(inputs)) out = model(fluid.dygraph.to_variable(inputs))
loss = fluid.layers.reduce_mean(out) loss = fluid.layers.reduce_mean(out)
...@@ -543,15 +592,16 @@ class TestDygraphGradientClipFP16(unittest.TestCase): ...@@ -543,15 +592,16 @@ class TestDygraphGradientClipFP16(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.isclose(a=a, b=b, rtol=1e-3, atol=1e-8), np.isclose(a=a, b=b, rtol=1e-3, atol=1e-8),
"gradient clip by global norm has wrong results, expetcd:%f, but received:%f" "gradient clip by global norm has wrong results, expetcd:%f, but received:%f"
% (a, b)) % (a, b),
)
class TestDygraphGradientClipFP64(unittest.TestCase): class TestDygraphGradientClipFP64(unittest.TestCase):
def test_gradient_clip(self): def test_gradient_clip(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
inputs = fluid.layers.uniform_random([16, 5], min=-10, inputs = fluid.layers.uniform_random(
max=10).astype('float64') [16, 5], min=-10, max=10
).astype('float64')
linear = fluid.dygraph.Linear(5, 5, dtype="float64") linear = fluid.dygraph.Linear(5, 5, dtype="float64")
out = linear(fluid.dygraph.to_variable(inputs)) out = linear(fluid.dygraph.to_variable(inputs))
loss = fluid.layers.reduce_mean(out) loss = fluid.layers.reduce_mean(out)
...@@ -589,11 +639,11 @@ class TestDygraphGradientClipFP64(unittest.TestCase): ...@@ -589,11 +639,11 @@ class TestDygraphGradientClipFP64(unittest.TestCase):
self.assertTrue( self.assertTrue(
np.isclose(a=a, b=b, rtol=1e-6, atol=1e-8), np.isclose(a=a, b=b, rtol=1e-6, atol=1e-8),
"gradient clip by global norm has wrong results, expetcd:%f, but received:%f" "gradient clip by global norm has wrong results, expetcd:%f, but received:%f"
% (a, b)) % (a, b),
)
class TestPureFP16ClipGradByGlobalNorm(unittest.TestCase): class TestPureFP16ClipGradByGlobalNorm(unittest.TestCase):
def check_main(self, expected_has_cast_op): def check_main(self, expected_has_cast_op):
main_prog = paddle.static.Program() main_prog = paddle.static.Program()
startup_prog = paddle.static.Program() startup_prog = paddle.static.Program()
...@@ -604,12 +654,12 @@ class TestPureFP16ClipGradByGlobalNorm(unittest.TestCase): ...@@ -604,12 +654,12 @@ class TestPureFP16ClipGradByGlobalNorm(unittest.TestCase):
param_and_grads = [] param_and_grads = []
main_block = main_prog.global_block() main_block = main_prog.global_block()
for name, shape in zip(names, shapes): for name, shape in zip(names, shapes):
p = main_block.create_parameter(name=name, p = main_block.create_parameter(
shape=shape, name=name, shape=shape, dtype='float16'
dtype='float16') )
g = main_block.create_parameter(name=p.name + '@GRAD', g = main_block.create_parameter(
shape=p.shape, name=p.name + '@GRAD', shape=p.shape, dtype=p.dtype
dtype=p.dtype) )
param_and_grads.append((p, g)) param_and_grads.append((p, g))
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
......
...@@ -148,19 +148,21 @@ class AdamW(Optimizer): ...@@ -148,19 +148,21 @@ class AdamW(Optimizer):
_beta1_pow_acc_str = "beta1_pow_acc" _beta1_pow_acc_str = "beta1_pow_acc"
_beta2_pow_acc_str = "beta2_pow_acc" _beta2_pow_acc_str = "beta2_pow_acc"
def __init__(self, def __init__(
learning_rate=0.001, self,
beta1=0.9, learning_rate=0.001,
beta2=0.999, beta1=0.9,
epsilon=1e-8, beta2=0.999,
parameters=None, epsilon=1e-8,
weight_decay=0.01, parameters=None,
lr_ratio=None, weight_decay=0.01,
apply_decay_param_fun=None, lr_ratio=None,
grad_clip=None, apply_decay_param_fun=None,
lazy_mode=False, grad_clip=None,
multi_precision=False, lazy_mode=False,
name=None): multi_precision=False,
name=None,
):
assert learning_rate is not None assert learning_rate is not None
assert beta1 is not None assert beta1 is not None
assert beta2 is not None assert beta2 is not None
...@@ -171,14 +173,16 @@ class AdamW(Optimizer): ...@@ -171,14 +173,16 @@ class AdamW(Optimizer):
raise ValueError("Invaild value of beta2, expect beta2 in [0,1).") raise ValueError("Invaild value of beta2, expect beta2 in [0,1).")
if not 0 <= epsilon: if not 0 <= epsilon:
raise ValueError("Invaild value of epsilon, expect epsilon >= 0.") raise ValueError("Invaild value of epsilon, expect epsilon >= 0.")
if not isinstance(weight_decay, float) and \ if not isinstance(weight_decay, float) and not isinstance(
not isinstance(weight_decay, framework.Variable): weight_decay, framework.Variable
):
raise TypeError("weight_decay should be float or Tensor.") raise TypeError("weight_decay should be float or Tensor.")
if lr_ratio is not None: if lr_ratio is not None:
assert isinstance(lr_ratio, Callable) assert isinstance(lr_ratio, Callable)
if not core.is_compiled_with_cuda(): if not core.is_compiled_with_cuda():
raise NotImplementedError( raise NotImplementedError(
"'lr_ratio' is unimplemented in CPU, XPU and NPU") "'lr_ratio' is unimplemented in CPU, XPU and NPU"
)
if parameters is not None: if parameters is not None:
# paddle.Tensor is also iterable, so here we don't check whether # paddle.Tensor is also iterable, so here we don't check whether
...@@ -187,13 +191,16 @@ class AdamW(Optimizer): ...@@ -187,13 +191,16 @@ class AdamW(Optimizer):
if isinstance(parameters, (paddle.Tensor, core.eager.Tensor)): if isinstance(parameters, (paddle.Tensor, core.eager.Tensor)):
raise TypeError( raise TypeError(
"`parameters` argument given to the optimizer should be " "`parameters` argument given to the optimizer should be "
"an iterable of paddle Tensors, but got argument type is `{}`." "an iterable of paddle Tensors, but got argument type is `{}`.".format(
.format(type(parameters))) type(parameters)
)
)
if isinstance(parameters, dict): if isinstance(parameters, dict):
raise TypeError( raise TypeError(
"`parameters` argument should not get dict type, " "`parameters` argument should not get dict type, "
"if parameter groups is needed, please set `parameters`" "if parameter groups is needed, please set `parameters`"
" as list of dict") " as list of dict"
)
self._parameter_list = list(parameters) self._parameter_list = list(parameters)
else: else:
self._parameter_list = None self._parameter_list = None
...@@ -207,8 +214,9 @@ class AdamW(Optimizer): ...@@ -207,8 +214,9 @@ class AdamW(Optimizer):
if not isinstance(learning_rate, (float, LRScheduler)): if not isinstance(learning_rate, (float, LRScheduler)):
raise TypeError( raise TypeError(
"learning rate should be float or LRScheduler, got %s here" % "learning rate should be float or LRScheduler, got %s here"
type(learning_rate)) % type(learning_rate)
)
if grad_clip is not None: if grad_clip is not None:
if not isinstance(grad_clip, GradientClipBase): if not isinstance(grad_clip, GradientClipBase):
raise TypeError( raise TypeError(
...@@ -220,8 +228,9 @@ class AdamW(Optimizer): ...@@ -220,8 +228,9 @@ class AdamW(Optimizer):
if self._parameter_list: if self._parameter_list:
if isinstance(self._parameter_list[0], dict): if isinstance(self._parameter_list[0], dict):
for param_group in self._parameter_list: for param_group in self._parameter_list:
assert 'params' in param_group, \ assert (
'params should be set in parameters if parameter groups are optimized in different options' 'params' in param_group
), 'params should be set in parameters if parameter groups are optimized in different options'
self._dtype = self._parameter_list[0]['params'][0].dtype self._dtype = self._parameter_list[0]['params'][0].dtype
else: else:
self._dtype = self._parameter_list[0].dtype self._dtype = self._parameter_list[0].dtype
...@@ -260,7 +269,7 @@ class AdamW(Optimizer): ...@@ -260,7 +269,7 @@ class AdamW(Optimizer):
'beta2': beta2, 'beta2': beta2,
'epsilon': epsilon, 'epsilon': epsilon,
'lazy_mode': lazy_mode, 'lazy_mode': lazy_mode,
'grad_clip': grad_clip 'grad_clip': grad_clip,
} }
self._param_groups = [] self._param_groups = []
...@@ -297,7 +306,8 @@ class AdamW(Optimizer): ...@@ -297,7 +306,8 @@ class AdamW(Optimizer):
elif isinstance(params, set): elif isinstance(params, set):
raise TypeError( raise TypeError(
"optimizer parameters should be in ordered collections," "optimizer parameters should be in ordered collections,"
"but received set, please use list instead.") "but received set, please use list instead."
)
else: else:
param_group['params'] = list(params) param_group['params'] = list(params)
...@@ -311,11 +321,13 @@ class AdamW(Optimizer): ...@@ -311,11 +321,13 @@ class AdamW(Optimizer):
if not param_set.isdisjoint(set(param_group['params'])): if not param_set.isdisjoint(set(param_group['params'])):
raise ValueError( raise ValueError(
"some parameters appear in more than one parameter group") "some parameters appear in more than one parameter group"
)
for param in param_group['params']: for param in param_group['params']:
param.optimize_attr['learning_rate'] = param_group.get( param.optimize_attr['learning_rate'] = param_group.get(
'learning_rate', 1.) 'learning_rate', 1.0
)
self._param_groups.append(param_group) self._param_groups.append(param_group)
...@@ -327,19 +339,23 @@ class AdamW(Optimizer): ...@@ -327,19 +339,23 @@ class AdamW(Optimizer):
var_name = param.name + "_fp32_master" var_name = param.name + "_fp32_master"
var_name = unique_name.generate(var_name) var_name = unique_name.generate(var_name)
var = layers.create_global_var(name=var_name, var = layers.create_global_var(
shape=param.shape, name=var_name,
value=0, shape=param.shape,
dtype='float32', value=0,
persistable=True) dtype='float32',
persistable=True,
)
block = self.helper.startup_program.global_block() block = self.helper.startup_program.global_block()
block.append_op(type="cast", block.append_op(
inputs={"X": [param]}, type="cast",
outputs={"Out": [var]}, inputs={"X": [param]},
attrs={ outputs={"Out": [var]},
"in_dtype": param.dtype, attrs={
"out_dtype": core.VarDesc.VarType.FP32 "in_dtype": param.dtype,
}) "out_dtype": core.VarDesc.VarType.FP32,
},
)
self._master_weights[param.name] = var self._master_weights[param.name] = var
return var return var
...@@ -353,20 +369,31 @@ class AdamW(Optimizer): ...@@ -353,20 +369,31 @@ class AdamW(Optimizer):
""" """
if self._name is not None: if self._name is not None:
name = self._name + "_" + name name = self._name + "_" + name
find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 find_master = self._multi_precision and (
target_param = self._master_weights[ param.dtype == core.VarDesc.VarType.FP16
param.name] if find_master else param or param.dtype == core.VarDesc.VarType.BF16
)
target_param = (
self._master_weights[param.name] if find_master else param
)
target_name = target_param.name target_name = target_param.name
if (name not in self._accumulators if (
or target_name not in self._accumulators[name]): name not in self._accumulators
or target_name not in self._accumulators[name]
):
raise Exception( raise Exception(
"Accumulator {} does not exist for parameter {}".format( "Accumulator {} does not exist for parameter {}".format(
name, target_name)) name, target_name
)
)
return self._accumulators[name][target_name] return self._accumulators[name][target_name]
def _add_moments_pows(self, p): def _add_moments_pows(self, p):
acc_dtype = p.dtype acc_dtype = p.dtype
if acc_dtype == core.VarDesc.VarType.FP16: if (
acc_dtype == core.VarDesc.VarType.FP16
or acc_dtype == core.VarDesc.VarType.BF16
):
acc_dtype = core.VarDesc.VarType.FP32 acc_dtype = core.VarDesc.VarType.FP32
self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype)
self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype)
...@@ -374,18 +401,24 @@ class AdamW(Optimizer): ...@@ -374,18 +401,24 @@ class AdamW(Optimizer):
name=self._beta1_pow_acc_str, name=self._beta1_pow_acc_str,
param=p, param=p,
dtype=acc_dtype, dtype=acc_dtype,
fill_value=0.9 if isinstance(self._beta1, Variable) \ fill_value=0.9
else self._beta1, if isinstance(self._beta1, Variable)
else self._beta1,
shape=[1], shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') type=core.VarDesc.VarType.LOD_TENSOR,
device='cpu',
)
self._add_accumulator( self._add_accumulator(
name=self._beta2_pow_acc_str, name=self._beta2_pow_acc_str,
param=p, param=p,
dtype=acc_dtype, dtype=acc_dtype,
fill_value=0.999 if isinstance(self._beta2, Variable) \ fill_value=0.999
else self._beta2, if isinstance(self._beta2, Variable)
else self._beta2,
shape=[1], shape=[1],
type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') type=core.VarDesc.VarType.LOD_TENSOR,
device='cpu',
)
def _create_accumulators(self, block, parameters): def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
...@@ -394,13 +427,19 @@ class AdamW(Optimizer): ...@@ -394,13 +427,19 @@ class AdamW(Optimizer):
# Create accumulator tensors for first and second moments # Create accumulator tensors for first and second moments
for p in parameters: for p in parameters:
if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: if self._multi_precision and (
p.dtype == core.VarDesc.VarType.FP16
or p.dtype == core.VarDesc.VarType.BF16
):
master_p = self._create_master_weight(p) master_p = self._create_master_weight(p)
self._add_moments_pows(master_p) self._add_moments_pows(master_p)
continue continue
if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision: if (
p.dtype == core.VarDesc.VarType.FP16
or p.dtype == core.VarDesc.VarType.BF16
) and not self._multi_precision:
warnings.warn( warnings.warn(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." "Accumulating with FP16/BFP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Adam optimizer." "Consider using multi_precision=True option of the Adam optimizer."
) )
self._add_moments_pows(p) self._add_moments_pows(p)
...@@ -413,53 +452,112 @@ class AdamW(Optimizer): ...@@ -413,53 +452,112 @@ class AdamW(Optimizer):
# Whether we should do weight decay for the parameter. # Whether we should do weight decay for the parameter.
with_decay = True with_decay = True
if self._apply_decay_param_fun is not None \ if (
and not self._apply_decay_param_fun(param.name): self._apply_decay_param_fun is not None
and not self._apply_decay_param_fun(param.name)
):
with_decay = False with_decay = False
moment1 = self._get_accumulator(self._moment1_acc_str, moment1 = self._get_accumulator(
param_and_grad[0]) self._moment1_acc_str, param_and_grad[0]
moment2 = self._get_accumulator(self._moment2_acc_str, )
param_and_grad[0]) moment2 = self._get_accumulator(
beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, self._moment2_acc_str, param_and_grad[0]
param_and_grad[0]) )
beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str, beta1_pow_acc = self._get_accumulator(
param_and_grad[0]) self._beta1_pow_acc_str, param_and_grad[0]
find_master = self._multi_precision and param_and_grad[ )
0].dtype == core.VarDesc.VarType.FP16 beta2_pow_acc = self._get_accumulator(
master_weight = (self._master_weights[param_and_grad[0].name] self._beta2_pow_acc_str, param_and_grad[0]
if find_master else None) )
find_master = self._multi_precision and (
param_and_grad[0].dtype == core.VarDesc.VarType.FP16
or param_and_grad[0].dtype == core.VarDesc.VarType.BF16
)
master_weight = (
self._master_weights[param_and_grad[0].name]
if find_master
else None
)
lr = self._create_param_lr(param_and_grad) lr = self._create_param_lr(param_and_grad)
# create the adamw optimize op # create the adamw optimize op
if framework._non_static_mode(): if framework._non_static_mode():
lr_ratio_ = 1. if self._lr_ratio is None else self._lr_ratio( lr_ratio_ = (
param_and_grad[0]) 1.0
if self._lr_ratio is None
_beta1 = self._beta1 if not isinstance( else self._lr_ratio(param_and_grad[0])
self._beta1, Variable) else self._beta1.numpy().item(0) )
_beta2 = self._beta2 if not isinstance(
self._beta2, Variable) else self._beta2.numpy().item(0) _beta1 = (
self._beta1
if not isinstance(self._beta1, Variable)
else self._beta1.numpy().item(0)
)
_beta2 = (
self._beta2
if not isinstance(self._beta2, Variable)
else self._beta2.numpy().item(0)
)
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
found_inf = self._get_auxiliary_var('found_inf') found_inf = self._get_auxiliary_var('found_inf')
_, _, _, _, _, _ = _C_ops.adamw_( _, _, _, _, _, _ = _C_ops.adamw_(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2, param_and_grad[0],
beta1_pow_acc, beta2_pow_acc, master_weight, found_inf, param_and_grad[1],
_beta1, _beta2, self._epsilon, lr_ratio_, lr,
self._weight_decay, with_decay, self._lazy_mode, 1000, moment1,
find_master, False) moment2,
beta1_pow_acc,
beta2_pow_acc,
master_weight,
found_inf,
_beta1,
_beta2,
self._epsilon,
lr_ratio_,
self._weight_decay,
with_decay,
self._lazy_mode,
1000,
find_master,
False,
)
else: else:
_, _, _, _, _, _ = _legacy_C_ops.adamw( _, _, _, _, _, _ = _legacy_C_ops.adamw(
param_and_grad[0], param_and_grad[1], lr, moment1, moment2, param_and_grad[0],
beta1_pow_acc, beta2_pow_acc, master_weight, param_and_grad[1],
param_and_grad[0], moment1, moment2, beta1_pow_acc, lr,
beta2_pow_acc, master_weight, 'epsilon', self._epsilon, moment1,
'lazy_mode', self._lazy_mode, moment2,
'min_row_size_to_use_multithread', 1000, 'beta1', _beta1, beta1_pow_acc,
'beta2', _beta2, "with_decay", with_decay, 'coeff', beta2_pow_acc,
self._weight_decay, 'multi_precision', find_master, master_weight,
'lr_ratio', lr_ratio_) param_and_grad[0],
moment1,
moment2,
beta1_pow_acc,
beta2_pow_acc,
master_weight,
'epsilon',
self._epsilon,
'lazy_mode',
self._lazy_mode,
'min_row_size_to_use_multithread',
1000,
'beta1',
_beta1,
'beta2',
_beta2,
"with_decay",
with_decay,
'coeff',
self._weight_decay,
'multi_precision',
find_master,
'lr_ratio',
lr_ratio_,
)
return None return None
inputs = { inputs = {
...@@ -486,18 +584,14 @@ class AdamW(Optimizer): ...@@ -486,18 +584,14 @@ class AdamW(Optimizer):
"Beta2PowOut": [beta2_pow_acc], "Beta2PowOut": [beta2_pow_acc],
} }
attrs = { attrs = {
"lazy_mode": "lazy_mode": self._lazy_mode,
self._lazy_mode, "min_row_size_to_use_multithread": 1000,
"min_row_size_to_use_multithread": "multi_precision": find_master,
1000, "with_decay": with_decay,
"multi_precision": "coeff": self._weight_decay,
find_master, "lr_ratio": 1.0
"with_decay": if self._lr_ratio is None
with_decay, else self._lr_ratio(param_and_grad[0]),
"coeff":
self._weight_decay,
"lr_ratio":
1. if self._lr_ratio is None else self._lr_ratio(param_and_grad[0])
} }
if isinstance(self._beta1, Variable): if isinstance(self._beta1, Variable):
...@@ -517,11 +611,13 @@ class AdamW(Optimizer): ...@@ -517,11 +611,13 @@ class AdamW(Optimizer):
inputs["MasterParam"] = master_weight inputs["MasterParam"] = master_weight
outputs["MasterParamOut"] = master_weight outputs["MasterParamOut"] = master_weight
adamw_op = block.append_op(type=self.type, adamw_op = block.append_op(
inputs=inputs, type=self.type,
outputs=outputs, inputs=inputs,
attrs=attrs, outputs=outputs,
stop_gradient=True) attrs=attrs,
stop_gradient=True,
)
return adamw_op return adamw_op
...@@ -541,7 +637,7 @@ class AdamW(Optimizer): ...@@ -541,7 +637,7 @@ class AdamW(Optimizer):
.. code-block:: python .. code-block:: python
import paddle import paddle
a = paddle.rand([2,13], dtype="float32") a = paddle.rand([2,13], dtype="float32")
linear = paddle.nn.Linear(13, 5) linear = paddle.nn.Linear(13, 5)
# This can be any optimizer supported by dygraph. # This can be any optimizer supported by dygraph.
...@@ -560,24 +656,28 @@ class AdamW(Optimizer): ...@@ -560,24 +656,28 @@ class AdamW(Optimizer):
if param._grad_ivar() is not None: if param._grad_ivar() is not None:
grad_var = param._grad_ivar() grad_var = param._grad_ivar()
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
if hasattr(grad_var, "is_selected_rows" if (
) and grad_var.is_selected_rows( hasattr(grad_var, "is_selected_rows")
) and self.regularization is not None: and grad_var.is_selected_rows()
and self.regularization is not None
):
raise RuntimeError( raise RuntimeError(
"AdamW don't support weight_decay with sparse parameters, please set it to None." "AdamW don't support weight_decay with sparse parameters, please set it to None."
) )
else: else:
if hasattr( if (
grad_var, "_is_sparse") and grad_var._is_sparse( hasattr(grad_var, "_is_sparse")
) and self.regularization is not None: and grad_var._is_sparse()
and self.regularization is not None
):
raise RuntimeError( raise RuntimeError(
"AdamW don't support weight_decay with sparse parameters, please set it to None." "AdamW don't support weight_decay with sparse parameters, please set it to None."
) )
params_grads.append((param, grad_var)) params_grads.append((param, grad_var))
optimize_ops = self._apply_optimize(loss=None, optimize_ops = self._apply_optimize(
startup_program=None, loss=None, startup_program=None, params_grads=params_grads
params_grads=params_grads) )
else: else:
# optimize parameters in groups # optimize parameters in groups
for param_group in self._param_groups: for param_group in self._param_groups:
...@@ -588,35 +688,41 @@ class AdamW(Optimizer): ...@@ -588,35 +688,41 @@ class AdamW(Optimizer):
if param._grad_ivar() is not None: if param._grad_ivar() is not None:
grad_var = param._grad_ivar() grad_var = param._grad_ivar()
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
if hasattr(grad_var, "is_selected_rows" if (
) and grad_var.is_selected_rows( hasattr(grad_var, "is_selected_rows")
) and self.regularization is not None: and grad_var.is_selected_rows()
and self.regularization is not None
):
raise RuntimeError( raise RuntimeError(
"AdamW don't support weight_decay with sparse parameters, please set it to None." "AdamW don't support weight_decay with sparse parameters, please set it to None."
) )
else: else:
if hasattr(grad_var, if (
"_is_sparse") and grad_var._is_sparse( hasattr(grad_var, "_is_sparse")
) and self.regularization is not None: and grad_var._is_sparse()
and self.regularization is not None
):
raise RuntimeError( raise RuntimeError(
"AdamW don't support weight_decay with sparse parameters, please set it to None." "AdamW don't support weight_decay with sparse parameters, please set it to None."
) )
params_grads['params'].append((param, grad_var)) params_grads['params'].append((param, grad_var))
params_grads.update( params_grads.update(
{k: v {k: v for k, v in param_group.items() if k != 'params'}
for k, v in param_group.items() if k != 'params'}) )
self._apply_optimize(loss=None, self._apply_optimize(
startup_program=None, loss=None, startup_program=None, params_grads=params_grads
params_grads=params_grads) )
def _update_param_group(self, parameters): def _update_param_group(self, parameters):
self._beta1 = parameters.get('beta1', self._default_dict['beta1']) self._beta1 = parameters.get('beta1', self._default_dict['beta1'])
self._beta2 = parameters.get('beta2', self._default_dict['beta2']) self._beta2 = parameters.get('beta2', self._default_dict['beta2'])
self._epsilon = parameters.get('epsilon', self._default_dict['epsilon']) self._epsilon = parameters.get('epsilon', self._default_dict['epsilon'])
self._lazy_mode = parameters.get('lazy_mode', self._lazy_mode = parameters.get(
self._default_dict['lazy_mode']) 'lazy_mode', self._default_dict['lazy_mode']
self._weight_decay = parameters.get('weight_decay', )
self._default_dict['weight_decay']) self._weight_decay = parameters.get(
'weight_decay', self._default_dict['weight_decay']
)
parameters = parameters.get('params') parameters = parameters.get('params')
return parameters return parameters
...@@ -440,15 +440,21 @@ class Optimizer(object): ...@@ -440,15 +440,21 @@ class Optimizer(object):
return self._opti_name_list return self._opti_name_list
def _create_global_learning_rate(self): def _create_global_learning_rate(self):
# lr var can't be float16, for pure fp16 training, should extra handle the dtype for lr # lr var can't be float16 or bfloat16, for pure fp16 or fp16 training, should extra handle the dtype for lr
_lr_dtype = ( _lr_dtype = (
paddle.get_default_dtype() if self._dtype is None else self._dtype paddle.get_default_dtype() if self._dtype is None else self._dtype
) )
_lr_dtype = ( _lr_dtype = (
paddle.float32 paddle.float32
if ( if (
paddle.get_default_dtype() != "float16" (
and _lr_dtype == paddle.float16 paddle.get_default_dtype() != "float16"
and _lr_dtype == paddle.float16
)
or (
paddle.get_default_dtype() != "bfloat16"
and _lr_dtype == paddle.bfloat16
)
) )
else _lr_dtype else _lr_dtype
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册