diff --git a/csrc/includes/conversion_utils.h b/csrc/includes/conversion_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..4f952e5f0650a4929c32aa2f0b1a1c2d8f6ad7cd --- /dev/null +++ b/csrc/includes/conversion_utils.h @@ -0,0 +1,625 @@ +/* +Copyright 2022 The Microsoft DeepSpeed Team +*/ + +#pragma once + +#include "ds_kernel_utils.h" + +#include +#include + +#ifdef BF16_AVAILABLE +#include +#endif + +namespace conversion { + +// Basic primitive for constructing conversions +template +DS_D_INLINE TO to(FROM val) +{ + return to(val); +} + +// Specializations + +/********************* Identity Conversions *********************/ +/* +Identity conversions are useful in templated functions where we might have +a fixed destination type. For example, I might have a kernel that accepts +__half, __nv_bfloat16, and float but always want to do the core computation +at floating point: + +T mem_value = input[idx]; +float compute_value = conversion::to(mem_value); + +In practice, we should be able to elide the second template parameter: +float compute_val = conversion::to(mem_value); + +In this case, we need an implementation to handle the T = float case + +NOTE: The type inferencing system appears to be unable to handle inferring the first +template parameter, even in the trivial case. +*/ + +// Floating point types +template <> +DS_D_INLINE double to(double val) +{ + return val; +} +template <> +DS_D_INLINE float to(float val) +{ + return val; +} +template <> +DS_D_INLINE __half to(__half val) +{ + return val; +} +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 to(__nv_bfloat16 val) +{ + return val; +} +#endif + +// Integer types +template <> +DS_D_INLINE int8_t to(int8_t val) +{ + return val; +} +template <> +DS_D_INLINE uint8_t to(uint8_t val) +{ + return val; +} +template <> +DS_D_INLINE int16_t to(int16_t val) +{ + return val; +} +template <> +DS_D_INLINE uint16_t to(uint16_t val) +{ + return val; +} +template <> +DS_D_INLINE int32_t to(int32_t val) +{ + return val; +} +template <> +DS_D_INLINE uint32_t to(uint32_t val) +{ + return val; +} +template <> +DS_D_INLINE int64_t to(int64_t val) +{ + return val; +} +template <> +DS_D_INLINE uint64_t to(uint64_t val) +{ + return val; +} + +// TODO: evaluate if we want bools + +/********************* To Double Conversions *********************/ + +// * to double variants + +// Would normally like to not use C cast, but this is an important enough conversion +// to keep +template <> +DS_D_INLINE double to(float val) +{ +#ifdef PTX_AVAILABLE + double ret_val; + asm("ctv.rn.f64.f32 %0, %1;\n" : "=d"(ret_val) : "f"(val)); + return ret_val; +#else + return double(val); +#endif +} +// Note: there is a CVT instruction for __half -> double, but there's no inline interface +// for passing a single half value +template <> +DS_D_INLINE double to(__half val) +{ + return to(__half2float(val)); +} +template <> +DS_D_INLINE double to(int64_t val) +{ + return __ll2double_rn(val); +} +template <> +DS_D_INLINE double to(int32_t val) +{ + return __int2double_rn(val); +} +template <> +DS_D_INLINE double to(int16_t val) +{ + return __int2double_rn(val); +} +template <> +DS_D_INLINE double to(int8_t val) +{ + return __int2double_rn(val); +} +template <> +DS_D_INLINE double to(uint64_t val) +{ + return __ull2double_rn(val); +} +template <> +DS_D_INLINE double to(uint32_t val) +{ + return __uint2double_rn(val); +} +template <> +DS_D_INLINE double to(uint16_t val) +{ + return __uint2double_rn(val); +} +template <> +DS_D_INLINE double to(uint8_t val) +{ + return __uint2double_rn(val); +} + +// Same applies here +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE double to(__nv_bfloat16 val) +{ + return to(__bfloat162float(val)); +} +#endif + +/********************* To Float Conversions *********************/ + +template <> +DS_D_INLINE float to(double val) +{ + return __double2float_rn(val); +} +template <> +DS_D_INLINE float to(__half val) +{ + return __half2float(val); +} +template <> +DS_D_INLINE float to(int64_t val) +{ + return __ll2float_rn(val); +} +template <> +DS_D_INLINE float to(int32_t val) +{ + return __int2float_rn(val); +} +template <> +DS_D_INLINE float to(int16_t val) +{ + return __int2float_rn(val); +} +template <> +DS_D_INLINE float to(int8_t val) +{ + return __int2float_rn(val); +} +template <> +DS_D_INLINE float to(uint64_t val) +{ + return __ull2float_rn(val); +} +template <> +DS_D_INLINE float to(uint32_t val) +{ + return __uint2float_rn(val); +} +template <> +DS_D_INLINE float to(uint16_t val) +{ + return __uint2float_rn(val); +} +template <> +DS_D_INLINE float to(uint8_t val) +{ + return __uint2float_rn(val); +} + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE float to(__nv_bfloat16 val) +{ + return __bfloat162float(val); +} +#endif + +/********************* To Float2 Conversions *********************/ +template <> +DS_D_INLINE float2 to(__half2 val) +{ + return __half22float2(val); +} + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE float2 to(__nv_bfloat162 val) +{ + return __bfloat1622float2(val); +} +#endif + +/********************* To Half Conversions *********************/ +template <> +DS_D_INLINE __half to(double val) +{ + return __double2half(val); +} +template <> +DS_D_INLINE __half to(float val) +{ + return __float2half(val); +} +template <> +DS_D_INLINE __half to(int64_t val) +{ + return __ll2half_rn(val); +} +template <> +DS_D_INLINE __half to(int32_t val) +{ + return __int2half_rn(val); +} +template <> +DS_D_INLINE __half to(int16_t val) +{ + return __short2half_rn(val); +} +template <> +DS_D_INLINE __half to(int8_t val) +{ + return __int2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint64_t val) +{ + return __ull2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint32_t val) +{ + return __uint2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint16_t val) +{ + return __ushort2half_rn(val); +} +template <> +DS_D_INLINE __half to(uint8_t val) +{ + return __uint2half_rn(val); +} + +#ifdef BF16_AVAILABLE +// No direct conversion +template <> +DS_D_INLINE __half to(__nv_bfloat16 val) +{ + return to<__half>(to(val)); +} +#endif + +/********************* To Half2 Conversions *********************/ +template <> +DS_D_INLINE __half2 to(float2 val) +{ + return __float22half2_rn(val); +} + +#ifdef BF16_AVAILABLE +// No direct conversion +template <> +DS_D_INLINE __half2 to(__nv_bfloat162 val) +{ + return to<__half2>(to(val)); +} +#endif + +/********************* To BF16 Conversions *********************/ +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat16 to(double val) +{ + return __double2bfloat16(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(float val) +{ + return __float2bfloat16(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int64_t val) +{ + return __ll2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int32_t val) +{ + return __int2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int16_t val) +{ + return __short2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(int8_t val) +{ + return __int2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint64_t val) +{ + return __ull2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint32_t val) +{ + return __uint2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint16_t val) +{ + return __ushort2bfloat16_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat16 to(uint8_t val) +{ + return __uint2bfloat16_rn(val); +} +#endif + +/********************* To BF162 Conversions *********************/ +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE __nv_bfloat162 to(float2 val) +{ + return __float22bfloat162_rn(val); +} +template <> +DS_D_INLINE __nv_bfloat162 to(__half2 val) +{ + return to<__nv_bfloat162>(to(val)); +} +#endif + +/********************* To INT64_T Conversions *********************/ +template <> +DS_D_INLINE int64_t to(double val) +{ + return __double2ll_rn(val); +} +template <> +DS_D_INLINE int64_t to(float val) +{ + return __float2ll_rn(val); +} +template <> +DS_D_INLINE int64_t to(__half val) +{ + return __half2ll_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int64_t to(__nv_bfloat16 val) +{ + return __bfloat162ll_rn(val); +} +#endif + +/********************* To INT32_T Conversions *********************/ +template <> +DS_D_INLINE int32_t to(double val) +{ + return __double2int_rn(val); +} +template <> +DS_D_INLINE int32_t to(float val) +{ + return __float2int_rn(val); +} +template <> +DS_D_INLINE int32_t to(__half val) +{ + return __half2int_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int32_t to(__nv_bfloat16 val) +{ + return __bfloat162int_rn(val); +} +#endif + +/********************* To INT16_T Conversions *********************/ +template <> +DS_D_INLINE int16_t to(double val) +{ + return __double2int_rn(val); +} +template <> +DS_D_INLINE int16_t to(float val) +{ + return __float2int_rn(val); +} +template <> +DS_D_INLINE int16_t to(__half val) +{ + return __half2int_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int16_t to(__nv_bfloat16 val) +{ + return __bfloat162int_rn(val); +} +#endif + +/********************* To INT8_T Conversions *********************/ +template <> +DS_D_INLINE int8_t to(double val) +{ + return __double2int_rn(val); +} +template <> +DS_D_INLINE int8_t to(float val) +{ + return __float2int_rn(val); +} +template <> +DS_D_INLINE int8_t to(__half val) +{ + return __half2int_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE int8_t to(__nv_bfloat16 val) +{ + return __bfloat162int_rn(val); +} +#endif + +/********************* To UINT64_T Conversions *********************/ +template <> +DS_D_INLINE uint64_t to(double val) +{ + return __double2ull_rn(val); +} +template <> +DS_D_INLINE uint64_t to(float val) +{ + return __float2ull_rn(val); +} +template <> +DS_D_INLINE uint64_t to(__half val) +{ + return __half2ull_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint64_t to(__nv_bfloat16 val) +{ + return __bfloat162ull_rn(val); +} +#endif + +/********************* To UINT32_T Conversions *********************/ +template <> +DS_D_INLINE uint32_t to(double val) +{ + return __double2uint_rn(val); +} +template <> +DS_D_INLINE uint32_t to(float val) +{ + return __float2uint_rn(val); +} +template <> +DS_D_INLINE uint32_t to(__half val) +{ + return __half2uint_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint32_t to(__nv_bfloat16 val) +{ + return __bfloat162uint_rn(val); +} +#endif + +/********************* To UINT16_T Conversions *********************/ +template <> +DS_D_INLINE uint16_t to(double val) +{ + return __double2uint_rn(val); +} +template <> +DS_D_INLINE uint16_t to(float val) +{ + return __float2uint_rn(val); +} +template <> +DS_D_INLINE uint16_t to(__half val) +{ + return __half2uint_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint16_t to(__nv_bfloat16 val) +{ + return __bfloat162uint_rn(val); +} +#endif + +/********************* To UINT8_T Conversions *********************/ +template <> +DS_D_INLINE uint8_t to(double val) +{ + return __double2uint_rn(val); +} +template <> +DS_D_INLINE uint8_t to(float val) +{ + return __float2uint_rn(val); +} +template <> +DS_D_INLINE uint8_t to(__half val) +{ + return __half2uint_rn(val); +} +// No direct support for integer casts at the C++ level and I don't feel they're so important +// to demand an PTX at this time + +#ifdef BF16_AVAILABLE +template <> +DS_D_INLINE uint8_t to(__nv_bfloat16 val) +{ + return __bfloat162uint_rn(val); +} +#endif + +} // namespace conversion diff --git a/csrc/includes/custom_cuda_layers.h b/csrc/includes/custom_cuda_layers.h index 9edc26b5d63b9eadfe60cc06047cb71cd616306e..006adadfd862f33a136e73a7873dcaebad413c94 100644 --- a/csrc/includes/custom_cuda_layers.h +++ b/csrc/includes/custom_cuda_layers.h @@ -1,21 +1,17 @@ +/* +Copyright 2022 The Microsoft DeepSpeed Team +*/ + #pragma once +#include "ds_kernel_utils.h" + #include #include +#include #include #include -#ifdef __HIP_PLATFORM_HCC__ -#define HALF_PRECISION_AVAILABLE = 1 -#include -#else -#if __CUDA_ARCH__ >= 530 -#define HALF_PRECISION_AVAILABLE = 1 -#endif -#include -#endif -#include - #include "context.h" #include "cublas_wrappers.h" diff --git a/csrc/includes/ds_kernel_utils.h b/csrc/includes/ds_kernel_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..9b0b74bd914e6abc85bce51b59748066be3e30ad --- /dev/null +++ b/csrc/includes/ds_kernel_utils.h @@ -0,0 +1,39 @@ +/* +Copyright 2022 The Microsoft DeepSpeed Team + +Centralized header file for preprocessor macros and constants +used throughout the codebase. +*/ + +#pragma once + +#include + +#define DS_HD_INLINE __host__ __device__ __forceinline__ +#define DS_D_INLINE __device__ __forceinline__ + +#ifdef __HIP_PLATFORM_HCC__ + +// constexpr variant of warpSize for templating +constexpr int hw_warp_size = 64; +#define HALF_PRECISION_AVAILABLE = 1 +#include + +#else // !__HIP_PLATFORM_HCC__ + +// constexpr variant of warpSize for templating +constexpr int hw_warp_size = 32; + +#if __CUDA_ARCH__ >= 530 +#define HALF_PRECISION_AVAILABLE = 1 +#define PTX_AVAILABLE +#endif // __CUDA_ARCH__ >= 530 + +#if __CUDA_ARCH__ >= 800 +#define ASYNC_COPY_AVAILABLE +#define BF16_AVAILABLE +#endif // __CUDA_ARCH__ >= 800 + +#include + +#endif //__HIP_PLATFORM_HCC__ diff --git a/csrc/includes/memory_access_utils.h b/csrc/includes/memory_access_utils.h index 548f9f5c7d3b1bc1b248408b35010fb3db40a3f1..a3f858f1f8704acc8da88d98e8429f0116953806 100644 --- a/csrc/includes/memory_access_utils.h +++ b/csrc/includes/memory_access_utils.h @@ -5,15 +5,7 @@ Copyright 2022 The Microsoft DeepSpeed Team #pragma once #include - -#if __CUDA_ARCH__ >= 800 -#define ASYNC_COPY_AVAILABLE -#endif - -// Discuss basic GPUs to support -#if __CUDA_ARCH__ >= 530 -#define PTX_AVAILABLE -#endif +#include "ds_kernel_utils.h" /////////////////////////////// Memory Access Utils /////////////////////////////// namespace mem_access { diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index 8bc58769ede7afba47ccc94d2ccab39f57c4375e..b3bcd30f57186a486e122cb337994ca726dc4a48 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -2,6 +2,7 @@ Copyright 2022 The Microsoft DeepSpeed Team */ +#include "conversion_utils.h" #include "inference_cuda_layers.h" #include "memory_access_utils.h" @@ -16,58 +17,29 @@ inline __device__ float gelu(const float x) return x * 0.5f * (1.0f + tanhf(sqrt_param * (x + mul_param * x * x * x))); } -__global__ void fused_bias_gelu(float* input, - const float* bias, - int total_count, - int intermediate_size) +template +__global__ void fused_bias_gelu(T* input, const T* bias, int total_count, int intermediate_size) { // Input restriction: intermediate_size % vals_per_access == 0 constexpr int granularity = 16; - constexpr int vals_per_access = granularity / sizeof(float); - const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access; + constexpr int values_per_access = granularity / sizeof(T); + const int offset = (blockIdx.x * blockDim.x + threadIdx.x) * values_per_access; if (offset < total_count) { - float data[vals_per_access]; - float data_bias[vals_per_access]; + T data[values_per_access]; + T data_bias[values_per_access]; mem_access::load_global(data, input + offset); mem_access::load_global(data_bias, bias + (offset % intermediate_size)); #pragma unroll - for (int i = 0; i < vals_per_access; i++) { data[i] = gelu(data[i] + data_bias[i]); } - - mem_access::store_global(input + offset, data); - } -} - -__global__ void fused_bias_gelu(__half* input, - const __half* bias, - int total_count, - int intermediate_size) -{ - // Input restriction: intermediate_size % vals_per_access == 0 - // This kernel doubles the per-thread ALU workload as compared to the float implementation -#ifdef HALF_PRECISION_AVAILABLE - constexpr int granularity = 16; - constexpr int vals_per_access = granularity / sizeof(__half); - int offset = (blockIdx.x * blockDim.x + threadIdx.x) * vals_per_access; - - if (offset < total_count) { - // Divide by 2 since we store two values per __half2 - __half2 data[vals_per_access / 2]; - __half2 bias_data[vals_per_access / 2]; - mem_access::load_global(data, input + offset); - mem_access::load_global(bias_data, bias + (offset % intermediate_size)); - -#pragma unroll - for (int i = 0; i < vals_per_access / 2; i++) { - float2 data_f = __half22float2(data[i]); - float2 bias_f = __half22float2(bias_data[i]); - data[i] = __floats2half2_rn(gelu(data_f.x + bias_f.x), gelu(data_f.y + bias_f.y)); + for (int i = 0; i < values_per_access; i++) { + float data_f = conversion::to(data[i]); + float bias_f = conversion::to(data_bias[i]); + data[i] = conversion::to(gelu(data_f + bias_f)); } mem_access::store_global(input + offset, data); } -#endif } template diff --git a/csrc/transformer/inference/includes/inference_context.h b/csrc/transformer/inference/includes/inference_context.h index fe0616a75c2ee9da87a46cf39130f9d4783f2cf4..6725fc72fb7cff917b814376e026e204bb5757d0 100644 --- a/csrc/transformer/inference/includes/inference_context.h +++ b/csrc/transformer/inference/includes/inference_context.h @@ -16,6 +16,8 @@ Copyright 2022 The Microsoft DeepSpeed Team #define GIGABYTE (1024 * 1024 * 1024) #define MAX_OUT_TOKENS 8192 + +// TODO: refactor out #define WARP_SIZE 32 #define CUDA_CHECK(callstr) \ diff --git a/csrc/transformer/inference/includes/inference_cuda_layers.h b/csrc/transformer/inference/includes/inference_cuda_layers.h index 6302ceb2935da827a93f2c6414a1e5cc5f7afb94..1f86e2d858d13eeb7d3789c1744da8c259407c52 100644 --- a/csrc/transformer/inference/includes/inference_cuda_layers.h +++ b/csrc/transformer/inference/includes/inference_cuda_layers.h @@ -4,15 +4,7 @@ Copyright 2022 The Microsoft DeepSpeed Team #pragma once -#ifdef __HIP_PLATFORM_HCC__ -#define HALF_PRECISION_AVAILABLE = 1 -#include -#else -#if __CUDA_ARCH__ >= 530 -#define HALF_PRECISION_AVAILABLE = 1 -#endif -#include -#endif +#include "ds_kernel_utils.h" #include #include