提交 150eb7d9 编写于 作者: M Molly Smith

naming. precommit

上级 bc450d48
......@@ -161,7 +161,7 @@ void launch_fused_add2<float>(float* out,
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(total_count, out, inp1, inp2);
}
......@@ -178,7 +178,7 @@ void launch_fused_add2<__half>(__half* out,
int total_count = batch_size * seq_length * hidden_dim / 4;
dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4);
fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(total_count, out, inp1, inp2);
}
......
......@@ -86,92 +86,110 @@ __global__ void attn_softmax_v2(__half* vals,
// if (lane == 0) printf("%d, %d: %d \n", wid, blockIdx.x, mask_offset);
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane);
bool check1 = ((!triangular || (data_id <= seq_id)) && (data_id >> 2) >= window_stride4 && data_id < sequence_length);
bool low_x_check = check1 && (data_id > window_stride);
bool low_y_check = check1 && ((data_id + reduceWidth) < sequence_length) && ((!triangular || ((data_id + reduceWidth) <= seq_id)) && (data_id + reduceWidth) > window_stride);
bool high_x_check = check1 && ((data_id + reduceWidth*2) < sequence_length) && ((!triangular || ((data_id + reduceWidth*2) <= seq_id)) && (data_id + reduceWidth*2) > window_stride);
bool high_y_check = check1 && ((data_id + reduceWidth*3) < sequence_length) && ((!triangular || ((data_id + reduceWidth*3) <= seq_id)) && (data_id + reduceWidth*3) > window_stride);
if (mask && alibi){
low_data[i].x = low_x_check
? __half2float(vals[data_id]) * layer_scale + (__half2float(alibi[data_id + alibi_offset])) + (__half2float(mask[data_id + mask_offset]))
: minus_infinity;
b.sync();
low_data[i].y = low_y_check
? __half2float(vals[data_id + reduceWidth]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth])) + (__half2float(mask[data_id + mask_offset + reduceWidth]))
: minus_infinity;
b.sync();
high_data[i].x = high_x_check
? __half2float(vals[data_id + reduceWidth*2]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth*2])) + (__half2float(mask[data_id + mask_offset + reduceWidth*2]))
: minus_infinity;
b.sync();
high_data[i].y = high_y_check
? __half2float(vals[data_id + reduceWidth*3]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth*3])) + (__half2float(mask[data_id + mask_offset + reduceWidth*3]))
: minus_infinity;
b.sync();
}
else if (mask){
low_data[i].x = low_x_check
? __half2float(vals[data_id]) * layer_scale + (__half2float(mask[data_id + mask_offset]))
: minus_infinity;
b.sync();
low_data[i].y = low_y_check
? __half2float(vals[data_id + reduceWidth]) * layer_scale + (__half2float(mask[data_id + mask_offset + reduceWidth]))
: minus_infinity;
b.sync();
high_data[i].x = high_x_check
? __half2float(vals[data_id + reduceWidth*2]) * layer_scale + (__half2float(mask[data_id + mask_offset + reduceWidth*2]))
: minus_infinity;
b.sync();
high_data[i].y = high_y_check
? __half2float(vals[data_id + reduceWidth*3]) * layer_scale + (__half2float(mask[data_id + mask_offset + reduceWidth*3]))
: minus_infinity;
b.sync();
}
else if (alibi){
low_data[i].x = low_x_check
? __half2float(vals[data_id]) * layer_scale + (__half2float(alibi[data_id + alibi_offset]))
bool check1 = ((!triangular || (data_id <= seq_id)) &&
(data_id >> 2) >= window_stride4 && data_id < sequence_length);
bool low_x_check = check1 && (data_id > window_stride);
bool low_y_check = check1 && ((data_id + reduceWidth) < sequence_length) &&
((!triangular || ((data_id + reduceWidth) <= seq_id)) &&
(data_id + reduceWidth) > window_stride);
bool high_x_check = check1 && ((data_id + reduceWidth * 2) < sequence_length) &&
((!triangular || ((data_id + reduceWidth * 2) <= seq_id)) &&
(data_id + reduceWidth * 2) > window_stride);
bool high_y_check = check1 && ((data_id + reduceWidth * 3) < sequence_length) &&
((!triangular || ((data_id + reduceWidth * 3) <= seq_id)) &&
(data_id + reduceWidth * 3) > window_stride);
if (mask && alibi) {
low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset])) +
(__half2float(mask[data_id + mask_offset]))
: minus_infinity;
b.sync();
low_data[i].y =
low_y_check ? __half2float(vals[data_id + reduceWidth]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset + reduceWidth])) +
(__half2float(mask[data_id + mask_offset + reduceWidth]))
: minus_infinity;
b.sync();
low_data[i].y = low_y_check
? __half2float(vals[data_id + reduceWidth]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth]))
: minus_infinity;
b.sync();
high_data[i].x = high_x_check
? __half2float(vals[data_id + reduceWidth*2]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth*2]))
: minus_infinity;
b.sync();
high_data[i].y = high_y_check
? __half2float(vals[data_id + reduceWidth*3]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth*3]))
: minus_infinity;
b.sync();
}
else {
low_data[i].x = low_x_check
? __half2float(vals[data_id]) * layer_scale
b.sync();
high_data[i].x =
high_x_check
? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset + reduceWidth * 2])) +
(__half2float(mask[data_id + mask_offset + reduceWidth * 2]))
: minus_infinity;
b.sync();
high_data[i].y =
high_y_check
? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset + reduceWidth * 3])) +
(__half2float(mask[data_id + mask_offset + reduceWidth * 3]))
: minus_infinity;
b.sync();
} else if (mask) {
low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale +
(__half2float(mask[data_id + mask_offset]))
: minus_infinity;
b.sync();
low_data[i].y = low_y_check
? __half2float(vals[data_id + reduceWidth]) * layer_scale +
(__half2float(mask[data_id + mask_offset + reduceWidth]))
: minus_infinity;
b.sync();
high_data[i].x =
high_x_check ? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale +
(__half2float(mask[data_id + mask_offset + reduceWidth * 2]))
: minus_infinity;
b.sync();
high_data[i].y =
high_y_check ? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale +
(__half2float(mask[data_id + mask_offset + reduceWidth * 3]))
: minus_infinity;
b.sync();
} else if (alibi) {
low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset]))
: minus_infinity;
b.sync();
low_data[i].y =
low_y_check ? __half2float(vals[data_id + reduceWidth]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset + reduceWidth]))
: minus_infinity;
b.sync();
low_data[i].y = low_y_check
? __half2float(vals[data_id + reduceWidth]) * layer_scale
: minus_infinity;
b.sync();
high_data[i].x = high_x_check
? __half2float(vals[data_id + reduceWidth*2]) * layer_scale
: minus_infinity;
b.sync();
high_data[i].y = high_y_check
? __half2float(vals[data_id + reduceWidth*3]) * layer_scale
: minus_infinity;
b.sync();
b.sync();
high_data[i].x =
high_x_check
? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset + reduceWidth * 2]))
: minus_infinity;
b.sync();
high_data[i].y =
high_y_check
? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale +
(__half2float(alibi[data_id + alibi_offset + reduceWidth * 3]))
: minus_infinity;
b.sync();
} else {
low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale
: minus_infinity;
b.sync();
low_data[i].y = low_y_check
? __half2float(vals[data_id + reduceWidth]) * layer_scale
: minus_infinity;
b.sync();
high_data[i].x = high_x_check
? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale
: minus_infinity;
b.sync();
high_data[i].y = high_y_check
? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale
: minus_infinity;
b.sync();
}
// if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id);
max_val = (low_data[i].x > max_val ? low_data[i].x : max_val);
max_val = (low_data[i].y > max_val ? low_data[i].y : max_val);
max_val = (high_data[i].x > max_val ? high_data[i].x : max_val);
max_val = (high_data[i].y > max_val ? high_data[i].y : max_val);
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
......@@ -227,13 +245,13 @@ __global__ void attn_softmax_v2(__half* vals,
b.sync();
if ((data_id + reduceWidth) < sequence_length)
vals[data_id + reduceWidth] = __float2half(low_data[i].y / sum);
b.sync();
if ((data_id + reduceWidth*2) < sequence_length)
vals[data_id + reduceWidth*2] = __float2half(high_data[i].x / sum);
b.sync();
if ((data_id + reduceWidth*3) < sequence_length)
vals[data_id + reduceWidth*3] = __float2half(high_data[i].y / sum);
b.sync();
b.sync();
if ((data_id + reduceWidth * 2) < sequence_length)
vals[data_id + reduceWidth * 2] = __float2half(high_data[i].x / sum);
b.sync();
if ((data_id + reduceWidth * 3) < sequence_length)
vals[data_id + reduceWidth * 3] = __float2half(high_data[i].y / sum);
b.sync();
}
}
}
......@@ -291,54 +309,50 @@ __global__ void attn_softmax_v2(float* vals,
for (int i = 0; i < iterations; i++) {
int data_id = i * (reduceWidth << 2) + (seq_lane);
bool check1 = ((!triangular || (data_id <= seq_id)) && (data_id >> 2) >= window_stride4 && data_id < sequence_length);
bool low_x_check = check1 && (data_id > window_stride);
bool low_y_check = check1 && ((data_id + reduceWidth) < sequence_length) && ((!triangular || ((data_id + reduceWidth) <= seq_id)) && (data_id + reduceWidth) > window_stride);
bool high_x_check = check1 && ((data_id + reduceWidth*2) < sequence_length) && ((!triangular || ((data_id + reduceWidth*2) <= seq_id)) && (data_id + reduceWidth*2) > window_stride);
bool high_y_check = check1 && ((data_id + reduceWidth*3) < sequence_length) && ((!triangular || ((data_id + reduceWidth*3) <= seq_id)) && (data_id + reduceWidth*3) > window_stride);
if (attn_mask){
data[i].x = low_x_check
? vals[data_id] + attn_mask[data_id + mask_offset]
: minus_infinity;
b.sync();
data[i].y = low_y_check
? vals[data_id + reduceWidth] + attn_mask[data_id + mask_offset + reduceWidth]
: minus_infinity;
b.sync();
data[i].z = high_x_check
? vals[data_id + reduceWidth*2] + attn_mask[data_id + mask_offset + reduceWidth*2]
: minus_infinity;
b.sync();
data[i].w = high_y_check
? vals[data_id + reduceWidth*3] + attn_mask[data_id + mask_offset + reduceWidth*3]
: minus_infinity;
b.sync();
}
else {
data[i].x = low_x_check
? vals[data_id]
: minus_infinity;
b.sync();
data[i].y = low_y_check
? vals[data_id + reduceWidth]
: minus_infinity;
b.sync();
data[i].z = high_x_check
? vals[data_id + reduceWidth*2]
: minus_infinity;
b.sync();
data[i].w = high_y_check
? vals[data_id + reduceWidth*3]
: minus_infinity;
b.sync();
bool check1 = ((!triangular || (data_id <= seq_id)) &&
(data_id >> 2) >= window_stride4 && data_id < sequence_length);
bool x_check = check1 && (data_id > window_stride);
bool y_check = check1 && ((data_id + reduceWidth) < sequence_length) &&
((!triangular || ((data_id + reduceWidth) <= seq_id)) &&
(data_id + reduceWidth) > window_stride);
bool z_check = check1 && ((data_id + reduceWidth * 2) < sequence_length) &&
((!triangular || ((data_id + reduceWidth * 2) <= seq_id)) &&
(data_id + reduceWidth * 2) > window_stride);
bool w_check = check1 && ((data_id + reduceWidth * 3) < sequence_length) &&
((!triangular || ((data_id + reduceWidth * 3) <= seq_id)) &&
(data_id + reduceWidth * 3) > window_stride);
if (attn_mask) {
data[i].x = x_check ? vals[data_id] + attn_mask[data_id + mask_offset]
: minus_infinity;
b.sync();
data[i].y = y_check ? vals[data_id + reduceWidth] +
attn_mask[data_id + mask_offset + reduceWidth]
: minus_infinity;
b.sync();
data[i].z = z_check ? vals[data_id + reduceWidth * 2] +
attn_mask[data_id + mask_offset + reduceWidth * 2]
: minus_infinity;
b.sync();
data[i].w = w_check ? vals[data_id + reduceWidth * 3] +
attn_mask[data_id + mask_offset + reduceWidth * 3]
: minus_infinity;
b.sync();
} else {
data[i].x = x_check ? vals[data_id] : minus_infinity;
b.sync();
data[i].y = y_check ? vals[data_id + reduceWidth] : minus_infinity;
b.sync();
data[i].z = z_check ? vals[data_id + reduceWidth * 2] : minus_infinity;
b.sync();
data[i].w = w_check ? vals[data_id + reduceWidth * 3] : minus_infinity;
b.sync();
}
max_val = (data[i].x > max_val ? data[i].x : max_val);
max_val = (data[i].y > max_val ? data[i].y : max_val);
max_val = (data[i].z > max_val ? data[i].z : max_val);
max_val = (data[i].w > max_val ? data[i].w : max_val);
}
for (int i = 1; i < WARP_SIZE; i *= 2) {
......@@ -395,22 +409,35 @@ __global__ void attn_softmax_v2(float* vals,
b.sync();
if ((data_id + reduceWidth) < sequence_length)
vals[data_id + reduceWidth] = data[i].y / sum;
b.sync();
if ((data_id + reduceWidth*2) < sequence_length)
vals[data_id + reduceWidth*2] = data[i].z / sum;
b.sync();
if ((data_id + reduceWidth*3) < sequence_length)
vals[data_id + reduceWidth*3] = data[i].w / sum;
b.sync();
b.sync();
if ((data_id + reduceWidth * 2) < sequence_length)
vals[data_id + reduceWidth * 2] = data[i].z / sum;
b.sync();
if ((data_id + reduceWidth * 3) < sequence_length)
vals[data_id + reduceWidth * 3] = data[i].w / sum;
b.sync();
}
}
}
}
#define LAUNCH_ATTN_SOFTMAX_V2(iterations) \
attn_softmax_v2<iterations><<<grid, block, 0, stream>>> \
(vals,mask,alibi,layer_scale,triangular,recompute,local_attention,window_size,total_count,heads, \
sequence_length,num_seq,head_offset,mask_stride,mp_size,reduce_width);
#define LAUNCH_ATTN_SOFTMAX_V2(iterations) \
attn_softmax_v2<iterations><<<grid, block, 0, stream>>>(vals, \
mask, \
alibi, \
layer_scale, \
triangular, \
recompute, \
local_attention, \
window_size, \
total_count, \
heads, \
sequence_length, \
num_seq, \
head_offset, \
mask_stride, \
mp_size, \
reduce_width);
template <typename T>
void launch_attn_softmax_v2(T* vals,
......@@ -457,30 +484,23 @@ void launch_attn_softmax_v2(T* vals,
dim3 grid((total_count + partitions - 1) / partitions);
dim3 block(attn_threads);
if (sequence_length <= 32768){
if (iterations == 1){
if (sequence_length <= 32768) {
if (iterations == 1) {
LAUNCH_ATTN_SOFTMAX_V2(1);
}
else if (iterations == 2){
} else if (iterations == 2) {
LAUNCH_ATTN_SOFTMAX_V2(2);
}
else if (iterations == 4){
} else if (iterations == 4) {
LAUNCH_ATTN_SOFTMAX_V2(4);
}
else if (iterations == 8){
} else if (iterations == 8) {
LAUNCH_ATTN_SOFTMAX_V2(8);
}
else if (iterations == 16){
} else if (iterations == 16) {
LAUNCH_ATTN_SOFTMAX_V2(16);
}
else if (iterations == 32){
} else if (iterations == 32) {
LAUNCH_ATTN_SOFTMAX_V2(32);
}
else if (iterations == 64){
} else if (iterations == 64) {
LAUNCH_ATTN_SOFTMAX_V2(64);
}
}
else
} else
throw std::runtime_error("Unsupport Seq_Length!");
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册