未验证 提交 d7f4599d 编写于 作者: L LiYuRio 提交者: GitHub

Fix nan in fused multi transformer (#44093)

上级 54a9daf2
......@@ -125,7 +125,10 @@ void MasterDaemon::CloseControlFd() {
void MasterDaemon::StopByControlFd() {
VLOG(4) << ("begin to run StopByControlFd");
if (_control_fd[1] != -1) {
::write(_control_fd[1], "\0", 1);
PADDLE_ENFORCE_NE(::write(_control_fd[1], "\0", 1),
-1,
platform::errors::Fatal(
"failed to write control pipe errno:%d", errno));
// close the write end of the pipe
::close(_control_fd[1]);
_control_fd[1] = -1;
......
......@@ -294,6 +294,52 @@ inline __device__ uint4 mul(uint4 a, uint4 b) {
return c;
}
template <>
inline __device__ uint32_t mul(uint32_t a, float b) {
float2 tmp = half2_to_float2(a);
float2 tmp_res;
tmp_res.x = tmp.x * b;
tmp_res.y = tmp.y * b;
uint32_t res = float2_to_half2(tmp_res);
return res;
}
template <>
inline __device__ uint2 mul(uint2 a, float b) {
uint2 res;
res.x = mul<uint32_t, uint32_t, float>(a.x, b);
res.y = mul<uint32_t, uint32_t, float>(a.y, b);
return res;
}
template <>
inline __device__ uint4 mul(uint4 a, float b) {
uint4 res;
res.x = mul<uint32_t, uint32_t, float>(a.x, b);
res.y = mul<uint32_t, uint32_t, float>(a.y, b);
res.z = mul<uint32_t, uint32_t, float>(a.z, b);
res.w = mul<uint32_t, uint32_t, float>(a.w, b);
return res;
}
template <>
inline __device__ float2 mul(float2 a, float b) {
float2 res;
res.x = a.x * b;
res.y = a.y * b;
return res;
}
template <>
inline __device__ float4 mul(float4 a, float b) {
float4 res;
res.x = a.x * b;
res.y = a.y * b;
res.z = a.z * b;
res.w = a.w * b;
return res;
}
inline __device__ float sum(float v) { return v; }
inline __device__ float sum(float2 v) { return v.x + v.y; }
inline __device__ float sum(float4 v) { return v.x + v.y + v.z + v.w; }
......@@ -445,11 +491,15 @@ inline __device__ Float8_ cast_to_float(uint4 u) {
}
template <int THREADS_PER_KEY, typename K_vec, int N>
inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) {
K_vec qk_vec = mul<K_vec, K_vec, K_vec>(q[0], k[0]);
inline __device__ float qk_dot_(const K_vec (&q)[N],
const K_vec (&k)[N],
float inv_sqrt_dh) {
K_vec inv_q = mul<K_vec, K_vec, float>(q[0], inv_sqrt_dh);
K_vec qk_vec = mul<K_vec, K_vec, K_vec>(inv_q, k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
inv_q = mul<K_vec, K_vec, float>(q[ii], inv_sqrt_dh);
qk_vec = fma(inv_q, k[ii], qk_vec);
}
float qk = sum(qk_vec);
......@@ -463,8 +513,10 @@ inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) {
template <typename T, int THREADS_PER_KEY>
struct Qk_dot {
template <typename K_vec, int N>
static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) {
return qk_dot_<THREADS_PER_KEY>(q, k);
static inline __device__ float dot(const K_vec (&q)[N],
const K_vec (&k)[N],
float inv_sqrt_dh) {
return qk_dot_<THREADS_PER_KEY>(q, k, inv_sqrt_dh);
}
};
......@@ -706,7 +758,9 @@ __global__ void masked_multihead_attention_kernel(
}
}
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q, k) * params.inv_sqrt_dh;
// NOTE(liyurui): We should multiple q with inv_sqrt_dh first, for dot(q, k)
// may overflow with FP16 in large model.
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q, k, params.inv_sqrt_dh);
// bool is_mask = false;
if (ti < params.timestep && tid % THREADS_PER_KEY == 0) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册