未验证 提交 fd1449c7 编写于 作者: L Lev Kurilenko 提交者: GitHub

Port Reza's INT8-quantization fix to container architecture (#2725)

Co-authored-by: NReza Yazdani <reyazda@microsoft.com>
Co-authored-by: NReza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com>
Co-authored-by: NHeyang Qin <heyangqin@microsoft.com>
Co-authored-by: NMichael Wyatt <michaelwyatt@microsoft.com>
上级 e4b3b610
......@@ -768,15 +768,15 @@ void quantized_gemm(void* output,
int bsz,
int hidden_size)
{
T* weight16 = (T*)Context::Instance().GetWorkSpace() + 12 * hidden_size * bsz;
// auto options = at::TensorOptions()
// .dtype(at::kHalf)
// .layout(at::kStrided)
// .device(at::kCUDA)
// .requires_grad(false);
// auto tmp = torch::empty(weight.sizes(), options);
// T* weight16 = (T*)tmp.data_ptr();
// T* weight16 = (T*)Context::Instance().GetWorkSpace() + 12 * hidden_size * bsz;
auto options = at::TensorOptions()
.dtype(at::kHalf)
.layout(at::kStrided)
.device(at::kCUDA)
.requires_grad(false);
auto tmp = torch::empty(weight.sizes(), options);
T* weight16 = (T*)tmp.data_ptr();
launch_dequantize(weight16,
(int8_t*)weight.data_ptr(),
(float*)qscale.data_ptr(),
......
......@@ -2,6 +2,8 @@ from .base import *
from .features.meta_tensor import MetaTensorContainer
from deepspeed.model_implementations.transformers.ds_bloom import DeepSpeedBloomInference
from ..policy import TransformerPolicy
from ..policy import transformer_param_names
from ..policy import maybe_copy
supported_models = {None}
......@@ -28,21 +30,52 @@ class DS_BloomContainer(MetaTensorContainer, BaseTransformerContainer):
self.module.attention.attn_qkvb,
self.qkvb)
def get_param_names(self):
return 'self_attention.query_key_value.weight', \
'self_attention.query_key_value.bias', \
'self_attention.dense.weight', \
'self_attention.dense.bias', \
'mlp.dense_h_to_4h.weight', \
'mlp.dense_h_to_4h.bias', \
'mlp.dense_4h_to_h.weight', \
'mlp.dense_4h_to_h.bias', \
'input_layernorm.weight', \
'input_layernorm.bias', \
'post_attention_layernorm.weight', \
'post_attention_layernorm.bias', \
self.policy.use_load_prefix, \
self.policy.split_qkv
def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
param_names = (
'self_attention.query_key_value.weight', \
'self_attention.query_key_value.bias', \
'self_attention.dense.weight', \
'self_attention.dense.bias', \
'mlp.dense_h_to_4h.weight', \
'mlp.dense_h_to_4h.bias', \
'mlp.dense_4h_to_h.weight', \
'mlp.dense_4h_to_h.bias', \
'post_attention_layernorm.weight', \
'post_attention_layernorm.bias', \
'input_layernorm.weight', \
'input_layernorm.bias'
)
for i in range(0, 2):
maybe_copy(module.attention,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i],
qkv=True,
megatron_v2=self.is_megatron_v2,
split_qkv=self.split_qkv)
for i in range(2, 4):
maybe_copy(module.attention,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i])
for i in range(4, 10):
maybe_copy(module.mlp,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i])
for i in range(10, 12):
maybe_copy(module,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i])
class BLOOMLayerPolicy(TransformerPolicy):
......@@ -82,9 +115,9 @@ class BLOOMLayerPolicy(TransformerPolicy):
def mlp(self):
return self.client_module.mlp.dense_h_to_4h.weight, \
self.client_module.mlp.dense_h_to_4h.bias, \
self.client_module.mlp.dense_4h_to_h.weight, \
self.client_module.mlp.dense_4h_to_h.bias
self.client_module.mlp.dense_h_to_4h.bias, \
self.client_module.mlp.dense_4h_to_h.weight, \
self.client_module.mlp.dense_4h_to_h.bias
def layernorm(self):
return self.client_module.post_attention_layernorm.weight, \
......
......@@ -33,14 +33,9 @@ class MetaTensorContainer(ABC):
super().transpose()
@abstractmethod
def get_param_names(self):
def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
"""
Returns all the transformer parameter names to
be loaded from checkpoint files. The order of
the names is as follows:
1. Attention weights and biases;
2. MLP weights and biases;
3. LayerNorm weights and biases;
Load all the transformer parameter from the checkpoint file (sd).
In addition to the parameter names, we require two
more parameters to help read the the data correctly
from the checkpoint and split the qkv heads in the
......@@ -50,12 +45,12 @@ class MetaTensorContainer(ABC):
layer of the model for searching the parameter's name
in a checkpoint file. For more information of how this
is used please see
https://github.com/microsoft/DeepSpeed/blob/fix-ckpt-loading/deepspeed/module_inject/load_checkpoint.py#L341
https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/module_inject/load_checkpoint.py
2. `split_qkv` (Default: True): we use this flag when splitting
the qkv parameter into heads. If it is False, it means the heads
of q, k, and v are stored together and needs to split in the
DeepSpeed-Inference API.
"""
raise NotImplementedError(
"A get_param_names() function must be defined in the model container \
"A load_params() function must be defined in the model container \
when inheriting the MetaTensorContainer feature")
......@@ -4,6 +4,9 @@ from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInfe
import torch
from torch.nn.parameter import Parameter
from ..policy import TransformerPolicy
from ..policy import transformer_param_names
from ..policy import maybe_copy
from ..policy import maybe_copy_qkv
class DS_GPTJContainer(MetaTensorContainer, BaseTransformerContainer):
......@@ -18,19 +21,50 @@ class DS_GPTJContainer(MetaTensorContainer, BaseTransformerContainer):
self.module.config.scale_attention = self.scale_attention
return self.module
def get_param_names(self):
return 'attn.q_proj.weight', \
'attn.k_proj.weight', \
'attn.v_proj.weight', \
'attn.out_proj.weight', \
'mlp.fc_in.weight', \
'mlp.fc_in.bias', \
'mlp.fc_out.weight', \
'mlp.fc_out.bias', \
'ln_1.weight', \
'ln_1.bias', \
self.policy.use_load_prefix, \
self.policy.split_qkv
def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
param_names = (
'attn.q_proj.weight', \
'attn.k_proj.weight', \
'attn.v_proj.weight', \
'attn.out_proj.weight', \
'mlp.fc_in.weight', \
'mlp.fc_in.bias', \
'mlp.fc_out.weight', \
'mlp.fc_out.bias', \
'ln_1.weight', \
'ln_1.bias'
)
maybe_copy_qkv(
module.attention,
sd,
weight_quantizer,
mp_replace,
'attn_qkvw',
[prefix + param_names[0],
prefix + param_names[1],
prefix + param_names[2]],
split_qkv=self.split_qkv)
for i in range(3, 4):
maybe_copy(module.attention,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 1],
prefix + param_names[i])
for i in range(4, 8):
maybe_copy(module.mlp,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i])
for i in range(8, 10):
maybe_copy(module,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i + 2],
prefix + param_names[i])
class HFGPTJLayerPolicy(TransformerPolicy):
......
......@@ -4,6 +4,9 @@ from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInfe
import torch
from torch.nn.parameter import Parameter
from ..policy import TransformerPolicy
from ..policy import transformer_param_names
from ..policy import maybe_copy
from ..policy import maybe_copy_qkv
class DS_GPTNEOContainer(MetaTensorContainer, BaseTransformerContainer):
......@@ -18,21 +21,53 @@ class DS_GPTNEOContainer(MetaTensorContainer, BaseTransformerContainer):
self.module.config.scale_attention = self.scale_attention
return self.module
def get_param_names(self):
return 'attention.query_key_value.weight', \
'attention.query_key_value.bias', \
'attention.dense.weight', \
'attention.dense.bias', \
'mlp.dense_h_to_4h.weight', \
'mlp.dense_h_to_4h.bias', \
'mlp.dense_4h_to_h.weight', \
'mlp.dense_4h_to_h.bias', \
'input_layernorm.weight', \
'input_layernorm.bias', \
'post_attention_layernorm.weight', \
'post_attention_layernorm.bias', \
self.policy.use_load_prefix, \
self.policy.split_qkv
def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
param_names = (
'attn.attention.q_proj.weight', \
'attn.attention.k_proj.weight', \
'attn.attention.v_proj.weight', \
'attn.attention.out_proj.weight', \
'attn.attention.out_proj.bias', \
'mlp.c_fc.weight', \
'mlp.c_fc.bias', \
'mlp.c_proj.weight', \
'mlp.c_proj.bias', \
'ln_2.weight', \
'ln_2.bias', \
'ln_1.weight', \
'ln_1.bias'
)
maybe_copy_qkv(
module.attention,
sd,
weight_quantizer,
mp_replace,
'attn_qkvw',
[prefix + param_names[0],
prefix + param_names[1],
prefix + param_names[2]],
split_qkv=self.split_qkv)
for i in range(3, 5):
maybe_copy(module.attention,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 1],
prefix + param_names[i])
for i in range(5, 11):
maybe_copy(module.mlp,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 1],
prefix + param_names[i])
for i in range(11, 13):
maybe_copy(module,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 1],
prefix + param_names[i])
class HFGPTNEOLayerPolicy(TransformerPolicy):
......
......@@ -4,6 +4,8 @@ from .features.megatron import MegatronContainer
from deepspeed.model_implementations.transformers.ds_gpt import DeepSpeedGPTInference
import torch
from ..policy import TransformerPolicy
from ..policy import transformer_param_names
from ..policy import maybe_copy
from packaging import version as pkg_version
......@@ -26,21 +28,53 @@ class DS_GPTNEOXContainer(MetaTensorContainer,
return self.module
def get_param_names(self):
return 'attention.query_key_value.weight', \
'attention.query_key_value.bias', \
'attention.dense.weight', \
'attention.dense.bias', \
'mlp.dense_h_to_4h.weight', \
'mlp.dense_h_to_4h.bias', \
'mlp.dense_4h_to_h.weight', \
'mlp.dense_4h_to_h.bias', \
'input_layernorm.weight', \
'input_layernorm.bias', \
'post_attention_layernorm.weight', \
'post_attention_layernorm.bias', \
self.policy.use_load_prefix, \
self.policy.split_qkv
def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
param_names = (
'attention.query_key_value.weight', \
'attention.query_key_value.bias', \
'attention.dense.weight', \
'attention.dense.bias', \
'mlp.dense_h_to_4h.weight', \
'mlp.dense_h_to_4h.bias', \
'mlp.dense_4h_to_h.weight', \
'mlp.dense_4h_to_h.bias', \
'post_attention_layernorm.weight', \
'post_attention_layernorm.bias', \
'input_layernorm.weight', \
'input_layernorm.bias'
)
for i in range(0, 2):
maybe_copy(module.attention,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i],
qkv=True,
megatron_v2=self.is_megatron_v2,
split_qkv=self.split_qkv,
heads=self.client_module.attention.num_attention_heads)
for i in range(2, 4):
maybe_copy(module.attention,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i])
for i in range(4, 10):
maybe_copy(module.mlp,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i])
for i in range(10, 12):
maybe_copy(module,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i])
class GPTNEOXLayerPolicy(TransformerPolicy):
......
......@@ -4,6 +4,9 @@ from deepspeed.model_implementations.transformers.ds_opt import DeepSpeedOPTInfe
import torch
from torch.nn.parameter import Parameter
from ..policy import TransformerPolicy
from ..policy import transformer_param_names
from ..policy import maybe_copy
from ..policy import maybe_copy_qkv
from deepspeed.utils.types import ActivationFuncType
......@@ -19,25 +22,59 @@ class DS_OPTContainer(MetaTensorContainer, BaseTransformerContainer):
self.module.config.scale_attention = self.scale_attention
return self.module
def get_param_names(self):
return 'self_attn.q_proj.weight', \
'self_attn.q_proj.bias', \
'self_attn.k_proj.weight', \
'self_attn.k_proj.bias', \
'self_attn.v_proj.weight', \
'self_attn.v_proj.bias', \
'self_attn.out_proj.weight', \
'self_attn.out_proj.bias', \
'fc1.weight', \
'fc1.bias', \
'fc2.weight', \
'fc2.bias', \
'self_attn_layer_norm.weight', \
'self_attn_layer_norm.bias', \
'final_layer_norm.weight', \
'final_layer_norm.bias', \
self.policy.use_load_prefix, \
self.policy.split_qkv
def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
param_names = (
'self_attn.q_proj.weight', \
'self_attn.k_proj.weight', \
'self_attn.v_proj.weight', \
'self_attn.q_proj.bias', \
'self_attn.k_proj.bias', \
'self_attn.v_proj.bias', \
'self_attn.out_proj.weight', \
'self_attn.out_proj.bias', \
'fc1.weight', \
'fc1.bias', \
'fc2.weight', \
'fc2.bias', \
'final_layer_norm.weight', \
'final_layer_norm.bias', \
'self_attn_layer_norm.weight', \
'self_attn_layer_norm.bias'
)
for i in range(0, 6, 3):
maybe_copy_qkv(module.attention,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i // 3],
[
prefix + param_names[i],
prefix + param_names[i + 1],
prefix + param_names[i + 2]
],
split_qkv=self.split_qkv)
for i in range(6, 8):
maybe_copy(module.attention,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 4],
prefix + param_names[i])
for i in range(8, 14):
maybe_copy(module.mlp,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 4],
prefix + param_names[i])
for i in range(14, 16):
maybe_copy(module,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i - 4],
prefix + param_names[i])
class HFOPTLayerPolicy(TransformerPolicy):
......
......@@ -16,11 +16,10 @@ def load_model_with_checkpoint(r_module,
sd,
mp_replace,
ckpt_type,
ckpt_mp_size,
weight_quantizer=None,
rank=0,
param_names=None,
transformer_config=None,
megatron_v2=False):
replace_policy=None):
error_msgs = []
def transpose(data):
......@@ -37,6 +36,12 @@ def load_model_with_checkpoint(r_module,
if hasattr(module, 'weight'):
module.weight = mp_replace.copy(module.weight.data, sd[0][prefix + 'weight'])
if prefix + 'bias' in sd[0].keys():
if module.bias.data.is_meta:
# meta tensor cannot be casted or copied to, so we need to replace it with a normal tensor here
module.bias = torch.nn.parameter.Parameter(
data=torch.empty_like(module.bias.data,
device="cpu"),
requires_grad=module.bias.data.requires_grad)
module.bias = mp_replace.copy(module.bias.data, sd[0][prefix + 'bias'])
args = None
gc.collect()
......@@ -51,6 +56,8 @@ def load_model_with_checkpoint(r_module,
tmp_data, scale = sd[0][prefix + n]
tmp_data = tmp_data
scale = scale.to(get_accelerator().current_device_name())
# set the quantizer number of groups using the checkpoint scale shape
weight_quantizer.num_groups = scale.shape[0]
else:
tmp_data = sd[0][prefix + n].to(
get_accelerator().current_device_name())
......@@ -100,11 +107,31 @@ def load_model_with_checkpoint(r_module,
get_accelerator().current_device_name())
for j in range(len(sd))
]
weight_partition = torch.cat([
ad[0].to(get_accelerator().current_device_name())
if type(ad) is list else ad for ad in all_data
],
dim=dim)
# Check if the weight tensor is for the QKV parameter
if src_shape[1] == (3 *
src_shape[0]) // ckpt_mp_size:
qkv_size = src_shape[outer_dim] // 3
src_split = [
torch.split(src[0].data,
qkv_size,
dim=outer_dim)
for src in all_data
]
weight_partition = torch.cat([
torch.cat([qkv_s[i] for qkv_s in src_split],
axis=outer_dim)
for i in range(len(src_split[0]))
],
dim=dim)
else:
weight_partition = torch.cat([
ad[0].to(
get_accelerator().current_device_name())
if type(ad) is list else ad
for ad in all_data
],
dim=dim)
if tmp_data.dtype == torch.int8:
scale = torch.cat([
ad[1].to(
......@@ -135,150 +162,46 @@ def load_model_with_checkpoint(r_module,
).current_device_name()).contiguous()
p.data.copy_(bias_split)
else:
p.data.copy_(
torch.cat(
[sd[j][prefix + n] for j in range(len(sd))],
dim=0).to(get_accelerator(
).current_device_name()).contiguous())
# Check if the weight tensor is for the QKV parameter
if src_shape[0] == (3 * r_module.config.hidden_size
) // ckpt_mp_size:
qkv_size = src_shape[0] // 3
src_split = [
torch.split(sd[j][prefix + n],
qkv_size,
dim=0) for j in range(len(sd))
]
p.data.copy_(
torch.cat(
[
torch.cat([
qkv_s[i] for qkv_s in src_split
],
axis=0)
for i in range(len(src_split[0]))
],
dim=0).to(get_accelerator(
).current_device_name()).contiguous())
else:
p.data.copy_(
torch.cat(
[
sd[j][prefix + n]
for j in range(len(sd))
],
dim=0).to(get_accelerator(
).current_device_name()).contiguous())
load_parameters(module, prefix)
for n, child in module.named_children():
load_parameters(child, prefix + n + '.')
else:
def _transpose(x):
heads = transformer_config.heads // mp_replace.mp_size
attention_head_size = x.shape[-1] // heads
new_x_shape = x.size()[:-1] + (heads, attention_head_size)
x_1 = x.view(*new_x_shape)
(q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=(x_1.dim() - 1))
if len(q.shape) > 2:
return torch.cat((q.reshape(q.shape[0],
-1),
k.reshape(q.shape[0],
-1),
v.reshape(q.shape[0],
-1)),
dim=-1).reshape(x.shape)
else:
return torch.cat((q.reshape(-1),
k.reshape(-1),
v.reshape(-1)),
dim=-1).reshape(x.shape)
# This checks if the parameter exits in the checkpoint file and maybe copies it into the corresponding destination tensor.
# Note that not all parameters are saved in one checkpoint, that's why we always need to check if they exist!
def maybe_copy(module,
dst_name,
src_name,
qkv=False,
megatron_v2=False,
split_qkv=False):
if src_name in sd[0]:
dst = getattr(module, dst_name)
tmp = sd[0][src_name].cuda()
if len(dst.shape) == 1:
if split_qkv:
dst = mp_replace.qkv_copy(dst, tmp)
else:
dst = mp_replace.copy(dst, tmp)
if qkv and megatron_v2:
dst = torch.nn.parameter.Parameter(
_transpose(dst).contiguous())
else:
if split_qkv:
dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, tmp if weight_quantizer.q_int8 else \
(transpose(tmp).contiguous())))
else:
dst = weight_quantizer.quantize(mp_replace.copy(dst, tmp if weight_quantizer.q_int8 else \
transpose(tmp)))
if qkv and megatron_v2:
scale1 = dst.scale
dst = torch.nn.parameter.Parameter(
_transpose(dst).contiguous())
dst.scale = scale1
setattr(module, dst_name, dst)
# Extending the maybe_copy function for when the q, k, and v are in separate parameters!
def maybe_copy_qkv(module, dst_name, src_names, split_qkv=False):
if src_names[0] in sd[0]:
q = sd[0][src_names[0]]
k = sd[0][src_names[1]]
v = sd[0][src_names[2]]
qkv_data = torch.cat((q, k, v), dim=0)
dst = getattr(module, dst_name)
if len(dst.shape) == 1:
if split_qkv:
dst = mp_replace.qkv_copy(dst,
(qkv_data.cuda()).contiguous())
else:
dst = mp_replace.copy(dst, qkv_data.cuda())
else:
if split_qkv:
dst = weight_quantizer.quantize(mp_replace.qkv_copy(dst, qkv_data.cuda() if weight_quantizer.q_int8 else \
((transpose(qkv_data.cuda())).contiguous())))
else:
dst = weight_quantizer.quantize(mp_replace.copy(dst, qkv_data.cuda() if weight_quantizer.q_int8 else \
transpose(qkv_data.cuda())))
setattr(module, dst_name, dst)
if len(param_names) == 14:
qkv_w, qkv_b, attn_ow, attn_ob, \
mlp_intw, mlp_intb, mlp_ow, mlp_ob, \
inp_normw, inp_normb, attn_nw, attn_nb, _, split_qkv = param_names
elif len(param_names) < 14:
q_w, k_w, v_w, attn_ow, \
mlp_intw, mlp_intb, mlp_ow, mlp_ob, \
inp_normw, inp_normb, _, split_qkv = param_names
else:
q_w, q_b, k_w, k_b, v_w, v_b, attn_ow, attn_ob, \
mlp_intw, mlp_intb, mlp_ow, mlp_ob, \
inp_normw, inp_normb, attn_nw, attn_nb, _, split_qkv = param_names
maybe_copy(module, 'norm_w', prefix + inp_normw)
maybe_copy(module, 'norm_b', prefix + inp_normb)
if len(param_names) == 14:
maybe_copy(module.attention,
'attn_qkvw',
prefix + qkv_w,
qkv=True,
megatron_v2=megatron_v2,
split_qkv=split_qkv)
maybe_copy(module.attention,
'attn_qkvb',
prefix + qkv_b,
qkv=True,
megatron_v2=megatron_v2,
split_qkv=split_qkv)
elif len(param_names) < 14:
maybe_copy_qkv(module.attention,
'attn_qkvw',
[prefix + q_w,
prefix + k_w,
prefix + v_w],
split_qkv=split_qkv)
else:
maybe_copy_qkv(module.attention,
'attn_qkvw',
[prefix + q_w,
prefix + k_w,
prefix + v_w],
split_qkv=split_qkv)
maybe_copy_qkv(module.attention,
'attn_qkvb',
[prefix + q_b,
prefix + k_b,
prefix + v_b],
split_qkv=split_qkv)
maybe_copy(module.attention, 'attn_ow', prefix + attn_ow)
if len(param_names) >= 14:
maybe_copy(module.attention, 'attn_ob', prefix + attn_ob)
maybe_copy(module.mlp, 'attn_nw', prefix + attn_nw)
maybe_copy(module.mlp, 'attn_nb', prefix + attn_nb)
maybe_copy(module.mlp, 'inter_w', prefix + mlp_intw)
maybe_copy(module.mlp, 'inter_b', prefix + mlp_intb)
maybe_copy(module.mlp, 'output_w', prefix + mlp_ow)
maybe_copy(module.mlp, 'output_b', prefix + mlp_ob)
replace_policy.load_params(module,
sd[0],
weight_quantizer,
mp_replace,
prefix)
try:
import transformers
......@@ -345,24 +268,22 @@ def load_model_with_checkpoint(r_module,
if ds_id is not None:
all_ds_ids[ds_id] = child.weight
setattr(module, name, child)
layer_policies[child.__class__](child, prefix + name + '.')
else:
load_module_recursive(
child,
prefix if (level == 0 and ckpt_type == 'pp') and param_names[-2] else \
prefix if (level == 0 and ckpt_type == 'pp') and replace_policy.use_load_prefix else \
prefix + name + '.',
level + 1)
load_module_recursive(r_module)
#XXX: hack to tie embedding w. lm_head for BLOOM, need to revist soon
embedding_weight = None
for n, p in r_module.named_parameters():
if "word_embeddings." in n or "embed_tokens." in n:
if "word_embeddings." in n or "embed_tokens." in n or "wte." in n:
embedding_weight = p
if embedding_weight is not None:
if embedding_weight is not None and r_module.lm_head.weight.is_meta:
r_module.lm_head.weight = embedding_weight
for sd_ in sd:
del sd_
......
......@@ -3,6 +3,21 @@ Copyright 2022 The Microsoft DeepSpeed Team
'''
from abc import ABC, abstractmethod
from deepspeed.utils.types import ActivationFuncType
import torch
transformer_param_names = (
'attn_qkvw', \
'attn_qkvb', \
'attn_ow' , \
'attn_ob', \
'inter_w', \
'inter_b', \
'output_w', \
'output_b', \
'attn_nw', \
'attn_nb', \
'norm_w', \
'norm_b')
class DSPolicy(ABC):
......@@ -86,3 +101,104 @@ class TransformerPolicy(DSPolicy):
gamma and beta with shape: (hidden)
"""
raise NotImplementedError
# TODO (lekurile): This function exists in base container as well, consolidate as some point
def transpose(data):
with torch.no_grad():
data = data.contiguous()
data1 = data.transpose(-1, -2).reshape(-1)
data.reshape(-1).copy_(data1)
data1 = None
return data.reshape(data.shape[-1], data.shape[-2])
# TODO (lekurile): This function exists in megatron feature container as well, consolidate as some point
def _transpose(x, heads=1, mp_replace=None):
heads = heads // mp_replace.mp_size
outer_dim = -1
attention_head_size = x.shape[outer_dim] // heads
new_x_shape = x.size()[:outer_dim] + (heads, attention_head_size)
x_1 = x.view(*new_x_shape)
(q, k, v) = torch.split(x_1, (x_1.shape[-1] // 3), dim=-1)
if len(q.shape) > 2:
new_shape = (q.shape[0], ) + (-1, )
return torch.cat((q.reshape(new_shape),
k.reshape(new_shape),
v.reshape(new_shape)),
dim=outer_dim).reshape(x.shape)
else:
return torch.cat((q.reshape(-1),
k.reshape(-1),
v.reshape(-1)),
dim=-1).reshape(x.shape)
# This checks if the parameter exits in the checkpoint file and maybe copies it into the corresponding destination tensor.
# Note that not all parameters are saved in one checkpoint, that's why we always need to check if they exist!
def maybe_copy(module,
sd,
weight_quantizer,
mp_replace,
dst_name,
src_name,
qkv=False,
megatron_v2=False,
split_qkv=False,
heads=1):
if src_name in sd:
dst = getattr(module, dst_name)
tmp = sd[src_name]
if len(dst.shape) == 1:
if split_qkv:
dst = mp_replace.qkv_copy(dst, tmp)
else:
dst = mp_replace.copy(dst, tmp)
if qkv and megatron_v2:
dst = torch.nn.parameter.Parameter(
_transpose(dst,
heads=heads,
mp_replace=mp_replace).contiguous())
else:
if split_qkv:
dst = mp_replace.qkv_copy(dst, weight_quantizer.quantize(tmp if weight_quantizer.q_int8 else \
(transpose(tmp).contiguous())), int8=weight_quantizer.q_int8)
else:
if qkv and megatron_v2:
tmp = _transpose(transpose(tmp),
heads=heads,
mp_replace=mp_replace).contiguous()
if weight_quantizer.q_int8:
tmp = transpose(tmp)
dst = mp_replace.copy(dst, weight_quantizer.quantize(tmp if weight_quantizer.q_int8 else \
transpose(tmp)), int8=weight_quantizer.q_int8)
setattr(module, dst_name, dst)
# Extending the maybe_copy function for when the q, k, and v are in separate parameters!
def maybe_copy_qkv(module,
sd,
weight_quantizer,
mp_replace,
dst_name,
src_names,
split_qkv=False):
if src_names[0] in sd:
q = sd[src_names[0]]
k = sd[src_names[1]]
v = sd[src_names[2]]
qkv_data = torch.cat((q, k, v), dim=0)
dst = getattr(module, dst_name)
if len(dst.shape) == 1:
if split_qkv:
dst = mp_replace.qkv_copy(dst, qkv_data.contiguous())
else:
dst = mp_replace.copy(dst, qkv_data)
else:
if split_qkv:
dst = mp_replace.qkv_copy(dst, weight_quantizer.quantize(qkv_data.cuda() if weight_quantizer.q_int8 else \
((transpose(qkv_data)).contiguous())), int8=weight_quantizer.q_int8)
else:
dst = mp_replace.copy(dst, weight_quantizer.quantize(qkv_data.cuda() if weight_quantizer.q_int8 else \
transpose(qkv_data)), int8=weight_quantizer.q_int8)
setattr(module, dst_name, dst)
......@@ -36,35 +36,39 @@ class ReplaceWithTensorSlicing:
for merging your checkpoints before replacing the transformer layer with\
inference-kernels'
def qkv_copy(self, dst, src):
def qkv_copy(self, dst, src, int8=False):
if src is None:
return src
src_shape = src.shape
dst_shape = dst.shape
if self.out_dim == 0:
src_split = torch.split(src.data,
src_shape[self.out_dim] // self.mp_size,
dim=0)
else:
src_split = torch.split(src.data, src.shape[-1] // 3, dim=-1)
outer_dim = 0 if int8 else -1
inner_dim = -1 if int8 else 0
src_split = torch.split(src.data, src.shape[outer_dim] // 3, dim=outer_dim)
if (len(src_shape) == 2 and len(dst_shape) == 2):
if src_shape[self.out_dim] == dst_shape[self.out_dim]:
return torch.nn.parameter.Parameter(src)
if src_shape[outer_dim] == dst_shape[self.out_dim]:
dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
if hasattr(src, 'scale'):
dst.scale = src.scale
return dst
if self.out_dim == 1:
self.merge_assert(src_shape[self.out_dim], dst_shape[self.out_dim])
self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
qkv_size = dst_shape[self.out_dim] // 3
qkv_split = [
torch.split(src_s,
qkv_size,
dim=self.out_dim) for src_s in src_split
dim=outer_dim) for src_s in src_split
]
weight_split = [
torch.cat([qkv_s[i] for qkv_s in qkv_split],
axis=self.out_dim) for i in range(len(qkv_split[0]))
axis=outer_dim) for i in range(len(qkv_split[0]))
]
dst.data.copy_(weight_split[self.gpu_index].to(
get_accelerator().current_device_name()).contiguous())
dst = dst.reshape(-1).data.copy_(
weight_split[self.gpu_index].contiguous().reshape(-1)).reshape(
weight_split[self.gpu_index].shape)
else:
dst.data.copy_(src_split[self.gpu_index].to(
get_accelerator().current_device_name()).contiguous())
......@@ -78,47 +82,49 @@ class ReplaceWithTensorSlicing:
torch.cat([qkv_s[i] for qkv_s in qkv_split],
axis=0) for i in range(len(qkv_split[0]))
]
dst.data.copy_(bias_split[self.gpu_index].to(
get_accelerator().current_device_name()).contiguous())
dst.data.copy_(bias_split[self.gpu_index].contiguous())
else:
dst.data.copy_(src_split[self.gpu_index].to(
get_accelerator().current_device_name()).contiguous())
dst.data.copy_(src_split[self.gpu_index].contiguous())
return torch.nn.parameter.Parameter(dst)
dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
if hasattr(src, 'scale'):
dst.scale = src.scale
return dst
def copy(self, dst, src):
def copy(self, dst, src, int8=False):
if src is None:
return src
assert not dst.data.is_meta # the torch.Tensor.copy_ method used below will silently fail on meta tensors
outer_dim = 0 if int8 else 1
inner_dim = 1 if int8 else 0
src_shape = src.shape
dst_shape = dst.shape
if (len(src_shape) == 2 and len(dst_shape) == 2):
if src_shape[0] == dst_shape[0] and src_shape[1] == dst_shape[1]:
dst.data.copy_(src)
if src_shape[inner_dim] == dst_shape[
self.in_dim] and src_shape[outer_dim] == dst_shape[self.out_dim]:
dst = dst.reshape(-1).data.copy_(src.data.reshape(-1)).reshape(src.shape)
else:
if src_shape[self.in_dim] != dst_shape[self.in_dim]:
self.merge_assert(src_shape[self.in_dim], dst_shape[self.in_dim])
if src_shape[inner_dim] != dst_shape[self.in_dim]:
self.merge_assert(src_shape[inner_dim], dst_shape[self.in_dim])
weight_split = torch.split(
src,
dst_shape[self.in_dim],
dim=self.in_dim)[self.gpu_index].to(
get_accelerator().current_device_name()).contiguous()
dim=inner_dim)[self.gpu_index].contiguous()
else:
self.merge_assert(src_shape[self.out_dim], dst_shape[self.out_dim])
self.merge_assert(src_shape[outer_dim], dst_shape[self.out_dim])
weight_split = torch.split(
src.data,
dst_shape[self.out_dim],
dim=self.out_dim)[self.gpu_index].to(
get_accelerator().current_device_name()).contiguous()
dst.data.copy_(weight_split.contiguous())
dim=outer_dim)[self.gpu_index].contiguous()
dst = dst.reshape(-1).data.copy_(weight_split.reshape(-1)).reshape(
weight_split.shape)
else:
if src_shape[0] == dst_shape[0]:
dst.data.copy_(src)
else:
bias_split = torch.split(
src.data,
dst_shape[-1])[self.gpu_index].to(
get_accelerator().current_device_name()).contiguous()
bias_split = torch.split(src.data,
dst_shape[-1])[self.gpu_index].contiguous()
dst.data.copy_(bias_split)
dst = torch.nn.parameter.Parameter(dst, requires_grad=False)
if hasattr(src, 'scale'):
......@@ -142,12 +148,12 @@ def get_transformer_name(replaced_module):
class GroupQuantizer:
def __init__(self, q_int8=True, group_size=1, num_bits=8):
def __init__(self, q_int8=True, group_size=1, num_bits=8, num_groups=0):
self.group_size = group_size
self.num_bits = num_bits
self.q_int8 = q_int8
# TODO(jeff): need to check w. Reza on why this is needed when changing tp size w. bloom
self.num_groups = 32
self.num_groups = num_groups
def quantize(self, inputs, qkv=True, count=1, parallel_dim=0):
if not self.q_int8 or not qkv:
......@@ -155,7 +161,8 @@ class GroupQuantizer:
inputs.scale = torch.empty(1)
return inputs
q_range = 2**self.num_bits
num_groups = inputs.shape[0] // self.group_size
num_groups = self.num_groups if self.num_groups > 0 else inputs.shape[
0] // self.group_size
inputs = inputs.to(get_accelerator().current_device_name())
input_flat = inputs.reshape(num_groups, -1).contiguous()
input_min = torch.min(input_flat, dim=1, keepdim=True)[0].float()
......@@ -164,7 +171,6 @@ class GroupQuantizer:
input_flat = (input_flat / scale).round().clamp(-q_range // 2, q_range // 2 - 1)
inputs_q = input_flat.reshape(inputs.shape).to(torch.int8).contiguous()
out = torch.nn.Parameter(inputs_q, requires_grad=False)
#print(inputs.shape)
inputs_split = inputs.split(inputs.shape[parallel_dim] // 2, dim=parallel_dim)
input_flat = [
inputs_split[i].reshape(num_groups,
......@@ -385,7 +391,7 @@ def replace_transformer_layer(orig_layer_impl,
# 10. copy the tensors from the model-specific container to the new module
_container.copy_data_to_new_module()
# 11. set globals for generic checkpoint loading
# 11. set global for generic checkpoint loading
global container_g
if container_g is None:
......@@ -552,15 +558,13 @@ def replace_transformer_layer(orig_layer_impl,
checkpoint[i]),
map_location='cpu')
]
load_model_with_checkpoint(
replaced_module,
sd,
mp_replace,
ckpt_type,
quantizer,
param_names=container_g.get_param_names(),
transformer_config=container_g.ds_model_config,
megatron_v2=container_g.megatron_v2)
load_model_with_checkpoint(replaced_module,
sd,
mp_replace,
ckpt_type,
ckpt_mp_size,
quantizer,
replace_policy=container_g.policy)
pbar.update(1)
else:
import gc
......@@ -584,16 +588,14 @@ def replace_transformer_layer(orig_layer_impl,
torch.load(ckpt_file,
map_location='cpu') for ckpt_file in ckpt_files
]
load_model_with_checkpoint(
replaced_module,
sds,
mp_replace,
ckpt_type,
quantizer,
int(rank % tp_split_size),
param_names=container_g.get_param_names(),
transformer_config=container_g.ds_model_config,
megatron_v2=container_g.megatron_v2)
load_model_with_checkpoint(replaced_module,
sds,
mp_replace,
ckpt_type,
ckpt_mp_size,
quantizer,
int(rank % tp_split_size),
replace_policy=container_g.policy)
sds = [None for _ in sds]
gc.collect()
......@@ -608,16 +610,14 @@ def replace_transformer_layer(orig_layer_impl,
checkpoint["non_tp"][i]
) if base_dir1 else checkpoint["non_tp"][i]
sds = [torch.load(ckpt_file, map_location='cpu')]
load_model_with_checkpoint(
replaced_module,
sds,
mp_replace,
ckpt_type,
quantizer,
int(rank % tp_split_size),
param_names=container_g.get_param_names(),
transformer_config=container_g.ds_model_config,
megatron_v2=container_g.megatron_v2)
load_model_with_checkpoint(replaced_module,
sds,
mp_replace,
ckpt_type,
ckpt_mp_size,
quantizer,
int(rank % tp_split_size),
replace_policy=container_g.policy)
sds = [None for _ in sds]
gc.collect()
print(f"checkpoint loading time at rank {rank}: {time.time()-start_time} sec")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册