未验证 提交 4dc65f7b 编写于 作者: S stephen youn 提交者: GitHub

DeepSpeed-Triton for Inference (#3748)

Co-authored-by: NStephen Youn <styoun@microsoft.com>
Co-authored-by: NArash Bakhtiari <arash@bakhtiari.org>
Co-authored-by: NCheng Li <pistasable@gmail.com>
Co-authored-by: NEthan Doe <yidoe@microsoft.com>
Co-authored-by: Nyidoe <68296935+yidoe@users.noreply.github.com>
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 2c62cb4c
......@@ -27,7 +27,7 @@ jobs:
- name: Install deepspeed
run: |
pip install .[dev,autotuning]
pip install .[dev,autotuning,triton]
ds_report
- name: Formatting checks
......
# DeepSpeed with Triton compiler
# 1. Overview
We have integrated [Triton](https://github.com/openai/triton), an open source compiler for GPU programming, into DeepSpeed, which further boosts the inference speed of BERT-like models in float16 precision.
By replacing some CUDA kernels or torch operators with Triton kernels, we achieved 1.14\~1.68x speedup (or 12\~41% latency reduction) for different models and GPUs, as shown in Table 1.
<div align="center">
| Hardware | Bert-base | Bert-large | Roberta-base | Roberta-large |
|----------|:------:|:------:|:------:|:------:|
| A100 |1.65x | 1.68x | 1.53x | 1.61x |
| V100 | 1.29x | 1.14x | 1.23x | 1.21x |
Table 1. The average speedup (see NOTE below for more detail)
</div>
For those transformer operators in float16, we have implemented kernels written in Triton language that replace ordinary CUDA kernels or torch operators.
The Triton kernels we implemented include softmax, layer-normalization, residual-addition and all the matrix multiplications except MLP layers (see NOTE below for details).
In our experiments, Triton kernels help to reduce the average latecy (over difference sequence lengths) by 6\~24% (depending on model and hardware) when compared to the latency with CUDA-only kernels.
Figures below show the latency reduction in more detail.
Figure 1 visualizes latency reduction in different sequence lengths in A100 GPU for Bert-base model.
The baseline (blue) is from Huggingface transformers without any kernel injection, the orange is from Deepspeed with CUDA-only kernels and the gray is from Deepspeed with Triton kernels.
Figure 2 shows the same plot for Bert-large model in A100 GPU.
<div align="center">
<img src="../assets/images/triton-bert-base-latency.png" width="500px" alt="triton-bert-base-latency"/>
*Figure 1: Normalized P90 latency for Bert-base model in A100 GPU across different sequence lengths*
<img src="../assets/images/triton-bert-large-latency.png" width="500px" alt="triton-bert-large-latency"/>
*Figure 2: Normalized P90 latency for Bert-large model in A100 GPU across different sequence lengths*
</div>
Next, we dive deeper into this new feature in DeepSpeed.
# 2. How to use Triton in Deepspeed
You can enable Triton compilers to optimize these kernels by setting a flag in the DeepSpeed config file.
```
pipe = pipeline('fill-mask', model='bert-base-cased', framework='pt', device=0)
pipe.model = deepspeed.init_inference(pipe.model,
dtype=torch.float16,
replace_with_kernel_inject=True,
enable_cuda_graph=True,
use_triton=True,
triton_autotune=True,
max_out_tokens=pipe.tokenizer.model_max_length)
```
## Running BERT inference with Triton kernels
We use an example of Bert-base here.
```python
pip install deepspeed[triton]
git clone https://github.com/microsoft/DeepSpeedExamples.git
cd DeepSpeedExamples/inference/huggingface/fill-mask
deepspeed --num_gpus 1 test-bert.py --triton
```
To run a performance benchmark, you can use the following command:
```python
pip install deepspeed[triton]
git clone https://github.com/microsoft/DeepSpeedExamples.git
cd DeepSpeedExamples/benchmarks/inference
deepspeed --num_gpus 1 triton-bert-benchmark.py --model bert-base-cased --dtype fp16 --kernel-inject --deepspeed --graphs --triton
```
# NOTE
<!-- **_NOTE:_** -->
* For more information on how to use DeepSpeed, please visit our [GitHub Page](https://github.com/microsoft/DeepSpeedExamples) and our [website](https://www.deepspeed.ai/), where you can find blog posts, tutorials, and documentation.
* This feature is currently only supported for BERT, Roberta and other BERT-like models, and not for text-generation models yet.
* To achieve the best performance with Triton optimization, you need to activate CUDA graph and ‘triton_autotune’ in the DeepSpeed config. CUDA graph prevents the overhead of JIT compilation and a deep call stack in Triton. ‘triton_autotune’ executes an initial step to find the most suitable parameters for Triton kernels, which may take some time.
* We used [Triton 2.0.0.post1 release](https://pypi.org/project/triton/2.0.0.post1/) in our experiments.
* In our experiments, we used a batch size of 1, a sequence length range of 8 to 512, and a ‘fill-mask’ task. Table 1 shows the average P90 latency over the entire sequence length range, while Figures 1 and 2 show the P90 latency for specific sub-ranges. The baseline is the Huggingface transformers without any optimization. The speedup is calculated as (baseline P90 latency)/(DeepSpeed-Triton P90 Latency). We found that the CUDA kernel in MLP performed better than the Triton kernel in our experiments, so we used a hybrid approach that combines both kernels when Triton is enabled in the DeepSpeed config.
......@@ -12,6 +12,12 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from packaging import version as pkg_version
try:
import triton # noqa: F401
HAS_TRITON = True
except ImportError:
HAS_TRITON = False
from . import ops
from . import module_inject
......
......@@ -4,6 +4,7 @@
# DeepSpeed Team
import torch
import deepspeed
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
from deepspeed.runtime.zero.config import DeepSpeedZeroConfig
from pydantic import Field
......@@ -152,6 +153,18 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
can run faster using the graph replay method.
"""
use_triton: bool = False
"""
Use this flag to use triton kernels for inference ops.
"""
triton_autotune: bool = False
"""
Use this flag to enable triton autotuning.
Turning it on is better for performance but increase the 1st runtime for
autotuning.
"""
zero: DeepSpeedZeroConfig = {}
"""
ZeRO configuration to use with the Inference Engine. Expects a dictionary
......@@ -279,6 +292,12 @@ class DeepSpeedInferenceConfig(DeepSpeedConfigModel):
return DeepSpeedMoEConfig(moe=field_value)
return field_value
@validator("use_triton")
def has_triton(cls, field_value, values):
if field_value and not deepspeed.HAS_TRITON:
raise ValueError('Triton needs to be installed to use deepspeed with triton kernels')
return field_value
class Config:
# Get the str representation of the datatype for serialization
json_encoders = {torch.dtype: lambda x: str(x)}
......@@ -607,6 +607,7 @@ class InferenceEngine(Module):
else:
self._create_cuda_graph(*inputs, **kwargs)
outputs = self._graph_replay(*inputs, **kwargs)
else:
outputs = self.module(*inputs, **kwargs)
......
......@@ -12,6 +12,10 @@ from deepspeed.ops.transformer.inference.ds_mlp import DeepSpeedMLP
from deepspeed.ops.transformer.inference.ds_attention import DeepSpeedSelfAttention, BloomSelfAttention
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder
import deepspeed
if deepspeed.HAS_TRITON:
from deepspeed.ops.transformer.inference.triton.mlp import TritonMLP
from deepspeed.ops.transformer.inference.triton.attention import TritonSelfAttention
inference_module = None
......@@ -55,14 +59,24 @@ class DeepSpeedTransformerInference(nn.Module):
if DeepSpeedTransformerInference.layer_id == 1:
log_dist(f"DeepSpeed-Inference config: {self.config.__dict__}", [0])
if deepspeed.HAS_TRITON and self.config.use_triton:
log_dist(f"Injecting Triton kernels ...", [0])
if self.config.bigscience_bloom:
self.attention = BloomSelfAttention(self.config, mp_group, quantize_scales, quantize_groups, merge_count)
assert not self.config.use_triton
else:
self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups,
merge_count)
self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count,
mlp_extra_grouping)
if deepspeed.HAS_TRITON and self.config.use_triton:
self.attention = TritonSelfAttention(self.config)
else:
self.attention = DeepSpeedSelfAttention(self.config, mp_group, quantize_scales, quantize_groups,
merge_count)
if deepspeed.HAS_TRITON and self.config.use_triton:
self.mlp = TritonMLP(self.config)
else:
self.mlp = DeepSpeedMLP(self.config, mp_group, quantize_scales, quantize_groups, merge_count,
mlp_extra_grouping)
device = get_accelerator().current_device_name() # if config.bigscience_bloom else 'cpu'
if self.config.set_empty_params:
......
......@@ -8,6 +8,7 @@ from abc import ABC
import torch
import deepspeed
from deepspeed.ops.transformer.inference.config import DeepSpeedInferenceConfig
from deepspeed.accelerator import get_accelerator
......@@ -79,6 +80,10 @@ class BaseTransformerContainer(ABC):
self.input_nb = None
self.mp_group = None
self.use_triton = False
# Triton
self.use_triton = config.use_triton and deepspeed.HAS_TRITON
def create_ds_model_config(self):
self.set_hidden_heads(*self.policy.get_hidden_heads())
......@@ -110,7 +115,14 @@ class BaseTransformerContainer(ABC):
use_mup=self.use_mup,
return_single_tuple=self.return_single_tuple,
set_empty_params=self.config.set_empty_params,
transposed_mode=self.config.transposed_mode)
transposed_mode=self.config.transposed_mode,
use_triton=self.use_triton,
triton_autotune=self.config.triton_autotune)
if self.use_triton and deepspeed.HAS_TRITON:
if not self.config.triton_autotune:
from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul
fp16_matmul.skip_autotune()
return self.ds_model_config
......
......@@ -18,6 +18,7 @@ class DS_BERTContainer(BaseTransformerContainer):
# All model specific things should be defined here instead of the base class.
self.return_tuple = True
self.triangular_masking = False
self.use_triton = kwargs['config'].use_triton and deepspeed.HAS_TRITON
def create_module(self, config=None):
_config = config if config is not None else self.ds_model_config
......
......@@ -18,6 +18,7 @@ class DS_DistilBERTContainer(BaseTransformerContainer):
# All model specific things should be defined here instead of the base class.
self.triangular_masking = False
self.return_single_tuple = True
self.use_triton = kwargs['config'].use_triton and deepspeed.HAS_TRITON
def create_module(self, config=None):
_config = config if config is not None else self.ds_model_config
......
......@@ -44,6 +44,7 @@ class DeepSpeedInferenceConfig(TransformerConfig):
scale_attention: If true, both q and k are scaled by 1/sqrt(attention_heads) before attention computation.
return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor
bigscience_bloom: This flag is added temporarily for supporting the BLOOM-176B model architecture.
use_triton: This flag is to enable triton kernels in inference or not.
"""
def __init__(self,
......@@ -77,7 +78,9 @@ class DeepSpeedInferenceConfig(TransformerConfig):
scale_attn_by_inverse_layer_idx=False,
return_single_tuple=False,
set_empty_params=False,
transposed_mode=False):
transposed_mode=False,
use_triton=False,
triton_autotune=False):
super(DeepSpeedInferenceConfig,
self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads,
num_hidden_layers)
......@@ -109,6 +112,8 @@ class DeepSpeedInferenceConfig(TransformerConfig):
self.return_single_tuple = return_single_tuple
self.set_empty_params = set_empty_params
self.transposed_mode = transposed_mode
self.use_triton = use_triton
self.triton_autotune = triton_autotune
@classmethod
def from_dict(cls, json_object):
......
......@@ -6,6 +6,7 @@
import torch
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
import deepspeed
class GELUGemmOp(BaseOp):
......@@ -14,9 +15,13 @@ class GELUGemmOp(BaseOp):
super(GELUGemmOp, self).__init__(config)
try:
if self.config.dtype in [torch.float16, torch.int8]:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp16 # type: ignore
if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16:
from deepspeed.ops.transformer.inference.triton.ops import fused_gemm_gelu as _triton_fused_gemm_gelu
self.fused_gemm_gelu = _triton_fused_gemm_gelu # type: ignore
else:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_bf16
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_bf16 # type: ignore
else:
self.fused_gemm_gelu = self.inference_module.fused_gemm_gelu_fp32 # type: ignore
except AttributeError:
......
......@@ -6,6 +6,7 @@
import torch
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
import deepspeed
class LinearOp(BaseOp):
......@@ -14,6 +15,14 @@ class LinearOp(BaseOp):
super(LinearOp, self).__init__(config)
try:
if self.config.dtype in [torch.float16, torch.int8]:
if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16:
from deepspeed.ops.transformer.inference.triton.ops import linear_func as _triton_linear_func
self.linear_func = _triton_linear_func
triton_autotune = config.triton_autotune and config.layer_id == 0
if triton_autotune:
__class__._triton_autotune(2, self.config.max_out_tokens, self.config.hidden_size)
else:
self.linear_func = self.inference_module.linear_layer_fp16
self.linear_func = self.inference_module.linear_layer_fp16
elif self.config.dtype == torch.bfloat16:
self.linear_func = self.inference_module.linear_layer_bf16
......@@ -37,3 +46,15 @@ class LinearOp(BaseOp):
qkv_out = self.linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads,
self.config.transposed_mode)
return qkv_out
@staticmethod
def _triton_autotune(min_seqlen, max_seqlen, hidden_size, dtype=torch.float16):
from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, matmul
seqlen = [(min_seqlen + i)
for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)]
Fp16Matmul._read_autotune_table()
for N in seqlen:
A = torch.randn((N, hidden_size), dtype=dtype, device='cuda')
B = torch.randn((hidden_size, 3 * hidden_size), dtype=dtype, device='cuda')
matmul(A, B)
Fp16Matmul._update_autotune_table()
......@@ -19,7 +19,9 @@ class MLPGemmOp(BaseOp):
super(MLPGemmOp, self).__init__(config)
try:
if self.config.norm_type == NormType.LayerNorm:
if self.config.dtype in [torch.float16, torch.int8]:
if self.config.dtype in [
torch.float16, torch.int8
]: # non-triton cuda kernel has a higher performance in MLP than mlp_gemm_func in triton.ops
self.mlp_gemm_func = self.inference_module.mlp_gemm_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.mlp_gemm_func = self.inference_module.mlp_gemm_bf16
......
......@@ -8,6 +8,7 @@ import torch
import torch.nn.functional as F
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
import deepspeed
from deepspeed.utils.types import NormType
......@@ -18,7 +19,14 @@ class QKVGemmOp(BaseOp):
try:
if self.config.norm_type == NormType.LayerNorm:
if self.config.dtype in [torch.float16, torch.int8]:
self.qkv_gemm_func = self.inference_module.qkv_gemm_fp16 # type: ignore
if deepspeed.HAS_TRITON and self.config.use_triton and self.config.dtype == torch.float16:
from deepspeed.ops.transformer.inference.triton.ops import qkv_gemm_func as _triton_qkv_gemm_func
self.qkv_gemm_func = _triton_qkv_gemm_func
triton_autotune = config.triton_autotune and config.layer_id == 0
if triton_autotune:
__class__._triton_autotune(2, self.config.max_out_tokens, self.config.hidden_size)
else:
self.qkv_gemm_func = self.inference_module.qkv_gemm_fp16 # type: ignore
elif self.config.dtype == torch.bfloat16:
self.qkv_gemm_func = self.inference_module.qkv_gemm_bf16
else:
......@@ -36,6 +44,18 @@ class QKVGemmOp(BaseOp):
elif self.config.norm_type == NormType.RMSNorm:
self.qkv_gemm_func = self.rms_qkv_gemm_fallback
@staticmethod
def _triton_autotune(min_seqlen, max_seqlen, hidden_size, dtype=torch.float16):
from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, matmul
seqlen = [(min_seqlen + i)
for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)]
Fp16Matmul._read_autotune_table()
for N in seqlen:
A = torch.randn((N, hidden_size), dtype=dtype, device='cuda')
B = torch.randn((hidden_size, 3 * hidden_size), dtype=dtype, device='cuda')
matmul(A, B)
Fp16Matmul._update_autotune_table()
def qkv_gemm_fallback(self, input, weight, q_scale, bias, gamma, beta, eps, add_bias, q_int8, transpose):
if os.environ.get('DS_KI_FALLBACK') == 'True' and not transpose:
inp_norm = F.layer_norm(input, (input.shape[2], ), gamma, beta, eps)
......
......@@ -7,6 +7,7 @@ import os
import torch
from ..config import DeepSpeedInferenceConfig
from .base import BaseOp
import deepspeed
class VectorMatMulOp(BaseOp):
......@@ -14,7 +15,16 @@ class VectorMatMulOp(BaseOp):
def __init__(self, config: DeepSpeedInferenceConfig):
super(VectorMatMulOp, self).__init__(config)
try:
if self.config.dtype in [torch.float16, torch.int8]:
if self.config.dtype == torch.float16:
if deepspeed.HAS_TRITON and config.use_triton:
from deepspeed.ops.transformer.inference.triton.ops import vector_matmul_func as _triton_vector_matmul_func
self.vector_matmul_func = _triton_vector_matmul_func
triton_autotune = config.triton_autotune and config.layer_id == 0
if triton_autotune:
__class__._triton_autotune(2, self.config.max_out_tokens, self.config.hidden_size)
else:
self.vector_matmul_func = self.inference_module.vector_matmul_fp16
elif self.config.dtype == torch.int8:
self.vector_matmul_func = self.inference_module.vector_matmul_fp16
elif self.config.dtype == torch.bfloat16:
self.vector_matmul_func = self.inference_module.vector_matmul_bf16
......@@ -34,3 +44,15 @@ class VectorMatMulOp(BaseOp):
q_int8 = self.config.dtype == torch.int8
output = self.vector_matmul_func(input, weight, async_op, q_scale, q_int8, self.config.transposed_mode)
return output
@staticmethod
def _triton_autotune(min_seqlen, max_seqlen, hidden_size, dtype=torch.float16):
from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, matmul
seqlen = [(min_seqlen + i)
for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)]
Fp16Matmul._read_autotune_table()
for N in seqlen:
A = torch.randn((N, hidden_size), dtype=dtype, device='cuda')
B = torch.randn((hidden_size, hidden_size), dtype=dtype, device='cuda')
matmul(A, B)
Fp16Matmul._update_autotune_table()
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .residual_add import residual_add_bias
from .layer_norm import layer_norm, layer_norm_residual
from .gelu import gelu
from .softmax import softmax
from .ops import *
from .matmul_ext import fp16_matmul, matmul_4d, score_4d_matmul, context_4d_matmul
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import math
import torch
import torch.nn as nn
from deepspeed.accelerator import get_accelerator
from deepspeed import comm as dist
from deepspeed.ops.transformer.inference.op_binding import LinearOp, VectorMatMulOp, SoftmaxContextOp, QKVGemmOp
from deepspeed.ops.transformer.inference.triton import (
softmax,
score_4d_matmul,
context_4d_matmul,
)
minus_inf = -10000.0
class TritonSelfAttention(nn.Module):
num_layers = 0
def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1, qkv_merging=False):
super(TritonSelfAttention, self).__init__()
self.config = config
data_type = self.config.dtype
data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype
assert data_type_fp == torch.half, "triton supports fp16 data_type_fp"
self.config.layer_id = TritonSelfAttention.num_layers
TritonSelfAttention.num_layers = TritonSelfAttention.num_layers + 1
device = get_accelerator().current_device_name() #if config.bigscience_bloom else 'cpu'
assert config.mp_size == 1, "mp_size has to be 1 with triton attention yet"
if self.config.set_empty_params:
self.attn_qw = None
self.attn_qb = None
self.attn_kw = None
self.attn_kb = None
self.attn_vw = None
self.attn_vb = None
self.attn_qkvw = None
self.attn_qkvb = None
self.attn_ow = None
self.attn_ob = None
else:
qkv_size_per_partition = (self.config.hidden_size // self.config.mp_size) * 3
self.attn_qkvw = nn.Parameter(torch.empty(self.config.hidden_size,
qkv_size_per_partition,
dtype=data_type,
device=device),
requires_grad=False)
self.attn_qkvb = nn.Parameter(torch.empty(qkv_size_per_partition, dtype=data_type_fp, device=device),
requires_grad=False)
# self-ouput weights
out_size_per_partition = self.config.hidden_size // self.config.mp_size
self.attn_ow = nn.Parameter(torch.empty(out_size_per_partition,
self.config.hidden_size,
dtype=data_type,
device=device),
requires_grad=False)
self.attn_ob = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
requires_grad=False)
self.num_attention_heads_per_partition = self.config.heads // self.config.mp_size
self.hidden_size_per_partition = self.config.hidden_size // self.config.mp_size
self.hidden_size_per_attention_head = self.config.hidden_size // self.config.heads
self.mp_group = mp_group
self.use_flash = False
# used for quantization
self.q_scales = q_scales
self.q_groups = q_groups
self.merge_count = int(math.log2(merge_count))
self.norm_factor = math.sqrt(self.config.hidden_size // self.config.heads)
if not config.use_mup:
self.norm_factor = math.sqrt(self.norm_factor)
if self.config.scale_attn_by_inverse_layer_idx is True:
self.norm_factor *= math.sqrt(self.config.layer_id + 1)
# https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/gpt2/modeling_gpt2.py#L191
triton_autotune = self.config.triton_autotune and self.config.layer_id == 0
self.qkv_func = QKVGemmOp(config)
self.score_context_func = SoftmaxContextOp(config)
self.linear_func = LinearOp(config)
self.vector_matmul_func = VectorMatMulOp(config)
self.hidden_size = config.hidden_size
self.head_size = config.hidden_size // config.heads
self.scale = (1 / self.norm_factor / self.norm_factor if self.config.scale_attention else 1.0
) # making it back to 1/sqrt(head_size)
self.triangular_masking = self.config.triangular_masking
# triton autotune table update for score/context matmul
if triton_autotune:
print(f"running triton autotune for attention")
__class__._triton_autotune(2, self.config.max_out_tokens, self.head_size, self.config.hidden_size,
self.triangular_masking, self.scale)
@staticmethod
def _triton_autotune(min_seqlen,
max_seqlen,
head_size,
hidden_size,
triangular_masking,
scale,
dtype=torch.float16):
from deepspeed.ops.transformer.inference.triton.matmul_ext import Fp16Matmul, score_4d_matmul, context_4d_matmul
seqlen = [(min_seqlen + i)
for i in range(0, max_seqlen - min_seqlen + Fp16Matmul._cache_stride + 1, Fp16Matmul._cache_stride)]
Fp16Matmul._read_autotune_table()
for N in seqlen:
qkv = torch.randn((1, N, 3 * hidden_size), dtype=dtype, device='cuda')
output = score_4d_matmul(qkv, head_size, triangular_masking, scale)
context_4d_matmul(output, qkv, head_size)
Fp16Matmul._update_autotune_table()
def ds_compute_attention(self, qkv_out, input_mask, layer_past, alibi):
if isinstance(qkv_out, list):
qkv_out = qkv_out[0]
no_masking = input_mask is None
if no_masking:
input_mask = torch.empty(1)
attn_key_value = self.score_context_func(
query_key_value=qkv_out,
attn_mask=((1 - input_mask).to(qkv_out.dtype) *
minus_inf) if input_mask.dtype == torch.int64 else input_mask,
heads=self.num_attention_heads_per_partition,
norm_factor=(1 / self.norm_factor if self.config.scale_attention else 1.0),
no_masking=no_masking,
layer_id=self.config.layer_id,
num_layers=TritonSelfAttention.num_layers,
alibi=alibi)
context_layer, key_layer, value_layer = attn_key_value
return context_layer, key_layer, value_layer
def forward(
self,
input,
input_mask,
head_mask=None,
layer_past=None,
get_present=False, # not used
encoder_hidden_states=None, # not used
encoder_attention_mask=None, # not used
triangularutput_attentions=False, # not used
norm_w=None,
norm_b=None,
alibi=None,
use_triton_attention=True):
if not self.config.pre_layer_norm:
qkv_out = self.linear_func(input=input,
weight=self.attn_qkvw,
bias=self.attn_qkvb,
add_bias=self.attn_qkvb is not None,
do_flash_attn=False,
num_heads=self.num_attention_heads_per_partition,
num_layers=TritonSelfAttention.num_layers)
qkv = qkv_out
else:
qkv_out = self.qkv_func(input=input,
weight=self.attn_qkvw,
bias=(self.attn_qkvb if self.attn_qkvb is not None else norm_b),
gamma=norm_w,
beta=norm_b)
qkv = qkv_out[0]
if use_triton_attention and (alibi is None):
context_layer = compute_attention(qkv=qkv,
input_mask=input_mask,
scale=self.scale,
layer_past=layer_past,
alibi=alibi,
head_size=self.head_size,
use_triton_flash=self.use_flash,
use_cuda_flash=False,
triangular=self.triangular_masking)
key_layer, value_layer = qkv[:, :, self.hidden_size:2 * self.hidden_size], qkv[:, :, 2 * self.hidden_size:]
else:
context_layer, key_layer, value_layer = self.ds_compute_attention(qkv_out=qkv_out,
input_mask=input_mask,
layer_past=layer_past,
alibi=alibi)
output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow)
inp_norm = qkv_out[-1]
if self.config.mlp_after_attn and self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1:
dist.all_reduce(output, group=self.mp_group)
return (output, key_layer, value_layer, context_layer, inp_norm)
global inference_module
def compute_attention(qkv,
input_mask,
layer_past,
alibi,
scale,
head_size,
triangular=False,
use_cuda_flash=False,
use_triton_flash=False,
use_ds_attention=False):
if isinstance(qkv, list):
qkv = qkv[0]
#assert layer_past is None, "layer_past not supported in triton yet"
assert alibi is None, "layer_past not supported in alibi yet"
output = score_4d_matmul(qkv, head_size, triangular, scale)
if triangular:
output = softmax(output)
else:
output = softmax(output, input_mask)
output = context_4d_matmul(output, qkv, head_size)
return output
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import triton
import triton.language as tl
@triton.jit
def gelu_functor(x):
# Using approximation introduces greater parity errors.
# return tl.sigmoid(1.702 * x) * x
return x * 0.5 * (1.0 + tl.libdevice.erf(x / 1.41421356237))
@triton.jit
def gelu_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
output = gelu_functor(x)
tl.store(output_ptr + offsets, output, mask=mask)
def gelu(activations: torch.Tensor) -> torch.Tensor:
assert activations.is_contiguous()
assert activations.is_cuda
output = torch.empty_like(activations)
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
gelu_kernel[grid](activations, output, n_elements, BLOCK_SIZE=1024)
return output
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import triton
import triton.language as tl
'''
layer-normalization
modified the triton kernel in
https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/05-layer-norm.py
'''
@triton.jit
def layer_norm_kernel(
Out,
A,
Weight,
Bias,
stride,
N,
eps,
BLOCK_SIZE: tl.constexpr,
):
# position of elements processed by this program
row = tl.program_id(0)
Out += row * stride
A += row * stride
# compute mean
mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32)
a = tl.where(cols < N, a - mean, 0.0)
_var += a * a
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# multiply by weight and add bias
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
weight = tl.load(Weight + cols, mask=mask)
bias = tl.load(Bias + cols, mask=mask)
a = tl.load(A + cols, mask=mask, other=0.0).to(tl.float32)
a_hat = (a - mean) * rstd
out = a_hat * weight + bias
# # write-back
tl.store(Out + cols, out, mask=mask)
@triton.jit
def layer_norm_residual_kernel(
Out,
A,
Residual,
ln_input,
Weight,
Bias,
stride,
N,
eps,
BLOCK_SIZE: tl.constexpr,
):
# position of elements processed by this program
row = tl.program_id(0)
Out += row * stride
A += row * stride
Residual += row * stride
ln_input += row * stride
# compute mean
mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32)
res = tl.load(Residual + cols, mask=cols < N, other=0.0).to(tl.float32)
a = a + res
tl.store(ln_input + cols, a, mask=cols < N)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(ln_input + cols, mask=cols < N, other=0.0).to(tl.float32)
a = tl.where(cols < N, a - mean, 0.0)
_var += a * a
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# multiply by weight and add bias
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
weight = tl.load(Weight + cols, mask=mask)
bias = tl.load(Bias + cols, mask=mask)
a = tl.load(ln_input + cols, mask=mask, other=0.0).to(tl.float32)
a_hat = (a - mean) * rstd
out = a_hat * weight + bias
# write-back
tl.store(Out + cols, out, mask=mask)
@triton.jit
def layer_norm_residual_bias_kernel(
Out,
A,
Residual,
InputBias,
ln_input,
Weight,
Bias,
stride,
N,
eps,
BLOCK_SIZE: tl.constexpr,
):
# position of elements processed by this program
row = tl.program_id(0)
Out += row * stride
A += row * stride
Residual += row * stride
ln_input += row * stride
# compute mean
mean = 0
_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(A + cols, mask=cols < N, other=0.0).to(tl.float32)
res = tl.load(Residual + cols, mask=cols < N, other=0.0).to(tl.float32)
b = tl.load(InputBias + cols, mask=cols < N, other=0.0).to(tl.float32)
a = a + b + res
tl.store(ln_input + cols, a, mask=cols < N)
_mean += a
mean = tl.sum(_mean, axis=0) / N
# compute variance
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(ln_input + cols, mask=cols < N, other=0.0).to(tl.float32)
a = tl.where(cols < N, a - mean, 0.0)
_var += a * a
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
# multiply by weight and add bias
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
weight = tl.load(Weight + cols, mask=mask)
bias = tl.load(Bias + cols, mask=mask)
a = tl.load(ln_input + cols, mask=mask, other=0.0).to(tl.float32)
a_hat = (a - mean) * rstd
out = a_hat * weight + bias
# write-back
tl.store(Out + cols, out, mask=mask)
def layer_norm(a, weight, bias, eps):
assert a.is_contiguous()
assert weight.is_contiguous()
assert bias.is_contiguous()
# allocate output
out = torch.empty_like(a)
# reshape input data into 2D tensor
a_arg = a.view(-1, a.shape[-1])
M, N = a_arg.shape
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // a.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
BLOCK_SIZE = max(BLOCK_SIZE, 128)
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
BLOCK_SIZE = BLOCK_SIZE if N <= 4096 else 8192
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
layer_norm_kernel[(M, )](
out,
a_arg,
weight,
bias,
a_arg.stride(0),
N,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return out
def layer_norm_residual(a, input_bias, residual, weight, bias, eps):
assert a.is_contiguous()
assert weight.is_contiguous()
assert bias.is_contiguous()
assert residual.is_contiguous()
# allocate output and scratch-pad for residual addition
out = torch.empty_like(a)
ln_input = torch.empty_like(a)
# reshape input data into 2D tensor
a_arg = a.view(-1, a.shape[-1])
residual = residual.view(-1, residual.shape[-1])
M, N = a_arg.shape
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // a.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
BLOCK_SIZE = max(BLOCK_SIZE, 128)
BLOCK_SIZE = min(BLOCK_SIZE, 4096)
BLOCK_SIZE = BLOCK_SIZE if N <= 4096 else 8192
# heuristics for number of warps
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
if input_bias is None:
layer_norm_residual_kernel[(M, )](
out,
a_arg,
residual,
ln_input,
weight,
bias,
a_arg.stride(0),
N,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
else:
layer_norm_residual_bias_kernel[(M, )](
out,
a_arg,
residual,
input_bias,
ln_input,
weight,
bias,
a_arg.stride(0),
N,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return out
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import triton
import os
from filelock import FileLock
import deepspeed.ops.transformer.inference.triton.triton_matmul_kernel as triton_matmul_kernel
import pickle
from io import open
# -----------------------------------------------------------------------------
# util class/functions for triton
def _default_cache_dir():
return os.path.join(os.environ["HOME"], ".triton", "autotune")
def bias_add_activation(C, bias=None, activation=""):
if bias is not None:
C += bias
# activation
if activation == "relu":
relu = torch.nn.Relu()
C = relu(C)
elif activation == "leaky_relu":
leaky_relu = torch.nn.LeakyReLU(0.01)
C = leaky_relu(C)
elif activation == "gelu":
sigmoid = torch.nn.Sigmoid()
C = sigmoid(1.702 * C) * C
elif activation == "sigmoid":
sigmoid = torch.nn.Sigmoid()
C = sigmoid(C)
return C
class AutotuneCacheManager:
"""
Cache manager for autotune
"""
def __init__(self, key):
self.key = key
self.file_path = None
self.lock_path = None
# if caching is enabled, get the lock and bin path
self.cache_dir = os.environ.get('TRITON_CACHE_DIR', _default_cache_dir())
if self.cache_dir:
os.makedirs(self.cache_dir, exist_ok=True)
if self.cache_dir:
self.file_path = os.path.join(self.cache_dir, self.key + ".pickle")
self.lock_path = self.file_path + ".lock"
def has_file(self):
return self.file_path and os.path.exists(self.file_path)
def put(self, table):
if self.file_path:
assert self.lock_path is not None
with FileLock(self.lock_path):
with open(self.file_path + ".tmp", 'wb') as handle:
pickle.dump(table, handle)
os.rename(self.file_path + ".tmp", self.file_path)
def load(self):
if os.path.exists(self.file_path):
with open(self.file_path, 'rb') as handle:
loaded_dict = pickle.load(handle)
return loaded_dict
else:
return None
# -----------------------------------------------------------------------------
# triton matmul class
class MatmulExt(torch.autograd.Function):
"""
a wrapper class that can call different triton matmul kernels depending on the input parameters
"""
@staticmethod
def forward(A, B, bias=None, activation="", use_triton=True, update_autotune_table=False):
"""
A: input, activation matrix A
B: input, weight matrix B
"""
matmul = None
quantize_activation = False
Batch = 0
if len(A.shape) == 3: # if A is 3d-tensor where batch index is given as 0-axis
assert A.is_contiguous(), "matrix A must be contiguous"
Batch, M, K = A.shape
A = A.view(-1, K)
# fp16 activation and fp16 weight matmul into fp16 output
matmul = fp16_matmul
C = matmul.forward(A, B, use_triton=use_triton, bias=bias, activation=activation)
if matmul and update_autotune_table:
matmul._update_autotune_table()
if Batch > 0:
C = C.view(Batch, M, -1)
return C
class TritonMatmul(torch.autograd.Function):
"""
triton matmul kernel superclass
"""
def __init__(self):
pass
@staticmethod
def _ref_forward(A, B, ref_dtype=torch.float32):
C = torch.matmul(A.type(ref_dtype), B.type(ref_dtype))
return C
@staticmethod
def _read_autotune_table(cache_key, triton_kernel):
cache_manager = AutotuneCacheManager(cache_key)
table = cache_manager.load()
if table:
triton_kernel.cache = table
@staticmethod
def _write_autotune_table(cache_key, triton_kernel):
cache_manager = AutotuneCacheManager(cache_key)
cache_manager.put(triton_kernel.cache)
@staticmethod
def _update_autotune_table(cache_key, triton_kernel):
cache_manager = AutotuneCacheManager(cache_key)
autotune_table = cache_manager.load()
if autotune_table is None:
autotune_table = dict()
autotune_table.update(triton_kernel.cache) # always overwrite with the new autotune results
cache_manager = AutotuneCacheManager(cache_key)
cache_manager.put(autotune_table)
@staticmethod
def forward(
A,
B,
ref_dtype=torch.float32, # fp32 only
bias=None,
activation=""):
C = torch.matmul(A.type(ref_dtype), B.type(ref_dtype))
C = bias_add_activation(C, bias, activation)
return C
class Fp16Matmul(TritonMatmul):
"""
fp16 matrix multiplication kernel
dtypes: fp16 x fp16 = fp16
"""
_2d_kernel = triton_matmul_kernel._fp_matmul
_4d_kernel = triton_matmul_kernel.matmul_4d_kernel
_cache_stride = 32
def __init__(self, read_cache=True):
super().__init__()
if read_cache:
__class__._read_autotune_table()
def skip_autotune(self):
__class__._2d_kernel.configs = [__class__._2d_kernel.configs[0]]
__class__._4d_kernel.configs = [__class__._4d_kernel.configs[0]]
@staticmethod
def forward(A, B, use_triton=True, bias=None, activation=""):
if use_triton:
device = A.device
# handle non-contiguous inputs if necessary
if A.stride(0) > 1 and A.stride(1) > 1:
A = A.contiguous()
if B.stride(0) > 1 and B.stride(1) > 1:
B = B.contiguous()
# checks constraints
assert A.shape[1] == B.shape[0], "incompatible dimensions"
M, K = A.shape
_, N = B.shape
# allocates output
C = torch.empty((M, N), device=device, dtype=A.dtype)
# accumulator types
ACC_TYPE = triton.language.float32 if A.dtype in [torch.float16, torch.bfloat16, torch.float32
] else triton.language.int32
# launch kernel
grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), META['SPLIT_K'])
__class__._2d_kernel[grid](A,
B,
C,
M,
N,
K,
bias,
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(1),
C.stride(0),
C.stride(1),
M // __class__._cache_stride,
N // __class__._cache_stride,
K // __class__._cache_stride,
GROUP_M=8,
ACC_TYPE=ACC_TYPE,
BIAS_ADD=(0 if bias is None else 1),
ACTIVATION=activation)
else:
C = torch.matmul(A, B)
return C
@staticmethod
def _matmul_4d(a, b):
assert a.shape[-1] == b.shape[-2], "incompatible dimensions"
assert a.is_contiguous(), "matrix A must be contiguous"
assert b.is_contiguous(), "matrix B must be contiguous"
B, H, M, K = a.shape
B, H, K, N = b.shape
assert K > 1, "inner-product dimension K should be larger than 1"
c = torch.empty((B, H, M, N), device=a.device, dtype=a.dtype)
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),
H,
B,
)
__class__._4d_kernel[grid](
a,
b,
c,
M,
N,
K,
M // __class__._cache_stride,
N // __class__._cache_stride,
K // __class__._cache_stride,
a.stride(0),
a.stride(1),
a.stride(2),
a.stride(3),
b.stride(0),
b.stride(1),
b.stride(2),
b.stride(3),
c.stride(0),
c.stride(1),
c.stride(2),
c.stride(3),
scale=-1.0,
MASK=False,
)
return c
@staticmethod
def _score_4d_matmul(input, head_size, input_mask, scale=-1.0):
assert input.is_contiguous(), "matrix input must be contiguous"
batches = input.shape[0]
d_model = input.shape[-1] // 3
num_of_heads = d_model // head_size
q = input[:, :, :d_model]
k = input[:, :, d_model:d_model * 2]
q = q.view(batches, -1, num_of_heads, head_size)
k = k.view(batches, -1, num_of_heads, head_size)
# checks constraints
assert q.shape == k.shape, "incompatible dimensions"
B, M, H, K = q.shape
B, N, H, K = k.shape
assert K > 1, "inner-product dimension K should be larger than 1"
# allocates output
output = torch.empty((B, H, M, N), device=q.device, dtype=q.dtype)
grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
H,
B,
)
__class__._4d_kernel[grid](
q,
k,
output,
M,
N,
K,
M // __class__._cache_stride,
N // __class__._cache_stride,
K // __class__._cache_stride,
q.stride(0),
q.stride(2),
q.stride(1),
q.stride(3),
k.stride(0),
k.stride(2),
k.stride(3),
k.stride(1),
output.stride(0),
output.stride(1),
output.stride(2),
output.stride(3),
scale=scale,
MASK=False,
)
return output
@staticmethod
def _context_4d_matmul(prob, input, head_size):
assert prob.is_contiguous(), "matrix prob must be contiguous"
assert input.is_contiguous(), "matrix input must be contiguous"
batches = input.shape[0]
d_model = input.shape[-1] // 3
num_of_heads = d_model // head_size
v = input[:, :, d_model * 2:]
v = v.view(batches, -1, num_of_heads, head_size)
# checks constraints
assert (prob.shape[0] == v.shape[0] and prob.shape[1] == v.shape[2] and prob.shape[2] == v.shape[1]
and prob.shape[3] == v.shape[1]), "incompatible dimensions"
B, H, M, K = prob.shape
B, K, H, N = v.shape
assert K > 1, "inner-product dimension K should be larger than 1"
# allocates output
output = torch.empty((B, M, H, N), device=v.device, dtype=v.dtype)
grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
H,
B,
)
__class__._4d_kernel[grid](
prob,
v,
output,
M,
N,
K,
M // __class__._cache_stride,
N // __class__._cache_stride,
K // __class__._cache_stride,
prob.stride(0),
prob.stride(1),
prob.stride(2),
prob.stride(3),
v.stride(0),
v.stride(2),
v.stride(1),
v.stride(3),
# Here we also transpose the output when writing to memory.
output.stride(0),
output.stride(2),
output.stride(1),
output.stride(3),
scale=-1,
MASK=False,
)
return output.view(batches, -1, d_model)
@staticmethod
def _ref_forward(A, B, ref_dtype=torch.float32, bias=None, activation=""):
C = torch.matmul(A.type(ref_dtype), B.type(ref_dtype))
C = bias_add_activation(C, bias, activation)
return C
@staticmethod
def _check_parity(A,
B,
output_dtype,
SA=None,
SB=None,
qblock_size=None,
ref_dtype=torch.float32,
tol=0.01,
use_triton=True,
bias=None,
activation=""):
torch_output = __class__._ref_forward(A, B, ref_dtype=ref_dtype, bias=bias, activation=activation)
triton_output = __class__.forward(A, B, use_triton=use_triton, bias=bias, activation=activation)
assert triton.testing.allclose(triton_output.cpu().type(torch_output.dtype), torch_output.cpu(), tol=tol)
print(f"{__class__.__name__}: PASSed the parity check")
return triton_output, torch_output
@staticmethod
def _read_autotune_table():
TritonMatmul._read_autotune_table(__class__.__name__ + "_2d_kernel", __class__._2d_kernel)
TritonMatmul._read_autotune_table(__class__.__name__ + "_4d_kernel", __class__._4d_kernel)
@staticmethod
def _write_autotune_table():
TritonMatmul._write_autotune_table(__class__.__name__ + "_2d_kernel", __class__._2d_kernel)
TritonMatmul._write_autotune_table(__class__.__name__ + "_4d_kernel", __class__._4d_kernel)
@staticmethod
def _update_autotune_table():
TritonMatmul._update_autotune_table(__class__.__name__ + "_2d_kernel", __class__._2d_kernel)
TritonMatmul._update_autotune_table(__class__.__name__ + "_4d_kernel", __class__._4d_kernel)
# -----------------------------------------------------------------------------
# mapping
matmul = MatmulExt.forward
fp16_matmul = Fp16Matmul()
matmul_4d = fp16_matmul._matmul_4d
score_4d_matmul = fp16_matmul._score_4d_matmul
context_4d_matmul = fp16_matmul._context_4d_matmul
#
import atexit
@atexit.register
def matmul_ext_update_autotune_table():
fp16_matmul._update_autotune_table()
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import math
import torch.nn as nn
from deepspeed.accelerator import get_accelerator
from deepspeed import comm as dist
from ..op_binding import MLPGemmOp, VectorMatMulOp, GELUGemmOp, ResidualAddOp
class TritonMLP(nn.Module):
def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count=1, mlp_extra_grouping=False):
super(TritonMLP, self).__init__()
self.config = config
data_type = self.config.dtype
data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype
device = get_accelerator().current_device_name()
self.attn_nw = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
requires_grad=False)
self.attn_nb = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
requires_grad=False)
intm_size_per_partition = self.config.intermediate_size // self.config.mp_size
self.inter_w = nn.Parameter(torch.empty(self.config.hidden_size,
intm_size_per_partition,
dtype=data_type,
device=device),
requires_grad=False)
self.inter_b = nn.Parameter(torch.empty(intm_size_per_partition, dtype=data_type_fp, device=device),
requires_grad=False)
self.output_w = nn.Parameter(torch.empty(intm_size_per_partition,
self.config.hidden_size,
dtype=data_type,
device=device),
requires_grad=False)
self.output_b = nn.Parameter(torch.empty(self.config.hidden_size, dtype=data_type_fp, device=device),
requires_grad=False)
# used for quantization
self.q_scales = q_scales
self.q_groups = q_groups * 2 if mlp_extra_grouping else q_groups
self.merge_count = int(math.log2(merge_count))
self.mp_group = mp_group
self.mlp_gemm_func = MLPGemmOp(config)
self.vector_matmul_func = VectorMatMulOp(config)
self.fused_gemm_gelu = GELUGemmOp(config)
self.residual_add_func = ResidualAddOp(config)
def forward(self, input, residual, residual_norm, bias):
residual_add = None
if self.attn_nw is None:
output = self.fused_gemm_gelu(input=residual_norm,
weight=self.inter_w,
bias=self.inter_b,
weight_out=self.output_w)
else:
output, residual_add = self.mlp_gemm_func(input=input,
residual=residual,
input_bias=bias,
weight_interm=self.inter_w,
weight_out=self.output_w,
bias=self.inter_b,
gamma=self.attn_nw,
beta=self.attn_nb)
residual = self.residual_add_func(hidden_state=output,
residual=residual,
attention_output=input,
attention_bias=bias if bias is not None else self.output_b,
final_bias=self.output_b,
add_bias=bias is not None,
residual_add=residual_add)
if self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1:
dist.all_reduce(residual, group=self.mp_group)
return residual
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import deepspeed
from deepspeed.ops.op_builder import InferenceBuilder
import deepspeed.ops.transformer.inference.triton.matmul_ext as matmul_ext
from deepspeed.ops.transformer.inference.triton.layer_norm import layer_norm, layer_norm_residual
inference_module = None
def vector_matmul_func(input, weight, async_op, q_scale, q_int8, transposed_mode):
assert not transposed_mode and not async_op and not q_int8
return matmul_ext.matmul(input, weight, bias=None, activation="", use_triton=True)
def fused_gemm_gelu(input,
weight,
weight_scale,
bias,
weight_out,
weight_out_scale,
epsilon,
pre_layer_norm,
q_int8,
async_op,
transposed_mode,
use_triton_ln=True):
assert not transposed_mode
# activation
activation = "gelu"
# intermediate fc in FF
intm_out = matmul_ext.matmul(input, weight, bias=bias, activation=activation, use_triton=True)
# output fc in FF
ff_out = matmul_ext.matmul(
intm_out,
weight_out,
bias=None,
activation="", # bias added layer with residual_add + bias + layerNorm layer
use_triton=True)
return ff_out
def linear_func(input, weight, bias, add_bias, do_flash_attn, num_heads, transposed_mode=False):
assert not transposed_mode and not do_flash_attn
qkv_out = matmul_ext.matmul(input, weight, bias=(bias if add_bias else None), activation="", use_triton=True)
return qkv_out
def mlp_gemm_func(input,
residual,
input_bias,
weight_interm,
weight_out,
bias,
gamma,
beta,
epsilon,
pre_layer_norm,
mlp_after_attn,
weight_interm_scale,
weight_out_scale,
q_int8,
mlp_act_func_type,
transposed_mode,
use_triton_ln=True):
assert not transposed_mode
# residual add and layerNorm after attention
if use_triton_ln:
mlp_input = layer_norm_residual(input, input_bias, residual, gamma, beta, epsilon)
else:
global inference_module
if inference_module is None:
inference_module = InferenceBuilder().load()
mlp_input = inference_module._layer_norm_residual(input, input_bias, residual, gamma, beta, epsilon)
# activation
if deepspeed.utils.types.ActivationFuncType(mlp_act_func_type) == deepspeed.utils.types.ActivationFuncType.GELU:
activation = "gelu"
elif deepspeed.utils.types.ActivationFuncType(mlp_act_func_type) == deepspeed.utils.types.ActivationFuncType.ReLU:
activation = "relu"
else:
activation = ""
# intermediate fc in FF
intm_out = matmul_ext.matmul(mlp_input, weight_interm, bias=bias, activation=activation, use_triton=True)
# output fc in FF
ff_out = matmul_ext.matmul(
intm_out,
weight_out,
bias=None,
activation="", # bias added layer with residual_add + bias + layerNorm layer
use_triton=True)
return ff_out, mlp_input
def qkv_gemm_func(
input,
weight,
q_scale,
bias,
gamma,
beta,
epsilon,
add_bias,
q_int8,
transposed_mode=False,
use_triton_ln=True,
):
assert not transposed_mode
# residual add and layerNorm after attention
if use_triton_ln:
qkv_input = layer_norm(input, gamma, beta, epsilon)
else:
global inference_module
if inference_module is None:
inference_module = InferenceBuilder().load()
qkv_input = inference_module.layer_norm(input, gamma, beta, epsilon)
qkv_out = matmul_ext.matmul(qkv_input, weight, bias=(bias if add_bias else None), activation="", use_triton=True)
return qkv_out, qkv_input
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import triton
import triton.language as tl
@triton.jit
def residual_add_bias_kernel(
hidden_state_ptr,
residual_ptr,
attn_output_ptr,
hidden_state_size,
attn_bias_ptr,
final_bias_ptr,
bias_size,
output_ptr,
mp_size: tl.constexpr,
mlp_after_attn: tl.constexpr,
pre_attn_norm: tl.constexpr,
add_attn_bias: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < hidden_state_size
bias_offsets = offsets % bias_size
bias_mask = bias_offsets < bias_size
tl_hidden_state = tl.load(hidden_state_ptr + offsets, mask=mask)
tl_residual = tl.load(residual_ptr + offsets, mask=mask)
tl_attn_output = tl.load(attn_output_ptr + offsets, mask=mask)
tl_attn_bias = tl.load(attn_bias_ptr + bias_offsets, mask=bias_mask)
tl_final_bias = tl.load(final_bias_ptr + bias_offsets, mask=bias_mask)
if mlp_after_attn:
if pre_attn_norm:
output = tl_hidden_state + (tl_residual + tl_final_bias + tl_attn_output + tl_attn_bias) / mp_size
else:
output = tl_hidden_state + tl_residual + tl_final_bias
else:
output = tl_hidden_state + tl_attn_output + (tl_residual + tl_final_bias) / mp_size
if add_attn_bias:
output += tl_attn_bias / mp_size
tl.store(output_ptr + offsets, output, mask=mask)
def residual_add_bias(hidden_state: torch.Tensor, residual: torch.Tensor, attn_output: torch.Tensor,
attn_bias: torch.Tensor, final_bias: torch.Tensor, mp_size: int, mlp_after_attn: bool,
add_attn_bias: bool, pre_attn_norm: bool):
# check that all tensors are on the same device
assert hidden_state.is_cuda and residual.is_cuda and attn_output.is_cuda \
and attn_bias.is_cuda and final_bias.is_cuda
# check that all tensors have the same dtype
assert hidden_state.dtype == residual.dtype == attn_output.dtype \
== attn_bias.dtype == final_bias.dtype
# check that all tensors have the right shape
assert hidden_state.shape == residual.shape == attn_output.shape
assert attn_bias.shape == final_bias.shape
assert attn_bias.shape[0] == hidden_state.shape[2]
output = torch.empty_like(hidden_state)
hidden_state_size = output.numel()
bias_size = attn_bias.numel()
grid = lambda meta: (triton.cdiv(hidden_state_size, meta['BLOCK_SIZE']), )
residual_add_bias_kernel[grid](hidden_state, residual, attn_output, hidden_state_size,\
attn_bias, final_bias, bias_size, output, mp_size, mlp_after_attn, pre_attn_norm, \
add_attn_bias, \
BLOCK_SIZE=1024)
return output
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import triton
import triton.language as tl
'''
softmax
modified the triton kernel in
https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py
'''
@triton.jit
def softmax_kernel(output_ptr, input_ptr, stride, n_cols, BLOCK_SIZE: tl.constexpr):
row_idx = tl.program_id(0)
row_start_ptr = input_ptr + row_idx * stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
row_minus_max = row - tl.max(row, axis=0)
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
output_row_start_ptr = output_ptr + row_idx * stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
@triton.jit
def masked_softmax_kernel(output_ptr, input_ptr, stride, mask_ptr, mask_stride, n_cols, BLOCK_SIZE: tl.constexpr):
row_idx = tl.program_id(0)
row_start_ptr = input_ptr + row_idx * stride
col_offsets = tl.arange(0, BLOCK_SIZE)
input_ptrs = row_start_ptr + col_offsets
mask_ptrs = mask_ptr + col_offsets + row_idx * mask_stride # mask_stride is 0 for 1d mask
row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32)
mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)
row_minus_max = row - tl.max(row, axis=0)
row_minus_max = row_minus_max + mask
numerator = tl.exp(row_minus_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
output_row_start_ptr = output_ptr + row_idx * stride
output_ptrs = output_row_start_ptr + col_offsets
tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)
def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:
assert input.is_contiguous()
assert (dim == -1) or (dim == len(input.shape) - 1), "Only dim=-1 is supported"
use_mask = False if mask is None else True
input_arg = input.view(-1, input.shape[-1])
n_rows, n_cols = input_arg.shape
BLOCK_SIZE = max(triton.next_power_of_2(n_cols), 2)
num_warps = 4
if BLOCK_SIZE >= 2048:
num_warps = 8
if BLOCK_SIZE >= 4096:
num_warps = 16
# Allocate output
output = torch.empty_like(input)
if use_mask:
assert mask.is_contiguous()
mask = mask.view(-1, mask.shape[-1])
mask_stride = mask.shape[-1] if mask.shape[-2] > 1 else 0
masked_softmax_kernel[(n_rows, )](
output,
input,
input_arg.stride(0),
mask,
mask_stride,
n_cols,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
else:
softmax_kernel[(n_rows, )](
output,
input,
input_arg.stride(0),
n_cols,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
return output
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import triton
import triton.language as tl
from .gelu import gelu_functor
import torch
AUTOTUNE_TOP_K = 10
SKIP_AUTOTUNE = False
def _fp16_matmul_prune_config(configs, named_args, skip_autotune=SKIP_AUTOTUNE):
if skip_autotune:
configs = [configs[0]]
else:
configs = triton.ops.matmul_perf_model.early_config_prune(configs, named_args)
return configs
"""
fp16 matmul implementation is adapted from triton matmul:
https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/triton/ops/matmul.py
"""
@triton.autotune(
configs=[
# basic configs for compute-bound matmuls
triton.Config({
'BLOCK_M': 128,
'BLOCK_N': 256,
'BLOCK_K': 32,
'SPLIT_K': 1
}, num_stages=3, num_warps=8),
triton.Config({
'BLOCK_M': 256,
'BLOCK_N': 128,
'BLOCK_K': 32,
'SPLIT_K': 1
}, num_stages=3, num_warps=8),
triton.Config({
'BLOCK_M': 256,
'BLOCK_N': 64,
'BLOCK_K': 32,
'SPLIT_K': 1
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_M': 64,
'BLOCK_N': 256,
'BLOCK_K': 32,
'SPLIT_K': 1
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_M': 128,
'BLOCK_N': 128,
'BLOCK_K': 32,
'SPLIT_K': 1
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_M': 128,
'BLOCK_N': 64,
'BLOCK_K': 32,
'SPLIT_K': 1
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_M': 64,
'BLOCK_N': 128,
'BLOCK_K': 32,
'SPLIT_K': 1
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_M': 128,
'BLOCK_N': 32,
'BLOCK_K': 32,
'SPLIT_K': 1
}, num_stages=4, num_warps=4),
triton.Config({
'BLOCK_M': 64,
'BLOCK_N': 32,
'BLOCK_K': 32,
'SPLIT_K': 1
}, num_stages=5, num_warps=2),
],
key=['CACHE_M', 'CACHE_N', 'CACHE_K'],
prune_configs_by={
'early_config_prune': _fp16_matmul_prune_config,
'perf_model': None,
'top_k': AUTOTUNE_TOP_K
},
)
@triton.heuristics({
'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0,
})
@triton.jit
def _fp_matmul(
A,
B,
C,
M,
N,
K,
bias,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
CACHE_M,
CACHE_N,
CACHE_K,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_M: tl.constexpr,
SPLIT_K: tl.constexpr,
EVEN_K: tl.constexpr,
ACC_TYPE: tl.constexpr,
BIAS_ADD: tl.constexpr,
ACTIVATION: tl.constexpr,
):
# matrix multiplication
pid = tl.program_id(0)
pid_z = tl.program_id(1)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
# do matrix multiplication
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)
# pointers
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K * SPLIT_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
acc += tl.dot(a, b)
A += BLOCK_K * SPLIT_K * stride_ak
B += BLOCK_K * SPLIT_K * stride_bk
# bias addition
if BIAS_ADD:
bias_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
bias_ptr = bias + bias_offset
b = tl.load(bias_ptr, mask=bias_offset < N)
acc = acc + b[None, :]
# activation
if ACTIVATION == "relu":
acc = tl.where(acc >= 0, acc, 0)
elif ACTIVATION == "leaky_relu":
acc = tl.where(acc >= 0, acc, 0.01 * acc)
elif ACTIVATION == "gelu":
#acc = tl.sigmoid(1.702 * acc) * acc
acc = gelu_functor(acc)
elif ACTIVATION == "sigmoid":
acc = tl.sigmoid(acc) # sigmoid
acc = acc.to(C.dtype.element_ty)
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)
mask = (rm < M)[:, None] & (rn < N)[None, :]
# handles write-back with reduction-splitting
if SPLIT_K == 1:
tl.store(C, acc, mask=mask)
else:
tl.atomic_add(C, acc, mask=mask)
def matmul_4d_prune_config(configs, named_args, skip_autotune=SKIP_AUTOTUNE):
if skip_autotune:
configs = [configs[0]]
else:
device = torch.cuda.current_device() #ignore-cuda
capability = torch.cuda.get_device_capability() #ignore-cuda
# BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, num_warps, num_stages
dtsize = named_args['a_ptr'].element_size()
dtype = named_args['a_ptr'].dtype
# make sure we have enough smem
pruned_configs = []
for config in configs:
kw = config.kwargs
BLOCK_M, BLOCK_N, BLOCK_K, num_stages = \
kw['BLOCK_SIZE_M'], kw['BLOCK_SIZE_N'], kw['BLOCK_SIZE_K'], config.num_stages
triton.compiler.init_cuda_utils()
max_shared_memory = triton.compiler.cuda_utils.get_device_properties(device)["max_shared_mem"]
required_shared_memory = (BLOCK_M + BLOCK_N) * BLOCK_K * num_stages * dtsize
if required_shared_memory <= max_shared_memory:
pruned_configs.append(config)
configs = pruned_configs
return configs
@triton.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8
},
num_stages=1, # this is mainly for unit test, to minimize the share memory usage
num_warps=8),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=5,
num_warps=2,
),
triton.Config(
{
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=5,
num_warps=2,
),
],
key=['CACHE_M', 'CACHE_N', 'CACHE_K'],
prune_configs_by={
'early_config_prune': matmul_4d_prune_config,
'perf_model': None,
'top_k': AUTOTUNE_TOP_K
},
)
@triton.jit
def matmul_4d_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
# Matrix dimensions
M,
N,
K,
CACHE_M,
CACHE_N,
CACHE_K,
stride_ab,
stride_ah,
stride_am,
stride_ak,
stride_bb,
stride_bh,
stride_bk,
stride_bn,
stride_cb,
stride_ch,
stride_cm,
stride_cn,
scale,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MASK: tl.constexpr,
):
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
pid = tl.program_id(axis=0)
head = tl.program_id(axis=1)
batch = tl.program_id(axis=2)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
if MASK:
if (pid_m + 1) * BLOCK_SIZE_M - 1 < pid_n * BLOCK_SIZE_N:
c = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=c_ptr.dtype.element_ty) - float("inf")
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_cm[:, None] +
stride_cn * offs_cn[None, :])
tl.store(c_ptrs, c)
return
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = (a_ptr + batch * stride_ab + head * stride_ah +
(offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak))
b_ptrs = (b_ptr + batch * stride_bb + head * stride_bh +
(offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn))
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, K, BLOCK_SIZE_K):
a_mask = (offs_am[:, None] < M) & (offs_k[None, :] + k < K)
b_mask = (offs_k[:, None] + k < K) & (offs_bn[None, :] < N)
a = tl.load(a_ptrs, mask=a_mask, other=0.)
b = tl.load(b_ptrs, mask=b_mask, other=0.)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
c = accumulator.to(c_ptr.dtype.element_ty)
if scale > 0:
c = c * scale.to(c_ptr.dtype.element_ty)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
if MASK:
c += tl.where(offs_cm[:, None] >= offs_cn[None, :], 0, float("-inf"))
c_ptrs = (c_ptr + batch * stride_cb + head * stride_ch + stride_cm * offs_cm[:, None] +
stride_cn * offs_cn[None, :])
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
......@@ -13,4 +13,5 @@ sphinx-rtd-theme
tensorboard
torchvision
transformers
triton
wandb
......@@ -3,3 +3,4 @@ lm-eval==0.3.0
protobuf
transformers
transformers[sentencepiece]
triton
......@@ -67,7 +67,8 @@ extras_require = {
'sparse_attn': fetch_requirements('requirements/requirements-sparse_attn.txt'),
'sparse': fetch_requirements('requirements/requirements-sparse_pruning.txt'),
'inf': fetch_requirements('requirements/requirements-inf.txt'),
'sd': fetch_requirements('requirements/requirements-sd.txt')
'sd': fetch_requirements('requirements/requirements-sd.txt'),
'triton': fetch_requirements('requirements/requirements-triton.txt'),
}
# Add specific cupy version to both onebit extension variants.
......
......@@ -95,13 +95,18 @@ def enable_cuda_graph(request):
return request.param
@pytest.fixture(params=[True, False], ids=["Triton", "noTriton"])
def enable_triton(request):
return request.param
"""
This fixture will validate the configuration
"""
@pytest.fixture()
def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph):
def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph, enable_triton):
model, task = model_w_task
msg = ""
if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"):
......@@ -125,6 +130,12 @@ def invalid_model_task_config(model_w_task, dtype, enable_cuda_graph):
msg = f"Bloom models only support half precision, cannot use dtype {dtype}"
elif ("bert" not in model.lower()) and enable_cuda_graph:
msg = "Non bert/roberta models do no support CUDA Graph"
elif enable_triton and not (dtype in [torch.half]):
msg = "Triton is for fp16"
elif enable_triton and not deepspeed.HAS_TRITON:
msg = "triton needs to be installed for the test"
elif ("bert" not in model.lower()) and enable_triton:
msg = "Triton kernels do not support Non bert/roberta models yet"
return msg
......@@ -249,16 +260,16 @@ Tests
class TestModelTask(DistributedTest):
world_size = 1
def test(
self,
model_w_task,
dtype,
enable_cuda_graph,
query,
inf_kwargs,
assert_fn,
invalid_model_task_config,
):
def test(self,
model_w_task,
dtype,
enable_cuda_graph,
enable_triton,
query,
inf_kwargs,
assert_fn,
invalid_model_task_config,
perf_meas=True):
if invalid_model_task_config:
pytest.skip(invalid_model_task_config)
......@@ -284,13 +295,18 @@ class TestModelTask(DistributedTest):
get_accelerator().synchronize()
bs_time = time.time() - start
pipe.model = deepspeed.init_inference(
pipe.model,
mp_size=1,
dtype=dtype,
replace_with_kernel_inject=True,
enable_cuda_graph=enable_cuda_graph,
)
args = {
'mp_size': 1,
'dtype': dtype,
'replace_with_kernel_inject': True,
'enable_cuda_graph': enable_cuda_graph,
'use_triton': enable_triton,
'triton_autotune': False,
}
if pipe.tokenizer.model_max_length < deepspeed.ops.transformer.inference.config.DeepSpeedInferenceConfig(
).max_out_tokens:
args.update({'max_out_tokens': pipe.tokenizer.model_max_length})
pipe.model = deepspeed.init_inference(pipe.model, **args)
check_injection(pipe.model)
# Warm-up queries for perf measurement
#for i in range(10):
......@@ -301,6 +317,11 @@ class TestModelTask(DistributedTest):
get_accelerator().synchronize()
ds_time = time.time() - start
if perf_meas:
print(
f"model={model}, task={task}, dtype={dtype}, cuda_graph={enable_cuda_graph}, triton={enable_triton}, bs_time={bs_time}, ds_time={ds_time}"
)
# facebook/opt* and some bigscient/bloom* models are not matching
# baseline exactly, adding an exception to them for now
if ("opt" in model) or ("bloom" in model):
......@@ -309,6 +330,7 @@ class TestModelTask(DistributedTest):
# These performance tests are only measuring the time for a single
# inference request, we just want to check that performance isn't terrible
#assert ds_time <= (bs_time * 1.1)
assert assert_fn(bs_output, ds_output)
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import pytest
import torch
import deepspeed
# reference timplementation
def ref_torch_attention(q, k, v, mask, sm_scale):
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
p = torch.softmax(p.float() + mask, dim=-1).half()
ref_out = torch.matmul(p, v)
return ref_out
# test attention operator
@pytest.mark.inference_ops
@pytest.mark.parametrize("Z", [1]) # batch
@pytest.mark.parametrize("H", [12]) # heads
@pytest.mark.parametrize("N_CTX", [4, 128]) # sequence length
@pytest.mark.parametrize("D_HEAD", [64, 128])
@pytest.mark.parametrize("causal", [True, False])
def test_attention(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16):
if not deepspeed.HAS_TRITON:
pytest.skip("triton has to be installed for the test")
# skip autotune in testing
from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul
fp16_matmul.skip_autotune()
import triton
from deepspeed.ops.transformer.inference.triton.attention import compute_attention
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5)
k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5)
v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0, std=.5)
sm_scale = 0.3
# reference implementation
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
score = p
mask = torch.zeros((Z, H, N_CTX, N_CTX), dtype=dtype, device="cuda")
M = torch.tril(torch.ones((N_CTX, N_CTX), device="cuda"))
if causal:
for z in range(Z):
for h in range(H):
mask[:, :, M == 0] = float("-inf")
p = torch.softmax(p.float() + mask, dim=-1).half()
softmax_out = p
ref_out = torch.matmul(p, v)
context = ref_out
# adjust it to expected tensor format and run test
qkv = torch.randn((Z, N_CTX, 3 * H * D_HEAD), dtype=dtype, device='cuda', requires_grad=False)
qkv[:, :, :H * D_HEAD] = q.permute(0, 2, 1, 3).contiguous().reshape((Z, N_CTX, H * D_HEAD))
qkv[:, :, 1 * H * D_HEAD:2 * H * D_HEAD] = k.permute(0, 2, 1, 3).contiguous().reshape((Z, N_CTX, H * D_HEAD))
qkv[:, :, 2 * H * D_HEAD:] = v.permute(0, 2, 1, 3).contiguous().reshape((Z, N_CTX, H * D_HEAD))
tri_out = compute_attention(qkv,
input_mask=mask,
layer_past=None,
alibi=None,
scale=sm_scale,
head_size=D_HEAD,
triangular=False,
use_cuda_flash=False,
use_triton_flash=False,
use_ds_attention=False)
tri_out = tri_out.reshape((Z, N_CTX, H, D_HEAD)).permute(0, 2, 1, 3)
triton.testing.allclose(ref_out, tri_out)
triton.testing.assert_almost_equal(ref_out, tri_out)
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import pytest
import torch
import deepspeed
from deepspeed.ops.op_builder import InferenceBuilder
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)
inference_module = None
torch_minor_version = None
def allclose(x, y):
assert x.dtype == y.dtype
rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype]
return torch.allclose(x, y, rtol=rtol, atol=atol)
def version_appropriate_gelu(activations):
global torch_minor_version
if torch_minor_version is None:
torch_minor_version = int(torch.__version__.split('.')[1])
# If torch version = 1.12
if torch_minor_version < 12:
return torch.nn.functional.gelu(activations)
else:
return torch.nn.functional.gelu(activations, approximate='tanh')
def run_gelu_reference(activations):
# Expected behavior is that of casting to float32 internally and using the tanh approximation
return version_appropriate_gelu(activations.to(torch.float32)).to(activations.dtype)
def run_gelu_ds(activations, use_triton_ops=False):
if use_triton_ops:
from deepspeed.ops.transformer.inference.triton import gelu
return gelu(activations)
channels = activations.shape[-1]
bias = torch.zeros((channels), dtype=activations.dtype, device='cuda')
global inference_module
if inference_module is None:
inference_module = InferenceBuilder().load()
if activations.dtype == torch.float16:
return inference_module.bias_gelu_fp16(activations, bias)
else:
return inference_module.bias_gelu_fp32(activations, bias)
@pytest.mark.inference_ops
@pytest.mark.parametrize("batch", [1, 2])
@pytest.mark.parametrize("sequence", [1, 128, 255])
@pytest.mark.parametrize("channels", [512, 1232, 4096])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("use_triton_ops", [True, False])
def test_gelu(batch, sequence, channels, dtype, use_triton_ops):
activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device='cuda')
activations_ref = activations_ds.clone().detach()
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
ds_out = run_gelu_ds(activations_ds, use_triton_ops)
ref_out = run_gelu_reference(activations_ref)
assert (allclose(ds_out, ref_out))
......@@ -9,6 +9,14 @@ import pytest
from deepspeed.accelerator import get_accelerator
from deepspeed.ops.op_builder import InferenceBuilder
from .inference_test_utils import allclose, get_dtypes
try:
import triton # noqa: F401
from deepspeed.ops.transformer.inference.triton import (
layer_norm,
layer_norm_residual,
)
except ImportError:
print("triton import failed")
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)
......@@ -30,19 +38,32 @@ def ds_implementation(vals, gamma, beta, epsilon):
return inference_module.layer_norm(vals, gamma, beta, epsilon)
def ds_triton_implementation(vals, gamma, beta, epsilon):
return layer_norm(vals, gamma, beta, epsilon)
@pytest.mark.inference_ops
@pytest.mark.parametrize("batch", [1, 32])
@pytest.mark.parametrize("seq_len", [1, 128])
@pytest.mark.parametrize("channels", [384, 512, 768, 1024, 2048, 8192, 14432])
@pytest.mark.parametrize("dtype", get_dtypes())
def test_layer_norm(batch, seq_len, channels, dtype):
@pytest.mark.parametrize("use_triton_ops", [False, True])
def test_layer_norm(batch, seq_len, channels, dtype, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
vals = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name())
gamma = torch.randn((channels), dtype=dtype, device=get_accelerator().current_device_name())
beta = torch.rand((channels), dtype=dtype, device=get_accelerator().current_device_name())
epsilon = 1e-5
ref_output = ref_implementation(vals, gamma, beta, epsilon, channels, dtype)
new_output = ds_implementation(vals, gamma, beta, epsilon)
if use_triton_ops:
new_output = ds_triton_implementation(vals, gamma, beta, epsilon)
if dtype != torch.float16: # fp16 supported in triton
return
else:
new_output = ds_implementation(vals, gamma, beta, epsilon)
if not allclose(new_output, ref_output):
#print(new_output - ref_output)
......@@ -68,12 +89,20 @@ def residual_ds_implementation(vals, bias, res, gamma, beta, epsilon):
return inference_module._layer_norm_residual(vals, bias, res, gamma, beta, epsilon)
def residual_ds_triton_implementation(vals, bias, res, gamma, beta, epsilon):
return layer_norm_residual(vals, bias, res, gamma, beta, epsilon)
@pytest.mark.inference_ops
@pytest.mark.parametrize("batch", [1, 32])
@pytest.mark.parametrize("seq_len", [1, 128])
@pytest.mark.parametrize("channels", [384, 512, 768, 1024, 2048, 8192, 14432])
@pytest.mark.parametrize("dtype", get_dtypes())
def test_layer_norm_residual(batch, seq_len, channels, dtype):
@pytest.mark.parametrize("use_triton_ops", [False, True])
def test_layer_norm_residual(batch, seq_len, channels, dtype, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
vals = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name())
residual = torch.randn((batch, seq_len, channels), dtype=dtype, device=get_accelerator().current_device_name())
bias = torch.randn((channels), dtype=dtype, device=get_accelerator().current_device_name())
......@@ -81,7 +110,13 @@ def test_layer_norm_residual(batch, seq_len, channels, dtype):
beta = torch.rand((channels), dtype=dtype, device=get_accelerator().current_device_name())
epsilon = 1e-5
new_output = residual_ds_implementation(vals, bias, residual, gamma, beta, epsilon)
if use_triton_ops:
new_output = residual_ds_triton_implementation(vals, bias, residual, gamma, beta, epsilon)
if dtype != torch.float16: # fp16 supported in triton
return
else:
new_output = residual_ds_implementation(vals, bias, residual, gamma, beta, epsilon)
ref_output = residual_ref_implementation(vals, bias, residual, gamma, beta, epsilon, channels, dtype)
print((new_output - ref_output).abs().max())
......@@ -129,3 +164,38 @@ def test_layer_norm_residual_store_pre_ln_res(batch, seq_len, channels, dtype):
assert allclose(ds_res_output, norm_res_output)
assert allclose(ds_norm_output, ref_norm_output)
@pytest.mark.inference_ops
@pytest.mark.parametrize("M", [4])
@pytest.mark.parametrize("N", [4])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("residual", [True, False])
@pytest.mark.parametrize("input_bias", [True, False])
def test_triton_layer_norm(M, N, dtype, residual, input_bias, eps=1e-5, device='cuda'):
if not deepspeed.HAS_TRITON:
pytest.skip("triton has to be installed for the test")
torch.manual_seed(0)
# create data
x_shape = (M, N)
w_shape = (x_shape[-1], )
weight = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=False)
bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=False)
x_bias = torch.rand(w_shape, dtype=dtype, device='cuda', requires_grad=False)
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')
dy = .1 * torch.randn_like(x)
if residual:
res = torch.rand(x_shape, dtype=dtype, device='cuda', requires_grad=False)
else:
res = torch.zeros(x_shape, dtype=dtype, device='cuda', requires_grad=False)
x.requires_grad_(True)
# forward pass
if residual or input_bias:
y_tri = layer_norm_residual(x, x_bias if input_bias else None, res, weight, bias, eps)
else:
y_tri = layer_norm(x, weight, bias, eps)
y_ref = torch.nn.functional.layer_norm(x + res + (x_bias if input_bias else 0), w_shape, weight, bias,
eps).to(dtype)
# compare
#print(f"y_tri={y_tri}, y_ref={y_ref}")
triton.testing.assert_almost_equal(y_tri, y_ref)
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import pytest
import torch
import deepspeed
from deepspeed.ops.op_builder import InferenceBuilder
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)
inference_module = None
torch_minor_version = None
def allclose(x, y):
assert x.dtype == y.dtype
rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (5e-2, 2e-3)}[x.dtype]
return torch.allclose(x, y, rtol=rtol, atol=atol)
def run_matmul_ref(a, b):
return torch.matmul(a, b)
def run_matmul_ds(a, b, use_triton_ops=False):
if use_triton_ops:
from deepspeed.ops.transformer.inference.triton import matmul_4d as matmul
return matmul(a, b)
assert use_triton_ops, "Only triton softmax is supported for now"
@pytest.mark.inference_ops
@pytest.mark.parametrize("B", [1, 2])
@pytest.mark.parametrize("H", [1, 2, 16])
@pytest.mark.parametrize("M", [1, 7, 8, 128])
@pytest.mark.parametrize("K", [2, 5, 16, 128])
@pytest.mark.parametrize("N", [1, 2, 8, 512])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("use_triton_ops", [True])
def test_matmul_4d(B, H, M, K, N, dtype, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
# skip autotune in testing
from deepspeed.ops.transformer.inference.triton.matmul_ext import fp16_matmul
fp16_matmul.skip_autotune()
a_ds = torch.randn((B, H, M, K), dtype=dtype, device='cuda')
b_ds = torch.randn((B, H, K, N), dtype=dtype, device='cuda')
a_ref = a_ds.clone().detach()
b_ref = b_ds.clone().detach()
ds_out = run_matmul_ds(a_ds, b_ds, use_triton_ops)
ref_out = run_matmul_ref(a_ref, b_ref)
assert (allclose(ds_out, ref_out))
......@@ -74,8 +74,11 @@ def run_residual_add_reference(hidden_state, residual, attn_output, attn_bias, f
@pytest.mark.parametrize("add_bias", [True, False])
@pytest.mark.parametrize("mp_size", [1, 2])
@pytest.mark.parametrize("pre_attn_norm", [True, False])
@pytest.mark.parametrize("use_triton_ops", [True, False])
def test_residual_add(inference_module, batch, sequence, hidden_dim, dtype, mlp_after_attn, add_bias, mp_size,
pre_attn_norm):
pre_attn_norm, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops and dtype == torch.float16:
pytest.skip("triton has to be installed for the test")
ds_out = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
residual = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
attn_output = torch.randn((batch, sequence, hidden_dim), dtype=dtype, device=get_accelerator().device_name())
......@@ -90,6 +93,9 @@ def test_residual_add(inference_module, batch, sequence, hidden_dim, dtype, mlp_
ds_out, residual, attn_output, attn_bias, final_bias, mp_size, mlp_after_attn, add_bias, pre_attn_norm
]
if use_triton_ops:
from deepspeed.ops.transformer.inference.triton import residual_add_bias
ds_out = residual_add_bias(*res_add_args)
if dtype == torch.float16:
ds_out = inference_module.residual_add_bias_fp16(*res_add_args)
elif dtype == torch.float32:
......@@ -97,7 +103,12 @@ def test_residual_add(inference_module, batch, sequence, hidden_dim, dtype, mlp_
elif dtype == torch.bfloat16:
ds_out = inference_module.residual_add_bias_bf16(*res_add_args)
else:
raise ValueError(f"Unsupported dtype: {dtype}")
if dtype == torch.float16:
ds_out = inference_module.residual_add_bias_fp16(*res_add_args)
elif dtype == torch.float32:
ds_out = inference_module.residual_add_bias_fp32(*res_add_args)
else:
raise ValueError(f"Unsupported dtype: {dtype}")
if not allclose(ds_out, ref_out):
print((ds_out - ref_out).abs().max())
......
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import pytest
import torch
import deepspeed
from deepspeed.ops.op_builder import InferenceBuilder
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)
inference_module = None
torch_minor_version = None
def allclose(x, y):
assert x.dtype == y.dtype
rtol, atol = {torch.float32: (5e-4, 5e-5), torch.float16: (3e-2, 2e-3)}[x.dtype]
return torch.allclose(x, y, rtol=rtol, atol=atol)
def run_softmax_reference(input):
return torch.nn.functional.softmax(input, dim=-1)
def run_softmax_ds(input, use_triton_ops=False):
if use_triton_ops:
from deepspeed.ops.transformer.inference.triton import softmax
# return torch.empty_like(input)
return softmax(input)
assert use_triton_ops, "Only triton softmax is supported for now"
@pytest.mark.inference_ops
@pytest.mark.parametrize("batch", [1, 2])
@pytest.mark.parametrize("sequence", [1, 128, 255, 1232])
@pytest.mark.parametrize("channels", [512, 4096])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("use_triton_ops", [True])
def test_softmax(batch, sequence, channels, dtype, use_triton_ops):
if not deepspeed.HAS_TRITON and use_triton_ops:
pytest.skip("triton has to be installed for the test")
input_ds = torch.randn((batch, sequence, channels), dtype=dtype, device='cuda')
input_ref = input_ds.clone().detach()
ds_out = run_softmax_ds(input_ds, use_triton_ops)
ref_out = run_softmax_reference(input_ref)
assert (allclose(ds_out, ref_out))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册