diff --git a/op_builder/builder.py b/op_builder/builder.py index 44d6a440c05630cdb646b59520cfac6c5623735e..ea10320529ff74c2d116378bd57c49b3faa2771a 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -36,9 +36,6 @@ else: def installed_cuda_version(name=""): - import torch.cuda - if not torch.cuda.is_available(): - return 0, 0 import torch.utils.cpp_extension cuda_home = torch.utils.cpp_extension.CUDA_HOME assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)" @@ -78,8 +75,6 @@ cuda_minor_mismatch_ok = { def assert_no_cuda_mismatch(name=""): cuda_major, cuda_minor = installed_cuda_version(name) - if cuda_minor == 0 and cuda_major == 0: - return False sys_cuda_version = f'{cuda_major}.{cuda_minor}' torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2]) # This is a show-stopping error, should probably not proceed past this @@ -344,10 +339,11 @@ class OpBuilder(ABC): def is_cuda_enable(self): try: - if torch.cuda.is_available(): - return '-D__ENABLE_CUDA__' - except: - print(f"{WARNING} {self.name} torch.cuda is missing, only cpu ops can be compiled!") + assert_no_cuda_mismatch(self.name) + return '-D__ENABLE_CUDA__' + except BaseException: + print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, " + "only cpu ops can be compiled!") return '-D__DISABLE_CUDA__' return '-D__DISABLE_CUDA__' @@ -459,7 +455,11 @@ class OpBuilder(ABC): raise RuntimeError(f"Unable to JIT load the {self.name} op due to ninja not being installed.") if isinstance(self, CUDAOpBuilder) and not self.is_rocm_pytorch(): - self.build_for_cpu = not assert_no_cuda_mismatch(self.name) + try: + assert_no_cuda_mismatch(self.name) + self.build_for_cpu = False + except BaseException: + self.build_for_cpu = True self.jit_mode = True from torch.utils.cpp_extension import load @@ -579,7 +579,12 @@ class CUDAOpBuilder(OpBuilder): return super().is_compatible(verbose) def builder(self): - self.build_for_cpu = not assert_no_cuda_mismatch(self.name) + try: + assert_no_cuda_mismatch(self.name) + self.build_for_cpu = False + except BaseException: + self.build_for_cpu = True + if self.build_for_cpu: from torch.utils.cpp_extension import CppExtension as ExtensionBuilder else: