提交 b7ad2a2d 编写于 作者: M Molly Smith

fp32

上级 2d7d1749
...@@ -276,11 +276,9 @@ __global__ void attn_softmax_v2(float* vals, ...@@ -276,11 +276,9 @@ __global__ void attn_softmax_v2(float* vals,
vals += (iter_offset * sequence_length); vals += (iter_offset * sequence_length);
int batch_idx = iter_offset / (num_seq * heads); int batch_idx = iter_offset / (num_seq * heads);
int alibi_offset = batch_idx * heads * mp_size + head_offset;
int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride); int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride);
mask_offset = mask_offset * sequence_length; mask_offset = mask_offset * sequence_length;
int seq_id = iter_offset % num_seq; int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length); int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2)) int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
...@@ -292,58 +290,55 @@ __global__ void attn_softmax_v2(float* vals, ...@@ -292,58 +290,55 @@ __global__ void attn_softmax_v2(float* vals,
float max_val = minus_infinity; float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2); int data_id = i * (reduceWidth << 2) + (seq_lane);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 && bool check1 = ((!triangular || (data_id <= seq_id)) && (data_id >> 2) >= window_stride4 && data_id < sequence_length);
data_id < sequence_length) { bool low_x_check = check1 && (data_id > window_stride);
if ((sequence_length - data_id) >= 4) { bool low_y_check = check1 && ((data_id + reduceWidth) < sequence_length) && ((!triangular || ((data_id + reduceWidth) <= seq_id)) && (data_id + reduceWidth) > window_stride);
data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity); bool high_x_check = check1 && ((data_id + reduceWidth*2) < sequence_length) && ((!triangular || ((data_id + reduceWidth*2) <= seq_id)) && (data_id + reduceWidth*2) > window_stride);
data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) && bool high_y_check = check1 && ((data_id + reduceWidth*3) < sequence_length) && ((!triangular || ((data_id + reduceWidth*3) <= seq_id)) && (data_id + reduceWidth*3) > window_stride);
(data_id + 1) > window_stride)
? vals[data_id + 1] if (attn_mask){
: minus_infinity; data[i].x = low_x_check
data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) && ? vals[data_id] + attn_mask[data_id + mask_offset]
(data_id + 2) > window_stride) : minus_infinity;
? vals[data_id + 2] b.sync();
: minus_infinity; data[i].y = low_y_check
data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) && ? vals[data_id + reduceWidth] + attn_mask[data_id + mask_offset + reduceWidth]
(data_id + 3) > window_stride) : minus_infinity;
? vals[data_id + 3] b.sync();
: minus_infinity; data[i].z = high_x_check
if (attn_mask) { ? vals[data_id + reduceWidth*2] + attn_mask[data_id + mask_offset + reduceWidth*2]
data[i].x += attn_mask[data_id + mask_offset]; : minus_infinity;
data[i].y += attn_mask[data_id + mask_offset + 1]; b.sync();
data[i].z += attn_mask[data_id + mask_offset + 2]; data[i].w = high_y_check
data[i].w += attn_mask[data_id + mask_offset + 3]; ? vals[data_id + reduceWidth*3] + attn_mask[data_id + mask_offset + reduceWidth*3]
} : minus_infinity;
} else { b.sync();
data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity; }
data[i].y = (((!triangular || (data_id + 1) <= seq_id)) && else {
(data_id + 1) > window_stride && (data_id + 1) < sequence_length) data[i].x = low_x_check
? (vals[data_id + 1]) ? vals[data_id]
: minus_infinity; : minus_infinity;
data[i].z = (((!triangular || (data_id + 2) <= seq_id)) && b.sync();
(data_id + 2) > window_stride && (data_id + 2) < sequence_length) data[i].y = low_y_check
? (vals[data_id + 2]) ? vals[data_id + reduceWidth]
: minus_infinity; : minus_infinity;
data[i].w = minus_infinity; b.sync();
if (attn_mask) { data[i].z = high_x_check
data[i].x += attn_mask[data_id + mask_offset]; ? vals[data_id + reduceWidth*2]
if ((data_id + 1) < sequence_length) : minus_infinity;
data[i].y += attn_mask[data_id + mask_offset + 1]; b.sync();
if ((data_id + 2) < sequence_length) data[i].w = high_y_check
data[i].z += attn_mask[data_id + mask_offset + 2]; ? vals[data_id + reduceWidth*3]
} : minus_infinity;
} b.sync();
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
} else {
data[i].x = minus_infinity;
data[i].y = minus_infinity;
data[i].z = minus_infinity;
data[i].w = minus_infinity;
} }
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
} }
for (int i = 1; i < WARP_SIZE; i *= 2) { for (int i = 1; i < WARP_SIZE; i *= 2) {
...@@ -394,19 +389,19 @@ __global__ void attn_softmax_v2(float* vals, ...@@ -394,19 +389,19 @@ __global__ void attn_softmax_v2(float* vals,
sum += 1e-6; sum += 1e-6;
for (int i = 0; i < iterations; i++) { for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2); int data_id = i * (reduceWidth << 2) + (seq_lane);
if (data_id < sequence_length) { if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) { vals[data_id] = data[i].x / sum;
vals[data_id] = data[i].x / sum; b.sync();
vals[data_id + 1] = data[i].y / sum; if ((data_id + reduceWidth) < sequence_length)
vals[data_id + 2] = data[i].z / sum; vals[data_id + reduceWidth] = data[i].y / sum;
vals[data_id + 3] = data[i].w / sum; b.sync();
} else { if ((data_id + reduceWidth*2) < sequence_length)
vals[data_id] = data[i].x / sum; vals[data_id + reduceWidth*2] = data[i].z / sum;
if ((data_id + 1) < sequence_length) vals[data_id + 1] = data[i].y / sum; b.sync();
if ((data_id + 2) < sequence_length) vals[data_id + 2] = data[i].z / sum; if ((data_id + reduceWidth*3) < sequence_length)
} vals[data_id + reduceWidth*3] = data[i].w / sum;
b.sync();
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册