diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index 74ea5ef7221458799f0732bef82e3ee9e4dde95b..de7ca5a711976d3f9718524c81c71234c28edbbe 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -34,6 +34,58 @@ from torch import nn INFERENCE_MODEL_TIMER = "model-forward-inference" +def build_bloom_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: + """ + Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it + relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value + `softmax(l+a) = softmax(l)`. Based on + https://github.com/ofirpress/attention_with_linear_biases/blob/a35aaca144e0eb6b789dfcb46784c4b8e31b7983/fairseq/models/transformer.py#L742 + TODO @thomasw21 this doesn't work as nicely due to the masking strategy, and so masking varies slightly. + + Args: + Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) + attention_mask (`torch.Tensor`): + Token-wise attention mask, this should be of shape (batch_size, max_seq_len). + num_heads (`int`, *required*): + number of heads + dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): + dtype of the output tensor + """ + import math + batch_size, seq_length = attention_mask.shape + closest_power_of_2 = 2**math.floor(math.log2(num_heads)) + base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32) + powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != num_heads: + extra_base = torch.tensor(2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + device=attention_mask.device, + dtype=torch.float32) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly + # This is more or less identical to T5's relative position bias: + # https://github.com/huggingface/transformers/blob/f681437203baa7671de3174b0fa583c349d9d5e1/src/transformers/models/t5/modeling_t5.py#L527 + arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] + alibi = slopes[..., None] * arange_tensor + if dist.is_initialized(): + num_heads_per_rank = int(num_heads / dist.get_world_size()) + offset = dist.get_rank() * num_heads_per_rank + alibi = alibi.view(batch_size, num_heads, 1, seq_length) + alibi = alibi[:, offset:num_heads_per_rank + offset, :, :] + return alibi.reshape(batch_size * num_heads_per_rank, 1, seq_length).to(dtype) + else: + return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) + + class InferenceEngine(Module): inference_mp_group = None inference_ep_group = None @@ -86,8 +138,14 @@ class InferenceEngine(Module): self.model_profile_enabled = False self._model_times = [] - # This is a hack to remove the prepare_mask function on HF side for BLOOM architecture - self.remove_mask_prepare_for_bloom() + if not self.injection_dict and config.replace_with_kernel_inject: + # This is a hack to remove the prepare_mask function on HF side for BLOOM architecture + self.remove_mask_prepare_for_bloom() + + if self.injection_dict or not config.replace_with_kernel_inject: + # This is a hack to redefine the alibi func due to TP + if config.tensor_parallel.tp_size > 1: + self.build_alibi_tensor() if get_accelerator().device_name() == 'cuda' and config.enable_cuda_graph: assert pkg_version.parse(torch.__version__) >= pkg_version.parse("1.10"), \ @@ -178,6 +236,11 @@ class InferenceEngine(Module): if hasattr(self.module.transformer, '_prepare_attn_mask'): self.module.transformer._prepare_attn_mask = lambda attention_mask, *args, **kwargs: attention_mask + def build_alibi_tensor(self): + if hasattr(self.module, 'transformer'): + if hasattr(self.module.transformer, 'build_alibi_tensor'): + self.module.transformer.build_alibi_tensor = build_bloom_alibi_tensor + def _pre_forward_hook(self, module, *inputs, **kwargs): if self.use_cuda_events: self.timers(INFERENCE_MODEL_TIMER).start() diff --git a/deepspeed/module_inject/auto_tp.py b/deepspeed/module_inject/auto_tp.py index 578278c4425c021f74f2ab512fa38ce53a41c991..bf49df9781f56dcbacf8e4ef2888d9c5c437d6e4 100644 --- a/deepspeed/module_inject/auto_tp.py +++ b/deepspeed/module_inject/auto_tp.py @@ -32,7 +32,7 @@ class AutoTP(): return mlist def supported(model): - unsupported = ['bloom', 'codegen', 'deberta', 'flaubert', 'fsmt', 'gpt2', 'led', 'longformer', 'xlm', 'xlnet'] + unsupported = ['codegen', 'deberta', 'flaubert', 'fsmt', 'gpt2', 'led', 'longformer', 'xlm', 'xlnet'] model = str(model) key = re.search(r": (.*?)Model", model) if key is None: