提交 b7ad2a2d 编写于 作者: M Molly Smith

fp32

上级 2d7d1749
......@@ -276,11 +276,9 @@ __global__ void attn_softmax_v2(float* vals,
vals += (iter_offset * sequence_length);
int batch_idx = iter_offset / (num_seq * heads);
int alibi_offset = batch_idx * heads * mp_size + head_offset;
int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride);
mask_offset = mask_offset * sequence_length;
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;
int real_seq_id = seq_id + (num_seq == sequence_length ? 0 : sequence_length);
int window_stride4 = (local_attention && (real_seq_id >> 2) > (window_size >> 2))
......@@ -292,58 +290,55 @@ __global__ void attn_softmax_v2(float* vals,
float max_val = minus_infinity;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
if ((!triangular || ((data_id >> 2) <= seq_id4)) && (data_id >> 2) >= window_stride4 &&
data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
data[i].x = (data_id > window_stride ? vals[data_id] : minus_infinity);
data[i].y = ((!triangular || ((data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride)
? vals[data_id + 1]
: minus_infinity;
data[i].z = ((!triangular || ((data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride)
? vals[data_id + 2]
: minus_infinity;
data[i].w = ((!triangular || ((data_id + 3) <= seq_id)) &&
(data_id + 3) > window_stride)
? vals[data_id + 3]
: minus_infinity;
if (attn_mask) {
data[i].x += attn_mask[data_id + mask_offset];
data[i].y += attn_mask[data_id + mask_offset + 1];
data[i].z += attn_mask[data_id + mask_offset + 2];
data[i].w += attn_mask[data_id + mask_offset + 3];
}
} else {
data[i].x = data_id > window_stride ? vals[data_id] : minus_infinity;
data[i].y = (((!triangular || (data_id + 1) <= seq_id)) &&
(data_id + 1) > window_stride && (data_id + 1) < sequence_length)
? (vals[data_id + 1])
: minus_infinity;
data[i].z = (((!triangular || (data_id + 2) <= seq_id)) &&
(data_id + 2) > window_stride && (data_id + 2) < sequence_length)
? (vals[data_id + 2])
: minus_infinity;
data[i].w = minus_infinity;
if (attn_mask) {
data[i].x += attn_mask[data_id + mask_offset];
if ((data_id + 1) < sequence_length)
data[i].y += attn_mask[data_id + mask_offset + 1];
if ((data_id + 2) < sequence_length)
data[i].z += attn_mask[data_id + mask_offset + 2];
}
}
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
} else {
data[i].x = minus_infinity;
data[i].y = minus_infinity;
data[i].z = minus_infinity;
data[i].w = minus_infinity;
int data_id = i * (reduceWidth << 2) + (seq_lane);
bool check1 = ((!triangular || (data_id <= seq_id)) && (data_id >> 2) >= window_stride4 && data_id < sequence_length);
bool low_x_check = check1 && (data_id > window_stride);
bool low_y_check = check1 && ((data_id + reduceWidth) < sequence_length) && ((!triangular || ((data_id + reduceWidth) <= seq_id)) && (data_id + reduceWidth) > window_stride);
bool high_x_check = check1 && ((data_id + reduceWidth*2) < sequence_length) && ((!triangular || ((data_id + reduceWidth*2) <= seq_id)) && (data_id + reduceWidth*2) > window_stride);
bool high_y_check = check1 && ((data_id + reduceWidth*3) < sequence_length) && ((!triangular || ((data_id + reduceWidth*3) <= seq_id)) && (data_id + reduceWidth*3) > window_stride);
if (attn_mask){
data[i].x = low_x_check
? vals[data_id] + attn_mask[data_id + mask_offset]
: minus_infinity;
b.sync();
data[i].y = low_y_check
? vals[data_id + reduceWidth] + attn_mask[data_id + mask_offset + reduceWidth]
: minus_infinity;
b.sync();
data[i].z = high_x_check
? vals[data_id + reduceWidth*2] + attn_mask[data_id + mask_offset + reduceWidth*2]
: minus_infinity;
b.sync();
data[i].w = high_y_check
? vals[data_id + reduceWidth*3] + attn_mask[data_id + mask_offset + reduceWidth*3]
: minus_infinity;
b.sync();
}
else {
data[i].x = low_x_check
? vals[data_id]
: minus_infinity;
b.sync();
data[i].y = low_y_check
? vals[data_id + reduceWidth]
: minus_infinity;
b.sync();
data[i].z = high_x_check
? vals[data_id + reduceWidth*2]
: minus_infinity;
b.sync();
data[i].w = high_y_check
? vals[data_id + reduceWidth*3]
: minus_infinity;
b.sync();
}
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
......@@ -394,19 +389,19 @@ __global__ void attn_softmax_v2(float* vals,
sum += 1e-6;
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane << 2);
int data_id = i * (reduceWidth << 2) + (seq_lane);
if (data_id < sequence_length) {
if ((sequence_length - data_id) >= 4) {
vals[data_id] = data[i].x / sum;
vals[data_id + 1] = data[i].y / sum;
vals[data_id + 2] = data[i].z / sum;
vals[data_id + 3] = data[i].w / sum;
} else {
vals[data_id] = data[i].x / sum;
if ((data_id + 1) < sequence_length) vals[data_id + 1] = data[i].y / sum;
if ((data_id + 2) < sequence_length) vals[data_id + 2] = data[i].z / sum;
}
vals[data_id] = data[i].x / sum;
b.sync();
if ((data_id + reduceWidth) < sequence_length)
vals[data_id + reduceWidth] = data[i].y / sum;
b.sync();
if ((data_id + reduceWidth*2) < sequence_length)
vals[data_id + reduceWidth*2] = data[i].z / sum;
b.sync();
if ((data_id + reduceWidth*3) < sequence_length)
vals[data_id + reduceWidth*3] = data[i].w / sum;
b.sync();
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册