未验证 提交 d0dfe38d 编写于 作者: M Molly Smith 提交者: GitHub

Update relu.cu with mem_access_utils (#2306)

上级 b2d550ab
......@@ -3,7 +3,9 @@ Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "inference_cuda_layers.h"
#include "memory_access_utils.h"
namespace cg = cooperative_groups;
#define MAX_CAP 4
#define MAX_SEQ 2048
......@@ -14,25 +16,21 @@ __global__ void fused_bias_relu(float* input,
int total_count,
int intermediate_size)
{
float4* input_cast = reinterpret_cast<float4*>(input);
const float4* bias_cast = reinterpret_cast<const float4*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
// Input restriction: intermediate_size % vals_per_access == 0
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(float);
const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access;
if (offset < total_count) {
float4 data = input_cast[offset];
float4 bias_data = bias_cast[offset % intermediate_size];
float data[vals_per_access];
float data_bias[vals_per_access];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(data_bias, bias + (offset % intermediate_size));
data.x += bias_data.x;
data.y += bias_data.y;
data.z += bias_data.z;
data.w += bias_data.w;
#pragma unroll
for (int i = 0; i < vals_per_access; i++) { data[i] = relu(data[i] + data_bias[i]); }
data.x = relu(data.x);
data.y = relu(data.y);
data.z = relu(data.z);
data.w = relu(data.w);
input_cast[offset] = data;
mem_access::store_global<granularity>(input + offset, data);
}
}
......@@ -41,40 +39,28 @@ __global__ void fused_bias_relu(__half* input,
int total_count,
int intermediate_size)
{
// Input restriction: intermediate_size % vals_per_access == 0
// This kernel doubles the per-thread ALU workload as compared to the float implementation
#ifdef HALF_PRECISION_AVAILABLE
float2* input_cast = reinterpret_cast<float2*>(input);
const float2* bias_cast = reinterpret_cast<const float2*>(bias);
int offset = blockIdx.x * blockDim.x + threadIdx.x;
constexpr int granularity = 16;
constexpr int vals_per_access = granularity / sizeof(__half);
int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access;
if (offset < total_count) {
float2 vals_vec = input_cast[offset];
float2 bias_vec = bias_cast[offset % intermediate_size];
__half2* vals_half = reinterpret_cast<__half2*>(&vals_vec);
__half2* bias_half = reinterpret_cast<__half2*>(&bias_vec);
float2 low_data = __half22float2(vals_half[0]);
float2 high_data = __half22float2(vals_half[1]);
float2 low_bias = __half22float2(bias_half[0]);
float2 high_bias = __half22float2(bias_half[1]);
low_data.x += low_bias.x;
low_data.y += low_bias.y;
high_data.x += high_bias.x;
high_data.y += high_bias.y;
low_data.x = relu(low_data.x);
low_data.y = relu(low_data.y);
high_data.x = relu(high_data.x);
high_data.y = relu(high_data.y);
vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
// Divide by 2 since we store two values per __half2
__half2 data[vals_per_access / 2];
__half2 bias_data[vals_per_access / 2];
mem_access::load_global<granularity>(data, input + offset);
mem_access::load_global<granularity>(bias_data, bias + (offset % intermediate_size));
#pragma unroll
for (int i = 0; i < vals_per_access / 2; i++) {
float2 data_f = __half22float2(data[i]);
float2 bias_f = __half22float2(bias_data[i]);
data[i] = __floats2half2_rn(relu(data_f.x + bias_f.x), relu(data_f.y + bias_f.y));
}
input_cast[offset] = vals_vec;
mem_access::store_global<granularity>(input + offset, data);
}
#endif
}
......@@ -86,13 +72,16 @@ void launch_bias_relu(T* input,
int batch_size,
cudaStream_t stream)
{
int total_count = batch_size * (intermediate_size / 4);
int threads = 1024; // intermediate_size / iterations / 4;
constexpr int threads = 1024;
constexpr int granularity = 16;
const int total_count = batch_size * intermediate_size;
const int elems_per_block = threads * (granularity / sizeof(T));
dim3 block_dims(threads);
dim3 grid_dims(((total_count - 1) / 1024 + 1)); // (batch_size);
dim3 grid_dims((total_count + elems_per_block - 1) / elems_per_block);
fused_bias_relu<<<grid_dims, block_dims, 0, stream>>>(
input, bias, total_count, intermediate_size / 4);
input, bias, total_count, intermediate_size);
}
template void launch_bias_relu<float>(float*, const float*, int, int, cudaStream_t);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册