未验证 提交 30c8d8a8 编写于 作者: C Connor Holmes 提交者: GitHub

Initial dequant library implementation (#2521)

上级 0b265326
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "conversion_utils.h"
#include "ds_kernel_utils.h"
#include "quantization.h"
#include "quantization_utils.h"
namespace cg = cooperative_groups;
#pragma once
namespace dequantize {
using Type = quantize::Type;
template <Type qType, int numBits>
using Params = quantize::Params<qType, numBits>;
constexpr int granularity = quantize::granularity;
using PackedInt4 = quantize::PackedInt4;
constexpr int h_per_chunk = granularity / sizeof(__half);
constexpr int h2_per_chunk = granularity / sizeof(__half2);
/*
Device function that reads quantized data from global memory, dequantizes
it, and stores it to global memory.
Template Arguments :
numBits - Number of bits in quantized element. int: 4, 8
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
unroll - Number of load steps to internally unroll int
threads - Number of threads to perform dequant int
Function arguments:
global_output - __half pointer in global memory
data - Quantized data in global memory
global_params - Quantization parameters in global memory
elems_per_group - Number of elements in each quantization group
total_elems - Tensor size (note, does not need to be multiple of elems_per_group)
*/
template <int numBits, Type qType, int unroll, int threads>
DS_D_INLINE void to_global(__half* global_output,
const int8_t* data,
const float* global_params,
const int elems_per_group,
const int total_elems);
/*
Device function that quantizes 16 bytes of __half type input data.
Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
Function Arguments :
local_output - Local array to store dequantized data __half* or __half2*
data - Pointer to quantized input data. int8_t*
Params - Parameters for quantization. Params<qType, numBits>
*/
template <int numBits, Type qType>
DS_D_INLINE void chunk(__half2* local_output, const int8_t* data, Params<qType, numBits> q_params);
template <int numBits, Type qType>
DS_D_INLINE void chunk(__half* local_output, const int8_t* data, Params<qType, numBits> q_params);
/**************** Implementations ******************/
template <int numBits, Type qType>
DS_D_INLINE void chunk(__half* local_output, const int8_t* data, Params<qType, numBits> q_params)
{
constexpr int32_t num_elems_packed = 8 / numBits;
constexpr int32_t iters = h_per_chunk / num_elems_packed;
#pragma unroll
for (int i = 0; i < iters; i++) {
if constexpr (num_elems_packed == 1) {
local_output[i] = q_params.dequantize(data[i]);
} else {
auto accessible_data = *(PackedInt4*)(&data[i]);
local_output[2 * i] = q_params.dequantize(accessible_data.low);
local_output[2 * i + 1] = q_params.dequantize(accessible_data.high);
}
}
}
template <int numBits, Type qType>
DS_D_INLINE void chunk(__half2* local_output, const int8_t* data, Params<qType, numBits> q_params)
{
__half* local_output_cast = reinterpret_cast<__half*>(local_output);
chunk<numBits>(local_output_cast, data, q_params);
}
template <int numBits, Type qType, int unroll, int threads>
DS_D_INLINE void _to_global(__half* global_output,
const int8_t* data,
const float* global_params,
const int elems_per_group,
const int total_elems)
{
cg::thread_block tb = cg::this_thread_block();
cg::thread_block_tile<hw_warp_size> warp = cg::tiled_partition<hw_warp_size>(tb);
// Load constants
// TODO(cmikeh2): Refactor into functions?
constexpr int load_granularity = granularity * numBits / 16;
constexpr int load_step_stride = load_granularity * threads;
constexpr int load_block_stride = load_step_stride * unroll;
// Store constants
constexpr int store_step_stride = h_per_chunk * threads;
constexpr int store_block_stride = store_step_stride * unroll;
// Load offsets
const int load_block_offset = tb.group_index().x * load_block_stride;
// Note: we can use `load_granularity` since the dtype is `int8_t`.
const int load_thread_offset = tb.thread_index().x * load_granularity;
const int8_t* load_base = data + load_block_offset + load_thread_offset;
// Store offsets
const int store_block_offset = tb.group_index().x * store_block_stride;
const int store_thread_offset = tb.thread_index().x * h_per_chunk;
const int elem_id_base = store_block_offset + store_thread_offset;
int8_t local_load_buffer[load_granularity * unroll];
__half local_dequant_buffer[h_per_chunk * unroll];
/*
Note: Splitting this loop in half gave about 3-5% performance increase for reasons that aren't
totally clear to me, so this is a deliberately weird code structure.
*/
#pragma unroll
for (int i = 0; i < unroll; i++) {
const int elem_id_iter = elem_id_base + i * store_step_stride;
if (elem_id_iter < total_elems) {
mem_access::load_global<load_granularity>(local_load_buffer + i * load_granularity,
load_base + i * load_step_stride);
}
}
#pragma unroll
for (int i = 0; i < unroll; i++) {
const int elem_id_iter = elem_id_base + i * store_step_stride;
if (elem_id_iter < total_elems) {
// TODO(cmikeh2): Can we amortize this division? Perform once on the first iteration and
// use indexing math to do division free interpolation of the successive groups?
const int group_index = elem_id_iter / elems_per_group;
Params<qType, numBits> q_params(global_params, group_index);
chunk<numBits, qType>(local_dequant_buffer + i * h_per_chunk,
local_load_buffer + i * load_granularity,
q_params);
mem_access::store_global<granularity>(global_output + elem_id_iter,
local_dequant_buffer + i * h_per_chunk);
}
}
}
template <int numBits, Type qType, int unroll, int threads>
DS_D_INLINE void to_global(__half* global_output,
const int8_t* data,
const float* global_params,
const int elems_per_group,
const int total_elems)
{
if constexpr (numBits == 4 || numBits == 8) {
_to_global<numBits, qType, unroll, threads>(
global_output, data, global_params, elems_per_group, total_elems);
} else if constexpr (numBits == 3) {
// TODO(cmikeh2): Need this implementation
assert(false);
} else {
assert(false);
}
}
} // namespace dequantize
...@@ -25,6 +25,15 @@ void launch_quant(int8_t* output_data, ...@@ -25,6 +25,15 @@ void launch_quant(int8_t* output_data,
int elems_per_group, int elems_per_group,
cudaStream_t stream); cudaStream_t stream);
void launch_dequantize_kernel(__half* dequant_data,
const int8_t* q_data,
const float* q_params,
quantize::Type q_type,
int num_bits,
int elems_per_group,
int total_elems,
cudaStream_t stream);
template <typename T> template <typename T>
void launch_fake_quantize_kernel(T* vals, void launch_fake_quantize_kernel(T* vals,
int total_count, int total_count,
......
#include <cstdio> /*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include <cassert>
#include "conversion_utils.h" #include "conversion_utils.h"
#include "ds_kernel_utils.h" #include "ds_kernel_utils.h"
#include "memory_access_utils.h" #include "memory_access_utils.h"
...@@ -33,7 +37,12 @@ public: ...@@ -33,7 +37,12 @@ public:
*/ */
DS_D_INLINE int8_t quantize(__half val); DS_D_INLINE int8_t quantize(__half val);
DS_D_INLINE __half dequantize(int8_t val);
DS_D_INLINE void store(float* params, int group_index); DS_D_INLINE void store(float* params, int group_index);
// Initialize from memory
DS_D_INLINE Params(const float* params, int group_index);
}; };
template <int numBits> template <int numBits>
...@@ -61,11 +70,22 @@ public: ...@@ -61,11 +70,22 @@ public:
return (int8_t)data_i32; return (int8_t)data_i32;
} }
DS_D_INLINE __half dequantize(int8_t val)
{
const float val_deq_f = conversion::to<float>(val) * scale;
return conversion::to<__half>(val_deq_f);
}
DS_D_INLINE void store(float* params, int group_index) DS_D_INLINE void store(float* params, int group_index)
{ {
const float store_scale = 1 / scale; const float store_scale = 1 / scale;
mem_access::store_global<sizeof(float)>(params + group_index, &store_scale); mem_access::store_global<sizeof(float)>(params + group_index, &store_scale);
} }
DS_D_INLINE Params(const float* params, int group_index)
{
mem_access::load_global<sizeof(float)>(&scale, params + group_index);
}
}; };
template <int numBits> template <int numBits>
...@@ -84,10 +104,14 @@ public: ...@@ -84,10 +104,14 @@ public:
return (int8_t)data_i32; return (int8_t)data_i32;
} }
DS_D_INLINE __half dequantize(int8_t val) { assert(false); }
DS_D_INLINE void store(float* params, int group_index) DS_D_INLINE void store(float* params, int group_index)
{ {
mem_access::store_global<sizeof(float)>(params + group_index, &scale); mem_access::store_global<sizeof(float)>(params + group_index, &scale);
} }
DS_D_INLINE Params(const float* params, int group_index) { assert(false); }
}; };
template <int numBits> template <int numBits>
...@@ -117,12 +141,26 @@ public: ...@@ -117,12 +141,26 @@ public:
return (int8_t)data_i32; return (int8_t)data_i32;
} }
DS_D_INLINE __half dequantize(int8_t val)
{
const float val_deq_f = conversion::to<float>(val) * scale + offset;
return conversion::to<__half>(val_deq_f);
}
DS_D_INLINE void store(float* params, int group_index) DS_D_INLINE void store(float* params, int group_index)
{ {
// Codegen should turn this into stg.64
const float store_scale = 1 / scale; const float store_scale = 1 / scale;
mem_access::store_global<sizeof(float)>(params + 2 * group_index, &store_scale); mem_access::store_global<sizeof(float)>(params + 2 * group_index, &store_scale);
mem_access::store_global<sizeof(float)>(params + 2 * group_index + 1, &offset); mem_access::store_global<sizeof(float)>(params + 2 * group_index + 1, &offset);
} }
DS_D_INLINE Params(const float* params, int group_index)
{
// Codegen should turn this into ldg.64
mem_access::load_global<sizeof(float)>(&scale, params + 2 * group_index);
mem_access::load_global<sizeof(float)>(&offset, params + 2 * group_index + 1);
}
}; };
/* /*
...@@ -293,7 +331,7 @@ Template Arguments : ...@@ -293,7 +331,7 @@ Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4 numBits - Number of bits in quantized element. int : 8 or 4
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
Function Arguments : Function Arguments :
local_output - Pointer to shared memory to store quantized data. int8_t* local_output - Pointer to local memory to store quantized data. int8_t*
data - Pointer to input data. __half* data - Pointer to input data. __half*
Params - Parameters for quantization. Params<qType, numBits> Params - Parameters for quantization. Params<qType, numBits>
*/ */
...@@ -306,7 +344,7 @@ Template Arguments : ...@@ -306,7 +344,7 @@ Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4 numBits - Number of bits in quantized element. int : 8 or 4
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
Function Arguments : Function Arguments :
local_output - Pointer to shared memory to store quantized data. int8_t* local_output - Pointer to local memory to store quantized data. int8_t*
data - Pointer to input data. __half2* data - Pointer to input data. __half2*
Params - Parameters for quantization. Params<qType, numBits> Params - Parameters for quantization. Params<qType, numBits>
*/ */
......
...@@ -263,7 +263,7 @@ DS_D_INLINE __half init<ROpType::Min>() ...@@ -263,7 +263,7 @@ DS_D_INLINE __half init<ROpType::Min>()
} }
template <> template <>
__half init<ROpType::Max>() DS_D_INLINE __half init<ROpType::Max>()
{ {
constexpr __half_raw neg_inf = {0xFC00}; constexpr __half_raw neg_inf = {0xFC00};
return __half(neg_inf); return __half(neg_inf);
......
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "dequantization_utils.h"
#include "memory_access_utils.h"
namespace cg = cooperative_groups;
template <int numBits, dequantize::Type qType, int unroll, int threads>
__global__ void dequantize_kernel(__half* __restrict__ dequant_data,
const int8_t* __restrict__ q_data,
const float* __restrict__ q_params,
int elems_per_group,
int total_elems)
{
dequantize::to_global<numBits, qType, unroll, threads>(
dequant_data, q_data, q_params, elems_per_group, total_elems);
}
#define LAUNCH_DEQUANT_KERNEL(num_bits, q_type) \
dequantize_kernel<num_bits, q_type, unroll, threads><<<grid, block, 0, stream>>>( \
dequant_data, q_data, q_params, elems_per_group, total_elems);
void launch_dequantize_kernel(__half* dequant_data,
const int8_t* q_data,
const float* q_params,
quantize::Type q_type,
int num_bits,
int elems_per_group,
int total_elems,
cudaStream_t stream)
{
constexpr int unroll = 8;
constexpr int threads = 512;
constexpr int elems_per_block = unroll * threads * dequantize::h_per_chunk;
const dim3 block(threads);
const dim3 grid((total_elems + elems_per_block - 1) / elems_per_block);
// TODO(cmikeh2): It may make sense to tune unroll, there is perf benefit for large
// problem sizes with this large unroll value.
if (num_bits == 8 && q_type == quantize::Type::Symmetric) {
LAUNCH_DEQUANT_KERNEL(8, quantize::Type::Symmetric);
} else if (num_bits == 8 && q_type == quantize::Type::Asymmetric) {
LAUNCH_DEQUANT_KERNEL(8, quantize::Type::Asymmetric);
} else if (num_bits == 4 && q_type == quantize::Type::Symmetric) {
LAUNCH_DEQUANT_KERNEL(4, quantize::Type::Symmetric);
} else if (num_bits == 4 && q_type == quantize::Type::Asymmetric) {
LAUNCH_DEQUANT_KERNEL(4, quantize::Type::Asymmetric);
}
}
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <cassert>
#include <vector> #include <vector>
#include "quantization.h" #include "quantization.h"
...@@ -112,6 +113,47 @@ std::vector<at::Tensor> quantize_kernel(at::Tensor& input_vals, ...@@ -112,6 +113,47 @@ std::vector<at::Tensor> quantize_kernel(at::Tensor& input_vals,
return {output, params}; return {output, params};
} }
int num_decompressed_elems(at::Tensor& quantized_data, int num_bits)
{
if (num_bits == 8) {
return quantized_data.size(-1);
} else if (num_bits == 4) {
return quantized_data.size(-1) * 2;
} else {
assert(false);
return 0;
}
}
at::Tensor dequantize(at::Tensor& quantized_data,
at::Tensor& params,
int groups,
int num_bits,
quantize::Type quant_type)
{
auto output_options = at::TensorOptions()
.dtype(torch::kFloat16)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
const int final_dim_size = num_decompressed_elems(quantized_data, num_bits);
auto output = torch::empty({quantized_data.size(0), final_dim_size}, output_options);
const int total_elems = quantized_data.size(0) * final_dim_size;
const int elems_per_group = total_elems / groups;
launch_dequantize_kernel((__half*)output.data_ptr(),
(const int8_t*)quantized_data.data_ptr(),
(const float*)params.data_ptr(),
quant_type,
num_bits,
elems_per_group,
total_elems,
at::cuda::getCurrentCUDAStream());
return output;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{ {
m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)"); m.def("ds_quantize_fp32", &ds_quantize<float>, "DeepSpeed Quantize with fp32 (CUDA)");
...@@ -133,4 +175,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ...@@ -133,4 +175,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.value("IntegerSymmetric", quantize::Type::IntegerSymmetric) .value("IntegerSymmetric", quantize::Type::IntegerSymmetric)
.export_values(); .export_values();
m.def("quantize", &quantize_kernel); m.def("quantize", &quantize_kernel);
m.def("dequantize", &dequantize);
} }
#include <cstdio> /*
#include "custom_cuda_layers.h" Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "memory_access_utils.h" #include "memory_access_utils.h"
#include "quantization.h" #include "quantization.h"
#include "quantization_utils.h" #include "quantization_utils.h"
...@@ -99,7 +101,6 @@ void launch_quant(int8_t* output_data, ...@@ -99,7 +101,6 @@ void launch_quant(int8_t* output_data,
// warp-sized blocks rather than stepping up to 64/96 threads // warp-sized blocks rather than stepping up to 64/96 threads
const int one_step_threads = next_pow2((elems_per_group + h_per_step - 1) / h_per_step); const int one_step_threads = next_pow2((elems_per_group + h_per_step - 1) / h_per_step);
const int threads_per_group = (one_step_threads < max_threads) ? one_step_threads : max_threads; const int threads_per_group = (one_step_threads < max_threads) ? one_step_threads : max_threads;
const int warps_per_group = threads_per_group / hw_warp_size;
const int groups_per_block = const int groups_per_block =
is_subblock_schedule ? (max_threads + threads_per_group - 1) / threads_per_group : 1; is_subblock_schedule ? (max_threads + threads_per_group - 1) / threads_per_group : 1;
......
...@@ -17,6 +17,7 @@ class QuantizerBuilder(CUDAOpBuilder): ...@@ -17,6 +17,7 @@ class QuantizerBuilder(CUDAOpBuilder):
'csrc/quantization/pt_binding.cpp', 'csrc/quantization/pt_binding.cpp',
'csrc/quantization/fake_quantizer.cu', 'csrc/quantization/fake_quantizer.cu',
'csrc/quantization/quantize.cu', 'csrc/quantization/quantize.cu',
'csrc/quantization/dequantize.cu',
] ]
def include_paths(self): def include_paths(self):
......
"""
Copyright 2022 The Microsoft DeepSpeed Team
"""
import pytest
import torch
from deepspeed.ops import op_builder
quantize_module = None
def int4x2to2xint4(int4X2tensor):
high = int4X2tensor >> 4
low = (int4X2tensor << 4) >> 4
return torch.stack((high, low), dim=-1).flatten()
def run_quantize(data, num_groups, q_bits, is_symmetric_quant):
global quantize_module
if quantize_module is None:
quantize_module = op_builder.QuantizerBuilder().load()
return quantize_module.quantize(
data,
num_groups,
q_bits,
quantize_module.Symmetric if is_symmetric_quant else quantize_module.Asymmetric)
def run_dequantize(quantized_data, params, num_groups, q_bits, is_symmetric_quant):
global quantize_module
if quantize_module is None:
quantize_module = op_builder.QuantizerBuilder().load()
return quantize_module.dequantize(
quantized_data,
params,
num_groups,
q_bits,
quantize_module.Symmetric if is_symmetric_quant else quantize_module.Asymmetric)
def run_ref_dequantize(quantized_data, params, num_groups, q_bits, is_symmetric_quant):
if (q_bits == 4):
quantized_data = int4x2to2xint4(quantized_data)
quantized_data = quantized_data.reshape(num_groups, -1).to(torch.float32)
if is_symmetric_quant:
return (quantized_data * params).to(torch.float16)
else:
scales = params[:, 0].reshape(-1, 1)
offsets = params[:, 1].reshape(-1, 1)
return (quantized_data * scales + offsets).to(torch.float16)
@pytest.mark.inference
@pytest.mark.parametrize("num_groups", [1, 13, 512])
@pytest.mark.parametrize("num_elems",
[8,
16,
32,
64,
128,
256,
4096,
8192,
12288,
16384])
@pytest.mark.parametrize("is_symmetric_quant", [True, False])
@pytest.mark.parametrize("q_bits", [4, 8])
def test_dequantize(num_elems, num_groups, is_symmetric_quant, q_bits):
activations = torch.randn((num_groups,
num_elems),
dtype=torch.float16,
device='cuda')
quantized_data, params = run_quantize(activations, num_groups, q_bits, is_symmetric_quant)
ds_dequant = run_dequantize(quantized_data,
params,
num_groups,
q_bits,
is_symmetric_quant)
ref_dequant = run_ref_dequantize(quantized_data,
params,
num_groups,
q_bits,
is_symmetric_quant)
assert (torch.allclose(ds_dequant.flatten(),
ref_dequant.flatten(),
rtol=3e-2,
atol=2e-3))
...@@ -7,7 +7,6 @@ import torch ...@@ -7,7 +7,6 @@ import torch
from deepspeed.ops import op_builder from deepspeed.ops import op_builder
inference_module = None inference_module = None
torch_minor_version = None
def run_quantize_ds(activations, num_groups, q_bits, is_symmetric_quant): def run_quantize_ds(activations, num_groups, q_bits, is_symmetric_quant):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册