提交 150eb7d9 编写于 作者: M Molly Smith

naming. precommit

上级 bc450d48
......@@ -86,69 +86,89 @@ __global__ void attn_softmax_v2(__half* vals,
// if (lane == 0) printf("%d, %d: %d \n", wid, blockIdx.x, mask_offset);
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane);
bool check1 = ((!triangular || (data_id <= seq_id)) && (data_id >> 2) >= window_stride4 && data_id < sequence_length);
bool check1 = ((!triangular || (data_id <= seq_id)) &&
(data_id >> 2) >= window_stride4 && data_id < sequence_length);
bool low_x_check = check1 && (data_id > window_stride);
bool low_y_check = check1 && ((data_id + reduceWidth) < sequence_length) && ((!triangular || ((data_id + reduceWidth) <= seq_id)) && (data_id + reduceWidth) > window_stride);
bool high_x_check = check1 && ((data_id + reduceWidth*2) < sequence_length) && ((!triangular || ((data_id + reduceWidth*2) <= seq_id)) && (data_id + reduceWidth*2) > window_stride);
bool high_y_check = check1 && ((data_id + reduceWidth*3) < sequence_length) && ((!triangular || ((data_id + reduceWidth*3) <= seq_id)) && (data_id + reduceWidth*3) > window_stride);
if (mask && alibi){
low_data[i].x = low_x_check
? __half2float(vals[data_id]) * layer_scale + (__half2float(alibi[data_id + alibi_offset])) + (__half2float(mask[data_id + mask_offset]))
bool low_y_check = check1 && ((data_id + reduceWidth) < sequence_length) &&
((!triangular || ((data_id + reduceWidth) <= seq_id)) &&
(data_id + reduceWidth) > window_stride);
bool high_x_check = check1 && ((data_id + reduceWidth * 2) < sequence_length) &&
((!triangular || ((data_id + reduceWidth * 2) <= seq_id)) &&
(data_id + reduceWidth * 2) > window_stride);
bool high_y_check = check1 && ((data_id + reduceWidth * 3) < sequence_length) &&
((!triangular || ((data_id + reduceWidth * 3) <= seq_id)) &&
(data_id + reduceWidth * 3) > window_stride);
if (mask && alibi) {
low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset])) +
(__half2float(mask[data_id + mask_offset]))
: minus_infinity;
b.sync();
low_data[i].y = low_y_check
? __half2float(vals[data_id + reduceWidth]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth])) + (__half2float(mask[data_id + mask_offset + reduceWidth]))
low_data[i].y =
low_y_check ? __half2float(vals[data_id + reduceWidth]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset + reduceWidth])) +
(__half2float(mask[data_id + mask_offset + reduceWidth]))
: minus_infinity;
b.sync();
high_data[i].x = high_x_check
? __half2float(vals[data_id + reduceWidth*2]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth*2])) + (__half2float(mask[data_id + mask_offset + reduceWidth*2]))
high_data[i].x =
high_x_check
? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset + reduceWidth * 2])) +
(__half2float(mask[data_id + mask_offset + reduceWidth * 2]))
: minus_infinity;
b.sync();
high_data[i].y = high_y_check
? __half2float(vals[data_id + reduceWidth*3]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth*3])) + (__half2float(mask[data_id + mask_offset + reduceWidth*3]))
high_data[i].y =
high_y_check
? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset + reduceWidth * 3])) +
(__half2float(mask[data_id + mask_offset + reduceWidth * 3]))
: minus_infinity;
b.sync();
}
else if (mask){
low_data[i].x = low_x_check
? __half2float(vals[data_id]) * layer_scale + (__half2float(mask[data_id + mask_offset]))
} else if (mask) {
low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale +
(__half2float(mask[data_id + mask_offset]))
: minus_infinity;
b.sync();
low_data[i].y = low_y_check
? __half2float(vals[data_id + reduceWidth]) * layer_scale + (__half2float(mask[data_id + mask_offset + reduceWidth]))
? __half2float(vals[data_id + reduceWidth]) * layer_scale +
(__half2float(mask[data_id + mask_offset + reduceWidth]))
: minus_infinity;
b.sync();
high_data[i].x = high_x_check
? __half2float(vals[data_id + reduceWidth*2]) * layer_scale + (__half2float(mask[data_id + mask_offset + reduceWidth*2]))
high_data[i].x =
high_x_check ? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale +
(__half2float(mask[data_id + mask_offset + reduceWidth * 2]))
: minus_infinity;
b.sync();
high_data[i].y = high_y_check
? __half2float(vals[data_id + reduceWidth*3]) * layer_scale + (__half2float(mask[data_id + mask_offset + reduceWidth*3]))
high_data[i].y =
high_y_check ? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale +
(__half2float(mask[data_id + mask_offset + reduceWidth * 3]))
: minus_infinity;
b.sync();
}
else if (alibi){
low_data[i].x = low_x_check
? __half2float(vals[data_id]) * layer_scale + (__half2float(alibi[data_id + alibi_offset]))
} else if (alibi) {
low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset]))
: minus_infinity;
b.sync();
low_data[i].y = low_y_check
? __half2float(vals[data_id + reduceWidth]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth]))
low_data[i].y =
low_y_check ? __half2float(vals[data_id + reduceWidth]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset + reduceWidth]))
: minus_infinity;
b.sync();
high_data[i].x = high_x_check
? __half2float(vals[data_id + reduceWidth*2]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth*2]))
high_data[i].x =
high_x_check
? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset + reduceWidth * 2]))
: minus_infinity;
b.sync();
high_data[i].y = high_y_check
? __half2float(vals[data_id + reduceWidth*3]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth*3]))
high_data[i].y =
high_y_check
? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset + reduceWidth * 3]))
: minus_infinity;
b.sync();
}
else {
low_data[i].x = low_x_check
? __half2float(vals[data_id]) * layer_scale
} else {
low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale
: minus_infinity;
b.sync();
low_data[i].y = low_y_check
......@@ -156,22 +176,20 @@ __global__ void attn_softmax_v2(__half* vals,
: minus_infinity;
b.sync();
high_data[i].x = high_x_check
? __half2float(vals[data_id + reduceWidth*2]) * layer_scale
? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale
: minus_infinity;
b.sync();
high_data[i].y = high_y_check
? __half2float(vals[data_id + reduceWidth*3]) * layer_scale
? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale
: minus_infinity;
b.sync();
}
// if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id);
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
......@@ -228,11 +246,11 @@ __global__ void attn_softmax_v2(__half* vals,
if ((data_id + reduceWidth) < sequence_length)
vals[data_id + reduceWidth] = __float2half(low_data[i].y / sum);
b.sync();
if ((data_id + reduceWidth*2) < sequence_length)
vals[data_id + reduceWidth*2] = __float2half(high_data[i].x / sum);
if ((data_id + reduceWidth * 2) < sequence_length)
vals[data_id + reduceWidth * 2] = __float2half(high_data[i].x / sum);
b.sync();
if ((data_id + reduceWidth*3) < sequence_length)
vals[data_id + reduceWidth*3] = __float2half(high_data[i].y / sum);
if ((data_id + reduceWidth * 3) < sequence_length)
vals[data_id + reduceWidth * 3] = __float2half(high_data[i].y / sum);
b.sync();
}
}
......@@ -291,46 +309,43 @@ __global__ void attn_softmax_v2(float* vals,
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane);
bool check1 = ((!triangular || (data_id <= seq_id)) && (data_id >> 2) >= window_stride4 && data_id < sequence_length);
bool low_x_check = check1 && (data_id > window_stride);
bool low_y_check = check1 && ((data_id + reduceWidth) < sequence_length) && ((!triangular || ((data_id + reduceWidth) <= seq_id)) && (data_id + reduceWidth) > window_stride);
bool high_x_check = check1 && ((data_id + reduceWidth*2) < sequence_length) && ((!triangular || ((data_id + reduceWidth*2) <= seq_id)) && (data_id + reduceWidth*2) > window_stride);
bool high_y_check = check1 && ((data_id + reduceWidth*3) < sequence_length) && ((!triangular || ((data_id + reduceWidth*3) <= seq_id)) && (data_id + reduceWidth*3) > window_stride);
if (attn_mask){
data[i].x = low_x_check
? vals[data_id] + attn_mask[data_id + mask_offset]
bool check1 = ((!triangular || (data_id <= seq_id)) &&
(data_id >> 2) >= window_stride4 && data_id < sequence_length);
bool x_check = check1 && (data_id > window_stride);
bool y_check = check1 && ((data_id + reduceWidth) < sequence_length) &&
((!triangular || ((data_id + reduceWidth) <= seq_id)) &&
(data_id + reduceWidth) > window_stride);
bool z_check = check1 && ((data_id + reduceWidth * 2) < sequence_length) &&
((!triangular || ((data_id + reduceWidth * 2) <= seq_id)) &&
(data_id + reduceWidth * 2) > window_stride);
bool w_check = check1 && ((data_id + reduceWidth * 3) < sequence_length) &&
((!triangular || ((data_id + reduceWidth * 3) <= seq_id)) &&
(data_id + reduceWidth * 3) > window_stride);
if (attn_mask) {
data[i].x = x_check ? vals[data_id] + attn_mask[data_id + mask_offset]
: minus_infinity;
b.sync();
data[i].y = low_y_check
? vals[data_id + reduceWidth] + attn_mask[data_id + mask_offset + reduceWidth]
data[i].y = y_check ? vals[data_id + reduceWidth] +
attn_mask[data_id + mask_offset + reduceWidth]
: minus_infinity;
b.sync();
data[i].z = high_x_check
? vals[data_id + reduceWidth*2] + attn_mask[data_id + mask_offset + reduceWidth*2]
data[i].z = z_check ? vals[data_id + reduceWidth * 2] +
attn_mask[data_id + mask_offset + reduceWidth * 2]
: minus_infinity;
b.sync();
data[i].w = high_y_check
? vals[data_id + reduceWidth*3] + attn_mask[data_id + mask_offset + reduceWidth*3]
data[i].w = w_check ? vals[data_id + reduceWidth * 3] +
attn_mask[data_id + mask_offset + reduceWidth * 3]
: minus_infinity;
b.sync();
}
else {
data[i].x = low_x_check
? vals[data_id]
: minus_infinity;
} else {
data[i].x = x_check ? vals[data_id] : minus_infinity;
b.sync();
data[i].y = low_y_check
? vals[data_id + reduceWidth]
: minus_infinity;
data[i].y = y_check ? vals[data_id + reduceWidth] : minus_infinity;
b.sync();
data[i].z = high_x_check
? vals[data_id + reduceWidth*2]
: minus_infinity;
data[i].z = z_check ? vals[data_id + reduceWidth * 2] : minus_infinity;
b.sync();
data[i].w = high_y_check
? vals[data_id + reduceWidth*3]
: minus_infinity;
data[i].w = w_check ? vals[data_id + reduceWidth * 3] : minus_infinity;
b.sync();
}
......@@ -338,7 +353,6 @@ __global__ void attn_softmax_v2(float* vals,
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
......@@ -396,11 +410,11 @@ __global__ void attn_softmax_v2(float* vals,
if ((data_id + reduceWidth) < sequence_length)
vals[data_id + reduceWidth] = data[i].y / sum;
b.sync();
if ((data_id + reduceWidth*2) < sequence_length)
vals[data_id + reduceWidth*2] = data[i].z / sum;
if ((data_id + reduceWidth * 2) < sequence_length)
vals[data_id + reduceWidth * 2] = data[i].z / sum;
b.sync();
if ((data_id + reduceWidth*3) < sequence_length)
vals[data_id + reduceWidth*3] = data[i].w / sum;
if ((data_id + reduceWidth * 3) < sequence_length)
vals[data_id + reduceWidth * 3] = data[i].w / sum;
b.sync();
}
}
......@@ -408,9 +422,22 @@ __global__ void attn_softmax_v2(float* vals,
}
#define LAUNCH_ATTN_SOFTMAX_V2(iterations) \
attn_softmax_v2<iterations><<<grid, block, 0, stream>>> \
(vals,mask,alibi,layer_scale,triangular,recompute,local_attention,window_size,total_count,heads, \
sequence_length,num_seq,head_offset,mask_stride,mp_size,reduce_width);
attn_softmax_v2<iterations><<<grid, block, 0, stream>>>(vals, \
mask, \
alibi, \
layer_scale, \
triangular, \
recompute, \
local_attention, \
window_size, \
total_count, \
heads, \
sequence_length, \
num_seq, \
head_offset, \
mask_stride, \
mp_size, \
reduce_width);
template <typename T>
void launch_attn_softmax_v2(T* vals,
......@@ -457,30 +484,23 @@ void launch_attn_softmax_v2(T* vals,
dim3 grid((total_count + partitions - 1) / partitions);
dim3 block(attn_threads);
if (sequence_length <= 32768){
if (iterations == 1){
if (sequence_length <= 32768) {
if (iterations == 1) {
LAUNCH_ATTN_SOFTMAX_V2(1);
}
else if (iterations == 2){
} else if (iterations == 2) {
LAUNCH_ATTN_SOFTMAX_V2(2);
}
else if (iterations == 4){
} else if (iterations == 4) {
LAUNCH_ATTN_SOFTMAX_V2(4);
}
else if (iterations == 8){
} else if (iterations == 8) {
LAUNCH_ATTN_SOFTMAX_V2(8);
}
else if (iterations == 16){
} else if (iterations == 16) {
LAUNCH_ATTN_SOFTMAX_V2(16);
}
else if (iterations == 32){
} else if (iterations == 32) {
LAUNCH_ATTN_SOFTMAX_V2(32);
}
else if (iterations == 64){
} else if (iterations == 64) {
LAUNCH_ATTN_SOFTMAX_V2(64);
}
}
else
} else
throw std::runtime_error("Unsupport Seq_Length!");
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册