提交 64214255 编写于 作者: A AUTOMATIC

use commandline-supplied cuda device name instead of cuda:0 for safetensors PR...

use commandline-supplied cuda device name instead of cuda:0 for safetensors PR that doesn't fix anything
上级 68fbf455
...@@ -173,7 +173,7 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None ...@@ -173,7 +173,7 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
if extension.lower() == ".safetensors": if extension.lower() == ".safetensors":
device = map_location or shared.weight_load_location device = map_location or shared.weight_load_location
if device is None: if device is None:
device = "cuda:0" if torch.cuda.is_available() else "cpu" device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu"
pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
else: else:
pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册