""" Copyright 2020 The Microsoft DeepSpeed Team """ import warnings from .builder import OpBuilder from packaging import version as pkg_version class SparseAttnBuilder(OpBuilder): BUILD_VAR = "DS_BUILD_SPARSE_ATTN" NAME = "sparse_attn" def __init__(self): super().__init__(name=self.NAME) def absolute_name(self): return f'deepspeed.ops.sparse_attention.{self.NAME}_op' def sources(self): return ['csrc/sparse_attention/utils.cpp'] def cxx_args(self): return ['-O2', '-fopenmp'] def is_compatible(self, verbose=True): # Check to see if llvm and cmake are installed since they are dependencies #required_commands = ['llvm-config|llvm-config-9', 'cmake'] #command_status = list(map(self.command_exists, required_commands)) #deps_compatible = all(command_status) try: import torch except ImportError: self.warning(f"unable to import torch, please install it first") return False # torch-cpu will not have a cuda version if torch.version.cuda is None: cuda_compatible = False self.warning(f"{self.NAME} cuda is not available from torch") else: major, minor = torch.version.cuda.split('.')[:2] cuda_compatible = (int(major) == 10 and int(minor) >= 1) or (int(major) >= 11) if not cuda_compatible: self.warning(f"{self.NAME} requires CUDA version 10.1+") TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MINOR = int(torch.__version__.split('.')[1]) torch_compatible = TORCH_MAJOR == 1 and TORCH_MINOR >= 5 if not torch_compatible: self.warning( f'{self.NAME} requires a torch version >= 1.5 but detected {TORCH_MAJOR}.{TORCH_MINOR}' ) try: import triton except ImportError: # auto-install of triton is broken on some systems, reverting to manual install for now # see this issue: https://github.com/microsoft/DeepSpeed/issues/1710 self.warning( f"please install triton==1.0.0 if you want to use sparse attention") return False installed_triton = pkg_version.parse(triton.__version__) if installed_triton != pkg_version.parse("1.0.0"): self.warning( f"using untested triton version ({installed_triton}), only 1.0.0 is known to be compatible" ) return False return super().is_compatible(verbose) and torch_compatible and cuda_compatible