diff --git a/paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cu b/paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cu index f92479888f8177e9348162a3dba79369aa2b6710..94699c9ce69541224fa2eb1e6f3a3b090f164d33 100644 --- a/paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cu +++ b/paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cu @@ -68,6 +68,11 @@ __device__ __inline__ void load_data_upper_tri(plat::float16* dst, *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); } +__device__ __inline__ void load_data_upper_tri(plat::bfloat16* dst, + const plat::bfloat16* src) { + *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); +} + __device__ __inline__ void load_data_upper_tri(float* dst, const float* src) { *(reinterpret_cast(dst)) = *(reinterpret_cast(src)); } @@ -76,6 +81,10 @@ __device__ __inline__ void load_zero_vector_upper_tri(plat::float16* dst) { *(reinterpret_cast(dst)) = make_float2(0.0f, 0.0f); } +__device__ __inline__ void load_zero_vector_upper_tri(plat::bfloat16* dst) { + *(reinterpret_cast(dst)) = make_float2(0.0f, 0.0f); +} + __device__ __inline__ void load_zero_vector_upper_tri(float* dst) { *(reinterpret_cast(dst)) = make_float4(0.0f, 0.0f, 0.0f, 0.0f); } @@ -595,8 +604,11 @@ namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( fused_softmax_mask_upper_triangle, ops::SoftmaxMaskFuseUpperTriangleKernel, + ops::SoftmaxMaskFuseUpperTriangleKernel, ops::SoftmaxMaskFuseUpperTriangleKernel); REGISTER_OP_CUDA_KERNEL( fused_softmax_mask_upper_triangle_grad, ops::SoftmaxMaskFuseUpperTriangleGradKernel, + ops::SoftmaxMaskFuseUpperTriangleGradKernel, ops::SoftmaxMaskFuseUpperTriangleGradKernel); diff --git a/paddle/fluid/operators/math.h b/paddle/fluid/operators/math.h index 47281fb0280f0fc5128d978d9aedaeb4e8d19cd3..f376663ecec04d420e63e3d72effea3fcc6483e1 100644 --- a/paddle/fluid/operators/math.h +++ b/paddle/fluid/operators/math.h @@ -15,6 +15,7 @@ #pragma once #include "math.h" // NOLINT +#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/float16.h" #include "paddle/phi/core/hostdevice.h" @@ -33,6 +34,10 @@ inline HOSTDEVICE platform::float16 real_log(platform::float16 x) { return static_cast(::logf(static_cast(x))); } +inline HOSTDEVICE phi::dtype::bfloat16 real_log(phi::dtype::bfloat16 x) { + return static_cast(::logf(static_cast(x))); +} + inline HOSTDEVICE float real_log(float x) { return ::logf(x); } inline HOSTDEVICE double real_log(double x) { return ::log(x); } diff --git a/paddle/fluid/operators/math/cross_entropy.cu b/paddle/fluid/operators/math/cross_entropy.cu index c366dd6fcef349b6d2859e801f00cef96f44d13d..f8bd4b60d47d98b95b419a912c852679466d1bb3 100644 --- a/paddle/fluid/operators/math/cross_entropy.cu +++ b/paddle/fluid/operators/math/cross_entropy.cu @@ -13,9 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/cross_entropy.h" + #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/operators/math.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" +#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_context.h" @@ -152,7 +154,10 @@ void CrossEntropyFunctor::operator()( template class CrossEntropyFunctor; template class CrossEntropyFunctor; -template class CrossEntropyFunctor; +template class CrossEntropyFunctor; +#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION_MIN(8, 1, 0) +template class CrossEntropyFunctor; +#endif } // namespace math } // namespace operators diff --git a/paddle/fluid/operators/math/cross_entropy.h b/paddle/fluid/operators/math/cross_entropy.h index 0de10789ba02ee97b3710ac766519217347a0fcd..651579005b99e049955a090b6382fa05729ace61 100644 --- a/paddle/fluid/operators/math/cross_entropy.h +++ b/paddle/fluid/operators/math/cross_entropy.h @@ -17,7 +17,8 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/float16.h" +#include "paddle/phi/common/bfloat16.h" +#include "paddle/phi/common/float16.h" #include "paddle/phi/core/hostdevice.h" namespace paddle { @@ -46,14 +47,30 @@ struct TolerableValue { // Also. In standard implementation of cross entropy, other // framework not has the ValueClipping. template <> -struct TolerableValue { - HOSTDEVICE platform::float16 operator()(const platform::float16& x) const { - if (platform::isfinite(x)) +struct TolerableValue { + HOSTDEVICE phi::dtype::float16 operator()( + const phi::dtype::float16& x) const { + if (phi::dtype::isfinite(x)) { return x; - else if (x > static_cast(0)) - return std::numeric_limits::max(); - else - return std::numeric_limits::min(); + } else if (x > static_cast(0)) { + return std::numeric_limits::max(); + } else { + return std::numeric_limits::min(); + } + } +}; + +template <> +struct TolerableValue { + HOSTDEVICE phi::dtype::bfloat16 operator()( + const phi::dtype::bfloat16& x) const { + if (phi::dtype::isfinite(x)) { + return x; + } else if (x > static_cast(0)) { + return std::numeric_limits::max(); + } else { + return std::numeric_limits::min(); + } } }; diff --git a/paddle/phi/kernels/full_kernel.cc b/paddle/phi/kernels/full_kernel.cc index 9622bff5c255aef470cbd50c9e8496e39bf7d02b..ce898210633b7d02c75bafa04c958f7f93a11e02 100644 --- a/paddle/phi/kernels/full_kernel.cc +++ b/paddle/phi/kernels/full_kernel.cc @@ -59,7 +59,8 @@ PD_REGISTER_KERNEL(full_batch_size_like, int, int64_t, bool, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); } #endif diff --git a/paddle/phi/kernels/gpu/activation_grad_kernel.cu b/paddle/phi/kernels/gpu/activation_grad_kernel.cu index b947c70cb89d495d6fe9f58fe889119e4e8e54a9..d40f3b5013a9a1be416972be5caed4fb7dbc3fe8 100644 --- a/paddle/phi/kernels/gpu/activation_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_grad_kernel.cu @@ -370,7 +370,8 @@ PD_REGISTER_KERNEL(exp_grad, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_ACTIVATION_GRAD_KERNEL(soft_shrink_grad, SoftShrinkGradKernel) PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_shrink_grad, HardShrinkGradKernel) @@ -385,7 +386,8 @@ PD_REGISTER_KERNEL(expm1_grad, phi::Expm1GradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(logit_grad, GPU, @@ -393,7 +395,8 @@ PD_REGISTER_KERNEL(logit_grad, phi::LogitGradKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(square_grad, GPU, diff --git a/paddle/phi/kernels/gpu/activation_kernel.cu b/paddle/phi/kernels/gpu/activation_kernel.cu index e57332c40756af5e6b3e87f1ed8d966124945553..ab32f420701755ad0356ccb36e675790751af01d 100644 --- a/paddle/phi/kernels/gpu/activation_kernel.cu +++ b/paddle/phi/kernels/gpu/activation_kernel.cu @@ -212,21 +212,24 @@ PD_REGISTER_KERNEL(exp, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(expm1, GPU, ALL_LAYOUT, phi::Expm1Kernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(logit, GPU, ALL_LAYOUT, phi::LogitKernel, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(square, GPU, ALL_LAYOUT, diff --git a/paddle/phi/kernels/gpu/arg_min_max_kernel.cu b/paddle/phi/kernels/gpu/arg_min_max_kernel.cu index 13db18534955b6a1200641b53b185e11afd033bd..4c440ed0dd71937508aefa11506b5e6534f7d11a 100644 --- a/paddle/phi/kernels/gpu/arg_min_max_kernel.cu +++ b/paddle/phi/kernels/gpu/arg_min_max_kernel.cu @@ -255,6 +255,7 @@ PD_REGISTER_KERNEL(arg_min, ALL_LAYOUT, phi::ArgMinKernel, phi::dtype::float16, + phi::dtype::bfloat16, float, double, int32_t, @@ -267,6 +268,7 @@ PD_REGISTER_KERNEL(arg_max, ALL_LAYOUT, phi::ArgMaxKernel, phi::dtype::float16, + phi::dtype::bfloat16, float, double, int32_t, diff --git a/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu b/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu index 5d40304c5e0c669d1af3b22bb9a79ecb6b34f0ef..93cdf64a8ef37bb360981eade44d4e6149871fc4 100644 --- a/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu @@ -282,6 +282,7 @@ void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx, } // namespace phi +#ifdef PADDLE_WITH_HIP PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad, GPU, ALL_LAYOUT, @@ -289,3 +290,23 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad, float, double, phi::dtype::float16) {} +#else +#if CUDNN_VERSION_MIN(8, 1, 0) +PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad, + GPU, + ALL_LAYOUT, + phi::CrossEntropyWithSoftmaxGradKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} +#else +PD_REGISTER_KERNEL(cross_entropy_with_softmax_grad, + GPU, + ALL_LAYOUT, + phi::CrossEntropyWithSoftmaxGradKernel, + float, + double, + phi::dtype::float16) {} +#endif +#endif diff --git a/paddle/phi/kernels/gpu/cross_entropy_kernel.cu b/paddle/phi/kernels/gpu/cross_entropy_kernel.cu index 76201a1077edbb4dd82fc013473acbcd764b9027..087ba293fb840c6a46e7a873b7d79270df44cca8 100644 --- a/paddle/phi/kernels/gpu/cross_entropy_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_entropy_kernel.cu @@ -252,7 +252,7 @@ __device__ __forceinline__ AccT ThreadReduce(const T* input, input -= offset; size += offset; if (tid >= offset) { - val = reducer(val, input[tid]); + val = reducer(val, static_cast(input[tid])); } size -= blockDim.x; input += blockDim.x; @@ -268,14 +268,14 @@ __device__ __forceinline__ AccT ThreadReduce(const T* input, #pragma unroll for (int i = 0; i < VecSize; ++i) { - val = reducer(val, ins[i]); + val = reducer(val, static_cast(ins[i])); } } // scalar part tid = size - remain + threadIdx.x; for (; tid < size; tid += blockDim.x) { - val = reducer(val, input[tid]); + val = reducer(val, static_cast(input[tid])); } return val; } @@ -1470,6 +1470,16 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax, float, phi::dtype::float16) {} #else +#if CUDNN_VERSION_MIN(8, 1, 0) +PD_REGISTER_KERNEL(cross_entropy_with_softmax, + GPU, + ALL_LAYOUT, + phi::CrossEntropyWithSoftmaxKernel, + float, + double, + phi::dtype::float16, + phi::dtype::bfloat16) {} +#else PD_REGISTER_KERNEL(cross_entropy_with_softmax, GPU, ALL_LAYOUT, @@ -1478,3 +1488,4 @@ PD_REGISTER_KERNEL(cross_entropy_with_softmax, double, phi::dtype::float16) {} #endif +#endif diff --git a/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu b/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu index a78dc717b046b61b54bc344f9f943b0faeb0ce50..da1045c27c58d7b5f69e9e14de03a1224f3d8807 100644 --- a/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/gather_nd_grad_kernel.cu @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/gather_nd_grad_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" @@ -63,4 +64,5 @@ PD_REGISTER_KERNEL(gather_nd_grad, double, int64_t, int, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/gather_nd_kernel.cu b/paddle/phi/kernels/gpu/gather_nd_kernel.cu index 7b2412958902d3a98e9f612a06f4efdf1eb044bb..b8ac4aa263afa81a1111b243572438c0b2ae410e 100644 --- a/paddle/phi/kernels/gpu/gather_nd_kernel.cu +++ b/paddle/phi/kernels/gpu/gather_nd_kernel.cu @@ -15,6 +15,7 @@ #include "paddle/phi/kernels/gather_nd_kernel.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" @@ -58,4 +59,5 @@ PD_REGISTER_KERNEL(gather_nd, int, int16_t, bool, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/index_sample_grad_kernel.cu b/paddle/phi/kernels/gpu/index_sample_grad_kernel.cu index d2671dff7b0184f50345bdeeadb59d018a3badf0..db1c3966e911b858cd559b4a6bf8ea6ad5c8e3c8 100644 --- a/paddle/phi/kernels/gpu/index_sample_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/index_sample_grad_kernel.cu @@ -134,6 +134,8 @@ PD_REGISTER_KERNEL(index_sample_grad, GPU, ALL_LAYOUT, phi::IndexSampleGradKernel, + phi::dtype::float16, + phi::dtype::bfloat16, float, double, int, diff --git a/paddle/phi/kernels/gpu/index_sample_kernel.cu b/paddle/phi/kernels/gpu/index_sample_kernel.cu index 9b95d761fcbad475d142553741ffd658f0688884..053851fa26598923576cb150b34b4a3686c818a0 100644 --- a/paddle/phi/kernels/gpu/index_sample_kernel.cu +++ b/paddle/phi/kernels/gpu/index_sample_kernel.cu @@ -107,6 +107,8 @@ PD_REGISTER_KERNEL(index_sample, GPU, ALL_LAYOUT, phi::IndexSampleKernel, + phi::dtype::float16, + phi::dtype::bfloat16, float, double, int, diff --git a/paddle/phi/kernels/gpu/tril_triu_grad_kernel.cu b/paddle/phi/kernels/gpu/tril_triu_grad_kernel.cu index 3271b38ae87268b2d372a7bf2b3147410ef57f02..ba93ed41a493b8913c4333dedc4482a5dc5d7d9c 100644 --- a/paddle/phi/kernels/gpu/tril_triu_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/tril_triu_grad_kernel.cu @@ -25,4 +25,5 @@ PD_REGISTER_KERNEL(tril_triu_grad, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/gpu/tril_triu_kernel.cu b/paddle/phi/kernels/gpu/tril_triu_kernel.cu index 65dcca70584b81c52605f619c14d6c2eeb24438c..db42fa7d425ddf3327c2ab48fff6054398db0820 100644 --- a/paddle/phi/kernels/gpu/tril_triu_kernel.cu +++ b/paddle/phi/kernels/gpu/tril_triu_kernel.cu @@ -25,4 +25,5 @@ PD_REGISTER_KERNEL(tril_triu, double, int, int64_t, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/kps/compare_kernel.cu b/paddle/phi/kernels/kps/compare_kernel.cu index b981d802255a2a4550034b69fe7d7fc10e341587..b882fcc2a6c960032936c040bb1604776f60de6a 100644 --- a/paddle/phi/kernels/kps/compare_kernel.cu +++ b/paddle/phi/kernels/kps/compare_kernel.cu @@ -114,7 +114,8 @@ PD_REGISTER_KERNEL(less_than, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(less_equal, KPS, ALL_LAYOUT, @@ -125,7 +126,8 @@ PD_REGISTER_KERNEL(less_equal, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(greater_than, KPS, ALL_LAYOUT, @@ -136,7 +138,8 @@ PD_REGISTER_KERNEL(greater_than, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(greater_equal, KPS, ALL_LAYOUT, @@ -147,7 +150,8 @@ PD_REGISTER_KERNEL(greater_equal, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(equal, KPS, ALL_LAYOUT, @@ -158,7 +162,8 @@ PD_REGISTER_KERNEL(equal, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(not_equal, KPS, ALL_LAYOUT, @@ -169,7 +174,8 @@ PD_REGISTER_KERNEL(not_equal, int64_t, float, double, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} PD_REGISTER_KERNEL(equal_all, KPS, diff --git a/paddle/phi/kernels/shape_kernel.cc b/paddle/phi/kernels/shape_kernel.cc index 2c2b41e3c66fc7d192d28f06633ac72bf02c35b2..b866719859c8f0388a1394495c25ccbafbe4d59d 100644 --- a/paddle/phi/kernels/shape_kernel.cc +++ b/paddle/phi/kernels/shape_kernel.cc @@ -63,7 +63,8 @@ PD_REGISTER_KERNEL(shape, double, phi::dtype::complex, phi::dtype::complex, - phi::dtype::float16) { + phi::dtype::float16, + phi::dtype::bfloat16) { kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); } #endif diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index f987e8b89cf2549996cdcf09f374808086dc4224..422a11c7e885388e5e2ed566a99e495c63062585 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -3791,8 +3791,17 @@ def gather_nd(x, index, name=None): check_variable_and_dtype( x, 'x', - ['bool', 'float32', 'float64', 'int16', 'int32', 'int64'], - 'gather_np', + [ + 'bool', + 'float16', + 'uint16', + 'float32', + 'float64', + 'int16', + 'int32', + 'int64', + ], + 'gather_nd', ) check_variable_and_dtype(index, 'index', ['int32', 'int64'], 'gather_np') helper = LayerHelper('gather_nd', **locals())