From d7f4599d11cf6fec67ecab38129a36d06cac10b5 Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Wed, 6 Jul 2022 15:09:13 +0800 Subject: [PATCH] Fix nan in fused multi transformer (#44093) --- paddle/fluid/distributed/store/tcp_store.cc | 5 +- .../fused/fused_multi_transformer_op.cu | 66 +++++++++++++++++-- 2 files changed, 64 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/distributed/store/tcp_store.cc b/paddle/fluid/distributed/store/tcp_store.cc index a67ca29a543..e4228e4428d 100644 --- a/paddle/fluid/distributed/store/tcp_store.cc +++ b/paddle/fluid/distributed/store/tcp_store.cc @@ -125,7 +125,10 @@ void MasterDaemon::CloseControlFd() { void MasterDaemon::StopByControlFd() { VLOG(4) << ("begin to run StopByControlFd"); if (_control_fd[1] != -1) { - ::write(_control_fd[1], "\0", 1); + PADDLE_ENFORCE_NE(::write(_control_fd[1], "\0", 1), + -1, + platform::errors::Fatal( + "failed to write control pipe errno:%d", errno)); // close the write end of the pipe ::close(_control_fd[1]); _control_fd[1] = -1; diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu index f806359093c..fafbcf724d7 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu @@ -294,6 +294,52 @@ inline __device__ uint4 mul(uint4 a, uint4 b) { return c; } +template <> +inline __device__ uint32_t mul(uint32_t a, float b) { + float2 tmp = half2_to_float2(a); + float2 tmp_res; + tmp_res.x = tmp.x * b; + tmp_res.y = tmp.y * b; + uint32_t res = float2_to_half2(tmp_res); + return res; +} + +template <> +inline __device__ uint2 mul(uint2 a, float b) { + uint2 res; + res.x = mul(a.x, b); + res.y = mul(a.y, b); + return res; +} + +template <> +inline __device__ uint4 mul(uint4 a, float b) { + uint4 res; + res.x = mul(a.x, b); + res.y = mul(a.y, b); + res.z = mul(a.z, b); + res.w = mul(a.w, b); + return res; +} + +template <> +inline __device__ float2 mul(float2 a, float b) { + float2 res; + res.x = a.x * b; + res.y = a.y * b; + return res; +} + +template <> +inline __device__ float4 mul(float4 a, float b) { + float4 res; + res.x = a.x * b; + res.y = a.y * b; + res.z = a.z * b; + res.w = a.w * b; + return res; +} + inline __device__ float sum(float v) { return v; } inline __device__ float sum(float2 v) { return v.x + v.y; } inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; } @@ -445,11 +491,15 @@ inline __device__ Float8_ cast_to_float(uint4 u) { } template -inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) { - K_vec qk_vec = mul(q[0], k[0]); +inline __device__ float qk_dot_(const K_vec (&q)[N], + const K_vec (&k)[N], + float inv_sqrt_dh) { + K_vec inv_q = mul(q[0], inv_sqrt_dh); + K_vec qk_vec = mul(inv_q, k[0]); #pragma unroll for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); + inv_q = mul(q[ii], inv_sqrt_dh); + qk_vec = fma(inv_q, k[ii], qk_vec); } float qk = sum(qk_vec); @@ -463,8 +513,10 @@ inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) { template struct Qk_dot { template - static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) { - return qk_dot_(q, k); + static inline __device__ float dot(const K_vec (&q)[N], + const K_vec (&k)[N], + float inv_sqrt_dh) { + return qk_dot_(q, k, inv_sqrt_dh); } }; @@ -706,7 +758,9 @@ __global__ void masked_multihead_attention_kernel( } } - float qk = Qk_dot::dot(q, k) * params.inv_sqrt_dh; + // NOTE(liyurui): We should multiple q with inv_sqrt_dh first, for dot(q, k) + // may overflow with FP16 in large model. + float qk = Qk_dot::dot(q, k, params.inv_sqrt_dh); // bool is_mask = false; if (ti < params.timestep && tid % THREADS_PER_KEY == 0) { -- GitLab