未验证 提交 799760ab 编写于 作者: A AUTOMATIC1111 提交者: GitHub

Merge pull request #11722 from akx/mps-gc-fix

Fix MPS cache cleanup
......@@ -54,8 +54,9 @@ def torch_gc():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
elif has_mps() and hasattr(torch.mps, 'empty_cache'):
torch.mps.empty_cache()
if has_mps():
mac_specific.torch_mps_gc()
def enable_tf32():
......
import logging
import torch
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version
log = logging.getLogger()
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# use check `getattr` and try it for compatibility.
......@@ -19,9 +23,19 @@ def check_for_mps() -> bool:
return False
else:
return torch.backends.mps.is_available() and torch.backends.mps.is_built()
has_mps = check_for_mps()
def torch_mps_gc() -> None:
try:
from torch.mps import empty_cache
empty_cache()
except Exception:
log.warning("MPS garbage collection failed", exc_info=True)
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def cumsum_fix(input, cumsum_func, *args, **kwargs):
if input.device.type == 'mps':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册