未验证 提交 c199edac 编写于 作者: M Michael Wyatt 提交者: GitHub

refactor to use mem_access (#2317)

上级 060078ab
......@@ -95,54 +95,44 @@ template void launch_bias_gelu<__half>(__half*, const __half*, int, int, cudaStr
// Not called directly from DeepSpeed, but used in ds_qkv_gemm_int8, ds_linear_layer, etc.
__global__ void fused_bias_add(float* input, const float* bias, int total_count, int hidden_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(float);
const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % hidden_size];
float data[vals_per_access];
float bias_data[vals_per_access];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(bias_data, bias + (offset % hidden_size));
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
#pragma unroll
for (int i = 0; i < vals_per_access; i++) { data[i] += bias_data[i]; }
input_cast[offset] = data;
mem_access::store_global<granularity>(input + offset, data);
}
}
__global__ void fused_bias_add(__half* input, const __half* bias, int total_count, int hidden_size)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(__half);
const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % hidden_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
__half2 data[vals_per_access / 2];
__half2 bias_data[vals_per_access / 2];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(bias_data, bias + (offset % hidden_size));
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
#pragma unroll
for (int i = 0; i < vals_per_access / 2; i++) {
float2 data_f = __half22float2(data[i]);
float2 bias_f = __half22float2(bias_data[i]);
data[i] = __floats2half2_rn(data_f.x + bias_f.x, data_f.y + bias_f.y);
}
input_cast[offset] = vals_vec;
mem_access::store_global<granularity>(input + offset, data);
}
#endif
}
......@@ -150,12 +140,15 @@ __global__ void fused_bias_add(__half* input, const __half* bias, int total_coun
template <typename T>
void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, cudaStream_t stream)
{
int total_count = batch_size * (hidden_size / 4);
int threads = 1024; // hidden_size / iterations / 4;
constexpr int threads = 1024;
constexpr int granularity = 16;
const int total_count = batch_size * hidden_size;
const int elems_per_block = threads * (granularity / sizeof(T));
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / threads + 1)); // (batch_size);
dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block);
fused_bias_add<<<grid_dims, block_dims, 0, stream>>>(input, bias, total_count, hidden_size / 4);
fused_bias_add<<<grid_dims, block_dims, 0, stream>>>(input, bias, total_count, hidden_size);
}
template void launch_bias_add<float>(float*, const float*, int, int, cudaStream_t);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册