未验证 提交 24e07399 编写于 作者: J Jeff Rasley 提交者: GitHub

update SA comp check to fix torch-cpu issue (#631)

上级 81aeea36
......@@ -25,6 +25,18 @@ class SparseAttnBuilder(OpBuilder):
command_status = list(map(self.command_exists, required_commands))
deps_compatible = all(command_status)
# torch-cpu will not have a cuda version
if torch.version.cuda is None:
cuda_compatible = False
self.warning(f"{self.NAME} cuda is not available from torch")
else:
major, minor = torch.version.cuda.split('.')[:2]
cuda_compatible = int(major) == 10 and int(minor) >= 1
if not cuda_compatible:
self.warning(
f"{self.NAME} requires CUDA version 10.1+, does not currently support >=11 or <10.1"
)
TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1])
torch_compatible = TORCH_MAJOR == 1 and TORCH_MINOR >= 5
......@@ -33,4 +45,5 @@ class SparseAttnBuilder(OpBuilder):
f'{self.NAME} requires a torch version >= 1.5 but detected {TORCH_MAJOR}.{TORCH_MINOR}'
)
return super().is_compatible() and deps_compatible and torch_compatible
return super().is_compatible(
) and deps_compatible and torch_compatible and cuda_compatible
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册