From 150eb7d96b6084190265b440739317216992bd82 Mon Sep 17 00:00:00 2001 From: Molly Smith Date: Thu, 30 Mar 2023 04:02:05 +0500 Subject: [PATCH] naming. precommit --- csrc/transformer/general_kernels.cu | 4 +- csrc/transformer/inference/csrc/softmax.cu | 326 +++++++++++---------- 2 files changed, 175 insertions(+), 155 deletions(-) diff --git a/csrc/transformer/general_kernels.cu b/csrc/transformer/general_kernels.cu index 2d9eeb82..ea549100 100644 --- a/csrc/transformer/general_kernels.cu +++ b/csrc/transformer/general_kernels.cu @@ -161,7 +161,7 @@ void launch_fused_add2(float* out, int total_count = batch_size * seq_length * hidden_dim / 4; dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length); - dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); + dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); fused_add2_kernel<<>>(total_count, out, inp1, inp2); } @@ -178,7 +178,7 @@ void launch_fused_add2<__half>(__half* out, int total_count = batch_size * seq_length * hidden_dim / 4; dim3 grid_dim = DS_GET_BLOCKS(total_count); //(batch_size * seq_length); - dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); + dim3 block_dim = DS_CUDA_NUM_THREADS; //(hidden_dim / 4); fused_add2_kernel<<>>(total_count, out, inp1, inp2); } diff --git a/csrc/transformer/inference/csrc/softmax.cu b/csrc/transformer/inference/csrc/softmax.cu index 5599eebf..47a5df9d 100644 --- a/csrc/transformer/inference/csrc/softmax.cu +++ b/csrc/transformer/inference/csrc/softmax.cu @@ -86,92 +86,110 @@ __global__ void attn_softmax_v2(__half* vals, // if (lane == 0) printf("%d, %d: %d \n", wid, blockIdx.x, mask_offset); for (int i = 0; i < iterations; i++) { int data_id = i * (reduceWidth << 2) + (seq_lane); - bool check1 = ((!triangular || (data_id <= seq_id)) && (data_id >> 2) >= window_stride4 && data_id < sequence_length); - bool low_x_check = check1 && (data_id > window_stride); - bool low_y_check = check1 && ((data_id + reduceWidth) < sequence_length) && ((!triangular || ((data_id + reduceWidth) <= seq_id)) && (data_id + reduceWidth) > window_stride); - bool high_x_check = check1 && ((data_id + reduceWidth*2) < sequence_length) && ((!triangular || ((data_id + reduceWidth*2) <= seq_id)) && (data_id + reduceWidth*2) > window_stride); - bool high_y_check = check1 && ((data_id + reduceWidth*3) < sequence_length) && ((!triangular || ((data_id + reduceWidth*3) <= seq_id)) && (data_id + reduceWidth*3) > window_stride); - - if (mask && alibi){ - low_data[i].x = low_x_check - ? __half2float(vals[data_id]) * layer_scale + (__half2float(alibi[data_id + alibi_offset])) + (__half2float(mask[data_id + mask_offset])) - : minus_infinity; - b.sync(); - low_data[i].y = low_y_check - ? __half2float(vals[data_id + reduceWidth]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth])) + (__half2float(mask[data_id + mask_offset + reduceWidth])) - : minus_infinity; - b.sync(); - high_data[i].x = high_x_check - ? __half2float(vals[data_id + reduceWidth*2]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth*2])) + (__half2float(mask[data_id + mask_offset + reduceWidth*2])) - : minus_infinity; - b.sync(); - high_data[i].y = high_y_check - ? __half2float(vals[data_id + reduceWidth*3]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth*3])) + (__half2float(mask[data_id + mask_offset + reduceWidth*3])) - : minus_infinity; - b.sync(); - } - else if (mask){ - low_data[i].x = low_x_check - ? __half2float(vals[data_id]) * layer_scale + (__half2float(mask[data_id + mask_offset])) - : minus_infinity; - b.sync(); - low_data[i].y = low_y_check - ? __half2float(vals[data_id + reduceWidth]) * layer_scale + (__half2float(mask[data_id + mask_offset + reduceWidth])) - : minus_infinity; - b.sync(); - high_data[i].x = high_x_check - ? __half2float(vals[data_id + reduceWidth*2]) * layer_scale + (__half2float(mask[data_id + mask_offset + reduceWidth*2])) - : minus_infinity; - b.sync(); - high_data[i].y = high_y_check - ? __half2float(vals[data_id + reduceWidth*3]) * layer_scale + (__half2float(mask[data_id + mask_offset + reduceWidth*3])) - : minus_infinity; - b.sync(); - } - else if (alibi){ - low_data[i].x = low_x_check - ? __half2float(vals[data_id]) * layer_scale + (__half2float(alibi[data_id + alibi_offset])) + bool check1 = ((!triangular || (data_id <= seq_id)) && + (data_id >> 2) >= window_stride4 && data_id < sequence_length); + bool low_x_check = check1 && (data_id > window_stride); + bool low_y_check = check1 && ((data_id + reduceWidth) < sequence_length) && + ((!triangular || ((data_id + reduceWidth) <= seq_id)) && + (data_id + reduceWidth) > window_stride); + bool high_x_check = check1 && ((data_id + reduceWidth * 2) < sequence_length) && + ((!triangular || ((data_id + reduceWidth * 2) <= seq_id)) && + (data_id + reduceWidth * 2) > window_stride); + bool high_y_check = check1 && ((data_id + reduceWidth * 3) < sequence_length) && + ((!triangular || ((data_id + reduceWidth * 3) <= seq_id)) && + (data_id + reduceWidth * 3) > window_stride); + + if (mask && alibi) { + low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale + + (__half2float(alibi[data_id + alibi_offset])) + + (__half2float(mask[data_id + mask_offset])) + : minus_infinity; + b.sync(); + low_data[i].y = + low_y_check ? __half2float(vals[data_id + reduceWidth]) * layer_scale + + (__half2float(alibi[data_id + alibi_offset + reduceWidth])) + + (__half2float(mask[data_id + mask_offset + reduceWidth])) : minus_infinity; - b.sync(); - low_data[i].y = low_y_check - ? __half2float(vals[data_id + reduceWidth]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth])) - : minus_infinity; - b.sync(); - high_data[i].x = high_x_check - ? __half2float(vals[data_id + reduceWidth*2]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth*2])) - : minus_infinity; - b.sync(); - high_data[i].y = high_y_check - ? __half2float(vals[data_id + reduceWidth*3]) * layer_scale + (__half2float(alibi[data_id + alibi_offset + reduceWidth*3])) - : minus_infinity; - b.sync(); - } - else { - low_data[i].x = low_x_check - ? __half2float(vals[data_id]) * layer_scale + b.sync(); + high_data[i].x = + high_x_check + ? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale + + (__half2float(alibi[data_id + alibi_offset + reduceWidth * 2])) + + (__half2float(mask[data_id + mask_offset + reduceWidth * 2])) + : minus_infinity; + b.sync(); + high_data[i].y = + high_y_check + ? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale + + (__half2float(alibi[data_id + alibi_offset + reduceWidth * 3])) + + (__half2float(mask[data_id + mask_offset + reduceWidth * 3])) + : minus_infinity; + b.sync(); + } else if (mask) { + low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale + + (__half2float(mask[data_id + mask_offset])) + : minus_infinity; + b.sync(); + low_data[i].y = low_y_check + ? __half2float(vals[data_id + reduceWidth]) * layer_scale + + (__half2float(mask[data_id + mask_offset + reduceWidth])) + : minus_infinity; + b.sync(); + high_data[i].x = + high_x_check ? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale + + (__half2float(mask[data_id + mask_offset + reduceWidth * 2])) + : minus_infinity; + b.sync(); + high_data[i].y = + high_y_check ? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale + + (__half2float(mask[data_id + mask_offset + reduceWidth * 3])) + : minus_infinity; + b.sync(); + } else if (alibi) { + low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale + + (__half2float(alibi[data_id + alibi_offset])) + : minus_infinity; + b.sync(); + low_data[i].y = + low_y_check ? __half2float(vals[data_id + reduceWidth]) * layer_scale + + (__half2float(alibi[data_id + alibi_offset + reduceWidth])) : minus_infinity; - b.sync(); - low_data[i].y = low_y_check - ? __half2float(vals[data_id + reduceWidth]) * layer_scale - : minus_infinity; - b.sync(); - high_data[i].x = high_x_check - ? __half2float(vals[data_id + reduceWidth*2]) * layer_scale - : minus_infinity; - b.sync(); - high_data[i].y = high_y_check - ? __half2float(vals[data_id + reduceWidth*3]) * layer_scale - : minus_infinity; - b.sync(); + b.sync(); + high_data[i].x = + high_x_check + ? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale + + (__half2float(alibi[data_id + alibi_offset + reduceWidth * 2])) + : minus_infinity; + b.sync(); + high_data[i].y = + high_y_check + ? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale + + (__half2float(alibi[data_id + alibi_offset + reduceWidth * 3])) + : minus_infinity; + b.sync(); + } else { + low_data[i].x = low_x_check ? __half2float(vals[data_id]) * layer_scale + : minus_infinity; + b.sync(); + low_data[i].y = low_y_check + ? __half2float(vals[data_id + reduceWidth]) * layer_scale + : minus_infinity; + b.sync(); + high_data[i].x = high_x_check + ? __half2float(vals[data_id + reduceWidth * 2]) * layer_scale + : minus_infinity; + b.sync(); + high_data[i].y = high_y_check + ? __half2float(vals[data_id + reduceWidth * 3]) * layer_scale + : minus_infinity; + b.sync(); } - // if(lane == 0) printf("%f , %d, %d \n", low_data[i].x, data_id, seq_id); max_val = (low_data[i].x > max_val ? low_data[i].x : max_val); max_val = (low_data[i].y > max_val ? low_data[i].y : max_val); max_val = (high_data[i].x > max_val ? high_data[i].x : max_val); max_val = (high_data[i].y > max_val ? high_data[i].y : max_val); - } for (int i = 1; i < WARP_SIZE; i *= 2) { @@ -227,13 +245,13 @@ __global__ void attn_softmax_v2(__half* vals, b.sync(); if ((data_id + reduceWidth) < sequence_length) vals[data_id + reduceWidth] = __float2half(low_data[i].y / sum); - b.sync(); - if ((data_id + reduceWidth*2) < sequence_length) - vals[data_id + reduceWidth*2] = __float2half(high_data[i].x / sum); - b.sync(); - if ((data_id + reduceWidth*3) < sequence_length) - vals[data_id + reduceWidth*3] = __float2half(high_data[i].y / sum); - b.sync(); + b.sync(); + if ((data_id + reduceWidth * 2) < sequence_length) + vals[data_id + reduceWidth * 2] = __float2half(high_data[i].x / sum); + b.sync(); + if ((data_id + reduceWidth * 3) < sequence_length) + vals[data_id + reduceWidth * 3] = __float2half(high_data[i].y / sum); + b.sync(); } } } @@ -291,54 +309,50 @@ __global__ void attn_softmax_v2(float* vals, for (int i = 0; i < iterations; i++) { int data_id = i * (reduceWidth << 2) + (seq_lane); - bool check1 = ((!triangular || (data_id <= seq_id)) && (data_id >> 2) >= window_stride4 && data_id < sequence_length); - bool low_x_check = check1 && (data_id > window_stride); - bool low_y_check = check1 && ((data_id + reduceWidth) < sequence_length) && ((!triangular || ((data_id + reduceWidth) <= seq_id)) && (data_id + reduceWidth) > window_stride); - bool high_x_check = check1 && ((data_id + reduceWidth*2) < sequence_length) && ((!triangular || ((data_id + reduceWidth*2) <= seq_id)) && (data_id + reduceWidth*2) > window_stride); - bool high_y_check = check1 && ((data_id + reduceWidth*3) < sequence_length) && ((!triangular || ((data_id + reduceWidth*3) <= seq_id)) && (data_id + reduceWidth*3) > window_stride); - - if (attn_mask){ - data[i].x = low_x_check - ? vals[data_id] + attn_mask[data_id + mask_offset] - : minus_infinity; - b.sync(); - data[i].y = low_y_check - ? vals[data_id + reduceWidth] + attn_mask[data_id + mask_offset + reduceWidth] - : minus_infinity; - b.sync(); - data[i].z = high_x_check - ? vals[data_id + reduceWidth*2] + attn_mask[data_id + mask_offset + reduceWidth*2] - : minus_infinity; - b.sync(); - data[i].w = high_y_check - ? vals[data_id + reduceWidth*3] + attn_mask[data_id + mask_offset + reduceWidth*3] - : minus_infinity; - b.sync(); - } - else { - data[i].x = low_x_check - ? vals[data_id] - : minus_infinity; - b.sync(); - data[i].y = low_y_check - ? vals[data_id + reduceWidth] - : minus_infinity; - b.sync(); - data[i].z = high_x_check - ? vals[data_id + reduceWidth*2] - : minus_infinity; - b.sync(); - data[i].w = high_y_check - ? vals[data_id + reduceWidth*3] - : minus_infinity; - b.sync(); + bool check1 = ((!triangular || (data_id <= seq_id)) && + (data_id >> 2) >= window_stride4 && data_id < sequence_length); + bool x_check = check1 && (data_id > window_stride); + bool y_check = check1 && ((data_id + reduceWidth) < sequence_length) && + ((!triangular || ((data_id + reduceWidth) <= seq_id)) && + (data_id + reduceWidth) > window_stride); + bool z_check = check1 && ((data_id + reduceWidth * 2) < sequence_length) && + ((!triangular || ((data_id + reduceWidth * 2) <= seq_id)) && + (data_id + reduceWidth * 2) > window_stride); + bool w_check = check1 && ((data_id + reduceWidth * 3) < sequence_length) && + ((!triangular || ((data_id + reduceWidth * 3) <= seq_id)) && + (data_id + reduceWidth * 3) > window_stride); + + if (attn_mask) { + data[i].x = x_check ? vals[data_id] + attn_mask[data_id + mask_offset] + : minus_infinity; + b.sync(); + data[i].y = y_check ? vals[data_id + reduceWidth] + + attn_mask[data_id + mask_offset + reduceWidth] + : minus_infinity; + b.sync(); + data[i].z = z_check ? vals[data_id + reduceWidth * 2] + + attn_mask[data_id + mask_offset + reduceWidth * 2] + : minus_infinity; + b.sync(); + data[i].w = w_check ? vals[data_id + reduceWidth * 3] + + attn_mask[data_id + mask_offset + reduceWidth * 3] + : minus_infinity; + b.sync(); + } else { + data[i].x = x_check ? vals[data_id] : minus_infinity; + b.sync(); + data[i].y = y_check ? vals[data_id + reduceWidth] : minus_infinity; + b.sync(); + data[i].z = z_check ? vals[data_id + reduceWidth * 2] : minus_infinity; + b.sync(); + data[i].w = w_check ? vals[data_id + reduceWidth * 3] : minus_infinity; + b.sync(); } max_val = (data[i].x > max_val ? data[i].x : max_val); max_val = (data[i].y > max_val ? data[i].y : max_val); max_val = (data[i].z > max_val ? data[i].z : max_val); max_val = (data[i].w > max_val ? data[i].w : max_val); - } for (int i = 1; i < WARP_SIZE; i *= 2) { @@ -395,22 +409,35 @@ __global__ void attn_softmax_v2(float* vals, b.sync(); if ((data_id + reduceWidth) < sequence_length) vals[data_id + reduceWidth] = data[i].y / sum; - b.sync(); - if ((data_id + reduceWidth*2) < sequence_length) - vals[data_id + reduceWidth*2] = data[i].z / sum; - b.sync(); - if ((data_id + reduceWidth*3) < sequence_length) - vals[data_id + reduceWidth*3] = data[i].w / sum; - b.sync(); + b.sync(); + if ((data_id + reduceWidth * 2) < sequence_length) + vals[data_id + reduceWidth * 2] = data[i].z / sum; + b.sync(); + if ((data_id + reduceWidth * 3) < sequence_length) + vals[data_id + reduceWidth * 3] = data[i].w / sum; + b.sync(); } } } } -#define LAUNCH_ATTN_SOFTMAX_V2(iterations) \ - attn_softmax_v2<<>> \ - (vals,mask,alibi,layer_scale,triangular,recompute,local_attention,window_size,total_count,heads, \ - sequence_length,num_seq,head_offset,mask_stride,mp_size,reduce_width); +#define LAUNCH_ATTN_SOFTMAX_V2(iterations) \ + attn_softmax_v2<<>>(vals, \ + mask, \ + alibi, \ + layer_scale, \ + triangular, \ + recompute, \ + local_attention, \ + window_size, \ + total_count, \ + heads, \ + sequence_length, \ + num_seq, \ + head_offset, \ + mask_stride, \ + mp_size, \ + reduce_width); template void launch_attn_softmax_v2(T* vals, @@ -457,30 +484,23 @@ void launch_attn_softmax_v2(T* vals, dim3 grid((total_count + partitions - 1) / partitions); dim3 block(attn_threads); - if (sequence_length <= 32768){ - if (iterations == 1){ + if (sequence_length <= 32768) { + if (iterations == 1) { LAUNCH_ATTN_SOFTMAX_V2(1); - } - else if (iterations == 2){ + } else if (iterations == 2) { LAUNCH_ATTN_SOFTMAX_V2(2); - } - else if (iterations == 4){ + } else if (iterations == 4) { LAUNCH_ATTN_SOFTMAX_V2(4); - } - else if (iterations == 8){ + } else if (iterations == 8) { LAUNCH_ATTN_SOFTMAX_V2(8); - } - else if (iterations == 16){ + } else if (iterations == 16) { LAUNCH_ATTN_SOFTMAX_V2(16); - } - else if (iterations == 32){ + } else if (iterations == 32) { LAUNCH_ATTN_SOFTMAX_V2(32); - } - else if (iterations == 64){ + } else if (iterations == 64) { LAUNCH_ATTN_SOFTMAX_V2(64); } - } - else + } else throw std::runtime_error("Unsupport Seq_Length!"); } -- GitLab