未验证 提交 036c5d6d 编写于 作者: J Jinzhen Lin 提交者: GitHub

remove `torch.cuda.is_available()` check when compiling ops (#3085)

Co-authored-by: NOlatunji Ruwase <olruwase@microsoft.com>
Co-authored-by: NLogan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: NJeff Rasley <jerasley@microsoft.com>
上级 c381dd16
......@@ -36,9 +36,6 @@ else:
def installed_cuda_version(name=""):
import torch.cuda
if not torch.cuda.is_available():
return 0, 0
import torch.utils.cpp_extension
cuda_home = torch.utils.cpp_extension.CUDA_HOME
assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)"
......@@ -78,8 +75,6 @@ cuda_minor_mismatch_ok = {
def assert_no_cuda_mismatch(name=""):
cuda_major, cuda_minor = installed_cuda_version(name)
if cuda_minor == 0 and cuda_major == 0:
return False
sys_cuda_version = f'{cuda_major}.{cuda_minor}'
torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2])
# This is a show-stopping error, should probably not proceed past this
......@@ -344,10 +339,11 @@ class OpBuilder(ABC):
def is_cuda_enable(self):
try:
if torch.cuda.is_available():
return '-D__ENABLE_CUDA__'
except:
print(f"{WARNING} {self.name} torch.cuda is missing, only cpu ops can be compiled!")
assert_no_cuda_mismatch(self.name)
return '-D__ENABLE_CUDA__'
except BaseException:
print(f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, "
"only cpu ops can be compiled!")
return '-D__DISABLE_CUDA__'
return '-D__DISABLE_CUDA__'
......@@ -459,7 +455,11 @@ class OpBuilder(ABC):
raise RuntimeError(f"Unable to JIT load the {self.name} op due to ninja not being installed.")
if isinstance(self, CUDAOpBuilder) and not self.is_rocm_pytorch():
self.build_for_cpu = not assert_no_cuda_mismatch(self.name)
try:
assert_no_cuda_mismatch(self.name)
self.build_for_cpu = False
except BaseException:
self.build_for_cpu = True
self.jit_mode = True
from torch.utils.cpp_extension import load
......@@ -579,7 +579,12 @@ class CUDAOpBuilder(OpBuilder):
return super().is_compatible(verbose)
def builder(self):
self.build_for_cpu = not assert_no_cuda_mismatch(self.name)
try:
assert_no_cuda_mismatch(self.name)
self.build_for_cpu = False
except BaseException:
self.build_for_cpu = True
if self.build_for_cpu:
from torch.utils.cpp_extension import CppExtension as ExtensionBuilder
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册