未验证 提交 78d4ca1f 编写于 作者: L lokoppakmsft 提交者: GitHub

Deepspeed quantization library v0.1 (#2450)

* Initial commit Deepspeed quantization library

* Match function signatures

* Add Quantization Kernel

* adding offset comparision and precommit changes

* format fixes

* FIle name changes

* pt_binding_changes

* test name change

* Integer quantization, minor refactors

* Add directed test_case

* format fixes

* Move param calculation to constructor of params class

* Use local function and add elemsPerBlock

* change function to be specalized

* sub block reduce

* add new schedule

* Add new schedule test case

* fix illegal writes in sch1

* Style fixes in comments
Co-authored-by: NConnor Holmes <connorholmes@microsoft.com>
上级 d40a15fc
......@@ -5,6 +5,7 @@ Copyright 2022 The Microsoft DeepSpeed Team
#pragma once
#include "ds_kernel_utils.h"
#include "quantization.h"
#include <cuda.h>
#include <cuda_fp16.h>
......@@ -41,30 +42,6 @@ Copyright 2022 The Microsoft DeepSpeed Team
#define WARP_SIZE_BITS 5
template <typename T>
void launch_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template <typename T>
void launch_sr_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template <typename T>
void launch_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template <typename T>
void launch_sr_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
// Fused bias add with gelu activation
template <typename T>
void launch_bias_gelu(const T* input,
......
#pragma once
#include <cuda_fp16.h>
#include "ds_kernel_utils.h"
namespace quantize {
enum class Type { Symmetric, Asymmetric, IntegerSymmetric };
struct PackedInt4 {
int8_t high : 4;
int8_t low : 4;
};
DS_HD_INLINE bool requires_offset(Type qType) { return qType == Type::Asymmetric; }
} // namespace quantize
template <int numBits, quantize::Type qType>
void launch_quant(int8_t* output_data,
float* params,
const __half* input_data,
int groups,
int elems_per_group,
cudaStream_t stream);
template <typename T>
void launch_fake_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template <typename T>
void launch_sr_fake_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template <typename T>
void launch_fake_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template <typename T>
void launch_sr_fake_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
#include <cstdio>
#include "conversion_utils.h"
#include "ds_kernel_utils.h"
#include "memory_access_utils.h"
#include "quantization.h"
#include "reduction_utils.h"
#pragma once
using rop = reduce::ROpType;
namespace quantize {
constexpr int granularity = 16;
constexpr int h_per_load = granularity / sizeof(__half);
constexpr int h2_per_load = granularity / sizeof(__half2);
constexpr int max_threads = 1024;
/*
Class to hold the quantization parameters for a given tensor.
Holds the implementation of the quantization operation.
*/
template <Type qType, int numBits>
class Params {
public:
/*
Quantization implementation, supports
1) 4 Bit
2) 8 Bit
3) Symmetric
4) Asymmetric
Function Arguments :
val : The __half value to quantize.
*/
DS_D_INLINE int8_t quantize(__half val);
DS_D_INLINE void store(float* params, int group_index);
};
template <int numBits>
class Params<Type::Symmetric, numBits> {
public:
float scale;
DS_D_INLINE Params(float max)
{
if (max == 0) {
scale = 1.0;
} else {
scale = (1 << numBits) / (2 * max);
}
}
DS_D_INLINE int8_t quantize(__half val)
{
constexpr int32_t q_min = -(1 << (numBits - 1));
constexpr int32_t q_max = (1 << (numBits - 1)) - 1;
float val_f = conversion::to<float>(val) * scale;
int32_t data_i32 = conversion::to<int32_t>(val_f);
data_i32 = min(max(data_i32, q_min), q_max);
return (int8_t)data_i32;
}
DS_D_INLINE void store(float* params, int group_index)
{
const float store_scale = 1 / scale;
mem_access::store_global<sizeof(float)>(params + group_index, &store_scale);
}
};
template <int numBits>
class Params<Type::IntegerSymmetric, numBits> {
public:
int32_t scale;
DS_D_INLINE Params(float max) { scale = conversion::to<int32_t>(max + 0.5f); }
DS_D_INLINE int8_t quantize(__half val)
{
constexpr int32_t q_max = (1 << (numBits - 1)) - 1;
float val_f = conversion::to<float>(val) * q_max;
float scaled_val = val_f / conversion::to<float>(scale);
int32_t data_i32 = conversion::to<int32_t>(scaled_val);
return (int8_t)data_i32;
}
DS_D_INLINE void store(float* params, int group_index)
{
mem_access::store_global<sizeof(float)>(params + group_index, &scale);
}
};
template <int numBits>
class Params<Type::Asymmetric, numBits> {
public:
float scale;
float offset;
DS_D_INLINE Params(float max, float min)
{
if (max == min) {
scale = 1.0;
} else {
scale = (1 << numBits) / (max - min);
}
offset = -(1 << (numBits - 1)) - (min * scale);
}
DS_D_INLINE int8_t quantize(__half val)
{
constexpr int32_t q_min = -(1 << (numBits - 1));
constexpr int32_t q_max = (1 << (numBits - 1)) - 1;
float val_f = conversion::to<float>(val) * scale + offset;
int32_t data_i32 = conversion::to<int32_t>(val_f);
data_i32 = min(max(data_i32, q_min), q_max);
return (int8_t)data_i32;
}
DS_D_INLINE void store(float* params, int group_index)
{
const float store_scale = 1 / scale;
mem_access::store_global<sizeof(float)>(params + 2 * group_index, &store_scale);
mem_access::store_global<sizeof(float)>(params + 2 * group_index + 1, &offset);
}
};
/*
Group stats tracks the necessary statistics about the quantized group
to abstract the particulars for the main loop.
*/
template <Type qType>
class GroupStats {
public:
DS_D_INLINE void update(__half2 val);
DS_D_INLINE void reduce(cg::thread_block& tb, cg::thread_block_tile<hw_warp_size>& warp);
};
template <>
class GroupStats<Type::Symmetric> {
public:
// Symmetric quantization only tracks the maximum absolute value
__half2 cur_max;
float max;
/*
Technically, this would give bad results if there
are 0 values to process since the reduction would
give -inf instead of 0. We do not consider this
to be a reasonable edge case.
*/
DS_D_INLINE GroupStats() { cur_max = reduce::init<rop::Max, __half2>(); }
/*
Updated the running absmax used to calculate params.
Function Arguments :
val : The __half2 value to update the running min and max with.
*/
DS_D_INLINE void update(__half2 val)
{
cur_max = reduce::element<rop::Max>(cur_max, __habs2(val));
}
/*
Function to return calculated quantization params.
Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
Function Arguments :
tb - Threadblock object. cg::thread_block
warp - Warp object. cg::thread_block_tile<hw_warp_size>
*/
template <int numBits, int threads_per_group>
DS_D_INLINE Params<Type::Symmetric, numBits> get_params(
cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp)
{
const float2 partial_max = conversion::to<float2>(cur_max);
float max = reduce::element<rop::Max>(partial_max.x, partial_max.y);
reduce::partitioned_block<rop::Max, threads_per_group>(tb, warp, max);
Params<Type::Symmetric, numBits> params(max);
return params;
}
};
template <>
class GroupStats<Type::IntegerSymmetric> {
public:
// Symmetric quantization only tracks the maximum absolute value
__half2 cur_max;
/*
Technically, this would give bad results if there
are 0 values to process since the reduction would
give -inf instead of 0. We do not consider this
to be a reasonable edge case.
*/
DS_D_INLINE GroupStats() { cur_max = reduce::init<rop::Max, __half2>(); }
/*
Updated the running absmax used to calculate params.
Function Arguments :
val : The __half2 value to update the running min and max with.
*/
DS_D_INLINE void update(__half2 val)
{
cur_max = reduce::element<rop::Max>(cur_max, __habs2(val));
}
/*
Function to return calculated quantization params.
Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
Function Arguments :
tb - Threadblock object. cg::thread_block
warp - Warp object. cg::thread_block_tile<hw_warp_size>
*/
template <int numBits, int threads_per_group>
DS_D_INLINE Params<Type::IntegerSymmetric, numBits> get_params(
cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp)
{
const float2 partial_max = conversion::to<float2>(cur_max);
float max = reduce::element<rop::Max>(partial_max.x, partial_max.y);
reduce::partitioned_block<rop::Max, threads_per_group>(tb, warp, max);
Params<Type::IntegerSymmetric, numBits> params(max);
return params;
}
};
template <>
class GroupStats<Type::Asymmetric> {
public:
__half2 cur_max;
__half2 cur_min;
/*
Initialize cur_max to -inf, cur_min to inf since
we are doing a true range analysis.
*/
DS_D_INLINE GroupStats()
{
cur_max = reduce::init<rop::Max, __half2>();
cur_min = reduce::init<rop::Min, __half2>();
}
/*
Updated the running min and max used to calculate params.
Function Arguments :
val : The __half2 value to update the running min and max with.
*/
DS_D_INLINE void update(__half2 val)
{
cur_max = reduce::element<rop::Max>(cur_max, val);
cur_min = reduce::element<rop::Min>(cur_min, val);
}
/*
Function to return calculated quantization params.
Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
Function Arguments :
tb - Threadblock object. cg::thread_block
warp - Warp object. cg::thread_block_tile<hw_warp_size>
*/
template <int numBits, int threads_per_group>
DS_D_INLINE Params<Type::Asymmetric, numBits> get_params(
cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp)
{
const float2 partial_max = conversion::to<float2>(cur_max);
float max = reduce::element<rop::Max>(partial_max.x, partial_max.y);
const float2 partial_min = conversion::to<float2>(cur_min);
float min = reduce::element<rop::Min>(partial_min.x, partial_min.y);
reduce::partitioned_block<rop::Max, rop::Min, threads_per_group>(tb, warp, max, min);
Params<Type::Asymmetric, numBits> params(max, min);
return params;
}
};
/*
Device function that quantizes 16 bytes of __half type input data.
Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
Function Arguments :
local_output - Pointer to shared memory to store quantized data. int8_t*
data - Pointer to input data. __half*
Params - Parameters for quantization. Params<qType, numBits>
*/
template <int numBits, Type qType>
DS_D_INLINE void _chunk(int8_t* local_output, const __half* data, Params<qType, numBits> q_params);
/*
Device function that quantizes 16 bytes of __half2 type input data.
Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
Function Arguments :
local_output - Pointer to shared memory to store quantized data. int8_t*
data - Pointer to input data. __half2*
Params - Parameters for quantization. Params<qType, numBits>
*/
template <int numBits, Type qType>
DS_D_INLINE void _chunk(int8_t* local_output, const __half2* data, Params<qType, numBits> q_params);
/*
Helper function to do serial reduction on register-file arrays.
Template Arguments :
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
numChunks - Number of bits in quantized element. int : 8 or 4
Function Arguments :
local_buffer - Pointer memory with input half2 data to be quantized.
*/
template <Type qType, int numChunks>
DS_D_INLINE GroupStats<qType> _local_serial_reduce(__half2* local_buffer);
/*
The main loop of the kernel that quantizes array in local memory of __half2 type input data, when
Quantization parameters are pre-computed.
Template Arguments :
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
numBits - Number of bits in quantized element. int : 8 or 4
numChunks - Number of chunks(16 bytes of Input data). int : 8 or 4
Function Arguments :
local_buffer - Pointer memory with input half2 data to be quantized.
scales - Pointer to output scales.
offsets - Pointer to output offsets.
output_data - Pointer to output data.
elems_per_group - Number of elements to quantize in a group.
q_params - Quantization parameters.
*/
template <int numBits, Type qType, int numChunks, int threads_per_group, int max_threads>
DS_D_INLINE void local_array(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
__half2* local_buffer,
float* __restrict__ scales,
float* __restrict__ offsets,
int8_t* __restrict__ output_data,
const int& elems_per_group,
const int& groups,
Params<qType, numBits> q_params);
/*
The main loop of the kernel that quantizes array in local memory of __half2 type input data.
This function computes quantization parameters for each group.
Template Arguments :
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
numBits - Number of bits in quantized element. int : 8 or 4
numChunks - Number of chunks(16 bytes of Input data). int : 8 or 4
Function Arguments :
local_buffer - Pointer memory with input half2 data to be quantized.
scales - Pointer to output scales.
offsets - Pointer to output offsets.
output_data - Pointer to output data.
elems_per_group - Number of elements to quantize in a group.
*/
template <Type qType, int numBits, int numChunks, int threads_per_group, int max_threads>
__device__ void local_array(__half2* local_buffer,
float* __restrict__ scales,
float* __restrict__ offsets,
int8_t* __restrict__ output_data,
const int& elems_per_group,
const int& groups);
template <int numBits, Type qType>
DS_D_INLINE void _chunk(int8_t* local_output, const __half* data, Params<qType, numBits> q_params)
{
constexpr int32_t elems = 16 / sizeof(__half);
constexpr int32_t num_elems_packed = 8 / numBits;
#pragma unroll
for (int i = 0, oi = 0; i < elems; i += num_elems_packed, oi++) {
if (num_elems_packed == 1) {
// TODO(cmikeh2): refactor to use conversion utils
local_output[i] = q_params.quantize(data[i]);
} else if (num_elems_packed == 2) {
int8_t data_i8_1 = q_params.quantize(data[i]);
int8_t data_i8_2 = q_params.quantize(data[i + 1]);
auto data_i8 = PackedInt4{data_i8_2, data_i8_1};
local_output[oi] = *((int8_t*)(&data_i8));
}
}
}
template <int numBits, Type qType>
DS_D_INLINE void _chunk(int8_t* local_output, const __half2* data, Params<qType, numBits> q_params)
{
const __half* data_cast = reinterpret_cast<const __half*>(data);
_chunk<numBits>(local_output, data_cast, q_params);
}
template <Type qType, int numChunks>
DS_D_INLINE GroupStats<qType> _local_serial_reduce(__half2* local_buffer)
{
GroupStats<qType> stats;
#pragma unroll
for (int i = 0; i < numChunks * h2_per_load; i++) { stats.update(local_buffer[i]); }
return stats;
}
template <Type qType, int numBits, int numChunks, int threads_per_group, int max_threads>
DS_D_INLINE void local_array(cg::thread_block& tb,
cg::thread_block_tile<hw_warp_size>& warp,
__half2* local_buffer,
float* __restrict__ global_params,
int8_t* __restrict__ output_data,
const int& elems_per_group,
const int& groups,
Params<qType, numBits> q_params)
{
constexpr int num_ele_int8 = 8 / numBits;
constexpr int num_int8_out = quantize::h_per_load / num_ele_int8;
// Indexing offsets
const int block_num =
(tb.group_index().x * max_threads / threads_per_group) + tb.thread_index().y;
const int block_offset = block_num * elems_per_group;
const int elem_offset = tb.thread_index().x * quantize::h_per_load;
const int base_offset = (block_offset + elem_offset) / num_ele_int8;
const int stride = tb.size() * quantize::h_per_load / num_ele_int8;
int8_t local_output[num_int8_out];
if (tb.thread_index().x == 0 && block_num < groups) {
q_params.store(
global_params,
(tb.group_index().x * max_threads / threads_per_group) + tb.thread_index().y);
}
#pragma unroll
for (int i = 0; i < numChunks; i++) {
if (elem_offset + i * stride * num_ele_int8 < elems_per_group && block_num < groups) {
quantize::_chunk<numBits, qType>(
local_output, local_buffer + i * quantize::h2_per_load, q_params);
mem_access::store_global<num_int8_out>(output_data + (base_offset + i * stride),
local_output);
}
}
}
template <Type qType,
int numBits,
int numChunks,
int threads_per_group = max_threads,
int max_threads = 256>
__device__ void local_array(__half2* local_buffer,
float* __restrict__ global_params,
int8_t* __restrict__ output_data,
const int& elems_per_group,
const int& groups)
{
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
auto group_stats = _local_serial_reduce<qType, numChunks>(local_buffer);
auto params = group_stats.template get_params<numBits, threads_per_group>(tb, warp);
quantize::local_array<qType, numBits, numChunks, threads_per_group, max_threads>(
tb, warp, local_buffer, global_params, output_data, elems_per_group, groups, params);
}
} // namespace quantize
......@@ -514,11 +514,11 @@ DS_D_INLINE void partitioned_block(cg::thread_block& tb,
float& val)
{
if (num_threads <= hw_warp_size) {
_warp<Op, num_threads>(warp, val);
_warp<Op, num_threads>(warp, &val);
} else {
constexpr int num_warps = num_threads / hw_warp_size;
const int warp_offset = warp.meta_group_rank() & ~(num_warps - 1);
_block<num_warps, Op>(tb, warp, val, warp_offset);
_block<num_warps, Op>(tb, warp, &val, warp_offset);
}
}
......
......@@ -4,7 +4,7 @@
namespace cg = cooperative_groups;
__global__ void quantize_kernel(__half* vals, int group_size, int num_bits)
__global__ void fake_quantize_kernel(__half* vals, int group_size, int num_bits)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
......@@ -82,7 +82,7 @@ __global__ void quantize_kernel(__half* vals, int group_size, int num_bits)
#endif
}
__global__ void quantize_kernel(float* vals, int group_size, int num_bits)
__global__ void fake_quantize_kernel(float* vals, int group_size, int num_bits)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
......@@ -162,34 +162,35 @@ __global__ void quantize_kernel(float* vals, int group_size, int num_bits)
}
template <typename T>
void launch_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream)
void launch_fake_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream)
{
dim3 grid_dim(group_num);
dim3 block_dim(1024);
quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(vals, total_count / group_num, num_bits);
fake_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, total_count / group_num, num_bits);
}
template void launch_quantize_kernel(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_quantize_kernel(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
__global__ void sr_quantize_kernel(__half* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
template void launch_fake_quantize_kernel(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_fake_quantize_kernel(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
__global__ void sr_fake_quantize_kernel(__half* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
......@@ -321,11 +322,11 @@ __global__ void sr_quantize_kernel(__half* vals,
#endif
}
__global__ void sr_quantize_kernel(float* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
__global__ void sr_fake_quantize_kernel(float* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
......@@ -441,11 +442,11 @@ __global__ void sr_quantize_kernel(float* vals,
}
template <typename T>
void launch_sr_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream)
void launch_sr_fake_quantize_kernel(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream)
{
dim3 block_dim(1024);
dim3 grid_dim(group_num);
......@@ -453,21 +454,21 @@ void launch_sr_quantize_kernel(T* vals,
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
sr_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
sr_fake_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
}
template void launch_sr_quantize_kernel(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_sr_quantize_kernel(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_sr_fake_quantize_kernel(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_sr_fake_quantize_kernel(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
__global__ void quantize_kernel_asym(__half* vals, int group_size, int num_bits)
__global__ void fake_quantize_kernel_asym(__half* vals, int group_size, int num_bits)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
......@@ -580,7 +581,7 @@ __global__ void quantize_kernel_asym(__half* vals, int group_size, int num_bits)
#endif
}
__global__ void quantize_kernel_asym(float* vals, int group_size, int num_bits)
__global__ void fake_quantize_kernel_asym(float* vals, int group_size, int num_bits)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
......@@ -684,35 +685,35 @@ __global__ void quantize_kernel_asym(float* vals, int group_size, int num_bits)
}
template <typename T>
void launch_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream)
void launch_fake_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream)
{
dim3 grid_dim(group_num);
dim3 block_dim(1024);
quantize_kernel_asym<<<grid_dim, block_dim, 0, stream>>>(
fake_quantize_kernel_asym<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, num_bits);
}
template void launch_quantize_kernel_asym(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_quantize_kernel_asym(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
__global__ void sr_quantize_kernel_asym(__half* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
template void launch_fake_quantize_kernel_asym(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_fake_quantize_kernel_asym(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
__global__ void sr_fake_quantize_kernel_asym(__half* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)
......@@ -864,11 +865,11 @@ __global__ void sr_quantize_kernel_asym(__half* vals,
#endif
}
__global__ void sr_quantize_kernel_asym(float* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
__global__ void sr_fake_quantize_kernel_asym(float* vals,
int token_size,
int token_num,
int num_bits,
std::pair<uint64_t, uint64_t> seed)
{
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<32> g = cg::tiled_partition<32>(b);
......@@ -995,11 +996,11 @@ __global__ void sr_quantize_kernel_asym(float* vals,
}
}
template <typename T>
void launch_sr_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream)
void launch_sr_fake_quantize_kernel_asym(T* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream)
{
dim3 block_dim(1024);
dim3 grid_dim(group_num);
......@@ -1007,16 +1008,16 @@ void launch_sr_quantize_kernel_asym(T* vals,
uint64_t inc = total_count / grid_dim.x / block_dim.x;
std::pair<uint64_t, uint64_t> seed = Context::Instance().IncrementOffset(inc);
sr_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
sr_fake_quantize_kernel<<<grid_dim, block_dim, 0, stream>>>(
vals, (total_count / group_num) / 4, group_num, num_bits, seed);
}
template void launch_sr_quantize_kernel_asym(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_sr_quantize_kernel_asym(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_sr_fake_quantize_kernel_asym(float* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
template void launch_sr_fake_quantize_kernel_asym(__half* vals,
int total_count,
int group_num,
int num_bits,
cudaStream_t stream);
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <vector>
#include "custom_cuda_layers.h"
#include "quantization.h"
template <typename T>
at::Tensor ds_quantize(at::Tensor& vals, int groups, int bits)
......@@ -10,8 +10,8 @@ at::Tensor ds_quantize(at::Tensor& vals, int groups, int bits)
int size = 1;
for (auto dim : t_size) size *= dim;
if ((((size / groups) - 1) / 4096 + 1) <= MAX_REG) {
launch_quantize_kernel(
if ((((size / groups) - 1) / 4096 + 1) <= 256) {
launch_fake_quantize_kernel(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
......@@ -25,7 +25,7 @@ at::Tensor ds_sr_quantize(at::Tensor& vals, int groups, int bits)
for (auto dim : t_size) size *= dim;
if (((size / groups) / 4 / 1024) <= 256) {
launch_sr_quantize_kernel(
launch_sr_fake_quantize_kernel(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
......@@ -38,8 +38,8 @@ at::Tensor ds_quantize_asym(at::Tensor& vals, int groups, int bits)
int size = 1;
for (auto dim : t_size) size *= dim;
if ((((size / groups) - 1) / 4096 + 1) <= MAX_REG) {
launch_quantize_kernel_asym(
if ((((size / groups) - 1) / 4096 + 1) <= 256) {
launch_fake_quantize_kernel_asym(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
......@@ -53,12 +53,65 @@ at::Tensor ds_sr_quantize_asym(at::Tensor& vals, int groups, int bits)
for (auto dim : t_size) size *= dim;
if (((size / groups) / 4 / 1024) <= 256) {
launch_sr_quantize_kernel_asym(
launch_sr_fake_quantize_kernel_asym(
(T*)vals.data_ptr(), size, groups, bits, at::cuda::getCurrentCUDAStream());
}
return vals;
}
#define QUANTIZATION_CASE(TYPE, BITS) \
case TYPE: \
launch_quant<BITS, TYPE>((int8_t*)output.data_ptr(), \
(float*)params.data_ptr(), \
(__half*)input_vals.data_ptr(), \
groups, \
elems_per_group, \
at::cuda::getCurrentCUDAStream()); \
break;
std::vector<at::Tensor> quantize_kernel(at::Tensor& input_vals,
int groups,
int numBits,
quantize::Type quantType)
{
auto dtype = (quantType == quantize::Type::IntegerSymmetric) ? torch::kInt32 : at::kFloat;
auto params_options = at::TensorOptions()
.dtype(dtype)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int param_elems = (quantize::requires_offset(quantType)) ? 2 : 1;
auto params = torch::empty({groups, param_elems}, params_options);
auto output_options = at::TensorOptions()
.dtype(at::kChar)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
int K = input_vals.size(1) / (numBits == 8 ? 1 : 2);
int M = input_vals.size(0);
auto output = torch::empty({M, K}, output_options);
const int elems_per_group = at::numel(input_vals) / groups;
if (numBits == 4) {
switch (quantType) {
QUANTIZATION_CASE(quantize::Type::Symmetric, 4)
QUANTIZATION_CASE(quantize::Type::Asymmetric, 4)
QUANTIZATION_CASE(quantize::Type::IntegerSymmetric, 4)
}
} else {
switch (quantType) {
QUANTIZATION_CASE(quantize::Type::Symmetric, 8)
QUANTIZATION_CASE(quantize::Type::Asymmetric, 8)
QUANTIZATION_CASE(quantize::Type::IntegerSymmetric, 8)
}
}
return {output, params};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
......@@ -74,4 +127,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("ds_sr_quantize_asym_fp16",
&ds_sr_quantize_asym<__half>,
"DeepSpeed Quantize with fp16 (CUDA)");
pybind11::enum_<quantize::Type>(m, "QuantizationType")
.value("Symmetric", quantize::Type::Symmetric)
.value("Asymmetric", quantize::Type::Asymmetric)
.value("IntegerSymmetric", quantize::Type::IntegerSymmetric)
.export_values();
m.def("quantize", &quantize_kernel);
}
#include <cstdio>
#include "custom_cuda_layers.h"
#include "memory_access_utils.h"
#include "quantization.h"
#include "quantization_utils.h"
#include "reduction_utils.h"
namespace cg = cooperative_groups;
/*
Pure quantization kernel with no fusion.
*/
template <int q_bits,
quantize::Type quant_type,
int UNROLL,
int internal_unroll,
int threads_per_group,
int max_threads>
__global__ void cached_quantization(int8_t* __restrict__ output_data,
float* __restrict__ params,
const __half* __restrict__ input_data,
int groups,
int elems_per_group)
{
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
// Indexing offsets
const int block_offset =
(tb.group_index().x * (max_threads / threads_per_group) * elems_per_group) +
(tb.thread_index().y * elems_per_group);
const int elem_offset = tb.thread_index().x * quantize::h_per_load;
const int base_offset = block_offset + elem_offset;
const int stride = tb.size() * quantize::h_per_load;
const __half* input_base = input_data + base_offset; //..
__half2 local_buffer[UNROLL * internal_unroll * quantize::h2_per_load];
#pragma unroll
for (int i = 0; i < UNROLL; i++) {
// Convenience helper, should resolve to register indices and not realize.
__half2* iteration_buffer = local_buffer + i * internal_unroll * quantize::h2_per_load;
#pragma unroll
for (int j = 0; j < internal_unroll; j++) {
const int iteration = i * internal_unroll + j;
mem_access::load_global<quantize::granularity>(
iteration_buffer + j * quantize::h2_per_load,
input_base + iteration * stride,
elem_offset + iteration * stride < elems_per_group);
}
}
quantize::
local_array<quant_type, q_bits, UNROLL * internal_unroll, threads_per_group, max_threads>(
local_buffer, params, output_data, elems_per_group, groups);
}
/********* Launcher methods ***********/
int next_pow2(const int val)
{
int rounded_val = val - 1;
rounded_val |= rounded_val >> 1;
rounded_val |= rounded_val >> 2;
rounded_val |= rounded_val >> 4;
rounded_val |= rounded_val >> 8;
return rounded_val + 1;
}
int32_t round_to_32(int32_t raw_value) { return (((raw_value - 1) >> 5) + 1) << 5; }
#define LAUNCH_CACHED_QUANT( \
q_bits, quant_type, unroll_factor, internal_unroll, threads_per_group, max_threads) \
cached_quantization<q_bits, \
quant_type, \
unroll_factor, \
internal_unroll, \
threads_per_group, \
max_threads> \
<<<grid, block, 0, stream>>>(output_data, params, input_data, groups, elems_per_group);
template <int numBits, quantize::Type qType>
void launch_quant(int8_t* output_data,
float* params,
const __half* input_data,
const int groups,
const int elems_per_group,
cudaStream_t stream)
{
constexpr int max_threads = 256;
constexpr int internal_unroll = 2;
const bool is_subblock_schedule = (elems_per_group <= 128) ? true : false;
const int h_per_step = is_subblock_schedule ? quantize::h_per_load
: quantize::h_per_load * internal_unroll;
// Scheduling concern: may be slightly faster for some inputs to assign multiple stages of
// warp-sized blocks rather than stepping up to 64/96 threads
const int one_step_threads = next_pow2((elems_per_group + h_per_step - 1) / h_per_step);
const int threads_per_group = (one_step_threads < max_threads) ? one_step_threads : max_threads;
const int warps_per_group = threads_per_group / hw_warp_size;
const int groups_per_block =
is_subblock_schedule ? (max_threads + threads_per_group - 1) / threads_per_group : 1;
const int groups_launch = (groups_per_block + groups - 1) / groups_per_block;
dim3 block(threads_per_group, groups_per_block);
dim3 grid(groups_launch);
const int elems_per_step = threads_per_group * h_per_step;
const int external_unroll = (elems_per_group + elems_per_step - 1) / elems_per_step;
if (is_subblock_schedule) {
// <=128
if (threads_per_group == 1) {
LAUNCH_CACHED_QUANT(numBits, qType, 1, 1, 1, max_threads);
} else if (threads_per_group == 2) {
LAUNCH_CACHED_QUANT(numBits, qType, 1, 1, 2, max_threads);
} else if (threads_per_group == 4) {
LAUNCH_CACHED_QUANT(numBits, qType, 1, 1, 4, max_threads);
} else if (threads_per_group == 8) {
LAUNCH_CACHED_QUANT(numBits, qType, 1, 1, 8, max_threads);
} else if (threads_per_group == 16) {
LAUNCH_CACHED_QUANT(numBits, qType, 1, 1, 16, max_threads);
}
} else if (external_unroll == 1) {
// 129 - 4096 elems
// (this can launch with 1-7 warps as well)
LAUNCH_CACHED_QUANT(numBits, qType, 1, internal_unroll, max_threads, max_threads);
} else if (external_unroll == 2) {
// 4097 - 8192 elems
LAUNCH_CACHED_QUANT(numBits, qType, 2, internal_unroll, max_threads, max_threads);
} else if (external_unroll == 3) {
// 8193 - 12288 elems
LAUNCH_CACHED_QUANT(numBits, qType, 3, internal_unroll, max_threads, max_threads);
} else if (external_unroll == 4) {
// 12289 - 16384 elems
LAUNCH_CACHED_QUANT(numBits, qType, 4, internal_unroll, max_threads, max_threads);
}
}
template void launch_quant<8, quantize::Type::Symmetric>(int8_t*,
float*,
const __half*,
int,
int,
cudaStream_t);
template void launch_quant<8, quantize::Type::Asymmetric>(int8_t*,
float*,
const __half*,
int,
int,
cudaStream_t);
template void launch_quant<8, quantize::Type::IntegerSymmetric>(int8_t*,
float*,
const __half*,
int,
int,
cudaStream_t);
template void launch_quant<4, quantize::Type::Symmetric>(int8_t*,
float*,
const __half*,
int,
int,
cudaStream_t);
template void launch_quant<4, quantize::Type::Asymmetric>(int8_t*,
float*,
const __half*,
int,
int,
cudaStream_t);
template void launch_quant<4, quantize::Type::IntegerSymmetric>(int8_t*,
float*,
const __half*,
int,
int,
cudaStream_t);
......@@ -15,7 +15,8 @@ class QuantizerBuilder(CUDAOpBuilder):
def sources(self):
return [
'csrc/quantization/pt_binding.cpp',
'csrc/quantization/quantizer.cu',
'csrc/quantization/fake_quantizer.cu',
'csrc/quantization/quantize.cu',
]
def include_paths(self):
......
......@@ -40,7 +40,7 @@ def run_quant_dequant(inputs, groups, bits):
@pytest.mark.parametrize("groups", [1, 16])
# Test with number of quant groups as 1 and 16.
# Note that we have an explicit boundary for groups as ((size / groups) - 1) / 4096 + 1) <= MAX_REG.
def test_quant_dequant(tensor_shape, groups):
def test_fake_quant_dequant(tensor_shape, groups):
input_tensor = torch.rand((tensor_shape), dtype=torch.float16).cuda()
......
"""
Copyright 2022 The Microsoft DeepSpeed Team
"""
import pytest
import torch
from deepspeed.ops import op_builder
inference_module = None
torch_minor_version = None
def run_quantize_ds(activations, num_groups, q_bits, is_symmetric_quant):
global inference_module
if inference_module is None:
inference_module = op_builder.QuantizerBuilder().load()
return inference_module.quantize(
activations,
num_groups,
q_bits,
inference_module.Symmetric
if is_symmetric_quant else inference_module.Asymmetric)
def get_q_props(q_bits):
q_range = 2**q_bits
q_min = -(2**(q_bits - 1))
q_max = (2**(q_bits - 1) - 1)
q_min = torch.IntTensor([q_min]).to(device='cuda')
q_max = torch.IntTensor([q_max]).to(device='cuda')
return q_range, q_max, q_min
def get_scale_zero_point(q_bits,
is_symmetric_quant,
max,
min,
absmax,
scales=None,
zero_points=None):
q_range, q_max, q_min = get_q_props(q_bits)
if is_symmetric_quant:
scale = torch.empty_like(absmax)
for i, x in enumerate(absmax):
scale[i] = torch.ones_like(x) if x == 0 else q_range / (2 * x)
zero_point = torch.zeros(scale.shape, dtype=torch.float32, device='cuda')
else:
scale = torch.empty_like(max)
for i, x in enumerate(max):
scale[i] = torch.ones_like(x) if max[i] == min[i] else q_range / (max[i] -
min[i])
zero_point = q_min - (min * scale)
return scale, zero_point
def int4x2to2xint4(int4X2tensor):
high = int4X2tensor >> 4
low = (int4X2tensor << 4) >> 4
return torch.stack((high, low), dim=-1).flatten()
def run_float_quantize(q_bits, is_symmetric_quant, activations_ref, num_groups):
# Reference implementation
# https://pytorch.org/docs/stable/quantization-support.html
activations_ref = activations_ref.reshape(num_groups, -1).to(dtype=torch.float32)
max_abs_activations_ref = torch.amax(torch.abs(activations_ref),
dim=-1).view(num_groups,
-1)
max_activations_ref = torch.amax(activations_ref, dim=-1).view(num_groups, -1)
min_activations_ref = torch.amin(activations_ref, dim=-1).view(num_groups, -1)
_, q_max, q_min = get_q_props(q_bits)
scale, zero_point = get_scale_zero_point(q_bits, is_symmetric_quant, max_activations_ref, min_activations_ref, max_abs_activations_ref)
data_f = activations_ref * scale
if not is_symmetric_quant:
data_f = data_f + zero_point
data_i32 = torch.round(data_f).to(dtype=torch.int32)
data_i32 = torch.minimum(torch.maximum(data_i32,
q_min.expand_as(data_i32)),
q_max.expand_as(data_i32))
data_i8 = data_i32.to(dtype=torch.int8)
scales = (1.0 / scale).reshape(-1, 1)
offsets = zero_point.reshape(-1, 1)
params = torch.cat((scales, offsets), dim=-1)
return data_i8, params
@pytest.mark.inference
@pytest.mark.parametrize("num_groups", [1, 13, 512])
@pytest.mark.parametrize("num_elems",
[8,
16,
32,
64,
128,
256,
4096,
8192,
12288,
16384])
@pytest.mark.parametrize("is_symmetric_quant", [True, False])
@pytest.mark.parametrize("q_bits", [4, 8])
@pytest.mark.parametrize("directed_case", ["all_zeros", None])
def test_float_quantize(num_elems,
num_groups,
is_symmetric_quant,
q_bits,
directed_case):
if directed_case == "all_zeros":
activations_ds = torch.zeros((num_groups,
num_elems),
dtype=torch.float16,
device='cuda')
else:
activations_ds = torch.randn((num_groups,
num_elems),
dtype=torch.float16,
device='cuda')
activations_ref = activations_ds.clone().detach()
ref_out_tensor, ref_params = run_float_quantize(q_bits, is_symmetric_quant, activations_ref, num_groups)
ds_out_tensor, ds_out_params = run_quantize_ds(activations_ds, num_groups, q_bits, is_symmetric_quant)
if (q_bits == 4):
ds_out_tensor = int4x2to2xint4(ds_out_tensor)
# Allow a max difference of 1 to account for differences in rounding in pytorch implementation
assert (torch.all(
torch.lt(torch.abs(ds_out_tensor.flatten() - ref_out_tensor.flatten()),
2)))
if is_symmetric_quant:
assert (torch.allclose(ds_out_params.flatten(), ref_params[:, 0].flatten()))
else:
assert (torch.allclose(ds_out_params[:,
0].flatten(),
ref_params[:,
0].flatten()))
assert (torch.allclose(ds_out_params[:,
1].flatten(),
ref_params[:,
1].flatten(),
atol=5e-5,
rtol=5e-5))
def run_integer_quantize_ds(activations, num_groups, q_bits):
global inference_module
if inference_module is None:
inference_module = op_builder.QuantizerBuilder().load()
return inference_module.quantize(activations,
num_groups,
q_bits,
inference_module.IntegerSymmetric)
def run_integer_quantize(q_bits, activations_ref, num_groups):
activations_ref = activations_ref.reshape(num_groups, -1).to(dtype=torch.float32)
max_abs_activations_ref = torch.amax(torch.abs(activations_ref),
dim=-1).view(num_groups,
-1)
_, q_max, q_min = get_q_props(q_bits)
print(max_abs_activations_ref)
max_abs_activations_ref = (max_abs_activations_ref + 1).to(torch.int8).to(
torch.float32)
print(max_abs_activations_ref)
numerator = activations_ref * q_max
print(numerator.dtype)
denominator = max_abs_activations_ref
data_f = numerator / denominator
data_i32 = torch.round(data_f).to(dtype=torch.int32)
data_i32 = torch.minimum(torch.maximum(data_i32,
q_min.expand_as(data_i32)),
q_max.expand_as(data_i32))
data_i8 = data_i32.to(dtype=torch.int8)
return data_i8, max_abs_activations_ref.to(torch.int32)
@pytest.mark.inference
@pytest.mark.parametrize("num_groups", [1, 2, 4, 8, 16, 32, 64, 512])
@pytest.mark.parametrize("num_elems", [4096, 8192, 12288, 16384])
@pytest.mark.parametrize("q_bits", [4, 8])
def test_integer_quantize(num_elems, num_groups, q_bits):
activations_ds = torch.ones((num_groups,
num_elems),
dtype=torch.float16,
device='cuda') * 0.35
activations_ref = activations_ds.clone().detach()
ref_out_tensor, ref_params = run_integer_quantize(q_bits, activations_ref, num_groups)
ds_out_tensor, ds_out_params = run_integer_quantize_ds(activations_ds, num_groups, q_bits)
if (q_bits == 4):
ds_out_tensor = int4x2to2xint4(ds_out_tensor)
# Allow a max difference of 1 to account for differences in rounding in pytorch implementation
assert (torch.all(
torch.lt(torch.abs(ds_out_tensor.flatten() - ref_out_tensor.flatten()),
2)))
assert (torch.allclose(ds_out_params.flatten(), ref_params.flatten()))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册