未验证 提交 35b350b2 编写于 作者: R Reza Yazdani 提交者: GitHub

Fix quantized-inference & Add generic support of checkpoint loading (#2547)

* fix checkpoint loading when it is a dictionary

* fix some issues with saving ckpt & int8 inference

* fix quantized-inference & add generic support of checkpoint loading

* remove int8 hard-coded flag

* fix mlp return tensors

* fix several issue to load checkpoints of GPT-J, GPT-NEOX, and OPT with different TP-size

* add more comments & description for checkpoint-loading module
Co-authored-by: NMichael Wyatt <michaelwyatt@microsoft.com>
上级 b8416282
......@@ -763,11 +763,18 @@ void quantized_gemm(void* output,
at::Tensor& weight,
at::Tensor& qscale,
int groups,
int bsz)
int bsz,
int hidden_size)
{
T* weight16 = (T*)Context::Instance().GetWorkSpace() +
12 * Context::Instance().GetMaxTokenLenght() * weight.size(1);
T* weight16 = (T*)Context::Instance().GetWorkSpace() + 12 * hidden_size * bsz;
// auto options = at::TensorOptions()
// .dtype(at::kHalf)
// .layout(at::kStrided)
// .device(at::kCUDA)
// .requires_grad(false);
// auto tmp = torch::empty(weight.sizes(), options);
// T* weight16 = (T*)tmp.data_ptr();
launch_dequantize(weight16,
(int8_t*)weight.data_ptr(),
(float*)qscale.data_ptr(),
......@@ -814,7 +821,8 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
ds_layer_norm_internal<T>(workspace, input, gamma, beta, epsilon);
if (q_int8) {
quantized_gemm<T>(output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz);
quantized_gemm<T>(
output.data_ptr(), workspace, weight, q_scale, q_scale.size(0), bsz, input.size(2));
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
......@@ -1202,15 +1210,19 @@ at::Tensor ds_vector_matmul(at::Tensor& input,
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
int out_size = q_int8 ? weight.size(0) : weight.size(1);
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
auto output = at::from_blob(workspace, {input.size(0), input.size(1), out_size}, options);
if (q_int8) {
quantized_gemm<T>(
output.data_ptr(), (T*)input.data_ptr(), weight, q_scale, q_scale.size(0), bsz);
quantized_gemm<T>(output.data_ptr(),
(T*)input.data_ptr(),
weight,
q_scale,
q_scale.size(0),
bsz,
input.size(2));
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
......@@ -1293,9 +1305,9 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
} else {
ds_layer_norm_internal(inp_norm, input, gamma, beta, epsilon);
}
if (q_int8) {
quantized_gemm<T>(intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz);
quantized_gemm<T>(
intermediate, inp_norm, weight, q_scale, q_scale.size(0), bsz, input.size(2));
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
......@@ -1331,9 +1343,15 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output,
bsz,
Context::Instance().GetCurrentStream());
}
if (q_int8) {
quantized_gemm<T>(
output.data_ptr(), intermediate, weight1, q_scale1, q_scale1.size(0), bsz);
quantized_gemm<T>(output.data_ptr(),
intermediate,
weight1,
q_scale1,
q_scale1.size(0),
bsz,
input.size(2));
} else {
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
......@@ -1449,64 +1467,95 @@ std::vector<at::Tensor> ds_mlp_gemm_int8(at::Tensor& input,
template <typename T>
at::Tensor fused_gemm_gelu(at::Tensor& input,
at::Tensor& weight,
at::Tensor& weight_scale,
at::Tensor& bias,
at::Tensor& weight_out,
at::Tensor& weight_out_scale,
const float epsilon,
bool preLayerNorm,
bool q_int8,
bool async_op)
{
auto input_cont = input.contiguous();
auto options = at::TensorOptions()
.dtype(input_cont.options().dtype())
.dtype(input.options().dtype())
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto intermediate =
at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options);
auto output = at::empty({input_cont.size(0), input_cont.size(1), weight_out.size(1)}, options);
int bsz = input_cont.size(0) * input_cont.size(1);
int intm_dim = q_int8 ? weight.size(0) : weight.size(1);
// auto output = at::from_blob((T*)Context::Instance().GetWorkSpace() + torch::numel(input),
// {input.size(0), input.size(1), out_size},
// options);
// T* intermediate = (T*)input.data_ptr() + torch::numel(input);
auto intermediate = at::empty({input.size(0), input.size(1), intm_dim}, options);
int bsz = input.size(0) * input.size(1);
float alpha = (T)1.0;
float gemm_beta = (T)0.0;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight.size(1),
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input_cont.data_ptr(),
(T*)intermediate.data_ptr(),
if (q_int8) {
quantized_gemm<T>(intermediate.data_ptr(),
(T*)input.data_ptr(),
weight,
weight_scale,
weight_scale.size(0),
bsz,
input.size(2));
} else {
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
intm_dim,
bsz,
input.size(2),
&alpha,
&gemm_beta,
(T*)weight.data_ptr(),
(T*)input.data_ptr(),
(T*)intermediate.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
launch_bias_gelu((T*)intermediate.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
intm_dim,
bsz,
Context::Instance().GetCurrentStream());
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
weight_out.size(1),
bsz,
intermediate.size(2),
&alpha,
&gemm_beta,
(T*)weight_out.data_ptr(),
(T*)intermediate.data_ptr(),
(T*)output.data_ptr(),
int out_size = q_int8 ? weight_out.size(0) : weight_out.size(1);
auto output = at::empty({input.size(0), input.size(1), out_size}, options);
if (q_int8) {
quantized_gemm<T>(output.data_ptr(),
(T*)intermediate.data_ptr(),
weight_out,
weight_out_scale,
weight_out_scale.size(0),
bsz,
input.size(2));
} else {
cublas_gemm_ex(Context::Instance().GetCublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
out_size,
bsz,
intm_dim,
&alpha,
&gemm_beta,
(T*)weight_out.data_ptr(),
(T*)intermediate.data_ptr(),
(T*)output.data_ptr(),
#ifdef __HIP_PLATFORM_HCC__
rocblas_gemm_algo_standard);
rocblas_gemm_algo_standard);
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
}
// cudaEventRecord(Context::Instance().GetCompEvent(2),
// Context::Instance().GetCurrentStream(true));
return output;
......
......@@ -68,7 +68,7 @@ class DeepSpeedTransformerInference(nn.Module):
merge_count,
mlp_extra_grouping)
device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu'
device = torch.cuda.current_device() #if config.bigscience_bloom else 'cpu'
self.norm_w = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type,
device=device),
......@@ -131,7 +131,6 @@ class DeepSpeedTransformerInference(nn.Module):
if (self.config.fp16 or self.config.q_int8) \
and input.dtype == torch.float:
input = input.half()
with torch.no_grad():
attention_output, key, value, context_outputtn_ctx, inp_norm = \
self.attention(input,
......
......@@ -23,7 +23,7 @@ class LinearAllreduce(nn.Module):
class LinearLayer(nn.Module):
def __init__(self, weight_shape=None, dtype=None, weight=None, bias=None):
def __init__(self, weight_shape=None, dtype=torch.half, weight=None, bias=None):
super(LinearLayer, self).__init__()
if weight is not None:
self.weight = weight
......@@ -33,10 +33,12 @@ class LinearLayer(nn.Module):
torch.empty(weight_shape,
dtype=dtype,
device=torch.cuda.current_device()))
self.bias = Parameter(
torch.empty(weight_shape[0],
dtype=dtype,
device=torch.cuda.current_device()))
device=torch.cuda.current_device())) \
if bias is not None else None
def forward(self, input):
output = torch.matmul(input, self.weight.transpose(-1, -2))
......@@ -57,7 +59,7 @@ class Normalize(nn.Module):
class EmbeddingLayer(nn.Module):
def __init__(self, weight_shape, dtype=torch.float):
def __init__(self, weight_shape, dtype=torch.half):
super(EmbeddingLayer, self).__init__()
self.weight = Parameter(
torch.empty(weight_shape[0],
......@@ -67,3 +69,28 @@ class EmbeddingLayer(nn.Module):
def forward(self, input):
return F.embedding(input, self.weight)
class OPTEmbedding(EmbeddingLayer):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, weight_shape):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
super().__init__(weight_shape)
def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()
# create positions depending on attention_mask
positions = (torch.cumsum(attention_mask,
dim=1).type_as(attention_mask) *
attention_mask).long() - 1
# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]
return super().forward(positions + self.offset)
from torch import nn
import deepspeed.ops.transformer as transformer_inference
from ..runtime.zero import GatheredParameters
from .layers import LinearLayer, Normalize, EmbeddingLayer
from .layers import LinearLayer, Normalize, EmbeddingLayer, OPTEmbedding
import torch
import gc
......@@ -11,14 +11,18 @@ def load_model_with_checkpoint(r_module,
mp_replace,
ckpt_type,
weight_quantizer=None,
rank=0):
rank=0,
param_names=None,
transformer_config=None,
megatron_v2=False):
error_msgs = []
def transpose(data):
data = data.contiguous()
data1 = data.transpose(-1, -2).reshape(-1)
data.reshape(-1).copy_(data1)
data1 = None
with torch.no_grad():
data = data.contiguous()
data1 = data.transpose(-1, -2).reshape(-1)
data.reshape(-1).copy_(data1)
data1 = None
return data.reshape(data.shape[-1], data.shape[-2])
def load(module, prefix):
......@@ -87,7 +91,7 @@ def load_model_with_checkpoint(r_module,
else:
assert tmp_data.dtype != torch.int8, \
'''Merging of the checkpoints are not supported when using INT8 checkpoint! \
Please use a as many GPUs as TP-size for the checkpoint'''
Please use a as many GPUs as TP-size for the checkpoint'''
all_data = [
sd[j][prefix +
n] if type(sd[j][prefix + n]) is list else
......@@ -138,37 +142,146 @@ def load_model_with_checkpoint(r_module,
for n, child in module.named_children():
load_parameters(child, prefix + n + '.')
else:
module.norm_w.data.copy_(sd[0][prefix + 'input_layernorm.' + 'weight'])
module.norm_b.data.copy_(sd[0][prefix + 'input_layernorm.' + 'bias'])
module.attention.attn_qkvw = mp_replace.copy(module.attention.attn_qkvw,
weight_quantizer.quantize(sd[0][prefix + 'self_attention.query_key_value.' + 'weight']) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(transpose(sd[0][prefix + 'self_attention.query_key_value.' + 'weight'])))
module.attention.attn_qkvb = mp_replace.copy(
module.attention.attn_qkvb.data,
sd[0][prefix + 'self_attention.query_key_value.' + 'bias'])
module.attention.attn_ow = mp_replace.copy(module.attention.attn_ow,
weight_quantizer.quantize(sd[0][prefix + 'self_attention.dense.' + 'weight']) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(transpose(sd[0][prefix + 'self_attention.dense.' + 'weight'])))
module.attention.attn_ob = mp_replace.copy(
module.attention.attn_ob.data,
sd[0][prefix + 'self_attention.dense.' + 'bias'])
module.mlp.attn_nw.data.copy_(sd[0][prefix + 'post_attention_layernorm.' +
'weight'])
module.mlp.attn_nb.data.copy_(sd[0][prefix + 'post_attention_layernorm.' +
'bias'])
module.mlp.inter_w = mp_replace.copy(module.mlp.inter_w,
weight_quantizer.quantize(sd[0][prefix + 'mlp.dense_h_to_4h.' + 'weight']) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(transpose(sd[0][prefix + 'mlp.dense_h_to_4h.' + 'weight'])))
module.mlp.inter_b = mp_replace.copy(
module.mlp.inter_b.data,
sd[0][prefix + 'mlp.dense_h_to_4h.' + 'bias'])
module.mlp.output_w = mp_replace.copy(module.mlp.output_w,
weight_quantizer.quantize(sd[0][prefix + 'mlp.dense_4h_to_h.' + 'weight']) if weight_quantizer.q_int8 else \
weight_quantizer.quantize(transpose(sd[0][prefix + 'mlp.dense_4h_to_h.' + 'weight'])))
module.mlp.output_b = mp_replace.copy(
module.mlp.output_b.data,
sd[0][prefix + 'mlp.dense_4h_to_h.' + 'bias'])
def _transpose(x):
heads = transformer_config.heads // mp_replace.mp_size
attention_head_size = x.shape[-1] // heads
new_x_shape = x.size()[:-1] + (heads, attention_head_size)
x_1 = x.view(*new_x_shape)
(q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1))
if len(q.shape) > 2:
return torch.cat((q.reshape(q.shape[0],
-1),
k.reshape(q.shape[0],
-1),
v.reshape(q.shape[0],
-1)),
dim=-1).reshape(x.shape)
else:
return torch.cat((q.reshape(-1),
k.reshape(-1),
v.reshape(-1)),
dim=-1).reshape(x.shape)
# This checks if the parameter exits in the checkpoint file and maybe copies it into the corresponding destination tensor.
# Note that not all parameters are saved in one checkpoint, that's why we always need to check if they exist!
def maybe_copy(module,
dst_name,
src_name,
qkv=False,
megatron_v2=False,
split_qkv=False):
if src_name in sd[0]:
dst = getattr(module, dst_name)
tmp = sd[0][src_name].cuda()
if len(dst.shape) == 1:
if split_qkv:
dst = mp_replace.qkv_copy(dst, tmp)
else:
dst = mp_replace.copy(dst, tmp)
if qkv and megatron_v2:
dst = torch.nn.parameter.Parameter(
_transpose(dst).contiguous())
else:
if split_qkv:
dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, tmp if weight_quantizer.q_int8 else \
(transpose(tmp).contiguous())))
else:
dst = weight_quantizer.quantize(mp_replace.copy(dst, tmp if weight_quantizer.q_int8 else \
transpose(tmp)))
if qkv and megatron_v2:
scale1 = dst.scale
dst = torch.nn.parameter.Parameter(
_transpose(dst).contiguous())
dst.scale = scale1
setattr(module, dst_name, dst)
# Extending the maybe_copy function for when the q, k, and v are in separate parameters!
def maybe_copy_qkv(module, dst_name, src_names, split_qkv=False):
if src_names[0] in sd[0]:
q = sd[0][src_names[0]]
k = sd[0][src_names[1]]
v = sd[0][src_names[2]]
qkv_data = torch.cat((q, k, v), dim=0)
dst = getattr(module, dst_name)
if len(dst.shape) == 1:
if split_qkv:
dst = mp_replace.qkv_copy(dst,
(qkv_data.cuda()).contiguous())
else:
dst = mp_replace.copy(dst, qkv_data.cuda())
else:
if split_qkv:
dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, qkv_data.cuda() if weight_quantizer.q_int8 else \
((transpose(qkv_data.cuda())).contiguous())))
else:
dst = weight_quantizer.quantize(mp_replace.copy(dst, qkv_data.cuda() if weight_quantizer.q_int8 else \
transpose(qkv_data.cuda())))
setattr(module, dst_name, dst)
if len(param_names) == 14:
qkv_w, qkv_b, attn_ow, attn_ob, \
mlp_intw, mlp_intb, mlp_ow, mlp_ob, \
inp_normw, inp_normb, attn_nw, attn_nb, _, split_qkv = param_names
elif len(param_names) < 14:
q_w, k_w, v_w, attn_ow, \
mlp_intw, mlp_intb, mlp_ow, mlp_ob, \
inp_normw, inp_normb, _, split_qkv = param_names
else:
q_w, q_b, k_w, k_b, v_w, v_b, attn_ow, attn_ob, \
mlp_intw, mlp_intb, mlp_ow, mlp_ob, \
inp_normw, inp_normb, attn_nw, attn_nb, _, split_qkv = param_names
maybe_copy(module, 'norm_w', prefix + inp_normw)
maybe_copy(module, 'norm_b', prefix + inp_normb)
if len(param_names) == 14:
maybe_copy(module.attention,
'attn_qkvw',
prefix + qkv_w,
qkv=True,
megatron_v2=megatron_v2,
split_qkv=split_qkv)
maybe_copy(module.attention,
'attn_qkvb',
prefix + qkv_b,
qkv=True,
megatron_v2=megatron_v2,
split_qkv=split_qkv)
elif len(param_names) < 14:
maybe_copy_qkv(module.attention,
'attn_qkvw',
[prefix + q_w,
prefix + k_w,
prefix + v_w],
split_qkv=split_qkv)
else:
maybe_copy_qkv(module.attention,
'attn_qkvw',
[prefix + q_w,
prefix + k_w,
prefix + v_w],
split_qkv=split_qkv)
maybe_copy_qkv(module.attention,
'attn_qkvb',
[prefix + q_b,
prefix + k_b,
prefix + v_b],
split_qkv=split_qkv)
maybe_copy(module.attention, 'attn_ow', prefix + attn_ow)
if len(param_names) >= 14:
maybe_copy(module.attention, 'attn_ob', prefix + attn_ob)
maybe_copy(module.mlp, 'attn_nw', prefix + attn_nw)
maybe_copy(module.mlp, 'attn_nb', prefix + attn_nb)
maybe_copy(module.mlp, 'inter_w', prefix + mlp_intw)
maybe_copy(module.mlp, 'inter_b', prefix + mlp_intb)
maybe_copy(module.mlp, 'output_w', prefix + mlp_ow)
maybe_copy(module.mlp, 'output_b', prefix + mlp_ob)
try:
import transformers
OPTLearnedPositionalEmbedding = transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding
except:
OPTLearnedPositionalEmbedding = None
layer_policies = {
nn.Linear: load,
nn.Embedding: load,
......@@ -176,7 +289,9 @@ def load_model_with_checkpoint(r_module,
EmbeddingLayer: load,
LinearLayer: load,
Normalize: load,
transformer_inference.DeepSpeedTransformerInference: load_transformer_layer
transformer_inference.DeepSpeedTransformerInference: load_transformer_layer,
OPTLearnedPositionalEmbedding: load,
OPTEmbedding: load
}
all_ds_ids = {}
......@@ -201,14 +316,17 @@ def load_model_with_checkpoint(r_module,
ds_shape = child.weight.shape
else:
ds_shape = child.weight.ds_shape
if child.__class__ is nn.LayerNorm:
child = Normalize(dim=ds_shape[-1],
dtype=child.weight.dtype,
eps=child.eps)
setattr(module, name, child)
elif child.__class__ is nn.Linear:
child = LinearLayer(weight=child.weight, bias=child.bias)
child = LinearLayer(weight_shape=child.weight.shape,
bias=child.bias)
setattr(module, name, child)
elif child.__class__ is OPTLearnedPositionalEmbedding:
child = OPTEmbedding(weight_shape=ds_shape)
setattr(module, name, child)
else:
ds_id = None
......@@ -224,7 +342,8 @@ def load_model_with_checkpoint(r_module,
else:
load_module_recursive(
child,
prefix if level == 0 and ckpt_type == 'pp' else prefix + name + '.',
prefix if (level == 0 and ckpt_type == 'pp') and param_names[-2] else \
prefix + name + '.',
level + 1)
load_module_recursive(r_module)
......@@ -232,10 +351,11 @@ def load_model_with_checkpoint(r_module,
#XXX: hack to tie embedding w. lm_head for BLOOM, need to revist soon
embedding_weight = None
for n, p in r_module.named_parameters():
if "word_embeddings." in n:
if "word_embeddings." in n or "embed_tokens." in n:
embedding_weight = p
assert hasattr(r_module, 'lm_head'), "attempting to set lm_head but it doesn't exist"
r_module.lm_head.weight = embedding_weight
if embedding_weight is not None:
r_module.lm_head.weight = embedding_weight
for sd_ in sd:
del sd_
sd = None
......
......@@ -36,13 +36,13 @@ class ReplaceWithTensorSlicing:
return src
src_shape = src.shape
dst_shape = dst.shape
if self.out_dim == 0:
src_split = torch.split(src.data,
src_shape[self.out_dim] // self.mp_size,
dim=0)
else:
src_split = torch.split(src.data, src.shape[-1] // 3, dim=-1)
if (len(src_shape) == 2 and len(dst_shape) == 2):
if src_shape[self.out_dim] == dst_shape[self.out_dim]:
return torch.nn.parameter.Parameter(src)
......@@ -54,7 +54,6 @@ class ReplaceWithTensorSlicing:
qkv_size,
dim=self.out_dim) for src_s in src_split
]
weight_split = [
torch.cat([qkv_s[i] for qkv_s in qkv_split],
axis=self.out_dim) for i in range(len(qkv_split[0]))
......@@ -137,8 +136,7 @@ def get_transformer_name(replaced_module):
class GroupQuantizer:
def __init__(self, q_int8=True, num_groups=32, group_size=32, num_bits=8):
self.num_groups = num_groups
def __init__(self, q_int8=True, group_size=1, num_bits=8):
self.group_size = group_size
self.num_bits = num_bits
self.q_int8 = q_int8
......@@ -149,8 +147,9 @@ class GroupQuantizer:
inputs.scale = torch.empty(1)
return inputs
q_range = 2**self.num_bits
num_groups = inputs.shape[0] // self.group_size
inputs = inputs.to(torch.cuda.current_device())
input_flat = inputs.reshape(self.num_groups, -1).contiguous()
input_flat = inputs.reshape(num_groups, -1).contiguous()
input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float()
input_max = torch.max(input_flat, dim=1, keepdim=True)[0].float()
scale = torch.max(input_min.abs(), input_max.abs()) * 2.0 / (q_range)
......@@ -160,7 +159,7 @@ class GroupQuantizer:
#print(inputs.shape)
inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim)
input_flat = [
inputs_split[i].reshape(self.num_groups,
inputs_split[i].reshape(num_groups,
-1).contiguous() for i in range(2)
]
input_min = [
......@@ -182,7 +181,7 @@ class GroupQuantizer:
out.scale = torch.cat([scale.squeeze().unsqueeze(0),
scale1[0],
scale1[1]],
dim=0).reshape(self.num_groups,
dim=0).reshape(num_groups,
-1).contiguous()
return out
......@@ -286,6 +285,11 @@ def generic_injection(module, fp16=False, enable_cuda_graph=True):
setattr(module, name, new_module)
selected_policy_g = None
megatron_v2_g = False
transformer_config_g = None
def replace_transformer_layer(orig_layer_impl,
model,
checkpoint_dict,
......@@ -325,6 +329,9 @@ def replace_transformer_layer(orig_layer_impl,
inference=False,
layer_id=0):
policy = policy_cls(child, inference=inference)
global selected_policy_g
if selected_policy_g is None:
selected_policy_g = policy
if not policy.cuda_graph_supported:
# policy says cuda graph is not supported raise an error if set
assert not config.enable_cuda_graph, "cuda graph is not supported with this model, please disable"
......@@ -340,6 +347,8 @@ def replace_transformer_layer(orig_layer_impl,
moe = True
attn_linear_layer, qkvw, qkvb, dense_w, dense_b, scale_attention, megatron_v2 = policy.attention()
global megatron_v2_g
megatron_v2_g = megatron_v2
if not moe or config.moe.type == 'standard':
mlp_linear_layer, _h4h_w, _h4h_b, _4hh_w, _4hh_b = policy.mlp()
else:
......@@ -439,6 +448,8 @@ def replace_transformer_layer(orig_layer_impl,
bigscience_bloom=bigscience_bloom,
max_out_tokens=config.max_out_tokens,
scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx)
global transformer_config_g
transformer_config_g = transformer_config
if moe:
new_module = transformer_inference.DeepSpeedMoEInference(
......@@ -553,6 +564,10 @@ def replace_transformer_layer(orig_layer_impl,
if qkvw.is_meta or qkvw.numel() == 0 or qkvw.is_meta:
if qkvw.is_meta or qkvw.ds_tensor.numel() < attn_block.attn_qkvw.numel():
if qkvb is None:
attn_block.attn_qkvb = None
if dense_b is None:
attn_block.attn_ob = None
pass
else:
with GatheredParameters([
......@@ -911,7 +926,9 @@ def replace_transformer_layer(orig_layer_impl,
mp_replace,
ckpt_type,
quantizer,
)
param_names=selected_policy_g.get_param_names(),
transformer_config=transformer_config_g,
megatron_v2=megatron_v2_g)
pbar.update(1)
else:
import gc
......@@ -935,12 +952,16 @@ def replace_transformer_layer(orig_layer_impl,
torch.load(ckpt_file,
map_location='cpu') for ckpt_file in ckpt_files
]
load_model_with_checkpoint(replaced_module,
sds,
mp_replace,
ckpt_type,
quantizer,
int(rank % tp_split_size))
load_model_with_checkpoint(
replaced_module,
sds,
mp_replace,
ckpt_type,
quantizer,
int(rank % tp_split_size),
param_names=selected_policy_g.get_param_names(),
transformer_config=transformer_config_g,
megatron_v2=megatron_v2_g)
sds = [None for _ in sds]
gc.collect()
......@@ -955,12 +976,16 @@ def replace_transformer_layer(orig_layer_impl,
checkpoint["non_tp"][i]
) if base_dir1 else checkpoint["non_tp"][i]
sds = [torch.load(ckpt_file, map_location='cpu')]
load_model_with_checkpoint(replaced_module,
sds,
mp_replace,
ckpt_type,
quantizer,
int(rank % tp_split_size))
load_model_with_checkpoint(
replaced_module,
sds,
mp_replace,
ckpt_type,
quantizer,
int(rank % tp_split_size),
param_names=selected_policy_g.get_param_names(),
transformer_config=transformer_config_g,
megatron_v2=megatron_v2_g)
sds = [None for _ in sds]
gc.collect()
print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec")
......@@ -986,6 +1011,7 @@ def replace_transformer_layer(orig_layer_impl,
non_tp_ckpt_name = f'non-tp.pt'
ckpt_files = [non_tp_ckpt_name]
os.makedirs(config.save_mp_checkpoint_path, exist_ok=True)
if not dist.is_initialized() or dist.get_rank() == 0:
print("Saving tp-sharded checkpoints")
torch.save(
......@@ -996,7 +1022,7 @@ def replace_transformer_layer(orig_layer_impl,
if transformer_name not in k
}),
f'{config.save_mp_checkpoint_path}/{non_tp_ckpt_name}')
config = json.dumps({
new_config = json.dumps({
'type':
ckpt_name,
'base_dir':
......@@ -1020,7 +1046,7 @@ def replace_transformer_layer(orig_layer_impl,
})
with open(f"{config.save_mp_checkpoint_path}/ds-inference_config.json",
"w") as cfg:
cfg.write(config)
cfg.write(new_config)
rep_sd = replaced_module.state_dict()
for n, p in replaced_module.named_parameters():
......
......@@ -92,15 +92,19 @@ class TransformerPolicy(DSPolicy):
hf_model_config = None
def __init__(
self,
inference=True,
linear_layer=True,
scale_attention=True,
megatron_v2=False,
# the type of activation function used in MLP
mlp_act_func_type=ActivationFuncType.GELU,
# applies layer norm before attention if `pre_attn_norm` is set to True
pre_attn_norm=True):
self,
inference=True,
linear_layer=True,
scale_attention=True,
megatron_v2=False,
# the type of activation function used in MLP
mlp_act_func_type=ActivationFuncType.GELU,
# applies layer norm before attention if `pre_attn_norm` is set to True
pre_attn_norm=True,
# this flag shows whether or not using prefix in loading the checkpoint
use_load_prefix=False,
# whether or not the qkv is stored in the split-format
split_qkv=True):
super().__init__()
self.inference = inference
self.linear_layer = linear_layer
......@@ -108,6 +112,7 @@ class TransformerPolicy(DSPolicy):
self.is_megatron_v2 = megatron_v2
self.mlp_act_func_type = mlp_act_func_type
self.pre_attn_norm = pre_attn_norm
self.load_prefix = False
def attention(self):
"""
......@@ -139,6 +144,31 @@ class TransformerPolicy(DSPolicy):
"""
raise NotImplementedError
def get_param_names(self):
"""
Returns all the transformer parameter names to
be loaded from checkpoint files. The order of
the names is as follows:
1. Attention weights and biases;
2. MLP weights and biases;
3. LayerNorm weights and biases;
In addition to the parameter names, we require two
more parameters to help read the the data correctly
from the checkpoint and split the qkv heads in the
right order:
1. `use_load_prefix` (Default: False): this specifies
whether we need to use the name of first abstraction
layer of the model for searching the parameter's name
in a checkpoint file. For more information of how this
is used please see
https://github.com/microsoft/DeepSpeed/blob/fix-ckpt-loading/deepspeed/module_inject/load_checkpoint.py#L341
2. `split_qkv` (Default: True): we use this flag when splitting
the qkv parameter into heads. If it is False, it means the heads
of q, k, and v are stored together and needs to split in the
DeepSpeed-Inference API.
"""
raise NotImplementedError
class HFBertLayerPolicy(TransformerPolicy):
def __init__(self, client_module, inference=False):
......@@ -294,6 +324,22 @@ class HFGPTNEOLayerPolicy(TransformerPolicy):
self.client_module.ln_1.weight, \
self.client_module.ln_1.bias
def get_param_names(self):
return 'attention.query_key_value.weight', \
'attention.query_key_value.bias', \
'attention.dense.weight', \
'attention.dense.bias', \
'mlp.dense_h_to_4h.weight', \
'mlp.dense_h_to_4h.bias', \
'mlp.dense_4h_to_h.weight', \
'mlp.dense_4h_to_h.bias', \
'input_layernorm.weight', \
'input_layernorm.bias', \
'post_attention_layernorm.weight', \
'post_attention_layernorm.bias', \
self.use_load_prefix, \
self.split_qkv
class HFGPTJLayerPolicy(TransformerPolicy):
_orig_layer_class = None
......@@ -339,6 +385,20 @@ class HFGPTJLayerPolicy(TransformerPolicy):
self.client_module.ln_1.weight, \
self.client_module.ln_1.bias
def get_param_names(self):
return 'attn.q_proj.weight', \
'attn.k_proj.weight', \
'attn.v_proj.weight', \
'attn.out_proj.weight', \
'mlp.fc_in.weight', \
'mlp.fc_in.bias', \
'mlp.fc_out.weight', \
'mlp.fc_out.bias', \
'ln_1.weight', \
'ln_1.bias', \
self.use_load_prefix, \
self.split_qkv
class MegatronLayerPolicy(TransformerPolicy):
_orig_layer_class = None
......@@ -463,7 +523,11 @@ class HFGPT2LayerPolicy(TransformerPolicy):
class BLOOMLayerPolicy(TransformerPolicy):
_orig_layer_class = None
def __init__(self, client_module, inference=True):
def __init__(self,
client_module,
inference=True,
use_load_prefix=True,
split_qkv=False):
super().__init__(inference, linear_layer=True)
self.client_module = client_module
try:
......@@ -501,12 +565,28 @@ class BLOOMLayerPolicy(TransformerPolicy):
self.client_module.input_layernorm.weight, \
self.client_module.input_layernorm.bias
def get_param_names(self):
return 'self_attention.query_key_value.weight', \
'self_attention.query_key_value.bias', \
'self_attention.dense.weight', \
'self_attention.dense.bias', \
'mlp.dense_h_to_4h.weight', \
'mlp.dense_h_to_4h.bias', \
'mlp.dense_4h_to_h.weight', \
'mlp.dense_4h_to_h.bias', \
'input_layernorm.weight', \
'input_layernorm.bias', \
'post_attention_layernorm.weight', \
'post_attention_layernorm.bias', \
self.use_load_prefix, \
self.split_qkv
class GPTNEOXLayerPolicy(TransformerPolicy):
_orig_layer_class = None
version = 0
def __init__(self, client_module, inference=True, megatron_v2=True):
def __init__(self, client_module, inference=True, megatron_v2=True, split_qkv=False):
super().__init__(inference, megatron_v2=megatron_v2)
self.client_module = client_module
if GPTNEOXLayerPolicy._orig_layer_class is None:
......@@ -555,11 +635,27 @@ class GPTNEOXLayerPolicy(TransformerPolicy):
self.client_module.input_layernorm.weight, \
self.client_module.input_layernorm.bias
def get_param_names(self):
return 'attention.query_key_value.weight', \
'attention.query_key_value.bias', \
'attention.dense.weight', \
'attention.dense.bias', \
'mlp.dense_h_to_4h.weight', \
'mlp.dense_h_to_4h.bias', \
'mlp.dense_4h_to_h.weight', \
'mlp.dense_4h_to_h.bias', \
'input_layernorm.weight', \
'input_layernorm.bias', \
'post_attention_layernorm.weight', \
'post_attention_layernorm.bias', \
self.use_load_prefix, \
self.split_qkv
class HFOPTLayerPolicy(TransformerPolicy):
_orig_layer_class = None
def __init__(self, client_module, inference=True):
def __init__(self, client_module, inference=True, use_load_prefix=True):
super().__init__(inference,
linear_layer=True,
mlp_act_func_type=ActivationFuncType.ReLU,
......@@ -568,9 +664,9 @@ class HFOPTLayerPolicy(TransformerPolicy):
try:
import transformers
HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer
if isinstance(DSPolicy.hf_model_config,
if isinstance(TransformerPolicy.hf_model_config,
transformers.models.opt.configuration_opt.OPTConfig):
self.pre_attn_norm = self.hf_model_config.do_layer_norm_before
self.pre_attn_norm = TransformerPolicy.hf_model_config.do_layer_norm_before
except:
HFOPTLayerPolicy._orig_layer_class = None
......@@ -612,6 +708,26 @@ class HFOPTLayerPolicy(TransformerPolicy):
self.client_module.self_attn_layer_norm.weight, \
self.client_module.self_attn_layer_norm.bias
def get_param_names(self):
return 'self_attn.q_proj.weight', \
'self_attn.q_proj.bias', \
'self_attn.k_proj.weight', \
'self_attn.k_proj.bias', \
'self_attn.v_proj.weight', \
'self_attn.v_proj.bias', \
'self_attn.out_proj.weight', \
'self_attn.out_proj.bias', \
'fc1.weight', \
'fc1.bias', \
'fc2.weight', \
'fc2.bias', \
'self_attn_layer_norm.weight', \
'self_attn_layer_norm.bias', \
'final_layer_norm.weight', \
'final_layer_norm.bias', \
self.use_load_prefix, \
self.split_qkv
# transformer-based policies
replace_policies = [
......
......@@ -331,7 +331,6 @@ class DeepSpeedSelfAttentionFunction(Function):
False,
attn_ow.scale,
config.q_int8)
return output, key_layer, value_layer, context_layer, qkv_out[-1]
def selfAttention_int8():
......@@ -394,7 +393,7 @@ class DeepSpeedSelfAttention(nn.Module):
data_type_fp = torch.half if config.fp16 else torch.float
self.config.layer_id = DeepSpeedSelfAttention.num_layers
DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1
device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu'
device = torch.cuda.current_device() #if config.bigscience_bloom else 'cpu'
qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3
self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size,
qkv_size_per_partition,
......
......@@ -41,10 +41,13 @@ class DeepSpeedMLPFunction(Function):
if attn_nw is None:
output = fused_gemm_gelu(residual_norm,
inter_w,
inter_w.scale,
inter_b,
output_w,
output_w.scale,
config.epsilon,
config.pre_layer_norm,
config.q_int8,
False)
else:
output, residual_add = mlp_gemm_func(input,
......@@ -96,7 +99,7 @@ class DeepSpeedMLP(nn.Module):
self.config = config
data_type = torch.int8 if config.q_int8 else torch.half if config.fp16 else torch.float
data_type_fp = torch.half if config.fp16 else torch.float
device = torch.cuda.current_device() if config.bigscience_bloom else 'cpu'
device = torch.cuda.current_device() #if config.bigscience_bloom else 'cpu'
self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size,
dtype=data_type_fp,
device=device),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册