未验证 提交 e14d40e5 编写于 作者: A Arash Bakhtiari 提交者: GitHub

Refactor fused_bias_residual kernels for better readability (#2356)

Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 79692af1
......@@ -126,120 +126,127 @@ void launch_bias_add(T* input, const T* bias, int hidden_size, int batch_size, c
template void launch_bias_add<float>(float*, const float*, int, int, cudaStream_t);
template void launch_bias_add<__half>(__half*, const __half*, int, int, cudaStream_t);
__global__ void fused_bias_residual(float* input,
float* output,
float* attn,
float* bias,
float* attnbias,
int total_count,
int intermediate_size,
float mp_scale,
bool preln)
__global__ void fused_bias_residual(float* residual,
const float* hidden_state,
const float* attn,
const float* bias,
const float* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale,
const bool preln)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
float4* attn_cast = reinterpret_cast<float4*>(attn);
float4* bias_cast = reinterpret_cast<float4*>(bias);
float4* attnbias_cast = reinterpret_cast<float4*>(attnbias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
float4* res_fl4_ptr = reinterpret_cast<float4*>(residual);
const float4* hs_fl4_ptr = reinterpret_cast<const float4*>(hidden_state);
const float4* attn_fl4_ptr = reinterpret_cast<const float4*>(attn);
const float4* bias_fl4_ptr = reinterpret_cast<const float4*>(bias);
const float4* attn_bias_fl4_ptr = reinterpret_cast<const float4*>(attn_bias);
const int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 out = output_cast[offset];
float4 res_vec = attn_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];
float4 res_fl4 = res_fl4_ptr[offset];
const float4 hs_fl4 = hs_fl4_ptr[offset];
const float4 attn_fl4 = attn_fl4_ptr[offset];
const float4 bias_fl4 = bias_fl4_ptr[offset % intermediate_size];
const float4 attn_bias_fl4 = attn_bias_fl4_ptr[offset % intermediate_size];
if (preln) {
data.x = (data.x + res_vec.x + bias_data.x + attn_bias.x) * mp_scale + (out.x);
data.y = (data.y + res_vec.y + bias_data.y + attn_bias.y) * mp_scale + (out.y);
data.z = (data.z + res_vec.z + bias_data.z + attn_bias.z) * mp_scale + (out.z);
data.w = (data.w + res_vec.w + bias_data.w + attn_bias.w) * mp_scale + (out.w);
// residual = (residual + attention + bias + attention_bias) *
// mp_scale + hidden_state
res_fl4.x =
(res_fl4.x + attn_fl4.x + bias_fl4.x + attn_bias_fl4.x) * mp_scale + (hs_fl4.x);
res_fl4.y =
(res_fl4.y + attn_fl4.y + bias_fl4.y + attn_bias_fl4.y) * mp_scale + (hs_fl4.y);
res_fl4.z =
(res_fl4.z + attn_fl4.z + bias_fl4.z + attn_bias_fl4.z) * mp_scale + (hs_fl4.z);
res_fl4.w =
(res_fl4.w + attn_fl4.w + bias_fl4.w + attn_bias_fl4.w) * mp_scale + (hs_fl4.w);
} else {
data.x = data.x + out.x + bias_data.x;
data.y = data.y + out.y + bias_data.y;
data.z = data.z + out.z + bias_data.z;
data.w = data.w + out.w + bias_data.w;
// residual += hidden_state + bias
res_fl4.x = res_fl4.x + hs_fl4.x + bias_fl4.x;
res_fl4.y = res_fl4.y + hs_fl4.y + bias_fl4.y;
res_fl4.z = res_fl4.z + hs_fl4.z + bias_fl4.z;
res_fl4.w = res_fl4.w + hs_fl4.w + bias_fl4.w;
}
input_cast[offset] = data;
res_fl4_ptr[offset] = res_fl4;
}
}
__global__ void fused_bias_residual(__half* input,
__half* output,
__half* attn,
__half* bias,
__half* attn_bias,
int total_count,
int intermediate_size,
float mp_scale,
bool preln)
__global__ void fused_bias_residual(__half* residual,
const __half* hidden_state,
const __half* attn,
const __half* bias,
const __half* attn_bias,
const int total_count,
const int intermediate_size,
const float mp_scale,
const bool preln)
{
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
float2* output_cast = reinterpret_cast<float2*>(output);
float2* attn_cast = reinterpret_cast<float2*>(attn);
float2* bias_cast = reinterpret_cast<float2*>(bias);
float2* attnbias_cast = reinterpret_cast<float2*>(attn_bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
float2* res_fl2_ptr = reinterpret_cast<float2*>(residual);
const float2* hs_fl2_ptr = reinterpret_cast<const float2*>(hidden_state);
const float2* attn_fl2_ptr = reinterpret_cast<const float2*>(attn);
const float2* bias_fl2_ptr = reinterpret_cast<const float2*>(bias);
const float2* attn_bias_fl2_ptr = reinterpret_cast<const float2*>(attn_bias);
const int offset = blockIdx.x * blockDim.x + threadIdx.x;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 out_vec = output_cast[offset];
float2 res_vec = attn_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
float2 attn_bias_vec = attnbias_cast[offset % intermediate_size];
float2 res_fl2 = res_fl2_ptr[offset];
const float2 hs_fl2 = hs_fl2_ptr[offset];
const float2 attn_fl2 = attn_fl2_ptr[offset];
const float2 bias_fl2 = bias_fl2_ptr[offset % intermediate_size];
const float2 attn_bias_fl2 = attn_bias_fl2_ptr[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* out_half = reinterpret_cast<__half2*>(&out_vec);
__half2* res_half = reinterpret_cast<__half2*>(&res_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
__half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec);
__half2* res_half2 = reinterpret_cast<__half2*>(&res_fl2);
const __half2* hs_half2 = reinterpret_cast<const __half2*>(&hs_fl2);
const __half2* attn_half2 = reinterpret_cast<const __half2*>(&attn_fl2);
const __half2* bias_half2 = reinterpret_cast<const __half2*>(&bias_fl2);
const __half2* attn_bias_half2 = reinterpret_cast<const __half2*>(&attn_bias_fl2);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 res_low = __half22float2(res_half2[0]);
float2 res_high = __half22float2(res_half2[1]);
float2 low_out = __half22float2(out_half[0]);
float2 high_out = __half22float2(out_half[1]);
const float2 hs_low = __half22float2(hs_half2[0]);
const float2 hs_high = __half22float2(hs_half2[1]);
float2 low_res = __half22float2(res_half[0]);
float2 high_res = __half22float2(res_half[1]);
const float2 attn_low = __half22float2(attn_half2[0]);
const float2 attn_high = __half22float2(attn_half2[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
const float2 bias_low = __half22float2(bias_half2[0]);
const float2 bias_high = __half22float2(bias_half2[1]);
float2 attn_low_bias = __half22float2(attnbias_half[0]);
float2 attn_high_bias = __half22float2(attnbias_half[1]);
const float2 attn_bias_low = __half22float2(attn_bias_half2[0]);
const float2 attn_bias_high = __half22float2(attn_bias_half2[1]);
if (preln) {
low_data.x =
(low_data.x + low_res.x + (low_bias.x + attn_low_bias.x)) * mp_scale + low_out.x;
low_data.y =
(low_data.y + low_res.y + (low_bias.y + attn_low_bias.y)) * mp_scale + low_out.y;
high_data.x = (high_data.x + high_res.x + (high_bias.x + attn_high_bias.x)) * mp_scale +
high_out.x;
high_data.y = (high_data.y + high_res.y + (high_bias.y + attn_high_bias.y)) * mp_scale +
high_out.y;
// residual = (residual + attention + bias + attention_bias) *
// mp_scale + hidden_state
res_low.x =
(res_low.x + attn_low.x + bias_low.x + attn_bias_low.x) * mp_scale + hs_low.x;
res_low.y =
(res_low.y + attn_low.y + bias_low.y + attn_bias_low.y) * mp_scale + hs_low.y;
res_high.x =
(res_high.x + attn_high.x + bias_high.x + attn_bias_high.x) * mp_scale + hs_high.x;
res_high.y =
(res_high.y + attn_high.y + bias_high.y + attn_bias_high.y) * mp_scale + hs_high.y;
} else {
low_data.x = (low_data.x + low_out.x + low_bias.x);
low_data.y = (low_data.y + low_out.y + low_bias.y);
high_data.x = (high_data.x + high_out.x + high_bias.x);
high_data.y = (high_data.y + high_out.y + high_bias.y);
// residual += hidden_state + bias
res_low.x = (res_low.x + hs_low.x + bias_low.x);
res_low.y = (res_low.y + hs_low.y + bias_low.y);
res_high.x = (res_high.x + hs_high.x + bias_high.x);
res_high.y = (res_high.y + hs_high.y + bias_high.y);
}
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
res_half2[0] = __float22half2_rn(res_low);
res_half2[1] = __float22half2_rn(res_high);
input_cast[offset] = vals_vec;
res_fl2_ptr[offset] = res_fl2;
}
#endif
}
template <typename T>
void launch_bias_residual(T* input,
T* output,
void launch_bias_residual(T* residual,
T* hidden_state,
T* attn,
T* bias,
T* attn_bias,
......@@ -253,8 +260,15 @@ void launch_bias_residual(T* input,
dim3 block_dims(1024);
dim3 grid_dims((total_count - 1) / 1024 + 1); // (batch_size);
fused_bias_residual<<<grid_dims, block_dims, 0, stream>>>(
input, output, attn, bias, attn_bias, total_count, hidden_dim / 4, 1.0 / mp_size, preln);
fused_bias_residual<<<grid_dims, block_dims, 0, stream>>>(residual,
hidden_state,
attn,
bias,
attn_bias,
total_count,
hidden_dim / 4,
1.0 / mp_size,
preln);
}
template void launch_bias_residual<
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册