未验证 提交 35b350b2 编写于 作者: R Reza Yazdani 提交者: GitHub

Fix quantized-inference & Add generic support of checkpoint loading (#2547)

* fix checkpoint loading when it is a dictionary

* fix some issues with saving ckpt & int8 inference

* fix quantized-inference & add generic support of checkpoint loading

* remove int8 hard-coded flag

* fix mlp return tensors

* fix several issue to load checkpoints of GPT-J, GPT-NEOX, and OPT with different TP-size

* add more comments & description for checkpoint-loading module
Co-authored-by: NMichael Wyatt <michaelwyatt@microsoft.com>
上级 b8416282
...@@ -763,11 +763,18 @@ void quantized_gemm(void* output, ...@@ -763,11 +763,18 @@ void quantized_gemm(void* output,
at::Tensor& weight, at::Tensor& weight,
at::Tensor& qscale, at::Tensor& qscale,
int groups, int groups,
int bsz) int bsz,
int hidden_size)
{ {
T* weight16 = (T*)Context::Instance().GetWorkSpace() + T* weight16 = (T*)Context::Instance().GetWorkSpace() + 12 * hidden_size * bsz;
12 * Context::Instance().GetMaxTokenLenght() * weight.size(1);
// auto options = at::TensorOptions()
// .dtype(at::kHalf)
// .layout(at::kStrided)
// .device(at::kCUDA)
// .requires_grad(false);
// auto tmp = torch::empty(weight.sizes(), options);
// T* weight16 = (T*)tmp.data_ptr();
launch_dequantize(weight16, launch_dequantize(weight16,
(int8_t*)weight.data_ptr(), (int8_t*)weight.data_ptr(),
(float*)qscale.data_ptr(), (float*)qscale.data_ptr(),
...@@ -814,7 +821,8 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, ...@@ -814,7 +821,8 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
ds_layer_norm_internal<T>(workspace, input, gamma, beta, epsilon); ds_layer_norm_internal<T>(workspace, input, gamma, beta, epsilon);
if (q_int8) { if (q_int8) {
quantized_gemm<T>(output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz); quantized_gemm<T>(
output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz, input.size(2));
} else { } else {
float alpha = (T)1.0; float alpha = (T)1.0;
float gemm_beta = (T)0.0; float gemm_beta = (T)0.0;
...@@ -1202,15 +1210,19 @@ at::Tensor ds_vector_matmul(at::Tensor& input, ...@@ -1202,15 +1210,19 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
.layout(at::kStrided) .layout(at::kStrided)
.device(at::kCUDA) .device(at::kCUDA)
.requires_grad(false); .requires_grad(false);
int out_size = q_int8 ? weight.size(0) : weight.size(1); int out_size = q_int8 ? weight.size(0) : weight.size(1);
int bsz = input.size(0) * input.size(1); int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace(); T* workspace = (T*)Context::Instance().GetWorkSpace();
auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options); auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options);
if (q_int8) { if (q_int8) {
quantized_gemm<T>( quantized_gemm<T>(output.data_ptr(),
output.data_ptr(), (T*)input.data_ptr(), weight, q_scale, q_scale.size(0), bsz); (T*)input.data_ptr(),
weight,
q_scale,
q_scale.size(0),
bsz,
input.size(2));
} else { } else {
float alpha = (T)1.0; float alpha = (T)1.0;
float gemm_beta = (T)0.0; float gemm_beta = (T)0.0;
...@@ -1293,9 +1305,9 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, ...@@ -1293,9 +1305,9 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
} else { } else {
ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon); ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon);
} }
if (q_int8) { if (q_int8) {
quantized_gemm<T>(intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz); quantized_gemm<T>(
intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz, input.size(2));
} else { } else {
float alpha = (T)1.0; float alpha = (T)1.0;
float gemm_beta = (T)0.0; float gemm_beta = (T)0.0;
...@@ -1331,9 +1343,15 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, ...@@ -1331,9 +1343,15 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
bsz, bsz,
Context::Instance().GetCurrentStream()); Context::Instance().GetCurrentStream());
} }
if (q_int8) { if (q_int8) {
quantized_gemm<T>( quantized_gemm<T>(output.data_ptr(),
output.data_ptr(), intermediate, weight1, q_scale1, q_scale1.size(0), bsz); intermediate,
weight1,
q_scale1,
q_scale1.size(0),
bsz,
input.size(2));
} else { } else {
float alpha = (T)1.0; float alpha = (T)1.0;
float gemm_beta = (T)0.0; float gemm_beta = (T)0.0;
...@@ -1449,64 +1467,95 @@ std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input, ...@@ -1449,64 +1467,95 @@ std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input,
template <typename T> template <typename T>
at::Tensor fused_gemm_gelu(at::Tensor& input, at::Tensor fused_gemm_gelu(at::Tensor& input,
at::Tensor& weight, at::Tensor& weight,
at::Tensor& weight_scale,
at::Tensor& bias, at::Tensor& bias,
at::Tensor& weight_out, at::Tensor& weight_out,
at::Tensor& weight_out_scale,
const float epsilon, const float epsilon,
bool preLayerNorm, bool preLayerNorm,
bool q_int8,
bool async_op) bool async_op)
{ {
auto input_cont = input.contiguous();
auto options = at::TensorOptions() auto options = at::TensorOptions()
.dtype(input_cont.options().dtype()) .dtype(input.options().dtype())
.layout(at::kStrided) .layout(at::kStrided)
.device(at::kCUDA) .device(at::kCUDA)
.requires_grad(false); .requires_grad(false);
auto intermediate = int intm_dim = q_int8 ? weight.size(0) : weight.size(1);
at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight_out.size(1)}, options); // auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input),
int bsz = input_cont.size(0) * input_cont.size(1); // {input.size(0), input.size(1), out_size},
// options);
// T* intermediate = (T*)input.data_ptr() + torch::numel(input);
auto intermediate = at::empty({input.size(0), input.size(1), intm_dim}, options);
int bsz = input.size(0) * input.size(1);
float alpha = (T)1.0; float alpha = (T)1.0;
float gemm_beta = (T)0.0; float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); if (q_int8) {
cublas_gemm_ex(Context::Instance().GetCublasHandle(), quantized_gemm<T>(intermediate.data_ptr(),
CUBLAS_OP_N, (T*)input.data_ptr(),
CUBLAS_OP_N, weight,
weight.size(1), weight_scale,
bsz, weight_scale.size(0),
input.size(2), bsz,
&alpha, input.size(2));
&gemm_beta, } else {
(T*)weight.data_ptr(), cublasSetStream(Context::Instance().GetCublasHandle(),
(T*)input_cont.data_ptr(), Context::Instance().GetCurrentStream());
(T*)intermediate.data_ptr(), cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
intm_dim,
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input.data_ptr(),
(T*)intermediate.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard); rocblas_gemm_algo_standard);
#else #else
CUBLAS_GEMM_DEFAULT_TENSOR_OP); CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif #endif
}
launch_bias_gelu((T*)intermediate.data_ptr(), launch_bias_gelu((T*)intermediate.data_ptr(),
(T*)bias.data_ptr(), (T*)bias.data_ptr(),
weight.size(1), intm_dim,
bsz, bsz,
Context::Instance().GetCurrentStream()); Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(), int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1);
CUBLAS_OP_N, auto output = at::empty({input.size(0), input.size(1), out_size}, options);
CUBLAS_OP_N, if (q_int8) {
weight_out.size(1), quantized_gemm<T>(output.data_ptr(),
bsz, (T*)intermediate.data_ptr(),
intermediate.size(2), weight_out,
&alpha, weight_out_scale,
&gemm_beta, weight_out_scale.size(0),
(T*)weight_out.data_ptr(), bsz,
(T*)intermediate.data_ptr(), input.size(2));
(T*)output.data_ptr(), } else {
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
out_size,
bsz,
intm_dim,
&alpha,
&gemm_beta,
(T*)weight_out.data_ptr(),
(T*)intermediate.data_ptr(),
(T*)output.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard); rocblas_gemm_algo_standard);
#else #else
CUBLAS_GEMM_DEFAULT_TENSOR_OP); CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif #endif
}
// cudaEventRecord(Context::Instance().GetCompEvent(2), // cudaEventRecord(Context::Instance().GetCompEvent(2),
// Context::Instance().GetCurrentStream(true)); // Context::Instance().GetCurrentStream(true));
return output; return output;
......
...@@ -68,7 +68,7 @@ class DeepSpeedTransformerInference(nn.Module): ...@@ -68,7 +68,7 @@ class DeepSpeedTransformerInference(nn.Module):
merge_count, merge_count,
mlp_extra_grouping) mlp_extra_grouping)
device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' device = torch.cuda.current_device() #if config.bigscience_bloom else 'cpu'
self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size, self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type, dtype=data_type,
device=device), device=device),
...@@ -131,7 +131,6 @@ class DeepSpeedTransformerInference(nn.Module): ...@@ -131,7 +131,6 @@ class DeepSpeedTransformerInference(nn.Module):
if (self.config.fp16 or self.config.q_int8) \ if (self.config.fp16 or self.config.q_int8) \
and input.dtype == torch.float: and input.dtype == torch.float:
input = input.half() input = input.half()
with torch.no_grad(): with torch.no_grad():
attention_output, key, value, context_outputtn_ctx, inp_norm = \ attention_output, key, value, context_outputtn_ctx, inp_norm = \
self.attention(input, self.attention(input,
......
...@@ -23,7 +23,7 @@ class LinearAllreduce(nn.Module): ...@@ -23,7 +23,7 @@ class LinearAllreduce(nn.Module):
class LinearLayer(nn.Module): class LinearLayer(nn.Module):
def __init__(self, weight_shape=None, dtype=None, weight=None, bias=None): def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
super(LinearLayer, self).__init__() super(LinearLayer, self).__init__()
if weight is not None: if weight is not None:
self.weight = weight self.weight = weight
...@@ -33,10 +33,12 @@ class LinearLayer(nn.Module): ...@@ -33,10 +33,12 @@ class LinearLayer(nn.Module):
torch.empty(weight_shape, torch.empty(weight_shape,
dtype=dtype, dtype=dtype,
device=torch.cuda.current_device())) device=torch.cuda.current_device()))
self.bias = Parameter( self.bias = Parameter(
torch.empty(weight_shape[0], torch.empty(weight_shape[0],
dtype=dtype, dtype=dtype,
device=torch.cuda.current_device())) device=torch.cuda.current_device())) \
if bias is not None else None
def forward(self, input): def forward(self, input):
output = torch.matmul(input, self.weight.transpose(-1, -2)) output = torch.matmul(input, self.weight.transpose(-1, -2))
...@@ -57,7 +59,7 @@ class Normalize(nn.Module): ...@@ -57,7 +59,7 @@ class Normalize(nn.Module):
class EmbeddingLayer(nn.Module): class EmbeddingLayer(nn.Module):
def __init__(self, weight_shape, dtype=torch.float): def __init__(self, weight_shape, dtype=torch.half):
super(EmbeddingLayer, self).__init__() super(EmbeddingLayer, self).__init__()
self.weight = Parameter( self.weight = Parameter(
torch.empty(weight_shape[0], torch.empty(weight_shape[0],
...@@ -67,3 +69,28 @@ class EmbeddingLayer(nn.Module): ...@@ -67,3 +69,28 @@ class EmbeddingLayer(nn.Module):
def forward(self, input): def forward(self, input):
return F.embedding(input, self.weight) return F.embedding(input, self.weight)
class OPTEmbedding(EmbeddingLayer):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, weight_shape):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
super().__init__(weight_shape)
def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()
# create positions depending on attention_mask
positions = (torch.cumsum(attention_mask,
dim=1).type_as(attention_mask) *
attention_mask).long() - 1
# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]
return super().forward(positions + self.offset)
from torch import nn from torch import nn
import deepspeed.ops.transformer as transformer_inference import deepspeed.ops.transformer as transformer_inference
from ..runtime.zero import GatheredParameters from ..runtime.zero import GatheredParameters
from .layers import LinearLayer, Normalize, EmbeddingLayer from .layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding
import torch import torch
import gc import gc
...@@ -11,14 +11,18 @@ def load_model_with_checkpoint(r_module, ...@@ -11,14 +11,18 @@ def load_model_with_checkpoint(r_module,
mp_replace, mp_replace,
ckpt_type, ckpt_type,
weight_quantizer=None, weight_quantizer=None,
rank=0): rank=0,
param_names=None,
transformer_config=None,
megatron_v2=False):
error_msgs = [] error_msgs = []
def transpose(data): def transpose(data):
data = data.contiguous() with torch.no_grad():
data1 = data.transpose(-1, -2).reshape(-1) data = data.contiguous()
data.reshape(-1).copy_(data1) data1 = data.transpose(-1, -2).reshape(-1)
data1 = None data.reshape(-1).copy_(data1)
data1 = None
return data.reshape(data.shape[-1], data.shape[-2]) return data.reshape(data.shape[-1], data.shape[-2])
def load(module, prefix): def load(module, prefix):
...@@ -87,7 +91,7 @@ def load_model_with_checkpoint(r_module, ...@@ -87,7 +91,7 @@ def load_model_with_checkpoint(r_module,
else: else:
assert tmp_data.dtype != torch.int8, \ assert tmp_data.dtype != torch.int8, \
'''Merging of the checkpoints are not supported when using INT8 checkpoint! \ '''Merging of the checkpoints are not supported when using INT8 checkpoint! \
Please use a as many GPUs as TP-size for the checkpoint''' Please use a as many GPUs as TP-size for the checkpoint'''
all_data = [ all_data = [
sd[j][prefix + sd[j][prefix +
n] if type(sd[j][prefix + n]) is list else n] if type(sd[j][prefix + n]) is list else
...@@ -138,37 +142,146 @@ def load_model_with_checkpoint(r_module, ...@@ -138,37 +142,146 @@ def load_model_with_checkpoint(r_module,
for n, child in module.named_children(): for n, child in module.named_children():
load_parameters(child, prefix + n + '.') load_parameters(child, prefix + n + '.')
else: else:
module.norm_w.data.copy_(sd[0][prefix + 'input_layernorm.' + 'weight'])
module.norm_b.data.copy_(sd[0][prefix + 'input_layernorm.' + 'bias'])
module.attention.attn_qkvw = mp_replace.copy(module.attention.attn_qkvw,
weight_quantizer.quantize(sd[0][prefix + 'self_attention.query_key_value.' + 'weight']) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(transpose(sd[0][prefix + 'self_attention.query_key_value.' + 'weight'])))
module.attention.attn_qkvb = mp_replace.copy(
module.attention.attn_qkvb.data,
sd[0][prefix + 'self_attention.query_key_value.' + 'bias'])
module.attention.attn_ow = mp_replace.copy(module.attention.attn_ow,
weight_quantizer.quantize(sd[0][prefix + 'self_attention.dense.' + 'weight']) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(transpose(sd[0][prefix + 'self_attention.dense.' + 'weight'])))
module.attention.attn_ob = mp_replace.copy(
module.attention.attn_ob.data,
sd[0][prefix + 'self_attention.dense.' + 'bias'])
module.mlp.attn_nw.data.copy_(sd[0][prefix + 'post_attention_layernorm.' +
'weight'])
module.mlp.attn_nb.data.copy_(sd[0][prefix + 'post_attention_layernorm.' +
'bias'])
module.mlp.inter_w = mp_replace.copy(module.mlp.inter_w,
weight_quantizer.quantize(sd[0][prefix + 'mlp.dense_h_to_4h.' + 'weight']) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(transpose(sd[0][prefix + 'mlp.dense_h_to_4h.' + 'weight'])))
module.mlp.inter_b = mp_replace.copy(
module.mlp.inter_b.data,
sd[0][prefix + 'mlp.dense_h_to_4h.' + 'bias'])
module.mlp.output_w = mp_replace.copy(module.mlp.output_w,
weight_quantizer.quantize(sd[0][prefix + 'mlp.dense_4h_to_h.' + 'weight']) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(transpose(sd[0][prefix + 'mlp.dense_4h_to_h.' + 'weight'])))
module.mlp.output_b = mp_replace.copy(
module.mlp.output_b.data,
sd[0][prefix + 'mlp.dense_4h_to_h.' + 'bias'])
def _transpose(x):
heads = transformer_config.heads // mp_replace.mp_size
attention_head_size = x.shape[-1] // heads
new_x_shape = x.size()[:-1] + (heads, attention_head_size)
x_1 = x.view(*new_x_shape)
(q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1))
if len(q.shape) > 2:
return torch.cat((q.reshape(q.shape[0],
-1),
k.reshape(q.shape[0],
-1),
v.reshape(q.shape[0],
-1)),
dim=-1).reshape(x.shape)
else:
return torch.cat((q.reshape(-1),
k.reshape(-1),
v.reshape(-1)),
dim=-1).reshape(x.shape)
# This checks if the parameter exits in the checkpoint file and maybe copies it into the corresponding destination tensor.
# Note that not all parameters are saved in one checkpoint, that's why we always need to check if they exist!
def maybe_copy(module,
dst_name,
src_name,
qkv=False,
megatron_v2=False,
split_qkv=False):
if src_name in sd[0]:
dst = getattr(module, dst_name)
tmp = sd[0][src_name].cuda()
if len(dst.shape) == 1:
if split_qkv:
dst = mp_replace.qkv_copy(dst, tmp)
else:
dst = mp_replace.copy(dst, tmp)
if qkv and megatron_v2:
dst = torch.nn.parameter.Parameter(
_transpose(dst).contiguous())
else:
if split_qkv:
dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, tmp if weight_quantizer.q_int8 else \
(transpose(tmp).contiguous())))
else:
dst = weight_quantizer.quantize(mp_replace.copy(dst, tmp if weight_quantizer.q_int8 else \
transpose(tmp)))
if qkv and megatron_v2:
scale1 = dst.scale
dst = torch.nn.parameter.Parameter(
_transpose(dst).contiguous())
dst.scale = scale1
setattr(module, dst_name, dst)
# Extending the maybe_copy function for when the q, k, and v are in separate parameters!
def maybe_copy_qkv(module, dst_name, src_names, split_qkv=False):
if src_names[0] in sd[0]:
q = sd[0][src_names[0]]
k = sd[0][src_names[1]]
v = sd[0][src_names[2]]
qkv_data = torch.cat((q, k, v), dim=0)
dst = getattr(module, dst_name)
if len(dst.shape) == 1:
if split_qkv:
dst = mp_replace.qkv_copy(dst,
(qkv_data.cuda()).contiguous())
else:
dst = mp_replace.copy(dst, qkv_data.cuda())
else:
if split_qkv:
dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, qkv_data.cuda() if weight_quantizer.q_int8 else \
((transpose(qkv_data.cuda())).contiguous())))
else:
dst = weight_quantizer.quantize(mp_replace.copy(dst, qkv_data.cuda() if weight_quantizer.q_int8 else \
transpose(qkv_data.cuda())))
setattr(module, dst_name, dst)
if len(param_names) == 14:
qkv_w, qkv_b, attn_ow, attn_ob, \
mlp_intw, mlp_intb, mlp_ow, mlp_ob, \
inp_normw, inp_normb, attn_nw, attn_nb, _, split_qkv = param_names
elif len(param_names) < 14:
q_w, k_w, v_w, attn_ow, \
mlp_intw, mlp_intb, mlp_ow, mlp_ob, \
inp_normw, inp_normb, _, split_qkv = param_names
else:
q_w, q_b, k_w, k_b, v_w, v_b, attn_ow, attn_ob, \
mlp_intw, mlp_intb, mlp_ow, mlp_ob, \
inp_normw, inp_normb, attn_nw, attn_nb, _, split_qkv = param_names
maybe_copy(module, 'norm_w', prefix + inp_normw)
maybe_copy(module, 'norm_b', prefix + inp_normb)
if len(param_names) == 14:
maybe_copy(module.attention,
'attn_qkvw',
prefix + qkv_w,
qkv=True,
megatron_v2=megatron_v2,
split_qkv=split_qkv)
maybe_copy(module.attention,
'attn_qkvb',
prefix + qkv_b,
qkv=True,
megatron_v2=megatron_v2,
split_qkv=split_qkv)
elif len(param_names) < 14:
maybe_copy_qkv(module.attention,
'attn_qkvw',
[prefix + q_w,
prefix + k_w,
prefix + v_w],
split_qkv=split_qkv)
else:
maybe_copy_qkv(module.attention,
'attn_qkvw',
[prefix + q_w,
prefix + k_w,
prefix + v_w],
split_qkv=split_qkv)
maybe_copy_qkv(module.attention,
'attn_qkvb',
[prefix + q_b,
prefix + k_b,
prefix + v_b],
split_qkv=split_qkv)
maybe_copy(module.attention, 'attn_ow', prefix + attn_ow)
if len(param_names) >= 14:
maybe_copy(module.attention, 'attn_ob', prefix + attn_ob)
maybe_copy(module.mlp, 'attn_nw', prefix + attn_nw)
maybe_copy(module.mlp, 'attn_nb', prefix + attn_nb)
maybe_copy(module.mlp, 'inter_w', prefix + mlp_intw)
maybe_copy(module.mlp, 'inter_b', prefix + mlp_intb)
maybe_copy(module.mlp, 'output_w', prefix + mlp_ow)
maybe_copy(module.mlp, 'output_b', prefix + mlp_ob)
try:
import transformers
OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding
except:
OPTLearnedPositionalEmbedding = None
layer_policies = { layer_policies = {
nn.Linear: load, nn.Linear: load,
nn.Embedding: load, nn.Embedding: load,
...@@ -176,7 +289,9 @@ def load_model_with_checkpoint(r_module, ...@@ -176,7 +289,9 @@ def load_model_with_checkpoint(r_module,
EmbeddingLayer: load, EmbeddingLayer: load,
LinearLayer: load, LinearLayer: load,
Normalize: load, Normalize: load,
transformer_inference.DeepSpeedTransformerInference: load_transformer_layer transformer_inference.DeepSpeedTransformerInference: load_transformer_layer,
OPTLearnedPositionalEmbedding: load,
OPTEmbedding: load
} }
all_ds_ids = {} all_ds_ids = {}
...@@ -201,14 +316,17 @@ def load_model_with_checkpoint(r_module, ...@@ -201,14 +316,17 @@ def load_model_with_checkpoint(r_module,
ds_shape = child.weight.shape ds_shape = child.weight.shape
else: else:
ds_shape = child.weight.ds_shape ds_shape = child.weight.ds_shape
if child.__class__ is nn.LayerNorm: if child.__class__ is nn.LayerNorm:
child = Normalize(dim=ds_shape[-1], child = Normalize(dim=ds_shape[-1],
dtype=child.weight.dtype, dtype=child.weight.dtype,
eps=child.eps) eps=child.eps)
setattr(module, name, child) setattr(module, name, child)
elif child.__class__ is nn.Linear: elif child.__class__ is nn.Linear:
child = LinearLayer(weight=child.weight, bias=child.bias) child = LinearLayer(weight_shape=child.weight.shape,
bias=child.bias)
setattr(module, name, child)
elif child.__class__ is OPTLearnedPositionalEmbedding:
child = OPTEmbedding(weight_shape=ds_shape)
setattr(module, name, child) setattr(module, name, child)
else: else:
ds_id = None ds_id = None
...@@ -224,7 +342,8 @@ def load_model_with_checkpoint(r_module, ...@@ -224,7 +342,8 @@ def load_model_with_checkpoint(r_module,
else: else:
load_module_recursive( load_module_recursive(
child, child,
prefix if level == 0 and ckpt_type == 'pp' else prefix + name + '.', prefix if (level == 0 and ckpt_type == 'pp') and param_names[-2] else \
prefix + name + '.',
level + 1) level + 1)
load_module_recursive(r_module) load_module_recursive(r_module)
...@@ -232,10 +351,11 @@ def load_model_with_checkpoint(r_module, ...@@ -232,10 +351,11 @@ def load_model_with_checkpoint(r_module,
#XXX: hack to tie embedding w. lm_head for BLOOM, need to revist soon #XXX: hack to tie embedding w. lm_head for BLOOM, need to revist soon
embedding_weight = None embedding_weight = None
for n, p in r_module.named_parameters(): for n, p in r_module.named_parameters():
if "word_embeddings." in n: if "word_embeddings." in n or "embed_tokens." in n:
embedding_weight = p embedding_weight = p
assert hasattr(r_module, 'lm_head'), "attempting to set lm_head but it doesn't exist"
r_module.lm_head.weight = embedding_weight if embedding_weight is not None:
r_module.lm_head.weight = embedding_weight
for sd_ in sd: for sd_ in sd:
del sd_ del sd_
sd = None sd = None
......
...@@ -36,13 +36,13 @@ class ReplaceWithTensorSlicing: ...@@ -36,13 +36,13 @@ class ReplaceWithTensorSlicing:
return src return src
src_shape = src.shape src_shape = src.shape
dst_shape = dst.shape dst_shape = dst.shape
if self.out_dim == 0: if self.out_dim == 0:
src_split = torch.split(src.data, src_split = torch.split(src.data,
src_shape[self.out_dim] // self.mp_size, src_shape[self.out_dim] // self.mp_size,
dim=0) dim=0)
else: else:
src_split = torch.split(src.data, src.shape[-1] // 3, dim=-1) src_split = torch.split(src.data, src.shape[-1] // 3, dim=-1)
if (len(src_shape) == 2 and len(dst_shape) == 2): if (len(src_shape) == 2 and len(dst_shape) == 2):
if src_shape[self.out_dim] == dst_shape[self.out_dim]: if src_shape[self.out_dim] == dst_shape[self.out_dim]:
return torch.nn.parameter.Parameter(src) return torch.nn.parameter.Parameter(src)
...@@ -54,7 +54,6 @@ class ReplaceWithTensorSlicing: ...@@ -54,7 +54,6 @@ class ReplaceWithTensorSlicing:
qkv_size, qkv_size,
dim=self.out_dim) for src_s in src_split dim=self.out_dim) for src_s in src_split
] ]
weight_split = [ weight_split = [
torch.cat([qkv_s[i] for qkv_s in qkv_split], torch.cat([qkv_s[i] for qkv_s in qkv_split],
axis=self.out_dim) for i in range(len(qkv_split[0])) axis=self.out_dim) for i in range(len(qkv_split[0]))
...@@ -137,8 +136,7 @@ def get_transformer_name(replaced_module): ...@@ -137,8 +136,7 @@ def get_transformer_name(replaced_module):
class GroupQuantizer: class GroupQuantizer:
def __init__(self, q_int8=True, num_groups=32, group_size=32, num_bits=8): def __init__(self, q_int8=True, group_size=1, num_bits=8):
self.num_groups = num_groups
self.group_size = group_size self.group_size = group_size
self.num_bits = num_bits self.num_bits = num_bits
self.q_int8 = q_int8 self.q_int8 = q_int8
...@@ -149,8 +147,9 @@ class GroupQuantizer: ...@@ -149,8 +147,9 @@ class GroupQuantizer:
inputs.scale = torch.empty(1) inputs.scale = torch.empty(1)
return inputs return inputs
q_range = 2**self.num_bits q_range = 2**self.num_bits
num_groups = inputs.shape[0] // self.group_size
inputs = inputs.to(torch.cuda.current_device()) inputs = inputs.to(torch.cuda.current_device())
input_flat = inputs.reshape(self.num_groups, -1).contiguous() input_flat = inputs.reshape(num_groups, -1).contiguous()
input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float() input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float()
input_max = torch.max(input_flat, dim=1, keepdim=True)[0].float() input_max = torch.max(input_flat, dim=1, keepdim=True)[0].float()
scale = torch.max(input_min.abs(), input_max.abs()) * 2.0 / (q_range) scale = torch.max(input_min.abs(), input_max.abs()) * 2.0 / (q_range)
...@@ -160,7 +159,7 @@ class GroupQuantizer: ...@@ -160,7 +159,7 @@ class GroupQuantizer:
#print(inputs.shape) #print(inputs.shape)
inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim) inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim)
input_flat = [ input_flat = [
inputs_split[i].reshape(self.num_groups, inputs_split[i].reshape(num_groups,
-1).contiguous() for i in range(2) -1).contiguous() for i in range(2)
] ]
input_min = [ input_min = [
...@@ -182,7 +181,7 @@ class GroupQuantizer: ...@@ -182,7 +181,7 @@ class GroupQuantizer:
out.scale = torch.cat([scale.squeeze().unsqueeze(0), out.scale = torch.cat([scale.squeeze().unsqueeze(0),
scale1[0], scale1[0],
scale1[1]], scale1[1]],
dim=0).reshape(self.num_groups, dim=0).reshape(num_groups,
-1).contiguous() -1).contiguous()
return out return out
...@@ -286,6 +285,11 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True): ...@@ -286,6 +285,11 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
setattr(module, name, new_module) setattr(module, name, new_module)
selected_policy_g = None
megatron_v2_g = False
transformer_config_g = None
def replace_transformer_layer(orig_layer_impl, def replace_transformer_layer(orig_layer_impl,
model, model,
checkpoint_dict, checkpoint_dict,
...@@ -325,6 +329,9 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -325,6 +329,9 @@ def replace_transformer_layer(orig_layer_impl,
inference=False, inference=False,
layer_id=0): layer_id=0):
policy = policy_cls(child, inference=inference) policy = policy_cls(child, inference=inference)
global selected_policy_g
if selected_policy_g is None:
selected_policy_g = policy
if not policy.cuda_graph_supported: if not policy.cuda_graph_supported:
# policy says cuda graph is not supported raise an error if set # policy says cuda graph is not supported raise an error if set
assert not config.enable_cuda_graph, "cuda graph is not supported with this model, please disable" assert not config.enable_cuda_graph, "cuda graph is not supported with this model, please disable"
...@@ -340,6 +347,8 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -340,6 +347,8 @@ def replace_transformer_layer(orig_layer_impl,
moe = True moe = True
attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention, megatron_v2 = policy.attention() attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention, megatron_v2 = policy.attention()
global megatron_v2_g
megatron_v2_g = megatron_v2
if not moe or config.moe.type == 'standard': if not moe or config.moe.type == 'standard':
mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp() mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp()
else: else:
...@@ -439,6 +448,8 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -439,6 +448,8 @@ def replace_transformer_layer(orig_layer_impl,
bigscience_bloom=bigscience_bloom, bigscience_bloom=bigscience_bloom,
max_out_tokens=config.max_out_tokens, max_out_tokens=config.max_out_tokens,
scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx) scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx)
global transformer_config_g
transformer_config_g = transformer_config
if moe: if moe:
new_module = transformer_inference.DeepSpeedMoEInference( new_module = transformer_inference.DeepSpeedMoEInference(
...@@ -553,6 +564,10 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -553,6 +564,10 @@ def replace_transformer_layer(orig_layer_impl,
if qkvw.is_meta or qkvw.numel() == 0 or qkvw.is_meta: if qkvw.is_meta or qkvw.numel() == 0 or qkvw.is_meta:
if qkvw.is_meta or qkvw.ds_tensor.numel() < attn_block.attn_qkvw.numel(): if qkvw.is_meta or qkvw.ds_tensor.numel() < attn_block.attn_qkvw.numel():
if qkvb is None:
attn_block.attn_qkvb = None
if dense_b is None:
attn_block.attn_ob = None
pass pass
else: else:
with GatheredParameters([ with GatheredParameters([
...@@ -911,7 +926,9 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -911,7 +926,9 @@ def replace_transformer_layer(orig_layer_impl,
mp_replace, mp_replace,
ckpt_type, ckpt_type,
quantizer, quantizer,
) param_names=selected_policy_g.get_param_names(),
transformer_config=transformer_config_g,
megatron_v2=megatron_v2_g)
pbar.update(1) pbar.update(1)
else: else:
import gc import gc
...@@ -935,12 +952,16 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -935,12 +952,16 @@ def replace_transformer_layer(orig_layer_impl,
torch.load(ckpt_file, torch.load(ckpt_file,
map_location='cpu') for ckpt_file in ckpt_files map_location='cpu') for ckpt_file in ckpt_files
] ]
load_model_with_checkpoint(replaced_module, load_model_with_checkpoint(
sds, replaced_module,
mp_replace, sds,
ckpt_type, mp_replace,
quantizer, ckpt_type,
int(rank % tp_split_size)) quantizer,
int(rank % tp_split_size),
param_names=selected_policy_g.get_param_names(),
transformer_config=transformer_config_g,
megatron_v2=megatron_v2_g)
sds = [None for _ in sds] sds = [None for _ in sds]
gc.collect() gc.collect()
...@@ -955,12 +976,16 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -955,12 +976,16 @@ def replace_transformer_layer(orig_layer_impl,
checkpoint["non_tp"][i] checkpoint["non_tp"][i]
) if base_dir1 else checkpoint["non_tp"][i] ) if base_dir1 else checkpoint["non_tp"][i]
sds = [torch.load(ckpt_file, map_location='cpu')] sds = [torch.load(ckpt_file, map_location='cpu')]
load_model_with_checkpoint(replaced_module, load_model_with_checkpoint(
sds, replaced_module,
mp_replace, sds,
ckpt_type, mp_replace,
quantizer, ckpt_type,
int(rank % tp_split_size)) quantizer,
int(rank % tp_split_size),
param_names=selected_policy_g.get_param_names(),
transformer_config=transformer_config_g,
megatron_v2=megatron_v2_g)
sds = [None for _ in sds] sds = [None for _ in sds]
gc.collect() gc.collect()
print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec") print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec")
...@@ -986,6 +1011,7 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -986,6 +1011,7 @@ def replace_transformer_layer(orig_layer_impl,
non_tp_ckpt_name = f'non-tp.pt' non_tp_ckpt_name = f'non-tp.pt'
ckpt_files = [non_tp_ckpt_name] ckpt_files = [non_tp_ckpt_name]
os.makedirs(config.save_mp_checkpoint_path, exist_ok=True) os.makedirs(config.save_mp_checkpoint_path, exist_ok=True)
if not dist.is_initialized() or dist.get_rank() == 0: if not dist.is_initialized() or dist.get_rank() == 0:
print("Saving tp-sharded checkpoints") print("Saving tp-sharded checkpoints")
torch.save( torch.save(
...@@ -996,7 +1022,7 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -996,7 +1022,7 @@ def replace_transformer_layer(orig_layer_impl,
if transformer_name not in k if transformer_name not in k
}), }),
f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}') f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}')
config = json.dumps({ new_config = json.dumps({
'type': 'type':
ckpt_name, ckpt_name,
'base_dir': 'base_dir':
...@@ -1020,7 +1046,7 @@ def replace_transformer_layer(orig_layer_impl, ...@@ -1020,7 +1046,7 @@ def replace_transformer_layer(orig_layer_impl,
}) })
with open(f"{config.save_mp_checkpoint_path}/ds-inference_config.json", with open(f"{config.save_mp_checkpoint_path}/ds-inference_config.json",
"w") as cfg: "w") as cfg:
cfg.write(config) cfg.write(new_config)
rep_sd = replaced_module.state_dict() rep_sd = replaced_module.state_dict()
for n, p in replaced_module.named_parameters(): for n, p in replaced_module.named_parameters():
......
...@@ -92,15 +92,19 @@ class TransformerPolicy(DSPolicy): ...@@ -92,15 +92,19 @@ class TransformerPolicy(DSPolicy):
hf_model_config = None hf_model_config = None
def __init__( def __init__(
self, self,
inference=True, inference=True,
linear_layer=True, linear_layer=True,
scale_attention=True, scale_attention=True,
megatron_v2=False, megatron_v2=False,
# the type of activation function used in MLP # the type of activation function used in MLP
mlp_act_func_type=ActivationFuncType.GELU, mlp_act_func_type=ActivationFuncType.GELU,
# applies layer norm before attention if `pre_attn_norm` is set to True # applies layer norm before attention if `pre_attn_norm` is set to True
pre_attn_norm=True): pre_attn_norm=True,
# this flag shows whether or not using prefix in loading the checkpoint
use_load_prefix=False,
# whether or not the qkv is stored in the split-format
split_qkv=True):
super().__init__() super().__init__()
self.inference = inference self.inference = inference
self.linear_layer = linear_layer self.linear_layer = linear_layer
...@@ -108,6 +112,7 @@ class TransformerPolicy(DSPolicy): ...@@ -108,6 +112,7 @@ class TransformerPolicy(DSPolicy):
self.is_megatron_v2 = megatron_v2 self.is_megatron_v2 = megatron_v2
self.mlp_act_func_type = mlp_act_func_type self.mlp_act_func_type = mlp_act_func_type
self.pre_attn_norm = pre_attn_norm self.pre_attn_norm = pre_attn_norm
self.load_prefix = False
def attention(self): def attention(self):
""" """
...@@ -139,6 +144,31 @@ class TransformerPolicy(DSPolicy): ...@@ -139,6 +144,31 @@ class TransformerPolicy(DSPolicy):
""" """
raise NotImplementedError raise NotImplementedError
def get_param_names(self):
"""
Returns all the transformer parameter names to
be loaded from checkpoint files. The order of
the names is as follows:
1. Attention weights and biases;
2. MLP weights and biases;
3. LayerNorm weights and biases;
In addition to the parameter names, we require two
more parameters to help read the the data correctly
from the checkpoint and split the qkv heads in the
right order:
1. `use_load_prefix` (Default: False): this specifies
whether we need to use the name of first abstraction
layer of the model for searching the parameter's name
in a checkpoint file. For more information of how this
is used please see
https://github.com/microsoft/DeepSpeed/blob/fix-ckpt-loading/deepspeed/module_inject/load_checkpoint.py#L341
2. `split_qkv` (Default: True): we use this flag when splitting
the qkv parameter into heads. If it is False, it means the heads
of q, k, and v are stored together and needs to split in the
DeepSpeed-Inference API.
"""
raise NotImplementedError
class HFBertLayerPolicy(TransformerPolicy): class HFBertLayerPolicy(TransformerPolicy):
def __init__(self, client_module, inference=False): def __init__(self, client_module, inference=False):
...@@ -294,6 +324,22 @@ class HFGPTNEOLayerPolicy(TransformerPolicy): ...@@ -294,6 +324,22 @@ class HFGPTNEOLayerPolicy(TransformerPolicy):
self.client_module.ln_1.weight, \ self.client_module.ln_1.weight, \
self.client_module.ln_1.bias self.client_module.ln_1.bias
def get_param_names(self):
return 'attention.query_key_value.weight', \
'attention.query_key_value.bias', \
'attention.dense.weight', \
'attention.dense.bias', \
'mlp.dense_h_to_4h.weight', \
'mlp.dense_h_to_4h.bias', \
'mlp.dense_4h_to_h.weight', \
'mlp.dense_4h_to_h.bias', \
'input_layernorm.weight', \
'input_layernorm.bias', \
'post_attention_layernorm.weight', \
'post_attention_layernorm.bias', \
self.use_load_prefix, \
self.split_qkv
class HFGPTJLayerPolicy(TransformerPolicy): class HFGPTJLayerPolicy(TransformerPolicy):
_orig_layer_class = None _orig_layer_class = None
...@@ -339,6 +385,20 @@ class HFGPTJLayerPolicy(TransformerPolicy): ...@@ -339,6 +385,20 @@ class HFGPTJLayerPolicy(TransformerPolicy):
self.client_module.ln_1.weight, \ self.client_module.ln_1.weight, \
self.client_module.ln_1.bias self.client_module.ln_1.bias
def get_param_names(self):
return 'attn.q_proj.weight', \
'attn.k_proj.weight', \
'attn.v_proj.weight', \
'attn.out_proj.weight', \
'mlp.fc_in.weight', \
'mlp.fc_in.bias', \
'mlp.fc_out.weight', \
'mlp.fc_out.bias', \
'ln_1.weight', \
'ln_1.bias', \
self.use_load_prefix, \
self.split_qkv
class MegatronLayerPolicy(TransformerPolicy): class MegatronLayerPolicy(TransformerPolicy):
_orig_layer_class = None _orig_layer_class = None
...@@ -463,7 +523,11 @@ class HFGPT2LayerPolicy(TransformerPolicy): ...@@ -463,7 +523,11 @@ class HFGPT2LayerPolicy(TransformerPolicy):
class BLOOMLayerPolicy(TransformerPolicy): class BLOOMLayerPolicy(TransformerPolicy):
_orig_layer_class = None _orig_layer_class = None
def __init__(self, client_module, inference=True): def __init__(self,
client_module,
inference=True,
use_load_prefix=True,
split_qkv=False):
super().__init__(inference, linear_layer=True) super().__init__(inference, linear_layer=True)
self.client_module = client_module self.client_module = client_module
try: try:
...@@ -501,12 +565,28 @@ class BLOOMLayerPolicy(TransformerPolicy): ...@@ -501,12 +565,28 @@ class BLOOMLayerPolicy(TransformerPolicy):
self.client_module.input_layernorm.weight, \ self.client_module.input_layernorm.weight, \
self.client_module.input_layernorm.bias self.client_module.input_layernorm.bias
def get_param_names(self):
return 'self_attention.query_key_value.weight', \
'self_attention.query_key_value.bias', \
'self_attention.dense.weight', \
'self_attention.dense.bias', \
'mlp.dense_h_to_4h.weight', \
'mlp.dense_h_to_4h.bias', \
'mlp.dense_4h_to_h.weight', \
'mlp.dense_4h_to_h.bias', \
'input_layernorm.weight', \
'input_layernorm.bias', \
'post_attention_layernorm.weight', \
'post_attention_layernorm.bias', \
self.use_load_prefix, \
self.split_qkv
class GPTNEOXLayerPolicy(TransformerPolicy): class GPTNEOXLayerPolicy(TransformerPolicy):
_orig_layer_class = None _orig_layer_class = None
version = 0 version = 0
def __init__(self, client_module, inference=True, megatron_v2=True): def __init__(self, client_module, inference=True, megatron_v2=True, split_qkv=False):
super().__init__(inference, megatron_v2=megatron_v2) super().__init__(inference, megatron_v2=megatron_v2)
self.client_module = client_module self.client_module = client_module
if GPTNEOXLayerPolicy._orig_layer_class is None: if GPTNEOXLayerPolicy._orig_layer_class is None:
...@@ -555,11 +635,27 @@ class GPTNEOXLayerPolicy(TransformerPolicy): ...@@ -555,11 +635,27 @@ class GPTNEOXLayerPolicy(TransformerPolicy):
self.client_module.input_layernorm.weight, \ self.client_module.input_layernorm.weight, \
self.client_module.input_layernorm.bias self.client_module.input_layernorm.bias
def get_param_names(self):
return 'attention.query_key_value.weight', \
'attention.query_key_value.bias', \
'attention.dense.weight', \
'attention.dense.bias', \
'mlp.dense_h_to_4h.weight', \
'mlp.dense_h_to_4h.bias', \
'mlp.dense_4h_to_h.weight', \
'mlp.dense_4h_to_h.bias', \
'input_layernorm.weight', \
'input_layernorm.bias', \
'post_attention_layernorm.weight', \
'post_attention_layernorm.bias', \
self.use_load_prefix, \
self.split_qkv
class HFOPTLayerPolicy(TransformerPolicy): class HFOPTLayerPolicy(TransformerPolicy):
_orig_layer_class = None _orig_layer_class = None
def __init__(self, client_module, inference=True): def __init__(self, client_module, inference=True, use_load_prefix=True):
super().__init__(inference, super().__init__(inference,
linear_layer=True, linear_layer=True,
mlp_act_func_type=ActivationFuncType.ReLU, mlp_act_func_type=ActivationFuncType.ReLU,
...@@ -568,9 +664,9 @@ class HFOPTLayerPolicy(TransformerPolicy): ...@@ -568,9 +664,9 @@ class HFOPTLayerPolicy(TransformerPolicy):
try: try:
import transformers import transformers
HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer
if isinstance(DSPolicy.hf_model_config, if isinstance(TransformerPolicy.hf_model_config,
transformers.models.opt.configuration_opt.OPTConfig): transformers.models.opt.configuration_opt.OPTConfig):
self.pre_attn_norm = self.hf_model_config.do_layer_norm_before self.pre_attn_norm = TransformerPolicy.hf_model_config.do_layer_norm_before
except: except:
HFOPTLayerPolicy._orig_layer_class = None HFOPTLayerPolicy._orig_layer_class = None
...@@ -612,6 +708,26 @@ class HFOPTLayerPolicy(TransformerPolicy): ...@@ -612,6 +708,26 @@ class HFOPTLayerPolicy(TransformerPolicy):
self.client_module.self_attn_layer_norm.weight, \ self.client_module.self_attn_layer_norm.weight, \
self.client_module.self_attn_layer_norm.bias self.client_module.self_attn_layer_norm.bias
def get_param_names(self):
return 'self_attn.q_proj.weight', \
'self_attn.q_proj.bias', \
'self_attn.k_proj.weight', \
'self_attn.k_proj.bias', \
'self_attn.v_proj.weight', \
'self_attn.v_proj.bias', \
'self_attn.out_proj.weight', \
'self_attn.out_proj.bias', \
'fc1.weight', \
'fc1.bias', \
'fc2.weight', \
'fc2.bias', \
'self_attn_layer_norm.weight', \
'self_attn_layer_norm.bias', \
'final_layer_norm.weight', \
'final_layer_norm.bias', \
self.use_load_prefix, \
self.split_qkv
# transformer-based policies # transformer-based policies
replace_policies = [ replace_policies = [
......
...@@ -331,7 +331,6 @@ class DeepSpeedSelfAttentionFunction(Function): ...@@ -331,7 +331,6 @@ class DeepSpeedSelfAttentionFunction(Function):
False, False,
attn_ow.scale, attn_ow.scale,
config.q_int8) config.q_int8)
return output, key_layer, value_layer, context_layer, qkv_out[-1] return output, key_layer, value_layer, context_layer, qkv_out[-1]
def selfAttention_int8(): def selfAttention_int8():
...@@ -394,7 +393,7 @@ class DeepSpeedSelfAttention(nn.Module): ...@@ -394,7 +393,7 @@ class DeepSpeedSelfAttention(nn.Module):
data_type_fp = torch.half if config.fp16 else torch.float data_type_fp = torch.half if config.fp16 else torch.float
self.config.layer_id = DeepSpeedSelfAttention.num_layers self.config.layer_id = DeepSpeedSelfAttention.num_layers
DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1 DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1
device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' device = torch.cuda.current_device() #if config.bigscience_bloom else 'cpu'
qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3
self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size, self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size,
qkv_size_per_partition, qkv_size_per_partition,
......
...@@ -41,10 +41,13 @@ class DeepSpeedMLPFunction(Function): ...@@ -41,10 +41,13 @@ class DeepSpeedMLPFunction(Function):
if attn_nw is None: if attn_nw is None:
output = fused_gemm_gelu(residual_norm, output = fused_gemm_gelu(residual_norm,
inter_w, inter_w,
inter_w.scale,
inter_b, inter_b,
output_w, output_w,
output_w.scale,
config.epsilon, config.epsilon,
config.pre_layer_norm, config.pre_layer_norm,
config.q_int8,
False) False)
else: else:
output, residual_add = mlp_gemm_func(input, output, residual_add = mlp_gemm_func(input,
...@@ -96,7 +99,7 @@ class DeepSpeedMLP(nn.Module): ...@@ -96,7 +99,7 @@ class DeepSpeedMLP(nn.Module):
self.config = config self.config = config
data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float
data_type_fp = torch.half if config.fp16 else torch.float data_type_fp = torch.half if config.fp16 else torch.float
device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu' device = torch.cuda.current_device() #if config.bigscience_bloom else 'cpu'
self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type_fp, dtype=data_type_fp,
device=device), device=device),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册