未验证 提交 e73de8ce 编写于 作者: M Molly Smith 提交者: GitHub

Optimize Softmax Kernel (#3112)

* Simplify kernel

* Coalesce memory attempt 1. Logits divergence.

* Logits fix?

* sync after every global mem access

* template on iterations. Down to 8.3% cuda time for 8k tokens

* Up to 64 iterations

* Add alibi/mask check

* fp32

* Revert builder.py

* naming. precommit

* Revert "naming. precommit"

This reverts commit 150eb7d9.

* naming. spacing

* Spacing. simplify checks

* remove bsyncs

* missed bsyncs

* precommit
上级 f2c9a827
...@@ -30,6 +30,7 @@ void CheckCudaErrorAux(const char* file, unsigned line) ...@@ -30,6 +30,7 @@ void CheckCudaErrorAux(const char* file, unsigned line)
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
template <int iterations>
__global__ void attn_softmax_v2(__half* vals, __global__ void attn_softmax_v2(__half* vals,
__half* mask, __half* mask,
__half* alibi, __half* alibi,
...@@ -45,7 +46,6 @@ __global__ void attn_softmax_v2(__half* vals, ...@@ -45,7 +46,6 @@ __global__ void attn_softmax_v2(__half* vals,
int head_offset, int head_offset,
int mask_stride, int mask_stride,
int mp_size, int mp_size,
int iterations,
int reduceWidth) int reduceWidth)
{ {
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
...@@ -75,7 +75,6 @@ __global__ void attn_softmax_v2(__half* vals, ...@@ -75,7 +75,6 @@ __global__ void attn_softmax_v2(__half* vals,
alibi_offset = (alibi_offset + ((iter_offset / num_seq) % heads)) * sequence_length; alibi_offset = (alibi_offset + ((iter_offset / num_seq) % heads)) * sequence_length;
mask_offset = mask_offset * sequence_length; mask_offset = mask_offset * sequence_length;
int seq_id = iter_offset % num_seq; int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length); int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2)) int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
...@@ -87,83 +86,95 @@ __global__ void attn_softmax_v2(__half* vals, ...@@ -87,83 +86,95 @@ __global__ void attn_softmax_v2(__half* vals,
float max_val = minus_infinity; float max_val = minus_infinity;
// if (lane == 0) printf("%d, %d: %d \n", wid, blockIdx.x, mask_offset); // if (lane == 0) printf("%d, %d: %d \n", wid, blockIdx.x, mask_offset);
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2); int data_id = i * (reduceWidth << 2) + (seq_lane);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 && bool check = (data_id >> 2) >= window_stride4;
data_id < sequence_length) { bool low_x_check = check && (data_id < sequence_length) &&
if ((sequence_length - data_id) >= 4) { (!triangular || (data_id <= seq_id)) && (data_id > window_stride);
low_data[i].x = data_id > window_stride bool low_y_check = check && ((data_id + reduceWidth) < sequence_length) &&
? __half2float(vals[data_id]) * layer_scale (!triangular || ((data_id + reduceWidth) <= seq_id)) &&
((data_id + reduceWidth) > window_stride);
bool high_x_check = check && ((data_id + reduceWidth * 2) < sequence_length) &&
(!triangular || ((data_id + reduceWidth * 2) <= seq_id)) &&
((data_id + reduceWidth * 2) > window_stride);
bool high_y_check = check && ((data_id + reduceWidth * 3) < sequence_length) &&
(!triangular || ((data_id + reduceWidth * 3) <= seq_id)) &&
((data_id + reduceWidth * 3) > window_stride);
if (mask && alibi) {
low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset])) +
(__half2float(mask[data_id + mask_offset]))
: minus_infinity; : minus_infinity;
low_data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) && low_data[i].y =
(data_id + 1) > window_stride) low_y_check ? __half2float(vals[data_id + reduceWidth]) * layer_scale +
? __half2float(vals[data_id + 1]) * layer_scale (__half2float(alibi[data_id + alibi_offset + reduceWidth])) +
(__half2float(mask[data_id + mask_offset + reduceWidth]))
: minus_infinity;
high_data[i].x =
high_x_check
? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset + reduceWidth * 2])) +
(__half2float(mask[data_id + mask_offset + reduceWidth * 2]))
: minus_infinity;
high_data[i].y =
high_y_check
? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset + reduceWidth * 3])) +
(__half2float(mask[data_id + mask_offset + reduceWidth * 3]))
: minus_infinity;
} else if (mask) {
low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale +
(__half2float(mask[data_id + mask_offset]))
: minus_infinity; : minus_infinity;
high_data[i].x = ((!triangular || ((data_id + 2) <= seq_id)) && low_data[i].y = low_y_check
(data_id + 2) > window_stride) ? __half2float(vals[data_id + reduceWidth]) * layer_scale +
? __half2float(vals[data_id + 2]) * layer_scale (__half2float(mask[data_id + mask_offset + reduceWidth]))
: minus_infinity; : minus_infinity;
high_data[i].y = ((!triangular || ((data_id + 3) <= seq_id)) && high_data[i].x =
(data_id + 3) > window_stride) high_x_check ? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale +
? __half2float(vals[data_id + 3]) * layer_scale (__half2float(mask[data_id + mask_offset + reduceWidth * 2]))
: minus_infinity;
high_data[i].y =
high_y_check ? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale +
(__half2float(mask[data_id + mask_offset + reduceWidth * 3]))
: minus_infinity;
} else if (alibi) {
low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset]))
: minus_infinity; : minus_infinity;
if (alibi) {
low_data[i].x = low_data[i].x + __half2float(alibi[data_id + alibi_offset]);
low_data[i].y = low_data[i].y =
low_data[i].y + __half2float(alibi[data_id + alibi_offset + 1]); low_y_check ? __half2float(vals[data_id + reduceWidth]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset + reduceWidth]))
: minus_infinity;
high_data[i].x = high_data[i].x =
high_data[i].x + __half2float(alibi[data_id + alibi_offset + 2]); high_x_check
? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset + reduceWidth * 2]))
: minus_infinity;
high_data[i].y = high_data[i].y =
high_data[i].y + __half2float(alibi[data_id + alibi_offset + 3]); high_y_check
} ? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale +
if (mask) { (__half2float(alibi[data_id + alibi_offset + reduceWidth * 3]))
low_data[i].x += __half2float(mask[data_id + mask_offset]); : minus_infinity;
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
high_data[i].y += __half2float(mask[data_id + mask_offset + 3]);
}
} else { } else {
low_data[i].x = data_id > window_stride low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale
? __half2float(vals[data_id]) * layer_scale
: minus_infinity; : minus_infinity;
low_data[i].y = (((!triangular || (data_id + 1) <= seq_id) && low_data[i].y = low_y_check
(data_id + 1) > window_stride) && ? __half2float(vals[data_id + reduceWidth]) * layer_scale
(data_id + 1) < sequence_length)
? __half2float(vals[data_id + 1]) * layer_scale
: minus_infinity; : minus_infinity;
high_data[i].x = (((!triangular || (data_id + 2) <= seq_id) && high_data[i].x = high_x_check
(data_id + 2) > window_stride) && ? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale
(data_id + 2) < sequence_length) : minus_infinity;
? __half2float(vals[data_id + 2]) * layer_scale high_data[i].y = high_y_check
? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale
: minus_infinity; : minus_infinity;
if (alibi) {
low_data[i].x = low_data[i].x + __half2float(alibi[data_id + alibi_offset]);
if ((data_id + 1) < sequence_length)
low_data[i].y =
low_data[i].y + __half2float(alibi[data_id + alibi_offset + 1]);
if ((data_id + 2) < sequence_length)
high_data[i].x =
high_data[i].x + __half2float(alibi[data_id + alibi_offset + 2]);
}
high_data[i].y = minus_infinity;
if (mask) {
low_data[i].x += __half2float(mask[data_id + mask_offset]);
if ((data_id + 1) < sequence_length)
low_data[i].y += __half2float(mask[data_id + mask_offset + 1]);
if ((data_id + 2) < sequence_length)
high_data[i].x += __half2float(mask[data_id + mask_offset + 2]);
}
} }
// if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id); // if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id);
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val); max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val); max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val); max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
max_val = (high_data[i].y > max_val ? high_data[i].y : max_val); max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
} else {
low_data[i].x = minus_infinity;
low_data[i].y = minus_infinity;
high_data[i].x = minus_infinity;
high_data[i].y = minus_infinity;
}
} }
for (int i = 1; i < WARP_SIZE; i *= 2) { for (int i = 1; i < WARP_SIZE; i *= 2) {
...@@ -212,26 +223,21 @@ __global__ void attn_softmax_v2(__half* vals, ...@@ -212,26 +223,21 @@ __global__ void attn_softmax_v2(__half* vals,
} }
sum += 1e-6; sum += 1e-6;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2); int data_id = i * (reduceWidth << 2) + (seq_lane);
if (data_id < sequence_length) { if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = __float2half(low_data[i].x / sum); vals[data_id] = __float2half(low_data[i].x / sum);
vals[data_id + 1] = __float2half(low_data[i].y / sum); if ((data_id + reduceWidth) < sequence_length)
vals[data_id + 2] = __float2half(high_data[i].x / sum); vals[data_id + reduceWidth] = __float2half(low_data[i].y / sum);
vals[data_id + 3] = __float2half(high_data[i].y / sum); if ((data_id + reduceWidth * 2) < sequence_length)
} else { vals[data_id + reduceWidth * 2] = __float2half(high_data[i].x / sum);
vals[data_id] = __float2half(low_data[i].x / sum); if ((data_id + reduceWidth * 3) < sequence_length)
if ((data_id + 1) < sequence_length) vals[data_id + reduceWidth * 3] = __float2half(high_data[i].y / sum);
vals[data_id + 1] = __float2half(low_data[i].y / sum);
if ((data_id + 2) < sequence_length)
vals[data_id + 2] = __float2half(high_data[i].x / sum);
}
} }
} }
} }
} }
template <int iterations>
__global__ void attn_softmax_v2(float* vals, __global__ void attn_softmax_v2(float* vals,
float* attn_mask, float* attn_mask,
float* alibi, float* alibi,
...@@ -247,7 +253,6 @@ __global__ void attn_softmax_v2(float* vals, ...@@ -247,7 +253,6 @@ __global__ void attn_softmax_v2(float* vals,
int head_offset, int head_offset,
int mask_stride, int mask_stride,
int mp_size, int mp_size,
int iterations,
int reduceWidth) int reduceWidth)
{ {
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
...@@ -269,11 +274,9 @@ __global__ void attn_softmax_v2(float* vals, ...@@ -269,11 +274,9 @@ __global__ void attn_softmax_v2(float* vals,
vals += (iter_offset * sequence_length); vals += (iter_offset * sequence_length);
int batch_idx = iter_offset / (num_seq * heads); int batch_idx = iter_offset / (num_seq * heads);
int alibi_offset = batch_idx * heads * mp_size + head_offset;
int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride); int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride);
mask_offset = mask_offset * sequence_length; mask_offset = mask_offset * sequence_length;
int seq_id = iter_offset % num_seq; int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length); int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2)) int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
...@@ -285,58 +288,43 @@ __global__ void attn_softmax_v2(float* vals, ...@@ -285,58 +288,43 @@ __global__ void attn_softmax_v2(float* vals,
float max_val = minus_infinity; float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2); int data_id = i * (reduceWidth << 2) + (seq_lane);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 && bool check = (data_id >> 2) >= window_stride4;
data_id < sequence_length) { bool x_check = check && (data_id < sequence_length) &&
if ((sequence_length - data_id) >= 4) { (!triangular || (data_id <= seq_id)) && (data_id > window_stride);
data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity); bool y_check = check && ((data_id + reduceWidth) < sequence_length) &&
data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) && (!triangular || ((data_id + reduceWidth) <= seq_id)) &&
(data_id + 1) > window_stride) ((data_id + reduceWidth) > window_stride);
? vals[data_id + 1] bool z_check = check && ((data_id + reduceWidth * 2) < sequence_length) &&
: minus_infinity; (!triangular || ((data_id + reduceWidth * 2) <= seq_id)) &&
data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) && ((data_id + reduceWidth * 2) > window_stride);
(data_id + 2) > window_stride) bool w_check = check && ((data_id + reduceWidth * 3) < sequence_length) &&
? vals[data_id + 2] (!triangular || ((data_id + reduceWidth * 3) <= seq_id)) &&
((data_id + reduceWidth * 3) > window_stride);
if (attn_mask) {
data[i].x = x_check ? vals[data_id] + attn_mask[data_id + mask_offset]
: minus_infinity; : minus_infinity;
data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) && data[i].y = y_check ? vals[data_id + reduceWidth] +
(data_id + 3) > window_stride) attn_mask[data_id + mask_offset + reduceWidth]
? vals[data_id + 3]
: minus_infinity; : minus_infinity;
if (attn_mask) { data[i].z = z_check ? vals[data_id + reduceWidth * 2] +
data[i].x += attn_mask[data_id + mask_offset]; attn_mask[data_id + mask_offset + reduceWidth * 2]
data[i].y += attn_mask[data_id + mask_offset + 1];
data[i].z += attn_mask[data_id + mask_offset + 2];
data[i].w += attn_mask[data_id + mask_offset + 3];
}
} else {
data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity;
data[i].y = (((!triangular || (data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride && (data_id + 1) < sequence_length)
? (vals[data_id + 1])
: minus_infinity; : minus_infinity;
data[i].z = (((!triangular || (data_id + 2) <= seq_id)) && data[i].w = w_check ? vals[data_id + reduceWidth * 3] +
(data_id + 2) > window_stride && (data_id + 2) < sequence_length) attn_mask[data_id + mask_offset + reduceWidth * 3]
? (vals[data_id + 2])
: minus_infinity; : minus_infinity;
data[i].w = minus_infinity; } else {
if (attn_mask) { data[i].x = x_check ? vals[data_id] : minus_infinity;
data[i].x += attn_mask[data_id + mask_offset]; data[i].y = y_check ? vals[data_id + reduceWidth] : minus_infinity;
if ((data_id + 1) < sequence_length) data[i].z = z_check ? vals[data_id + reduceWidth * 2] : minus_infinity;
data[i].y += attn_mask[data_id + mask_offset + 1]; data[i].w = w_check ? vals[data_id + reduceWidth * 3] : minus_infinity;
if ((data_id + 2) < sequence_length)
data[i].z += attn_mask[data_id + mask_offset + 2];
}
} }
max_val = (data[i].x > max_val ? data[i].x : max_val); max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val); max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val); max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val); max_val = (data[i].w > max_val ? data[i].w : max_val);
} else {
data[i].x = minus_infinity;
data[i].y = minus_infinity;
data[i].z = minus_infinity;
data[i].w = minus_infinity;
}
} }
for (int i = 1; i < WARP_SIZE; i *= 2) { for (int i = 1; i < WARP_SIZE; i *= 2) {
...@@ -387,24 +375,38 @@ __global__ void attn_softmax_v2(float* vals, ...@@ -387,24 +375,38 @@ __global__ void attn_softmax_v2(float* vals,
sum += 1e-6; sum += 1e-6;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2); int data_id = i * (reduceWidth << 2) + (seq_lane);
if (data_id < sequence_length) { if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = data[i].x / sum;
vals[data_id + 1] = data[i].y / sum;
vals[data_id + 2] = data[i].z / sum;
vals[data_id + 3] = data[i].w / sum;
} else {
vals[data_id] = data[i].x / sum; vals[data_id] = data[i].x / sum;
if ((data_id + 1) < sequence_length) vals[data_id + 1] = data[i].y / sum; if ((data_id + reduceWidth) < sequence_length)
if ((data_id + 2) < sequence_length) vals[data_id + 2] = data[i].z / sum; vals[data_id + reduceWidth] = data[i].y / sum;
} if ((data_id + reduceWidth * 2) < sequence_length)
vals[data_id + reduceWidth * 2] = data[i].z / sum;
if ((data_id + reduceWidth * 3) < sequence_length)
vals[data_id + reduceWidth * 3] = data[i].w / sum;
} }
} }
} }
} }
#define LAUNCH_ATTN_SOFTMAX_V2(iterations) \
attn_softmax_v2<iterations><<<grid, block, 0, stream>>>(vals, \
mask, \
alibi, \
layer_scale, \
triangular, \
recompute, \
local_attention, \
window_size, \
total_count, \
heads, \
sequence_length, \
num_seq, \
head_offset, \
mask_stride, \
mp_size, \
reduce_width);
template <typename T> template <typename T>
void launch_attn_softmax_v2(T* vals, void launch_attn_softmax_v2(T* vals,
T* mask, T* mask,
...@@ -450,25 +452,23 @@ void launch_attn_softmax_v2(T* vals, ...@@ -450,25 +452,23 @@ void launch_attn_softmax_v2(T* vals,
dim3 grid((total_count + partitions - 1) / partitions); dim3 grid((total_count + partitions - 1) / partitions);
dim3 block(attn_threads); dim3 block(attn_threads);
if (sequence_length <= 32768) if (sequence_length <= 32768) {
attn_softmax_v2<<<grid, block, 0, stream>>>(vals, if (iterations == 1) {
mask, LAUNCH_ATTN_SOFTMAX_V2(1);
alibi, } else if (iterations == 2) {
layer_scale, LAUNCH_ATTN_SOFTMAX_V2(2);
triangular, } else if (iterations == 4) {
recompute, LAUNCH_ATTN_SOFTMAX_V2(4);
local_attention, } else if (iterations == 8) {
window_size, LAUNCH_ATTN_SOFTMAX_V2(8);
total_count, } else if (iterations == 16) {
heads, LAUNCH_ATTN_SOFTMAX_V2(16);
sequence_length, } else if (iterations == 32) {
num_seq, LAUNCH_ATTN_SOFTMAX_V2(32);
head_offset, } else if (iterations == 64) {
mask_stride, LAUNCH_ATTN_SOFTMAX_V2(64);
mp_size, }
iterations, } else
reduce_width);
else
throw std::runtime_error("Unsupport Seq_Length!"); throw std::runtime_error("Unsupport Seq_Length!");
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册