From b7b231a668ac51365cdce11dfafe6f7da04b2350 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Fri, 30 Sep 2022 09:49:48 +0800 Subject: [PATCH] support pure bfloat16 for more ops (#46364) * support pure bfloat16 * support bf16 linear * update PR to pass CI * tiny fix where_grad_kernel.cu * add bfloat16 to selu_grad to pass CI * fix selu grad compilation error --- .../operators/fused/fused_gemm_epilogue_op.cu | 13 +- .../platform/device/gpu/gpu_primitives.h | 111 +++++++++--------- paddle/phi/kernels/empty_kernel.cc | 1 + paddle/phi/kernels/funcs/activation_functor.h | 8 +- paddle/phi/kernels/funcs/eigen/broadcast.cu | 1 + .../phi/kernels/gpu/activation_grad_kernel.cu | 3 +- paddle/phi/kernels/gpu/activation_kernel.cu | 11 +- paddle/phi/kernels/gpu/adam_kernel.cu | 6 +- paddle/phi/kernels/gpu/clip_grad_kernel.cu | 1 + paddle/phi/kernels/gpu/clip_kernel.cu | 3 +- .../phi/kernels/gpu/embedding_grad_kernel.cu | 6 +- paddle/phi/kernels/gpu/embedding_kernel.cu | 3 +- paddle/phi/kernels/gpu/gelu_grad_kernel.cu | 3 +- paddle/phi/kernels/gpu/gelu_kernel.cu | 3 +- paddle/phi/kernels/gpu/pad3d_grad_kernel.cu | 3 +- paddle/phi/kernels/gpu/pad3d_kernel.cu | 1 + .../kernels/gpu/pixel_shuffle_grad_kernel.cu | 4 +- .../phi/kernels/gpu/pixel_shuffle_kernel.cu | 10 +- paddle/phi/kernels/gpu/selu_grad_kernel.cu | 9 +- paddle/phi/kernels/gpu/tile_grad_kernel.cu | 3 +- paddle/phi/kernels/gpu/where_grad_kernel.cu | 6 +- paddle/phi/kernels/gpu/where_kernel.cu | 12 +- paddle/phi/kernels/impl/selu_kernel_impl.h | 13 +- python/paddle/fluid/clip.py | 14 ++- python/paddle/optimizer/adam.py | 2 +- python/paddle/tensor/stat.py | 6 +- 26 files changed, 160 insertions(+), 96 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu index 5f3c60df9a..e5bab3cae4 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/scope_guard.h" +#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/fluid/platform/float16.h" @@ -63,6 +64,9 @@ class FusedGemmEpilogueKernel : public framework::OpKernel { if (std::is_same::value) { mat_type = CUDA_R_16F; } + if (std::is_same::value) { + mat_type = CUDA_R_16BF; + } if (std::is_same::value) { mat_type = CUDA_R_64F; scale_type = CUDA_R_64F; @@ -354,6 +358,9 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel { if (std::is_same::value) { mat_type = CUDA_R_16F; } + if (std::is_same::value) { + mat_type = CUDA_R_16BF; + } if (std::is_same::value) { mat_type = CUDA_R_64F; scale_type = CUDA_R_64F; @@ -688,12 +695,14 @@ REGISTER_OP_CUDA_KERNEL( fused_gemm_epilogue, ops::FusedGemmEpilogueKernel, ops::FusedGemmEpilogueKernel, - ops::FusedGemmEpilogueKernel); + ops::FusedGemmEpilogueKernel, + ops::FusedGemmEpilogueKernel); REGISTER_OP_CUDA_KERNEL( fused_gemm_epilogue_grad, ops::FusedGemmEpilogueGradKernel, ops::FusedGemmEpilogueGradKernel, ops::FusedGemmEpilogueGradKernel); + paddle::platform::float16>, + ops::FusedGemmEpilogueKernel); #endif diff --git a/paddle/fluid/platform/device/gpu/gpu_primitives.h b/paddle/fluid/platform/device/gpu/gpu_primitives.h index b99d6de5db..96eddf0923 100644 --- a/paddle/fluid/platform/device/gpu/gpu_primitives.h +++ b/paddle/fluid/platform/device/gpu/gpu_primitives.h @@ -198,61 +198,6 @@ __device__ __forceinline__ void fastAtomicAdd(T *arr, T value) { CudaAtomicAdd(arr + index, value); } - -#ifdef PADDLE_WITH_CUDA -/* - * One thead block deals with elementwise atomicAdd for vector of len. - * @in: [x1, x2, x3, ...] - * @out:[y1+x1, y2+x2, y3+x3, ...] - * */ -template ::value>::type * = nullptr> -__device__ __forceinline__ void VectorizedAtomicAddPerBlock( - const int64_t len, int tid, int threads_per_block, const T *in, T *out) { - for (int i = tid; i < len; i += threads_per_block) { - CudaAtomicAdd(&out[i], in[i]); - } -} - -// Note: assume that len is even. If len is odd, call fastAtomicAdd directly. -template ::value>::type * = nullptr> -__device__ __forceinline__ void VectorizedAtomicAddPerBlock( - const int64_t len, int tid, int threads_per_block, const T *in, T *out) { -#if ((CUDA_VERSION < 10000) || \ - (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) - for (int i = tid; i < len; i += threads_per_block) { - CudaAtomicAdd(&out[i], in[i]); - } -#else - int i = 0; - int loops = len / 2 * 2; - - bool aligned_half2 = - (reinterpret_cast(out) % sizeof(__half2) == 0); - - if (aligned_half2) { - for (i = tid * 2; i < loops; i += threads_per_block * 2) { - __half2 value2; - T value_1 = in[i]; - T value_2 = in[i + 1]; - value2.x = *reinterpret_cast<__half *>(&value_1); - value2.y = *reinterpret_cast<__half *>(&value_2); - atomicAdd(reinterpret_cast<__half2 *>(&out[i]), value2); - } - for (; i < len; i += threads_per_block) { - fastAtomicAdd(out, i, len, in[i]); - } - } else { - for (int i = tid; i < len; i += threads_per_block) { - fastAtomicAdd(out, i, len, in[i]); - } - } -#endif -} -#endif #endif // NOTE(zhangbo): cuda do not have atomicCAS for __nv_bfloat16. @@ -601,5 +546,61 @@ CUDA_ATOMIC_WRAPPER(Min, float16) { } #endif +#ifdef PADDLE_CUDA_FP16 +#ifdef PADDLE_WITH_CUDA +/* + * One thead block deals with elementwise atomicAdd for vector of len. + * @in: [x1, x2, x3, ...] + * @out:[y1+x1, y2+x2, y3+x3, ...] + * */ +template ::value>::type * = nullptr> +__device__ __forceinline__ void VectorizedAtomicAddPerBlock( + const int64_t len, int tid, int threads_per_block, const T *in, T *out) { + for (int i = tid; i < len; i += threads_per_block) { + CudaAtomicAdd(&out[i], in[i]); + } +} + +// Note: assume that len is even. If len is odd, call fastAtomicAdd directly. +template ::value>::type * = nullptr> +__device__ __forceinline__ void VectorizedAtomicAddPerBlock( + const int64_t len, int tid, int threads_per_block, const T *in, T *out) { +#if ((CUDA_VERSION < 10000) || \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 700))) + for (int i = tid; i < len; i += threads_per_block) { + CudaAtomicAdd(&out[i], in[i]); + } +#else + int i = 0; + int loops = len / 2 * 2; + + bool aligned_half2 = + (reinterpret_cast(out) % sizeof(__half2) == 0); + + if (aligned_half2) { + for (i = tid * 2; i < loops; i += threads_per_block * 2) { + __half2 value2; + T value_1 = in[i]; + T value_2 = in[i + 1]; + value2.x = *reinterpret_cast<__half *>(&value_1); + value2.y = *reinterpret_cast<__half *>(&value_2); + atomicAdd(reinterpret_cast<__half2 *>(&out[i]), value2); + } + for (; i < len; i += threads_per_block) { + fastAtomicAdd(out, i, len, in[i]); + } + } else { + for (int i = tid; i < len; i += threads_per_block) { + fastAtomicAdd(out, i, len, in[i]); + } + } +#endif +} +#endif +#endif } // namespace platform } // namespace paddle diff --git a/paddle/phi/kernels/empty_kernel.cc b/paddle/phi/kernels/empty_kernel.cc index 2c969cc43d..01b07c438a 100644 --- a/paddle/phi/kernels/empty_kernel.cc +++ b/paddle/phi/kernels/empty_kernel.cc @@ -88,6 +88,7 @@ PD_REGISTER_KERNEL(empty, int64_t, bool, phi::dtype::float16, + phi::dtype::bfloat16, phi::dtype::complex, phi::dtype::complex) {} diff --git a/paddle/phi/kernels/funcs/activation_functor.h b/paddle/phi/kernels/funcs/activation_functor.h index 51420c5ecb..2af106ca38 100644 --- a/paddle/phi/kernels/funcs/activation_functor.h +++ b/paddle/phi/kernels/funcs/activation_functor.h @@ -2169,12 +2169,14 @@ struct CudaSeluFunctor : public BaseActivationFunctor { } __device__ __forceinline__ T operator()(const T x) const { - T res = x; - if (res <= zero) { + using MT = + typename std::conditional<(sizeof(T) > sizeof(float)), T, float>::type; + MT res = static_cast(x); + if (x <= zero) { res = alpha * expf(res) - alpha; } res *= scale; - return res; + return static_cast(res); } private: diff --git a/paddle/phi/kernels/funcs/eigen/broadcast.cu b/paddle/phi/kernels/funcs/eigen/broadcast.cu index 0b749f5c00..0c5a340887 100644 --- a/paddle/phi/kernels/funcs/eigen/broadcast.cu +++ b/paddle/phi/kernels/funcs/eigen/broadcast.cu @@ -84,6 +84,7 @@ INSTANTIATION(EigenBroadcast, int); INSTANTIATION(EigenBroadcast, int64_t); INSTANTIATION(EigenBroadcastGrad, bool); INSTANTIATION(EigenBroadcastGrad, float); +INSTANTIATION(EigenBroadcastGrad, dtype::bfloat16); INSTANTIATION(EigenBroadcastGrad, dtype::float16); INSTANTIATION(EigenBroadcastGrad, double); INSTANTIATION(EigenBroadcastGrad, dtype::complex); diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index 53f727ec51..b947c70cb8 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -449,4 +449,5 @@ PD_REGISTER_KERNEL(pow_grad, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index 0e9e754a99..e57332c407 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -265,5 +265,12 @@ PD_REGISTER_KERNEL(pow, double, int, int64_t, - phi::dtype::float16) {} -PD_REGISTER_KERNEL(selu, GPU, ALL_LAYOUT, phi::SeluKernel, float, double) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} +PD_REGISTER_KERNEL(selu, + GPU, + ALL_LAYOUT, + phi::SeluKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/adam_kernel.cu b/paddle/phi/kernels/gpu/adam_kernel.cu index d44f6d2800..0597311e21 100644 --- a/paddle/phi/kernels/gpu/adam_kernel.cu +++ b/paddle/phi/kernels/gpu/adam_kernel.cu @@ -372,7 +372,8 @@ PD_REGISTER_KERNEL(adam, phi::AdamDenseKernel, float, double, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { // Skip beta1_pow, beta2_pow, skip_update data transform kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); @@ -385,7 +386,8 @@ PD_REGISTER_KERNEL(merged_adam, phi::MergedAdamKernel, float, double, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { // Skip beta1_pow, beta2_pow data transform kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); diff --git a/paddle/phi/kernels/gpu/clip_grad_kernel.cu b/paddle/phi/kernels/gpu/clip_grad_kernel.cu index 4566e8468e..60d311a255 100644 --- a/paddle/phi/kernels/gpu/clip_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/clip_grad_kernel.cu @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(clip_grad, double, int, int64_t, + phi::dtype::bfloat16, phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/clip_kernel.cu b/paddle/phi/kernels/gpu/clip_kernel.cu index 9e0050db7f..e8d519a5d3 100644 --- a/paddle/phi/kernels/gpu/clip_kernel.cu +++ b/paddle/phi/kernels/gpu/clip_kernel.cu @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(clip, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu index 6694216214..e10d01ce9e 100644 --- a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu @@ -256,7 +256,8 @@ PD_REGISTER_KERNEL(embedding_grad, phi::EmbeddingGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(embedding_sparse_grad, GPU, @@ -264,4 +265,5 @@ PD_REGISTER_KERNEL(embedding_sparse_grad, phi::EmbeddingSparseGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/embedding_kernel.cu b/paddle/phi/kernels/gpu/embedding_kernel.cu index 90f3cc8d36..bb22fea5f6 100644 --- a/paddle/phi/kernels/gpu/embedding_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_kernel.cu @@ -125,4 +125,5 @@ PD_REGISTER_KERNEL(embedding, phi::EmbeddingKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/gelu_grad_kernel.cu b/paddle/phi/kernels/gpu/gelu_grad_kernel.cu index 1f33d5c901..b1ffa921f9 100644 --- a/paddle/phi/kernels/gpu/gelu_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/gelu_grad_kernel.cu @@ -99,4 +99,5 @@ PD_REGISTER_KERNEL(gelu_grad, phi::GeluGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/gelu_kernel.cu b/paddle/phi/kernels/gpu/gelu_kernel.cu index 509a5ccf4d..e0792c387d 100644 --- a/paddle/phi/kernels/gpu/gelu_kernel.cu +++ b/paddle/phi/kernels/gpu/gelu_kernel.cu @@ -93,4 +93,5 @@ PD_REGISTER_KERNEL(gelu, phi::GeluKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu b/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu index e9f820a318..fb7f1a2325 100644 --- a/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/pad3d_grad_kernel.cu @@ -509,4 +509,5 @@ PD_REGISTER_KERNEL(pad3d_grad, phi::Pad3dGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/pad3d_kernel.cu b/paddle/phi/kernels/gpu/pad3d_kernel.cu index d1b1d70667..fa85c650bc 100644 --- a/paddle/phi/kernels/gpu/pad3d_kernel.cu +++ b/paddle/phi/kernels/gpu/pad3d_kernel.cu @@ -583,6 +583,7 @@ PD_REGISTER_KERNEL(pad3d, ALL_LAYOUT, phi::Pad3dKernel, phi::dtype::float16, + phi::dtype::bfloat16, float, double, int, diff --git a/paddle/phi/kernels/gpu/pixel_shuffle_grad_kernel.cu b/paddle/phi/kernels/gpu/pixel_shuffle_grad_kernel.cu index 1414fb9df0..5c88bbbf42 100644 --- a/paddle/phi/kernels/gpu/pixel_shuffle_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/pixel_shuffle_grad_kernel.cu @@ -23,4 +23,6 @@ PD_REGISTER_KERNEL(pixel_shuffle_grad, ALL_LAYOUT, phi::PixelShuffleGradKernel, float, - double) {} + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/pixel_shuffle_kernel.cu b/paddle/phi/kernels/gpu/pixel_shuffle_kernel.cu index e43d6f9612..09eb0485a2 100644 --- a/paddle/phi/kernels/gpu/pixel_shuffle_kernel.cu +++ b/paddle/phi/kernels/gpu/pixel_shuffle_kernel.cu @@ -18,5 +18,11 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/pixel_shuffle_kernel_impl.h" -PD_REGISTER_KERNEL( - pixel_shuffle, GPU, ALL_LAYOUT, phi::PixelShuffleKernel, float, double) {} +PD_REGISTER_KERNEL(pixel_shuffle, + GPU, + ALL_LAYOUT, + phi::PixelShuffleKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/selu_grad_kernel.cu b/paddle/phi/kernels/gpu/selu_grad_kernel.cu index 0ed299413c..c715831ffc 100644 --- a/paddle/phi/kernels/gpu/selu_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/selu_grad_kernel.cu @@ -18,5 +18,10 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/impl/selu_grad_kernel_impl.h" -PD_REGISTER_KERNEL( - selu_grad, GPU, ALL_LAYOUT, phi::SeluGradKernel, float, double) {} +PD_REGISTER_KERNEL(selu_grad, + GPU, + ALL_LAYOUT, + phi::SeluGradKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/tile_grad_kernel.cu b/paddle/phi/kernels/gpu/tile_grad_kernel.cu index c092609e62..d1e356df40 100644 --- a/paddle/phi/kernels/gpu/tile_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/tile_grad_kernel.cu @@ -27,4 +27,5 @@ PD_REGISTER_KERNEL(tile_grad, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/where_grad_kernel.cu b/paddle/phi/kernels/gpu/where_grad_kernel.cu index 709dddcb82..4c411bfb9c 100644 --- a/paddle/phi/kernels/gpu/where_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/where_grad_kernel.cu @@ -25,10 +25,10 @@ __global__ void WhereGradCUDAKernel( int idx = blockDim.x * blockIdx.x + threadIdx.x; for (; idx < N; idx += blockDim.x * gridDim.x) { if (dx != nullptr) { - dx[idx] = cond[idx] ? dout[idx] : 0.; + dx[idx] = cond[idx] ? dout[idx] : static_cast(0.); } if (dy != nullptr) { - dy[idx] = cond[idx] ? 0. : dout[idx]; + dy[idx] = cond[idx] ? static_cast(0.) : dout[idx]; } } } @@ -61,6 +61,8 @@ PD_REGISTER_KERNEL(where_grad, GPU, ALL_LAYOUT, phi::WhereGradKernel, + phi::dtype::float16, + phi::dtype::bfloat16, float, double, int, diff --git a/paddle/phi/kernels/gpu/where_kernel.cu b/paddle/phi/kernels/gpu/where_kernel.cu index 441be02b99..09a974fbc2 100644 --- a/paddle/phi/kernels/gpu/where_kernel.cu +++ b/paddle/phi/kernels/gpu/where_kernel.cu @@ -45,5 +45,13 @@ void WhereKernel(const Context& ctx, } // namespace phi -PD_REGISTER_KERNEL( - where, GPU, ALL_LAYOUT, phi::WhereKernel, float, double, int, int64_t) {} +PD_REGISTER_KERNEL(where, + GPU, + ALL_LAYOUT, + phi::WhereKernel, + float, + double, + int, + int64_t, + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/impl/selu_kernel_impl.h b/paddle/phi/kernels/impl/selu_kernel_impl.h index 288f7bb9b7..0725b14125 100644 --- a/paddle/phi/kernels/impl/selu_kernel_impl.h +++ b/paddle/phi/kernels/impl/selu_kernel_impl.h @@ -57,14 +57,17 @@ struct SeluGradFunctor { dx_data_ptr_(dx_data_ptr) {} HOSTDEVICE void operator()(size_t idx) const { - T y_ele = y_data_ptr_[idx]; - T dy_ele = dy_data_ptr_[idx]; + using MT = + typename std::conditional<(sizeof(T) > sizeof(float)), T, float>::type; - float tmp = scale_; + auto y_ele = static_cast(y_data_ptr_[idx]); + auto dy_ele = static_cast(dy_data_ptr_[idx]); + + auto tmp = static_cast(scale_); if (y_ele <= 0) { - tmp = y_ele + la_; + tmp = y_ele + static_cast(la_); } - dx_data_ptr_[idx] = dy_ele * tmp; + dx_data_ptr_[idx] = static_cast(dy_ele * tmp); } const T* y_data_ptr_; const T* dy_data_ptr_; diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index e6f2e17c05..e9e3645852 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -50,8 +50,9 @@ def _clip_by_global_norm_using_mp_type(*args): def _cast_to_mp_type_if_enabled(x): - if x.dtype == core.VarDesc.VarType.FP16 and _clip_by_global_norm_using_mp_type( - ): + if (x.dtype == core.VarDesc.VarType.FP16 + or x.dtype == core.VarDesc.VarType.BF16 + ) and _clip_by_global_norm_using_mp_type(): return x.astype(core.VarDesc.VarType.FP32) else: return x @@ -63,7 +64,8 @@ def _squared_l2_norm(x): """ x = _cast_to_mp_type_if_enabled(x) - if core.is_compiled_with_xpu() or x.dtype == core.VarDesc.VarType.FP16: + if core.is_compiled_with_xpu( + ) or x.dtype == core.VarDesc.VarType.FP16 or x.dtype == core.VarDesc.VarType.BF16: square = layers.square(x) sum_square = layers.reduce_sum(square) return sum_square @@ -499,7 +501,7 @@ class ClipGradByGlobalNorm(ClipGradBase): merge_grad = layers.get_tensor_from_selected_rows(merge_grad) sum_square = _squared_l2_norm(merge_grad) - if sum_square.dtype == core.VarDesc.VarType.FP16: + if sum_square.dtype == core.VarDesc.VarType.FP16 or sum_square.dtype == core.VarDesc.VarType.BF16: sum_square_list_fp16.append(sum_square) elif sum_square.dtype == core.VarDesc.VarType.FP32: sum_square_list_fp32.append(sum_square) @@ -552,8 +554,8 @@ class ClipGradByGlobalNorm(ClipGradBase): continue # TODO(wangxi): use inplace elementwise_mul if need_clip: - clip_input = (clip_var.astype('float16') if g.dtype - == core.VarDesc.VarType.FP16 else clip_var) + clip_input = (clip_var.astype(g.dtype) + if clip_var.dtype != g.dtype else clip_var) new_grad = layers.elementwise_mul(g, clip_input) params_and_grads.append((p, new_grad)) else: diff --git a/python/paddle/optimizer/adam.py b/python/paddle/optimizer/adam.py index 4f8122121b..41d22e778f 100644 --- a/python/paddle/optimizer/adam.py +++ b/python/paddle/optimizer/adam.py @@ -275,7 +275,7 @@ class Adam(Optimizer): def _add_moments_pows(self, p): acc_dtype = p.dtype - if acc_dtype == core.VarDesc.VarType.FP16: + if acc_dtype == core.VarDesc.VarType.FP16 or acc_dtype == core.VarDesc.VarType.BF16: acc_dtype = core.VarDesc.VarType.FP32 self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype) diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index b5946459d3..144620f3c6 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -159,8 +159,10 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None): u = mean(x, axis, True, name) out = paddle.sum((x - u)**2, axis, keepdim=keepdim, name=name) - n = paddle.cast(paddle.numel(x), x.dtype) \ - / paddle.cast(paddle.numel(out), x.dtype) + dtype = x.dtype + n = paddle.cast(paddle.numel(x), paddle.int64) \ + / paddle.cast(paddle.numel(out), paddle.int64) + n = n.astype(dtype) if unbiased: one_const = paddle.ones([1], x.dtype) n = where(n > one_const, n - 1., one_const) -- GitLab