Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
30c8d8a8
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 12 个月
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
30c8d8a8
编写于
11月 18, 2022
作者:
C
Connor Holmes
提交者:
GitHub
11月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Initial dequant library implementation (#2521)
上级
0b265326
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
421 addition
and
8 deletion
+421
-8
csrc/includes/dequantization_utils.h
csrc/includes/dequantization_utils.h
+175
-0
csrc/includes/quantization.h
csrc/includes/quantization.h
+9
-0
csrc/includes/quantization_utils.h
csrc/includes/quantization_utils.h
+41
-3
csrc/includes/reduction_utils.h
csrc/includes/reduction_utils.h
+1
-1
csrc/quantization/dequantize.cu
csrc/quantization/dequantize.cu
+52
-0
csrc/quantization/pt_binding.cpp
csrc/quantization/pt_binding.cpp
+43
-0
csrc/quantization/quantize.cu
csrc/quantization/quantize.cu
+4
-3
op_builder/quantizer.py
op_builder/quantizer.py
+1
-0
tests/unit/ops/quantizer/test_dequantize.py
tests/unit/ops/quantizer/test_dequantize.py
+95
-0
tests/unit/ops/quantizer/test_quantize.py
tests/unit/ops/quantizer/test_quantize.py
+0
-1
未找到文件。
csrc/includes/dequantization_utils.h
0 → 100644
浏览文件 @
30c8d8a8
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "conversion_utils.h"
#include "ds_kernel_utils.h"
#include "quantization.h"
#include "quantization_utils.h"
namespace
cg
=
cooperative_groups
;
#pragma once
namespace
dequantize
{
using
Type
=
quantize
::
Type
;
template
<
Type
qType
,
int
numBits
>
using
Params
=
quantize
::
Params
<
qType
,
numBits
>
;
constexpr
int
granularity
=
quantize
::
granularity
;
using
PackedInt4
=
quantize
::
PackedInt4
;
constexpr
int
h_per_chunk
=
granularity
/
sizeof
(
__half
);
constexpr
int
h2_per_chunk
=
granularity
/
sizeof
(
__half2
);
/*
Device function that reads quantized data from global memory, dequantizes
it, and stores it to global memory.
Template Arguments :
numBits - Number of bits in quantized element. int: 4, 8
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
unroll - Number of load steps to internally unroll int
threads - Number of threads to perform dequant int
Function arguments:
global_output - __half pointer in global memory
data - Quantized data in global memory
global_params - Quantization parameters in global memory
elems_per_group - Number of elements in each quantization group
total_elems - Tensor size (note, does not need to be multiple of elems_per_group)
*/
template
<
int
numBits
,
Type
qType
,
int
unroll
,
int
threads
>
DS_D_INLINE
void
to_global
(
__half
*
global_output
,
const
int8_t
*
data
,
const
float
*
global_params
,
const
int
elems_per_group
,
const
int
total_elems
);
/*
Device function that quantizes 16 bytes of __half type input data.
Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
Function Arguments :
local_output - Local array to store dequantized data __half* or __half2*
data - Pointer to quantized input data. int8_t*
Params - Parameters for quantization. Params<qType, numBits>
*/
template
<
int
numBits
,
Type
qType
>
DS_D_INLINE
void
chunk
(
__half2
*
local_output
,
const
int8_t
*
data
,
Params
<
qType
,
numBits
>
q_params
);
template
<
int
numBits
,
Type
qType
>
DS_D_INLINE
void
chunk
(
__half
*
local_output
,
const
int8_t
*
data
,
Params
<
qType
,
numBits
>
q_params
);
/**************** Implementations ******************/
template
<
int
numBits
,
Type
qType
>
DS_D_INLINE
void
chunk
(
__half
*
local_output
,
const
int8_t
*
data
,
Params
<
qType
,
numBits
>
q_params
)
{
constexpr
int32_t
num_elems_packed
=
8
/
numBits
;
constexpr
int32_t
iters
=
h_per_chunk
/
num_elems_packed
;
#pragma unroll
for
(
int
i
=
0
;
i
<
iters
;
i
++
)
{
if
constexpr
(
num_elems_packed
==
1
)
{
local_output
[
i
]
=
q_params
.
dequantize
(
data
[
i
]);
}
else
{
auto
accessible_data
=
*
(
PackedInt4
*
)(
&
data
[
i
]);
local_output
[
2
*
i
]
=
q_params
.
dequantize
(
accessible_data
.
low
);
local_output
[
2
*
i
+
1
]
=
q_params
.
dequantize
(
accessible_data
.
high
);
}
}
}
template
<
int
numBits
,
Type
qType
>
DS_D_INLINE
void
chunk
(
__half2
*
local_output
,
const
int8_t
*
data
,
Params
<
qType
,
numBits
>
q_params
)
{
__half
*
local_output_cast
=
reinterpret_cast
<
__half
*>
(
local_output
);
chunk
<
numBits
>
(
local_output_cast
,
data
,
q_params
);
}
template
<
int
numBits
,
Type
qType
,
int
unroll
,
int
threads
>
DS_D_INLINE
void
_to_global
(
__half
*
global_output
,
const
int8_t
*
data
,
const
float
*
global_params
,
const
int
elems_per_group
,
const
int
total_elems
)
{
cg
::
thread_block
tb
=
cg
::
this_thread_block
();
cg
::
thread_block_tile
<
hw_warp_size
>
warp
=
cg
::
tiled_partition
<
hw_warp_size
>
(
tb
);
// Load constants
// TODO(cmikeh2): Refactor into functions?
constexpr
int
load_granularity
=
granularity
*
numBits
/
16
;
constexpr
int
load_step_stride
=
load_granularity
*
threads
;
constexpr
int
load_block_stride
=
load_step_stride
*
unroll
;
// Store constants
constexpr
int
store_step_stride
=
h_per_chunk
*
threads
;
constexpr
int
store_block_stride
=
store_step_stride
*
unroll
;
// Load offsets
const
int
load_block_offset
=
tb
.
group_index
().
x
*
load_block_stride
;
// Note: we can use `load_granularity` since the dtype is `int8_t`.
const
int
load_thread_offset
=
tb
.
thread_index
().
x
*
load_granularity
;
const
int8_t
*
load_base
=
data
+
load_block_offset
+
load_thread_offset
;
// Store offsets
const
int
store_block_offset
=
tb
.
group_index
().
x
*
store_block_stride
;
const
int
store_thread_offset
=
tb
.
thread_index
().
x
*
h_per_chunk
;
const
int
elem_id_base
=
store_block_offset
+
store_thread_offset
;
int8_t
local_load_buffer
[
load_granularity
*
unroll
];
__half
local_dequant_buffer
[
h_per_chunk
*
unroll
];
/*
Note: Splitting this loop in half gave about 3-5% performance increase for reasons that aren't
totally clear to me, so this is a deliberately weird code structure.
*/
#pragma unroll
for
(
int
i
=
0
;
i
<
unroll
;
i
++
)
{
const
int
elem_id_iter
=
elem_id_base
+
i
*
store_step_stride
;
if
(
elem_id_iter
<
total_elems
)
{
mem_access
::
load_global
<
load_granularity
>
(
local_load_buffer
+
i
*
load_granularity
,
load_base
+
i
*
load_step_stride
);
}
}
#pragma unroll
for
(
int
i
=
0
;
i
<
unroll
;
i
++
)
{
const
int
elem_id_iter
=
elem_id_base
+
i
*
store_step_stride
;
if
(
elem_id_iter
<
total_elems
)
{
// TODO(cmikeh2): Can we amortize this division? Perform once on the first iteration and
// use indexing math to do division free interpolation of the successive groups?
const
int
group_index
=
elem_id_iter
/
elems_per_group
;
Params
<
qType
,
numBits
>
q_params
(
global_params
,
group_index
);
chunk
<
numBits
,
qType
>
(
local_dequant_buffer
+
i
*
h_per_chunk
,
local_load_buffer
+
i
*
load_granularity
,
q_params
);
mem_access
::
store_global
<
granularity
>
(
global_output
+
elem_id_iter
,
local_dequant_buffer
+
i
*
h_per_chunk
);
}
}
}
template
<
int
numBits
,
Type
qType
,
int
unroll
,
int
threads
>
DS_D_INLINE
void
to_global
(
__half
*
global_output
,
const
int8_t
*
data
,
const
float
*
global_params
,
const
int
elems_per_group
,
const
int
total_elems
)
{
if
constexpr
(
numBits
==
4
||
numBits
==
8
)
{
_to_global
<
numBits
,
qType
,
unroll
,
threads
>
(
global_output
,
data
,
global_params
,
elems_per_group
,
total_elems
);
}
else
if
constexpr
(
numBits
==
3
)
{
// TODO(cmikeh2): Need this implementation
assert
(
false
);
}
else
{
assert
(
false
);
}
}
}
// namespace dequantize
csrc/includes/quantization.h
浏览文件 @
30c8d8a8
...
@@ -25,6 +25,15 @@ void launch_quant(int8_t* output_data,
...
@@ -25,6 +25,15 @@ void launch_quant(int8_t* output_data,
int
elems_per_group
,
int
elems_per_group
,
cudaStream_t
stream
);
cudaStream_t
stream
);
void
launch_dequantize_kernel
(
__half
*
dequant_data
,
const
int8_t
*
q_data
,
const
float
*
q_params
,
quantize
::
Type
q_type
,
int
num_bits
,
int
elems_per_group
,
int
total_elems
,
cudaStream_t
stream
);
template
<
typename
T
>
template
<
typename
T
>
void
launch_fake_quantize_kernel
(
T
*
vals
,
void
launch_fake_quantize_kernel
(
T
*
vals
,
int
total_count
,
int
total_count
,
...
...
csrc/includes/quantization_utils.h
浏览文件 @
30c8d8a8
#include <cstdio>
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include <cassert>
#include "conversion_utils.h"
#include "conversion_utils.h"
#include "ds_kernel_utils.h"
#include "ds_kernel_utils.h"
#include "memory_access_utils.h"
#include "memory_access_utils.h"
...
@@ -33,7 +37,12 @@ public:
...
@@ -33,7 +37,12 @@ public:
*/
*/
DS_D_INLINE
int8_t
quantize
(
__half
val
);
DS_D_INLINE
int8_t
quantize
(
__half
val
);
DS_D_INLINE
__half
dequantize
(
int8_t
val
);
DS_D_INLINE
void
store
(
float
*
params
,
int
group_index
);
DS_D_INLINE
void
store
(
float
*
params
,
int
group_index
);
// Initialize from memory
DS_D_INLINE
Params
(
const
float
*
params
,
int
group_index
);
};
};
template
<
int
numBits
>
template
<
int
numBits
>
...
@@ -61,11 +70,22 @@ public:
...
@@ -61,11 +70,22 @@ public:
return
(
int8_t
)
data_i32
;
return
(
int8_t
)
data_i32
;
}
}
DS_D_INLINE
__half
dequantize
(
int8_t
val
)
{
const
float
val_deq_f
=
conversion
::
to
<
float
>
(
val
)
*
scale
;
return
conversion
::
to
<
__half
>
(
val_deq_f
);
}
DS_D_INLINE
void
store
(
float
*
params
,
int
group_index
)
DS_D_INLINE
void
store
(
float
*
params
,
int
group_index
)
{
{
const
float
store_scale
=
1
/
scale
;
const
float
store_scale
=
1
/
scale
;
mem_access
::
store_global
<
sizeof
(
float
)
>
(
params
+
group_index
,
&
store_scale
);
mem_access
::
store_global
<
sizeof
(
float
)
>
(
params
+
group_index
,
&
store_scale
);
}
}
DS_D_INLINE
Params
(
const
float
*
params
,
int
group_index
)
{
mem_access
::
load_global
<
sizeof
(
float
)
>
(
&
scale
,
params
+
group_index
);
}
};
};
template
<
int
numBits
>
template
<
int
numBits
>
...
@@ -84,10 +104,14 @@ public:
...
@@ -84,10 +104,14 @@ public:
return
(
int8_t
)
data_i32
;
return
(
int8_t
)
data_i32
;
}
}
DS_D_INLINE
__half
dequantize
(
int8_t
val
)
{
assert
(
false
);
}
DS_D_INLINE
void
store
(
float
*
params
,
int
group_index
)
DS_D_INLINE
void
store
(
float
*
params
,
int
group_index
)
{
{
mem_access
::
store_global
<
sizeof
(
float
)
>
(
params
+
group_index
,
&
scale
);
mem_access
::
store_global
<
sizeof
(
float
)
>
(
params
+
group_index
,
&
scale
);
}
}
DS_D_INLINE
Params
(
const
float
*
params
,
int
group_index
)
{
assert
(
false
);
}
};
};
template
<
int
numBits
>
template
<
int
numBits
>
...
@@ -117,12 +141,26 @@ public:
...
@@ -117,12 +141,26 @@ public:
return
(
int8_t
)
data_i32
;
return
(
int8_t
)
data_i32
;
}
}
DS_D_INLINE
__half
dequantize
(
int8_t
val
)
{
const
float
val_deq_f
=
conversion
::
to
<
float
>
(
val
)
*
scale
+
offset
;
return
conversion
::
to
<
__half
>
(
val_deq_f
);
}
DS_D_INLINE
void
store
(
float
*
params
,
int
group_index
)
DS_D_INLINE
void
store
(
float
*
params
,
int
group_index
)
{
{
// Codegen should turn this into stg.64
const
float
store_scale
=
1
/
scale
;
const
float
store_scale
=
1
/
scale
;
mem_access
::
store_global
<
sizeof
(
float
)
>
(
params
+
2
*
group_index
,
&
store_scale
);
mem_access
::
store_global
<
sizeof
(
float
)
>
(
params
+
2
*
group_index
,
&
store_scale
);
mem_access
::
store_global
<
sizeof
(
float
)
>
(
params
+
2
*
group_index
+
1
,
&
offset
);
mem_access
::
store_global
<
sizeof
(
float
)
>
(
params
+
2
*
group_index
+
1
,
&
offset
);
}
}
DS_D_INLINE
Params
(
const
float
*
params
,
int
group_index
)
{
// Codegen should turn this into ldg.64
mem_access
::
load_global
<
sizeof
(
float
)
>
(
&
scale
,
params
+
2
*
group_index
);
mem_access
::
load_global
<
sizeof
(
float
)
>
(
&
offset
,
params
+
2
*
group_index
+
1
);
}
};
};
/*
/*
...
@@ -293,7 +331,7 @@ Template Arguments :
...
@@ -293,7 +331,7 @@ Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
numBits - Number of bits in quantized element. int : 8 or 4
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
Function Arguments :
Function Arguments :
local_output - Pointer to
shared memory to store quantized data.
int8_t*
local_output - Pointer to
local memory to store quantized data.
int8_t*
data - Pointer to input data. __half*
data - Pointer to input data. __half*
Params - Parameters for quantization. Params<qType, numBits>
Params - Parameters for quantization. Params<qType, numBits>
*/
*/
...
@@ -306,7 +344,7 @@ Template Arguments :
...
@@ -306,7 +344,7 @@ Template Arguments :
numBits - Number of bits in quantized element. int : 8 or 4
numBits - Number of bits in quantized element. int : 8 or 4
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
qType - Type of quantization to perform. Type::Symmetric or Type::Asymmetric
Function Arguments :
Function Arguments :
local_output - Pointer to
shared memory to store quantized data.
int8_t*
local_output - Pointer to
local memory to store quantized data.
int8_t*
data - Pointer to input data. __half2*
data - Pointer to input data. __half2*
Params - Parameters for quantization. Params<qType, numBits>
Params - Parameters for quantization. Params<qType, numBits>
*/
*/
...
...
csrc/includes/reduction_utils.h
浏览文件 @
30c8d8a8
...
@@ -263,7 +263,7 @@ DS_D_INLINE __half init<ROpType::Min>()
...
@@ -263,7 +263,7 @@ DS_D_INLINE __half init<ROpType::Min>()
}
}
template
<
>
template
<
>
__half
init
<
ROpType
::
Max
>
()
DS_D_INLINE
__half
init
<
ROpType
::
Max
>
()
{
{
constexpr
__half_raw
neg_inf
=
{
0xFC00
};
constexpr
__half_raw
neg_inf
=
{
0xFC00
};
return
__half
(
neg_inf
);
return
__half
(
neg_inf
);
...
...
csrc/quantization/dequantize.cu
0 → 100644
浏览文件 @
30c8d8a8
/*
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "dequantization_utils.h"
#include "memory_access_utils.h"
namespace
cg
=
cooperative_groups
;
template
<
int
numBits
,
dequantize
::
Type
qType
,
int
unroll
,
int
threads
>
__global__
void
dequantize_kernel
(
__half
*
__restrict__
dequant_data
,
const
int8_t
*
__restrict__
q_data
,
const
float
*
__restrict__
q_params
,
int
elems_per_group
,
int
total_elems
)
{
dequantize
::
to_global
<
numBits
,
qType
,
unroll
,
threads
>
(
dequant_data
,
q_data
,
q_params
,
elems_per_group
,
total_elems
);
}
#define LAUNCH_DEQUANT_KERNEL(num_bits, q_type) \
dequantize_kernel<num_bits, q_type, unroll, threads><<<grid, block, 0, stream>>>( \
dequant_data, q_data, q_params, elems_per_group, total_elems);
void
launch_dequantize_kernel
(
__half
*
dequant_data
,
const
int8_t
*
q_data
,
const
float
*
q_params
,
quantize
::
Type
q_type
,
int
num_bits
,
int
elems_per_group
,
int
total_elems
,
cudaStream_t
stream
)
{
constexpr
int
unroll
=
8
;
constexpr
int
threads
=
512
;
constexpr
int
elems_per_block
=
unroll
*
threads
*
dequantize
::
h_per_chunk
;
const
dim3
block
(
threads
);
const
dim3
grid
((
total_elems
+
elems_per_block
-
1
)
/
elems_per_block
);
// TODO(cmikeh2): It may make sense to tune unroll, there is perf benefit for large
// problem sizes with this large unroll value.
if
(
num_bits
==
8
&&
q_type
==
quantize
::
Type
::
Symmetric
)
{
LAUNCH_DEQUANT_KERNEL
(
8
,
quantize
::
Type
::
Symmetric
);
}
else
if
(
num_bits
==
8
&&
q_type
==
quantize
::
Type
::
Asymmetric
)
{
LAUNCH_DEQUANT_KERNEL
(
8
,
quantize
::
Type
::
Asymmetric
);
}
else
if
(
num_bits
==
4
&&
q_type
==
quantize
::
Type
::
Symmetric
)
{
LAUNCH_DEQUANT_KERNEL
(
4
,
quantize
::
Type
::
Symmetric
);
}
else
if
(
num_bits
==
4
&&
q_type
==
quantize
::
Type
::
Asymmetric
)
{
LAUNCH_DEQUANT_KERNEL
(
4
,
quantize
::
Type
::
Asymmetric
);
}
}
csrc/quantization/pt_binding.cpp
浏览文件 @
30c8d8a8
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include <cassert>
#include <vector>
#include <vector>
#include "quantization.h"
#include "quantization.h"
...
@@ -112,6 +113,47 @@ std::vector<at::Tensor> quantize_kernel(at::Tensor& input_vals,
...
@@ -112,6 +113,47 @@ std::vector<at::Tensor> quantize_kernel(at::Tensor& input_vals,
return
{
output
,
params
};
return
{
output
,
params
};
}
}
int
num_decompressed_elems
(
at
::
Tensor
&
quantized_data
,
int
num_bits
)
{
if
(
num_bits
==
8
)
{
return
quantized_data
.
size
(
-
1
);
}
else
if
(
num_bits
==
4
)
{
return
quantized_data
.
size
(
-
1
)
*
2
;
}
else
{
assert
(
false
);
return
0
;
}
}
at
::
Tensor
dequantize
(
at
::
Tensor
&
quantized_data
,
at
::
Tensor
&
params
,
int
groups
,
int
num_bits
,
quantize
::
Type
quant_type
)
{
auto
output_options
=
at
::
TensorOptions
()
.
dtype
(
torch
::
kFloat16
)
.
layout
(
at
::
kStrided
)
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
const
int
final_dim_size
=
num_decompressed_elems
(
quantized_data
,
num_bits
);
auto
output
=
torch
::
empty
({
quantized_data
.
size
(
0
),
final_dim_size
},
output_options
);
const
int
total_elems
=
quantized_data
.
size
(
0
)
*
final_dim_size
;
const
int
elems_per_group
=
total_elems
/
groups
;
launch_dequantize_kernel
((
__half
*
)
output
.
data_ptr
(),
(
const
int8_t
*
)
quantized_data
.
data_ptr
(),
(
const
float
*
)
params
.
data_ptr
(),
quant_type
,
num_bits
,
elems_per_group
,
total_elems
,
at
::
cuda
::
getCurrentCUDAStream
());
return
output
;
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
{
m
.
def
(
"ds_quantize_fp32"
,
&
ds_quantize
<
float
>
,
"DeepSpeed Quantize with fp32 (CUDA)"
);
m
.
def
(
"ds_quantize_fp32"
,
&
ds_quantize
<
float
>
,
"DeepSpeed Quantize with fp32 (CUDA)"
);
...
@@ -133,4 +175,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
...
@@ -133,4 +175,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
.
value
(
"IntegerSymmetric"
,
quantize
::
Type
::
IntegerSymmetric
)
.
value
(
"IntegerSymmetric"
,
quantize
::
Type
::
IntegerSymmetric
)
.
export_values
();
.
export_values
();
m
.
def
(
"quantize"
,
&
quantize_kernel
);
m
.
def
(
"quantize"
,
&
quantize_kernel
);
m
.
def
(
"dequantize"
,
&
dequantize
);
}
}
csrc/quantization/quantize.cu
浏览文件 @
30c8d8a8
#include <cstdio>
/*
#include "custom_cuda_layers.h"
Copyright 2022 The Microsoft DeepSpeed Team
*/
#include "memory_access_utils.h"
#include "memory_access_utils.h"
#include "quantization.h"
#include "quantization.h"
#include "quantization_utils.h"
#include "quantization_utils.h"
...
@@ -99,7 +101,6 @@ void launch_quant(int8_t* output_data,
...
@@ -99,7 +101,6 @@ void launch_quant(int8_t* output_data,
// warp-sized blocks rather than stepping up to 64/96 threads
// warp-sized blocks rather than stepping up to 64/96 threads
const
int
one_step_threads
=
next_pow2
((
elems_per_group
+
h_per_step
-
1
)
/
h_per_step
);
const
int
one_step_threads
=
next_pow2
((
elems_per_group
+
h_per_step
-
1
)
/
h_per_step
);
const
int
threads_per_group
=
(
one_step_threads
<
max_threads
)
?
one_step_threads
:
max_threads
;
const
int
threads_per_group
=
(
one_step_threads
<
max_threads
)
?
one_step_threads
:
max_threads
;
const
int
warps_per_group
=
threads_per_group
/
hw_warp_size
;
const
int
groups_per_block
=
const
int
groups_per_block
=
is_subblock_schedule
?
(
max_threads
+
threads_per_group
-
1
)
/
threads_per_group
:
1
;
is_subblock_schedule
?
(
max_threads
+
threads_per_group
-
1
)
/
threads_per_group
:
1
;
...
...
op_builder/quantizer.py
浏览文件 @
30c8d8a8
...
@@ -17,6 +17,7 @@ class QuantizerBuilder(CUDAOpBuilder):
...
@@ -17,6 +17,7 @@ class QuantizerBuilder(CUDAOpBuilder):
'csrc/quantization/pt_binding.cpp'
,
'csrc/quantization/pt_binding.cpp'
,
'csrc/quantization/fake_quantizer.cu'
,
'csrc/quantization/fake_quantizer.cu'
,
'csrc/quantization/quantize.cu'
,
'csrc/quantization/quantize.cu'
,
'csrc/quantization/dequantize.cu'
,
]
]
def
include_paths
(
self
):
def
include_paths
(
self
):
...
...
tests/unit/ops/quantizer/test_dequantize.py
0 → 100644
浏览文件 @
30c8d8a8
"""
Copyright 2022 The Microsoft DeepSpeed Team
"""
import
pytest
import
torch
from
deepspeed.ops
import
op_builder
quantize_module
=
None
def
int4x2to2xint4
(
int4X2tensor
):
high
=
int4X2tensor
>>
4
low
=
(
int4X2tensor
<<
4
)
>>
4
return
torch
.
stack
((
high
,
low
),
dim
=-
1
).
flatten
()
def
run_quantize
(
data
,
num_groups
,
q_bits
,
is_symmetric_quant
):
global
quantize_module
if
quantize_module
is
None
:
quantize_module
=
op_builder
.
QuantizerBuilder
().
load
()
return
quantize_module
.
quantize
(
data
,
num_groups
,
q_bits
,
quantize_module
.
Symmetric
if
is_symmetric_quant
else
quantize_module
.
Asymmetric
)
def
run_dequantize
(
quantized_data
,
params
,
num_groups
,
q_bits
,
is_symmetric_quant
):
global
quantize_module
if
quantize_module
is
None
:
quantize_module
=
op_builder
.
QuantizerBuilder
().
load
()
return
quantize_module
.
dequantize
(
quantized_data
,
params
,
num_groups
,
q_bits
,
quantize_module
.
Symmetric
if
is_symmetric_quant
else
quantize_module
.
Asymmetric
)
def
run_ref_dequantize
(
quantized_data
,
params
,
num_groups
,
q_bits
,
is_symmetric_quant
):
if
(
q_bits
==
4
):
quantized_data
=
int4x2to2xint4
(
quantized_data
)
quantized_data
=
quantized_data
.
reshape
(
num_groups
,
-
1
).
to
(
torch
.
float32
)
if
is_symmetric_quant
:
return
(
quantized_data
*
params
).
to
(
torch
.
float16
)
else
:
scales
=
params
[:,
0
].
reshape
(
-
1
,
1
)
offsets
=
params
[:,
1
].
reshape
(
-
1
,
1
)
return
(
quantized_data
*
scales
+
offsets
).
to
(
torch
.
float16
)
@
pytest
.
mark
.
inference
@
pytest
.
mark
.
parametrize
(
"num_groups"
,
[
1
,
13
,
512
])
@
pytest
.
mark
.
parametrize
(
"num_elems"
,
[
8
,
16
,
32
,
64
,
128
,
256
,
4096
,
8192
,
12288
,
16384
])
@
pytest
.
mark
.
parametrize
(
"is_symmetric_quant"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"q_bits"
,
[
4
,
8
])
def
test_dequantize
(
num_elems
,
num_groups
,
is_symmetric_quant
,
q_bits
):
activations
=
torch
.
randn
((
num_groups
,
num_elems
),
dtype
=
torch
.
float16
,
device
=
'cuda'
)
quantized_data
,
params
=
run_quantize
(
activations
,
num_groups
,
q_bits
,
is_symmetric_quant
)
ds_dequant
=
run_dequantize
(
quantized_data
,
params
,
num_groups
,
q_bits
,
is_symmetric_quant
)
ref_dequant
=
run_ref_dequantize
(
quantized_data
,
params
,
num_groups
,
q_bits
,
is_symmetric_quant
)
assert
(
torch
.
allclose
(
ds_dequant
.
flatten
(),
ref_dequant
.
flatten
(),
rtol
=
3e-2
,
atol
=
2e-3
))
tests/unit/ops/quantizer/test_quantize.py
浏览文件 @
30c8d8a8
...
@@ -7,7 +7,6 @@ import torch
...
@@ -7,7 +7,6 @@ import torch
from
deepspeed.ops
import
op_builder
from
deepspeed.ops
import
op_builder
inference_module
=
None
inference_module
=
None
torch_minor_version
=
None
def
run_quantize_ds
(
activations
,
num_groups
,
q_bits
,
is_symmetric_quant
):
def
run_quantize_ds
(
activations
,
num_groups
,
q_bits
,
is_symmetric_quant
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录