未验证 提交 d9b788d7 编写于 作者: J Jeff Rasley 提交者: GitHub

tweaks to ds-attn, distilbert policy, and mup (#2649)

上级 6375cb3f
......@@ -91,15 +91,17 @@ class DeepSpeedTransformerInference(nn.Module):
def forward(
self,
input,
input=None,
input_mask=None,
attention_mask=None,
attn_mask=None,
head_mask=None,
layer_past=None,
get_key_value=False,
get_present=False,
encoder_output=None,
enc_dec_attn_mask=None,
x=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
......@@ -109,6 +111,13 @@ class DeepSpeedTransformerInference(nn.Module):
# This needs to be redesigned later!
layer_head_mask=None,
past_key_value=None):
if x is not None:
input = x
input_mask = (input_mask if attn_mask is None else
attn_mask) if attention_mask is None else attention_mask
# Allocate memory only on first layer forward
if self.config.layer_id == 0:
self.allocate_workspace(self.config.hidden_size,
......@@ -167,7 +176,9 @@ class DeepSpeedTransformerInference(nn.Module):
if get_present:
output = (output, presents)
if self.config.return_tuple:
if self.config.return_single_tuple:
return (output, )
elif self.config.return_tuple:
return output if type(output) is tuple else (output, attn_mask)
else:
return output
......@@ -6,7 +6,7 @@ import deepspeed.ops.transformer as transformer_inference
from deepspeed.ops.transformer.inference.diffusers_attention import DeepSpeedDiffusersAttention
from deepspeed.ops.transformer.inference.diffusers_transformer_block import DeepSpeedDiffusersTransformerBlock
from deepspeed.ops.transformer.inference.diffusers_2d_transformer import Diffusers2DTransformerConfig
from .replace_policy import HFBertLayerPolicy, HFGPT2LayerPolicy, BLOOMLayerPolicy
from .replace_policy import HFBertLayerPolicy, HFGPT2LayerPolicy, BLOOMLayerPolicy, HFDistilBertLayerPolicy
from .replace_policy import replace_policies, generic_policies
from deepspeed import comm as dist
......@@ -438,7 +438,8 @@ def replace_transformer_layer(orig_layer_impl,
q_int8=quantize,
return_tuple=(config.return_tuple
or (policy_cls is HFBertLayerPolicy)),
triangular_masking=(policy_cls is not HFBertLayerPolicy),
triangular_masking=(policy_cls is not HFBertLayerPolicy
and policy_cls is not HFDistilBertLayerPolicy),
local_attention=((model_config.attention_layers[layer_id] == "local")
if hasattr(model_config,
'attention_layers') else False),
......@@ -451,7 +452,10 @@ def replace_transformer_layer(orig_layer_impl,
training_mp_size=config.training_mp_size,
bigscience_bloom=bigscience_bloom,
max_out_tokens=config.max_out_tokens,
scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx)
scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
use_mup=policy_cls.use_mup if hasattr(policy_cls,
'use_mup') else False,
return_single_tuple=(policy_cls is HFDistilBertLayerPolicy))
global transformer_config_g
transformer_config_g = transformer_config
......
......@@ -106,6 +106,7 @@ class TransformerPolicy(DSPolicy):
# whether or not the qkv is stored in the split-format
split_qkv=True):
super().__init__()
self.cuda_graph_supported = False
self.inference = inference
self.linear_layer = linear_layer
self.scale_attention = scale_attention
......@@ -406,6 +407,7 @@ class MegatronLayerPolicy(TransformerPolicy):
version = 0
moe_type = 'standard'
megatron_v2 = True
use_mup = False
def __init__(self, client_module, inference=True):
super().__init__(inference, megatron_v2=MegatronLayerPolicy.megatron_v2)
......@@ -731,6 +733,62 @@ class HFOPTLayerPolicy(TransformerPolicy):
self.split_qkv
class HFDistilBertLayerPolicy(TransformerPolicy):
_orig_layer_class = None
def __init__(self, client_module, inference=False, preln=False):
super().__init__(inference)
self.client_module = client_module
self.preln = preln
self.cuda_graph_supported = True
if HFDistilBertLayerPolicy._orig_layer_class is None:
try:
import transformers
HFDistilBertLayerPolicy._orig_layer_class = [
transformers.models.distilbert.modeling_distilbert.TransformerBlock,
]
except:
HFDistilBertLayerPolicy._orig_layer_class = None
def get_hidden_heads(self):
return self.client_module.attention.q_lin.weight.shape[1], \
self.client_module.attention.n_heads
def attention(self):
qw = self.client_module.attention.q_lin.weight
qb = self.client_module.attention.q_lin.bias
kw = self.client_module.attention.k_lin.weight
kb = self.client_module.attention.k_lin.bias
vw = self.client_module.attention.v_lin.weight
vb = self.client_module.attention.v_lin.bias
qkvw = Parameter(torch.cat((qw, kw, vw), dim=0))
qkvb = Parameter(torch.cat((qb, kb, vb), dim=0))
return self.linear_layer, \
qkvw, \
qkvb, \
self.client_module.attention.out_lin.weight, \
self.client_module.attention.out_lin.bias, \
self.scale_attention, \
False
def mlp(self):
intermediate_ff = self.client_module.ffn.lin1
return self.linear_layer, intermediate_ff.weight, intermediate_ff.bias, \
self.client_module.ffn.lin2.weight, \
self.client_module.ffn.lin2.bias
def layerNorm(self):
attention_layernorm = self.client_module.sa_layer_norm
transformer_layernorm = self.client_module.output_layer_norm
return attention_layernorm.weight, \
attention_layernorm.bias, \
transformer_layernorm.weight, \
transformer_layernorm.bias
# transformer-based policies
replace_policies = [
HFBertLayerPolicy,
......@@ -742,6 +800,7 @@ replace_policies = [
BLOOMLayerPolicy,
HFOPTLayerPolicy,
HFCLIPLayerPolicy,
HFDistilBertLayerPolicy
]
# non-transformer-based policies
......
......@@ -65,7 +65,10 @@ class DeepSpeedInferenceConfig(TransformerConfig):
training_mp_size=1,
bigscience_bloom=False,
max_out_tokens=1024,
scale_attn_by_inverse_layer_idx=False):
enable_qkv_quantization=False,
use_mup=False,
scale_attn_by_inverse_layer_idx=False,
return_single_tuple=False):
super(DeepSpeedInferenceConfig,
self).__init__(
hidden_size,
......@@ -94,6 +97,9 @@ class DeepSpeedInferenceConfig(TransformerConfig):
self.bigscience_bloom = bigscience_bloom
self.max_out_tokens = max_out_tokens
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
self.enable_qkv_quantization = enable_qkv_quantization
self.use_mup = use_mup
self.return_single_tuple = return_single_tuple
@classmethod
def from_dict(cls, json_object):
......
......@@ -61,8 +61,9 @@ class DeepSpeedSelfAttention(nn.Module):
self.q_groups = q_groups
self.merge_count = int(math.log2(merge_count))
self.norm_factor = math.sqrt(
math.sqrt(self.config.hidden_size // self.config.heads))
self.norm_factor = math.sqrt(self.config.hidden_size // self.config.heads)
if not config.use_mup:
self.norm_factor = math.sqrt(self.norm_factor)
self.qkv_merging = qkv_merging
if self.config.scale_attn_by_inverse_layer_idx is True:
......@@ -83,23 +84,16 @@ class DeepSpeedSelfAttention(nn.Module):
if no_masking:
input_mask = torch.empty(1)
if alibi is not None:
batch_heads = qkv_out.shape[0] * self.num_attention_heads_per_partition
offset = dist.get_rank() * batch_heads if dist.is_initialized() else 0
sliced_alibi = alibi[offset:batch_heads + offset, :, :]
else:
sliced_alibi = torch.empty(1)
attn_key_value = self.score_context_func(
query_key_value=qkv_out,
attn_mask=((1 - input_mask).to(qkv_out.dype) *
attn_mask=((1 - input_mask).to(qkv_out.dtype) *
minus_inf) if input_mask.dtype == torch.int64 else input_mask,
heads=self.num_attention_heads_per_partition,
norm_factor=(1 / self.norm_factor if self.config.scale_attention else 1.0),
no_masking=no_masking,
layer_id=self.config.layer_id,
num_layers=DeepSpeedSelfAttention.num_layers,
alibi=sliced_alibi)
alibi=alibi)
context_layer, key_layer, value_layer = attn_key_value
return context_layer, key_layer, value_layer
......@@ -123,7 +117,8 @@ class DeepSpeedSelfAttention(nn.Module):
bias=self.attn_qkvb,
add_bias=self.attn_qkvb is not None,
do_flash_attn=False,
num_heads=self.num_attention_heads_per_partition)
num_heads=self.num_attention_heads_per_partition,
num_layers=DeepSpeedSelfAttention.num_layers)
else:
qkv_out = self.qkv_func(
input=input,
......@@ -132,7 +127,8 @@ class DeepSpeedSelfAttention(nn.Module):
gamma=norm_w,
beta=norm_b,
add_bias=(self.attn_qkvb is not None),
num_layers=DeepSpeedSelfAttention.num_layers)
num_layers=DeepSpeedSelfAttention.num_layers,
num_heads=self.num_attention_heads_per_partition)
context_layer, key_layer, value_layer = self.compute_attention(
qkv_out=qkv_out,
......
import torch
from deepspeed import comm as dist
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
......@@ -19,9 +20,15 @@ class SoftmaxContextOp(BaseOp):
no_masking: bool,
layer_id: int,
num_layers: int,
alibi: torch.Tensor,
alibi_offset: int = None,
mp_size: int = None):
alibi: torch.Tensor):
if alibi is not None:
batch_heads = query_key_value.shape[0] * heads
offset = dist.get_rank() * batch_heads if dist.is_initialized() else 0
alibi = alibi[offset:batch_heads + offset, :, :]
else:
alibi = torch.empty(1)
output = self.softmax_context_func(query_key_value,
attn_mask,
self.config.rotary_dim,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册