diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index 5eb30478e80b402222a5491c819be9a01a4ab811..928f64c290e95b0806899a55cd339ac91615b5a5 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -95,54 +95,44 @@ template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStr // Not called directly from DeepSpeed, but used in ds_qkv_gemm_int8, ds_linear_layer, etc. __global__ void fused_bias_add(float* input, const float* bias, int total_count, int hidden_size) { - float4* input_cast = reinterpret_cast(input); - const float4* bias_cast = reinterpret_cast(bias); - int offset = blockIdx.x * blockDim.x + threadIdx.x; + constexpr int granularity = 16; + constexpr int vals_per_access = granularity / sizeof(float); + const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access; if (offset < total_count) { - float4 data = input_cast[offset]; - float4 bias_data = bias_cast[offset % hidden_size]; + float data[vals_per_access]; + float bias_data[vals_per_access]; + mem_access::load_global(data, input + offset); + mem_access::load_global(bias_data, bias + (offset % hidden_size)); - data.x += bias_data.x; - data.y += bias_data.y; - data.z += bias_data.z; - data.w += bias_data.w; +#pragma unroll + for (int i = 0; i < vals_per_access; i++) { data[i] += bias_data[i]; } - input_cast[offset] = data; + mem_access::store_global(input + offset, data); } } __global__ void fused_bias_add(__half* input, const __half* bias, int total_count, int hidden_size) { #ifdef HALF_PRECISION_AVAILABLE - - float2* input_cast = reinterpret_cast(input); - const float2* bias_cast = reinterpret_cast(bias); - - int offset = blockIdx.x * blockDim.x + threadIdx.x; + constexpr int granularity = 16; + constexpr int vals_per_access = granularity / sizeof(__half); + const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access; if (offset < total_count) { - float2 vals_vec = input_cast[offset]; - float2 bias_vec = bias_cast[offset % hidden_size]; - - __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); - __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); - - float2 low_data = __half22float2(vals_half[0]); - float2 high_data = __half22float2(vals_half[1]); - - float2 low_bias = __half22float2(bias_half[0]); - float2 high_bias = __half22float2(bias_half[1]); - - low_data.x += low_bias.x; - low_data.y += low_bias.y; - high_data.x += high_bias.x; - high_data.y += high_bias.y; + __half2 data[vals_per_access / 2]; + __half2 bias_data[vals_per_access / 2]; + mem_access::load_global(data, input + offset); + mem_access::load_global(bias_data, bias + (offset % hidden_size)); - vals_half[0] = __float22half2_rn(low_data); - vals_half[1] = __float22half2_rn(high_data); +#pragma unroll + for (int i = 0; i < vals_per_access / 2; i++) { + float2 data_f = __half22float2(data[i]); + float2 bias_f = __half22float2(bias_data[i]); + data[i] = __floats2half2_rn(data_f.x + bias_f.x, data_f.y + bias_f.y); + } - input_cast[offset] = vals_vec; + mem_access::store_global(input + offset, data); } #endif } @@ -150,12 +140,15 @@ __global__ void fused_bias_add(__half* input, const __half* bias, int total_coun template void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream) { - int total_count = batch_size * (hidden_size / 4); - int threads = 1024; // hidden_size / iterations / 4; + constexpr int threads = 1024; + constexpr int granularity = 16; + + const int total_count = batch_size * hidden_size; + const int elems_per_block = threads * (granularity / sizeof(T)); dim3 block_dims(threads); - dim3 grid_dims(((total_count - 1) / threads + 1)); // (batch_size); + dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block); - fused_bias_add<<>>(input, bias, total_count, hidden_size / 4); + fused_bias_add<<>>(input, bias, total_count, hidden_size); } template void launch_bias_add(float*, const float*, int, int, cudaStream_t);