diff --git a/.github/workflows/formatting.yml b/.github/workflows/formatting.yml
index 3a59170b42abef5d3b172ef9608ed66c37464ef2..c4af4e5a95358821d94584a857a2ab0be4027c36 100644
--- a/.github/workflows/formatting.yml
+++ b/.github/workflows/formatting.yml
@@ -27,7 +27,7 @@ jobs:
- name: Install deepspeed
run: |
- pip install .[dev,autotuning]
+ pip install .[dev,autotuning,triton]
ds_report
- name: Formatting checks
diff --git a/blogs/assets/images/triton-bert-base-latency.png b/blogs/assets/images/triton-bert-base-latency.png
new file mode 100644
index 0000000000000000000000000000000000000000..4f733f4d1afe0e7720ff0e3e93dc842f2fce1fe7
Binary files /dev/null and b/blogs/assets/images/triton-bert-base-latency.png differ
diff --git a/blogs/assets/images/triton-bert-large-latency.png b/blogs/assets/images/triton-bert-large-latency.png
new file mode 100644
index 0000000000000000000000000000000000000000..d82dc0ccac51392d6ad8062031896dffef6cb039
Binary files /dev/null and b/blogs/assets/images/triton-bert-large-latency.png differ
diff --git a/blogs/deepspeed-triton/README.md b/blogs/deepspeed-triton/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..071b5d4bc6d038bf49e36cae30a7544fed2bd776
--- /dev/null
+++ b/blogs/deepspeed-triton/README.md
@@ -0,0 +1,95 @@
+# DeepSpeed with Triton compiler
+
+# 1. Overview
+
+We have integrated [Triton](https://github.com/openai/triton), an open source compiler for GPU programming, into DeepSpeed, which further boosts the inference speed of BERT-like models in float16 precision.
+By replacing some CUDA kernels or torch operators with Triton kernels, we achieved 1.14\~1.68x speedup (or 12\~41% latency reduction) for different models and GPUs, as shown in Table 1.
+
+
+
+| Hardware | Bert-base | Bert-large | Roberta-base | Roberta-large |
+|----------|:------:|:------:|:------:|:------:|
+| A100 |1.65x | 1.68x | 1.53x | 1.61x |
+| V100 | 1.29x | 1.14x | 1.23x | 1.21x |
+
+Table 1. The average speedup (see NOTE below for more detail)
+
+
+
+
+For those transformer operators in float16, we have implemented kernels written in Triton language that replace ordinary CUDA kernels or torch operators.
+The Triton kernels we implemented include softmax, layer-normalization, residual-addition and all the matrix multiplications except MLP layers (see NOTE below for details).
+In our experiments, Triton kernels help to reduce the average latecy (over difference sequence lengths) by 6\~24% (depending on model and hardware) when compared to the latency with CUDA-only kernels.
+
+
+Figures below show the latency reduction in more detail.
+Figure 1 visualizes latency reduction in different sequence lengths in A100 GPU for Bert-base model.
+The baseline (blue) is from Huggingface transformers without any kernel injection, the orange is from Deepspeed with CUDA-only kernels and the gray is from Deepspeed with Triton kernels.
+Figure 2 shows the same plot for Bert-large model in A100 GPU.
+
+
+
+
![triton-bert-base-latency](../assets/images/triton-bert-base-latency.png)
+
+*Figure 1: Normalized P90 latency for Bert-base model in A100 GPU across different sequence lengths*
+
+
![triton-bert-large-latency](../assets/images/triton-bert-large-latency.png)
+
+*Figure 2: Normalized P90 latency for Bert-large model in A100 GPU across different sequence lengths*
+
+
+
+
+Next, we dive deeper into this new feature in DeepSpeed.
+
+# 2. How to use Triton in Deepspeed
+
+You can enable Triton compilers to optimize these kernels by setting a flag in the DeepSpeed config file.
+
+```
+pipe = pipeline('fill-mask', model='bert-base-cased', framework='pt', device=0)
+pipe.model = deepspeed.init_inference(pipe.model,
+ dtype=torch.float16,
+ replace_with_kernel_inject=True,
+ enable_cuda_graph=True,
+ use_triton=True,
+ triton_autotune=True,
+ max_out_tokens=pipe.tokenizer.model_max_length)
+```
+
+
+## Running BERT inference with Triton kernels
+
+We use an example of Bert-base here.
+
+```python
+pip install deepspeed[triton]
+
+git clone https://github.com/microsoft/DeepSpeedExamples.git
+cd DeepSpeedExamples/inference/huggingface/fill-mask
+
+deepspeed --num_gpus 1 test-bert.py --triton
+```
+
+To run a performance benchmark, you can use the following command:
+
+```python
+pip install deepspeed[triton]
+
+git clone https://github.com/microsoft/DeepSpeedExamples.git
+cd DeepSpeedExamples/benchmarks/inference
+
+deepspeed --num_gpus 1 triton-bert-benchmark.py --model bert-base-cased --dtype fp16 --kernel-inject --deepspeed --graphs --triton
+```
+
+# NOTE
+
+* For more information on how to use DeepSpeed, please visit our [GitHub Page](https://github.com/microsoft/DeepSpeedExamples) and our [website](https://www.deepspeed.ai/), where you can find blog posts, tutorials, and documentation.
+
+* This feature is currently only supported for BERT, Roberta and other BERT-like models, and not for text-generation models yet.
+
+* To achieve the best performance with Triton optimization, you need to activate CUDA graph and ‘triton_autotune’ in the DeepSpeed config. CUDA graph prevents the overhead of JIT compilation and a deep call stack in Triton. ‘triton_autotune’ executes an initial step to find the most suitable parameters for Triton kernels, which may take some time.
+
+* We used [Triton 2.0.0.post1 release](https://pypi.org/project/triton/2.0.0.post1/) in our experiments.
+
+* In our experiments, we used a batch size of 1, a sequence length range of 8 to 512, and a ‘fill-mask’ task. Table 1 shows the average P90 latency over the entire sequence length range, while Figures 1 and 2 show the P90 latency for specific sub-ranges. The baseline is the Huggingface transformers without any optimization. The speedup is calculated as (baseline P90 latency)/(DeepSpeed-Triton P90 Latency). We found that the CUDA kernel in MLP performed better than the Triton kernel in our experiments, so we used a hybrid approach that combines both kernels when Triton is enabled in the DeepSpeed config.
diff --git a/deepspeed/__init__.py b/deepspeed/__init__.py
index 255dacdccf6ef718ff7f40a6970c1feada6e2425..e060ebcb9734780704123910f59a1ae56da2144e 100755
--- a/deepspeed/__init__.py
+++ b/deepspeed/__init__.py
@@ -12,6 +12,12 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from packaging import version as pkg_version
+try:
+ import triton # noqa: F401
+ HAS_TRITON = True
+except ImportError:
+ HAS_TRITON = False
+
from . import ops
from . import module_inject
diff --git a/deepspeed/inference/config.py b/deepspeed/inference/config.py
index 3474578e6a118b44d8086f688dc84fc67482c81b..7f1dcb3dfe06a8e38ee61288b2c62e21cbb0d17e 100644
--- a/deepspeed/inference/config.py
+++ b/deepspeed/inference/config.py
@@ -4,6 +4,7 @@
# DeepSpeed Team
import torch
+import deepspeed
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from pydantic import Field
@@ -152,6 +153,18 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
can run faster using the graph replay method.
"""
+ use_triton: bool = False
+ """
+ Use this flag to use triton kernels for inference ops.
+ """
+
+ triton_autotune: bool = False
+ """
+ Use this flag to enable triton autotuning.
+ Turning it on is better for performance but increase the 1st runtime for
+ autotuning.
+ """
+
zero: DeepSpeedZeroConfig = {}
"""
ZeRO configuration to use with the Inference Engine. Expects a dictionary
@@ -279,6 +292,12 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
return DeepSpeedMoEConfig(moe=field_value)
return field_value
+ @validator("use_triton")
+ def has_triton(cls, field_value, values):
+ if field_value and not deepspeed.HAS_TRITON:
+ raise ValueError('Triton needs to be installed to use deepspeed with triton kernels')
+ return field_value
+
class Config:
# Get the str representation of the datatype for serialization
json_encoders = {torch.dtype: lambda x: str(x)}
diff --git a/deepspeed/inference/engine.py b/deepspeed/inference/engine.py
index d2f9bf9cf0b26f8ba360a458eeb4561eff516e0a..ab502551e702bd3fb447f4176dde78f709fec751 100755
--- a/deepspeed/inference/engine.py
+++ b/deepspeed/inference/engine.py
@@ -607,6 +607,7 @@ class InferenceEngine(Module):
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
+
else:
outputs = self.module(*inputs, **kwargs)
diff --git a/deepspeed/model_implementations/transformers/ds_transformer.py b/deepspeed/model_implementations/transformers/ds_transformer.py
index b91432461fcb1d0554c6f79f75a2cbdb47bc20fd..a41df58ad0591bfb04495da820b210cb8ad11727 100644
--- a/deepspeed/model_implementations/transformers/ds_transformer.py
+++ b/deepspeed/model_implementations/transformers/ds_transformer.py
@@ -12,6 +12,10 @@ from deepspeed.ops.transformer.inference.ds_mlp import DeepSpeedMLP
from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttention, BloomSelfAttention
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder
+import deepspeed
+if deepspeed.HAS_TRITON:
+ from deepspeed.ops.transformer.inference.triton.mlp import TritonMLP
+ from deepspeed.ops.transformer.inference.triton.attention import TritonSelfAttention
inference_module = None
@@ -55,14 +59,24 @@ class DeepSpeedTransformerInference(nn.Module):
if DeepSpeedTransformerInference.layer_id == 1:
log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0])
+ if deepspeed.HAS_TRITON and self.config.use_triton:
+ log_dist(f"Injecting Triton kernels ...", [0])
if self.config.bigscience_bloom:
self.attention = BloomSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, merge_count)
+ assert not self.config.use_triton
else:
- self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups,
- merge_count)
- self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count,
- mlp_extra_grouping)
+ if deepspeed.HAS_TRITON and self.config.use_triton:
+ self.attention = TritonSelfAttention(self.config)
+ else:
+ self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups,
+ merge_count)
+
+ if deepspeed.HAS_TRITON and self.config.use_triton:
+ self.mlp = TritonMLP(self.config)
+ else:
+ self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count,
+ mlp_extra_grouping)
device = get_accelerator().current_device_name() # if config.bigscience_bloom else 'cpu'
if self.config.set_empty_params:
diff --git a/deepspeed/module_inject/containers/base.py b/deepspeed/module_inject/containers/base.py
index a520664793ca61a3168b9f29b49fe902ba8d91f0..62b78a091fd5fe31f78aa5aa92b78bc653f6502a 100644
--- a/deepspeed/module_inject/containers/base.py
+++ b/deepspeed/module_inject/containers/base.py
@@ -8,6 +8,7 @@ from abc import ABC
import torch
+import deepspeed
from deepspeed.ops.transformer.inference.config import DeepSpeedInferenceConfig
from deepspeed.accelerator import get_accelerator
@@ -79,6 +80,10 @@ class BaseTransformerContainer(ABC):
self.input_nb = None
self.mp_group = None
+ self.use_triton = False
+
+ # Triton
+ self.use_triton = config.use_triton and deepspeed.HAS_TRITON
def create_ds_model_config(self):
self.set_hidden_heads(*self.policy.get_hidden_heads())
@@ -110,7 +115,14 @@ class BaseTransformerContainer(ABC):
use_mup=self.use_mup,
return_single_tuple=self.return_single_tuple,
set_empty_params=self.config.set_empty_params,
- transposed_mode=self.config.transposed_mode)
+ transposed_mode=self.config.transposed_mode,
+ use_triton=self.use_triton,
+ triton_autotune=self.config.triton_autotune)
+
+ if self.use_triton and deepspeed.HAS_TRITON:
+ if not self.config.triton_autotune:
+ from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul
+ fp16_matmul.skip_autotune()
return self.ds_model_config
diff --git a/deepspeed/module_inject/containers/bert.py b/deepspeed/module_inject/containers/bert.py
index 967a02276be607957b28edcb5c4402f4b3a7e4df..20ae575f45144733a82b609eed21284746cb95d0 100644
--- a/deepspeed/module_inject/containers/bert.py
+++ b/deepspeed/module_inject/containers/bert.py
@@ -18,6 +18,7 @@ class DS_BERTContainer(BaseTransformerContainer):
# All model specific things should be defined here instead of the base class.
self.return_tuple = True
self.triangular_masking = False
+ self.use_triton = kwargs['config'].use_triton and deepspeed.HAS_TRITON
def create_module(self, config=None):
_config = config if config is not None else self.ds_model_config
diff --git a/deepspeed/module_inject/containers/distil_bert.py b/deepspeed/module_inject/containers/distil_bert.py
index 2acd144cac044a880c8480d950acc08ff0fc6877..ecd0562438b5ac634cf8b4536fd3413d0f9ed9d8 100644
--- a/deepspeed/module_inject/containers/distil_bert.py
+++ b/deepspeed/module_inject/containers/distil_bert.py
@@ -18,6 +18,7 @@ class DS_DistilBERTContainer(BaseTransformerContainer):
# All model specific things should be defined here instead of the base class.
self.triangular_masking = False
self.return_single_tuple = True
+ self.use_triton = kwargs['config'].use_triton and deepspeed.HAS_TRITON
def create_module(self, config=None):
_config = config if config is not None else self.ds_model_config
diff --git a/deepspeed/ops/transformer/inference/config.py b/deepspeed/ops/transformer/inference/config.py
index 261523529d0b15104c1933fba34e9a6d92632f86..12007f6c14ca67f51816143276c83041b0800705 100644
--- a/deepspeed/ops/transformer/inference/config.py
+++ b/deepspeed/ops/transformer/inference/config.py
@@ -44,6 +44,7 @@ class DeepSpeedInferenceConfig(TransformerConfig):
scale_attention: If true, both q and k are scaled by 1/sqrt(attention_heads) before attention computation.
return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor
bigscience_bloom: This flag is added temporarily for supporting the BLOOM-176B model architecture.
+ use_triton: This flag is to enable triton kernels in inference or not.
"""
def __init__(self,
@@ -77,7 +78,9 @@ class DeepSpeedInferenceConfig(TransformerConfig):
scale_attn_by_inverse_layer_idx=False,
return_single_tuple=False,
set_empty_params=False,
- transposed_mode=False):
+ transposed_mode=False,
+ use_triton=False,
+ triton_autotune=False):
super(DeepSpeedInferenceConfig,
self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads,
num_hidden_layers)
@@ -109,6 +112,8 @@ class DeepSpeedInferenceConfig(TransformerConfig):
self.return_single_tuple = return_single_tuple
self.set_empty_params = set_empty_params
self.transposed_mode = transposed_mode
+ self.use_triton = use_triton
+ self.triton_autotune = triton_autotune
@classmethod
def from_dict(cls, json_object):
diff --git a/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py b/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py
index 1d2dc06d0b00865033650070266b146766fd2132..63323c150752b9b98e1120e427251ced382b115a 100644
--- a/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py
+++ b/deepspeed/ops/transformer/inference/op_binding/gelu_gemm.py
@@ -6,6 +6,7 @@
import torch
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
+import deepspeed
class GELUGemmOp(BaseOp):
@@ -14,9 +15,13 @@ class GELUGemmOp(BaseOp):
super(GELUGemmOp, self).__init__(config)
try:
if self.config.dtype in [torch.float16, torch.int8]:
- self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp16 # type: ignore
+ if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16:
+ from deepspeed.ops.transformer.inference.triton.ops import fused_gemm_gelu as _triton_fused_gemm_gelu
+ self.fused_gemm_gelu = _triton_fused_gemm_gelu # type: ignore
+ else:
+ self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
- self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_bf16
+ self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_bf16 # type: ignore
else:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp32 # type: ignore
except AttributeError:
diff --git a/deepspeed/ops/transformer/inference/op_binding/linear.py b/deepspeed/ops/transformer/inference/op_binding/linear.py
index 6ea15d608f42f33cbf422b02812a93286d7498ec..e970b562c6d6512d2d0de3d8aa564fbe2bf13d3b 100644
--- a/deepspeed/ops/transformer/inference/op_binding/linear.py
+++ b/deepspeed/ops/transformer/inference/op_binding/linear.py
@@ -6,6 +6,7 @@
import torch
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
+import deepspeed
class LinearOp(BaseOp):
@@ -14,6 +15,14 @@ class LinearOp(BaseOp):
super(LinearOp, self).__init__(config)
try:
if self.config.dtype in [torch.float16, torch.int8]:
+ if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16:
+ from deepspeed.ops.transformer.inference.triton.ops import linear_func as _triton_linear_func
+ self.linear_func = _triton_linear_func
+ triton_autotune = config.triton_autotune and config.layer_id == 0
+ if triton_autotune:
+ __class__._triton_autotune(2, self.config.max_out_tokens, self.config.hidden_size)
+ else:
+ self.linear_func = self.inference_module.linear_layer_fp16
self.linear_func = self.inference_module.linear_layer_fp16
elif self.config.dtype == torch.bfloat16:
self.linear_func = self.inference_module.linear_layer_bf16
@@ -37,3 +46,15 @@ class LinearOp(BaseOp):
qkv_out = self.linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads,
self.config.transposed_mode)
return qkv_out
+
+ @staticmethod
+ def _triton_autotune(min_seqlen, max_seqlen, hidden_size, dtype=torch.float16):
+ from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, matmul
+ seqlen = [(min_seqlen + i)
+ for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)]
+ Fp16Matmul._read_autotune_table()
+ for N in seqlen:
+ A = torch.randn((N, hidden_size), dtype=dtype, device='cuda')
+ B = torch.randn((hidden_size, 3 * hidden_size), dtype=dtype, device='cuda')
+ matmul(A, B)
+ Fp16Matmul._update_autotune_table()
diff --git a/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py b/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py
index e3e372d60080ad85ff0b6ca77db7efcf390957a1..d5e12cb9a80162859288aebac8bdd5aee5e52a6e 100644
--- a/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py
+++ b/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py
@@ -19,7 +19,9 @@ class MLPGemmOp(BaseOp):
super(MLPGemmOp, self).__init__(config)
try:
if self.config.norm_type == NormType.LayerNorm:
- if self.config.dtype in [torch.float16, torch.int8]:
+ if self.config.dtype in [
+ torch.float16, torch.int8
+ ]: # non-triton cuda kernel has a higher performance in MLP than mlp_gemm_func in triton.ops
self.mlp_gemm_func = self.inference_module.mlp_gemm_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.mlp_gemm_func = self.inference_module.mlp_gemm_bf16
diff --git a/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py b/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py
index a6b253930ede37edffe3ba7db97743f88875a51f..dca935c1eb118f0245edcc0dd61e9b19ef2ee641 100644
--- a/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py
+++ b/deepspeed/ops/transformer/inference/op_binding/qkv_gemm.py
@@ -8,6 +8,7 @@ import torch
import torch.nn.functional as F
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
+import deepspeed
from deepspeed.utils.types import NormType
@@ -18,7 +19,14 @@ class QKVGemmOp(BaseOp):
try:
if self.config.norm_type == NormType.LayerNorm:
if self.config.dtype in [torch.float16, torch.int8]:
- self.qkv_gemm_func = self.inference_module.qkv_gemm_fp16 # type: ignore
+ if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16:
+ from deepspeed.ops.transformer.inference.triton.ops import qkv_gemm_func as _triton_qkv_gemm_func
+ self.qkv_gemm_func = _triton_qkv_gemm_func
+ triton_autotune = config.triton_autotune and config.layer_id == 0
+ if triton_autotune:
+ __class__._triton_autotune(2, self.config.max_out_tokens, self.config.hidden_size)
+ else:
+ self.qkv_gemm_func = self.inference_module.qkv_gemm_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.qkv_gemm_func = self.inference_module.qkv_gemm_bf16
else:
@@ -36,6 +44,18 @@ class QKVGemmOp(BaseOp):
elif self.config.norm_type == NormType.RMSNorm:
self.qkv_gemm_func = self.rms_qkv_gemm_fallback
+ @staticmethod
+ def _triton_autotune(min_seqlen, max_seqlen, hidden_size, dtype=torch.float16):
+ from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, matmul
+ seqlen = [(min_seqlen + i)
+ for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)]
+ Fp16Matmul._read_autotune_table()
+ for N in seqlen:
+ A = torch.randn((N, hidden_size), dtype=dtype, device='cuda')
+ B = torch.randn((hidden_size, 3 * hidden_size), dtype=dtype, device='cuda')
+ matmul(A, B)
+ Fp16Matmul._update_autotune_table()
+
def qkv_gemm_fallback(self, input, weight, q_scale, bias, gamma, beta, eps, add_bias, q_int8, transpose):
if os.environ.get('DS_KI_FALLBACK') == 'True' and not transpose:
inp_norm = F.layer_norm(input, (input.shape[2], ), gamma, beta, eps)
diff --git a/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py b/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py
index cc40633ab0eeb8a892b35bfd77f5f4c9c2d0610a..011be859634d5735937f32dbea861f455200fff3 100644
--- a/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py
+++ b/deepspeed/ops/transformer/inference/op_binding/vector_matmul.py
@@ -7,6 +7,7 @@ import os
import torch
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
+import deepspeed
class VectorMatMulOp(BaseOp):
@@ -14,7 +15,16 @@ class VectorMatMulOp(BaseOp):
def __init__(self, config: DeepSpeedInferenceConfig):
super(VectorMatMulOp, self).__init__(config)
try:
- if self.config.dtype in [torch.float16, torch.int8]:
+ if self.config.dtype == torch.float16:
+ if deepspeed.HAS_TRITON and config.use_triton:
+ from deepspeed.ops.transformer.inference.triton.ops import vector_matmul_func as _triton_vector_matmul_func
+ self.vector_matmul_func = _triton_vector_matmul_func
+ triton_autotune = config.triton_autotune and config.layer_id == 0
+ if triton_autotune:
+ __class__._triton_autotune(2, self.config.max_out_tokens, self.config.hidden_size)
+ else:
+ self.vector_matmul_func = self.inference_module.vector_matmul_fp16
+ elif self.config.dtype == torch.int8:
self.vector_matmul_func = self.inference_module.vector_matmul_fp16
elif self.config.dtype == torch.bfloat16:
self.vector_matmul_func = self.inference_module.vector_matmul_bf16
@@ -34,3 +44,15 @@ class VectorMatMulOp(BaseOp):
q_int8 = self.config.dtype == torch.int8
output = self.vector_matmul_func(input, weight, async_op, q_scale, q_int8, self.config.transposed_mode)
return output
+
+ @staticmethod
+ def _triton_autotune(min_seqlen, max_seqlen, hidden_size, dtype=torch.float16):
+ from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, matmul
+ seqlen = [(min_seqlen + i)
+ for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)]
+ Fp16Matmul._read_autotune_table()
+ for N in seqlen:
+ A = torch.randn((N, hidden_size), dtype=dtype, device='cuda')
+ B = torch.randn((hidden_size, hidden_size), dtype=dtype, device='cuda')
+ matmul(A, B)
+ Fp16Matmul._update_autotune_table()
diff --git a/deepspeed/ops/transformer/inference/triton/__init__.py b/deepspeed/ops/transformer/inference/triton/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..b7d1968df62a99849992c8b1e93698d9e51cec30
--- /dev/null
+++ b/deepspeed/ops/transformer/inference/triton/__init__.py
@@ -0,0 +1,11 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+from .residual_add import residual_add_bias
+from .layer_norm import layer_norm, layer_norm_residual
+from .gelu import gelu
+from .softmax import softmax
+from .ops import *
+from .matmul_ext import fp16_matmul, matmul_4d, score_4d_matmul, context_4d_matmul
diff --git a/deepspeed/ops/transformer/inference/triton/attention.py b/deepspeed/ops/transformer/inference/triton/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bb0f47f413c9980b422abb772f8b5252f746ee2
--- /dev/null
+++ b/deepspeed/ops/transformer/inference/triton/attention.py
@@ -0,0 +1,229 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import math
+import torch
+import torch.nn as nn
+from deepspeed.accelerator import get_accelerator
+from deepspeed import comm as dist
+from deepspeed.ops.transformer.inference.op_binding import LinearOp, VectorMatMulOp, SoftmaxContextOp, QKVGemmOp
+from deepspeed.ops.transformer.inference.triton import (
+ softmax,
+ score_4d_matmul,
+ context_4d_matmul,
+)
+
+minus_inf = -10000.0
+
+
+class TritonSelfAttention(nn.Module):
+ num_layers = 0
+
+ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1, qkv_merging=False):
+ super(TritonSelfAttention, self).__init__()
+ self.config = config
+ data_type = self.config.dtype
+ data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype
+ assert data_type_fp == torch.half, "triton supports fp16 data_type_fp"
+
+ self.config.layer_id = TritonSelfAttention.num_layers
+ TritonSelfAttention.num_layers = TritonSelfAttention.num_layers + 1
+ device = get_accelerator().current_device_name() #if config.bigscience_bloom else 'cpu'
+
+ assert config.mp_size == 1, "mp_size has to be 1 with triton attention yet"
+ if self.config.set_empty_params:
+ self.attn_qw = None
+ self.attn_qb = None
+ self.attn_kw = None
+ self.attn_kb = None
+ self.attn_vw = None
+ self.attn_vb = None
+ self.attn_qkvw = None
+ self.attn_qkvb = None
+ self.attn_ow = None
+ self.attn_ob = None
+ else:
+ qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3
+ self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size,
+ qkv_size_per_partition,
+ dtype=data_type,
+ device=device),
+ requires_grad=False)
+ self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device),
+ requires_grad=False)
+ # self-ouput weights
+ out_size_per_partition = self.config.hidden_size // self.config.mp_size
+ self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition,
+ self.config.hidden_size,
+ dtype=data_type,
+ device=device),
+ requires_grad=False)
+
+ self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
+ requires_grad=False)
+
+ self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size
+ self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size
+ self.hidden_size_per_attention_head = self.config.hidden_size // self.config.heads
+
+ self.mp_group = mp_group
+ self.use_flash = False
+
+ # used for quantization
+ self.q_scales = q_scales
+ self.q_groups = q_groups
+ self.merge_count = int(math.log2(merge_count))
+
+ self.norm_factor = math.sqrt(self.config.hidden_size // self.config.heads)
+ if not config.use_mup:
+ self.norm_factor = math.sqrt(self.norm_factor)
+
+ if self.config.scale_attn_by_inverse_layer_idx is True:
+ self.norm_factor *= math.sqrt(self.config.layer_id + 1)
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/gpt2/modeling_gpt2.py#L191
+
+ triton_autotune = self.config.triton_autotune and self.config.layer_id == 0
+ self.qkv_func = QKVGemmOp(config)
+ self.score_context_func = SoftmaxContextOp(config)
+ self.linear_func = LinearOp(config)
+ self.vector_matmul_func = VectorMatMulOp(config)
+
+ self.hidden_size = config.hidden_size
+ self.head_size = config.hidden_size // config.heads
+ self.scale = (1 / self.norm_factor / self.norm_factor if self.config.scale_attention else 1.0
+ ) # making it back to 1/sqrt(head_size)
+ self.triangular_masking = self.config.triangular_masking
+
+ # triton autotune table update for score/context matmul
+ if triton_autotune:
+ print(f"running triton autotune for attention")
+ __class__._triton_autotune(2, self.config.max_out_tokens, self.head_size, self.config.hidden_size,
+ self.triangular_masking, self.scale)
+
+ @staticmethod
+ def _triton_autotune(min_seqlen,
+ max_seqlen,
+ head_size,
+ hidden_size,
+ triangular_masking,
+ scale,
+ dtype=torch.float16):
+ from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, score_4d_matmul, context_4d_matmul
+ seqlen = [(min_seqlen + i)
+ for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)]
+ Fp16Matmul._read_autotune_table()
+ for N in seqlen:
+ qkv = torch.randn((1, N, 3 * hidden_size), dtype=dtype, device='cuda')
+ output = score_4d_matmul(qkv, head_size, triangular_masking, scale)
+ context_4d_matmul(output, qkv, head_size)
+ Fp16Matmul._update_autotune_table()
+
+ def ds_compute_attention(self, qkv_out, input_mask, layer_past, alibi):
+ if isinstance(qkv_out, list):
+ qkv_out = qkv_out[0]
+
+ no_masking = input_mask is None
+
+ if no_masking:
+ input_mask = torch.empty(1)
+
+ attn_key_value = self.score_context_func(
+ query_key_value=qkv_out,
+ attn_mask=((1 - input_mask).to(qkv_out.dtype) *
+ minus_inf) if input_mask.dtype == torch.int64 else input_mask,
+ heads=self.num_attention_heads_per_partition,
+ norm_factor=(1 / self.norm_factor if self.config.scale_attention else 1.0),
+ no_masking=no_masking,
+ layer_id=self.config.layer_id,
+ num_layers=TritonSelfAttention.num_layers,
+ alibi=alibi)
+
+ context_layer, key_layer, value_layer = attn_key_value
+ return context_layer, key_layer, value_layer
+
+ def forward(
+ self,
+ input,
+ input_mask,
+ head_mask=None,
+ layer_past=None,
+ get_present=False, # not used
+ encoder_hidden_states=None, # not used
+ encoder_attention_mask=None, # not used
+ triangularutput_attentions=False, # not used
+ norm_w=None,
+ norm_b=None,
+ alibi=None,
+ use_triton_attention=True):
+
+ if not self.config.pre_layer_norm:
+ qkv_out = self.linear_func(input=input,
+ weight=self.attn_qkvw,
+ bias=self.attn_qkvb,
+ add_bias=self.attn_qkvb is not None,
+ do_flash_attn=False,
+ num_heads=self.num_attention_heads_per_partition,
+ num_layers=TritonSelfAttention.num_layers)
+ qkv = qkv_out
+ else:
+ qkv_out = self.qkv_func(input=input,
+ weight=self.attn_qkvw,
+ bias=(self.attn_qkvb if self.attn_qkvb is not None else norm_b),
+ gamma=norm_w,
+ beta=norm_b)
+ qkv = qkv_out[0]
+
+ if use_triton_attention and (alibi is None):
+ context_layer = compute_attention(qkv=qkv,
+ input_mask=input_mask,
+ scale=self.scale,
+ layer_past=layer_past,
+ alibi=alibi,
+ head_size=self.head_size,
+ use_triton_flash=self.use_flash,
+ use_cuda_flash=False,
+ triangular=self.triangular_masking)
+ key_layer, value_layer = qkv[:, :, self.hidden_size:2 * self.hidden_size], qkv[:, :, 2 * self.hidden_size:]
+ else:
+ context_layer, key_layer, value_layer = self.ds_compute_attention(qkv_out=qkv_out,
+ input_mask=input_mask,
+ layer_past=layer_past,
+ alibi=alibi)
+ output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow)
+
+ inp_norm = qkv_out[-1]
+
+ if self.config.mlp_after_attn and self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1:
+ dist.all_reduce(output, group=self.mp_group)
+
+ return (output, key_layer, value_layer, context_layer, inp_norm)
+
+
+global inference_module
+
+
+def compute_attention(qkv,
+ input_mask,
+ layer_past,
+ alibi,
+ scale,
+ head_size,
+ triangular=False,
+ use_cuda_flash=False,
+ use_triton_flash=False,
+ use_ds_attention=False):
+ if isinstance(qkv, list):
+ qkv = qkv[0]
+
+ #assert layer_past is None, "layer_past not supported in triton yet"
+ assert alibi is None, "layer_past not supported in alibi yet"
+ output = score_4d_matmul(qkv, head_size, triangular, scale)
+ if triangular:
+ output = softmax(output)
+ else:
+ output = softmax(output, input_mask)
+ output = context_4d_matmul(output, qkv, head_size)
+
+ return output
diff --git a/deepspeed/ops/transformer/inference/triton/gelu.py b/deepspeed/ops/transformer/inference/triton/gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..17b943fb78bc44d22d602352e947aefdd1faa264
--- /dev/null
+++ b/deepspeed/ops/transformer/inference/triton/gelu.py
@@ -0,0 +1,37 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def gelu_functor(x):
+ # Using approximation introduces greater parity errors.
+ # return tl.sigmoid(1.702 * x) * x
+ return x * 0.5 * (1.0 + tl.libdevice.erf(x / 1.41421356237))
+
+
+@triton.jit
+def gelu_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
+ pid = tl.program_id(axis=0)
+ block_start = pid * BLOCK_SIZE
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < n_elements
+ x = tl.load(x_ptr + offsets, mask=mask)
+ output = gelu_functor(x)
+ tl.store(output_ptr + offsets, output, mask=mask)
+
+
+def gelu(activations: torch.Tensor) -> torch.Tensor:
+ assert activations.is_contiguous()
+ assert activations.is_cuda
+
+ output = torch.empty_like(activations)
+ n_elements = output.numel()
+ grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
+ gelu_kernel[grid](activations, output, n_elements, BLOCK_SIZE=1024)
+ return output
diff --git a/deepspeed/ops/transformer/inference/triton/layer_norm.py b/deepspeed/ops/transformer/inference/triton/layer_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3f313d2ac3d8205702dc2ceb82856154e9ddb2c
--- /dev/null
+++ b/deepspeed/ops/transformer/inference/triton/layer_norm.py
@@ -0,0 +1,249 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import torch
+import triton
+import triton.language as tl
+'''
+layer-normalization
+modified the triton kernel in
+https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/05-layer-norm.py
+'''
+
+
+@triton.jit
+def layer_norm_kernel(
+ Out,
+ A,
+ Weight,
+ Bias,
+ stride,
+ N,
+ eps,
+ BLOCK_SIZE: tl.constexpr,
+):
+ # position of elements processed by this program
+ row = tl.program_id(0)
+ Out += row * stride
+ A += row * stride
+ # compute mean
+ mean = 0
+ _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32)
+ _mean += a
+ mean = tl.sum(_mean, axis=0) / N
+ # compute variance
+ _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32)
+ a = tl.where(cols < N, a - mean, 0.0)
+ _var += a * a
+ var = tl.sum(_var, axis=0) / N
+ rstd = 1 / tl.sqrt(var + eps)
+ # multiply by weight and add bias
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ mask = cols < N
+ weight = tl.load(Weight + cols, mask=mask)
+ bias = tl.load(Bias + cols, mask=mask)
+ a = tl.load(A + cols, mask=mask, other=0.0).to(tl.float32)
+ a_hat = (a - mean) * rstd
+ out = a_hat * weight + bias
+ # # write-back
+ tl.store(Out + cols, out, mask=mask)
+
+
+@triton.jit
+def layer_norm_residual_kernel(
+ Out,
+ A,
+ Residual,
+ ln_input,
+ Weight,
+ Bias,
+ stride,
+ N,
+ eps,
+ BLOCK_SIZE: tl.constexpr,
+):
+ # position of elements processed by this program
+ row = tl.program_id(0)
+ Out += row * stride
+ A += row * stride
+ Residual += row * stride
+ ln_input += row * stride
+ # compute mean
+ mean = 0
+ _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32)
+ res = tl.load(Residual + cols, mask=cols < N, other=0.0).to(tl.float32)
+ a = a + res
+ tl.store(ln_input + cols, a, mask=cols < N)
+ _mean += a
+ mean = tl.sum(_mean, axis=0) / N
+ # compute variance
+ _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ a = tl.load(ln_input + cols, mask=cols < N, other=0.0).to(tl.float32)
+ a = tl.where(cols < N, a - mean, 0.0)
+ _var += a * a
+ var = tl.sum(_var, axis=0) / N
+ rstd = 1 / tl.sqrt(var + eps)
+ # multiply by weight and add bias
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ mask = cols < N
+ weight = tl.load(Weight + cols, mask=mask)
+ bias = tl.load(Bias + cols, mask=mask)
+ a = tl.load(ln_input + cols, mask=mask, other=0.0).to(tl.float32)
+ a_hat = (a - mean) * rstd
+ out = a_hat * weight + bias
+ # write-back
+ tl.store(Out + cols, out, mask=mask)
+
+
+@triton.jit
+def layer_norm_residual_bias_kernel(
+ Out,
+ A,
+ Residual,
+ InputBias,
+ ln_input,
+ Weight,
+ Bias,
+ stride,
+ N,
+ eps,
+ BLOCK_SIZE: tl.constexpr,
+):
+ # position of elements processed by this program
+ row = tl.program_id(0)
+ Out += row * stride
+ A += row * stride
+ Residual += row * stride
+ ln_input += row * stride
+ # compute mean
+ mean = 0
+ _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32)
+ res = tl.load(Residual + cols, mask=cols < N, other=0.0).to(tl.float32)
+ b = tl.load(InputBias + cols, mask=cols < N, other=0.0).to(tl.float32)
+ a = a + b + res
+ tl.store(ln_input + cols, a, mask=cols < N)
+ _mean += a
+ mean = tl.sum(_mean, axis=0) / N
+ # compute variance
+ _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ a = tl.load(ln_input + cols, mask=cols < N, other=0.0).to(tl.float32)
+ a = tl.where(cols < N, a - mean, 0.0)
+ _var += a * a
+ var = tl.sum(_var, axis=0) / N
+ rstd = 1 / tl.sqrt(var + eps)
+ # multiply by weight and add bias
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ mask = cols < N
+ weight = tl.load(Weight + cols, mask=mask)
+ bias = tl.load(Bias + cols, mask=mask)
+ a = tl.load(ln_input + cols, mask=mask, other=0.0).to(tl.float32)
+ a_hat = (a - mean) * rstd
+ out = a_hat * weight + bias
+ # write-back
+ tl.store(Out + cols, out, mask=mask)
+
+
+def layer_norm(a, weight, bias, eps):
+ assert a.is_contiguous()
+ assert weight.is_contiguous()
+ assert bias.is_contiguous()
+
+ # allocate output
+ out = torch.empty_like(a)
+ # reshape input data into 2D tensor
+ a_arg = a.view(-1, a.shape[-1])
+ M, N = a_arg.shape
+ # Less than 64KB per feature: enqueue fused kernel
+ MAX_FUSED_SIZE = 65536 // a.element_size()
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
+ BLOCK_SIZE = max(BLOCK_SIZE, 128)
+ BLOCK_SIZE = min(BLOCK_SIZE, 4096)
+ BLOCK_SIZE = BLOCK_SIZE if N <= 4096 else 8192
+ # heuristics for number of warps
+ num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
+ layer_norm_kernel[(M, )](
+ out,
+ a_arg,
+ weight,
+ bias,
+ a_arg.stride(0),
+ N,
+ eps,
+ BLOCK_SIZE=BLOCK_SIZE,
+ num_warps=num_warps,
+ )
+ return out
+
+
+def layer_norm_residual(a, input_bias, residual, weight, bias, eps):
+ assert a.is_contiguous()
+ assert weight.is_contiguous()
+ assert bias.is_contiguous()
+ assert residual.is_contiguous()
+
+ # allocate output and scratch-pad for residual addition
+ out = torch.empty_like(a)
+ ln_input = torch.empty_like(a)
+ # reshape input data into 2D tensor
+ a_arg = a.view(-1, a.shape[-1])
+ residual = residual.view(-1, residual.shape[-1])
+ M, N = a_arg.shape
+ # Less than 64KB per feature: enqueue fused kernel
+ MAX_FUSED_SIZE = 65536 // a.element_size()
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
+ BLOCK_SIZE = max(BLOCK_SIZE, 128)
+ BLOCK_SIZE = min(BLOCK_SIZE, 4096)
+ BLOCK_SIZE = BLOCK_SIZE if N <= 4096 else 8192
+ # heuristics for number of warps
+ num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
+ if input_bias is None:
+ layer_norm_residual_kernel[(M, )](
+ out,
+ a_arg,
+ residual,
+ ln_input,
+ weight,
+ bias,
+ a_arg.stride(0),
+ N,
+ eps,
+ BLOCK_SIZE=BLOCK_SIZE,
+ num_warps=num_warps,
+ )
+ else:
+ layer_norm_residual_bias_kernel[(M, )](
+ out,
+ a_arg,
+ residual,
+ input_bias,
+ ln_input,
+ weight,
+ bias,
+ a_arg.stride(0),
+ N,
+ eps,
+ BLOCK_SIZE=BLOCK_SIZE,
+ num_warps=num_warps,
+ )
+ return out
diff --git a/deepspeed/ops/transformer/inference/triton/matmul_ext.py b/deepspeed/ops/transformer/inference/triton/matmul_ext.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad93831e1f84945a6f846426205c069b5652ec65
--- /dev/null
+++ b/deepspeed/ops/transformer/inference/triton/matmul_ext.py
@@ -0,0 +1,436 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import torch
+import triton
+import os
+from filelock import FileLock
+import deepspeed.ops.transformer.inference.triton.triton_matmul_kernel as triton_matmul_kernel
+import pickle
+from io import open
+
+
+# -----------------------------------------------------------------------------
+# util class/functions for triton
+def _default_cache_dir():
+ return os.path.join(os.environ["HOME"], ".triton", "autotune")
+
+
+def bias_add_activation(C, bias=None, activation=""):
+ if bias is not None:
+ C += bias
+ # activation
+ if activation == "relu":
+ relu = torch.nn.Relu()
+ C = relu(C)
+ elif activation == "leaky_relu":
+ leaky_relu = torch.nn.LeakyReLU(0.01)
+ C = leaky_relu(C)
+ elif activation == "gelu":
+ sigmoid = torch.nn.Sigmoid()
+ C = sigmoid(1.702 * C) * C
+ elif activation == "sigmoid":
+ sigmoid = torch.nn.Sigmoid()
+ C = sigmoid(C)
+ return C
+
+
+class AutotuneCacheManager:
+ """
+ Cache manager for autotune
+ """
+
+ def __init__(self, key):
+ self.key = key
+ self.file_path = None
+ self.lock_path = None
+ # if caching is enabled, get the lock and bin path
+ self.cache_dir = os.environ.get('TRITON_CACHE_DIR', _default_cache_dir())
+ if self.cache_dir:
+ os.makedirs(self.cache_dir, exist_ok=True)
+ if self.cache_dir:
+ self.file_path = os.path.join(self.cache_dir, self.key + ".pickle")
+ self.lock_path = self.file_path + ".lock"
+
+ def has_file(self):
+ return self.file_path and os.path.exists(self.file_path)
+
+ def put(self, table):
+ if self.file_path:
+ assert self.lock_path is not None
+ with FileLock(self.lock_path):
+ with open(self.file_path + ".tmp", 'wb') as handle:
+ pickle.dump(table, handle)
+ os.rename(self.file_path + ".tmp", self.file_path)
+
+ def load(self):
+ if os.path.exists(self.file_path):
+ with open(self.file_path, 'rb') as handle:
+ loaded_dict = pickle.load(handle)
+ return loaded_dict
+ else:
+ return None
+
+
+# -----------------------------------------------------------------------------
+# triton matmul class
+
+
+class MatmulExt(torch.autograd.Function):
+ """
+ a wrapper class that can call different triton matmul kernels depending on the input parameters
+ """
+
+ @staticmethod
+ def forward(A, B, bias=None, activation="", use_triton=True, update_autotune_table=False):
+ """
+ A: input, activation matrix A
+ B: input, weight matrix B
+ """
+ matmul = None
+ quantize_activation = False
+ Batch = 0
+
+ if len(A.shape) == 3: # if A is 3d-tensor where batch index is given as 0-axis
+ assert A.is_contiguous(), "matrix A must be contiguous"
+ Batch, M, K = A.shape
+ A = A.view(-1, K)
+
+ # fp16 activation and fp16 weight matmul into fp16 output
+ matmul = fp16_matmul
+ C = matmul.forward(A, B, use_triton=use_triton, bias=bias, activation=activation)
+
+ if matmul and update_autotune_table:
+ matmul._update_autotune_table()
+
+ if Batch > 0:
+ C = C.view(Batch, M, -1)
+
+ return C
+
+
+class TritonMatmul(torch.autograd.Function):
+ """
+ triton matmul kernel superclass
+ """
+
+ def __init__(self):
+ pass
+
+ @staticmethod
+ def _ref_forward(A, B, ref_dtype=torch.float32):
+ C = torch.matmul(A.type(ref_dtype), B.type(ref_dtype))
+ return C
+
+ @staticmethod
+ def _read_autotune_table(cache_key, triton_kernel):
+ cache_manager = AutotuneCacheManager(cache_key)
+ table = cache_manager.load()
+ if table:
+ triton_kernel.cache = table
+
+ @staticmethod
+ def _write_autotune_table(cache_key, triton_kernel):
+ cache_manager = AutotuneCacheManager(cache_key)
+ cache_manager.put(triton_kernel.cache)
+
+ @staticmethod
+ def _update_autotune_table(cache_key, triton_kernel):
+ cache_manager = AutotuneCacheManager(cache_key)
+ autotune_table = cache_manager.load()
+ if autotune_table is None:
+ autotune_table = dict()
+ autotune_table.update(triton_kernel.cache) # always overwrite with the new autotune results
+ cache_manager = AutotuneCacheManager(cache_key)
+ cache_manager.put(autotune_table)
+
+ @staticmethod
+ def forward(
+ A,
+ B,
+ ref_dtype=torch.float32, # fp32 only
+ bias=None,
+ activation=""):
+ C = torch.matmul(A.type(ref_dtype), B.type(ref_dtype))
+ C = bias_add_activation(C, bias, activation)
+ return C
+
+
+class Fp16Matmul(TritonMatmul):
+ """
+ fp16 matrix multiplication kernel
+ dtypes: fp16 x fp16 = fp16
+ """
+
+ _2d_kernel = triton_matmul_kernel._fp_matmul
+ _4d_kernel = triton_matmul_kernel.matmul_4d_kernel
+ _cache_stride = 32
+
+ def __init__(self, read_cache=True):
+ super().__init__()
+ if read_cache:
+ __class__._read_autotune_table()
+
+ def skip_autotune(self):
+ __class__._2d_kernel.configs = [__class__._2d_kernel.configs[0]]
+ __class__._4d_kernel.configs = [__class__._4d_kernel.configs[0]]
+
+ @staticmethod
+ def forward(A, B, use_triton=True, bias=None, activation=""):
+ if use_triton:
+ device = A.device
+ # handle non-contiguous inputs if necessary
+ if A.stride(0) > 1 and A.stride(1) > 1:
+ A = A.contiguous()
+ if B.stride(0) > 1 and B.stride(1) > 1:
+ B = B.contiguous()
+ # checks constraints
+ assert A.shape[1] == B.shape[0], "incompatible dimensions"
+ M, K = A.shape
+ _, N = B.shape
+ # allocates output
+ C = torch.empty((M, N), device=device, dtype=A.dtype)
+ # accumulator types
+ ACC_TYPE = triton.language.float32 if A.dtype in [torch.float16, torch.bfloat16, torch.float32
+ ] else triton.language.int32
+ # launch kernel
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
+ __class__._2d_kernel[grid](A,
+ B,
+ C,
+ M,
+ N,
+ K,
+ bias,
+ A.stride(0),
+ A.stride(1),
+ B.stride(0),
+ B.stride(1),
+ C.stride(0),
+ C.stride(1),
+ M // __class__._cache_stride,
+ N // __class__._cache_stride,
+ K // __class__._cache_stride,
+ GROUP_M=8,
+ ACC_TYPE=ACC_TYPE,
+ BIAS_ADD=(0 if bias is None else 1),
+ ACTIVATION=activation)
+ else:
+ C = torch.matmul(A, B)
+ return C
+
+ @staticmethod
+ def _matmul_4d(a, b):
+ assert a.shape[-1] == b.shape[-2], "incompatible dimensions"
+ assert a.is_contiguous(), "matrix A must be contiguous"
+ assert b.is_contiguous(), "matrix B must be contiguous"
+
+ B, H, M, K = a.shape
+ B, H, K, N = b.shape
+
+ assert K > 1, "inner-product dimension K should be larger than 1"
+
+ c = torch.empty((B, H, M, N), device=a.device, dtype=a.dtype)
+
+ grid = lambda META: (
+ triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
+ H,
+ B,
+ )
+
+ __class__._4d_kernel[grid](
+ a,
+ b,
+ c,
+ M,
+ N,
+ K,
+ M // __class__._cache_stride,
+ N // __class__._cache_stride,
+ K // __class__._cache_stride,
+ a.stride(0),
+ a.stride(1),
+ a.stride(2),
+ a.stride(3),
+ b.stride(0),
+ b.stride(1),
+ b.stride(2),
+ b.stride(3),
+ c.stride(0),
+ c.stride(1),
+ c.stride(2),
+ c.stride(3),
+ scale=-1.0,
+ MASK=False,
+ )
+ return c
+
+ @staticmethod
+ def _score_4d_matmul(input, head_size, input_mask, scale=-1.0):
+ assert input.is_contiguous(), "matrix input must be contiguous"
+
+ batches = input.shape[0]
+ d_model = input.shape[-1] // 3
+ num_of_heads = d_model // head_size
+
+ q = input[:, :, :d_model]
+ k = input[:, :, d_model:d_model * 2]
+
+ q = q.view(batches, -1, num_of_heads, head_size)
+ k = k.view(batches, -1, num_of_heads, head_size)
+
+ # checks constraints
+ assert q.shape == k.shape, "incompatible dimensions"
+ B, M, H, K = q.shape
+ B, N, H, K = k.shape
+
+ assert K > 1, "inner-product dimension K should be larger than 1"
+
+ # allocates output
+ output = torch.empty((B, H, M, N), device=q.device, dtype=q.dtype)
+ grid = lambda META: (
+ triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
+ H,
+ B,
+ )
+ __class__._4d_kernel[grid](
+ q,
+ k,
+ output,
+ M,
+ N,
+ K,
+ M // __class__._cache_stride,
+ N // __class__._cache_stride,
+ K // __class__._cache_stride,
+ q.stride(0),
+ q.stride(2),
+ q.stride(1),
+ q.stride(3),
+ k.stride(0),
+ k.stride(2),
+ k.stride(3),
+ k.stride(1),
+ output.stride(0),
+ output.stride(1),
+ output.stride(2),
+ output.stride(3),
+ scale=scale,
+ MASK=False,
+ )
+ return output
+
+ @staticmethod
+ def _context_4d_matmul(prob, input, head_size):
+ assert prob.is_contiguous(), "matrix prob must be contiguous"
+ assert input.is_contiguous(), "matrix input must be contiguous"
+
+ batches = input.shape[0]
+ d_model = input.shape[-1] // 3
+ num_of_heads = d_model // head_size
+
+ v = input[:, :, d_model * 2:]
+
+ v = v.view(batches, -1, num_of_heads, head_size)
+
+ # checks constraints
+ assert (prob.shape[0] == v.shape[0] and prob.shape[1] == v.shape[2] and prob.shape[2] == v.shape[1]
+ and prob.shape[3] == v.shape[1]), "incompatible dimensions"
+ B, H, M, K = prob.shape
+ B, K, H, N = v.shape
+
+ assert K > 1, "inner-product dimension K should be larger than 1"
+
+ # allocates output
+ output = torch.empty((B, M, H, N), device=v.device, dtype=v.dtype)
+ grid = lambda META: (
+ triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
+ H,
+ B,
+ )
+
+ __class__._4d_kernel[grid](
+ prob,
+ v,
+ output,
+ M,
+ N,
+ K,
+ M // __class__._cache_stride,
+ N // __class__._cache_stride,
+ K // __class__._cache_stride,
+ prob.stride(0),
+ prob.stride(1),
+ prob.stride(2),
+ prob.stride(3),
+ v.stride(0),
+ v.stride(2),
+ v.stride(1),
+ v.stride(3),
+ # Here we also transpose the output when writing to memory.
+ output.stride(0),
+ output.stride(2),
+ output.stride(1),
+ output.stride(3),
+ scale=-1,
+ MASK=False,
+ )
+ return output.view(batches, -1, d_model)
+
+ @staticmethod
+ def _ref_forward(A, B, ref_dtype=torch.float32, bias=None, activation=""):
+ C = torch.matmul(A.type(ref_dtype), B.type(ref_dtype))
+ C = bias_add_activation(C, bias, activation)
+ return C
+
+ @staticmethod
+ def _check_parity(A,
+ B,
+ output_dtype,
+ SA=None,
+ SB=None,
+ qblock_size=None,
+ ref_dtype=torch.float32,
+ tol=0.01,
+ use_triton=True,
+ bias=None,
+ activation=""):
+ torch_output = __class__._ref_forward(A, B, ref_dtype=ref_dtype, bias=bias, activation=activation)
+ triton_output = __class__.forward(A, B, use_triton=use_triton, bias=bias, activation=activation)
+ assert triton.testing.allclose(triton_output.cpu().type(torch_output.dtype), torch_output.cpu(), tol=tol)
+ print(f"{__class__.__name__}: PASSed the parity check")
+ return triton_output, torch_output
+
+ @staticmethod
+ def _read_autotune_table():
+ TritonMatmul._read_autotune_table(__class__.__name__ + "_2d_kernel", __class__._2d_kernel)
+ TritonMatmul._read_autotune_table(__class__.__name__ + "_4d_kernel", __class__._4d_kernel)
+
+ @staticmethod
+ def _write_autotune_table():
+ TritonMatmul._write_autotune_table(__class__.__name__ + "_2d_kernel", __class__._2d_kernel)
+ TritonMatmul._write_autotune_table(__class__.__name__ + "_4d_kernel", __class__._4d_kernel)
+
+ @staticmethod
+ def _update_autotune_table():
+ TritonMatmul._update_autotune_table(__class__.__name__ + "_2d_kernel", __class__._2d_kernel)
+ TritonMatmul._update_autotune_table(__class__.__name__ + "_4d_kernel", __class__._4d_kernel)
+
+
+# -----------------------------------------------------------------------------
+# mapping
+matmul = MatmulExt.forward
+fp16_matmul = Fp16Matmul()
+matmul_4d = fp16_matmul._matmul_4d
+score_4d_matmul = fp16_matmul._score_4d_matmul
+context_4d_matmul = fp16_matmul._context_4d_matmul
+
+#
+import atexit
+
+
+@atexit.register
+def matmul_ext_update_autotune_table():
+ fp16_matmul._update_autotune_table()
diff --git a/deepspeed/ops/transformer/inference/triton/mlp.py b/deepspeed/ops/transformer/inference/triton/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..1708080b27efb0671d361b7cdeaff4b262cf0ce8
--- /dev/null
+++ b/deepspeed/ops/transformer/inference/triton/mlp.py
@@ -0,0 +1,81 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import torch
+import math
+import torch.nn as nn
+from deepspeed.accelerator import get_accelerator
+from deepspeed import comm as dist
+from ..op_binding import MLPGemmOp, VectorMatMulOp, GELUGemmOp, ResidualAddOp
+
+
+class TritonMLP(nn.Module):
+
+ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1, mlp_extra_grouping=False):
+ super(TritonMLP, self).__init__()
+
+ self.config = config
+ data_type = self.config.dtype
+ data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype
+ device = get_accelerator().current_device_name()
+ self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
+ requires_grad=False)
+ self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
+ requires_grad=False)
+ intm_size_per_partition = self.config.intermediate_size // self.config.mp_size
+ self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size,
+ intm_size_per_partition,
+ dtype=data_type,
+ device=device),
+ requires_grad=False)
+ self.inter_b = nn.Parameter(torch.empty(intm_size_per_partition, dtype=data_type_fp, device=device),
+ requires_grad=False)
+ self.output_w = nn.Parameter(torch.empty(intm_size_per_partition,
+ self.config.hidden_size,
+ dtype=data_type,
+ device=device),
+ requires_grad=False)
+ self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
+ requires_grad=False)
+
+ # used for quantization
+ self.q_scales = q_scales
+ self.q_groups = q_groups * 2 if mlp_extra_grouping else q_groups
+ self.merge_count = int(math.log2(merge_count))
+ self.mp_group = mp_group
+
+ self.mlp_gemm_func = MLPGemmOp(config)
+ self.vector_matmul_func = VectorMatMulOp(config)
+ self.fused_gemm_gelu = GELUGemmOp(config)
+ self.residual_add_func = ResidualAddOp(config)
+
+ def forward(self, input, residual, residual_norm, bias):
+ residual_add = None
+ if self.attn_nw is None:
+ output = self.fused_gemm_gelu(input=residual_norm,
+ weight=self.inter_w,
+ bias=self.inter_b,
+ weight_out=self.output_w)
+ else:
+ output, residual_add = self.mlp_gemm_func(input=input,
+ residual=residual,
+ input_bias=bias,
+ weight_interm=self.inter_w,
+ weight_out=self.output_w,
+ bias=self.inter_b,
+ gamma=self.attn_nw,
+ beta=self.attn_nb)
+ residual = self.residual_add_func(hidden_state=output,
+ residual=residual,
+ attention_output=input,
+ attention_bias=bias if bias is not None else self.output_b,
+ final_bias=self.output_b,
+ add_bias=bias is not None,
+ residual_add=residual_add)
+
+ if self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1:
+ dist.all_reduce(residual, group=self.mp_group)
+
+ return residual
diff --git a/deepspeed/ops/transformer/inference/triton/ops.py b/deepspeed/ops/transformer/inference/triton/ops.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd87d08d4d2c6c0e4b80ab770ff9a1f03dd3138c
--- /dev/null
+++ b/deepspeed/ops/transformer/inference/triton/ops.py
@@ -0,0 +1,131 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import deepspeed
+from deepspeed.ops.op_builder import InferenceBuilder
+import deepspeed.ops.transformer.inference.triton.matmul_ext as matmul_ext
+from deepspeed.ops.transformer.inference.triton.layer_norm import layer_norm, layer_norm_residual
+
+inference_module = None
+
+
+def vector_matmul_func(input, weight, async_op, q_scale, q_int8, transposed_mode):
+ assert not transposed_mode and not async_op and not q_int8
+ return matmul_ext.matmul(input, weight, bias=None, activation="", use_triton=True)
+
+
+def fused_gemm_gelu(input,
+ weight,
+ weight_scale,
+ bias,
+ weight_out,
+ weight_out_scale,
+ epsilon,
+ pre_layer_norm,
+ q_int8,
+ async_op,
+ transposed_mode,
+ use_triton_ln=True):
+ assert not transposed_mode
+
+ # activation
+ activation = "gelu"
+
+ # intermediate fc in FF
+ intm_out = matmul_ext.matmul(input, weight, bias=bias, activation=activation, use_triton=True)
+
+ # output fc in FF
+ ff_out = matmul_ext.matmul(
+ intm_out,
+ weight_out,
+ bias=None,
+ activation="", # bias added layer with residual_add + bias + layerNorm layer
+ use_triton=True)
+ return ff_out
+
+
+def linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads, transposed_mode=False):
+ assert not transposed_mode and not do_flash_attn
+ qkv_out = matmul_ext.matmul(input, weight, bias=(bias if add_bias else None), activation="", use_triton=True)
+
+ return qkv_out
+
+
+def mlp_gemm_func(input,
+ residual,
+ input_bias,
+ weight_interm,
+ weight_out,
+ bias,
+ gamma,
+ beta,
+ epsilon,
+ pre_layer_norm,
+ mlp_after_attn,
+ weight_interm_scale,
+ weight_out_scale,
+ q_int8,
+ mlp_act_func_type,
+ transposed_mode,
+ use_triton_ln=True):
+ assert not transposed_mode
+
+ # residual add and layerNorm after attention
+ if use_triton_ln:
+ mlp_input = layer_norm_residual(input, input_bias, residual, gamma, beta, epsilon)
+ else:
+ global inference_module
+ if inference_module is None:
+ inference_module = InferenceBuilder().load()
+ mlp_input = inference_module._layer_norm_residual(input, input_bias, residual, gamma, beta, epsilon)
+
+ # activation
+ if deepspeed.utils.types.ActivationFuncType(mlp_act_func_type) == deepspeed.utils.types.ActivationFuncType.GELU:
+ activation = "gelu"
+ elif deepspeed.utils.types.ActivationFuncType(mlp_act_func_type) == deepspeed.utils.types.ActivationFuncType.ReLU:
+ activation = "relu"
+ else:
+ activation = ""
+
+ # intermediate fc in FF
+ intm_out = matmul_ext.matmul(mlp_input, weight_interm, bias=bias, activation=activation, use_triton=True)
+ # output fc in FF
+ ff_out = matmul_ext.matmul(
+ intm_out,
+ weight_out,
+ bias=None,
+ activation="", # bias added layer with residual_add + bias + layerNorm layer
+ use_triton=True)
+
+ return ff_out, mlp_input
+
+
+def qkv_gemm_func(
+ input,
+ weight,
+ q_scale,
+ bias,
+ gamma,
+ beta,
+ epsilon,
+ add_bias,
+ q_int8,
+ transposed_mode=False,
+ use_triton_ln=True,
+):
+
+ assert not transposed_mode
+ # residual add and layerNorm after attention
+ if use_triton_ln:
+ qkv_input = layer_norm(input, gamma, beta, epsilon)
+ else:
+ global inference_module
+ if inference_module is None:
+ inference_module = InferenceBuilder().load()
+ qkv_input = inference_module.layer_norm(input, gamma, beta, epsilon)
+
+ qkv_out = matmul_ext.matmul(qkv_input, weight, bias=(bias if add_bias else None), activation="", use_triton=True)
+
+ return qkv_out, qkv_input
diff --git a/deepspeed/ops/transformer/inference/triton/residual_add.py b/deepspeed/ops/transformer/inference/triton/residual_add.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c8ff5400af5d768c0fdbadebb17555d2ab6ac53
--- /dev/null
+++ b/deepspeed/ops/transformer/inference/triton/residual_add.py
@@ -0,0 +1,84 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import torch
+import triton
+import triton.language as tl
+
+
+@triton.jit
+def residual_add_bias_kernel(
+ hidden_state_ptr,
+ residual_ptr,
+ attn_output_ptr,
+ hidden_state_size,
+ attn_bias_ptr,
+ final_bias_ptr,
+ bias_size,
+ output_ptr,
+ mp_size: tl.constexpr,
+ mlp_after_attn: tl.constexpr,
+ pre_attn_norm: tl.constexpr,
+ add_attn_bias: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(axis=0)
+
+ block_start = pid * BLOCK_SIZE
+
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
+ mask = offsets < hidden_state_size
+
+ bias_offsets = offsets % bias_size
+ bias_mask = bias_offsets < bias_size
+
+ tl_hidden_state = tl.load(hidden_state_ptr + offsets, mask=mask)
+ tl_residual = tl.load(residual_ptr + offsets, mask=mask)
+ tl_attn_output = tl.load(attn_output_ptr + offsets, mask=mask)
+ tl_attn_bias = tl.load(attn_bias_ptr + bias_offsets, mask=bias_mask)
+ tl_final_bias = tl.load(final_bias_ptr + bias_offsets, mask=bias_mask)
+
+ if mlp_after_attn:
+ if pre_attn_norm:
+ output = tl_hidden_state + (tl_residual + tl_final_bias + tl_attn_output + tl_attn_bias) / mp_size
+ else:
+ output = tl_hidden_state + tl_residual + tl_final_bias
+ else:
+ output = tl_hidden_state + tl_attn_output + (tl_residual + tl_final_bias) / mp_size
+ if add_attn_bias:
+ output += tl_attn_bias / mp_size
+
+ tl.store(output_ptr + offsets, output, mask=mask)
+
+
+def residual_add_bias(hidden_state: torch.Tensor, residual: torch.Tensor, attn_output: torch.Tensor,
+ attn_bias: torch.Tensor, final_bias: torch.Tensor, mp_size: int, mlp_after_attn: bool,
+ add_attn_bias: bool, pre_attn_norm: bool):
+ # check that all tensors are on the same device
+ assert hidden_state.is_cuda and residual.is_cuda and attn_output.is_cuda \
+ and attn_bias.is_cuda and final_bias.is_cuda
+
+ # check that all tensors have the same dtype
+ assert hidden_state.dtype == residual.dtype == attn_output.dtype \
+ == attn_bias.dtype == final_bias.dtype
+
+ # check that all tensors have the right shape
+ assert hidden_state.shape == residual.shape == attn_output.shape
+ assert attn_bias.shape == final_bias.shape
+ assert attn_bias.shape[0] == hidden_state.shape[2]
+
+ output = torch.empty_like(hidden_state)
+
+ hidden_state_size = output.numel()
+ bias_size = attn_bias.numel()
+
+ grid = lambda meta: (triton.cdiv(hidden_state_size, meta['BLOCK_SIZE']), )
+
+ residual_add_bias_kernel[grid](hidden_state, residual, attn_output, hidden_state_size,\
+ attn_bias, final_bias, bias_size, output, mp_size, mlp_after_attn, pre_attn_norm, \
+ add_attn_bias, \
+ BLOCK_SIZE=1024)
+
+ return output
diff --git a/deepspeed/ops/transformer/inference/triton/softmax.py b/deepspeed/ops/transformer/inference/triton/softmax.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ee10d63e6cf8bfa6723856b53b7ca9ec30d3fdd
--- /dev/null
+++ b/deepspeed/ops/transformer/inference/triton/softmax.py
@@ -0,0 +1,89 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import torch
+import triton
+import triton.language as tl
+'''
+softmax
+modified the triton kernel in
+https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
+'''
+
+
+@triton.jit
+def softmax_kernel(output_ptr, input_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr):
+ row_idx = tl.program_id(0)
+ row_start_ptr = input_ptr + row_idx * stride
+ col_offsets = tl.arange(0, BLOCK_SIZE)
+ input_ptrs = row_start_ptr + col_offsets
+ row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
+ row_minus_max = row - tl.max(row, axis=0)
+ numerator = tl.exp(row_minus_max)
+ denominator = tl.sum(numerator, axis=0)
+ softmax_output = numerator / denominator
+ output_row_start_ptr = output_ptr + row_idx * stride
+ output_ptrs = output_row_start_ptr + col_offsets
+ tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
+
+
+@triton.jit
+def masked_softmax_kernel(output_ptr, input_ptr, stride, mask_ptr, mask_stride, n_cols, BLOCK_SIZE: tl.constexpr):
+ row_idx = tl.program_id(0)
+ row_start_ptr = input_ptr + row_idx * stride
+ col_offsets = tl.arange(0, BLOCK_SIZE)
+ input_ptrs = row_start_ptr + col_offsets
+ mask_ptrs = mask_ptr + col_offsets + row_idx * mask_stride # mask_stride is 0 for 1d mask
+ row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
+ mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
+ row_minus_max = row - tl.max(row, axis=0)
+ row_minus_max = row_minus_max + mask
+ numerator = tl.exp(row_minus_max)
+ denominator = tl.sum(numerator, axis=0)
+ softmax_output = numerator / denominator
+ output_row_start_ptr = output_ptr + row_idx * stride
+ output_ptrs = output_row_start_ptr + col_offsets
+ tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
+
+
+def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
+ assert input.is_contiguous()
+ assert (dim == -1) or (dim == len(input.shape) - 1), "Only dim=-1 is supported"
+
+ use_mask = False if mask is None else True
+ input_arg = input.view(-1, input.shape[-1])
+ n_rows, n_cols = input_arg.shape
+ BLOCK_SIZE = max(triton.next_power_of_2(n_cols), 2)
+ num_warps = 4
+ if BLOCK_SIZE >= 2048:
+ num_warps = 8
+ if BLOCK_SIZE >= 4096:
+ num_warps = 16
+ # Allocate output
+ output = torch.empty_like(input)
+ if use_mask:
+ assert mask.is_contiguous()
+ mask = mask.view(-1, mask.shape[-1])
+ mask_stride = mask.shape[-1] if mask.shape[-2] > 1 else 0
+ masked_softmax_kernel[(n_rows, )](
+ output,
+ input,
+ input_arg.stride(0),
+ mask,
+ mask_stride,
+ n_cols,
+ num_warps=num_warps,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+ else:
+ softmax_kernel[(n_rows, )](
+ output,
+ input,
+ input_arg.stride(0),
+ n_cols,
+ num_warps=num_warps,
+ BLOCK_SIZE=BLOCK_SIZE,
+ )
+ return output
diff --git a/deepspeed/ops/transformer/inference/triton/triton_matmul_kernel.py b/deepspeed/ops/transformer/inference/triton/triton_matmul_kernel.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d647ea090f60c7645a733495178f8f3b2e6545e
--- /dev/null
+++ b/deepspeed/ops/transformer/inference/triton/triton_matmul_kernel.py
@@ -0,0 +1,377 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import triton
+import triton.language as tl
+from .gelu import gelu_functor
+import torch
+
+AUTOTUNE_TOP_K = 10
+SKIP_AUTOTUNE = False
+
+
+def _fp16_matmul_prune_config(configs, named_args, skip_autotune=SKIP_AUTOTUNE):
+ if skip_autotune:
+ configs = [configs[0]]
+ else:
+ configs = triton.ops.matmul_perf_model.early_config_prune(configs, named_args)
+ return configs
+
+
+"""
+fp16 matmul implementation is adapted from triton matmul:
+https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/triton/ops/matmul.py
+"""
+
+
+@triton.autotune(
+ configs=[
+ # basic configs for compute-bound matmuls
+ triton.Config({
+ 'BLOCK_M': 128,
+ 'BLOCK_N': 256,
+ 'BLOCK_K': 32,
+ 'SPLIT_K': 1
+ }, num_stages=3, num_warps=8),
+ triton.Config({
+ 'BLOCK_M': 256,
+ 'BLOCK_N': 128,
+ 'BLOCK_K': 32,
+ 'SPLIT_K': 1
+ }, num_stages=3, num_warps=8),
+ triton.Config({
+ 'BLOCK_M': 256,
+ 'BLOCK_N': 64,
+ 'BLOCK_K': 32,
+ 'SPLIT_K': 1
+ }, num_stages=4, num_warps=4),
+ triton.Config({
+ 'BLOCK_M': 64,
+ 'BLOCK_N': 256,
+ 'BLOCK_K': 32,
+ 'SPLIT_K': 1
+ }, num_stages=4, num_warps=4),
+ triton.Config({
+ 'BLOCK_M': 128,
+ 'BLOCK_N': 128,
+ 'BLOCK_K': 32,
+ 'SPLIT_K': 1
+ }, num_stages=4, num_warps=4),
+ triton.Config({
+ 'BLOCK_M': 128,
+ 'BLOCK_N': 64,
+ 'BLOCK_K': 32,
+ 'SPLIT_K': 1
+ }, num_stages=4, num_warps=4),
+ triton.Config({
+ 'BLOCK_M': 64,
+ 'BLOCK_N': 128,
+ 'BLOCK_K': 32,
+ 'SPLIT_K': 1
+ }, num_stages=4, num_warps=4),
+ triton.Config({
+ 'BLOCK_M': 128,
+ 'BLOCK_N': 32,
+ 'BLOCK_K': 32,
+ 'SPLIT_K': 1
+ }, num_stages=4, num_warps=4),
+ triton.Config({
+ 'BLOCK_M': 64,
+ 'BLOCK_N': 32,
+ 'BLOCK_K': 32,
+ 'SPLIT_K': 1
+ }, num_stages=5, num_warps=2),
+ ],
+ key=['CACHE_M', 'CACHE_N', 'CACHE_K'],
+ prune_configs_by={
+ 'early_config_prune': _fp16_matmul_prune_config,
+ 'perf_model': None,
+ 'top_k': AUTOTUNE_TOP_K
+ },
+)
+@triton.heuristics({
+ 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
+})
+@triton.jit
+def _fp_matmul(
+ A,
+ B,
+ C,
+ M,
+ N,
+ K,
+ bias,
+ stride_am,
+ stride_ak,
+ stride_bk,
+ stride_bn,
+ stride_cm,
+ stride_cn,
+ CACHE_M,
+ CACHE_N,
+ CACHE_K,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr,
+ SPLIT_K: tl.constexpr,
+ EVEN_K: tl.constexpr,
+ ACC_TYPE: tl.constexpr,
+ BIAS_ADD: tl.constexpr,
+ ACTIVATION: tl.constexpr,
+):
+ # matrix multiplication
+ pid = tl.program_id(0)
+ pid_z = tl.program_id(1)
+ grid_m = (M + BLOCK_M - 1) // BLOCK_M
+ grid_n = (N + BLOCK_N - 1) // BLOCK_N
+ # re-order program ID for better L2 performance
+ width = GROUP_M * grid_n
+ group_id = pid // width
+ group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
+ pid_m = group_id * GROUP_M + (pid % group_size)
+ pid_n = (pid % width) // (group_size)
+ # do matrix multiplication
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
+ rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
+ rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
+ # pointers
+ A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
+ B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
+ acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
+ for k in range(K, 0, -BLOCK_K * SPLIT_K):
+ if EVEN_K:
+ a = tl.load(A)
+ b = tl.load(B)
+ else:
+ a = tl.load(A, mask=rk[None, :] < k, other=0.)
+ b = tl.load(B, mask=rk[:, None] < k, other=0.)
+ acc += tl.dot(a, b)
+ A += BLOCK_K * SPLIT_K * stride_ak
+ B += BLOCK_K * SPLIT_K * stride_bk
+ # bias addition
+ if BIAS_ADD:
+ bias_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ bias_ptr = bias + bias_offset
+ b = tl.load(bias_ptr, mask=bias_offset < N)
+ acc = acc + b[None, :]
+ # activation
+ if ACTIVATION == "relu":
+ acc = tl.where(acc >= 0, acc, 0)
+ elif ACTIVATION == "leaky_relu":
+ acc = tl.where(acc >= 0, acc, 0.01 * acc)
+ elif ACTIVATION == "gelu":
+ #acc = tl.sigmoid(1.702 * acc) * acc
+ acc = gelu_functor(acc)
+ elif ACTIVATION == "sigmoid":
+ acc = tl.sigmoid(acc) # sigmoid
+ acc = acc.to(C.dtype.element_ty)
+ # rematerialize rm and rn to save registers
+ rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
+ mask = (rm < M)[:, None] & (rn < N)[None, :]
+ # handles write-back with reduction-splitting
+ if SPLIT_K == 1:
+ tl.store(C, acc, mask=mask)
+ else:
+ tl.atomic_add(C, acc, mask=mask)
+
+
+def matmul_4d_prune_config(configs, named_args, skip_autotune=SKIP_AUTOTUNE):
+ if skip_autotune:
+ configs = [configs[0]]
+ else:
+ device = torch.cuda.current_device() #ignore-cuda
+ capability = torch.cuda.get_device_capability() #ignore-cuda
+ # BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
+ dtsize = named_args['a_ptr'].element_size()
+ dtype = named_args['a_ptr'].dtype
+
+ # make sure we have enough smem
+ pruned_configs = []
+ for config in configs:
+ kw = config.kwargs
+ BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
+ kw['BLOCK_SIZE_M'], kw['BLOCK_SIZE_N'], kw['BLOCK_SIZE_K'], config.num_stages
+
+ triton.compiler.init_cuda_utils()
+ max_shared_memory = triton.compiler.cuda_utils.get_device_properties(device)["max_shared_mem"]
+ required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
+ if required_shared_memory <= max_shared_memory:
+ pruned_configs.append(config)
+ configs = pruned_configs
+ return configs
+
+
+@triton.autotune(
+ configs=[
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 64,
+ "GROUP_SIZE_M": 8
+ },
+ num_stages=1, # this is mainly for unit test, to minimize the share memory usage
+ num_warps=8),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=4,
+ num_warps=4,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 32,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=5,
+ num_warps=2,
+ ),
+ triton.Config(
+ {
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "BLOCK_SIZE_K": 32,
+ "GROUP_SIZE_M": 8,
+ },
+ num_stages=5,
+ num_warps=2,
+ ),
+ ],
+ key=['CACHE_M', 'CACHE_N', 'CACHE_K'],
+ prune_configs_by={
+ 'early_config_prune': matmul_4d_prune_config,
+ 'perf_model': None,
+ 'top_k': AUTOTUNE_TOP_K
+ },
+)
+@triton.jit
+def matmul_4d_kernel(
+ # Pointers to matrices
+ a_ptr,
+ b_ptr,
+ c_ptr,
+ # Matrix dimensions
+ M,
+ N,
+ K,
+ CACHE_M,
+ CACHE_N,
+ CACHE_K,
+ stride_ab,
+ stride_ah,
+ stride_am,
+ stride_ak,
+ stride_bb,
+ stride_bh,
+ stride_bk,
+ stride_bn,
+ stride_cb,
+ stride_ch,
+ stride_cm,
+ stride_cn,
+ scale,
+ # Meta-parameters
+ BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr,
+ MASK: tl.constexpr,
+):
+ """Kernel for computing the matmul C = A x B.
+ A has shape (M, K), B has shape (K, N) and C has shape (M, N)
+ """
+ pid = tl.program_id(axis=0)
+ head = tl.program_id(axis=1)
+ batch = tl.program_id(axis=2)
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ if MASK:
+ if (pid_m + 1) * BLOCK_SIZE_M - 1 < pid_n * BLOCK_SIZE_N:
+ c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=c_ptr.dtype.element_ty) - float("inf")
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_cm[:, None] +
+ stride_cn * offs_cn[None, :])
+ tl.store(c_ptrs, c)
+ return
+
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+ a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah +
+ (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak))
+ b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh +
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn))
+
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+ for k in range(0, K, BLOCK_SIZE_K):
+ a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K)
+ b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N)
+ a = tl.load(a_ptrs, mask=a_mask, other=0.)
+ b = tl.load(b_ptrs, mask=b_mask, other=0.)
+ accumulator += tl.dot(a, b)
+ a_ptrs += BLOCK_SIZE_K * stride_ak
+ b_ptrs += BLOCK_SIZE_K * stride_bk
+
+ c = accumulator.to(c_ptr.dtype.element_ty)
+ if scale > 0:
+ c = c * scale.to(c_ptr.dtype.element_ty)
+
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
+ if MASK:
+ c += tl.where(offs_cm[:, None] >= offs_cn[None, :], 0, float("-inf"))
+ c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_cm[:, None] +
+ stride_cn * offs_cn[None, :])
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
+ tl.store(c_ptrs, c, mask=c_mask)
diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt
index ce1dd3429eea4c9569e5212495a709a25698b339..731b67a26814685fed42f4932d6d9a196558ae60 100644
--- a/requirements/requirements-dev.txt
+++ b/requirements/requirements-dev.txt
@@ -13,4 +13,5 @@ sphinx-rtd-theme
tensorboard
torchvision
transformers
+triton
wandb
diff --git a/requirements/requirements-inf.txt b/requirements/requirements-inf.txt
index 848a7f7a485de46d9941b3de6919b42724f76409..ef8bfff774d660e92c6803c2892905aafcb9fbce 100644
--- a/requirements/requirements-inf.txt
+++ b/requirements/requirements-inf.txt
@@ -3,3 +3,4 @@ lm-eval==0.3.0
protobuf
transformers
transformers[sentencepiece]
+triton
diff --git a/requirements/requirements-triton.txt b/requirements/requirements-triton.txt
new file mode 100644
index 0000000000000000000000000000000000000000..a59a965090a6473278f0cf9e7fd1d3cb9cb385c9
--- /dev/null
+++ b/requirements/requirements-triton.txt
@@ -0,0 +1 @@
+triton
diff --git a/setup.py b/setup.py
index 5d0aba18f2bb5fa2e514dc15338c2442b9cf09e4..1b5835f3cfe4805a21040bc9ec79262b5139e998 100755
--- a/setup.py
+++ b/setup.py
@@ -67,7 +67,8 @@ extras_require = {
'sparse_attn': fetch_requirements('requirements/requirements-sparse_attn.txt'),
'sparse': fetch_requirements('requirements/requirements-sparse_pruning.txt'),
'inf': fetch_requirements('requirements/requirements-inf.txt'),
- 'sd': fetch_requirements('requirements/requirements-sd.txt')
+ 'sd': fetch_requirements('requirements/requirements-sd.txt'),
+ 'triton': fetch_requirements('requirements/requirements-triton.txt'),
}
# Add specific cupy version to both onebit extension variants.
diff --git a/tests/unit/inference/test_inference.py b/tests/unit/inference/test_inference.py
index c42deb3dd6d743a63f49c93177640bdf98cb0276..f1fac5ba1ecbe2b545ec64dd69de9fb006414101 100644
--- a/tests/unit/inference/test_inference.py
+++ b/tests/unit/inference/test_inference.py
@@ -95,13 +95,18 @@ def enable_cuda_graph(request):
return request.param
+@pytest.fixture(params=[True, False], ids=["Triton", "noTriton"])
+def enable_triton(request):
+ return request.param
+
+
"""
This fixture will validate the configuration
"""
@pytest.fixture()
-def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph):
+def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph, enable_triton):
model, task = model_w_task
msg = ""
if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"):
@@ -125,6 +130,12 @@ def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph):
msg = f"Bloom models only support half precision, cannot use dtype {dtype}"
elif ("bert" not in model.lower()) and enable_cuda_graph:
msg = "Non bert/roberta models do no support CUDA Graph"
+ elif enable_triton and not (dtype in [torch.half]):
+ msg = "Triton is for fp16"
+ elif enable_triton and not deepspeed.HAS_TRITON:
+ msg = "triton needs to be installed for the test"
+ elif ("bert" not in model.lower()) and enable_triton:
+ msg = "Triton kernels do not support Non bert/roberta models yet"
return msg
@@ -249,16 +260,16 @@ Tests
class TestModelTask(DistributedTest):
world_size = 1
- def test(
- self,
- model_w_task,
- dtype,
- enable_cuda_graph,
- query,
- inf_kwargs,
- assert_fn,
- invalid_model_task_config,
- ):
+ def test(self,
+ model_w_task,
+ dtype,
+ enable_cuda_graph,
+ enable_triton,
+ query,
+ inf_kwargs,
+ assert_fn,
+ invalid_model_task_config,
+ perf_meas=True):
if invalid_model_task_config:
pytest.skip(invalid_model_task_config)
@@ -284,13 +295,18 @@ class TestModelTask(DistributedTest):
get_accelerator().synchronize()
bs_time = time.time() - start
- pipe.model = deepspeed.init_inference(
- pipe.model,
- mp_size=1,
- dtype=dtype,
- replace_with_kernel_inject=True,
- enable_cuda_graph=enable_cuda_graph,
- )
+ args = {
+ 'mp_size': 1,
+ 'dtype': dtype,
+ 'replace_with_kernel_inject': True,
+ 'enable_cuda_graph': enable_cuda_graph,
+ 'use_triton': enable_triton,
+ 'triton_autotune': False,
+ }
+ if pipe.tokenizer.model_max_length < deepspeed.ops.transformer.inference.config.DeepSpeedInferenceConfig(
+ ).max_out_tokens:
+ args.update({'max_out_tokens': pipe.tokenizer.model_max_length})
+ pipe.model = deepspeed.init_inference(pipe.model, **args)
check_injection(pipe.model)
# Warm-up queries for perf measurement
#for i in range(10):
@@ -301,6 +317,11 @@ class TestModelTask(DistributedTest):
get_accelerator().synchronize()
ds_time = time.time() - start
+ if perf_meas:
+ print(
+ f"model={model}, task={task}, dtype={dtype}, cuda_graph={enable_cuda_graph}, triton={enable_triton}, bs_time={bs_time}, ds_time={ds_time}"
+ )
+
# facebook/opt* and some bigscient/bloom* models are not matching
# baseline exactly, adding an exception to them for now
if ("opt" in model) or ("bloom" in model):
@@ -309,6 +330,7 @@ class TestModelTask(DistributedTest):
# These performance tests are only measuring the time for a single
# inference request, we just want to check that performance isn't terrible
#assert ds_time <= (bs_time * 1.1)
+
assert assert_fn(bs_output, ds_output)
diff --git a/tests/unit/ops/transformer/inference/test_attention.py b/tests/unit/ops/transformer/inference/test_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..760223115a280eca5b756372ca0520953754da10
--- /dev/null
+++ b/tests/unit/ops/transformer/inference/test_attention.py
@@ -0,0 +1,73 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import pytest
+import torch
+import deepspeed
+
+
+# reference timplementation
+def ref_torch_attention(q, k, v, mask, sm_scale):
+ p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
+ p = torch.softmax(p.float() + mask, dim=-1).half()
+ ref_out = torch.matmul(p, v)
+ return ref_out
+
+
+# test attention operator
+@pytest.mark.inference_ops
+@pytest.mark.parametrize("Z", [1]) # batch
+@pytest.mark.parametrize("H", [12]) # heads
+@pytest.mark.parametrize("N_CTX", [4, 128]) # sequence length
+@pytest.mark.parametrize("D_HEAD", [64, 128])
+@pytest.mark.parametrize("causal", [True, False])
+def test_attention(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
+ if not deepspeed.HAS_TRITON:
+ pytest.skip("triton has to be installed for the test")
+
+ # skip autotune in testing
+ from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul
+ fp16_matmul.skip_autotune()
+
+ import triton
+ from deepspeed.ops.transformer.inference.triton.attention import compute_attention
+ torch.manual_seed(20)
+ q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5)
+ k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5)
+ v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5)
+ sm_scale = 0.3
+
+ # reference implementation
+ p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
+ score = p
+ mask = torch.zeros((Z, H, N_CTX, N_CTX), dtype=dtype, device="cuda")
+ M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
+ if causal:
+ for z in range(Z):
+ for h in range(H):
+ mask[:, :, M == 0] = float("-inf")
+ p = torch.softmax(p.float() + mask, dim=-1).half()
+ softmax_out = p
+ ref_out = torch.matmul(p, v)
+ context = ref_out
+
+ # adjust it to expected tensor format and run test
+ qkv = torch.randn((Z, N_CTX, 3 * H * D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
+ qkv[:, :, :H * D_HEAD] = q.permute(0, 2, 1, 3).contiguous().reshape((Z, N_CTX, H * D_HEAD))
+ qkv[:, :, 1 * H * D_HEAD:2 * H * D_HEAD] = k.permute(0, 2, 1, 3).contiguous().reshape((Z, N_CTX, H * D_HEAD))
+ qkv[:, :, 2 * H * D_HEAD:] = v.permute(0, 2, 1, 3).contiguous().reshape((Z, N_CTX, H * D_HEAD))
+ tri_out = compute_attention(qkv,
+ input_mask=mask,
+ layer_past=None,
+ alibi=None,
+ scale=sm_scale,
+ head_size=D_HEAD,
+ triangular=False,
+ use_cuda_flash=False,
+ use_triton_flash=False,
+ use_ds_attention=False)
+ tri_out = tri_out.reshape((Z, N_CTX, H, D_HEAD)).permute(0, 2, 1, 3)
+ triton.testing.allclose(ref_out, tri_out)
+ triton.testing.assert_almost_equal(ref_out, tri_out)
diff --git a/tests/unit/ops/transformer/inference/test_gelu.py b/tests/unit/ops/transformer/inference/test_gelu.py
new file mode 100644
index 0000000000000000000000000000000000000000..de924848bfb4a42de52cc4c84aca024805e861a6
--- /dev/null
+++ b/tests/unit/ops/transformer/inference/test_gelu.py
@@ -0,0 +1,70 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import pytest
+import torch
+import deepspeed
+from deepspeed.ops.op_builder import InferenceBuilder
+
+if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
+ pytest.skip("Inference ops are not available on this system", allow_module_level=True)
+
+inference_module = None
+torch_minor_version = None
+
+
+def allclose(x, y):
+ assert x.dtype == y.dtype
+ rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype]
+ return torch.allclose(x, y, rtol=rtol, atol=atol)
+
+
+def version_appropriate_gelu(activations):
+ global torch_minor_version
+ if torch_minor_version is None:
+ torch_minor_version = int(torch.__version__.split('.')[1])
+ # If torch version = 1.12
+ if torch_minor_version < 12:
+ return torch.nn.functional.gelu(activations)
+ else:
+ return torch.nn.functional.gelu(activations, approximate='tanh')
+
+
+def run_gelu_reference(activations):
+ # Expected behavior is that of casting to float32 internally and using the tanh approximation
+ return version_appropriate_gelu(activations.to(torch.float32)).to(activations.dtype)
+
+
+def run_gelu_ds(activations, use_triton_ops=False):
+ if use_triton_ops:
+ from deepspeed.ops.transformer.inference.triton import gelu
+ return gelu(activations)
+
+ channels = activations.shape[-1]
+ bias = torch.zeros((channels), dtype=activations.dtype, device='cuda')
+ global inference_module
+ if inference_module is None:
+ inference_module = InferenceBuilder().load()
+ if activations.dtype == torch.float16:
+ return inference_module.bias_gelu_fp16(activations, bias)
+ else:
+ return inference_module.bias_gelu_fp32(activations, bias)
+
+
+@pytest.mark.inference_ops
+@pytest.mark.parametrize("batch", [1, 2])
+@pytest.mark.parametrize("sequence", [1, 128, 255])
+@pytest.mark.parametrize("channels", [512, 1232, 4096])
+@pytest.mark.parametrize("dtype", [torch.float16])
+@pytest.mark.parametrize("use_triton_ops", [True, False])
+def test_gelu(batch, sequence, channels, dtype, use_triton_ops):
+ activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device='cuda')
+ activations_ref = activations_ds.clone().detach()
+
+ if not deepspeed.HAS_TRITON and use_triton_ops:
+ pytest.skip("triton has to be installed for the test")
+ ds_out = run_gelu_ds(activations_ds, use_triton_ops)
+ ref_out = run_gelu_reference(activations_ref)
+ assert (allclose(ds_out, ref_out))
diff --git a/tests/unit/ops/transformer/inference/test_layer_norm.py b/tests/unit/ops/transformer/inference/test_layer_norm.py
index 25728d41a0cd80c14b01425eba58f92b91def421..952904a7847cd855bd7e195ee5480236c0494c32 100644
--- a/tests/unit/ops/transformer/inference/test_layer_norm.py
+++ b/tests/unit/ops/transformer/inference/test_layer_norm.py
@@ -9,6 +9,14 @@ import pytest
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder
from .inference_test_utils import allclose, get_dtypes
+try:
+ import triton # noqa: F401
+ from deepspeed.ops.transformer.inference.triton import (
+ layer_norm,
+ layer_norm_residual,
+ )
+except ImportError:
+ print("triton import failed")
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)
@@ -30,19 +38,32 @@ def ds_implementation(vals, gamma, beta, epsilon):
return inference_module.layer_norm(vals, gamma, beta, epsilon)
+def ds_triton_implementation(vals, gamma, beta, epsilon):
+ return layer_norm(vals, gamma, beta, epsilon)
+
+
@pytest.mark.inference_ops
@pytest.mark.parametrize("batch", [1, 32])
@pytest.mark.parametrize("seq_len", [1, 128])
@pytest.mark.parametrize("channels", [384, 512, 768, 1024, 2048, 8192, 14432])
@pytest.mark.parametrize("dtype", get_dtypes())
-def test_layer_norm(batch, seq_len, channels, dtype):
+@pytest.mark.parametrize("use_triton_ops", [False, True])
+def test_layer_norm(batch, seq_len, channels, dtype, use_triton_ops):
+ if not deepspeed.HAS_TRITON and use_triton_ops:
+ pytest.skip("triton has to be installed for the test")
+
vals = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name())
gamma = torch.randn((channels), dtype=dtype, device=get_accelerator().current_device_name())
beta = torch.rand((channels), dtype=dtype, device=get_accelerator().current_device_name())
epsilon = 1e-5
ref_output = ref_implementation(vals, gamma, beta, epsilon, channels, dtype)
- new_output = ds_implementation(vals, gamma, beta, epsilon)
+ if use_triton_ops:
+ new_output = ds_triton_implementation(vals, gamma, beta, epsilon)
+ if dtype != torch.float16: # fp16 supported in triton
+ return
+ else:
+ new_output = ds_implementation(vals, gamma, beta, epsilon)
if not allclose(new_output, ref_output):
#print(new_output - ref_output)
@@ -68,12 +89,20 @@ def residual_ds_implementation(vals, bias, res, gamma, beta, epsilon):
return inference_module._layer_norm_residual(vals, bias, res, gamma, beta, epsilon)
+def residual_ds_triton_implementation(vals, bias, res, gamma, beta, epsilon):
+ return layer_norm_residual(vals, bias, res, gamma, beta, epsilon)
+
+
@pytest.mark.inference_ops
@pytest.mark.parametrize("batch", [1, 32])
@pytest.mark.parametrize("seq_len", [1, 128])
@pytest.mark.parametrize("channels", [384, 512, 768, 1024, 2048, 8192, 14432])
@pytest.mark.parametrize("dtype", get_dtypes())
-def test_layer_norm_residual(batch, seq_len, channels, dtype):
+@pytest.mark.parametrize("use_triton_ops", [False, True])
+def test_layer_norm_residual(batch, seq_len, channels, dtype, use_triton_ops):
+ if not deepspeed.HAS_TRITON and use_triton_ops:
+ pytest.skip("triton has to be installed for the test")
+
vals = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name())
residual = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name())
bias = torch.randn((channels), dtype=dtype, device=get_accelerator().current_device_name())
@@ -81,7 +110,13 @@ def test_layer_norm_residual(batch, seq_len, channels, dtype):
beta = torch.rand((channels), dtype=dtype, device=get_accelerator().current_device_name())
epsilon = 1e-5
- new_output = residual_ds_implementation(vals, bias, residual, gamma, beta, epsilon)
+ if use_triton_ops:
+ new_output = residual_ds_triton_implementation(vals, bias, residual, gamma, beta, epsilon)
+ if dtype != torch.float16: # fp16 supported in triton
+ return
+ else:
+ new_output = residual_ds_implementation(vals, bias, residual, gamma, beta, epsilon)
+
ref_output = residual_ref_implementation(vals, bias, residual, gamma, beta, epsilon, channels, dtype)
print((new_output - ref_output).abs().max())
@@ -129,3 +164,38 @@ def test_layer_norm_residual_store_pre_ln_res(batch, seq_len, channels, dtype):
assert allclose(ds_res_output, norm_res_output)
assert allclose(ds_norm_output, ref_norm_output)
+
+
+@pytest.mark.inference_ops
+@pytest.mark.parametrize("M", [4])
+@pytest.mark.parametrize("N", [4])
+@pytest.mark.parametrize("dtype", [torch.float16])
+@pytest.mark.parametrize("residual", [True, False])
+@pytest.mark.parametrize("input_bias", [True, False])
+def test_triton_layer_norm(M, N, dtype, residual, input_bias, eps=1e-5, device='cuda'):
+ if not deepspeed.HAS_TRITON:
+ pytest.skip("triton has to be installed for the test")
+ torch.manual_seed(0)
+ # create data
+ x_shape = (M, N)
+ w_shape = (x_shape[-1], )
+ weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=False)
+ bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=False)
+ x_bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=False)
+ x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
+ dy = .1 * torch.randn_like(x)
+ if residual:
+ res = torch.rand(x_shape, dtype=dtype, device='cuda', requires_grad=False)
+ else:
+ res = torch.zeros(x_shape, dtype=dtype, device='cuda', requires_grad=False)
+ x.requires_grad_(True)
+ # forward pass
+ if residual or input_bias:
+ y_tri = layer_norm_residual(x, x_bias if input_bias else None, res, weight, bias, eps)
+ else:
+ y_tri = layer_norm(x, weight, bias, eps)
+ y_ref = torch.nn.functional.layer_norm(x + res + (x_bias if input_bias else 0), w_shape, weight, bias,
+ eps).to(dtype)
+ # compare
+ #print(f"y_tri={y_tri}, y_ref={y_ref}")
+ triton.testing.assert_almost_equal(y_tri, y_ref)
diff --git a/tests/unit/ops/transformer/inference/test_matmul.py b/tests/unit/ops/transformer/inference/test_matmul.py
new file mode 100644
index 0000000000000000000000000000000000000000..804a85750a3a0259a9d9cef0c9b967ab1955027e
--- /dev/null
+++ b/tests/unit/ops/transformer/inference/test_matmul.py
@@ -0,0 +1,59 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import pytest
+import torch
+import deepspeed
+from deepspeed.ops.op_builder import InferenceBuilder
+
+if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
+ pytest.skip("Inference ops are not available on this system", allow_module_level=True)
+
+inference_module = None
+torch_minor_version = None
+
+
+def allclose(x, y):
+ assert x.dtype == y.dtype
+ rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (5e-2, 2e-3)}[x.dtype]
+ return torch.allclose(x, y, rtol=rtol, atol=atol)
+
+
+def run_matmul_ref(a, b):
+ return torch.matmul(a, b)
+
+
+def run_matmul_ds(a, b, use_triton_ops=False):
+ if use_triton_ops:
+ from deepspeed.ops.transformer.inference.triton import matmul_4d as matmul
+ return matmul(a, b)
+
+ assert use_triton_ops, "Only triton softmax is supported for now"
+
+
+@pytest.mark.inference_ops
+@pytest.mark.parametrize("B", [1, 2])
+@pytest.mark.parametrize("H", [1, 2, 16])
+@pytest.mark.parametrize("M", [1, 7, 8, 128])
+@pytest.mark.parametrize("K", [2, 5, 16, 128])
+@pytest.mark.parametrize("N", [1, 2, 8, 512])
+@pytest.mark.parametrize("dtype", [torch.float16])
+@pytest.mark.parametrize("use_triton_ops", [True])
+def test_matmul_4d(B, H, M, K, N, dtype, use_triton_ops):
+ if not deepspeed.HAS_TRITON and use_triton_ops:
+ pytest.skip("triton has to be installed for the test")
+
+ # skip autotune in testing
+ from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul
+ fp16_matmul.skip_autotune()
+
+ a_ds = torch.randn((B, H, M, K), dtype=dtype, device='cuda')
+ b_ds = torch.randn((B, H, K, N), dtype=dtype, device='cuda')
+ a_ref = a_ds.clone().detach()
+ b_ref = b_ds.clone().detach()
+
+ ds_out = run_matmul_ds(a_ds, b_ds, use_triton_ops)
+ ref_out = run_matmul_ref(a_ref, b_ref)
+ assert (allclose(ds_out, ref_out))
diff --git a/tests/unit/ops/transformer/inference/test_residual_add.py b/tests/unit/ops/transformer/inference/test_residual_add.py
index 1a9d8975852c0ae2afdba9fc3ae15e3c24a7913b..c2952f74ff2d7414fcaf6ae4524c50941979d475 100644
--- a/tests/unit/ops/transformer/inference/test_residual_add.py
+++ b/tests/unit/ops/transformer/inference/test_residual_add.py
@@ -74,8 +74,11 @@ def run_residual_add_reference(hidden_state, residual, attn_output, attn_bias, f
@pytest.mark.parametrize("add_bias", [True, False])
@pytest.mark.parametrize("mp_size", [1, 2])
@pytest.mark.parametrize("pre_attn_norm", [True, False])
+@pytest.mark.parametrize("use_triton_ops", [True, False])
def test_residual_add(inference_module, batch, sequence, hidden_dim, dtype, mlp_after_attn, add_bias, mp_size,
- pre_attn_norm):
+ pre_attn_norm, use_triton_ops):
+ if not deepspeed.HAS_TRITON and use_triton_ops and dtype == torch.float16:
+ pytest.skip("triton has to be installed for the test")
ds_out = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
residual = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
attn_output = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
@@ -90,6 +93,9 @@ def test_residual_add(inference_module, batch, sequence, hidden_dim, dtype, mlp_
ds_out, residual, attn_output, attn_bias, final_bias, mp_size, mlp_after_attn, add_bias, pre_attn_norm
]
+ if use_triton_ops:
+ from deepspeed.ops.transformer.inference.triton import residual_add_bias
+ ds_out = residual_add_bias(*res_add_args)
if dtype == torch.float16:
ds_out = inference_module.residual_add_bias_fp16(*res_add_args)
elif dtype == torch.float32:
@@ -97,7 +103,12 @@ def test_residual_add(inference_module, batch, sequence, hidden_dim, dtype, mlp_
elif dtype == torch.bfloat16:
ds_out = inference_module.residual_add_bias_bf16(*res_add_args)
else:
- raise ValueError(f"Unsupported dtype: {dtype}")
+ if dtype == torch.float16:
+ ds_out = inference_module.residual_add_bias_fp16(*res_add_args)
+ elif dtype == torch.float32:
+ ds_out = inference_module.residual_add_bias_fp32(*res_add_args)
+ else:
+ raise ValueError(f"Unsupported dtype: {dtype}")
if not allclose(ds_out, ref_out):
print((ds_out - ref_out).abs().max())
diff --git a/tests/unit/ops/transformer/inference/test_softmax.py b/tests/unit/ops/transformer/inference/test_softmax.py
new file mode 100644
index 0000000000000000000000000000000000000000..76046f31e01ad122b253a058b6040ed988c17cf1
--- /dev/null
+++ b/tests/unit/ops/transformer/inference/test_softmax.py
@@ -0,0 +1,51 @@
+# Copyright (c) Microsoft Corporation.
+# SPDX-License-Identifier: Apache-2.0
+
+# DeepSpeed Team
+
+import pytest
+import torch
+import deepspeed
+from deepspeed.ops.op_builder import InferenceBuilder
+
+if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
+ pytest.skip("Inference ops are not available on this system", allow_module_level=True)
+
+inference_module = None
+torch_minor_version = None
+
+
+def allclose(x, y):
+ assert x.dtype == y.dtype
+ rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype]
+ return torch.allclose(x, y, rtol=rtol, atol=atol)
+
+
+def run_softmax_reference(input):
+ return torch.nn.functional.softmax(input, dim=-1)
+
+
+def run_softmax_ds(input, use_triton_ops=False):
+ if use_triton_ops:
+ from deepspeed.ops.transformer.inference.triton import softmax
+ # return torch.empty_like(input)
+ return softmax(input)
+
+ assert use_triton_ops, "Only triton softmax is supported for now"
+
+
+@pytest.mark.inference_ops
+@pytest.mark.parametrize("batch", [1, 2])
+@pytest.mark.parametrize("sequence", [1, 128, 255, 1232])
+@pytest.mark.parametrize("channels", [512, 4096])
+@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
+@pytest.mark.parametrize("use_triton_ops", [True])
+def test_softmax(batch, sequence, channels, dtype, use_triton_ops):
+ if not deepspeed.HAS_TRITON and use_triton_ops:
+ pytest.skip("triton has to be installed for the test")
+ input_ds = torch.randn((batch, sequence, channels), dtype=dtype, device='cuda')
+ input_ref = input_ds.clone().detach()
+
+ ds_out = run_softmax_ds(input_ds, use_triton_ops)
+ ref_out = run_softmax_reference(input_ref)
+ assert (allclose(ds_out, ref_out))