From d1e8b1e2ca84e4e458e47da9e5a95a96bcf5f330 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Tue, 11 Apr 2023 10:36:38 +0800 Subject: [PATCH] Cherry pick for fix of operator precision. (#52705) * Fix scale kernel for low precision, cherry pick #50998. * Fix the FP16 precision problem of add_n. (#50129) * Change squared_l2_norm to reuse ReduceKernel, and register fp16 and bf16 kernel, which is cherry pick #48315. * Cherry-pick the fix of MPTypeTrait in KP, which is implemented in #50993. * Cherry-pick the multi-precision support of AdamW for bf16, #48041. * Fix compiling error. * Cherry-pick the fix of CubTensorReduceImpl for bfloat16 in #50993. * Fix unittest. --------- Co-authored-by: liuruyan <44316842+liuruyan@users.noreply.github.com> --- paddle/phi/kernels/funcs/reduce_function.h | 38 +- paddle/phi/kernels/gpu/add_n_kernel.cu | 19 +- paddle/phi/kernels/gpu/scale_kernel.cu | 23 +- .../gpu/squared_l2_norm_grad_kernel.cu | 39 +- .../phi/kernels/gpu/squared_l2_norm_kernel.cu | 31 +- .../kernels/primitive/compute_primitives.h | 6 + python/paddle/fluid/clip.py | 297 ++++++++------ .../fluid/tests/unittests/test_add_n_op.py | 64 +++ .../tests/unittests/test_gradient_clip.py | 262 ++++++++----- python/paddle/optimizer/adamw.py | 370 +++++++++++------- python/paddle/optimizer/optimizer.py | 12 +- 11 files changed, 763 insertions(+), 398 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_add_n_op.py diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index 446dfc73d5b..8f3fe513682 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -986,14 +986,16 @@ template class ReduceOp, typename TransformOp> -static typename std::enable_if::value, - void>::type -CubTensorReduceImpl(const Tx* x_data, - Ty* y_data, - const TransformOp& transform, - int reduce_num, - const KPDevice& dev_ctx, - KPStream stream) { +static + typename std::enable_if::value && + !std::is_same::value, + void>::type + CubTensorReduceImpl(const Tx* x_data, + Ty* y_data, + const TransformOp& transform, + int reduce_num, + const KPDevice& dev_ctx, + KPStream stream) { auto reducer = ReduceOp(); cub::TransformInputIterator trans_x(x_data, transform); @@ -1037,6 +1039,23 @@ CubTensorReduceImpl(const Tx* x_data, PADDLE_THROW(phi::errors::InvalidArgument( "Tx should not be float16 when using cub::DeviceReduce::Reduce().")); } + +template + class ReduceOp, + typename TransformOp> +static typename std::enable_if::value, + void>::type +CubTensorReduceImpl(const Tx* x_data, + Ty* y_data, + const TransformOp& transform, + int reduce_num, + const KPDevice& dev_ctx, + KPStream stream) { + PADDLE_THROW(phi::errors::InvalidArgument( + "Tx should not be bfloat16 when using cub::DeviceReduce::Reduce().")); +} #endif // PADDLE_WITH_XPU_KP template ::value; - bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16; + constexpr bool kIsTxBF16 = std::is_same::value; + bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16 && !kIsTxBF16; #ifndef PADDLE_WITH_XPU_KP if (use_cub_reduce) { if (is_mean) { diff --git a/paddle/phi/kernels/gpu/add_n_kernel.cu b/paddle/phi/kernels/gpu/add_n_kernel.cu index f32ba597f5b..58c75097d45 100644 --- a/paddle/phi/kernels/gpu/add_n_kernel.cu +++ b/paddle/phi/kernels/gpu/add_n_kernel.cu @@ -14,10 +14,10 @@ #include "paddle/phi/kernels/add_n_kernel.h" -#include "paddle/phi/kernels/impl/add_n_kernel_impl.h" - #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/kernels/impl/add_n_kernel_impl.h" namespace phi { @@ -38,16 +38,18 @@ __global__ void Sum2CUDAKernel(const T *in_0, template __global__ void SumArrayCUDAKernel( T **in, T *out, int64_t N, size_t in_size, bool read_dst) { + using MPType = typename phi::dtype::MPTypeTrait::Type; int id = blockIdx.x * blockDim.x + threadIdx.x; while (id < N) { - T total(read_dst ? out[id] : static_cast(0)); + MPType total(read_dst ? static_cast(out[id]) + : static_cast(0)); for (int i = 0; i < in_size; ++i) { const T *tmp = in[i]; if (tmp) { - total += tmp[id]; + total += static_cast(tmp[id]); } } - out[id] = total; + out[id] = static_cast(total); id += blockDim.x * gridDim.x; } } @@ -116,11 +118,12 @@ void AddNKernel(const Context &dev_ctx, int64_t length_0 = in_0.numel(); int64_t length_1 = in_1.numel(); if (length_0 && length_1 && in_0.IsInitialized() && in_1.IsInitialized()) { + using MPType = typename phi::dtype::MPTypeTrait::Type; auto result = EigenVector::Flatten(*out); auto &place = *dev_ctx.eigen_device(); - auto in_0_e = EigenVector::Flatten(in_0); - auto in_1_e = EigenVector::Flatten(in_1); - result.device(place) = in_0_e + in_1_e; + auto in_0_e = EigenVector::Flatten(in_0).template cast(); + auto in_1_e = EigenVector::Flatten(in_1).template cast(); + result.device(place) = (in_0_e + in_1_e).template cast(); } else if (length_0 && in_0.IsInitialized()) { auto result = EigenVector::Flatten(*out); auto &place = *dev_ctx.eigen_device(); diff --git a/paddle/phi/kernels/gpu/scale_kernel.cu b/paddle/phi/kernels/gpu/scale_kernel.cu index 1a574c05494..5b7ef9ba631 100644 --- a/paddle/phi/kernels/gpu/scale_kernel.cu +++ b/paddle/phi/kernels/gpu/scale_kernel.cu @@ -15,28 +15,30 @@ limitations under the License. */ #include "paddle/phi/kernels/scale_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" namespace phi { -template +template struct ScaleFunctor { - InT bias; - InT scale; + ParamT bias; + ParamT scale; bool bias_after_scale; - ScaleFunctor(InT scale_data, InT bias_data, bool is_bias_after_sacle) + ScaleFunctor(ParamT scale_data, ParamT bias_data, bool is_bias_after_sacle) : bias(bias_data), scale(scale_data), bias_after_scale(is_bias_after_sacle) {} - __device__ __forceinline__ InT operator()(const InT x) const { + __device__ __forceinline__ DataT operator()(const DataT x) const { if (bias_after_scale) { - return scale * x + bias; + return static_cast(scale * static_cast(x) + bias); } else { - return scale * (x + bias); + return static_cast(scale * (static_cast(x) + bias)); } } }; @@ -48,16 +50,21 @@ void ScaleKernel(const Context& dev_ctx, float bias, bool bias_after_scale, DenseTensor* out) { + using MT = typename phi::dtype::MPTypeTrait::Type; std::vector inputs; std::vector outputs; inputs.emplace_back(&x); outputs.emplace_back(out); dev_ctx.template Alloc(out); + if (x.numel() <= 0 || (!x.IsInitialized())) { + return; + } phi::funcs::ElementwiseKernel( dev_ctx, inputs, &outputs, - ScaleFunctor(scale.to(), static_cast(bias), bias_after_scale)); + ScaleFunctor( + scale.to(), static_cast(bias), bias_after_scale)); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/squared_l2_norm_grad_kernel.cu b/paddle/phi/kernels/gpu/squared_l2_norm_grad_kernel.cu index 908a7557d1b..4557d44f150 100644 --- a/paddle/phi/kernels/gpu/squared_l2_norm_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/squared_l2_norm_grad_kernel.cu @@ -15,12 +15,47 @@ #include "paddle/phi/kernels/squared_l2_norm_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/squared_l2_norm_grad_kernel_impl.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" + +namespace phi { +/** + * x*y*2.0 + */ +template +struct DoubleMulFunctor { + __device__ __forceinline__ T operator()(const T a, const T b) const { + return b * a * static_cast(2.0f); + } +}; + +template +void SquaredL2NormGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + DenseTensor* dx) { + dev_ctx.template Alloc(dx); + + PADDLE_ENFORCE_EQ( + dout.numel(), + 1, + phi::errors::InvalidArgument( + "Input(GRAD@Out) of SquaredL2NormGradOP should be a scalar.")); + std::vector ins{&x, &dout}; + std::vector outs{dx}; + + funcs::BroadcastKernel( + dev_ctx, ins, &outs, -1, phi::DoubleMulFunctor()); +} +} // namespace phi PD_REGISTER_KERNEL(squared_l2_norm_grad, GPU, ALL_LAYOUT, phi::SquaredL2NormGradKernel, float, - double) {} + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu b/paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu index d585d209b42..1af6ad540e7 100644 --- a/paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu @@ -15,9 +15,34 @@ #include "paddle/phi/kernels/squared_l2_norm_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/impl/squared_l2_norm_kernel_impl.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" -PD_REGISTER_KERNEL( - squared_l2_norm, GPU, ALL_LAYOUT, phi::SquaredL2NormKernel, float, double) { +namespace phi { + +template +void SquaredL2NormKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + dev_ctx.template Alloc(out); + std::vector origin_reduce_dims; + for (size_t i = 0; i < x.dims().size(); i++) { + origin_reduce_dims.push_back(i); + } + phi::funcs::ReduceKernel>( + dev_ctx, x, out, kps::SquareFunctor(), origin_reduce_dims, false); } + +} // namespace phi + +PD_REGISTER_KERNEL(squared_l2_norm, + GPU, + ALL_LAYOUT, + phi::SquaredL2NormKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/primitive/compute_primitives.h b/paddle/phi/kernels/primitive/compute_primitives.h index 2265077d51b..f46e1126c41 100644 --- a/paddle/phi/kernels/primitive/compute_primitives.h +++ b/paddle/phi/kernels/primitive/compute_primitives.h @@ -52,6 +52,12 @@ class MPTypeTrait { using Type = float; }; +template <> +class MPTypeTrait { + public: + using Type = float; +}; + /** * @brief Will be used in BlockYReduce, get the index of reduce_num in shared * memory. diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index ffd94d840c4..d803de22606 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -32,8 +32,11 @@ from .framework import default_main_program from paddle import _C_ops, _legacy_C_ops __all__ = [ - 'set_gradient_clip', 'ErrorClipByValue', 'ClipGradByValue', - 'ClipGradByNorm', 'ClipGradByGlobalNorm' + 'set_gradient_clip', + 'ErrorClipByValue', + 'ClipGradByValue', + 'ClipGradByNorm', + 'ClipGradByGlobalNorm', ] _clip_by_global_norm_using_mp_type_flag = False @@ -52,9 +55,10 @@ def _clip_by_global_norm_using_mp_type(*args): def _cast_to_mp_type_if_enabled(x): - if (x.dtype == core.VarDesc.VarType.FP16 - or x.dtype == core.VarDesc.VarType.BF16 - ) and _clip_by_global_norm_using_mp_type(): + if ( + x.dtype == core.VarDesc.VarType.FP16 + or x.dtype == core.VarDesc.VarType.BF16 + ) and _clip_by_global_norm_using_mp_type(): return x.astype(core.VarDesc.VarType.FP32) else: return x @@ -66,8 +70,7 @@ def _squared_l2_norm(x): """ x = _cast_to_mp_type_if_enabled(x) - if core.is_compiled_with_xpu( - ) or x.dtype == core.VarDesc.VarType.FP16 or x.dtype == core.VarDesc.VarType.BF16: + if core.is_compiled_with_xpu(): square = layers.square(x) sum_square = layers.reduce_sum(square) return sum_square @@ -78,7 +81,9 @@ def _squared_l2_norm(x): return _legacy_C_ops.squared_l2_norm(x) op_type = 'squared_l2_norm' - check_variable_and_dtype(x, 'x', ['float32', 'float64'], op_type) + check_variable_and_dtype( + x, 'x', ['float32', 'float64', 'float16', 'uint16'], op_type + ) helper = LayerHelper(op_type, **locals()) out = helper.create_variable_for_type_inference(x.dtype) @@ -89,7 +94,6 @@ def _squared_l2_norm(x): class BaseErrorClipAttr(object): - def __str__(self): raise NotImplementedError() @@ -164,8 +168,9 @@ def error_clip_callback(block, context): for grad_n in [n for n in op_desc.output_arg_names() if n in grad_to_var]: fwd_var = block._var_recursive(grad_to_var[grad_n]) error_clip = getattr(fwd_var, "error_clip", None) - if not (error_clip is None - or isinstance(error_clip, BaseErrorClipAttr)): + if not ( + error_clip is None or isinstance(error_clip, BaseErrorClipAttr) + ): raise TypeError( "Variable's error_clip should be an instance of BaseErrorClipAttr or None." ) @@ -174,7 +179,6 @@ def error_clip_callback(block, context): class ClipGradBase(object): - def __init__(self): super(ClipGradBase, self).__init__() @@ -197,7 +201,8 @@ class ClipGradBase(object): warnings.warn( "'set_gradient_clip' will be ineffective, because you have " "set 'need_clip' in 'ParamAttr'. So, 'set_gradient_clip' " - "is redundant and you can remove it.") + "is redundant and you can remove it." + ) break return self._static_clip(params_grads) @@ -211,34 +216,34 @@ class ClipGradBase(object): class ClipGradByValue(ClipGradBase): """ Limit the value of multi-dimensional Tensor :math:`X` to the range [min, max]. - + - Any values less than min are set to ``min``. - + - Any values greater than max are set to ``max``. - The multi-dimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters set in ``optimizer``. + The multi-dimensional Tensor :math:`X` is not passed from this class, but the gradients of all parameters set in ``optimizer``. If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped. - - Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` + + Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` (for example: :ref:`api_paddle_optimizer_SGD`). Note: - ``need_clip`` of ``ClipGradByValue`` HAS BEEN DEPRECATED since 2.0. + ``need_clip`` of ``ClipGradByValue`` HAS BEEN DEPRECATED since 2.0. Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope. - + Args: max (float): The maximum value to clip by. - min (float, optional): The minimum value to clip by. if not set by user, it will be set to ``-max`` + min (float, optional): The minimum value to clip by. if not set by user, it will be set to ``-max`` automatically. In this case, ``max`` must be greater than 0. Examples: .. code-block:: python - + import paddle x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32') - linear = paddle.nn.Linear(in_features=10, out_features=10, - weight_attr=paddle.ParamAttr(need_clip=True), + linear = paddle.nn.Linear(in_features=10, out_features=10, + weight_attr=paddle.ParamAttr(need_clip=True), bias_attr=paddle.ParamAttr(need_clip=False)) out = linear(x) loss = paddle.mean(out) @@ -252,7 +257,7 @@ class ClipGradByValue(ClipGradBase): def __init__(self, max, min=None): super(ClipGradByValue, self).__init__() if min is None: - assert (max > 0.0) + assert max > 0.0 min = -max self.max = float(max) self.min = float(min) @@ -417,17 +422,17 @@ def _allow_pure_fp16_global_norm_clip(*args): class ClipGradByGlobalNorm(ClipGradBase): r""" - Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in + Given a list of Tensor :math:`t\_list` , calculate the global norm for the elements of all tensors in :math:`t\_list` , and limit it to ``clip_norm`` . - + - If the global norm is greater than ``clip_norm`` , all elements of :math:`t\_list` will be compressed by a ratio. - + - If the global norm is less than or equal to ``clip_norm`` , nothing will be done. - + The list of Tensor :math:`t\_list` is not passed from this class, but the gradients of all parameters set in ``optimizer``. If ``need_clip`` of specific param is ``False`` in its ``ParamAttr``, then the gradients of this param will not be clipped. - - Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` + + Gradient clip will takes effect after being set in ``optimizer`` , see the document ``optimizer`` (for example: :ref:`api_paddle_optimizer_SGD`). The clipping formula is: @@ -443,7 +448,7 @@ class ClipGradByGlobalNorm(ClipGradBase): global\_norm = \sqrt{\sum_{i=0}^{N-1}(l2norm(t\_list[i]))^2} Note: - ``need_clip`` of ``ClipGradyGlobalNorm`` HAS BEEN DEPRECATED since 2.0. + ``need_clip`` of ``ClipGradyGlobalNorm`` HAS BEEN DEPRECATED since 2.0. Please use ``need_clip`` in ``ParamAttr`` to speficiy the clip scope. Args: @@ -452,12 +457,12 @@ class ClipGradByGlobalNorm(ClipGradBase): Examples: .. code-block:: python - + import paddle x = paddle.uniform([10, 10], min=-1.0, max=1.0, dtype='float32') - linear = paddle.nn.Linear(in_features=10, out_features=10, - weight_attr=paddle.ParamAttr(need_clip=True), + linear = paddle.nn.Linear(in_features=10, out_features=10, + weight_attr=paddle.ParamAttr(need_clip=True), bias_attr=paddle.ParamAttr(need_clip=False)) out = linear(x) loss = paddle.mean(out) @@ -468,10 +473,9 @@ class ClipGradByGlobalNorm(ClipGradBase): sdg.step() """ - def __init__(self, - clip_norm, - group_name="default_group", - auto_skip_clip=False): + def __init__( + self, clip_norm, group_name="default_group", auto_skip_clip=False + ): super(ClipGradByGlobalNorm, self).__init__() self.clip_norm = float(clip_norm) self.group_name = group_name @@ -503,7 +507,10 @@ class ClipGradByGlobalNorm(ClipGradBase): merge_grad = layers.get_tensor_from_selected_rows(merge_grad) sum_square = _squared_l2_norm(merge_grad) - if sum_square.dtype == core.VarDesc.VarType.FP16 or sum_square.dtype == core.VarDesc.VarType.BF16: + if ( + sum_square.dtype == core.VarDesc.VarType.FP16 + or sum_square.dtype == core.VarDesc.VarType.BF16 + ): sum_square_list_fp16.append(sum_square) elif sum_square.dtype == core.VarDesc.VarType.FP32: sum_square_list_fp32.append(sum_square) @@ -511,8 +518,12 @@ class ClipGradByGlobalNorm(ClipGradBase): sum_square_list.append(sum_square) # all parameters have been filterd out - if len(sum_square_list) + len(sum_square_list_fp16) + len( - sum_square_list_fp32) == 0: + if ( + len(sum_square_list) + + len(sum_square_list_fp16) + + len(sum_square_list_fp32) + == 0 + ): return params_grads sum_dtype = 'float64' if len(sum_square_list) > 0 else "float32" @@ -531,22 +542,23 @@ class ClipGradByGlobalNorm(ClipGradBase): global_norm_var.append(global_norm_var_fp64) global_norm_var = paddle.add_n(global_norm_var) global_norm_var = layers.sqrt(global_norm_var) - max_global_norm = layers.fill_constant(shape=[1], - dtype=global_norm_var.dtype, - value=self.clip_norm) + max_global_norm = layers.fill_constant( + shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm + ) need_clip = False if not self.auto_skip_clip: # always apply clip need_clip = True - clip_var = layers.elementwise_div(x=max_global_norm, - y=layers.elementwise_max( - x=global_norm_var, - y=max_global_norm)) + clip_var = layers.elementwise_div( + x=max_global_norm, + y=layers.elementwise_max(x=global_norm_var, y=max_global_norm), + ) elif global_norm_var > max_global_norm: # only when global_norm_var > max_global_norm, grad need clip need_clip = True - clip_var = layers.elementwise_div(x=max_global_norm, - y=global_norm_var) + clip_var = layers.elementwise_div( + x=max_global_norm, y=global_norm_var + ) for p, g in params_grads: if g is None: @@ -556,8 +568,11 @@ class ClipGradByGlobalNorm(ClipGradBase): continue # TODO(wangxi): use inplace elementwise_mul if need_clip: - clip_input = (clip_var.astype(g.dtype) - if clip_var.dtype != g.dtype else clip_var) + clip_input = ( + clip_var.astype(g.dtype) + if clip_var.dtype != g.dtype + else clip_var + ) new_grad = layers.elementwise_mul(g, clip_input) params_and_grads.append((p, new_grad)) else: @@ -581,7 +596,8 @@ class ClipGradByGlobalNorm(ClipGradBase): if g.type == core.VarDesc.VarType.SELECTED_ROWS: merge_grad = layers.merge_selected_rows(g) merge_grad = layers.get_tensor_from_selected_rows( - merge_grad) + merge_grad + ) sum_square = _squared_l2_norm(merge_grad) if sum_square.dtype == core.VarDesc.VarType.FP16: sum_square_list_fp16.append(sum_square) @@ -591,8 +607,12 @@ class ClipGradByGlobalNorm(ClipGradBase): sum_square_list.append(sum_square) # all parameters have been filterd out - if len(sum_square_list) + len(sum_square_list_fp16) + len( - sum_square_list_fp32) == 0: + if ( + len(sum_square_list) + + len(sum_square_list_fp16) + + len(sum_square_list_fp32) + == 0 + ): return params_grads with p.block.program._optimized_guard([p, g]): @@ -601,10 +621,14 @@ class ClipGradByGlobalNorm(ClipGradBase): global_norm_var = [] if len(sum_square_list_fp16) > 0: global_norm_var_fp16 = layers.sums(sum_square_list_fp16) - if sum_square_list_fp32 or sum_square_list or not _allow_pure_fp16_global_norm_clip( + if ( + sum_square_list_fp32 + or sum_square_list + or not _allow_pure_fp16_global_norm_clip() ): global_norm_var.append( - global_norm_var_fp16.astype(sum_dtype)) + global_norm_var_fp16.astype(sum_dtype) + ) else: global_norm_var.append(global_norm_var_fp16) if len(sum_square_list_fp32) > 0: @@ -613,23 +637,28 @@ class ClipGradByGlobalNorm(ClipGradBase): global_norm_var.append(global_norm_var_fp32) else: global_norm_var.append( - global_norm_var_fp32.astype(sum_dtype)) + global_norm_var_fp32.astype(sum_dtype) + ) if len(sum_square_list) > 0: # fp64 global_norm_var_other_dtype = layers.sums(sum_square_list) global_norm_var.append(global_norm_var_other_dtype) - global_norm_var = layers.sums(global_norm_var) if len( - global_norm_var) > 1 else global_norm_var[0] + global_norm_var = ( + layers.sums(global_norm_var) + if len(global_norm_var) > 1 + else global_norm_var[0] + ) global_norm_var = layers.sqrt(x=global_norm_var) max_global_norm = layers.fill_constant( - shape=[1], - dtype=global_norm_var.dtype, - value=self.clip_norm) - scale_var = layers.elementwise_div(x=max_global_norm, - y=layers.elementwise_max( - x=max_global_norm, - y=global_norm_var)) + shape=[1], dtype=global_norm_var.dtype, value=self.clip_norm + ) + scale_var = layers.elementwise_div( + x=max_global_norm, + y=layers.elementwise_max( + x=max_global_norm, y=global_norm_var + ), + ) param_new_grad_name_dict = dict() for p, g in params_grads: if g is None: @@ -641,29 +670,32 @@ class ClipGradByGlobalNorm(ClipGradBase): with p.block.program._optimized_guard([p, g]): new_g = _cast_to_mp_type_if_enabled(g) # inplace - scale_input = (scale_var.astype('float16') if - new_g.dtype == core.VarDesc.VarType.FP16 and - scale_var.dtype != core.VarDesc.VarType.FP16 - else scale_var) + scale_input = ( + scale_var.astype('float16') + if new_g.dtype == core.VarDesc.VarType.FP16 + and scale_var.dtype != core.VarDesc.VarType.FP16 + else scale_var + ) # NOTE(Yuang Liu): For pure dp with gradient merge, the p and g # will be in different blocks with the gradient clip related ops. # We need to handle the correct block, otherwise will encounter # a 'NotFoundError' during compile time. block = default_main_program().current_block() - block.append_op(type='elementwise_mul', - inputs={ - 'X': new_g, - 'Y': scale_input - }, - outputs={'Out': new_g}) + block.append_op( + type='elementwise_mul', + inputs={'X': new_g, 'Y': scale_input}, + outputs={'Out': new_g}, + ) if new_g is not g: - block.append_op(type='cast', - inputs={'X': new_g}, - outputs={'Out': g}, - attrs={ - 'in_dtype': new_g.dtype, - 'out_dtype': g.dtype - }) + block.append_op( + type='cast', + inputs={'X': new_g}, + outputs={'Out': g}, + attrs={ + 'in_dtype': new_g.dtype, + 'out_dtype': g.dtype, + }, + ) param_new_grad_name_dict[p.name] = g.name params_and_grads.append((p, g)) @@ -676,7 +708,8 @@ class ClipGradByGlobalNorm(ClipGradBase): context[self.group_name] = [] context[self.group_name + "_clip_value"] = self.clip_norm context[self.group_name + "_clip"] = layers.fill_constant( - shape=[1], dtype=grad.dtype, value=self.clip_norm) + shape=[1], dtype=grad.dtype, value=self.clip_norm + ) else: if not self.clip_norm == context[self.group_name + "_clip_value"]: raise ValueError( @@ -699,20 +732,19 @@ class ClipGradByGlobalNorm(ClipGradBase): group_norm_var = layers.sums(input=self.context[self.group_name]) group_norm_var = layers.sqrt(x=group_norm_var) clip_var = self.context[self.group_name + "_clip"] - group_scale_var = layers.elementwise_div(x=clip_var, - y=layers.elementwise_max( - x=clip_var, - y=group_norm_var)) - assert group_scale_var.shape == (1, ) + group_scale_var = layers.elementwise_div( + x=clip_var, + y=layers.elementwise_max(x=clip_var, y=group_norm_var), + ) + assert group_scale_var.shape == (1,) self.context[group_scale_name] = group_scale_var # inplace - param.block.append_op(type='elementwise_mul', - inputs={ - 'X': grad, - 'Y': self.context[group_scale_name] - }, - outputs={'Out': grad}) + param.block.append_op( + type='elementwise_mul', + inputs={'X': grad, 'Y': self.context[group_scale_name]}, + outputs={'Out': grad}, + ) return param, grad @@ -721,23 +753,23 @@ class ClipGradByGlobalNorm(ClipGradBase): def set_gradient_clip(clip, param_list=None, program=None): """ :api_attr: Static Graph - + Warning: - - This API must be used after building network, and before ``minimize`` , - and it may be removed in future releases, so it is not recommended. + + This API must be used after building network, and before ``minimize`` , + and it may be removed in future releases, so it is not recommended. It is recommended to set ``grad_clip`` when initializing the ``optimizer`` , this is a better method to clip gradient. There are three clipping strategies: - :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , + :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , :ref:`api_fluid_clip_GradientClipByValue` . - + To specify parameters that require gradient clip. Args: - grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of - some derived class of ``GradientClipBase`` . There are three cliping strategies - ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , - :ref:`api_fluid_clip_GradientClipByValue` ). Default value: None, and there is no + grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of + some derived class of ``GradientClipBase`` . There are three cliping strategies + ( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` , + :ref:`api_fluid_clip_GradientClipByValue` ). Default value: None, and there is no gradient clipping. param_list (list(Variable), optional): Parameters that require gradient clip. It can be a list of parameter or a list of parameter's name. @@ -791,7 +823,7 @@ def set_gradient_clip(clip, param_list=None, program=None): param_list=[param_var1, param_var2]) sgd = fluid.optimizer.SGD(learning_rate=1e-3) sgd.minimize(loss) - + # network 4: use 'set_gradient_clip' and 'optimize(grad_clip=clip)' together with fluid.program_guard(fluid.Program(), fluid.Program()): loss = network() @@ -802,27 +834,31 @@ def set_gradient_clip(clip, param_list=None, program=None): # Set the gradient clipping strategy: clip2 sgd = fluid.optimizer.SGD(learning_rate=1e-3, grad_clip=clip2) sgd.minimize(loss) - # 'set_gradient_clip' will not take effect when setting has a conflict, + # 'set_gradient_clip' will not take effect when setting has a conflict, # and the gradient clipping strategy will be 'clip2' - - + + """ - warnings.warn("Caution! 'set_gradient_clip' is not recommended " - "and may be deprecated in future! " - "We recommend a new strategy: set 'grad_clip' " - "when initializing the 'optimizer'. " - "This method can reduce the mistakes, please " - "refer to documention of 'optimizer'.") + warnings.warn( + "Caution! 'set_gradient_clip' is not recommended " + "and may be deprecated in future! " + "We recommend a new strategy: set 'grad_clip' " + "when initializing the 'optimizer'. " + "This method can reduce the mistakes, please " + "refer to documention of 'optimizer'." + ) if not isinstance(clip, ClipGradBase): raise TypeError( - "'clip' should be an instance of ClipGradBase's derived class") + "'clip' should be an instance of ClipGradBase's derived class" + ) if program is None: program = framework.default_main_program() for op in program.block(0).ops: if 'op_namescope' in op.all_attrs() and "optimizer" in op.attr( - "op_namescope"): + "op_namescope" + ): warnings.warn( "'minimize' has been invoked before, this will make 'set_gradient_clip' " "be ineffective! Please invoke 'set_gradient_clip' before 'minimize'." @@ -847,14 +883,16 @@ def append_gradient_clip_ops(param_grads): for p, g in param_grads: if g is None: continue - with p.block.program._optimized_guard( - [p, g]), framework.name_scope('gradient_clip'): + with p.block.program._optimized_guard([p, g]), framework.name_scope( + 'gradient_clip' + ): clip_attr = getattr(p, 'gradient_clip_attr', None) if clip_attr is None: return param_grads if not isinstance(clip_attr, ClipGradBase): raise TypeError( - "clip attribute should be an instance of GradientClipBase") + "clip attribute should be an instance of GradientClipBase" + ) clip_attr._process_context(context=context, param=p, grad=g) @@ -863,8 +901,9 @@ def append_gradient_clip_ops(param_grads): for p, g in param_grads: if g is None: continue - with p.block.program._optimized_guard( - [p, g]), framework.name_scope('gradient_clip'): + with p.block.program._optimized_guard([p, g]), framework.name_scope( + 'gradient_clip' + ): param, new_grad = clip_attr._create_operators(param=p, grad=g) param_new_grad_name_dict[param.name] = new_grad.name res.append([param, new_grad]) @@ -888,12 +927,16 @@ def _correct_clip_op_role_var(params_grads, param_new_grad_name_dict): continue block_id_list.append(block_id) for op in param.block.program.global_block().ops: - if op.has_attr("op_namescope") and "gradient_clip" in op.attr( - "op_namescope") and op.attr('op_role_var'): + if ( + op.has_attr("op_namescope") + and "gradient_clip" in op.attr("op_namescope") + and op.attr('op_role_var') + ): param_name = op.attr('op_role_var')[0] if param_name in param_new_grad_name_dict: correct_p_g = [ - param_name, param_new_grad_name_dict[param_name] + param_name, + param_new_grad_name_dict[param_name], ] op._set_attr('op_role_var', correct_p_g) diff --git a/python/paddle/fluid/tests/unittests/test_add_n_op.py b/python/paddle/fluid/tests/unittests/test_add_n_op.py new file mode 100644 index 00000000000..3ca485b1419 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_add_n_op.py @@ -0,0 +1,64 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import numpy as np + +import paddle + + +class TestAddnOp(unittest.TestCase): + def setUp(self): + np.random.seed(20) + l = 32 + self.x_np = np.random.random([l, 16, 256]) + + def check_main(self, x_np, dtype, axis=None): + paddle.disable_static() + x = [] + for i in range(x_np.shape[0]): + val = paddle.to_tensor(x_np[i].astype(dtype)) + val.stop_gradient = False + x.append(val) + y = paddle.add_n(x) + x_g = paddle.grad(y, x) + y_np = y.numpy().astype('float32') + x_g_np = [] + for val in x_g: + x_g_np.append(val.numpy().astype('float32')) + paddle.enable_static() + return y_np, x_g_np + + def test_add_n_fp16(self): + if not paddle.is_compiled_with_cuda(): + return + y_np_16, x_g_np_16 = self.check_main(self.x_np, 'float16') + y_np_32, x_g_np_32 = self.check_main(self.x_np, 'float32') + + np.testing.assert_allclose(y_np_16, y_np_32, rtol=1e-03) + for i in range(len(x_g_np_32)): + np.testing.assert_allclose(x_g_np_16[i], x_g_np_32[i], rtol=1e-03) + + def test_add_n_api(self): + if not paddle.is_compiled_with_cuda(): + return + + y_np_32, x_g_np_32 = self.check_main(self.x_np, 'float32') + y_np_gt = np.sum(self.x_np, axis=0).astype('float32') + + np.testing.assert_allclose(y_np_32, y_np_gt, rtol=1e-06) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_gradient_clip.py b/python/paddle/fluid/tests/unittests/test_gradient_clip.py index 7140148119e..128ffee00b3 100644 --- a/python/paddle/fluid/tests/unittests/test_gradient_clip.py +++ b/python/paddle/fluid/tests/unittests/test_gradient_clip.py @@ -26,21 +26,17 @@ from paddle.fluid.clip import _allow_pure_fp16_global_norm_clip paddle.enable_static() -def bow_net(data, - label, - dict_dim, - emb_dim=128, - hid_dim=128, - hid_dim2=96, - class_dim=2): +def bow_net( + data, label, dict_dim, emb_dim=128, hid_dim=128, hid_dim2=96, class_dim=2 +): """ BOW net This model is from https://github.com/PaddlePaddle/models: fluid/PaddleNLP/text_classification/nets.py """ - emb = fluid.layers.embedding(input=data, - is_sparse=True, - size=[dict_dim, emb_dim]) + emb = fluid.layers.embedding( + input=data, is_sparse=True, size=[dict_dim, emb_dim] + ) bow = fluid.layers.sequence_pool(input=emb, pool_type='sum') bow_tanh = fluid.layers.tanh(bow) fc_1 = fluid.layers.fc(input=bow_tanh, size=hid_dim, act="tanh") @@ -53,7 +49,6 @@ def bow_net(data, class TestGradientClip(unittest.TestCase): - def setUp(self): self.word_dict_len = 5147 self.BATCH_SIZE = 2 @@ -77,8 +72,9 @@ class TestGradientClip(unittest.TestCase): def check_gradient_clip(self, place, dtype='float32'): prog = fluid.Program() startup_program = fluid.Program() - with fluid.program_guard(main_program=prog, - startup_program=startup_program): + with fluid.program_guard( + main_program=prog, startup_program=startup_program + ): image = fluid.data(name="a", shape=[-1, 784], dtype='float32') label = fluid.data(name="b", shape=[-1, 1], dtype='int64') if dtype != 'float32': @@ -99,8 +95,9 @@ class TestGradientClip(unittest.TestCase): p_g = sorted(p_g, key=lambda x: x[0].name) p_g_clip = sorted(p_g_clip, key=lambda x: x[0].name) - with fluid.program_guard(main_program=prog_clip, - startup_program=startup_program): + with fluid.program_guard( + main_program=prog_clip, startup_program=startup_program + ): p_g_clip = self.clip_gradient(p_g_clip) grad_list = [elem[1] for elem in p_g] @@ -113,20 +110,20 @@ class TestGradientClip(unittest.TestCase): data = next(train_reader()) out = exe.run(prog, feed=feeder.feed(data), fetch_list=grad_list) - out_clip = exe.run(prog_clip, - feed=feeder.feed(data), - fetch_list=grad_clip_list) + out_clip = exe.run( + prog_clip, feed=feeder.feed(data), fetch_list=grad_clip_list + ) self.check_clip_result(out, out_clip) def check_sparse_gradient_clip(self, place): prog = fluid.Program() startup_program = fluid.Program() - with fluid.program_guard(main_program=prog, - startup_program=startup_program): - data = fluid.data(name="words", - shape=[-1, 1], - dtype="int64", - lod_level=1) + with fluid.program_guard( + main_program=prog, startup_program=startup_program + ): + data = fluid.data( + name="words", shape=[-1, 1], dtype="int64", lod_level=1 + ) label = fluid.data(name="label", shape=[-1, 1], dtype="int64") cost = bow_net(data, label, self.word_dict_len) @@ -138,7 +135,7 @@ class TestGradientClip(unittest.TestCase): data = next(self.train_data()) val = exe.run(prog, feed=feeder.feed(data), fetch_list=[cost])[0] - self.assertEqual((1, ), val.shape) + self.assertEqual((1,), val.shape) self.assertFalse(np.isnan(val)) def backward_and_optimize(self, cost): @@ -146,7 +143,6 @@ class TestGradientClip(unittest.TestCase): class TestGradientClipByGlobalNorm(TestGradientClip): - def init(self): self.clip_norm = 0.2 @@ -166,13 +162,13 @@ class TestGradientClipByGlobalNorm(TestGradientClip): v, rtol=1e-05, atol=1e-08, - err_msg= - 'gradient clip by global norm has wrong results!, \nu={}\nv={}\ndiff={}' - .format(u, v, u - v)) + err_msg='gradient clip by global norm has wrong results!, \nu={}\nv={}\ndiff={}'.format( + u, v, u - v + ), + ) # test whether the output is right when use 'set_gradient_clip' def test_old_gradient_clip(self): - def func(params_grads): clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm) fluid.clip.set_gradient_clip(clip) @@ -183,7 +179,6 @@ class TestGradientClipByGlobalNorm(TestGradientClip): # test whether the output is right when use grad_clip def test_new_gradient_clip(self): - def func(params_grads): clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm) return clip(params_grads) @@ -193,7 +188,6 @@ class TestGradientClipByGlobalNorm(TestGradientClip): # test whether the output is right when use grad_clip under float64 def test_new_gradient_clip_fp64(self): - def func(params_grads): clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=self.clip_norm) return clip(params_grads) @@ -203,12 +197,12 @@ class TestGradientClipByGlobalNorm(TestGradientClip): # invoke 'set_gradient_clip' in a wrong order def test_wrong_API_order(self): - def backward_func(cost): clip = fluid.clip.GradientClipByGlobalNorm(clip_norm=5.0) fluid.clip.set_gradient_clip(clip) - sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01, - grad_clip=clip) + sgd_optimizer = fluid.optimizer.SGD( + learning_rate=0.01, grad_clip=clip + ) # if 'set_gradient_clip' and 'optimize(grad_clip)' together, 'set_gradient_clip' will be ineffective sgd_optimizer.minimize(cost) # 'set_gradient_clip' must before 'minimize', otherwise, 'set_gradient_clip' will be ineffective @@ -222,43 +216,72 @@ class TestGradientClipByGlobalNorm(TestGradientClip): def test_tpyeError(self): # the type of optimizer(grad_clip=) must be an instance of GradientClipBase's derived class with self.assertRaises(TypeError): - sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1, - grad_clip="test") + sgd_optimizer = fluid.optimizer.SGD( + learning_rate=0.1, grad_clip="test" + ) # if grad is None or not need clip def test_none_grad_fp32(self): ops = self._test_none_grad_helper("float32") - self.assertListEqual(ops, [ - 'squared_l2_norm', 'squared_l2_norm', 'sum', 'sqrt', - 'fill_constant', 'elementwise_max', 'elementwise_div', - 'elementwise_mul', 'elementwise_mul' - ]) + self.assertListEqual( + ops, + [ + 'squared_l2_norm', + 'squared_l2_norm', + 'sum', + 'sqrt', + 'fill_constant', + 'elementwise_max', + 'elementwise_div', + 'elementwise_mul', + 'elementwise_mul', + ], + ) def test_none_grad_fp16(self): ops = self._test_none_grad_helper("float16") - self.assertListEqual(ops, [ - 'square', 'reduce_sum', 'square', 'reduce_sum', 'sum', 'cast', - 'sqrt', 'fill_constant', 'elementwise_max', 'elementwise_div', - 'cast', 'elementwise_mul', 'cast', 'elementwise_mul' - ]) + self.assertListEqual( + ops, + [ + 'squared_l2_norm', + 'squared_l2_norm', + 'sum', + 'cast', + 'sqrt', + 'fill_constant', + 'elementwise_max', + 'elementwise_div', + 'cast', + 'elementwise_mul', + 'cast', + 'elementwise_mul', + ], + ) def _test_none_grad_helper(self, dtype): prog = fluid.Program() startup_program = fluid.Program() - with fluid.program_guard(main_program=prog, - startup_program=startup_program): + with fluid.program_guard( + main_program=prog, startup_program=startup_program + ): clip = fluid.clip.GradientClipByGlobalNorm(self.clip_norm) - x = fluid.default_main_program().global_block().create_parameter( - name="x", shape=[2, 3], dtype=dtype) - y = fluid.default_main_program().global_block().create_parameter( - name="y", shape=[2, 3], dtype=dtype) + x = ( + fluid.default_main_program() + .global_block() + .create_parameter(name="x", shape=[2, 3], dtype=dtype) + ) + y = ( + fluid.default_main_program() + .global_block() + .create_parameter(name="y", shape=[2, 3], dtype=dtype) + ) # (x, None) should not be returned params_grads = [(x, None), (x, y), (y, x)] params_grads = clip(params_grads) self.assertTrue( len(params_grads) == 2, - "ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!" + "ClipByGlobalNorm: when grad is None, it shouldn't be returned by gradient clip!", ) ops = [op.type for op in x.block.ops] @@ -266,7 +289,6 @@ class TestGradientClipByGlobalNorm(TestGradientClip): class TestGradientClipByNorm(TestGradientClip): - def init(self): self.clip_norm = 0.2 @@ -280,11 +302,11 @@ class TestGradientClipByNorm(TestGradientClip): v, rtol=1e-05, atol=1e-08, - err_msg='gradient clip by norm has wrong results!') + err_msg='gradient clip by norm has wrong results!', + ) # test whether the output is right when use grad_clip def test_gradient_clip(self): - def func(params_grads): clip = fluid.clip.GradientClipByNorm(clip_norm=self.clip_norm) return clip(params_grads) @@ -295,25 +317,35 @@ class TestGradientClipByNorm(TestGradientClip): # if grad is None or not need clip def test_none_grad(self): clip = fluid.clip.GradientClipByNorm(self.clip_norm) - x = fluid.default_main_program().global_block().create_parameter( - name="x", shape=[2, 3], dtype="float32", need_clip=False) - y = fluid.default_main_program().global_block().create_parameter( - name="y", shape=[2, 3], dtype="float32", need_clip=False) + x = ( + fluid.default_main_program() + .global_block() + .create_parameter( + name="x", shape=[2, 3], dtype="float32", need_clip=False + ) + ) + y = ( + fluid.default_main_program() + .global_block() + .create_parameter( + name="y", shape=[2, 3], dtype="float32", need_clip=False + ) + ) # (x, None) should not be returned params_grads = [(x, None), (x, y)] params_grads = clip(params_grads) self.assertTrue( len(clip(params_grads)) == 1, - "ClipGradByNorm: when grad is None, it shouldn't be returned by gradient clip!" + "ClipGradByNorm: when grad is None, it shouldn't be returned by gradient clip!", ) self.assertTrue( params_grads[0][1].name == 'y', - "ClipGradByNorm: grad should not be clipped when filtered out!") + "ClipGradByNorm: grad should not be clipped when filtered out!", + ) class TestGradientClipByValue(TestGradientClip): - def init(self): self.max = 0.2 self.min = 0.1 @@ -328,11 +360,11 @@ class TestGradientClipByValue(TestGradientClip): v, rtol=1e-06, atol=1e-08, - err_msg='gradient clip by value has wrong results!') + err_msg='gradient clip by value has wrong results!', + ) # test whether the output is right when use grad_clip def test_gradient_clip(self): - def func(params_grads): clip = fluid.clip.GradientClipByValue(max=self.max, min=self.min) return clip(params_grads) @@ -343,37 +375,49 @@ class TestGradientClipByValue(TestGradientClip): # if grad is None or not need clip def test_none_grad(self): clip = fluid.clip.GradientClipByValue(self.max, self.min) - x = fluid.default_main_program().global_block().create_parameter( - name="x", shape=[2, 3], dtype="float32", need_clip=False) - y = fluid.default_main_program().global_block().create_parameter( - name="y", shape=[2, 3], dtype="float32", need_clip=False) + x = ( + fluid.default_main_program() + .global_block() + .create_parameter( + name="x", shape=[2, 3], dtype="float32", need_clip=False + ) + ) + y = ( + fluid.default_main_program() + .global_block() + .create_parameter( + name="y", shape=[2, 3], dtype="float32", need_clip=False + ) + ) # (x, None) should not be returned params_grads = [(x, None), (x, y)] params_grads = clip(params_grads) self.assertTrue( len(clip(params_grads)) == 1, - "ClipGradByValue: when grad is None, it shouldn't be returned by gradient clip!" + "ClipGradByValue: when grad is None, it shouldn't be returned by gradient clip!", ) self.assertTrue( params_grads[0][1].name == 'y', - "ClipGradByValue: grad should not be clipped when filtered out!") + "ClipGradByValue: grad should not be clipped when filtered out!", + ) class TestDygraphGradientClip(unittest.TestCase): - def test_gradient_clip(self): with fluid.dygraph.guard(): linear = fluid.dygraph.Linear(5, 5) - inputs = fluid.layers.uniform_random([16, 5], min=-10, - max=10).astype('float32') + inputs = fluid.layers.uniform_random( + [16, 5], min=-10, max=10 + ).astype('float32') out = linear(fluid.dygraph.to_variable(inputs)) loss = fluid.layers.reduce_mean(out) loss.backward() sgd_optimizer = fluid.optimizer.SGD( learning_rate=0.0, parameter_list=linear.parameters(), - grad_clip=fluid.clip.GradientClipByGlobalNorm(0.1)) + grad_clip=fluid.clip.GradientClipByGlobalNorm(0.1), + ) self.check_clip_result(loss, sgd_optimizer) def check_clip_result(self, loss, optimizer): @@ -381,20 +425,23 @@ class TestDygraphGradientClip(unittest.TestCase): class TestDygraphGradientClipByGlobalNorm(TestDygraphGradientClip): - def setUp(self): self.clip_norm = 0.8 self.clip1 = fluid.clip.GradientClipByGlobalNorm( - clip_norm=self.clip_norm) + clip_norm=self.clip_norm + ) self.clip2 = fluid.clip.GradientClipByGlobalNorm( - clip_norm=self.clip_norm) + clip_norm=self.clip_norm + ) def check_clip_result(self, loss, optimizer): # if grad is None - x = fluid.dygraph.to_variable(np.array([2, 3]).astype("float32"), - name="x") - y = fluid.dygraph.to_variable(np.array([3, 4]).astype("float32"), - name="y") + x = fluid.dygraph.to_variable( + np.array([2, 3]).astype("float32"), name="x" + ) + y = fluid.dygraph.to_variable( + np.array([3, 4]).astype("float32"), name="y" + ) assert len(self.clip1([(x, x), (x, y), (x, None)])) == 2 # get params and grads from network opt, params_grads = optimizer.minimize(loss) @@ -419,11 +466,11 @@ class TestDygraphGradientClipByGlobalNorm(TestDygraphGradientClip): self.assertTrue( np.isclose(a=a, b=b, rtol=1e-6, atol=1e-8), "gradient clip by global norm has wrong results, expetcd:%f, but received:%f" - % (a, b)) + % (a, b), + ) class TestDygraphGradientClipByNorm(TestDygraphGradientClip): - def setUp(self): self.clip_norm = 0.8 self.clip = fluid.clip.GradientClipByNorm(clip_norm=self.clip_norm) @@ -448,11 +495,11 @@ class TestDygraphGradientClipByNorm(TestDygraphGradientClip): self.assertTrue( np.isclose(a=a, b=b, rtol=1e-6, atol=1e-8), "gradient clip by norm has wrong results, expetcd:%f, but received:%f" - % (a, b)) + % (a, b), + ) class TestDygraphGradientClipByValue(TestDygraphGradientClip): - def setUp(self): self.max = 0.2 self.min = 0.1 @@ -475,11 +522,11 @@ class TestDygraphGradientClipByValue(TestDygraphGradientClip): v, rtol=1e-06, atol=1e-08, - err_msg='gradient clip by value has wrong results!') + err_msg='gradient clip by value has wrong results!', + ) class SimpleNet(paddle.nn.Layer): - def __init__(self): super(SimpleNet, self).__init__() self.linear = paddle.nn.Linear(5, 5) @@ -492,19 +539,21 @@ class SimpleNet(paddle.nn.Layer): class TestDygraphGradientClipFP16(unittest.TestCase): - def test_gradient_clip(self): if fluid.core.is_compiled_with_cuda(): with fluid.dygraph.guard(): paddle.seed(10) model = SimpleNet() sgd_optimizer = paddle.optimizer.SGD( - learning_rate=0.0, parameters=model.parameters()) + learning_rate=0.0, parameters=model.parameters() + ) model, sgd_optimizer = paddle.amp.decorate( - models=model, optimizers=sgd_optimizer, level='O2') + models=model, optimizers=sgd_optimizer, level='O2' + ) scaler = paddle.amp.GradScaler(init_loss_scaling=1024) - inputs = fluid.layers.uniform_random([1, 5], min=-10, - max=10).astype('float32') + inputs = fluid.layers.uniform_random( + [1, 5], min=-10, max=10 + ).astype('float32') with paddle.amp.auto_cast(level='O2'): out = model(fluid.dygraph.to_variable(inputs)) loss = fluid.layers.reduce_mean(out) @@ -543,15 +592,16 @@ class TestDygraphGradientClipFP16(unittest.TestCase): self.assertTrue( np.isclose(a=a, b=b, rtol=1e-3, atol=1e-8), "gradient clip by global norm has wrong results, expetcd:%f, but received:%f" - % (a, b)) + % (a, b), + ) class TestDygraphGradientClipFP64(unittest.TestCase): - def test_gradient_clip(self): with fluid.dygraph.guard(): - inputs = fluid.layers.uniform_random([16, 5], min=-10, - max=10).astype('float64') + inputs = fluid.layers.uniform_random( + [16, 5], min=-10, max=10 + ).astype('float64') linear = fluid.dygraph.Linear(5, 5, dtype="float64") out = linear(fluid.dygraph.to_variable(inputs)) loss = fluid.layers.reduce_mean(out) @@ -589,11 +639,11 @@ class TestDygraphGradientClipFP64(unittest.TestCase): self.assertTrue( np.isclose(a=a, b=b, rtol=1e-6, atol=1e-8), "gradient clip by global norm has wrong results, expetcd:%f, but received:%f" - % (a, b)) + % (a, b), + ) class TestPureFP16ClipGradByGlobalNorm(unittest.TestCase): - def check_main(self, expected_has_cast_op): main_prog = paddle.static.Program() startup_prog = paddle.static.Program() @@ -604,12 +654,12 @@ class TestPureFP16ClipGradByGlobalNorm(unittest.TestCase): param_and_grads = [] main_block = main_prog.global_block() for name, shape in zip(names, shapes): - p = main_block.create_parameter(name=name, - shape=shape, - dtype='float16') - g = main_block.create_parameter(name=p.name + '@GRAD', - shape=p.shape, - dtype=p.dtype) + p = main_block.create_parameter( + name=name, shape=shape, dtype='float16' + ) + g = main_block.create_parameter( + name=p.name + '@GRAD', shape=p.shape, dtype=p.dtype + ) param_and_grads.append((p, g)) clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) diff --git a/python/paddle/optimizer/adamw.py b/python/paddle/optimizer/adamw.py index fbe23c84a2a..4438abde9ca 100644 --- a/python/paddle/optimizer/adamw.py +++ b/python/paddle/optimizer/adamw.py @@ -148,19 +148,21 @@ class AdamW(Optimizer): _beta1_pow_acc_str = "beta1_pow_acc" _beta2_pow_acc_str = "beta2_pow_acc" - def __init__(self, - learning_rate=0.001, - beta1=0.9, - beta2=0.999, - epsilon=1e-8, - parameters=None, - weight_decay=0.01, - lr_ratio=None, - apply_decay_param_fun=None, - grad_clip=None, - lazy_mode=False, - multi_precision=False, - name=None): + def __init__( + self, + learning_rate=0.001, + beta1=0.9, + beta2=0.999, + epsilon=1e-8, + parameters=None, + weight_decay=0.01, + lr_ratio=None, + apply_decay_param_fun=None, + grad_clip=None, + lazy_mode=False, + multi_precision=False, + name=None, + ): assert learning_rate is not None assert beta1 is not None assert beta2 is not None @@ -171,14 +173,16 @@ class AdamW(Optimizer): raise ValueError("Invaild value of beta2, expect beta2 in [0,1).") if not 0 <= epsilon: raise ValueError("Invaild value of epsilon, expect epsilon >= 0.") - if not isinstance(weight_decay, float) and \ - not isinstance(weight_decay, framework.Variable): + if not isinstance(weight_decay, float) and not isinstance( + weight_decay, framework.Variable + ): raise TypeError("weight_decay should be float or Tensor.") if lr_ratio is not None: assert isinstance(lr_ratio, Callable) if not core.is_compiled_with_cuda(): raise NotImplementedError( - "'lr_ratio' is unimplemented in CPU, XPU and NPU") + "'lr_ratio' is unimplemented in CPU, XPU and NPU" + ) if parameters is not None: # paddle.Tensor is also iterable, so here we don't check whether @@ -187,13 +191,16 @@ class AdamW(Optimizer): if isinstance(parameters, (paddle.Tensor, core.eager.Tensor)): raise TypeError( "`parameters` argument given to the optimizer should be " - "an iterable of paddle Tensors, but got argument type is `{}`." - .format(type(parameters))) + "an iterable of paddle Tensors, but got argument type is `{}`.".format( + type(parameters) + ) + ) if isinstance(parameters, dict): raise TypeError( "`parameters` argument should not get dict type, " "if parameter groups is needed, please set `parameters`" - " as list of dict") + " as list of dict" + ) self._parameter_list = list(parameters) else: self._parameter_list = None @@ -207,8 +214,9 @@ class AdamW(Optimizer): if not isinstance(learning_rate, (float, LRScheduler)): raise TypeError( - "learning rate should be float or LRScheduler, got %s here" % - type(learning_rate)) + "learning rate should be float or LRScheduler, got %s here" + % type(learning_rate) + ) if grad_clip is not None: if not isinstance(grad_clip, GradientClipBase): raise TypeError( @@ -220,8 +228,9 @@ class AdamW(Optimizer): if self._parameter_list: if isinstance(self._parameter_list[0], dict): for param_group in self._parameter_list: - assert 'params' in param_group, \ - 'params should be set in parameters if parameter groups are optimized in different options' + assert ( + 'params' in param_group + ), 'params should be set in parameters if parameter groups are optimized in different options' self._dtype = self._parameter_list[0]['params'][0].dtype else: self._dtype = self._parameter_list[0].dtype @@ -260,7 +269,7 @@ class AdamW(Optimizer): 'beta2': beta2, 'epsilon': epsilon, 'lazy_mode': lazy_mode, - 'grad_clip': grad_clip + 'grad_clip': grad_clip, } self._param_groups = [] @@ -297,7 +306,8 @@ class AdamW(Optimizer): elif isinstance(params, set): raise TypeError( "optimizer parameters should be in ordered collections," - "but received set, please use list instead.") + "but received set, please use list instead." + ) else: param_group['params'] = list(params) @@ -311,11 +321,13 @@ class AdamW(Optimizer): if not param_set.isdisjoint(set(param_group['params'])): raise ValueError( - "some parameters appear in more than one parameter group") + "some parameters appear in more than one parameter group" + ) for param in param_group['params']: param.optimize_attr['learning_rate'] = param_group.get( - 'learning_rate', 1.) + 'learning_rate', 1.0 + ) self._param_groups.append(param_group) @@ -327,19 +339,23 @@ class AdamW(Optimizer): var_name = param.name + "_fp32_master" var_name = unique_name.generate(var_name) - var = layers.create_global_var(name=var_name, - shape=param.shape, - value=0, - dtype='float32', - persistable=True) + var = layers.create_global_var( + name=var_name, + shape=param.shape, + value=0, + dtype='float32', + persistable=True, + ) block = self.helper.startup_program.global_block() - block.append_op(type="cast", - inputs={"X": [param]}, - outputs={"Out": [var]}, - attrs={ - "in_dtype": param.dtype, - "out_dtype": core.VarDesc.VarType.FP32 - }) + block.append_op( + type="cast", + inputs={"X": [param]}, + outputs={"Out": [var]}, + attrs={ + "in_dtype": param.dtype, + "out_dtype": core.VarDesc.VarType.FP32, + }, + ) self._master_weights[param.name] = var return var @@ -353,20 +369,31 @@ class AdamW(Optimizer): """ if self._name is not None: name = self._name + "_" + name - find_master = self._multi_precision and param.dtype == core.VarDesc.VarType.FP16 - target_param = self._master_weights[ - param.name] if find_master else param + find_master = self._multi_precision and ( + param.dtype == core.VarDesc.VarType.FP16 + or param.dtype == core.VarDesc.VarType.BF16 + ) + target_param = ( + self._master_weights[param.name] if find_master else param + ) target_name = target_param.name - if (name not in self._accumulators - or target_name not in self._accumulators[name]): + if ( + name not in self._accumulators + or target_name not in self._accumulators[name] + ): raise Exception( "Accumulator {} does not exist for parameter {}".format( - name, target_name)) + name, target_name + ) + ) return self._accumulators[name][target_name] def _add_moments_pows(self, p): acc_dtype = p.dtype - if acc_dtype == core.VarDesc.VarType.FP16: + if ( + acc_dtype == core.VarDesc.VarType.FP16 + or acc_dtype == core.VarDesc.VarType.BF16 + ): acc_dtype = core.VarDesc.VarType.FP32 self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) @@ -374,18 +401,24 @@ class AdamW(Optimizer): name=self._beta1_pow_acc_str, param=p, dtype=acc_dtype, - fill_value=0.9 if isinstance(self._beta1, Variable) \ - else self._beta1, + fill_value=0.9 + if isinstance(self._beta1, Variable) + else self._beta1, shape=[1], - type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') + type=core.VarDesc.VarType.LOD_TENSOR, + device='cpu', + ) self._add_accumulator( name=self._beta2_pow_acc_str, param=p, dtype=acc_dtype, - fill_value=0.999 if isinstance(self._beta2, Variable) \ - else self._beta2, + fill_value=0.999 + if isinstance(self._beta2, Variable) + else self._beta2, shape=[1], - type=core.VarDesc.VarType.LOD_TENSOR, device='cpu') + type=core.VarDesc.VarType.LOD_TENSOR, + device='cpu', + ) def _create_accumulators(self, block, parameters): assert isinstance(block, framework.Block) @@ -394,13 +427,19 @@ class AdamW(Optimizer): # Create accumulator tensors for first and second moments for p in parameters: - if self._multi_precision and p.dtype == core.VarDesc.VarType.FP16: + if self._multi_precision and ( + p.dtype == core.VarDesc.VarType.FP16 + or p.dtype == core.VarDesc.VarType.BF16 + ): master_p = self._create_master_weight(p) self._add_moments_pows(master_p) continue - if p.dtype == core.VarDesc.VarType.FP16 and not self._multi_precision: + if ( + p.dtype == core.VarDesc.VarType.FP16 + or p.dtype == core.VarDesc.VarType.BF16 + ) and not self._multi_precision: warnings.warn( - "Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence." + "Accumulating with FP16/BFP16 in optimizer can lead to poor accuracy or slow convergence." "Consider using multi_precision=True option of the Adam optimizer." ) self._add_moments_pows(p) @@ -413,53 +452,112 @@ class AdamW(Optimizer): # Whether we should do weight decay for the parameter. with_decay = True - if self._apply_decay_param_fun is not None \ - and not self._apply_decay_param_fun(param.name): + if ( + self._apply_decay_param_fun is not None + and not self._apply_decay_param_fun(param.name) + ): with_decay = False - moment1 = self._get_accumulator(self._moment1_acc_str, - param_and_grad[0]) - moment2 = self._get_accumulator(self._moment2_acc_str, - param_and_grad[0]) - beta1_pow_acc = self._get_accumulator(self._beta1_pow_acc_str, - param_and_grad[0]) - beta2_pow_acc = self._get_accumulator(self._beta2_pow_acc_str, - param_and_grad[0]) - find_master = self._multi_precision and param_and_grad[ - 0].dtype == core.VarDesc.VarType.FP16 - master_weight = (self._master_weights[param_and_grad[0].name] - if find_master else None) + moment1 = self._get_accumulator( + self._moment1_acc_str, param_and_grad[0] + ) + moment2 = self._get_accumulator( + self._moment2_acc_str, param_and_grad[0] + ) + beta1_pow_acc = self._get_accumulator( + self._beta1_pow_acc_str, param_and_grad[0] + ) + beta2_pow_acc = self._get_accumulator( + self._beta2_pow_acc_str, param_and_grad[0] + ) + find_master = self._multi_precision and ( + param_and_grad[0].dtype == core.VarDesc.VarType.FP16 + or param_and_grad[0].dtype == core.VarDesc.VarType.BF16 + ) + master_weight = ( + self._master_weights[param_and_grad[0].name] + if find_master + else None + ) lr = self._create_param_lr(param_and_grad) # create the adamw optimize op if framework._non_static_mode(): - lr_ratio_ = 1. if self._lr_ratio is None else self._lr_ratio( - param_and_grad[0]) - - _beta1 = self._beta1 if not isinstance( - self._beta1, Variable) else self._beta1.numpy().item(0) - _beta2 = self._beta2 if not isinstance( - self._beta2, Variable) else self._beta2.numpy().item(0) + lr_ratio_ = ( + 1.0 + if self._lr_ratio is None + else self._lr_ratio(param_and_grad[0]) + ) + + _beta1 = ( + self._beta1 + if not isinstance(self._beta1, Variable) + else self._beta1.numpy().item(0) + ) + _beta2 = ( + self._beta2 + if not isinstance(self._beta2, Variable) + else self._beta2.numpy().item(0) + ) if framework.in_dygraph_mode(): found_inf = self._get_auxiliary_var('found_inf') _, _, _, _, _, _ = _C_ops.adamw_( - param_and_grad[0], param_and_grad[1], lr, moment1, moment2, - beta1_pow_acc, beta2_pow_acc, master_weight, found_inf, - _beta1, _beta2, self._epsilon, lr_ratio_, - self._weight_decay, with_decay, self._lazy_mode, 1000, - find_master, False) + param_and_grad[0], + param_and_grad[1], + lr, + moment1, + moment2, + beta1_pow_acc, + beta2_pow_acc, + master_weight, + found_inf, + _beta1, + _beta2, + self._epsilon, + lr_ratio_, + self._weight_decay, + with_decay, + self._lazy_mode, + 1000, + find_master, + False, + ) else: _, _, _, _, _, _ = _legacy_C_ops.adamw( - param_and_grad[0], param_and_grad[1], lr, moment1, moment2, - beta1_pow_acc, beta2_pow_acc, master_weight, - param_and_grad[0], moment1, moment2, beta1_pow_acc, - beta2_pow_acc, master_weight, 'epsilon', self._epsilon, - 'lazy_mode', self._lazy_mode, - 'min_row_size_to_use_multithread', 1000, 'beta1', _beta1, - 'beta2', _beta2, "with_decay", with_decay, 'coeff', - self._weight_decay, 'multi_precision', find_master, - 'lr_ratio', lr_ratio_) + param_and_grad[0], + param_and_grad[1], + lr, + moment1, + moment2, + beta1_pow_acc, + beta2_pow_acc, + master_weight, + param_and_grad[0], + moment1, + moment2, + beta1_pow_acc, + beta2_pow_acc, + master_weight, + 'epsilon', + self._epsilon, + 'lazy_mode', + self._lazy_mode, + 'min_row_size_to_use_multithread', + 1000, + 'beta1', + _beta1, + 'beta2', + _beta2, + "with_decay", + with_decay, + 'coeff', + self._weight_decay, + 'multi_precision', + find_master, + 'lr_ratio', + lr_ratio_, + ) return None inputs = { @@ -486,18 +584,14 @@ class AdamW(Optimizer): "Beta2PowOut": [beta2_pow_acc], } attrs = { - "lazy_mode": - self._lazy_mode, - "min_row_size_to_use_multithread": - 1000, - "multi_precision": - find_master, - "with_decay": - with_decay, - "coeff": - self._weight_decay, - "lr_ratio": - 1. if self._lr_ratio is None else self._lr_ratio(param_and_grad[0]) + "lazy_mode": self._lazy_mode, + "min_row_size_to_use_multithread": 1000, + "multi_precision": find_master, + "with_decay": with_decay, + "coeff": self._weight_decay, + "lr_ratio": 1.0 + if self._lr_ratio is None + else self._lr_ratio(param_and_grad[0]), } if isinstance(self._beta1, Variable): @@ -517,11 +611,13 @@ class AdamW(Optimizer): inputs["MasterParam"] = master_weight outputs["MasterParamOut"] = master_weight - adamw_op = block.append_op(type=self.type, - inputs=inputs, - outputs=outputs, - attrs=attrs, - stop_gradient=True) + adamw_op = block.append_op( + type=self.type, + inputs=inputs, + outputs=outputs, + attrs=attrs, + stop_gradient=True, + ) return adamw_op @@ -541,7 +637,7 @@ class AdamW(Optimizer): .. code-block:: python import paddle - + a = paddle.rand([2,13], dtype="float32") linear = paddle.nn.Linear(13, 5) # This can be any optimizer supported by dygraph. @@ -560,24 +656,28 @@ class AdamW(Optimizer): if param._grad_ivar() is not None: grad_var = param._grad_ivar() if framework.in_dygraph_mode(): - if hasattr(grad_var, "is_selected_rows" - ) and grad_var.is_selected_rows( - ) and self.regularization is not None: + if ( + hasattr(grad_var, "is_selected_rows") + and grad_var.is_selected_rows() + and self.regularization is not None + ): raise RuntimeError( "AdamW don't support weight_decay with sparse parameters, please set it to None." ) else: - if hasattr( - grad_var, "_is_sparse") and grad_var._is_sparse( - ) and self.regularization is not None: + if ( + hasattr(grad_var, "_is_sparse") + and grad_var._is_sparse() + and self.regularization is not None + ): raise RuntimeError( "AdamW don't support weight_decay with sparse parameters, please set it to None." ) params_grads.append((param, grad_var)) - optimize_ops = self._apply_optimize(loss=None, - startup_program=None, - params_grads=params_grads) + optimize_ops = self._apply_optimize( + loss=None, startup_program=None, params_grads=params_grads + ) else: # optimize parameters in groups for param_group in self._param_groups: @@ -588,35 +688,41 @@ class AdamW(Optimizer): if param._grad_ivar() is not None: grad_var = param._grad_ivar() if framework.in_dygraph_mode(): - if hasattr(grad_var, "is_selected_rows" - ) and grad_var.is_selected_rows( - ) and self.regularization is not None: + if ( + hasattr(grad_var, "is_selected_rows") + and grad_var.is_selected_rows() + and self.regularization is not None + ): raise RuntimeError( "AdamW don't support weight_decay with sparse parameters, please set it to None." ) else: - if hasattr(grad_var, - "_is_sparse") and grad_var._is_sparse( - ) and self.regularization is not None: + if ( + hasattr(grad_var, "_is_sparse") + and grad_var._is_sparse() + and self.regularization is not None + ): raise RuntimeError( "AdamW don't support weight_decay with sparse parameters, please set it to None." ) params_grads['params'].append((param, grad_var)) params_grads.update( - {k: v - for k, v in param_group.items() if k != 'params'}) - self._apply_optimize(loss=None, - startup_program=None, - params_grads=params_grads) + {k: v for k, v in param_group.items() if k != 'params'} + ) + self._apply_optimize( + loss=None, startup_program=None, params_grads=params_grads + ) def _update_param_group(self, parameters): self._beta1 = parameters.get('beta1', self._default_dict['beta1']) self._beta2 = parameters.get('beta2', self._default_dict['beta2']) self._epsilon = parameters.get('epsilon', self._default_dict['epsilon']) - self._lazy_mode = parameters.get('lazy_mode', - self._default_dict['lazy_mode']) - self._weight_decay = parameters.get('weight_decay', - self._default_dict['weight_decay']) + self._lazy_mode = parameters.get( + 'lazy_mode', self._default_dict['lazy_mode'] + ) + self._weight_decay = parameters.get( + 'weight_decay', self._default_dict['weight_decay'] + ) parameters = parameters.get('params') return parameters diff --git a/python/paddle/optimizer/optimizer.py b/python/paddle/optimizer/optimizer.py index 2ab61bb5487..3622564504e 100644 --- a/python/paddle/optimizer/optimizer.py +++ b/python/paddle/optimizer/optimizer.py @@ -440,15 +440,21 @@ class Optimizer(object): return self._opti_name_list def _create_global_learning_rate(self): - # lr var can't be float16, for pure fp16 training, should extra handle the dtype for lr + # lr var can't be float16 or bfloat16, for pure fp16 or fp16 training, should extra handle the dtype for lr _lr_dtype = ( paddle.get_default_dtype() if self._dtype is None else self._dtype ) _lr_dtype = ( paddle.float32 if ( - paddle.get_default_dtype() != "float16" - and _lr_dtype == paddle.float16 + ( + paddle.get_default_dtype() != "float16" + and _lr_dtype == paddle.float16 + ) + or ( + paddle.get_default_dtype() != "bfloat16" + and _lr_dtype == paddle.bfloat16 + ) ) else _lr_dtype ) -- GitLab