Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Greenplum
DeepSpeed
提交
afdc7287
D
DeepSpeed
项目概览
Greenplum
/
DeepSpeed
上一次同步 大约 1 年
通知
10
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
DeepSpeed
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
afdc7287
编写于
8月 30, 2022
作者:
R
Reza Yazdani
提交者:
GitHub
8月 30, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Ds-inference Int8 support through ZeroQuant technology (#2217)
Co-authored-by:
N
Jeff Rasley
<
jerasley@microsoft.com
>
上级
088212a7
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
719 addition
and
358 deletion
+719
-358
csrc/transformer/inference/csrc/dequantize.cu
csrc/transformer/inference/csrc/dequantize.cu
+87
-0
csrc/transformer/inference/csrc/gelu.cu
csrc/transformer/inference/csrc/gelu.cu
+1
-0
csrc/transformer/inference/csrc/pt_binding.cpp
csrc/transformer/inference/csrc/pt_binding.cpp
+134
-81
csrc/transformer/inference/includes/custom_cuda_layers.h
csrc/transformer/inference/includes/custom_cuda_layers.h
+8
-0
deepspeed/__init__.py
deepspeed/__init__.py
+17
-3
deepspeed/inference/engine.py
deepspeed/inference/engine.py
+43
-31
deepspeed/module_inject/load_checkpoint.py
deepspeed/module_inject/load_checkpoint.py
+122
-54
deepspeed/module_inject/replace_module.py
deepspeed/module_inject/replace_module.py
+193
-71
deepspeed/ops/transformer/inference/transformer_inference.py
deepspeed/ops/transformer/inference/transformer_inference.py
+114
-118
未找到文件。
csrc/transformer/inference/csrc/dequantize.cu
浏览文件 @
afdc7287
...
...
@@ -108,3 +108,90 @@ template void launch_dequantize<__half>(__half*,
unsigned
,
unsigned
,
cudaStream_t
);
__global__
void
dequantize_kernel
(
float
*
output
,
const
int8_t
*
input
,
const
float
*
qscale
,
int
hidden_dim
,
unsigned
merge_hidden
,
int
cnt
)
{
}
__global__
void
dequantize_kernel
(
__half
*
output
,
const
int8_t
*
input
,
const
float
*
qscale
,
unsigned
hidden_dim
,
unsigned
merge_hidden
,
int
cnt
)
{
unsigned
bid
=
blockIdx
.
x
*
gridDim
.
y
+
blockIdx
.
y
;
unsigned
tid
=
threadIdx
.
x
;
float
local_scale
=
qscale
[
blockIdx
.
x
];
const
float
*
input_cast
=
reinterpret_cast
<
const
float
*>
(
input
);
float2
*
output_cast
=
reinterpret_cast
<
float2
*>
(
output
);
input_cast
+=
bid
*
merge_hidden
;
output_cast
+=
bid
*
merge_hidden
;
for
(
int
c
=
0
;
c
<
cnt
;
c
++
)
{
if
(
tid
<
merge_hidden
)
{
float
q
=
input_cast
[
tid
];
int8_t
*
q_int8
=
(
int8_t
*
)
&
q
;
float2
q_f
;
__half
*
q_h
=
(
__half
*
)
&
q_f
;
q_h
[
0
]
=
__float2half
(
local_scale
*
(
float
)
q_int8
[
0
]);
q_h
[
1
]
=
__float2half
(
local_scale
*
(
float
)
q_int8
[
1
]);
q_h
[
2
]
=
__float2half
(
local_scale
*
(
float
)
q_int8
[
2
]);
q_h
[
3
]
=
__float2half
(
local_scale
*
(
float
)
q_int8
[
3
]);
// q_h[4] = __float2half(local_scale * (float)q_int8[4]);
// q_h[5] = __float2half(local_scale * (float)q_int8[5]);
// q_h[6] = __float2half(local_scale * (float)q_int8[6]);
// q_h[7] = __float2half(local_scale * (float)q_int8[7]);
output_cast
[
tid
]
=
q_f
;
tid
+=
blockDim
.
x
;
}
}
}
template
<
typename
T
>
void
launch_dequantize
(
T
*
output
,
const
int8_t
*
input
,
const
float
*
qscale
,
unsigned
output_size
,
unsigned
hidden_dim
,
unsigned
groups
,
cudaStream_t
stream
)
{
unsigned
threads
=
1024
;
hidden_dim
/=
4
;
unsigned
hid_cnt
=
threads
/
hidden_dim
;
unsigned
thd_cnt
=
(
hidden_dim
-
1
)
/
threads
+
1
;
hid_cnt
=
hid_cnt
>
0
?
hid_cnt
:
1
;
unsigned
blocks
=
output_size
/
hid_cnt
/
groups
;
dim3
block_dims
(
threads
);
dim3
grid_dims
(
groups
,
blocks
);
dequantize_kernel
<<<
grid_dims
,
block_dims
,
0
,
stream
>>>
(
output
,
input
,
qscale
,
hidden_dim
,
hid_cnt
*
hidden_dim
,
thd_cnt
);
}
template
void
launch_dequantize
<
float
>(
float
*
,
const
int8_t
*
,
const
float
*
,
unsigned
,
unsigned
,
unsigned
,
cudaStream_t
);
template
void
launch_dequantize
<
__half
>(
__half
*
,
const
int8_t
*
,
const
float
*
,
unsigned
,
unsigned
,
unsigned
,
cudaStream_t
);
csrc/transformer/inference/csrc/gelu.cu
浏览文件 @
afdc7287
#include "custom_cuda_layers.h"
namespace
cg
=
cooperative_groups
;
#define MAX_CAP 4
#define MAX_SEQ 2048
...
...
csrc/transformer/inference/csrc/pt_binding.cpp
浏览文件 @
afdc7287
...
...
@@ -558,15 +558,55 @@ void ds_layernorm_internal(T* workspace,
Context
::
Instance
().
GetCurrentStream
());
}
template
<
typename
T
>
void
quantized_gemm
(
at
::
Tensor
&
output
,
T
*
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
qscale
,
int
groups
,
int
bsz
)
{
auto
weight16
=
at
::
empty
({
weight
.
size
(
0
),
weight
.
size
(
1
)},
output
.
options
());
launch_dequantize
((
T
*
)
weight16
.
data_ptr
(),
(
int8_t
*
)
weight
.
data_ptr
(),
(
float
*
)
qscale
.
data_ptr
(),
weight
.
size
(
0
),
weight
.
size
(
1
),
groups
,
Context
::
Instance
().
GetCurrentStream
());
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_T
,
CUBLAS_OP_N
,
weight
.
size
(
0
),
bsz
,
weight
.
size
(
1
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight16
.
data_ptr
(),
(
T
*
)
input
,
(
T
*
)
output
.
data_ptr
(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard
);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
}
template
<
typename
T
>
at
::
Tensor
qkv_unfused_cublas
(
at
::
Tensor
&
output
,
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
q_scale
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
add_bias
)
bool
add_bias
,
bool
q_int8
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
T
*
workspace
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
();
...
...
@@ -574,48 +614,55 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
ds_layernorm_internal
<
T
>
(
workspace
,
input
,
gamma
,
beta
,
epsilon
);
// cudaEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream());
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
if
(
q_int8
)
{
quantized_gemm
<
T
>
(
output
,
workspace
,
weight
,
q_scale
,
q_scale
.
size
(
0
),
bsz
);
}
else
{
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
workspace
,
(
T
*
)
output
.
data_ptr
(),
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
workspace
,
(
T
*
)
output
.
data_ptr
(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard
);
rocblas_gemm_algo_standard
);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
}
if
(
add_bias
)
launch_bias_add
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
q_int8
?
weight
.
size
(
0
)
:
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
return
torch
::
from_blob
(
workspace
,
input
.
sizes
(),
input
.
options
());
}
template
<
typename
T
>
std
::
vector
<
at
::
Tensor
>
ds_qkv_gemm
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
at
::
Tensor
&
q_scale
,
at
::
Tensor
&
bias
,
at
::
Tensor
&
gamma
,
at
::
Tensor
&
beta
,
const
float
epsilon
,
bool
add_bias
,
unsigned
num_layers
)
unsigned
num_layers
,
bool
q_int8
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
T
*
workspace
=
(
T
*
)
Context
::
Instance
().
GetWorkSpace
();
int
out_size
=
q_int8
?
weight
.
size
(
0
)
:
weight
.
size
(
1
);
if
(
!
workspace
)
{
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
...
...
@@ -628,9 +675,9 @@ std::vector<at::Tensor> ds_qkv_gemm(at::Tensor& input,
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
from_blob
(
workspace
,
{
input
.
size
(
0
),
input
.
size
(
1
),
weight
.
size
(
1
)
},
options
);
auto
inp_norm
=
qkv_unfused_cublas
<
T
>
(
output
,
input
,
weight
,
bias
,
gamma
,
beta
,
epsilon
,
add_bias
);
auto
output
=
at
::
from_blob
(
workspace
,
{
input
.
size
(
0
),
input
.
size
(
1
),
out_size
},
options
);
auto
inp_norm
=
qkv_unfused_cublas
<
T
>
(
output
,
input
,
weight
,
q_scale
,
bias
,
gamma
,
beta
,
epsilon
,
add_bias
,
q_int8
);
return
{
output
,
inp_norm
};
}
...
...
@@ -654,20 +701,18 @@ void quantized_gemm(at::Tensor& output,
launch_dequantize
((
T
*
)
weight16
.
data_ptr
(),
(
int8_t
*
)
weight
.
data_ptr
(),
(
float
*
)
qscale
.
data_ptr
(),
weight
.
size
(
1
),
weight
.
size
(
0
),
weight
.
size
(
1
),
groups
,
merge_count
,
Context
::
Instance
().
GetCurrentStream
());
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
weight
.
size
(
0
),
bsz
,
input
.
size
(
2
),
&
alpha
,
...
...
@@ -796,7 +841,11 @@ at::Tensor ds_linear_layer_int8(at::Tensor& input,
}
template
<
typename
T
>
at
::
Tensor
ds_vector_matmul
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
bool
async_op
)
at
::
Tensor
ds_vector_matmul
(
at
::
Tensor
&
input
,
at
::
Tensor
&
weight
,
bool
async_op
,
at
::
Tensor
&
q_scale
,
bool
q_int8
)
{
auto
input_cont
=
input
.
contiguous
();
auto
options
=
at
::
TensorOptions
()
...
...
@@ -805,28 +854,33 @@ at::Tensor ds_vector_matmul(at::Tensor& input, at::Tensor& weight, bool async_op
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
out_size
=
q_int8
?
weight
.
size
(
0
)
:
weight
.
size
(
1
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
(
async_op
));
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
bsz
,
input_cont
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
out_size
},
options
);
if
(
q_int8
)
{
quantized_gemm
<
T
>
(
output
,
(
T
*
)
input_cont
.
data_ptr
(),
weight
,
q_scale
,
q_scale
.
size
(
0
),
bsz
);
}
else
{
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
(
async_op
));
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
bsz
,
input_cont
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
input_cont
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard
);
rocblas_gemm_algo_standard
);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
}
return
output
;
}
...
...
@@ -862,6 +916,8 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
const
float
epsilon
,
bool
preLayerNorm
,
bool
mlp_after_attn
,
at
::
Tensor
&
q_scale
,
bool
q_int8
,
ActivationFuncType
act_func_type
)
{
int
bsz
=
input
.
size
(
0
)
*
input
.
size
(
1
);
...
...
@@ -881,36 +937,40 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
mlp_after_attn
,
Context
::
Instance
().
GetCurrentStream
());
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
inp_norm
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
if
(
q_int8
)
{
quantized_gemm
<
T
>
(
output
,
(
T
*
)
inp_norm
.
data_ptr
(),
weight
,
q_scale
,
q_scale
.
size
(
0
),
bsz
);
}
else
{
float
alpha
=
(
T
)
1.0
;
float
gemm_beta
=
(
T
)
0.0
;
cublasSetStream
(
Context
::
Instance
().
GetCublasHandle
(),
Context
::
Instance
().
GetCurrentStream
());
cublas_gemm_ex
(
Context
::
Instance
().
GetCublasHandle
(),
CUBLAS_OP_N
,
CUBLAS_OP_N
,
weight
.
size
(
1
),
bsz
,
input
.
size
(
2
),
&
alpha
,
&
gemm_beta
,
(
T
*
)
weight
.
data_ptr
(),
(
T
*
)
inp_norm
.
data_ptr
(),
(
T
*
)
output
.
data_ptr
(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard
);
rocblas_gemm_algo_standard
);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
#endif
}
if
(
act_func_type
==
ActivationFuncType
::
GELU
)
{
launch_bias_gelu
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
q_int8
?
weight
.
size
(
0
)
:
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
}
else
if
(
act_func_type
==
ActivationFuncType
::
ReLU
)
{
launch_bias_relu
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
weight
.
size
(
1
),
q_int8
?
weight
.
size
(
0
)
:
weight
.
size
(
1
),
bsz
,
Context
::
Instance
().
GetCurrentStream
());
}
...
...
@@ -929,6 +989,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
const
float
epsilon
,
bool
preLayerNorm
,
bool
mlp_after_attn
,
at
::
Tensor
&
q_scale
,
bool
q_int8
,
int
activation_type
)
{
auto
input_cont
=
input
.
contiguous
();
...
...
@@ -938,7 +1000,10 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
.
device
(
at
::
kCUDA
)
.
requires_grad
(
false
);
auto
output
=
at
::
empty
({
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
weight
.
size
(
1
)},
options
);
int
out_size
=
q_int8
?
weight
.
size
(
0
)
:
weight
.
size
(
1
);
auto
output
=
at
::
from_blob
((
T
*
)
Context
::
Instance
().
GetWorkSpace
(),
{
input_cont
.
size
(
0
),
input_cont
.
size
(
1
),
out_size
},
options
);
int
bsz
=
input_cont
.
size
(
0
)
*
input_cont
.
size
(
1
);
auto
act_func_type
=
static_cast
<
ActivationFuncType
>
(
activation_type
);
...
...
@@ -953,6 +1018,8 @@ std::vector<at::Tensor> ds_mlp_gemm(at::Tensor& input,
epsilon
,
preLayerNorm
,
mlp_after_attn
,
q_scale
,
q_int8
,
act_func_type
);
return
{
output
,
res_add
};
...
...
@@ -984,20 +1051,6 @@ std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input,
auto
inp_norm
=
at
::
empty_like
(
input_cont
);
auto
residual_add
=
(
preLayerNorm
?
at
::
empty_like
(
input_cont
)
:
inp_norm
);
// computing the blocking across K dimension
// launch_residual_layer_norm((T*)inp_norm.data_ptr(),
// (T*)residual_add.data_ptr(),
// (T*)input_cont.data_ptr(),
// (T*)residual.data_ptr(),
// (T*)input_bias.data_ptr(),
// (T*)gamma.data_ptr(),
// (T*)beta.data_ptr(),
// epsilon,
// bsz,
// input_cont.size(2),
// preLayerNorm,
// Context::Instance().GetCurrentStream());
quantized_gemm
<
T
>
(
output
,
inp_norm
,
weight
,
q_scale
,
groups
,
0
);
launch_bias_gelu
((
T
*
)
output
.
data_ptr
(),
(
T
*
)
bias
.
data_ptr
(),
...
...
csrc/transformer/inference/includes/custom_cuda_layers.h
浏览文件 @
afdc7287
...
...
@@ -109,6 +109,14 @@ void launch_dequantize(T* output,
cudaStream_t
stream
);
template
<
typename
T
>
void
launch_dequantize
(
T
*
output
,
const
int8_t
*
input
,
const
float
*
qscale
,
unsigned
output_size
,
unsigned
hidden_dim
,
unsigned
groups
,
cudaStream_t
stream
);
template
<
typename
T
>
void
launch_gptj_residual_add
(
T
*
input
,
T
*
output
,
T
*
attn
,
...
...
deepspeed/__init__.py
浏览文件 @
afdc7287
...
...
@@ -242,7 +242,8 @@ def init_inference(model,
moe_type
=
'standard'
,
args
=
None
,
enable_cuda_graph
=
False
,
save_mp_checkpoint_path
=
None
):
save_mp_checkpoint_path
=
None
,
base_dir
=
""
):
"""Initialize the DeepSpeed InferenceEngine.
Arguments:
...
...
@@ -278,7 +279,19 @@ def init_inference(model,
of groups used in quantization. A tuple is passed in if we want to mention that there is extra-grouping
for the MLP part of a Transformer layer (e.g. (True, 8) shows we quantize the model using 8 groups for
all the network except the MLP part that we use 8 extra grouping).
replace_with_kernel_inject: If set we inject kernel as we initialize the inference-engine
replace_with_kernel_inject: this flag need to be set to true to inject inference kernels for models such as, Bert, GPT2, GPT-Neo and GPT-J. Otherwise,
the injection_dict provides the names of two linear layers as a tuple: (attention_output projection, transformer output projection)
return_tuple: Specify whether or not the transformer layers need to return a tuple or a Tensor. It is set to True by default (returning a tuple).
ep_size: The expert-parallelism size which is used for partitioning the experts across the GPUs in the expert-parallel group.
moe: Specify if the type of Transformer is MoE. It is set to False by default.
moe_experts: The global number of experts used in an MoE layer.
moe_type: Specify the type of MoE layer. We have two types of MoE layer: 'Standard' and 'Residual'. It is set to 'Standard' type by default.
args: All the arguments used for launching the inference api that can be useful at the inference-engine for injecting the optimizations.
enable_cuda_graph: use this flag for capturing the CUDA-Graph of the inference ops, so that it can run faster using the graph replay method,
this is set to False by default
save_mp_checkpoint_path: The path for which we want to save the loaded model with a checkpoint. This feature is used for adjusting the
parallelism degree to help alleviate the model loading overhead. It does not save any new checkpoint if no path is passed.
base_dir: This shows the root directory under which all the checkpoint files exists. This can be passed through the json config too.
Returns:
A deepspeed.InferenceEngine wrapped model.
...
...
@@ -309,6 +322,7 @@ def init_inference(model,
moe_type
,
args
,
enable_cuda_graph
,
save_mp_checkpoint_path
)
save_mp_checkpoint_path
,
base_dir
)
return
engine
deepspeed/inference/engine.py
浏览文件 @
afdc7287
...
...
@@ -51,7 +51,8 @@ class InferenceEngine(Module):
moe_type
=
'standard'
,
config
=
None
,
enable_cuda_graph
=
False
,
save_mp_checkpoint_path
=
None
):
save_mp_checkpoint_path
=
None
,
base_dir
=
""
):
"""
Args:
model: torch.nn.Module
...
...
@@ -100,6 +101,9 @@ class InferenceEngine(Module):
self
.
checkpoint_engine
=
TorchCheckpointEngine
()
self
.
_init_quantization_setting
(
quantization_setting
)
# This is a hack to remove the prepare_mask function on HF side for BLOOM architecture
self
.
remove_mask_prepare_for_bloom
()
if
enable_cuda_graph
:
assert
pkg_version
.
parse
(
torch
.
__version__
)
>=
pkg_version
.
parse
(
"1.10"
),
\
"If you want to use cuda graph, please upgrade torch to at least v1.10"
...
...
@@ -135,7 +139,8 @@ class InferenceEngine(Module):
moe_type
,
training_mp_size
,
self
.
checkpoint
if
replace_with_kernel_inject
else
None
,
save_mp_checkpoint_path
=
save_mp_checkpoint_path
)
save_mp_checkpoint_path
=
save_mp_checkpoint_path
,
base_dir
=
base_dir
)
elif
replace_method
==
'auto'
:
self
.
_apply_injection_policy
(
return_tuple
=
return_tuple
,
...
...
@@ -145,7 +150,8 @@ class InferenceEngine(Module):
moe_type
=
moe_type
,
training_mp_size
=
training_mp_size
,
checkpoint_dir
=
self
.
checkpoint
if
replace_with_kernel_inject
else
None
,
save_mp_checkpoint_path
=
save_mp_checkpoint_path
)
save_mp_checkpoint_path
=
save_mp_checkpoint_path
,
base_dir
=
base_dir
)
device
=
torch
.
cuda
.
current_device
()
self
.
module
.
to
(
device
)
...
...
@@ -165,6 +171,11 @@ class InferenceEngine(Module):
self
.
config
=
getattr
(
self
.
module
,
'config'
,
None
)
if
config
is
None
else
config
self
.
generate
=
getattr
(
self
.
module
,
'generate'
,
None
)
def
remove_mask_prepare_for_bloom
(
self
):
if
hasattr
(
self
.
module
,
'transformer'
):
if
hasattr
(
self
.
module
.
transformer
,
'_prepare_attn_mask'
):
self
.
module
.
transformer
.
_prepare_attn_mask
=
lambda
attention_mask
,
*
args
,
**
kwargs
:
attention_mask
def
_create_model_parallel_group
(
self
):
# Call the init process
if
InferenceEngine
.
inference_mp_group
is
None
:
...
...
@@ -326,36 +337,37 @@ class InferenceEngine(Module):
moe_type
=
'standard'
,
training_mp_size
=
1
,
checkpoint_dir
=
None
,
save_mp_checkpoint_path
=
False
):
save_mp_checkpoint_path
=
False
,
base_dir
=
""
):
checkpoint
=
SDLoaderFactory
.
get_sd_loader_json
(
checkpoint_dir
,
self
.
checkpoint_engine
)
if
checkpoint_dir
is
not
None
else
None
replace_transformer_layer
(
client_
module
,
self
.
module
,
triangular_masking
=
self
.
triangular_masking
,
policy
=
injection_policy
,
mp_size
=
self
.
mp_world_size
,
mp_group
=
self
.
m
p_group
,
ep_group
=
self
.
e
p_group
,
expert_mp_group
=
self
.
expert_mp_group
,
config
=
self
.
config
,
fp16
=
(
self
.
dtype
==
torch
.
half
),
training
=
False
,
return_tuple
=
return_tuple
,
quantize
=
(
self
.
dtype
==
torch
.
int8
),
quantize_settings
=
(
self
.
quantization_scales
,
self
.
quantize_merge_count
,
self
.
mlp_extra_grouping
,
self
.
quantize_groups
),
replace_with_kernel_inject
=
replace_with_kernel_inject
,
moe
=
moe
,
moe_experts
=
moe_experts
,
moe_type
=
moe_type
,
training_mp_size
=
training_mp_size
,
checkpoint_dict
=
checkpoint
,
save_mp_checkpoint_path
=
save_mp_checkpoint_path
,
)
replace_transformer_layer
(
client_module
,
self
.
module
,
triangular_masking
=
self
.
triangular_masking
,
policy
=
injection_policy
,
mp_size
=
self
.
mp_world_size
,
mp_group
=
self
.
mp_group
,
ep_group
=
self
.
e
p_group
,
expert_mp_group
=
self
.
expert_m
p_group
,
config
=
self
.
config
,
fp16
=
(
self
.
dtype
==
torch
.
half
)
or
(
self
.
dtype
==
torch
.
int8
),
training
=
False
,
return_tuple
=
return_tuple
,
quantize
=
(
self
.
dtype
==
torch
.
int8
),
quantize_settings
=
(
self
.
quantization_scales
,
self
.
quantize_merge_count
,
self
.
mlp_extra_grouping
,
self
.
quantize_groups
),
replace_with_kernel_inject
=
replace_with_kernel_inject
,
moe
=
moe
,
moe_experts
=
moe_experts
,
moe_type
=
moe_type
,
training_mp_size
=
training_mp_size
,
checkpoint_dict
=
checkpoint
,
save_mp_checkpoint_path
=
save_mp_checkpoint_path
,
base_dir
=
base_dir
)
def
_get_all_ckpt_names
(
self
,
checkpoints_path
,
tag
):
ckpt_file_pattern
=
self
.
_get_ckpt_name
(
checkpoints_path
,
...
...
@@ -450,7 +462,7 @@ class InferenceEngine(Module):
return
'model'
def
_convert_to_dtype
(
self
):
if
self
.
dtype
is
torch
.
int8
and
self
.
quantization_scales
is
None
:
if
False
:
#
self.dtype is torch.int8 and self.quantization_scales is None:
quantizer
=
WeightQuantization
(
mlp_extra_grouping
=
self
.
mlp_extra_grouping
)
model
,
self
.
quantization_scales
=
quantizer
.
model_quantize
(
self
.
module
,
self
.
injection_dict
,
...
...
deepspeed/module_inject/load_checkpoint.py
浏览文件 @
afdc7287
...
...
@@ -3,9 +3,15 @@ import deepspeed.ops.transformer as transformer_inference
from
..runtime.zero
import
GatheredParameters
from
.layers
import
LinearLayer
,
Normalize
,
EmbeddingLayer
import
torch
import
gc
def
load_model_with_checkpoint
(
r_module
,
sd
,
mp_replace
,
ckpt_type
,
rank
=
0
):
def
load_model_with_checkpoint
(
r_module
,
sd
,
mp_replace
,
ckpt_type
,
weight_quantizer
=
None
,
rank
=
0
):
error_msgs
=
[]
def
transpose
(
data
):
...
...
@@ -15,7 +21,7 @@ def load_model_with_checkpoint(r_module, sd, mp_replace, ckpt_type, rank=0):
return
data
.
reshape
(
data
.
shape
[
-
1
],
data
.
shape
[
-
2
])
def
load
(
module
,
prefix
):
args
=
(
sd
,
prefix
,
{},
True
,
[],
[],
error_msgs
)
args
=
(
sd
[
0
]
,
prefix
,
{},
True
,
[],
[],
error_msgs
)
if
len
(
list
(
module
.
parameters
()))
>
0
and
list
(
module
.
parameters
())[
0
].
numel
()
==
0
:
...
...
@@ -25,81 +31,142 @@ def load_model_with_checkpoint(r_module, sd, mp_replace, ckpt_type, rank=0):
else
:
if
hasattr
(
module
,
'weight'
):
module
.
weight
=
mp_replace
.
copy
(
module
.
weight
.
data
,
sd
[
prefix
+
'weight'
])
if
prefix
+
'bias'
in
sd
.
keys
():
module
.
bias
=
mp_replace
.
copy
(
module
.
bias
.
data
,
sd
[
prefix
+
'bias'
])
sd
[
0
][
prefix
+
'weight'
])
if
prefix
+
'bias'
in
sd
[
0
].
keys
():
module
.
bias
=
mp_replace
.
copy
(
module
.
bias
.
data
,
sd
[
0
][
prefix
+
'bias'
])
args
=
None
gc
.
collect
()
def
load_transformer_layer
(
module
,
prefix
):
if
ckpt_type
==
"tp"
:
def
load_parameters
(
module
,
prefix
):
for
n
,
p
in
module
.
named_parameters
():
if
len
(
n
.
split
(
'.'
))
==
1
:
src_shape
=
sd
[
prefix
+
n
].
shape
if
prefix
+
n
in
sd
[
0
]
and
len
(
n
.
split
(
'.'
))
==
1
:
if
type
(
sd
[
0
][
prefix
+
n
])
is
list
:
tmp_data
,
scale
=
sd
[
0
][
prefix
+
n
]
tmp_data
=
tmp_data
scale
=
scale
.
to
(
torch
.
cuda
.
current_device
())
else
:
tmp_data
=
sd
[
0
][
prefix
+
n
].
to
(
torch
.
cuda
.
current_device
())
scale
=
None
src_shape
=
tmp_data
.
shape
dst_shape
=
p
.
shape
inner_dim
=
1
if
tmp_data
.
dtype
==
torch
.
int8
else
0
outer_dim
=
0
if
tmp_data
.
dtype
==
torch
.
int8
else
1
if
(
len
(
src_shape
)
==
2
and
len
(
dst_shape
)
==
2
):
if
src_shape
[
0
]
==
dst_shape
[
0
]
and
src_shape
[
1
]
==
dst_shape
[
1
]:
p
.
data
.
copy_
(
sd
[
prefix
+
n
])
if
(
src_shape
[
inner_dim
]
==
dst_shape
[
0
]
and
src_shape
[
outer_dim
]
==
dst_shape
[
1
]):
if
tmp_data
.
dtype
!=
torch
.
int8
:
p
=
weight_quantizer
.
quantize
(
transpose
(
tmp_data
)
if
weight_quantizer
.
q_int8
else
tmp_data
)
else
:
p
=
torch
.
nn
.
parameter
.
Parameter
(
tmp_data
,
requires_grad
=
False
)
p
.
scale
=
scale
setattr
(
module
,
n
,
p
)
else
:
if
src_shape
[
0
]
!=
dst_shape
[
0
]:
weight_split
=
torch
.
split
(
sd
[
prefix
+
n
],
dst_shape
[
0
],
dim
=
0
)[
rank
].
to
(
torch
.
cuda
.
current_device
()).
contiguous
()
dim
=
inner_dim
if
src_shape
[
inner_dim
]
!=
dst_shape
[
0
]
else
outer_dim
dim1
=
0
if
src_shape
[
inner_dim
]
!=
dst_shape
[
0
]
else
1
if
src_shape
[
dim
]
>
dst_shape
[
dim1
]:
weight_partition
=
torch
.
split
(
tmp_data
,
dst_shape
[
dim1
],
dim
=
dim
)[
rank
].
to
(
torch
.
cuda
.
current_device
())
assert
tmp_data
.
dtype
!=
torch
.
int8
or
scale
.
numel
()
>
weight_quantizer
.
num_groups
*
(
rank
+
1
),
\
'''ERROR: We require the quantization scales for larger TP-size when loading INT8 checkpoint!
\
Please use the FP16 checkpoint to generate INT8 checkpoint with the sharding parameters!'''
scale
=
scale
.
view
(
-
1
)[
weight_quantizer
.
num_groups
*
(
rank
+
1
):].
reshape
(
weight_quantizer
.
num_groups
,
-
1
).
contiguous
()
else
:
weight_split
=
torch
.
split
(
sd
[
prefix
+
n
],
dst_shape
[
1
],
dim
=
1
)[
rank
].
to
(
torch
.
cuda
.
current_device
()).
contiguous
()
p
.
data
.
copy_
(
weight_split
.
contiguous
())
assert
tmp_data
.
dtype
!=
torch
.
int8
,
\
'''Merging of the checkpoints are not supported when using INT8 checkpoint!
\
Please use a as many GPUs as TP-size for the checkpoint'''
all_data
=
[
sd
[
j
][
prefix
+
n
]
if
type
(
sd
[
j
][
prefix
+
n
])
is
list
else
sd
[
j
][
prefix
+
n
].
to
(
torch
.
cuda
.
current_device
())
for
j
in
range
(
len
(
sd
))
]
weight_partition
=
torch
.
cat
([
ad
[
0
].
to
(
torch
.
cuda
.
current_device
())
if
type
(
ad
)
is
list
else
ad
for
ad
in
all_data
],
dim
=
dim
)
if
tmp_data
.
dtype
==
torch
.
int8
:
scale
=
torch
.
cat
([
ad
[
1
].
to
(
torch
.
cuda
.
current_device
())
for
ad
in
all_data
],
dim
=
dim
)
if
tmp_data
.
dtype
!=
torch
.
int8
:
weight_partition
=
weight_quantizer
.
quantize
(
transpose
(
weight_partition
),
\
parallel_dim
=
(
0
if
dim
==
1
else
1
))
if
weight_quantizer
.
q_int8
else
\
weight_quantizer
.
quantize
(
weight_partition
)
else
:
weight_partition
=
torch
.
nn
.
parameter
.
Parameter
(
weight_partition
,
requires_grad
=
False
)
weight_partition
.
scale
=
scale
setattr
(
module
,
n
,
weight_partition
)
else
:
if
src_shape
[
0
]
==
dst_shape
[
0
]:
p
.
data
.
copy_
(
sd
[
prefix
+
n
]
)
p
.
data
.
copy_
(
tmp_data
)
else
:
bias_split
=
torch
.
split
(
sd
[
prefix
+
n
],
dst_shape
[
-
1
])[
rank
].
to
(
torch
.
cuda
.
current_device
()).
contiguous
()
p
.
data
.
copy_
(
bias_split
)
if
src_shape
[
0
]
>
dst_shape
[
0
]:
bias_split
=
torch
.
split
(
tmp_data
,
dst_shape
[
-
1
])[
rank
].
to
(
torch
.
cuda
.
current_device
()).
contiguous
()
p
.
data
.
copy_
(
bias_split
)
else
:
p
.
data
.
copy_
(
torch
.
cat
(
[
sd
[
j
][
prefix
+
n
]
for
j
in
range
(
len
(
sd
))],
dim
=
0
).
to
(
torch
.
cuda
.
current_device
()).
contiguous
())
load_parameters
(
module
,
prefix
)
for
n
,
child
in
module
.
named_children
():
load_parameters
(
child
,
prefix
+
n
+
'.'
)
else
:
module
.
norm_w
.
data
.
copy_
(
sd
[
prefix
+
'input_layernorm.'
+
'weight'
])
module
.
norm_b
.
data
.
copy_
(
sd
[
prefix
+
'input_layernorm.'
+
'bias'
])
module
.
attention
.
attn_qkvw
=
mp_replace
.
copy
(
module
.
attention
.
attn_qkvw
.
data
,
transpose
(
sd
[
prefix
+
'self_attention.query_key_value.'
+
'weight'
]
))
module
.
norm_w
.
data
.
copy_
(
sd
[
0
][
prefix
+
'input_layernorm.'
+
'weight'
])
module
.
norm_b
.
data
.
copy_
(
sd
[
0
][
prefix
+
'input_layernorm.'
+
'bias'
])
module
.
attention
.
attn_qkvw
=
mp_replace
.
copy
(
module
.
attention
.
attn_qkvw
,
weight_quantizer
.
quantize
(
sd
[
0
][
prefix
+
'self_attention.query_key_value.'
+
'weight'
])
if
weight_quantizer
.
q_int8
else
\
weight_quantizer
.
quantize
(
transpose
(
sd
[
0
][
prefix
+
'self_attention.query_key_value.'
+
'weight'
])
))
module
.
attention
.
attn_qkvb
=
mp_replace
.
copy
(
module
.
attention
.
attn_qkvb
.
data
,
sd
[
prefix
+
'self_attention.query_key_value.'
+
'bias'
])
module
.
attention
.
attn_ow
=
mp_replace
.
copy
(
module
.
attention
.
attn_ow
.
data
,
transpose
(
sd
[
prefix
+
'self_attention.dense.'
+
'weight'
]
))
sd
[
0
][
prefix
+
'self_attention.query_key_value.'
+
'bias'
])
module
.
attention
.
attn_ow
=
mp_replace
.
copy
(
module
.
attention
.
attn_ow
,
weight_quantizer
.
quantize
(
sd
[
0
][
prefix
+
'self_attention.dense.'
+
'weight'
])
if
weight_quantizer
.
q_int8
else
\
weight_quantizer
.
quantize
(
transpose
(
sd
[
0
][
prefix
+
'self_attention.dense.'
+
'weight'
])
))
module
.
attention
.
attn_ob
=
mp_replace
.
copy
(
module
.
attention
.
attn_ob
.
data
,
sd
[
prefix
+
'self_attention.dense.'
+
'bias'
])
module
.
mlp
.
attn_nw
.
data
.
copy_
(
sd
[
prefix
+
'post_attention_layernorm.'
+
'weight'
])
module
.
mlp
.
attn_nb
.
data
.
copy_
(
sd
[
prefix
+
'post_attention_layernorm.'
+
'bias'
])
module
.
mlp
.
inter_w
=
mp_replace
.
copy
(
module
.
mlp
.
inter_w
.
data
,
transpose
(
sd
[
prefix
+
'mlp.dense_h_to_4h.'
+
'weight'
]
))
sd
[
0
][
prefix
+
'self_attention.dense.'
+
'bias'
])
module
.
mlp
.
attn_nw
.
data
.
copy_
(
sd
[
0
][
prefix
+
'post_attention_layernorm.'
+
'weight'
])
module
.
mlp
.
attn_nb
.
data
.
copy_
(
sd
[
0
][
prefix
+
'post_attention_layernorm.'
+
'bias'
])
module
.
mlp
.
inter_w
=
mp_replace
.
copy
(
module
.
mlp
.
inter_w
,
weight_quantizer
.
quantize
(
sd
[
0
][
prefix
+
'mlp.dense_h_to_4h.'
+
'weight'
])
if
weight_quantizer
.
q_int8
else
\
weight_quantizer
.
quantize
(
transpose
(
sd
[
0
][
prefix
+
'mlp.dense_h_to_4h.'
+
'weight'
])
))
module
.
mlp
.
inter_b
=
mp_replace
.
copy
(
module
.
mlp
.
inter_b
.
data
,
sd
[
prefix
+
'mlp.dense_h_to_4h.'
+
'bias'
])
module
.
mlp
.
output_w
=
mp_replace
.
copy
(
module
.
mlp
.
output_w
.
data
,
transpose
(
sd
[
prefix
+
'mlp.dense_4h_to_h.'
+
'weight'
]
))
sd
[
0
][
prefix
+
'mlp.dense_h_to_4h.'
+
'bias'
])
module
.
mlp
.
output_w
=
mp_replace
.
copy
(
module
.
mlp
.
output_w
,
weight_quantizer
.
quantize
(
sd
[
0
][
prefix
+
'mlp.dense_4h_to_h.'
+
'weight'
])
if
weight_quantizer
.
q_int8
else
\
weight_quantizer
.
quantize
(
transpose
(
sd
[
0
][
prefix
+
'mlp.dense_4h_to_h.'
+
'weight'
])
))
module
.
mlp
.
output_b
=
mp_replace
.
copy
(
module
.
mlp
.
output_b
.
data
,
sd
[
prefix
+
'mlp.dense_4h_to_h.'
+
'bias'
])
sd
[
0
][
prefix
+
'mlp.dense_4h_to_h.'
+
'bias'
])
layer_policies
=
{
nn
.
Linear
:
load
,
...
...
@@ -117,7 +184,7 @@ def load_model_with_checkpoint(r_module, sd, mp_replace, ckpt_type, rank=0):
for
name
,
child
in
module
.
named_children
():
if
child
.
__class__
in
layer_policies
:
checking_key
=
prefix
+
name
+
'.'
if
not
any
(
checking_key
in
item
for
item
in
sd
.
keys
()):
if
not
any
(
checking_key
in
item
for
item
in
sd
[
0
]
.
keys
()):
if
hasattr
(
child
,
'weight'
)
and
\
(
hasattr
(
child
.
weight
,
'ds_id'
)
and
\
child
.
weight
.
ds_id
in
all_ds_ids
):
...
...
@@ -168,6 +235,7 @@ def load_model_with_checkpoint(r_module, sd, mp_replace, ckpt_type, rank=0):
embedding_weight
=
p
assert
hasattr
(
r_module
,
'lm_head'
),
"attempting to set lm_head but it doesn't exist"
r_module
.
lm_head
.
weight
=
embedding_weight
del
sd
for
sd_
in
sd
:
del
sd_
sd
=
None
gc
.
collect
()
deepspeed/module_inject/replace_module.py
浏览文件 @
afdc7287
...
...
@@ -5,7 +5,7 @@ import deepspeed
import
deepspeed.ops.transformer
as
transformer_inference
from
.replace_policy
import
HFBertLayerPolicy
,
HFGPT2LayerPolicy
,
BLOOMLayerPolicy
from
.replace_policy
import
replace_policies
from
..runtime.weight_quantizer
import
WeightQuantization
#
from ..runtime.weight_quantizer import WeightQuantization
from
deepspeed
import
comm
as
dist
from
torch
import
nn
...
...
@@ -115,8 +115,10 @@ class ReplaceWithTensorSlicing:
dst_shape
[
-
1
])[
self
.
gpu_index
].
to
(
torch
.
cuda
.
current_device
()).
contiguous
()
dst
.
data
.
copy_
(
bias_split
)
return
torch
.
nn
.
parameter
.
Parameter
(
dst
,
requires_grad
=
False
)
dst
=
torch
.
nn
.
parameter
.
Parameter
(
dst
,
requires_grad
=
False
)
if
hasattr
(
src
,
'scale'
):
dst
.
scale
=
src
.
scale
return
dst
def
get_transformer_name
(
replaced_module
):
...
...
@@ -134,6 +136,57 @@ def get_transformer_name(replaced_module):
return
transformer_name
class
GroupQuantizer
:
def
__init__
(
self
,
q_int8
=
True
,
num_groups
=
32
,
group_size
=
32
,
num_bits
=
8
):
self
.
num_groups
=
num_groups
self
.
group_size
=
group_size
self
.
num_bits
=
num_bits
self
.
q_int8
=
q_int8
def
quantize
(
self
,
inputs
,
qkv
=
True
,
count
=
1
,
parallel_dim
=
0
):
if
not
self
.
q_int8
or
not
qkv
:
inputs
=
torch
.
nn
.
Parameter
(
inputs
,
requires_grad
=
False
)
inputs
.
scale
=
torch
.
empty
(
1
)
return
inputs
q_range
=
2
**
self
.
num_bits
inputs
=
inputs
.
to
(
torch
.
cuda
.
current_device
())
input_flat
=
inputs
.
reshape
(
self
.
num_groups
,
-
1
).
contiguous
()
input_min
=
torch
.
min
(
input_flat
,
dim
=
1
,
keepdim
=
True
)[
0
].
float
()
input_max
=
torch
.
max
(
input_flat
,
dim
=
1
,
keepdim
=
True
)[
0
].
float
()
scale
=
torch
.
max
(
input_min
.
abs
(),
input_max
.
abs
())
*
2.0
/
(
q_range
)
input_flat
=
(
input_flat
/
scale
).
round
().
clamp
(
-
q_range
//
2
,
q_range
//
2
-
1
)
inputs_q
=
input_flat
.
reshape
(
inputs
.
shape
).
to
(
torch
.
int8
).
contiguous
()
out
=
torch
.
nn
.
Parameter
(
inputs_q
,
requires_grad
=
False
)
#print(inputs.shape)
inputs_split
=
inputs
.
split
(
inputs
.
shape
[
parallel_dim
]
//
2
,
dim
=
parallel_dim
)
input_flat
=
[
inputs_split
[
i
].
reshape
(
self
.
num_groups
,
-
1
).
contiguous
()
for
i
in
range
(
2
)
]
input_min
=
[
torch
.
min
(
input_flat
[
i
],
dim
=
1
,
keepdim
=
True
)[
0
].
float
()
for
i
in
range
(
2
)
]
input_max
=
[
torch
.
max
(
input_flat
[
i
],
dim
=
1
,
keepdim
=
True
)[
0
].
float
()
for
i
in
range
(
2
)
]
scale1
=
[
(
torch
.
max
(
input_min
[
i
].
abs
(),
input_max
[
i
].
abs
())
*
2.0
/
(
q_range
)).
squeeze
().
unsqueeze
(
0
)
for
i
in
range
(
2
)
]
out
.
scale
=
torch
.
cat
([
scale
.
squeeze
().
unsqueeze
(
0
),
scale1
[
0
],
scale1
[
1
]],
dim
=
0
).
reshape
(
self
.
num_groups
,
-
1
).
contiguous
()
return
out
def
replace_transformer_layer
(
orig_layer_impl
,
model
,
policy
=
None
,
...
...
@@ -161,7 +214,8 @@ def replace_transformer_layer(orig_layer_impl,
moe_experts
=
1
,
moe_type
=
'standard'
,
checkpoint_dict
=
None
,
save_mp_checkpoint_path
=
None
):
save_mp_checkpoint_path
=
None
,
base_dir
=
""
):
""" Replace bert-style transformer layers with DeepSpeed's transformer layer
Arguments:
orig_layer_impl (torch.nn.Module): the original transformer layer implementation to look for,
...
...
@@ -225,7 +279,7 @@ def replace_transformer_layer(orig_layer_impl,
_res_h4h_w
,
_res_h4h_b
,
_res_4hh_w
,
_res_4hh_b
,
_res_coef
=
policy
.
mlp
(
moe_type
)
attn_nw
,
attn_nb
,
input_nw
,
input_nb
=
policy
.
layerNorm
()
if
quantiz
e
:
if
Fals
e
:
if
policy_cls
is
not
HFBertLayerPolicy
:
qkvw
=
qkvw
.
to
(
torch
.
int8
)
dense_w
=
dense_w
.
to
(
torch
.
int8
)
...
...
@@ -257,6 +311,7 @@ def replace_transformer_layer(orig_layer_impl,
#expert_mp_replace = ReplaceWithTensorSlicing(mp_group=expert_mp_group)
quantizer
=
GroupQuantizer
(
q_int8
=
quantize
)
if
inference
:
if
moe
:
ep_world_size
=
dist
.
get_world_size
()
...
...
@@ -329,21 +384,21 @@ def replace_transformer_layer(orig_layer_impl,
new_module
=
transformer_inference
.
DeepSpeedTransformerInference
(
transformer_config
,
mp_group
=
mp_group
,
quantize_scales
=
quantization_scales
[
layer_id
],
#
quantize_scales=quantization_scales[layer_id],
quantize_groups
=
quantize_groups
,
merge_count
=
merge_count
,
mlp_extra_grouping
=
mlp_extra_grouping
,
qkv_merging
=
(
policy_cls
is
HFBertLayerPolicy
))
if
quantize
and
qkvw
.
dtype
!=
torch
.
int8
:
quantize_bits
=
8
quantizer
=
WeightQuantization
()
if
policy_cls
is
HFBertLayerPolicy
:
data_quantized
,
_
=
quantizer
.
quantize_data
(
qkvw
.
data
,
quantize_bits
,
quantize_groups
*
3
)
else
:
data_quantized
,
_
=
quantizer
.
quantize_data
(
qkvw
.
data
,
quantize_bits
,
quantize_groups
)
qkvw
.
data
.
copy_
(
data_quantized
)
qkvw
.
data
=
qkvw
.
data
.
to
(
torch
.
int8
)
#
if quantize and qkvw.dtype != torch.int8:
#
quantize_bits = 8
#
quantizer = WeightQuantization()
#
if policy_cls is HFBertLayerPolicy:
#
data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups * 3)
#
else:
#
data_quantized, _ = quantizer.quantize_data(qkvw.data, quantize_bits, quantize_groups)
#
qkvw.data.copy_(data_quantized)
#
qkvw.data = qkvw.data.to(torch.int8)
else
:
if
moe
:
...
...
@@ -478,18 +533,17 @@ def replace_transformer_layer(orig_layer_impl,
attn_block
.
attn_ow
=
mp_replace
.
copy
(
attn_block
.
attn_ow
,
dense_w
)
attn_block
.
attn_ob
=
mp_replace
.
copy
(
attn_block
.
attn_ob
,
dense_b
)
else
:
if
bigscience_bloom
:
attn_block
.
attn_qkvw
=
mp_replace
.
copy
(
attn_block
.
attn_qkvw
,
qkvw
)
attn_block
.
attn_qkvb
=
mp_replace
.
copy
(
attn_block
.
attn_qkvb
,
qkvb
)
else
:
attn_block
.
attn_qkvw
=
mp_replace
.
qkv_copy
(
attn_block
.
attn_qkvw
,
qkvw
)
attn_block
.
attn_qkvb
=
mp_replace
.
qkv_copy
(
attn_block
.
attn_qkvb
,
qkvb
)
attn_block
.
attn_ow
=
mp_replace
.
copy
(
attn_block
.
attn_ow
,
dense_w
)
attn_block
.
attn_qkvw
=
quantizer
.
quantize
(
mp_replace
.
copy
(
attn_block
.
attn_qkvw
,
qkvw
)
if
bigscience_bloom
else
\
mp_replace
.
qkv_copy
(
attn_block
.
attn_qkvw
,
qkvw
))
attn_block
.
attn_qkvb
=
\
mp_replace
.
copy
(
attn_block
.
attn_qkvb
,
qkvb
)
if
bigscience_bloom
else
\
mp_replace
.
qkv_copy
(
attn_block
.
attn_qkvb
,
qkvb
)
attn_block
.
attn_ow
=
quantizer
.
quantize
(
mp_replace
.
copy
(
attn_block
.
attn_ow
,
dense_w
))
attn_block
.
attn_ob
=
mp_replace
.
copy
(
attn_block
.
attn_ob
,
dense_b
)
if
moe
:
...
...
@@ -545,9 +599,13 @@ def replace_transformer_layer(orig_layer_impl,
mpl_block
.
output_b
,
_4hh_b
)
else
:
mpl_block
.
inter_w
=
mp_replace
.
copy
(
mpl_block
.
inter_w
,
_h4h_w
)
mpl_block
.
inter_w
=
quantizer
.
quantize
(
mp_replace
.
copy
(
mpl_block
.
inter_w
,
_h4h_w
))
mpl_block
.
inter_b
=
mp_replace
.
copy
(
mpl_block
.
inter_b
,
_h4h_b
)
mpl_block
.
output_w
=
mp_replace
.
copy
(
mpl_block
.
output_w
,
_4hh_w
)
mpl_block
.
output_w
=
quantizer
.
quantize
(
mp_replace
.
copy
(
mpl_block
.
output_w
,
_4hh_w
))
mpl_block
.
output_b
=
mp_replace
.
copy
(
mpl_block
.
output_b
,
_4hh_b
)
if
attn_nw
is
None
:
...
...
@@ -782,50 +840,92 @@ def replace_transformer_layer(orig_layer_impl,
replace_fn
=
replace_fn
,
_replace_policy
=
policy
)
quantizer
=
GroupQuantizer
(
q_int8
=
quantize
)
world_size
=
dist
.
get_world_size
()
if
dist
.
is_initialized
()
else
1
rank
=
dist
.
get_rank
()
if
dist
.
is_initialized
()
else
0
if
checkpoint_dict
is
not
None
:
start_time
=
time
.
time
()
checkpoint
=
checkpoint_dict
[
'checkpoints'
]
ckpt_list
=
checkpoint
[
"tp"
]
if
type
(
checkpoint
)
is
dict
else
checkpoint
ckpt_type
=
checkpoint_dict
.
get
(
'parallelization'
,
'pp'
)
ckpt_mp_size
=
checkpoint_dict
.
get
(
'mp_size'
,
mp_size
)
base_dir
=
checkpoint_dict
.
get
(
'base_dir'
,
''
)
ckpt_mp_size
=
checkpoint_dict
.
get
(
'tp_size'
,
len
(
ckpt_list
))
ckpt_mp_size
=
checkpoint_dict
.
get
(
'mp_size'
,
ckpt_mp_size
)
base_dir1
=
checkpoint_dict
.
get
(
'base_dir'
,
base_dir
)
if
ckpt_type
==
'pp'
:
if
ckpt_type
==
'pp'
and
type
(
checkpoint
)
is
list
:
pbar
=
tqdm
.
tqdm
(
total
=
len
(
checkpoint
),
desc
=
f
"Loading
{
len
(
checkpoint
)
}
checkpoint shards"
)
for
i
in
range
(
len
(
checkpoint
)):
if
not
deepspeed
.
comm
.
is_initialized
()
or
deepspeed
.
comm
.
get_rank
()
==
0
:
pbar
.
update
(
1
)
sd
=
torch
.
load
(
checkpoint
[
i
],
map_location
=
'cpu'
)
load_model_with_checkpoint
(
replaced_module
,
sd
,
mp_replace
,
ckpt_type
)
sd
=
[
torch
.
load
(
os
.
path
.
join
(
base_dir1
,
checkpoint
[
i
]),
map_location
=
'cpu'
)
]
load_model_with_checkpoint
(
replaced_module
,
sd
,
mp_replace
,
ckpt_type
,
quantizer
,
)
else
:
num_checkpoints
=
len
(
checkpoint
)
//
ckpt_mp_size
assert
world_size
>=
ckpt_mp_size
,
\
"Currently, merging checkpoints is not supported (when world_size is smaller than #checkpoints)!"
checkpoint_stride
=
world_size
//
ckpt_mp_size
if
not
deepspeed
.
comm
.
is_initialized
()
or
deepspeed
.
comm
.
get_rank
()
==
0
:
pbar
=
tqdm
.
tqdm
(
total
=
num_checkpoints
,
desc
=
f
"Loading
{
num_checkpoints
}
checkpoint shards"
)
import
gc
num_checkpoints
=
len
(
ckpt_list
)
//
ckpt_mp_size
tp_split_size
=
(
world_size
/
ckpt_mp_size
)
sd_offset
=
int
(
rank
/
tp_split_size
)
sd_count
=
int
((
rank
+
max
(
1
,
tp_split_size
))
/
tp_split_size
)
-
sd_offset
pbar
=
tqdm
.
tqdm
(
total
=
num_checkpoints
,
desc
=
f
"Loading
{
num_checkpoints
}
checkpoint shards"
)
for
i
in
range
(
num_checkpoints
):
if
not
deepspeed
.
comm
.
is_initialized
()
or
deepspeed
.
comm
.
get_rank
()
==
0
:
pbar
.
update
(
1
)
ckpt_index
=
i
*
ckpt_mp_size
+
(
rank
//
checkpoint_stride
)
ckpt_file
=
os
.
path
.
join
(
base_dir
,
checkpoint
[
ckpt_index
])
if
base_dir
else
checkpoint
[
ckpt_index
]
sd
=
torch
.
load
(
ckpt_file
,
map_location
=
'cpu'
)
pbar
.
update
(
1
)
ckpt_index
=
i
*
ckpt_mp_size
+
sd_offset
ckpt_files
=
[
os
.
path
.
join
(
base_dir1
,
ckpt_list
[
ckpt_index
+
j
])
if
base_dir1
else
ckpt_list
[
ckpt_index
+
j
]
for
j
in
range
(
sd_count
)
]
sds
=
[
torch
.
load
(
ckpt_file
,
map_location
=
'cpu'
)
for
ckpt_file
in
ckpt_files
]
load_model_with_checkpoint
(
replaced_module
,
sd
,
sd
s
,
mp_replace
,
ckpt_type
,
rank
%
(
world_size
//
ckpt_mp_size
))
quantizer
,
int
(
rank
%
tp_split_size
))
sds
=
[
None
for
_
in
sds
]
gc
.
collect
()
if
"non_tp"
in
checkpoint
:
pbar
=
tqdm
.
tqdm
(
total
=
len
(
checkpoint
[
"non_tp"
]),
desc
=
f
"Loading
{
len
(
checkpoint
[
'non_tp'
])
}
checkpoint shards"
)
for
i
in
range
(
len
(
checkpoint
[
"non_tp"
])):
pbar
.
update
(
1
)
ckpt_file
=
os
.
path
.
join
(
base_dir1
,
checkpoint
[
"non_tp"
][
i
]
)
if
base_dir1
else
checkpoint
[
"non_tp"
][
i
]
sds
=
[
torch
.
load
(
ckpt_file
,
map_location
=
'cpu'
)]
load_model_with_checkpoint
(
replaced_module
,
sds
,
mp_replace
,
ckpt_type
,
quantizer
,
int
(
rank
%
tp_split_size
))
sds
=
[
None
for
_
in
sds
]
gc
.
collect
()
print
(
f
"checkpoint loading time at rank
{
rank
}
:
{
time
.
time
()
-
start_time
}
sec"
)
if
save_mp_checkpoint_path
is
not
None
:
from
collections
import
OrderedDict
import
json
num_partitions
=
8
if
checkpoint_dict
is
None
:
ckpt_name
=
"ds_model"
...
...
@@ -840,8 +940,8 @@ def replace_transformer_layer(orig_layer_impl,
if
dist
.
is_initialized
():
dist
.
barrier
()
transformer_name
=
get_transformer_name
(
replaced_module
)
non_tp_ckpt_name
=
f
'
{
ckpt_name
}
-
non-tp.pt'
ckpt_files
=
[
non_tp_ckpt_name
]
*
world_size
non_tp_ckpt_name
=
f
'non-tp.pt'
ckpt_files
=
[
non_tp_ckpt_name
]
os
.
makedirs
(
save_mp_checkpoint_path
,
exist_ok
=
True
)
if
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
:
print
(
"Saving tp-sharded checkpoints"
)
...
...
@@ -853,25 +953,47 @@ def replace_transformer_layer(orig_layer_impl,
if
transformer_name
not
in
k
}),
f
'
{
save_mp_checkpoint_path
}
/
{
non_tp_ckpt_name
}
'
)
ckpt_files
+=
[
f
'
{
ckpt_name
}
-tp_
{
r
:
0
>
2
d
}
.pt'
for
r
in
range
(
world_size
)]
config
=
json
.
dumps
({
'type'
:
ckpt_name
,
'base_dir'
:
f
'
{
save_mp_checkpoint_path
}
'
,
'checkpoints'
:
ckpt_files
,
'version'
:
1.0
,
'parallelization'
:
'tp'
,
'mp_size'
:
world_size
'type'
:
ckpt_name
,
'base_dir'
:
f
'
{
save_mp_checkpoint_path
}
'
,
'checkpoints'
:
{
"non_tp"
:
ckpt_files
,
"tp"
:
[
f
'tp_
{
r
:
0
>
2
d
}
_
{
m
:
0
>
2
d
}
.pt'
for
m
in
range
(
num_partitions
)
for
r
in
range
(
world_size
)
]
},
'version'
:
1.0
,
'parallelization'
:
'tp'
,
'tp_size'
:
world_size
,
'dtype'
:
'int8'
if
quantize
else
(
'float16'
if
fp16
else
'float32'
)
})
with
open
(
f
"
{
save_mp_checkpoint_path
}
/
{
ckpt_name
}
_ds-inference_config.json"
,
"w"
)
as
cfg
:
with
open
(
f
"
{
save_mp_checkpoint_path
}
/ds-inference_config.json"
,
"w"
)
as
cfg
:
cfg
.
write
(
config
)
torch
.
save
(
OrderedDict
({
k
:
v
for
k
,
v
in
dict
(
replaced_module
.
state_dict
()).
items
()
if
transformer_name
in
k
}),
f
'
{
save_mp_checkpoint_path
}
/
{
ckpt_name
}
-tp_
{
rank
:
0
>
2
d
}
.pt'
)
rep_sd
=
replaced_module
.
state_dict
()
for
n
,
p
in
replaced_module
.
named_parameters
():
if
hasattr
(
p
,
'scale'
):
rep_sd
[
n
]
=
[
p
,
p
.
scale
]
keys
=
list
(
rep_sd
.
keys
())
partition_size
=
(
len
(
keys
)
//
num_partitions
+
1
)
for
m
in
range
(
num_partitions
):
torch
.
save
(
OrderedDict
({
k
:
[
rep_sd
[
k
],
rep_sd
[
k
].
scale
]
if
hasattr
(
rep_sd
[
k
],
'scale'
)
else
rep_sd
[
k
]
for
k
in
keys
[
m
*
partition_size
:(
m
+
1
)
*
partition_size
]
if
transformer_name
in
k
}),
f
'
{
save_mp_checkpoint_path
}
/tp_
{
rank
:
0
>
2
d
}
_
{
m
:
0
>
2
d
}
.pt'
)
return
replaced_module
...
...
deepspeed/ops/transformer/inference/transformer_inference.py
浏览文件 @
afdc7287
...
...
@@ -206,16 +206,6 @@ class DeepSpeedSelfAttentionFunction(Function):
value_layer
)
=
split_tensor_along_last_dim
(
mixed_x_layer
,
3
)
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
# concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim]
key_layer
=
torch
.
cat
((
past_key
.
type_as
(
key_layer
),
key_layer
),
dim
=
1
)
value_layer
=
torch
.
cat
((
past_value
.
type_as
(
value_layer
),
value_layer
),
dim
=
1
)
presents
=
(
key_layer
,
value_layer
)
# [batch_size, head_dim, q_length, k_length]
output_size
=
(
query_layer
.
size
(
0
),
query_layer
.
size
(
2
),
...
...
@@ -223,24 +213,37 @@ class DeepSpeedSelfAttentionFunction(Function):
key_layer
.
size
(
1
))
# [batch_size, q_length, num_heads, head_dim] -> [q_length, batch_size * num_heads, head_dim]
query_layer
=
query_layer
.
transpose
(
1
,
0
).
reshape
(
output_size
[
2
],
2
).
reshape
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
# [batch_size, k_length, num_heads, head_dim] -> [k_length, batch_size * num_heads, head_dim]
key_layer
=
key_layer
.
transpose
(
1
,
0
).
reshape
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
2
).
reshape
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
3
],
-
1
).
transpose
(
-
1
,
-
2
)
value_layer
=
value_layer
.
transpose
(
1
,
2
).
reshape
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
3
],
-
1
)
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
# concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim]
key_layer
=
torch
.
cat
((
past_key
.
type_as
(
key_layer
),
key_layer
),
dim
=-
1
)
value_layer
=
torch
.
cat
((
past_value
.
type_as
(
value_layer
),
value_layer
),
dim
=-
2
)
presents
=
(
key_layer
,
value_layer
)
# Raw attention scores. [batch_size * num_heads, q_length, k_length]
matmul_result
=
torch
.
matmul
(
query_layer
.
transpose
(
1
,
0
),
key_layer
.
transpose
(
1
,
0
).
transpose
(
1
,
2
))
matmul_result
=
torch
.
matmul
(
query_layer
,
key_layer
)
# change view to [batch_size, num_heads, q_length, k_length]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
attention_scores
=
matmul_result
.
view
(
output_size
[
0
],
output_size
[
1
],
output_size
[
2
],
-
1
)
offset
=
dist
.
get_rank
(
)
*
num_attention_heads_per_partition
if
dist
.
is_initialized
()
else
0
...
...
@@ -261,12 +264,7 @@ class DeepSpeedSelfAttentionFunction(Function):
attention_probs_reshaped
=
attention_probs
.
view
(
*
matmul_result
.
shape
)
# matmul: [batch_size * num_heads, q_length, head_dim]
context_layer
=
torch
.
bmm
(
attention_probs_reshaped
,
value_layer
.
transpose
(
1
,
2
).
reshape
(
-
1
,
value_layer
.
size
(
1
),
value_layer
.
size
(
3
)))
context_layer
=
torch
.
bmm
(
attention_probs_reshaped
,
value_layer
)
# change view [batch_size, num_heads, q_length, head_dim]
context_layer
=
context_layer
.
view
(
...
...
@@ -418,15 +416,21 @@ class DeepSpeedSelfAttentionFunction(Function):
qkv_out
=
qkv_func
(
input
,
attn_qkvw
,
attn_qkvw
.
scale
,
(
attn_qkvb
if
attn_qkvb
is
not
None
else
norm_b
),
norm_w
,
norm_b
,
config
.
epsilon
,
(
attn_qkvb
is
not
None
),
1
if
config
.
bigscience_bloom
else
DeepSpeedTransformerInference
.
layer_id
)
DeepSpeedTransformerInference
.
layer_id
,
config
.
q_int8
)
context_layer
,
key_layer
,
value_layer
=
compute_attention
(
qkv_out
[
0
]
if
isinstance
(
qkv_out
,
list
)
else
qkv_out
,
input_mask
)
output
=
vector_matmul_func
(
context_layer
,
attn_ow
,
False
)
output
=
vector_matmul_func
(
context_layer
,
attn_ow
,
False
,
attn_ow
.
scale
,
config
.
q_int8
)
return
output
,
key_layer
,
value_layer
,
context_layer
,
qkv_out
[
-
1
]
...
...
@@ -458,7 +462,7 @@ class DeepSpeedSelfAttentionFunction(Function):
(
merge_count
))
return
output
,
key_layer
,
value_layer
,
context_layer
if
config
.
q_int8
:
if
False
:
#
config.q_int8:
output
,
key_layer
,
value_layer
,
context_layer
=
selfAttention_int8
()
else
:
output
,
key_layer
,
value_layer
,
context_layer
,
inp_norm
=
selfAttention_fp
()
...
...
@@ -486,30 +490,34 @@ class DeepSpeedSelfAttention(nn.Module):
qkv_merging
=
False
):
super
(
DeepSpeedSelfAttention
,
self
).
__init__
()
self
.
config
=
config
data_type
=
torch
.
half
if
config
.
fp16
else
torch
.
float
data_type
=
torch
.
int8
if
config
.
q_int8
else
torch
.
half
if
config
.
fp16
else
torch
.
float
data_type_fp
=
torch
.
half
if
config
.
fp16
else
torch
.
float
self
.
config
.
layer_id
=
DeepSpeedSelfAttention
.
num_layers
DeepSpeedSelfAttention
.
num_layers
=
DeepSpeedSelfAttention
.
num_layers
+
1
device
=
torch
.
cuda
.
current_device
()
if
config
.
bigscience_bloom
else
'cpu'
self
.
attn_qkvw
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
hidden_size
,
(
self
.
config
.
hidden_size
//
self
.
config
.
mp_size
)
*
3
,
dtype
=
data_type
,
device
=
device
))
self
.
attn_qkvb
=
nn
.
Parameter
(
torch
.
empty
((
self
.
config
.
hidden_size
//
self
.
config
.
mp_size
)
*
3
,
dtype
=
data_type
,
device
=
device
))
self
.
attn_ow
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
hidden_size
//
self
.
config
.
mp_size
,
self
.
config
.
hidden_size
,
dtype
=
data_type
,
device
=
device
))
self
.
attn_ob
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
hidden_size
,
dtype
=
data_type
,
device
=
device
))
self
.
attn_qkvw
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
hidden_size
,
(
self
.
config
.
hidden_size
//
self
.
config
.
mp_size
)
*
3
,
dtype
=
data_type
,
device
=
device
),
requires_grad
=
False
)
self
.
attn_qkvb
=
nn
.
Parameter
(
torch
.
empty
(
(
self
.
config
.
hidden_size
//
self
.
config
.
mp_size
)
*
3
,
dtype
=
data_type_fp
,
device
=
device
),
requires_grad
=
False
)
self
.
attn_ow
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
hidden_size
//
self
.
config
.
mp_size
,
self
.
config
.
hidden_size
,
dtype
=
data_type
,
device
=
device
),
requires_grad
=
False
)
self
.
attn_ob
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
hidden_size
,
dtype
=
data_type_fp
,
device
=
device
),
requires_grad
=
False
)
self
.
num_attention_heads_per_partition
=
self
.
config
.
heads
//
self
.
config
.
mp_size
self
.
hidden_size_per_partition
=
self
.
config
.
hidden_size
//
self
.
config
.
mp_size
...
...
@@ -595,36 +603,16 @@ class DeepSpeedMLPFunction(Function):
bias_residual_func
,
activation_func_type
=
ActivationFuncType
.
GELU
):
if
config
.
q_int8
:
(
intermediate
,
residual_add
)
=
inference_cuda_module
.
mlp_gemm_int8
(
input
,
residual
,
bias
,
inter_w
,
inter_b
,
attn_nw
,
attn_nb
,
config
.
epsilon
,
q_scales
[
2
],
(
q_groups
*
(
2
**
merge_count
)),
config
.
pre_layer_norm
)
output
=
inference_cuda_module
.
vector_matmul_int8
(
intermediate
,
output_w
,
q_scales
[
3
],
q_groups
,
(
merge_count
))
if
attn_nw
is
None
:
output
=
fused_gemm_gelu
(
residual_norm
,
inter_w
,
inter_b
,
output_w
,
config
.
epsilon
,
config
.
pre_layer_norm
,
False
)
else
:
if
attn_nw
is
None
:
output
=
fused_gemm_gelu
(
residual_norm
,
inter_w
,
inter_b
,
output_w
,
config
.
epsilon
,
config
.
pre_layer_norm
,
False
)
else
:
intermediate
,
residual_add
=
mlp_gemm_func
(
input
,
intermediate
,
residual_add
=
mlp_gemm_func
(
input
,
residual
,
bias
,
inter_w
,
...
...
@@ -634,9 +622,14 @@ class DeepSpeedMLPFunction(Function):
config
.
epsilon
,
config
.
pre_layer_norm
,
config
.
mlp_after_attn
,
inter_w
.
scale
,
config
.
q_int8
,
config
.
mlp_act_func_type
)
output
=
vector_matmul_func
(
intermediate
,
output_w
,
False
)
output
=
vector_matmul_func
(
intermediate
,
output_w
,
False
,
output_w
.
scale
,
config
.
q_int8
)
inference_cuda_module
.
residual_add
(
output
,
residual
if
config
.
pre_layer_norm
else
residual_add
,
...
...
@@ -668,34 +661,38 @@ class DeepSpeedMLP(nn.Module):
super
(
DeepSpeedMLP
,
self
).
__init__
()
self
.
config
=
config
data_type
=
torch
.
half
if
config
.
fp16
else
torch
.
float
data_type
=
torch
.
int8
if
config
.
q_int8
else
torch
.
half
if
config
.
fp16
else
torch
.
float
data_type_fp
=
torch
.
half
if
config
.
fp16
else
torch
.
float
device
=
torch
.
cuda
.
current_device
()
if
config
.
bigscience_bloom
else
'cpu'
self
.
attn_nw
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
hidden_size
,
dtype
=
data_type
,
device
=
device
))
self
.
attn_nb
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
hidden_size
,
dtype
=
data_type
,
device
=
device
))
self
.
inter_w
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
hidden_size
,
self
.
config
.
intermediate_size
//
self
.
config
.
mp_size
,
dtype
=
data_type
,
device
=
device
))
self
.
inter_b
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
intermediate_size
//
self
.
config
.
mp_size
,
dtype
=
data_type
,
device
=
device
))
self
.
output_w
=
nn
.
Parameter
(
torch
.
empty
((
self
.
config
.
intermediate_size
//
self
.
config
.
mp_size
),
self
.
config
.
hidden_size
,
dtype
=
data_type
,
device
=
device
))
self
.
output_b
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
hidden_size
,
dtype
=
data_type
,
device
=
device
))
self
.
attn_nw
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
hidden_size
,
dtype
=
data_type_fp
,
device
=
device
),
requires_grad
=
False
)
self
.
attn_nb
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
hidden_size
,
dtype
=
data_type_fp
,
device
=
device
),
requires_grad
=
False
)
self
.
inter_w
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
hidden_size
,
self
.
config
.
intermediate_size
//
self
.
config
.
mp_size
,
dtype
=
data_type
,
device
=
device
),
requires_grad
=
False
)
self
.
inter_b
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
intermediate_size
//
self
.
config
.
mp_size
,
dtype
=
data_type_fp
,
device
=
device
),
requires_grad
=
False
)
self
.
output_w
=
nn
.
Parameter
(
torch
.
empty
(
(
self
.
config
.
intermediate_size
//
self
.
config
.
mp_size
),
self
.
config
.
hidden_size
,
dtype
=
data_type
,
device
=
device
),
requires_grad
=
False
)
self
.
output_b
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
hidden_size
,
dtype
=
data_type_fp
,
device
=
device
),
requires_grad
=
False
)
# used for quantization
self
.
q_scales
=
q_scales
...
...
@@ -790,14 +787,14 @@ class DeepSpeedTransformerInference(nn.Module):
mlp_extra_grouping
)
device
=
torch
.
cuda
.
current_device
()
if
config
.
bigscience_bloom
else
'cpu'
self
.
norm_w
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
hidden_siz
e
,
dtype
=
data_type
,
device
=
device
)
)
self
.
norm_b
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
hidden_siz
e
,
dtype
=
data_type
,
device
=
device
)
)
self
.
norm_w
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
hidden_size
,
dtype
=
data_typ
e
,
device
=
device
)
,
requires_grad
=
False
)
self
.
norm_b
=
nn
.
Parameter
(
torch
.
empty
(
self
.
config
.
hidden_size
,
dtype
=
data_typ
e
,
device
=
device
)
,
requires_grad
=
False
)
self
.
layer_past
=
None
def
forward
(
...
...
@@ -826,7 +823,6 @@ class DeepSpeedTransformerInference(nn.Module):
# We set the prev key/value to None when there is a prompt
if
input
.
shape
[
1
]
>
1
:
self
.
layer_past
=
None
layer_past
=
layer_past
if
layer_past
is
not
None
else
self
.
layer_past
head_mask
=
layer_head_mask
if
layer_head_mask
is
not
None
else
head_mask
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录