general_kernels.cu 14.6 KB
Newer Older
J
Jeff Rasley 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416
#include "general_kernels.h"

namespace cg = cooperative_groups;

template <typename T>
__global__ void column_sum_reduce(const T* __restrict__ inp,
                                  T* __restrict__ out,
                                  int rows,
                                  int width)
{
    __shared__ float tile[TILE_DIM][TILE_DIM + 1];

    cg::thread_block b = cg::this_thread_block();
    cg::thread_block_tile<TILE_DIM> g = cg::tiled_partition<TILE_DIM>(b);

    int idx = blockDim.x * blockIdx.x + threadIdx.x;
    int offset = threadIdx.y * width + idx;
    int y_stride = width * TILE_DIM;

    float localSum = 0;

    // Loop across matrix height
    for (int r = threadIdx.y; r < rows; r += TILE_DIM) {
        localSum += (float)inp[offset];
        offset += y_stride;
    }

    tile[threadIdx.x][threadIdx.y] = localSum;

    __syncthreads();

    // Sum the shared buffer.
    float sum = tile[threadIdx.y][threadIdx.x];

#ifndef __STOCHASTIC_MODE__
    __syncthreads();
#endif

    for (int i = 1; i < TILE_DIM; i <<= 1) sum += g.shfl_down(sum, i);

    if (threadIdx.x == 0) {
        int pos = blockIdx.x * TILE_DIM + threadIdx.y;
        out[pos] = sum;
    }
}

template <typename T>
void launch_fuse_transpose_bias_kernel(const T* inp,
                                       T* out,
                                       int rows,
                                       int cols,
                                       cudaStream_t stream);

template <>
void launch_fuse_transpose_bias_kernel<float>(const float* inp,
                                              float* out,
                                              int rows,
                                              int cols,
                                              cudaStream_t stream)
{
    assert(rows % TILE_DIM == 0);
    assert(cols % TILE_DIM == 0);

    dim3 grid_dim(cols / TILE_DIM);
    dim3 block_dim(TILE_DIM, TILE_DIM);

    column_sum_reduce<float><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
}

template <>
void launch_fuse_transpose_bias_kernel<__half>(const __half* inp,
                                               __half* out,
                                               int rows,
                                               int cols,
                                               cudaStream_t stream)
{
    assert(rows % TILE_DIM == 0);
    assert(cols % TILE_DIM == 0);

    dim3 grid_dim(cols / TILE_DIM);
    dim3 block_dim(TILE_DIM, TILE_DIM);

    column_sum_reduce<__half><<<grid_dim, block_dim, 0, stream>>>(inp, out, rows, cols);
}

__global__ void fused_add2_kernel(float* out,
                                  const float* inp1,
                                  const float* inp2,
                                  int size,
                                  int row_stride)
{
    int row = blockIdx.x;
    int id = threadIdx.x;

    const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
    const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
    float4* out_4 = reinterpret_cast<float4*>(out);

    float4 val;
    float4 inp1_reg = inp1_4[row * row_stride + id];
    float4 inp2_reg = inp2_4[row * row_stride + id];

    val.x = inp1_reg.x + inp2_reg.x;
    val.y = inp1_reg.y + inp2_reg.y;
    val.z = inp1_reg.z + inp2_reg.z;
    val.w = inp1_reg.w + inp2_reg.w;

    out_4[row * row_stride + id] = val;
}

__global__ void fused_add2_kernel(__half* out,
                                  const __half* inp1,
                                  const __half* inp2,
                                  int size,
                                  int row_stride)
{
    int row = blockIdx.x;
    int id = threadIdx.x;

    float2 inp1_4;
    float2 inp2_4;

    __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
    __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);

    const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
    const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);

    inp1_4 = inp1_arr[row * row_stride + id];
    inp2_4 = inp2_arr[row * row_stride + id];

    float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
    float2 inp1_h_f_1 = __half22float2(inp1_h[1]);

    float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
    float2 inp2_h_f_1 = __half22float2(inp2_h[1]);

    inp1_h_f_0.x += inp2_h_f_0.x;
    inp1_h_f_0.y += inp2_h_f_0.y;
    inp1_h_f_1.x += inp2_h_f_1.x;
    inp1_h_f_1.y += inp2_h_f_1.y;

    float2 val_f;
    __half2* val_h = reinterpret_cast<__half2*>(&val_f);

    val_h[0] = __float22half2_rn(inp1_h_f_0);
    val_h[1] = __float22half2_rn(inp1_h_f_1);

    float2* out_4 = reinterpret_cast<float2*>(out);
    out_4[row * row_stride + id] = val_f;
}

template <>
void launch_fused_add2<float>(float* out,
                              const float* inp1,
                              const float* inp2,
                              int batch_size,
                              int seq_length,
                              int hidden_dim,
                              cudaStream_t& stream)
{
    dim3 grid_dim(batch_size * seq_length);

    dim3 block_dim(hidden_dim / 4);

    fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(
        out, inp1, inp2, (batch_size * seq_length * hidden_dim), hidden_dim / 4);
}

template <>
void launch_fused_add2<__half>(__half* out,
                               const __half* inp1,
                               const __half* inp2,
                               int batch_size,
                               int seq_length,
                               int hidden_dim,
                               cudaStream_t& stream)
{
    dim3 grid_dim(batch_size * seq_length);

    dim3 block_dim(hidden_dim / 4);

    fused_add2_kernel<<<grid_dim, block_dim, 0, stream>>>(
        out, inp1, inp2, (batch_size * seq_length * hidden_dim), hidden_dim / 4);
}

__global__ void fused_add3_kernel(float* out,
                                  const float* inp1,
                                  const float* inp2,
                                  const float* inp3,
                                  int size,
                                  int row_stride)
{
    int row = blockIdx.x;
    int id = threadIdx.x;

    const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
    const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
    const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);

    float4* out_4 = reinterpret_cast<float4*>(out);

    float4 val;
    float4 inp1_reg = inp1_4[row * row_stride + id];
    float4 inp2_reg = inp2_4[row * row_stride + id];
    float4 inp3_reg = inp3_4[row * row_stride + id];

    val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x;
    val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y;
    val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z;
    val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w;

    out_4[row * row_stride + id] = val;
}

__global__ void fused_add3_kernel(__half* out,
                                  const __half* inp1,
                                  const __half* inp2,
                                  const __half* inp3,
                                  int size,
                                  int row_stride)
{
    int row = blockIdx.x;
    int id = threadIdx.x;
    const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
    const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
    const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);

    float2 inp1_4 = inp1_arr[row * row_stride + id];
    float2 inp2_4 = inp2_arr[row * row_stride + id];
    float2 inp3_4 = inp3_arr[row * row_stride + id];

    __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
    __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
    __half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);

    float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
    float2 inp1_h_f_1 = __half22float2(inp1_h[1]);

    float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
    float2 inp2_h_f_1 = __half22float2(inp2_h[1]);

    float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
    float2 inp3_h_f_1 = __half22float2(inp3_h[1]);

    inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x);
    inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y);
    inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x);
    inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y);

    float2 val_f;
    __half2* val_h = reinterpret_cast<__half2*>(&val_f);

    val_h[0] = __float22half2_rn(inp1_h_f_0);
    val_h[1] = __float22half2_rn(inp1_h_f_1);

    float2* out_4 = reinterpret_cast<float2*>(out);
    out_4[row * row_stride + id] = val_f;
}

template <>
void launch_fused_add3<float>(float* out,
                              const float* inp1,
                              const float* inp2,
                              const float* inp3,
                              int batch_size,
                              int seq_length,
                              int hidden_size,
                              cudaStream_t& stream)
{
    dim3 grid_dim(batch_size * seq_length);

    dim3 block_dim(hidden_size / 4);

    fused_add3_kernel<<<grid_dim, block_dim, 0, stream>>>(
        out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}

template <>
void launch_fused_add3<__half>(__half* out,
                               const __half* inp1,
                               const __half* inp2,
                               const __half* inp3,
                               int batch_size,
                               int seq_length,
                               int hidden_size,
                               cudaStream_t& stream)
{
    dim3 grid_dim(batch_size * seq_length);

    dim3 block_dim(hidden_size / 4);

    fused_add3_kernel<<<grid_dim, block_dim, 0, stream>>>(
        out, inp1, inp2, inp3, (batch_size * seq_length * hidden_size), hidden_size / 4);
}

__global__ void fused_add4_kernel(float* out,
                                  const float* inp1,
                                  const float* inp2,
                                  const float* inp3,
                                  const float* inp4,
                                  int size,
                                  int row_stride)
{
    int row = blockIdx.x;
    int id = threadIdx.x;

    const float4* inp1_4 = reinterpret_cast<const float4*>(inp1);
    const float4* inp2_4 = reinterpret_cast<const float4*>(inp2);
    const float4* inp3_4 = reinterpret_cast<const float4*>(inp3);
    const float4* inp4_4 = reinterpret_cast<const float4*>(inp4);
    float4* out_4 = reinterpret_cast<float4*>(out);

    float4 val;
    float4 inp1_reg = inp1_4[row * row_stride + id];
    float4 inp2_reg = inp2_4[row * row_stride + id];
    float4 inp3_reg = inp3_4[row * row_stride + id];
    float4 inp4_reg = inp4_4[row * row_stride + id];

    val.x = inp1_reg.x + inp2_reg.x + inp3_reg.x + inp4_reg.x;
    val.y = inp1_reg.y + inp2_reg.y + inp3_reg.y + inp4_reg.y;
    val.z = inp1_reg.z + inp2_reg.z + inp3_reg.z + inp4_reg.z;
    val.w = inp1_reg.w + inp2_reg.w + inp3_reg.w + inp4_reg.w;

    out_4[row * row_stride + id] = val;
}

__global__ void fused_add4_kernel(__half* out,
                                  const __half* inp1,
                                  const __half* inp2,
                                  const __half* inp3,
                                  const __half* inp4,
                                  int size,
                                  int row_stride)
{
    int row = blockIdx.x;
    int id = threadIdx.x;
    const float2* inp1_arr = reinterpret_cast<const float2*>(inp1);
    const float2* inp2_arr = reinterpret_cast<const float2*>(inp2);
    const float2* inp3_arr = reinterpret_cast<const float2*>(inp3);
    const float2* inp4_arr = reinterpret_cast<const float2*>(inp4);

    float2 inp1_4 = inp1_arr[row * row_stride + id];
    float2 inp2_4 = inp2_arr[row * row_stride + id];
    float2 inp3_4 = inp3_arr[row * row_stride + id];
    float2 inp4_4 = inp4_arr[row * row_stride + id];

    __half2* inp1_h = reinterpret_cast<__half2*>(&inp1_4);
    __half2* inp2_h = reinterpret_cast<__half2*>(&inp2_4);
    __half2* inp3_h = reinterpret_cast<__half2*>(&inp3_4);
    __half2* inp4_h = reinterpret_cast<__half2*>(&inp4_4);

    float2 inp1_h_f_0 = __half22float2(inp1_h[0]);
    float2 inp1_h_f_1 = __half22float2(inp1_h[1]);

    float2 inp2_h_f_0 = __half22float2(inp2_h[0]);
    float2 inp2_h_f_1 = __half22float2(inp2_h[1]);

    float2 inp3_h_f_0 = __half22float2(inp3_h[0]);
    float2 inp3_h_f_1 = __half22float2(inp3_h[1]);

    float2 inp4_h_f_0 = __half22float2(inp4_h[0]);
    float2 inp4_h_f_1 = __half22float2(inp4_h[1]);

    inp1_h_f_0.x += (inp2_h_f_0.x + inp3_h_f_0.x + inp4_h_f_0.x);
    inp1_h_f_0.y += (inp2_h_f_0.y + inp3_h_f_0.y + inp4_h_f_0.y);
    inp1_h_f_1.x += (inp2_h_f_1.x + inp3_h_f_1.x + inp4_h_f_1.x);
    inp1_h_f_1.y += (inp2_h_f_1.y + inp3_h_f_1.y + inp4_h_f_1.y);

    float2 val_f;
    __half2* val_h = reinterpret_cast<__half2*>(&val_f);

    val_h[0] = __float22half2_rn(inp1_h_f_0);
    val_h[1] = __float22half2_rn(inp1_h_f_1);

    float2* out_4 = reinterpret_cast<float2*>(out);
    out_4[row * row_stride + id] = val_f;
}

template <>
void launch_fused_add4<float>(float* out,
                              const float* inp1,
                              const float* inp2,
                              const float* inp3,
                              const float* inp4,
                              int batch_size,
                              int seq_length,
                              int hidden_size,
                              cudaStream_t& stream)
{
    dim3 grid_dim(batch_size * seq_length);

    dim3 block_dim(hidden_size / 4);

    fused_add4_kernel<<<grid_dim, block_dim, 0, stream>>>(
        out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}

template <>
void launch_fused_add4<__half>(__half* out,
                               const __half* inp1,
                               const __half* inp2,
                               const __half* inp3,
                               const __half* inp4,
                               int batch_size,
                               int seq_length,
                               int hidden_size,
                               cudaStream_t& stream)
{
    dim3 grid_dim(batch_size * seq_length);

    dim3 block_dim(hidden_size / 4);

    fused_add4_kernel<<<grid_dim, block_dim, 0, stream>>>(
        out, inp1, inp2, inp3, inp4, (batch_size * seq_length * hidden_size), hidden_size / 4);
}