diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml index 3a59170b42abef5d3b172ef9608ed66c37464ef2..c4af4e5a95358821d94584a857a2ab0be4027c36 100644 --- a/.github/workflows/formatting.yml +++ b/.github/workflows/formatting.yml @@ -27,7 +27,7 @@ jobs: - name: Install deepspeed run: | - pip install .[dev,autotuning] + pip install .[dev,autotuning,triton] ds_report - name: Formatting checks diff --git a/blogs/assets/images/triton-bert-base-latency.png b/blogs/assets/images/triton-bert-base-latency.png new file mode 100644 index 0000000000000000000000000000000000000000..4f733f4d1afe0e7720ff0e3e93dc842f2fce1fe7 Binary files /dev/null and b/blogs/assets/images/triton-bert-base-latency.png differ diff --git a/blogs/assets/images/triton-bert-large-latency.png b/blogs/assets/images/triton-bert-large-latency.png new file mode 100644 index 0000000000000000000000000000000000000000..d82dc0ccac51392d6ad8062031896dffef6cb039 Binary files /dev/null and b/blogs/assets/images/triton-bert-large-latency.png differ diff --git a/blogs/deepspeed-triton/README.md b/blogs/deepspeed-triton/README.md new file mode 100644 index 0000000000000000000000000000000000000000..071b5d4bc6d038bf49e36cae30a7544fed2bd776 --- /dev/null +++ b/blogs/deepspeed-triton/README.md @@ -0,0 +1,95 @@ +# DeepSpeed with Triton compiler + +# 1. Overview + +We have integrated [Triton](https://github.com/openai/triton), an open source compiler for GPU programming, into DeepSpeed, which further boosts the inference speed of BERT-like models in float16 precision. +By replacing some CUDA kernels or torch operators with Triton kernels, we achieved 1.14\~1.68x speedup (or 12\~41% latency reduction) for different models and GPUs, as shown in Table 1. + +
+ +| Hardware | Bert-base | Bert-large | Roberta-base | Roberta-large | +|----------|:------:|:------:|:------:|:------:| +| A100 |1.65x | 1.68x | 1.53x | 1.61x | +| V100 | 1.29x | 1.14x | 1.23x | 1.21x | + +Table 1. The average speedup (see NOTE below for more detail) + + +
+ +For those transformer operators in float16, we have implemented kernels written in Triton language that replace ordinary CUDA kernels or torch operators. +The Triton kernels we implemented include softmax, layer-normalization, residual-addition and all the matrix multiplications except MLP layers (see NOTE below for details). +In our experiments, Triton kernels help to reduce the average latecy (over difference sequence lengths) by 6\~24% (depending on model and hardware) when compared to the latency with CUDA-only kernels. + + +Figures below show the latency reduction in more detail. +Figure 1 visualizes latency reduction in different sequence lengths in A100 GPU for Bert-base model. +The baseline (blue) is from Huggingface transformers without any kernel injection, the orange is from Deepspeed with CUDA-only kernels and the gray is from Deepspeed with Triton kernels. +Figure 2 shows the same plot for Bert-large model in A100 GPU. + +
+ +triton-bert-base-latency + +*Figure 1: Normalized P90 latency for Bert-base model in A100 GPU across different sequence lengths* + +triton-bert-large-latency + +*Figure 2: Normalized P90 latency for Bert-large model in A100 GPU across different sequence lengths* + +
+ + +Next, we dive deeper into this new feature in DeepSpeed. + +# 2. How to use Triton in Deepspeed + +You can enable Triton compilers to optimize these kernels by setting a flag in the DeepSpeed config file. + +``` +pipe = pipeline('fill-mask', model='bert-base-cased', framework='pt', device=0) +pipe.model = deepspeed.init_inference(pipe.model, + dtype=torch.float16, + replace_with_kernel_inject=True, + enable_cuda_graph=True, + use_triton=True, + triton_autotune=True, + max_out_tokens=pipe.tokenizer.model_max_length) +``` + + +## Running BERT inference with Triton kernels + +We use an example of Bert-base here. + +```python +pip install deepspeed[triton] + +git clone https://github.com/microsoft/DeepSpeedExamples.git +cd DeepSpeedExamples/inference/huggingface/fill-mask + +deepspeed --num_gpus 1 test-bert.py --triton +``` + +To run a performance benchmark, you can use the following command: + +```python +pip install deepspeed[triton] + +git clone https://github.com/microsoft/DeepSpeedExamples.git +cd DeepSpeedExamples/benchmarks/inference + +deepspeed --num_gpus 1 triton-bert-benchmark.py --model bert-base-cased --dtype fp16 --kernel-inject --deepspeed --graphs --triton +``` + +# NOTE + +* For more information on how to use DeepSpeed, please visit our [GitHub Page](https://github.com/microsoft/DeepSpeedExamples) and our [website](https://www.deepspeed.ai/), where you can find blog posts, tutorials, and documentation. + +* This feature is currently only supported for BERT, Roberta and other BERT-like models, and not for text-generation models yet. + +* To achieve the best performance with Triton optimization, you need to activate CUDA graph and ‘triton_autotune’ in the DeepSpeed config. CUDA graph prevents the overhead of JIT compilation and a deep call stack in Triton. ‘triton_autotune’ executes an initial step to find the most suitable parameters for Triton kernels, which may take some time. + +* We used [Triton 2.0.0.post1 release](https://pypi.org/project/triton/2.0.0.post1/) in our experiments. + +* In our experiments, we used a batch size of 1, a sequence length range of 8 to 512, and a ‘fill-mask’ task. Table 1 shows the average P90 latency over the entire sequence length range, while Figures 1 and 2 show the P90 latency for specific sub-ranges. The baseline is the Huggingface transformers without any optimization. The speedup is calculated as (baseline P90 latency)/(DeepSpeed-Triton P90 Latency). We found that the CUDA kernel in MLP performed better than the Triton kernel in our experiments, so we used a hybrid approach that combines both kernels when Triton is enabled in the DeepSpeed config. diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py index 255dacdccf6ef718ff7f40a6970c1feada6e2425..e060ebcb9734780704123910f59a1ae56da2144e 100755 --- a/deepspeed/__init__.py +++ b/deepspeed/__init__.py @@ -12,6 +12,12 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler from packaging import version as pkg_version +try: + import triton # noqa: F401 + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + from . import ops from . import module_inject diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py index 3474578e6a118b44d8086f688dc84fc67482c81b..7f1dcb3dfe06a8e38ee61288b2c62e21cbb0d17e 100644 --- a/deepspeed/inference/config.py +++ b/deepspeed/inference/config.py @@ -4,6 +4,7 @@ # DeepSpeed Team import torch +import deepspeed from deepspeed.runtime.config_utils import DeepSpeedConfigModel from deepspeed.runtime.zero.config import DeepSpeedZeroConfig from pydantic import Field @@ -152,6 +153,18 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): can run faster using the graph replay method. """ + use_triton: bool = False + """ + Use this flag to use triton kernels for inference ops. + """ + + triton_autotune: bool = False + """ + Use this flag to enable triton autotuning. + Turning it on is better for performance but increase the 1st runtime for + autotuning. + """ + zero: DeepSpeedZeroConfig = {} """ ZeRO configuration to use with the Inference Engine. Expects a dictionary @@ -279,6 +292,12 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel): return DeepSpeedMoEConfig(moe=field_value) return field_value + @validator("use_triton") + def has_triton(cls, field_value, values): + if field_value and not deepspeed.HAS_TRITON: + raise ValueError('Triton needs to be installed to use deepspeed with triton kernels') + return field_value + class Config: # Get the str representation of the datatype for serialization json_encoders = {torch.dtype: lambda x: str(x)} diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py index d2f9bf9cf0b26f8ba360a458eeb4561eff516e0a..ab502551e702bd3fb447f4176dde78f709fec751 100755 --- a/deepspeed/inference/engine.py +++ b/deepspeed/inference/engine.py @@ -607,6 +607,7 @@ class InferenceEngine(Module): else: self._create_cuda_graph(*inputs, **kwargs) outputs = self._graph_replay(*inputs, **kwargs) + else: outputs = self.module(*inputs, **kwargs) diff --git a/deepspeed/model_implementations/transformers/ds_transformer.py b/deepspeed/model_implementations/transformers/ds_transformer.py index b91432461fcb1d0554c6f79f75a2cbdb47bc20fd..a41df58ad0591bfb04495da820b210cb8ad11727 100644 --- a/deepspeed/model_implementations/transformers/ds_transformer.py +++ b/deepspeed/model_implementations/transformers/ds_transformer.py @@ -12,6 +12,10 @@ from deepspeed.ops.transformer.inference.ds_mlp import DeepSpeedMLP from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttention, BloomSelfAttention from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder +import deepspeed +if deepspeed.HAS_TRITON: + from deepspeed.ops.transformer.inference.triton.mlp import TritonMLP + from deepspeed.ops.transformer.inference.triton.attention import TritonSelfAttention inference_module = None @@ -55,14 +59,24 @@ class DeepSpeedTransformerInference(nn.Module): if DeepSpeedTransformerInference.layer_id == 1: log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0]) + if deepspeed.HAS_TRITON and self.config.use_triton: + log_dist(f"Injecting Triton kernels ...", [0]) if self.config.bigscience_bloom: self.attention = BloomSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, merge_count) + assert not self.config.use_triton else: - self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, - merge_count) - self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count, - mlp_extra_grouping) + if deepspeed.HAS_TRITON and self.config.use_triton: + self.attention = TritonSelfAttention(self.config) + else: + self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, + merge_count) + + if deepspeed.HAS_TRITON and self.config.use_triton: + self.mlp = TritonMLP(self.config) + else: + self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count, + mlp_extra_grouping) device = get_accelerator().current_device_name() # if config.bigscience_bloom else 'cpu' if self.config.set_empty_params: diff --git a/deepspeed/module_inject/containers/base.py b/deepspeed/module_inject/containers/base.py index a520664793ca61a3168b9f29b49fe902ba8d91f0..62b78a091fd5fe31f78aa5aa92b78bc653f6502a 100644 --- a/deepspeed/module_inject/containers/base.py +++ b/deepspeed/module_inject/containers/base.py @@ -8,6 +8,7 @@ from abc import ABC import torch +import deepspeed from deepspeed.ops.transformer.inference.config import DeepSpeedInferenceConfig from deepspeed.accelerator import get_accelerator @@ -79,6 +80,10 @@ class BaseTransformerContainer(ABC): self.input_nb = None self.mp_group = None + self.use_triton = False + + # Triton + self.use_triton = config.use_triton and deepspeed.HAS_TRITON def create_ds_model_config(self): self.set_hidden_heads(*self.policy.get_hidden_heads()) @@ -110,7 +115,14 @@ class BaseTransformerContainer(ABC): use_mup=self.use_mup, return_single_tuple=self.return_single_tuple, set_empty_params=self.config.set_empty_params, - transposed_mode=self.config.transposed_mode) + transposed_mode=self.config.transposed_mode, + use_triton=self.use_triton, + triton_autotune=self.config.triton_autotune) + + if self.use_triton and deepspeed.HAS_TRITON: + if not self.config.triton_autotune: + from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul + fp16_matmul.skip_autotune() return self.ds_model_config diff --git a/deepspeed/module_inject/containers/bert.py b/deepspeed/module_inject/containers/bert.py index 967a02276be607957b28edcb5c4402f4b3a7e4df..20ae575f45144733a82b609eed21284746cb95d0 100644 --- a/deepspeed/module_inject/containers/bert.py +++ b/deepspeed/module_inject/containers/bert.py @@ -18,6 +18,7 @@ class DS_BERTContainer(BaseTransformerContainer): # All model specific things should be defined here instead of the base class. self.return_tuple = True self.triangular_masking = False + self.use_triton = kwargs['config'].use_triton and deepspeed.HAS_TRITON def create_module(self, config=None): _config = config if config is not None else self.ds_model_config diff --git a/deepspeed/module_inject/containers/distil_bert.py b/deepspeed/module_inject/containers/distil_bert.py index 2acd144cac044a880c8480d950acc08ff0fc6877..ecd0562438b5ac634cf8b4536fd3413d0f9ed9d8 100644 --- a/deepspeed/module_inject/containers/distil_bert.py +++ b/deepspeed/module_inject/containers/distil_bert.py @@ -18,6 +18,7 @@ class DS_DistilBERTContainer(BaseTransformerContainer): # All model specific things should be defined here instead of the base class. self.triangular_masking = False self.return_single_tuple = True + self.use_triton = kwargs['config'].use_triton and deepspeed.HAS_TRITON def create_module(self, config=None): _config = config if config is not None else self.ds_model_config diff --git a/deepspeed/ops/transformer/inference/config.py b/deepspeed/ops/transformer/inference/config.py index 261523529d0b15104c1933fba34e9a6d92632f86..12007f6c14ca67f51816143276c83041b0800705 100644 --- a/deepspeed/ops/transformer/inference/config.py +++ b/deepspeed/ops/transformer/inference/config.py @@ -44,6 +44,7 @@ class DeepSpeedInferenceConfig(TransformerConfig): scale_attention: If true, both q and k are scaled by 1/sqrt(attention_heads) before attention computation. return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor bigscience_bloom: This flag is added temporarily for supporting the BLOOM-176B model architecture. + use_triton: This flag is to enable triton kernels in inference or not. """ def __init__(self, @@ -77,7 +78,9 @@ class DeepSpeedInferenceConfig(TransformerConfig): scale_attn_by_inverse_layer_idx=False, return_single_tuple=False, set_empty_params=False, - transposed_mode=False): + transposed_mode=False, + use_triton=False, + triton_autotune=False): super(DeepSpeedInferenceConfig, self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads, num_hidden_layers) @@ -109,6 +112,8 @@ class DeepSpeedInferenceConfig(TransformerConfig): self.return_single_tuple = return_single_tuple self.set_empty_params = set_empty_params self.transposed_mode = transposed_mode + self.use_triton = use_triton + self.triton_autotune = triton_autotune @classmethod def from_dict(cls, json_object): diff --git a/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py b/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py index 1d2dc06d0b00865033650070266b146766fd2132..63323c150752b9b98e1120e427251ced382b115a 100644 --- a/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py +++ b/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py @@ -6,6 +6,7 @@ import torch from ..config import DeepSpeedInferenceConfig from .base import BaseOp +import deepspeed class GELUGemmOp(BaseOp): @@ -14,9 +15,13 @@ class GELUGemmOp(BaseOp): super(GELUGemmOp, self).__init__(config) try: if self.config.dtype in [torch.float16, torch.int8]: - self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp16 # type: ignore + if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16: + from deepspeed.ops.transformer.inference.triton.ops import fused_gemm_gelu as _triton_fused_gemm_gelu + self.fused_gemm_gelu = _triton_fused_gemm_gelu # type: ignore + else: + self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp16 # type: ignore elif self.config.dtype == torch.bfloat16: - self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_bf16 + self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_bf16 # type: ignore else: self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp32 # type: ignore except AttributeError: diff --git a/deepspeed/ops/transformer/inference/op_binding/linear.py b/deepspeed/ops/transformer/inference/op_binding/linear.py index 6ea15d608f42f33cbf422b02812a93286d7498ec..e970b562c6d6512d2d0de3d8aa564fbe2bf13d3b 100644 --- a/deepspeed/ops/transformer/inference/op_binding/linear.py +++ b/deepspeed/ops/transformer/inference/op_binding/linear.py @@ -6,6 +6,7 @@ import torch from ..config import DeepSpeedInferenceConfig from .base import BaseOp +import deepspeed class LinearOp(BaseOp): @@ -14,6 +15,14 @@ class LinearOp(BaseOp): super(LinearOp, self).__init__(config) try: if self.config.dtype in [torch.float16, torch.int8]: + if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16: + from deepspeed.ops.transformer.inference.triton.ops import linear_func as _triton_linear_func + self.linear_func = _triton_linear_func + triton_autotune = config.triton_autotune and config.layer_id == 0 + if triton_autotune: + __class__._triton_autotune(2, self.config.max_out_tokens, self.config.hidden_size) + else: + self.linear_func = self.inference_module.linear_layer_fp16 self.linear_func = self.inference_module.linear_layer_fp16 elif self.config.dtype == torch.bfloat16: self.linear_func = self.inference_module.linear_layer_bf16 @@ -37,3 +46,15 @@ class LinearOp(BaseOp): qkv_out = self.linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads, self.config.transposed_mode) return qkv_out + + @staticmethod + def _triton_autotune(min_seqlen, max_seqlen, hidden_size, dtype=torch.float16): + from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, matmul + seqlen = [(min_seqlen + i) + for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)] + Fp16Matmul._read_autotune_table() + for N in seqlen: + A = torch.randn((N, hidden_size), dtype=dtype, device='cuda') + B = torch.randn((hidden_size, 3 * hidden_size), dtype=dtype, device='cuda') + matmul(A, B) + Fp16Matmul._update_autotune_table() diff --git a/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py b/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py index e3e372d60080ad85ff0b6ca77db7efcf390957a1..d5e12cb9a80162859288aebac8bdd5aee5e52a6e 100644 --- a/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py +++ b/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py @@ -19,7 +19,9 @@ class MLPGemmOp(BaseOp): super(MLPGemmOp, self).__init__(config) try: if self.config.norm_type == NormType.LayerNorm: - if self.config.dtype in [torch.float16, torch.int8]: + if self.config.dtype in [ + torch.float16, torch.int8 + ]: # non-triton cuda kernel has a higher performance in MLP than mlp_gemm_func in triton.ops self.mlp_gemm_func = self.inference_module.mlp_gemm_fp16 # type: ignore elif self.config.dtype == torch.bfloat16: self.mlp_gemm_func = self.inference_module.mlp_gemm_bf16 diff --git a/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py b/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py index a6b253930ede37edffe3ba7db97743f88875a51f..dca935c1eb118f0245edcc0dd61e9b19ef2ee641 100644 --- a/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py +++ b/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py @@ -8,6 +8,7 @@ import torch import torch.nn.functional as F from ..config import DeepSpeedInferenceConfig from .base import BaseOp +import deepspeed from deepspeed.utils.types import NormType @@ -18,7 +19,14 @@ class QKVGemmOp(BaseOp): try: if self.config.norm_type == NormType.LayerNorm: if self.config.dtype in [torch.float16, torch.int8]: - self.qkv_gemm_func = self.inference_module.qkv_gemm_fp16 # type: ignore + if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16: + from deepspeed.ops.transformer.inference.triton.ops import qkv_gemm_func as _triton_qkv_gemm_func + self.qkv_gemm_func = _triton_qkv_gemm_func + triton_autotune = config.triton_autotune and config.layer_id == 0 + if triton_autotune: + __class__._triton_autotune(2, self.config.max_out_tokens, self.config.hidden_size) + else: + self.qkv_gemm_func = self.inference_module.qkv_gemm_fp16 # type: ignore elif self.config.dtype == torch.bfloat16: self.qkv_gemm_func = self.inference_module.qkv_gemm_bf16 else: @@ -36,6 +44,18 @@ class QKVGemmOp(BaseOp): elif self.config.norm_type == NormType.RMSNorm: self.qkv_gemm_func = self.rms_qkv_gemm_fallback + @staticmethod + def _triton_autotune(min_seqlen, max_seqlen, hidden_size, dtype=torch.float16): + from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, matmul + seqlen = [(min_seqlen + i) + for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)] + Fp16Matmul._read_autotune_table() + for N in seqlen: + A = torch.randn((N, hidden_size), dtype=dtype, device='cuda') + B = torch.randn((hidden_size, 3 * hidden_size), dtype=dtype, device='cuda') + matmul(A, B) + Fp16Matmul._update_autotune_table() + def qkv_gemm_fallback(self, input, weight, q_scale, bias, gamma, beta, eps, add_bias, q_int8, transpose): if os.environ.get('DS_KI_FALLBACK') == 'True' and not transpose: inp_norm = F.layer_norm(input, (input.shape[2], ), gamma, beta, eps) diff --git a/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py b/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py index cc40633ab0eeb8a892b35bfd77f5f4c9c2d0610a..011be859634d5735937f32dbea861f455200fff3 100644 --- a/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py +++ b/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py @@ -7,6 +7,7 @@ import os import torch from ..config import DeepSpeedInferenceConfig from .base import BaseOp +import deepspeed class VectorMatMulOp(BaseOp): @@ -14,7 +15,16 @@ class VectorMatMulOp(BaseOp): def __init__(self, config: DeepSpeedInferenceConfig): super(VectorMatMulOp, self).__init__(config) try: - if self.config.dtype in [torch.float16, torch.int8]: + if self.config.dtype == torch.float16: + if deepspeed.HAS_TRITON and config.use_triton: + from deepspeed.ops.transformer.inference.triton.ops import vector_matmul_func as _triton_vector_matmul_func + self.vector_matmul_func = _triton_vector_matmul_func + triton_autotune = config.triton_autotune and config.layer_id == 0 + if triton_autotune: + __class__._triton_autotune(2, self.config.max_out_tokens, self.config.hidden_size) + else: + self.vector_matmul_func = self.inference_module.vector_matmul_fp16 + elif self.config.dtype == torch.int8: self.vector_matmul_func = self.inference_module.vector_matmul_fp16 elif self.config.dtype == torch.bfloat16: self.vector_matmul_func = self.inference_module.vector_matmul_bf16 @@ -34,3 +44,15 @@ class VectorMatMulOp(BaseOp): q_int8 = self.config.dtype == torch.int8 output = self.vector_matmul_func(input, weight, async_op, q_scale, q_int8, self.config.transposed_mode) return output + + @staticmethod + def _triton_autotune(min_seqlen, max_seqlen, hidden_size, dtype=torch.float16): + from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, matmul + seqlen = [(min_seqlen + i) + for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)] + Fp16Matmul._read_autotune_table() + for N in seqlen: + A = torch.randn((N, hidden_size), dtype=dtype, device='cuda') + B = torch.randn((hidden_size, hidden_size), dtype=dtype, device='cuda') + matmul(A, B) + Fp16Matmul._update_autotune_table() diff --git a/deepspeed/ops/transformer/inference/triton/__init__.py b/deepspeed/ops/transformer/inference/triton/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..b7d1968df62a99849992c8b1e93698d9e51cec30 --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .residual_add import residual_add_bias +from .layer_norm import layer_norm, layer_norm_residual +from .gelu import gelu +from .softmax import softmax +from .ops import * +from .matmul_ext import fp16_matmul, matmul_4d, score_4d_matmul, context_4d_matmul diff --git a/deepspeed/ops/transformer/inference/triton/attention.py b/deepspeed/ops/transformer/inference/triton/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3bb0f47f413c9980b422abb772f8b5252f746ee2 --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/attention.py @@ -0,0 +1,229 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import math +import torch +import torch.nn as nn +from deepspeed.accelerator import get_accelerator +from deepspeed import comm as dist +from deepspeed.ops.transformer.inference.op_binding import LinearOp, VectorMatMulOp, SoftmaxContextOp, QKVGemmOp +from deepspeed.ops.transformer.inference.triton import ( + softmax, + score_4d_matmul, + context_4d_matmul, +) + +minus_inf = -10000.0 + + +class TritonSelfAttention(nn.Module): + num_layers = 0 + + def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1, qkv_merging=False): + super(TritonSelfAttention, self).__init__() + self.config = config + data_type = self.config.dtype + data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype + assert data_type_fp == torch.half, "triton supports fp16 data_type_fp" + + self.config.layer_id = TritonSelfAttention.num_layers + TritonSelfAttention.num_layers = TritonSelfAttention.num_layers + 1 + device = get_accelerator().current_device_name() #if config.bigscience_bloom else 'cpu' + + assert config.mp_size == 1, "mp_size has to be 1 with triton attention yet" + if self.config.set_empty_params: + self.attn_qw = None + self.attn_qb = None + self.attn_kw = None + self.attn_kb = None + self.attn_vw = None + self.attn_vb = None + self.attn_qkvw = None + self.attn_qkvb = None + self.attn_ow = None + self.attn_ob = None + else: + qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3 + self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size, + qkv_size_per_partition, + dtype=data_type, + device=device), + requires_grad=False) + self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device), + requires_grad=False) + # self-ouput weights + out_size_per_partition = self.config.hidden_size // self.config.mp_size + self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition, + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) + + self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), + requires_grad=False) + + self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size + self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size + self.hidden_size_per_attention_head = self.config.hidden_size // self.config.heads + + self.mp_group = mp_group + self.use_flash = False + + # used for quantization + self.q_scales = q_scales + self.q_groups = q_groups + self.merge_count = int(math.log2(merge_count)) + + self.norm_factor = math.sqrt(self.config.hidden_size // self.config.heads) + if not config.use_mup: + self.norm_factor = math.sqrt(self.norm_factor) + + if self.config.scale_attn_by_inverse_layer_idx is True: + self.norm_factor *= math.sqrt(self.config.layer_id + 1) + # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/gpt2/modeling_gpt2.py#L191 + + triton_autotune = self.config.triton_autotune and self.config.layer_id == 0 + self.qkv_func = QKVGemmOp(config) + self.score_context_func = SoftmaxContextOp(config) + self.linear_func = LinearOp(config) + self.vector_matmul_func = VectorMatMulOp(config) + + self.hidden_size = config.hidden_size + self.head_size = config.hidden_size // config.heads + self.scale = (1 / self.norm_factor / self.norm_factor if self.config.scale_attention else 1.0 + ) # making it back to 1/sqrt(head_size) + self.triangular_masking = self.config.triangular_masking + + # triton autotune table update for score/context matmul + if triton_autotune: + print(f"running triton autotune for attention") + __class__._triton_autotune(2, self.config.max_out_tokens, self.head_size, self.config.hidden_size, + self.triangular_masking, self.scale) + + @staticmethod + def _triton_autotune(min_seqlen, + max_seqlen, + head_size, + hidden_size, + triangular_masking, + scale, + dtype=torch.float16): + from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, score_4d_matmul, context_4d_matmul + seqlen = [(min_seqlen + i) + for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)] + Fp16Matmul._read_autotune_table() + for N in seqlen: + qkv = torch.randn((1, N, 3 * hidden_size), dtype=dtype, device='cuda') + output = score_4d_matmul(qkv, head_size, triangular_masking, scale) + context_4d_matmul(output, qkv, head_size) + Fp16Matmul._update_autotune_table() + + def ds_compute_attention(self, qkv_out, input_mask, layer_past, alibi): + if isinstance(qkv_out, list): + qkv_out = qkv_out[0] + + no_masking = input_mask is None + + if no_masking: + input_mask = torch.empty(1) + + attn_key_value = self.score_context_func( + query_key_value=qkv_out, + attn_mask=((1 - input_mask).to(qkv_out.dtype) * + minus_inf) if input_mask.dtype == torch.int64 else input_mask, + heads=self.num_attention_heads_per_partition, + norm_factor=(1 / self.norm_factor if self.config.scale_attention else 1.0), + no_masking=no_masking, + layer_id=self.config.layer_id, + num_layers=TritonSelfAttention.num_layers, + alibi=alibi) + + context_layer, key_layer, value_layer = attn_key_value + return context_layer, key_layer, value_layer + + def forward( + self, + input, + input_mask, + head_mask=None, + layer_past=None, + get_present=False, # not used + encoder_hidden_states=None, # not used + encoder_attention_mask=None, # not used + triangularutput_attentions=False, # not used + norm_w=None, + norm_b=None, + alibi=None, + use_triton_attention=True): + + if not self.config.pre_layer_norm: + qkv_out = self.linear_func(input=input, + weight=self.attn_qkvw, + bias=self.attn_qkvb, + add_bias=self.attn_qkvb is not None, + do_flash_attn=False, + num_heads=self.num_attention_heads_per_partition, + num_layers=TritonSelfAttention.num_layers) + qkv = qkv_out + else: + qkv_out = self.qkv_func(input=input, + weight=self.attn_qkvw, + bias=(self.attn_qkvb if self.attn_qkvb is not None else norm_b), + gamma=norm_w, + beta=norm_b) + qkv = qkv_out[0] + + if use_triton_attention and (alibi is None): + context_layer = compute_attention(qkv=qkv, + input_mask=input_mask, + scale=self.scale, + layer_past=layer_past, + alibi=alibi, + head_size=self.head_size, + use_triton_flash=self.use_flash, + use_cuda_flash=False, + triangular=self.triangular_masking) + key_layer, value_layer = qkv[:, :, self.hidden_size:2 * self.hidden_size], qkv[:, :, 2 * self.hidden_size:] + else: + context_layer, key_layer, value_layer = self.ds_compute_attention(qkv_out=qkv_out, + input_mask=input_mask, + layer_past=layer_past, + alibi=alibi) + output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow) + + inp_norm = qkv_out[-1] + + if self.config.mlp_after_attn and self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1: + dist.all_reduce(output, group=self.mp_group) + + return (output, key_layer, value_layer, context_layer, inp_norm) + + +global inference_module + + +def compute_attention(qkv, + input_mask, + layer_past, + alibi, + scale, + head_size, + triangular=False, + use_cuda_flash=False, + use_triton_flash=False, + use_ds_attention=False): + if isinstance(qkv, list): + qkv = qkv[0] + + #assert layer_past is None, "layer_past not supported in triton yet" + assert alibi is None, "layer_past not supported in alibi yet" + output = score_4d_matmul(qkv, head_size, triangular, scale) + if triangular: + output = softmax(output) + else: + output = softmax(output, input_mask) + output = context_4d_matmul(output, qkv, head_size) + + return output diff --git a/deepspeed/ops/transformer/inference/triton/gelu.py b/deepspeed/ops/transformer/inference/triton/gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..17b943fb78bc44d22d602352e947aefdd1faa264 --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/gelu.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import triton +import triton.language as tl + + +@triton.jit +def gelu_functor(x): + # Using approximation introduces greater parity errors. + # return tl.sigmoid(1.702 * x) * x + return x * 0.5 * (1.0 + tl.libdevice.erf(x / 1.41421356237)) + + +@triton.jit +def gelu_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + output = gelu_functor(x) + tl.store(output_ptr + offsets, output, mask=mask) + + +def gelu(activations: torch.Tensor) -> torch.Tensor: + assert activations.is_contiguous() + assert activations.is_cuda + + output = torch.empty_like(activations) + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) + gelu_kernel[grid](activations, output, n_elements, BLOCK_SIZE=1024) + return output diff --git a/deepspeed/ops/transformer/inference/triton/layer_norm.py b/deepspeed/ops/transformer/inference/triton/layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..d3f313d2ac3d8205702dc2ceb82856154e9ddb2c --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/layer_norm.py @@ -0,0 +1,249 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import triton +import triton.language as tl +''' +layer-normalization +modified the triton kernel in +https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/05-layer-norm.py +''' + + +@triton.jit +def layer_norm_kernel( + Out, + A, + Weight, + Bias, + stride, + N, + eps, + BLOCK_SIZE: tl.constexpr, +): + # position of elements processed by this program + row = tl.program_id(0) + Out += row * stride + A += row * stride + # compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32) + a = tl.where(cols < N, a - mean, 0.0) + _var += a * a + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # multiply by weight and add bias + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + weight = tl.load(Weight + cols, mask=mask) + bias = tl.load(Bias + cols, mask=mask) + a = tl.load(A + cols, mask=mask, other=0.0).to(tl.float32) + a_hat = (a - mean) * rstd + out = a_hat * weight + bias + # # write-back + tl.store(Out + cols, out, mask=mask) + + +@triton.jit +def layer_norm_residual_kernel( + Out, + A, + Residual, + ln_input, + Weight, + Bias, + stride, + N, + eps, + BLOCK_SIZE: tl.constexpr, +): + # position of elements processed by this program + row = tl.program_id(0) + Out += row * stride + A += row * stride + Residual += row * stride + ln_input += row * stride + # compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32) + res = tl.load(Residual + cols, mask=cols < N, other=0.0).to(tl.float32) + a = a + res + tl.store(ln_input + cols, a, mask=cols < N) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(ln_input + cols, mask=cols < N, other=0.0).to(tl.float32) + a = tl.where(cols < N, a - mean, 0.0) + _var += a * a + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # multiply by weight and add bias + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + weight = tl.load(Weight + cols, mask=mask) + bias = tl.load(Bias + cols, mask=mask) + a = tl.load(ln_input + cols, mask=mask, other=0.0).to(tl.float32) + a_hat = (a - mean) * rstd + out = a_hat * weight + bias + # write-back + tl.store(Out + cols, out, mask=mask) + + +@triton.jit +def layer_norm_residual_bias_kernel( + Out, + A, + Residual, + InputBias, + ln_input, + Weight, + Bias, + stride, + N, + eps, + BLOCK_SIZE: tl.constexpr, +): + # position of elements processed by this program + row = tl.program_id(0) + Out += row * stride + A += row * stride + Residual += row * stride + ln_input += row * stride + # compute mean + mean = 0 + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32) + res = tl.load(Residual + cols, mask=cols < N, other=0.0).to(tl.float32) + b = tl.load(InputBias + cols, mask=cols < N, other=0.0).to(tl.float32) + a = a + b + res + tl.store(ln_input + cols, a, mask=cols < N) + _mean += a + mean = tl.sum(_mean, axis=0) / N + # compute variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(ln_input + cols, mask=cols < N, other=0.0).to(tl.float32) + a = tl.where(cols < N, a - mean, 0.0) + _var += a * a + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + # multiply by weight and add bias + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + weight = tl.load(Weight + cols, mask=mask) + bias = tl.load(Bias + cols, mask=mask) + a = tl.load(ln_input + cols, mask=mask, other=0.0).to(tl.float32) + a_hat = (a - mean) * rstd + out = a_hat * weight + bias + # write-back + tl.store(Out + cols, out, mask=mask) + + +def layer_norm(a, weight, bias, eps): + assert a.is_contiguous() + assert weight.is_contiguous() + assert bias.is_contiguous() + + # allocate output + out = torch.empty_like(a) + # reshape input data into 2D tensor + a_arg = a.view(-1, a.shape[-1]) + M, N = a_arg.shape + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // a.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + BLOCK_SIZE = max(BLOCK_SIZE, 128) + BLOCK_SIZE = min(BLOCK_SIZE, 4096) + BLOCK_SIZE = BLOCK_SIZE if N <= 4096 else 8192 + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + layer_norm_kernel[(M, )]( + out, + a_arg, + weight, + bias, + a_arg.stride(0), + N, + eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return out + + +def layer_norm_residual(a, input_bias, residual, weight, bias, eps): + assert a.is_contiguous() + assert weight.is_contiguous() + assert bias.is_contiguous() + assert residual.is_contiguous() + + # allocate output and scratch-pad for residual addition + out = torch.empty_like(a) + ln_input = torch.empty_like(a) + # reshape input data into 2D tensor + a_arg = a.view(-1, a.shape[-1]) + residual = residual.view(-1, residual.shape[-1]) + M, N = a_arg.shape + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // a.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + BLOCK_SIZE = max(BLOCK_SIZE, 128) + BLOCK_SIZE = min(BLOCK_SIZE, 4096) + BLOCK_SIZE = BLOCK_SIZE if N <= 4096 else 8192 + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + if input_bias is None: + layer_norm_residual_kernel[(M, )]( + out, + a_arg, + residual, + ln_input, + weight, + bias, + a_arg.stride(0), + N, + eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + else: + layer_norm_residual_bias_kernel[(M, )]( + out, + a_arg, + residual, + input_bias, + ln_input, + weight, + bias, + a_arg.stride(0), + N, + eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + return out diff --git a/deepspeed/ops/transformer/inference/triton/matmul_ext.py b/deepspeed/ops/transformer/inference/triton/matmul_ext.py new file mode 100644 index 0000000000000000000000000000000000000000..ad93831e1f84945a6f846426205c069b5652ec65 --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/matmul_ext.py @@ -0,0 +1,436 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import triton +import os +from filelock import FileLock +import deepspeed.ops.transformer.inference.triton.triton_matmul_kernel as triton_matmul_kernel +import pickle +from io import open + + +# ----------------------------------------------------------------------------- +# util class/functions for triton +def _default_cache_dir(): + return os.path.join(os.environ["HOME"], ".triton", "autotune") + + +def bias_add_activation(C, bias=None, activation=""): + if bias is not None: + C += bias + # activation + if activation == "relu": + relu = torch.nn.Relu() + C = relu(C) + elif activation == "leaky_relu": + leaky_relu = torch.nn.LeakyReLU(0.01) + C = leaky_relu(C) + elif activation == "gelu": + sigmoid = torch.nn.Sigmoid() + C = sigmoid(1.702 * C) * C + elif activation == "sigmoid": + sigmoid = torch.nn.Sigmoid() + C = sigmoid(C) + return C + + +class AutotuneCacheManager: + """ + Cache manager for autotune + """ + + def __init__(self, key): + self.key = key + self.file_path = None + self.lock_path = None + # if caching is enabled, get the lock and bin path + self.cache_dir = os.environ.get('TRITON_CACHE_DIR', _default_cache_dir()) + if self.cache_dir: + os.makedirs(self.cache_dir, exist_ok=True) + if self.cache_dir: + self.file_path = os.path.join(self.cache_dir, self.key + ".pickle") + self.lock_path = self.file_path + ".lock" + + def has_file(self): + return self.file_path and os.path.exists(self.file_path) + + def put(self, table): + if self.file_path: + assert self.lock_path is not None + with FileLock(self.lock_path): + with open(self.file_path + ".tmp", 'wb') as handle: + pickle.dump(table, handle) + os.rename(self.file_path + ".tmp", self.file_path) + + def load(self): + if os.path.exists(self.file_path): + with open(self.file_path, 'rb') as handle: + loaded_dict = pickle.load(handle) + return loaded_dict + else: + return None + + +# ----------------------------------------------------------------------------- +# triton matmul class + + +class MatmulExt(torch.autograd.Function): + """ + a wrapper class that can call different triton matmul kernels depending on the input parameters + """ + + @staticmethod + def forward(A, B, bias=None, activation="", use_triton=True, update_autotune_table=False): + """ + A: input, activation matrix A + B: input, weight matrix B + """ + matmul = None + quantize_activation = False + Batch = 0 + + if len(A.shape) == 3: # if A is 3d-tensor where batch index is given as 0-axis + assert A.is_contiguous(), "matrix A must be contiguous" + Batch, M, K = A.shape + A = A.view(-1, K) + + # fp16 activation and fp16 weight matmul into fp16 output + matmul = fp16_matmul + C = matmul.forward(A, B, use_triton=use_triton, bias=bias, activation=activation) + + if matmul and update_autotune_table: + matmul._update_autotune_table() + + if Batch > 0: + C = C.view(Batch, M, -1) + + return C + + +class TritonMatmul(torch.autograd.Function): + """ + triton matmul kernel superclass + """ + + def __init__(self): + pass + + @staticmethod + def _ref_forward(A, B, ref_dtype=torch.float32): + C = torch.matmul(A.type(ref_dtype), B.type(ref_dtype)) + return C + + @staticmethod + def _read_autotune_table(cache_key, triton_kernel): + cache_manager = AutotuneCacheManager(cache_key) + table = cache_manager.load() + if table: + triton_kernel.cache = table + + @staticmethod + def _write_autotune_table(cache_key, triton_kernel): + cache_manager = AutotuneCacheManager(cache_key) + cache_manager.put(triton_kernel.cache) + + @staticmethod + def _update_autotune_table(cache_key, triton_kernel): + cache_manager = AutotuneCacheManager(cache_key) + autotune_table = cache_manager.load() + if autotune_table is None: + autotune_table = dict() + autotune_table.update(triton_kernel.cache) # always overwrite with the new autotune results + cache_manager = AutotuneCacheManager(cache_key) + cache_manager.put(autotune_table) + + @staticmethod + def forward( + A, + B, + ref_dtype=torch.float32, # fp32 only + bias=None, + activation=""): + C = torch.matmul(A.type(ref_dtype), B.type(ref_dtype)) + C = bias_add_activation(C, bias, activation) + return C + + +class Fp16Matmul(TritonMatmul): + """ + fp16 matrix multiplication kernel + dtypes: fp16 x fp16 = fp16 + """ + + _2d_kernel = triton_matmul_kernel._fp_matmul + _4d_kernel = triton_matmul_kernel.matmul_4d_kernel + _cache_stride = 32 + + def __init__(self, read_cache=True): + super().__init__() + if read_cache: + __class__._read_autotune_table() + + def skip_autotune(self): + __class__._2d_kernel.configs = [__class__._2d_kernel.configs[0]] + __class__._4d_kernel.configs = [__class__._4d_kernel.configs[0]] + + @staticmethod + def forward(A, B, use_triton=True, bias=None, activation=""): + if use_triton: + device = A.device + # handle non-contiguous inputs if necessary + if A.stride(0) > 1 and A.stride(1) > 1: + A = A.contiguous() + if B.stride(0) > 1 and B.stride(1) > 1: + B = B.contiguous() + # checks constraints + assert A.shape[1] == B.shape[0], "incompatible dimensions" + M, K = A.shape + _, N = B.shape + # allocates output + C = torch.empty((M, N), device=device, dtype=A.dtype) + # accumulator types + ACC_TYPE = triton.language.float32 if A.dtype in [torch.float16, torch.bfloat16, torch.float32 + ] else triton.language.int32 + # launch kernel + grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K']) + __class__._2d_kernel[grid](A, + B, + C, + M, + N, + K, + bias, + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(1), + C.stride(0), + C.stride(1), + M // __class__._cache_stride, + N // __class__._cache_stride, + K // __class__._cache_stride, + GROUP_M=8, + ACC_TYPE=ACC_TYPE, + BIAS_ADD=(0 if bias is None else 1), + ACTIVATION=activation) + else: + C = torch.matmul(A, B) + return C + + @staticmethod + def _matmul_4d(a, b): + assert a.shape[-1] == b.shape[-2], "incompatible dimensions" + assert a.is_contiguous(), "matrix A must be contiguous" + assert b.is_contiguous(), "matrix B must be contiguous" + + B, H, M, K = a.shape + B, H, K, N = b.shape + + assert K > 1, "inner-product dimension K should be larger than 1" + + c = torch.empty((B, H, M, N), device=a.device, dtype=a.dtype) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), + H, + B, + ) + + __class__._4d_kernel[grid]( + a, + b, + c, + M, + N, + K, + M // __class__._cache_stride, + N // __class__._cache_stride, + K // __class__._cache_stride, + a.stride(0), + a.stride(1), + a.stride(2), + a.stride(3), + b.stride(0), + b.stride(1), + b.stride(2), + b.stride(3), + c.stride(0), + c.stride(1), + c.stride(2), + c.stride(3), + scale=-1.0, + MASK=False, + ) + return c + + @staticmethod + def _score_4d_matmul(input, head_size, input_mask, scale=-1.0): + assert input.is_contiguous(), "matrix input must be contiguous" + + batches = input.shape[0] + d_model = input.shape[-1] // 3 + num_of_heads = d_model // head_size + + q = input[:, :, :d_model] + k = input[:, :, d_model:d_model * 2] + + q = q.view(batches, -1, num_of_heads, head_size) + k = k.view(batches, -1, num_of_heads, head_size) + + # checks constraints + assert q.shape == k.shape, "incompatible dimensions" + B, M, H, K = q.shape + B, N, H, K = k.shape + + assert K > 1, "inner-product dimension K should be larger than 1" + + # allocates output + output = torch.empty((B, H, M, N), device=q.device, dtype=q.dtype) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + H, + B, + ) + __class__._4d_kernel[grid]( + q, + k, + output, + M, + N, + K, + M // __class__._cache_stride, + N // __class__._cache_stride, + K // __class__._cache_stride, + q.stride(0), + q.stride(2), + q.stride(1), + q.stride(3), + k.stride(0), + k.stride(2), + k.stride(3), + k.stride(1), + output.stride(0), + output.stride(1), + output.stride(2), + output.stride(3), + scale=scale, + MASK=False, + ) + return output + + @staticmethod + def _context_4d_matmul(prob, input, head_size): + assert prob.is_contiguous(), "matrix prob must be contiguous" + assert input.is_contiguous(), "matrix input must be contiguous" + + batches = input.shape[0] + d_model = input.shape[-1] // 3 + num_of_heads = d_model // head_size + + v = input[:, :, d_model * 2:] + + v = v.view(batches, -1, num_of_heads, head_size) + + # checks constraints + assert (prob.shape[0] == v.shape[0] and prob.shape[1] == v.shape[2] and prob.shape[2] == v.shape[1] + and prob.shape[3] == v.shape[1]), "incompatible dimensions" + B, H, M, K = prob.shape + B, K, H, N = v.shape + + assert K > 1, "inner-product dimension K should be larger than 1" + + # allocates output + output = torch.empty((B, M, H, N), device=v.device, dtype=v.dtype) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + H, + B, + ) + + __class__._4d_kernel[grid]( + prob, + v, + output, + M, + N, + K, + M // __class__._cache_stride, + N // __class__._cache_stride, + K // __class__._cache_stride, + prob.stride(0), + prob.stride(1), + prob.stride(2), + prob.stride(3), + v.stride(0), + v.stride(2), + v.stride(1), + v.stride(3), + # Here we also transpose the output when writing to memory. + output.stride(0), + output.stride(2), + output.stride(1), + output.stride(3), + scale=-1, + MASK=False, + ) + return output.view(batches, -1, d_model) + + @staticmethod + def _ref_forward(A, B, ref_dtype=torch.float32, bias=None, activation=""): + C = torch.matmul(A.type(ref_dtype), B.type(ref_dtype)) + C = bias_add_activation(C, bias, activation) + return C + + @staticmethod + def _check_parity(A, + B, + output_dtype, + SA=None, + SB=None, + qblock_size=None, + ref_dtype=torch.float32, + tol=0.01, + use_triton=True, + bias=None, + activation=""): + torch_output = __class__._ref_forward(A, B, ref_dtype=ref_dtype, bias=bias, activation=activation) + triton_output = __class__.forward(A, B, use_triton=use_triton, bias=bias, activation=activation) + assert triton.testing.allclose(triton_output.cpu().type(torch_output.dtype), torch_output.cpu(), tol=tol) + print(f"{__class__.__name__}: PASSed the parity check") + return triton_output, torch_output + + @staticmethod + def _read_autotune_table(): + TritonMatmul._read_autotune_table(__class__.__name__ + "_2d_kernel", __class__._2d_kernel) + TritonMatmul._read_autotune_table(__class__.__name__ + "_4d_kernel", __class__._4d_kernel) + + @staticmethod + def _write_autotune_table(): + TritonMatmul._write_autotune_table(__class__.__name__ + "_2d_kernel", __class__._2d_kernel) + TritonMatmul._write_autotune_table(__class__.__name__ + "_4d_kernel", __class__._4d_kernel) + + @staticmethod + def _update_autotune_table(): + TritonMatmul._update_autotune_table(__class__.__name__ + "_2d_kernel", __class__._2d_kernel) + TritonMatmul._update_autotune_table(__class__.__name__ + "_4d_kernel", __class__._4d_kernel) + + +# ----------------------------------------------------------------------------- +# mapping +matmul = MatmulExt.forward +fp16_matmul = Fp16Matmul() +matmul_4d = fp16_matmul._matmul_4d +score_4d_matmul = fp16_matmul._score_4d_matmul +context_4d_matmul = fp16_matmul._context_4d_matmul + +# +import atexit + + +@atexit.register +def matmul_ext_update_autotune_table(): + fp16_matmul._update_autotune_table() diff --git a/deepspeed/ops/transformer/inference/triton/mlp.py b/deepspeed/ops/transformer/inference/triton/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..1708080b27efb0671d361b7cdeaff4b262cf0ce8 --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/mlp.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import math +import torch.nn as nn +from deepspeed.accelerator import get_accelerator +from deepspeed import comm as dist +from ..op_binding import MLPGemmOp, VectorMatMulOp, GELUGemmOp, ResidualAddOp + + +class TritonMLP(nn.Module): + + def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1, mlp_extra_grouping=False): + super(TritonMLP, self).__init__() + + self.config = config + data_type = self.config.dtype + data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype + device = get_accelerator().current_device_name() + self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), + requires_grad=False) + self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), + requires_grad=False) + intm_size_per_partition = self.config.intermediate_size // self.config.mp_size + self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size, + intm_size_per_partition, + dtype=data_type, + device=device), + requires_grad=False) + self.inter_b = nn.Parameter(torch.empty(intm_size_per_partition, dtype=data_type_fp, device=device), + requires_grad=False) + self.output_w = nn.Parameter(torch.empty(intm_size_per_partition, + self.config.hidden_size, + dtype=data_type, + device=device), + requires_grad=False) + self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device), + requires_grad=False) + + # used for quantization + self.q_scales = q_scales + self.q_groups = q_groups * 2 if mlp_extra_grouping else q_groups + self.merge_count = int(math.log2(merge_count)) + self.mp_group = mp_group + + self.mlp_gemm_func = MLPGemmOp(config) + self.vector_matmul_func = VectorMatMulOp(config) + self.fused_gemm_gelu = GELUGemmOp(config) + self.residual_add_func = ResidualAddOp(config) + + def forward(self, input, residual, residual_norm, bias): + residual_add = None + if self.attn_nw is None: + output = self.fused_gemm_gelu(input=residual_norm, + weight=self.inter_w, + bias=self.inter_b, + weight_out=self.output_w) + else: + output, residual_add = self.mlp_gemm_func(input=input, + residual=residual, + input_bias=bias, + weight_interm=self.inter_w, + weight_out=self.output_w, + bias=self.inter_b, + gamma=self.attn_nw, + beta=self.attn_nb) + residual = self.residual_add_func(hidden_state=output, + residual=residual, + attention_output=input, + attention_bias=bias if bias is not None else self.output_b, + final_bias=self.output_b, + add_bias=bias is not None, + residual_add=residual_add) + + if self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1: + dist.all_reduce(residual, group=self.mp_group) + + return residual diff --git a/deepspeed/ops/transformer/inference/triton/ops.py b/deepspeed/ops/transformer/inference/triton/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..dd87d08d4d2c6c0e4b80ab770ff9a1f03dd3138c --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/ops.py @@ -0,0 +1,131 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed +from deepspeed.ops.op_builder import InferenceBuilder +import deepspeed.ops.transformer.inference.triton.matmul_ext as matmul_ext +from deepspeed.ops.transformer.inference.triton.layer_norm import layer_norm, layer_norm_residual + +inference_module = None + + +def vector_matmul_func(input, weight, async_op, q_scale, q_int8, transposed_mode): + assert not transposed_mode and not async_op and not q_int8 + return matmul_ext.matmul(input, weight, bias=None, activation="", use_triton=True) + + +def fused_gemm_gelu(input, + weight, + weight_scale, + bias, + weight_out, + weight_out_scale, + epsilon, + pre_layer_norm, + q_int8, + async_op, + transposed_mode, + use_triton_ln=True): + assert not transposed_mode + + # activation + activation = "gelu" + + # intermediate fc in FF + intm_out = matmul_ext.matmul(input, weight, bias=bias, activation=activation, use_triton=True) + + # output fc in FF + ff_out = matmul_ext.matmul( + intm_out, + weight_out, + bias=None, + activation="", # bias added layer with residual_add + bias + layerNorm layer + use_triton=True) + return ff_out + + +def linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads, transposed_mode=False): + assert not transposed_mode and not do_flash_attn + qkv_out = matmul_ext.matmul(input, weight, bias=(bias if add_bias else None), activation="", use_triton=True) + + return qkv_out + + +def mlp_gemm_func(input, + residual, + input_bias, + weight_interm, + weight_out, + bias, + gamma, + beta, + epsilon, + pre_layer_norm, + mlp_after_attn, + weight_interm_scale, + weight_out_scale, + q_int8, + mlp_act_func_type, + transposed_mode, + use_triton_ln=True): + assert not transposed_mode + + # residual add and layerNorm after attention + if use_triton_ln: + mlp_input = layer_norm_residual(input, input_bias, residual, gamma, beta, epsilon) + else: + global inference_module + if inference_module is None: + inference_module = InferenceBuilder().load() + mlp_input = inference_module._layer_norm_residual(input, input_bias, residual, gamma, beta, epsilon) + + # activation + if deepspeed.utils.types.ActivationFuncType(mlp_act_func_type) == deepspeed.utils.types.ActivationFuncType.GELU: + activation = "gelu" + elif deepspeed.utils.types.ActivationFuncType(mlp_act_func_type) == deepspeed.utils.types.ActivationFuncType.ReLU: + activation = "relu" + else: + activation = "" + + # intermediate fc in FF + intm_out = matmul_ext.matmul(mlp_input, weight_interm, bias=bias, activation=activation, use_triton=True) + # output fc in FF + ff_out = matmul_ext.matmul( + intm_out, + weight_out, + bias=None, + activation="", # bias added layer with residual_add + bias + layerNorm layer + use_triton=True) + + return ff_out, mlp_input + + +def qkv_gemm_func( + input, + weight, + q_scale, + bias, + gamma, + beta, + epsilon, + add_bias, + q_int8, + transposed_mode=False, + use_triton_ln=True, +): + + assert not transposed_mode + # residual add and layerNorm after attention + if use_triton_ln: + qkv_input = layer_norm(input, gamma, beta, epsilon) + else: + global inference_module + if inference_module is None: + inference_module = InferenceBuilder().load() + qkv_input = inference_module.layer_norm(input, gamma, beta, epsilon) + + qkv_out = matmul_ext.matmul(qkv_input, weight, bias=(bias if add_bias else None), activation="", use_triton=True) + + return qkv_out, qkv_input diff --git a/deepspeed/ops/transformer/inference/triton/residual_add.py b/deepspeed/ops/transformer/inference/triton/residual_add.py new file mode 100644 index 0000000000000000000000000000000000000000..9c8ff5400af5d768c0fdbadebb17555d2ab6ac53 --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/residual_add.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import triton +import triton.language as tl + + +@triton.jit +def residual_add_bias_kernel( + hidden_state_ptr, + residual_ptr, + attn_output_ptr, + hidden_state_size, + attn_bias_ptr, + final_bias_ptr, + bias_size, + output_ptr, + mp_size: tl.constexpr, + mlp_after_attn: tl.constexpr, + pre_attn_norm: tl.constexpr, + add_attn_bias: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(axis=0) + + block_start = pid * BLOCK_SIZE + + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < hidden_state_size + + bias_offsets = offsets % bias_size + bias_mask = bias_offsets < bias_size + + tl_hidden_state = tl.load(hidden_state_ptr + offsets, mask=mask) + tl_residual = tl.load(residual_ptr + offsets, mask=mask) + tl_attn_output = tl.load(attn_output_ptr + offsets, mask=mask) + tl_attn_bias = tl.load(attn_bias_ptr + bias_offsets, mask=bias_mask) + tl_final_bias = tl.load(final_bias_ptr + bias_offsets, mask=bias_mask) + + if mlp_after_attn: + if pre_attn_norm: + output = tl_hidden_state + (tl_residual + tl_final_bias + tl_attn_output + tl_attn_bias) / mp_size + else: + output = tl_hidden_state + tl_residual + tl_final_bias + else: + output = tl_hidden_state + tl_attn_output + (tl_residual + tl_final_bias) / mp_size + if add_attn_bias: + output += tl_attn_bias / mp_size + + tl.store(output_ptr + offsets, output, mask=mask) + + +def residual_add_bias(hidden_state: torch.Tensor, residual: torch.Tensor, attn_output: torch.Tensor, + attn_bias: torch.Tensor, final_bias: torch.Tensor, mp_size: int, mlp_after_attn: bool, + add_attn_bias: bool, pre_attn_norm: bool): + # check that all tensors are on the same device + assert hidden_state.is_cuda and residual.is_cuda and attn_output.is_cuda \ + and attn_bias.is_cuda and final_bias.is_cuda + + # check that all tensors have the same dtype + assert hidden_state.dtype == residual.dtype == attn_output.dtype \ + == attn_bias.dtype == final_bias.dtype + + # check that all tensors have the right shape + assert hidden_state.shape == residual.shape == attn_output.shape + assert attn_bias.shape == final_bias.shape + assert attn_bias.shape[0] == hidden_state.shape[2] + + output = torch.empty_like(hidden_state) + + hidden_state_size = output.numel() + bias_size = attn_bias.numel() + + grid = lambda meta: (triton.cdiv(hidden_state_size, meta['BLOCK_SIZE']), ) + + residual_add_bias_kernel[grid](hidden_state, residual, attn_output, hidden_state_size,\ + attn_bias, final_bias, bias_size, output, mp_size, mlp_after_attn, pre_attn_norm, \ + add_attn_bias, \ + BLOCK_SIZE=1024) + + return output diff --git a/deepspeed/ops/transformer/inference/triton/softmax.py b/deepspeed/ops/transformer/inference/triton/softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..1ee10d63e6cf8bfa6723856b53b7ca9ec30d3fdd --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/softmax.py @@ -0,0 +1,89 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import torch +import triton +import triton.language as tl +''' +softmax +modified the triton kernel in +https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py +''' + + +@triton.jit +def softmax_kernel(output_ptr, input_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr): + row_idx = tl.program_id(0) + row_start_ptr = input_ptr + row_idx * stride + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32) + row_minus_max = row - tl.max(row, axis=0) + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + output_row_start_ptr = output_ptr + row_idx * stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) + + +@triton.jit +def masked_softmax_kernel(output_ptr, input_ptr, stride, mask_ptr, mask_stride, n_cols, BLOCK_SIZE: tl.constexpr): + row_idx = tl.program_id(0) + row_start_ptr = input_ptr + row_idx * stride + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + mask_ptrs = mask_ptr + col_offsets + row_idx * mask_stride # mask_stride is 0 for 1d mask + row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32) + mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32) + row_minus_max = row - tl.max(row, axis=0) + row_minus_max = row_minus_max + mask + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + output_row_start_ptr = output_ptr + row_idx * stride + output_ptrs = output_row_start_ptr + col_offsets + tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) + + +def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor: + assert input.is_contiguous() + assert (dim == -1) or (dim == len(input.shape) - 1), "Only dim=-1 is supported" + + use_mask = False if mask is None else True + input_arg = input.view(-1, input.shape[-1]) + n_rows, n_cols = input_arg.shape + BLOCK_SIZE = max(triton.next_power_of_2(n_cols), 2) + num_warps = 4 + if BLOCK_SIZE >= 2048: + num_warps = 8 + if BLOCK_SIZE >= 4096: + num_warps = 16 + # Allocate output + output = torch.empty_like(input) + if use_mask: + assert mask.is_contiguous() + mask = mask.view(-1, mask.shape[-1]) + mask_stride = mask.shape[-1] if mask.shape[-2] > 1 else 0 + masked_softmax_kernel[(n_rows, )]( + output, + input, + input_arg.stride(0), + mask, + mask_stride, + n_cols, + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + ) + else: + softmax_kernel[(n_rows, )]( + output, + input, + input_arg.stride(0), + n_cols, + num_warps=num_warps, + BLOCK_SIZE=BLOCK_SIZE, + ) + return output diff --git a/deepspeed/ops/transformer/inference/triton/triton_matmul_kernel.py b/deepspeed/ops/transformer/inference/triton/triton_matmul_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..9d647ea090f60c7645a733495178f8f3b2e6545e --- /dev/null +++ b/deepspeed/ops/transformer/inference/triton/triton_matmul_kernel.py @@ -0,0 +1,377 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import triton +import triton.language as tl +from .gelu import gelu_functor +import torch + +AUTOTUNE_TOP_K = 10 +SKIP_AUTOTUNE = False + + +def _fp16_matmul_prune_config(configs, named_args, skip_autotune=SKIP_AUTOTUNE): + if skip_autotune: + configs = [configs[0]] + else: + configs = triton.ops.matmul_perf_model.early_config_prune(configs, named_args) + return configs + + +""" +fp16 matmul implementation is adapted from triton matmul: +https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/triton/ops/matmul.py +""" + + +@triton.autotune( + configs=[ + # basic configs for compute-bound matmuls + triton.Config({ + 'BLOCK_M': 128, + 'BLOCK_N': 256, + 'BLOCK_K': 32, + 'SPLIT_K': 1 + }, num_stages=3, num_warps=8), + triton.Config({ + 'BLOCK_M': 256, + 'BLOCK_N': 128, + 'BLOCK_K': 32, + 'SPLIT_K': 1 + }, num_stages=3, num_warps=8), + triton.Config({ + 'BLOCK_M': 256, + 'BLOCK_N': 64, + 'BLOCK_K': 32, + 'SPLIT_K': 1 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_M': 64, + 'BLOCK_N': 256, + 'BLOCK_K': 32, + 'SPLIT_K': 1 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_M': 128, + 'BLOCK_N': 128, + 'BLOCK_K': 32, + 'SPLIT_K': 1 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_M': 128, + 'BLOCK_N': 64, + 'BLOCK_K': 32, + 'SPLIT_K': 1 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_M': 64, + 'BLOCK_N': 128, + 'BLOCK_K': 32, + 'SPLIT_K': 1 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_M': 128, + 'BLOCK_N': 32, + 'BLOCK_K': 32, + 'SPLIT_K': 1 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_M': 64, + 'BLOCK_N': 32, + 'BLOCK_K': 32, + 'SPLIT_K': 1 + }, num_stages=5, num_warps=2), + ], + key=['CACHE_M', 'CACHE_N', 'CACHE_K'], + prune_configs_by={ + 'early_config_prune': _fp16_matmul_prune_config, + 'perf_model': None, + 'top_k': AUTOTUNE_TOP_K + }, +) +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, +}) +@triton.jit +def _fp_matmul( + A, + B, + C, + M, + N, + K, + bias, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + CACHE_M, + CACHE_N, + CACHE_K, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, + SPLIT_K: tl.constexpr, + EVEN_K: tl.constexpr, + ACC_TYPE: tl.constexpr, + BIAS_ADD: tl.constexpr, + ACTIVATION: tl.constexpr, +): + # matrix multiplication + pid = tl.program_id(0) + pid_z = tl.program_id(1) + grid_m = (M + BLOCK_M - 1) // BLOCK_M + grid_n = (N + BLOCK_N - 1) // BLOCK_N + # re-order program ID for better L2 performance + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // (group_size) + # do matrix multiplication + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + # pointers + A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) + B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + for k in range(K, 0, -BLOCK_K * SPLIT_K): + if EVEN_K: + a = tl.load(A) + b = tl.load(B) + else: + a = tl.load(A, mask=rk[None, :] < k, other=0.) + b = tl.load(B, mask=rk[:, None] < k, other=0.) + acc += tl.dot(a, b) + A += BLOCK_K * SPLIT_K * stride_ak + B += BLOCK_K * SPLIT_K * stride_bk + # bias addition + if BIAS_ADD: + bias_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + bias_ptr = bias + bias_offset + b = tl.load(bias_ptr, mask=bias_offset < N) + acc = acc + b[None, :] + # activation + if ACTIVATION == "relu": + acc = tl.where(acc >= 0, acc, 0) + elif ACTIVATION == "leaky_relu": + acc = tl.where(acc >= 0, acc, 0.01 * acc) + elif ACTIVATION == "gelu": + #acc = tl.sigmoid(1.702 * acc) * acc + acc = gelu_functor(acc) + elif ACTIVATION == "sigmoid": + acc = tl.sigmoid(acc) # sigmoid + acc = acc.to(C.dtype.element_ty) + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) + mask = (rm < M)[:, None] & (rn < N)[None, :] + # handles write-back with reduction-splitting + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +def matmul_4d_prune_config(configs, named_args, skip_autotune=SKIP_AUTOTUNE): + if skip_autotune: + configs = [configs[0]] + else: + device = torch.cuda.current_device() #ignore-cuda + capability = torch.cuda.get_device_capability() #ignore-cuda + # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages + dtsize = named_args['a_ptr'].element_size() + dtype = named_args['a_ptr'].dtype + + # make sure we have enough smem + pruned_configs = [] + for config in configs: + kw = config.kwargs + BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \ + kw['BLOCK_SIZE_M'], kw['BLOCK_SIZE_N'], kw['BLOCK_SIZE_K'], config.num_stages + + triton.compiler.init_cuda_utils() + max_shared_memory = triton.compiler.cuda_utils.get_device_properties(device)["max_shared_mem"] + required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize + if required_shared_memory <= max_shared_memory: + pruned_configs.append(config) + configs = pruned_configs + return configs + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8 + }, + num_stages=1, # this is mainly for unit test, to minimize the share memory usage + num_warps=8), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + }, + num_stages=5, + num_warps=2, + ), + ], + key=['CACHE_M', 'CACHE_N', 'CACHE_K'], + prune_configs_by={ + 'early_config_prune': matmul_4d_prune_config, + 'perf_model': None, + 'top_k': AUTOTUNE_TOP_K + }, +) +@triton.jit +def matmul_4d_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + # Matrix dimensions + M, + N, + K, + CACHE_M, + CACHE_N, + CACHE_K, + stride_ab, + stride_ah, + stride_am, + stride_ak, + stride_bb, + stride_bh, + stride_bk, + stride_bn, + stride_cb, + stride_ch, + stride_cm, + stride_cn, + scale, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MASK: tl.constexpr, +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + pid = tl.program_id(axis=0) + head = tl.program_id(axis=1) + batch = tl.program_id(axis=2) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + if MASK: + if (pid_m + 1) * BLOCK_SIZE_M - 1 < pid_n * BLOCK_SIZE_N: + c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=c_ptr.dtype.element_ty) - float("inf") + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :]) + tl.store(c_ptrs, c) + return + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah + + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)) + b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, K, BLOCK_SIZE_K): + a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K) + b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N) + a = tl.load(a_ptrs, mask=a_mask, other=0.) + b = tl.load(b_ptrs, mask=b_mask, other=0.) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(c_ptr.dtype.element_ty) + if scale > 0: + c = c * scale.to(c_ptr.dtype.element_ty) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + if MASK: + c += tl.where(offs_cm[:, None] >= offs_cn[None, :], 0, float("-inf")) + c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_cm[:, None] + + stride_cn * offs_cn[None, :]) + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index ce1dd3429eea4c9569e5212495a709a25698b339..731b67a26814685fed42f4932d6d9a196558ae60 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -13,4 +13,5 @@ sphinx-rtd-theme tensorboard torchvision transformers +triton wandb diff --git a/requirements/requirements-inf.txt b/requirements/requirements-inf.txt index 848a7f7a485de46d9941b3de6919b42724f76409..ef8bfff774d660e92c6803c2892905aafcb9fbce 100644 --- a/requirements/requirements-inf.txt +++ b/requirements/requirements-inf.txt @@ -3,3 +3,4 @@ lm-eval==0.3.0 protobuf transformers transformers[sentencepiece] +triton diff --git a/requirements/requirements-triton.txt b/requirements/requirements-triton.txt new file mode 100644 index 0000000000000000000000000000000000000000..a59a965090a6473278f0cf9e7fd1d3cb9cb385c9 --- /dev/null +++ b/requirements/requirements-triton.txt @@ -0,0 +1 @@ +triton diff --git a/setup.py b/setup.py index 5d0aba18f2bb5fa2e514dc15338c2442b9cf09e4..1b5835f3cfe4805a21040bc9ec79262b5139e998 100755 --- a/setup.py +++ b/setup.py @@ -67,7 +67,8 @@ extras_require = { 'sparse_attn': fetch_requirements('requirements/requirements-sparse_attn.txt'), 'sparse': fetch_requirements('requirements/requirements-sparse_pruning.txt'), 'inf': fetch_requirements('requirements/requirements-inf.txt'), - 'sd': fetch_requirements('requirements/requirements-sd.txt') + 'sd': fetch_requirements('requirements/requirements-sd.txt'), + 'triton': fetch_requirements('requirements/requirements-triton.txt'), } # Add specific cupy version to both onebit extension variants. diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py index c42deb3dd6d743a63f49c93177640bdf98cb0276..f1fac5ba1ecbe2b545ec64dd69de9fb006414101 100644 --- a/tests/unit/inference/test_inference.py +++ b/tests/unit/inference/test_inference.py @@ -95,13 +95,18 @@ def enable_cuda_graph(request): return request.param +@pytest.fixture(params=[True, False], ids=["Triton", "noTriton"]) +def enable_triton(request): + return request.param + + """ This fixture will validate the configuration """ @pytest.fixture() -def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph): +def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph, enable_triton): model, task = model_w_task msg = "" if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"): @@ -125,6 +130,12 @@ def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph): msg = f"Bloom models only support half precision, cannot use dtype {dtype}" elif ("bert" not in model.lower()) and enable_cuda_graph: msg = "Non bert/roberta models do no support CUDA Graph" + elif enable_triton and not (dtype in [torch.half]): + msg = "Triton is for fp16" + elif enable_triton and not deepspeed.HAS_TRITON: + msg = "triton needs to be installed for the test" + elif ("bert" not in model.lower()) and enable_triton: + msg = "Triton kernels do not support Non bert/roberta models yet" return msg @@ -249,16 +260,16 @@ Tests class TestModelTask(DistributedTest): world_size = 1 - def test( - self, - model_w_task, - dtype, - enable_cuda_graph, - query, - inf_kwargs, - assert_fn, - invalid_model_task_config, - ): + def test(self, + model_w_task, + dtype, + enable_cuda_graph, + enable_triton, + query, + inf_kwargs, + assert_fn, + invalid_model_task_config, + perf_meas=True): if invalid_model_task_config: pytest.skip(invalid_model_task_config) @@ -284,13 +295,18 @@ class TestModelTask(DistributedTest): get_accelerator().synchronize() bs_time = time.time() - start - pipe.model = deepspeed.init_inference( - pipe.model, - mp_size=1, - dtype=dtype, - replace_with_kernel_inject=True, - enable_cuda_graph=enable_cuda_graph, - ) + args = { + 'mp_size': 1, + 'dtype': dtype, + 'replace_with_kernel_inject': True, + 'enable_cuda_graph': enable_cuda_graph, + 'use_triton': enable_triton, + 'triton_autotune': False, + } + if pipe.tokenizer.model_max_length < deepspeed.ops.transformer.inference.config.DeepSpeedInferenceConfig( + ).max_out_tokens: + args.update({'max_out_tokens': pipe.tokenizer.model_max_length}) + pipe.model = deepspeed.init_inference(pipe.model, **args) check_injection(pipe.model) # Warm-up queries for perf measurement #for i in range(10): @@ -301,6 +317,11 @@ class TestModelTask(DistributedTest): get_accelerator().synchronize() ds_time = time.time() - start + if perf_meas: + print( + f"model={model}, task={task}, dtype={dtype}, cuda_graph={enable_cuda_graph}, triton={enable_triton}, bs_time={bs_time}, ds_time={ds_time}" + ) + # facebook/opt* and some bigscient/bloom* models are not matching # baseline exactly, adding an exception to them for now if ("opt" in model) or ("bloom" in model): @@ -309,6 +330,7 @@ class TestModelTask(DistributedTest): # These performance tests are only measuring the time for a single # inference request, we just want to check that performance isn't terrible #assert ds_time <= (bs_time * 1.1) + assert assert_fn(bs_output, ds_output) diff --git a/tests/unit/ops/transformer/inference/test_attention.py b/tests/unit/ops/transformer/inference/test_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..760223115a280eca5b756372ca0520953754da10 --- /dev/null +++ b/tests/unit/ops/transformer/inference/test_attention.py @@ -0,0 +1,73 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed + + +# reference timplementation +def ref_torch_attention(q, k, v, mask, sm_scale): + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + p = torch.softmax(p.float() + mask, dim=-1).half() + ref_out = torch.matmul(p, v) + return ref_out + + +# test attention operator +@pytest.mark.inference_ops +@pytest.mark.parametrize("Z", [1]) # batch +@pytest.mark.parametrize("H", [12]) # heads +@pytest.mark.parametrize("N_CTX", [4, 128]) # sequence length +@pytest.mark.parametrize("D_HEAD", [64, 128]) +@pytest.mark.parametrize("causal", [True, False]) +def test_attention(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): + if not deepspeed.HAS_TRITON: + pytest.skip("triton has to be installed for the test") + + # skip autotune in testing + from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul + fp16_matmul.skip_autotune() + + import triton + from deepspeed.ops.transformer.inference.triton.attention import compute_attention + torch.manual_seed(20) + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5) + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5) + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5) + sm_scale = 0.3 + + # reference implementation + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + score = p + mask = torch.zeros((Z, H, N_CTX, N_CTX), dtype=dtype, device="cuda") + M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda")) + if causal: + for z in range(Z): + for h in range(H): + mask[:, :, M == 0] = float("-inf") + p = torch.softmax(p.float() + mask, dim=-1).half() + softmax_out = p + ref_out = torch.matmul(p, v) + context = ref_out + + # adjust it to expected tensor format and run test + qkv = torch.randn((Z, N_CTX, 3 * H * D_HEAD), dtype=dtype, device='cuda', requires_grad=False) + qkv[:, :, :H * D_HEAD] = q.permute(0, 2, 1, 3).contiguous().reshape((Z, N_CTX, H * D_HEAD)) + qkv[:, :, 1 * H * D_HEAD:2 * H * D_HEAD] = k.permute(0, 2, 1, 3).contiguous().reshape((Z, N_CTX, H * D_HEAD)) + qkv[:, :, 2 * H * D_HEAD:] = v.permute(0, 2, 1, 3).contiguous().reshape((Z, N_CTX, H * D_HEAD)) + tri_out = compute_attention(qkv, + input_mask=mask, + layer_past=None, + alibi=None, + scale=sm_scale, + head_size=D_HEAD, + triangular=False, + use_cuda_flash=False, + use_triton_flash=False, + use_ds_attention=False) + tri_out = tri_out.reshape((Z, N_CTX, H, D_HEAD)).permute(0, 2, 1, 3) + triton.testing.allclose(ref_out, tri_out) + triton.testing.assert_almost_equal(ref_out, tri_out) diff --git a/tests/unit/ops/transformer/inference/test_gelu.py b/tests/unit/ops/transformer/inference/test_gelu.py new file mode 100644 index 0000000000000000000000000000000000000000..de924848bfb4a42de52cc4c84aca024805e861a6 --- /dev/null +++ b/tests/unit/ops/transformer/inference/test_gelu.py @@ -0,0 +1,70 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed +from deepspeed.ops.op_builder import InferenceBuilder + +if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("Inference ops are not available on this system", allow_module_level=True) + +inference_module = None +torch_minor_version = None + + +def allclose(x, y): + assert x.dtype == y.dtype + rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype] + return torch.allclose(x, y, rtol=rtol, atol=atol) + + +def version_appropriate_gelu(activations): + global torch_minor_version + if torch_minor_version is None: + torch_minor_version = int(torch.__version__.split('.')[1]) + # If torch version = 1.12 + if torch_minor_version < 12: + return torch.nn.functional.gelu(activations) + else: + return torch.nn.functional.gelu(activations, approximate='tanh') + + +def run_gelu_reference(activations): + # Expected behavior is that of casting to float32 internally and using the tanh approximation + return version_appropriate_gelu(activations.to(torch.float32)).to(activations.dtype) + + +def run_gelu_ds(activations, use_triton_ops=False): + if use_triton_ops: + from deepspeed.ops.transformer.inference.triton import gelu + return gelu(activations) + + channels = activations.shape[-1] + bias = torch.zeros((channels), dtype=activations.dtype, device='cuda') + global inference_module + if inference_module is None: + inference_module = InferenceBuilder().load() + if activations.dtype == torch.float16: + return inference_module.bias_gelu_fp16(activations, bias) + else: + return inference_module.bias_gelu_fp32(activations, bias) + + +@pytest.mark.inference_ops +@pytest.mark.parametrize("batch", [1, 2]) +@pytest.mark.parametrize("sequence", [1, 128, 255]) +@pytest.mark.parametrize("channels", [512, 1232, 4096]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("use_triton_ops", [True, False]) +def test_gelu(batch, sequence, channels, dtype, use_triton_ops): + activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device='cuda') + activations_ref = activations_ds.clone().detach() + + if not deepspeed.HAS_TRITON and use_triton_ops: + pytest.skip("triton has to be installed for the test") + ds_out = run_gelu_ds(activations_ds, use_triton_ops) + ref_out = run_gelu_reference(activations_ref) + assert (allclose(ds_out, ref_out)) diff --git a/tests/unit/ops/transformer/inference/test_layer_norm.py b/tests/unit/ops/transformer/inference/test_layer_norm.py index 25728d41a0cd80c14b01425eba58f92b91def421..952904a7847cd855bd7e195ee5480236c0494c32 100644 --- a/tests/unit/ops/transformer/inference/test_layer_norm.py +++ b/tests/unit/ops/transformer/inference/test_layer_norm.py @@ -9,6 +9,14 @@ import pytest from deepspeed.accelerator import get_accelerator from deepspeed.ops.op_builder import InferenceBuilder from .inference_test_utils import allclose, get_dtypes +try: + import triton # noqa: F401 + from deepspeed.ops.transformer.inference.triton import ( + layer_norm, + layer_norm_residual, + ) +except ImportError: + print("triton import failed") if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) @@ -30,19 +38,32 @@ def ds_implementation(vals, gamma, beta, epsilon): return inference_module.layer_norm(vals, gamma, beta, epsilon) +def ds_triton_implementation(vals, gamma, beta, epsilon): + return layer_norm(vals, gamma, beta, epsilon) + + @pytest.mark.inference_ops @pytest.mark.parametrize("batch", [1, 32]) @pytest.mark.parametrize("seq_len", [1, 128]) @pytest.mark.parametrize("channels", [384, 512, 768, 1024, 2048, 8192, 14432]) @pytest.mark.parametrize("dtype", get_dtypes()) -def test_layer_norm(batch, seq_len, channels, dtype): +@pytest.mark.parametrize("use_triton_ops", [False, True]) +def test_layer_norm(batch, seq_len, channels, dtype, use_triton_ops): + if not deepspeed.HAS_TRITON and use_triton_ops: + pytest.skip("triton has to be installed for the test") + vals = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name()) gamma = torch.randn((channels), dtype=dtype, device=get_accelerator().current_device_name()) beta = torch.rand((channels), dtype=dtype, device=get_accelerator().current_device_name()) epsilon = 1e-5 ref_output = ref_implementation(vals, gamma, beta, epsilon, channels, dtype) - new_output = ds_implementation(vals, gamma, beta, epsilon) + if use_triton_ops: + new_output = ds_triton_implementation(vals, gamma, beta, epsilon) + if dtype != torch.float16: # fp16 supported in triton + return + else: + new_output = ds_implementation(vals, gamma, beta, epsilon) if not allclose(new_output, ref_output): #print(new_output - ref_output) @@ -68,12 +89,20 @@ def residual_ds_implementation(vals, bias, res, gamma, beta, epsilon): return inference_module._layer_norm_residual(vals, bias, res, gamma, beta, epsilon) +def residual_ds_triton_implementation(vals, bias, res, gamma, beta, epsilon): + return layer_norm_residual(vals, bias, res, gamma, beta, epsilon) + + @pytest.mark.inference_ops @pytest.mark.parametrize("batch", [1, 32]) @pytest.mark.parametrize("seq_len", [1, 128]) @pytest.mark.parametrize("channels", [384, 512, 768, 1024, 2048, 8192, 14432]) @pytest.mark.parametrize("dtype", get_dtypes()) -def test_layer_norm_residual(batch, seq_len, channels, dtype): +@pytest.mark.parametrize("use_triton_ops", [False, True]) +def test_layer_norm_residual(batch, seq_len, channels, dtype, use_triton_ops): + if not deepspeed.HAS_TRITON and use_triton_ops: + pytest.skip("triton has to be installed for the test") + vals = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name()) residual = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name()) bias = torch.randn((channels), dtype=dtype, device=get_accelerator().current_device_name()) @@ -81,7 +110,13 @@ def test_layer_norm_residual(batch, seq_len, channels, dtype): beta = torch.rand((channels), dtype=dtype, device=get_accelerator().current_device_name()) epsilon = 1e-5 - new_output = residual_ds_implementation(vals, bias, residual, gamma, beta, epsilon) + if use_triton_ops: + new_output = residual_ds_triton_implementation(vals, bias, residual, gamma, beta, epsilon) + if dtype != torch.float16: # fp16 supported in triton + return + else: + new_output = residual_ds_implementation(vals, bias, residual, gamma, beta, epsilon) + ref_output = residual_ref_implementation(vals, bias, residual, gamma, beta, epsilon, channels, dtype) print((new_output - ref_output).abs().max()) @@ -129,3 +164,38 @@ def test_layer_norm_residual_store_pre_ln_res(batch, seq_len, channels, dtype): assert allclose(ds_res_output, norm_res_output) assert allclose(ds_norm_output, ref_norm_output) + + +@pytest.mark.inference_ops +@pytest.mark.parametrize("M", [4]) +@pytest.mark.parametrize("N", [4]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("residual", [True, False]) +@pytest.mark.parametrize("input_bias", [True, False]) +def test_triton_layer_norm(M, N, dtype, residual, input_bias, eps=1e-5, device='cuda'): + if not deepspeed.HAS_TRITON: + pytest.skip("triton has to be installed for the test") + torch.manual_seed(0) + # create data + x_shape = (M, N) + w_shape = (x_shape[-1], ) + weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=False) + bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=False) + x_bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=False) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') + dy = .1 * torch.randn_like(x) + if residual: + res = torch.rand(x_shape, dtype=dtype, device='cuda', requires_grad=False) + else: + res = torch.zeros(x_shape, dtype=dtype, device='cuda', requires_grad=False) + x.requires_grad_(True) + # forward pass + if residual or input_bias: + y_tri = layer_norm_residual(x, x_bias if input_bias else None, res, weight, bias, eps) + else: + y_tri = layer_norm(x, weight, bias, eps) + y_ref = torch.nn.functional.layer_norm(x + res + (x_bias if input_bias else 0), w_shape, weight, bias, + eps).to(dtype) + # compare + #print(f"y_tri={y_tri}, y_ref={y_ref}") + triton.testing.assert_almost_equal(y_tri, y_ref) diff --git a/tests/unit/ops/transformer/inference/test_matmul.py b/tests/unit/ops/transformer/inference/test_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..804a85750a3a0259a9d9cef0c9b967ab1955027e --- /dev/null +++ b/tests/unit/ops/transformer/inference/test_matmul.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed +from deepspeed.ops.op_builder import InferenceBuilder + +if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("Inference ops are not available on this system", allow_module_level=True) + +inference_module = None +torch_minor_version = None + + +def allclose(x, y): + assert x.dtype == y.dtype + rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (5e-2, 2e-3)}[x.dtype] + return torch.allclose(x, y, rtol=rtol, atol=atol) + + +def run_matmul_ref(a, b): + return torch.matmul(a, b) + + +def run_matmul_ds(a, b, use_triton_ops=False): + if use_triton_ops: + from deepspeed.ops.transformer.inference.triton import matmul_4d as matmul + return matmul(a, b) + + assert use_triton_ops, "Only triton softmax is supported for now" + + +@pytest.mark.inference_ops +@pytest.mark.parametrize("B", [1, 2]) +@pytest.mark.parametrize("H", [1, 2, 16]) +@pytest.mark.parametrize("M", [1, 7, 8, 128]) +@pytest.mark.parametrize("K", [2, 5, 16, 128]) +@pytest.mark.parametrize("N", [1, 2, 8, 512]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("use_triton_ops", [True]) +def test_matmul_4d(B, H, M, K, N, dtype, use_triton_ops): + if not deepspeed.HAS_TRITON and use_triton_ops: + pytest.skip("triton has to be installed for the test") + + # skip autotune in testing + from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul + fp16_matmul.skip_autotune() + + a_ds = torch.randn((B, H, M, K), dtype=dtype, device='cuda') + b_ds = torch.randn((B, H, K, N), dtype=dtype, device='cuda') + a_ref = a_ds.clone().detach() + b_ref = b_ds.clone().detach() + + ds_out = run_matmul_ds(a_ds, b_ds, use_triton_ops) + ref_out = run_matmul_ref(a_ref, b_ref) + assert (allclose(ds_out, ref_out)) diff --git a/tests/unit/ops/transformer/inference/test_residual_add.py b/tests/unit/ops/transformer/inference/test_residual_add.py index 1a9d8975852c0ae2afdba9fc3ae15e3c24a7913b..c2952f74ff2d7414fcaf6ae4524c50941979d475 100644 --- a/tests/unit/ops/transformer/inference/test_residual_add.py +++ b/tests/unit/ops/transformer/inference/test_residual_add.py @@ -74,8 +74,11 @@ def run_residual_add_reference(hidden_state, residual, attn_output, attn_bias, f @pytest.mark.parametrize("add_bias", [True, False]) @pytest.mark.parametrize("mp_size", [1, 2]) @pytest.mark.parametrize("pre_attn_norm", [True, False]) +@pytest.mark.parametrize("use_triton_ops", [True, False]) def test_residual_add(inference_module, batch, sequence, hidden_dim, dtype, mlp_after_attn, add_bias, mp_size, - pre_attn_norm): + pre_attn_norm, use_triton_ops): + if not deepspeed.HAS_TRITON and use_triton_ops and dtype == torch.float16: + pytest.skip("triton has to be installed for the test") ds_out = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name()) residual = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name()) attn_output = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name()) @@ -90,6 +93,9 @@ def test_residual_add(inference_module, batch, sequence, hidden_dim, dtype, mlp_ ds_out, residual, attn_output, attn_bias, final_bias, mp_size, mlp_after_attn, add_bias, pre_attn_norm ] + if use_triton_ops: + from deepspeed.ops.transformer.inference.triton import residual_add_bias + ds_out = residual_add_bias(*res_add_args) if dtype == torch.float16: ds_out = inference_module.residual_add_bias_fp16(*res_add_args) elif dtype == torch.float32: @@ -97,7 +103,12 @@ def test_residual_add(inference_module, batch, sequence, hidden_dim, dtype, mlp_ elif dtype == torch.bfloat16: ds_out = inference_module.residual_add_bias_bf16(*res_add_args) else: - raise ValueError(f"Unsupported dtype: {dtype}") + if dtype == torch.float16: + ds_out = inference_module.residual_add_bias_fp16(*res_add_args) + elif dtype == torch.float32: + ds_out = inference_module.residual_add_bias_fp32(*res_add_args) + else: + raise ValueError(f"Unsupported dtype: {dtype}") if not allclose(ds_out, ref_out): print((ds_out - ref_out).abs().max()) diff --git a/tests/unit/ops/transformer/inference/test_softmax.py b/tests/unit/ops/transformer/inference/test_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..76046f31e01ad122b253a058b6040ed988c17cf1 --- /dev/null +++ b/tests/unit/ops/transformer/inference/test_softmax.py @@ -0,0 +1,51 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import pytest +import torch +import deepspeed +from deepspeed.ops.op_builder import InferenceBuilder + +if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: + pytest.skip("Inference ops are not available on this system", allow_module_level=True) + +inference_module = None +torch_minor_version = None + + +def allclose(x, y): + assert x.dtype == y.dtype + rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype] + return torch.allclose(x, y, rtol=rtol, atol=atol) + + +def run_softmax_reference(input): + return torch.nn.functional.softmax(input, dim=-1) + + +def run_softmax_ds(input, use_triton_ops=False): + if use_triton_ops: + from deepspeed.ops.transformer.inference.triton import softmax + # return torch.empty_like(input) + return softmax(input) + + assert use_triton_ops, "Only triton softmax is supported for now" + + +@pytest.mark.inference_ops +@pytest.mark.parametrize("batch", [1, 2]) +@pytest.mark.parametrize("sequence", [1, 128, 255, 1232]) +@pytest.mark.parametrize("channels", [512, 4096]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("use_triton_ops", [True]) +def test_softmax(batch, sequence, channels, dtype, use_triton_ops): + if not deepspeed.HAS_TRITON and use_triton_ops: + pytest.skip("triton has to be installed for the test") + input_ds = torch.randn((batch, sequence, channels), dtype=dtype, device='cuda') + input_ref = input_ds.clone().detach() + + ds_out = run_softmax_ds(input_ds, use_triton_ops) + ref_out = run_softmax_reference(input_ref) + assert (allclose(ds_out, ref_out))