未验证 提交 5e16eb2c 编写于 作者: W Wang, Yi 提交者: GitHub

enable autoTP for mpt in huggingface model hub without trust_remote_code (#4062)

see  https://github.com/huggingface/transformers/tree/main/src/transformers/models/mptCo-authored-by: NMolly Smith <112220543+molly-smith@users.noreply.github.com>
Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
上级 8fb111c0
......@@ -26,7 +26,7 @@ from ..module_inject.policy import TransformerPolicy
from ..module_inject.auto_tp import AutoTP
from ..module_inject.replace_policy import generic_policies
from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor
from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor, build_mpt_alibi_tensor
DS_INFERENCE_ENABLED = False
from torch import nn
......@@ -187,6 +187,9 @@ class InferenceEngine(Module):
if hasattr(self.module, 'transformer'):
if hasattr(self.module.transformer, 'build_alibi_tensor'):
self.module.transformer.build_alibi_tensor = build_bloom_alibi_tensor
if hasattr(self.module.transformer, 'build_mpt_alibi_tensor'):
self.module.transformer.build_mpt_alibi_tensor_orig = self.module.transformer.build_mpt_alibi_tensor
self.module.transformer.__class__.build_mpt_alibi_tensor = build_mpt_alibi_tensor
def build_attn_bias(self):
if hasattr(self.module, 'transformer'):
......
......@@ -296,21 +296,6 @@ class AutoTP():
if getattr(child, "replaced", False) == True:
return
weight_shape = child.weight.shape
if name == 'attn.Wqkv' and self.module._get_name() == 'MPTBlock':
# MPT block qkv weight's allocation is different from other models, it's [3,num_head,head_dim,hidden_size]
# instead of [num_head,3,head_dim,hidden_size]
new_weight = torch.empty((
weight_shape[0] // self.mp_size,
weight_shape[1],
),
device=child.weight.device,
dtype=child.weight.dtype)
reversed_dim = True
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group, out_dim=0)
# todo: can we remove new tensor allocation if we use strided copy?
mp_replace.strided_copy(new_weight, child.weight.data, num_splits=3, int8=reversed_dim)
setattr(child, "replaced", True)
return LinearLayer(weight=new_weight.to(get_accelerator().current_device_name()), bias=None)
mp_replace = ReplaceWithTensorSlicing(mp_group=self.mp_group)
if name in self.all_reduce_linears:
# if conv_linear_layer [weight_shape[1], weight_shape[0] // mp_size]
......
......@@ -76,3 +76,18 @@ def build_mpt_atten_bias_tensor(self,
offset = dist.get_rank() * num_heads_per_rank
attn_bias = attn_bias[:, offset:num_heads_per_rank + offset, :, :]
return attn_bias, attention_mask
def build_mpt_alibi_tensor(self, num_heads, sequence_length, alibi_bias_max=8, device=None) -> torch.Tensor:
r"""
Link to paper: https://arxiv.org/abs/2108.12409 - Alibi tensor is not causal as the original paper mentions, it
relies on a translation invariance of softmax for quick implementation. This implementation has been copied from
the alibi implementation of MPT source code that led to slightly different results than the Bloom alibi:
https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L292
"""
alibi = self.build_mpt_alibi_tensor_orig(num_heads, sequence_length, alibi_bias_max, device)
if dist.is_initialized():
num_heads_per_rank = int(num_heads / dist.get_world_size())
offset = dist.get_rank() * num_heads_per_rank
alibi = alibi[offset:num_heads_per_rank + offset, :, :]
return alibi
......@@ -16,7 +16,7 @@ def split_by_qkvlist_and_refuse(qkv_list, split_size, split_dim=0, cat_dim=0):
def require_tp_fused_qkvw(name, mp_size):
fused_qkvw_name_list = ['qkv_proj', 'query_key_value']
fused_qkvw_name_list = ['qkv_proj', 'query_key_value', 'attn.Wqkv']
if mp_size == 1:
return False
......@@ -33,6 +33,8 @@ def prepare_tp_fused_qkvw(module_str, src, mp_size, gpu_index):
'CodeGenBlock': 'codegentype',
'BloomBlock': 'bloomtype',
'GLMBlock': 'glmtype',
"MPTBlock": 'glmtype',
"MptBlock": 'glmtype',
}
def _codegen_type_transpose(input, mp_size, codegen_mp_num=4):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册