sd_models.py 22.5 KB
Newer Older
1
import collections
2 3
import os.path
import sys
4
import gc
5 6
import threading

7
import torch
8
import re
A
AUTOMATIC 已提交
9
import safetensors.torch
10
from omegaconf import OmegaConf
J
Jay Smith 已提交
11 12 13
from os import mkdir
from urllib import request
import ldm.modules.midas as midas
14 15 16

from ldm.util import instantiate_from_config

17
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
18 19
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
20
import tomesd
21 22

model_dir = "Stable-diffusion"
23
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
24 25

checkpoints_list = {}
A
Aarni Koskela 已提交
26 27
checkpoint_aliases = {}
checkpoint_alisases = checkpoint_aliases  # for compatibility with old name
28
checkpoints_loaded = collections.OrderedDict()
29

A
AUTOMATIC 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45

class CheckpointInfo:
    def __init__(self, filename):
        self.filename = filename
        abspath = os.path.abspath(filename)

        if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
            name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
        elif abspath.startswith(model_path):
            name = abspath.replace(model_path, '')
        else:
            name = os.path.basename(filename)

        if name.startswith("\\") or name.startswith("/"):
            name = name[1:]

46
        self.name = name
47
        self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
A
AUTOMATIC 已提交
48 49
        self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
        self.hash = model_hash(filename)
50

51
        self.sha256 = hashes.sha256_from_cache(self.filename, f"checkpoint/{name}")
52 53
        self.shorthash = self.sha256[0:10] if self.sha256 else None

54 55 56
        self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'

        self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
A
AUTOMATIC 已提交
57

58 59 60 61 62 63 64 65 66
        self.metadata = {}

        _, ext = os.path.splitext(self.filename)
        if ext.lower() == ".safetensors":
            try:
                self.metadata = read_metadata_from_safetensors(filename)
            except Exception as e:
                errors.display(e, f"reading checkpoint metadata: {filename}")

A
AUTOMATIC 已提交
67 68 69
    def register(self):
        checkpoints_list[self.title] = self
        for id in self.ids:
A
Aarni Koskela 已提交
70
            checkpoint_aliases[id] = self
A
AUTOMATIC 已提交
71 72

    def calculate_shorthash(self):
73
        self.sha256 = hashes.sha256(self.filename, f"checkpoint/{self.name}")
A
AUTOMATIC 已提交
74 75 76
        if self.sha256 is None:
            return

A
AUTOMATIC 已提交
77 78 79
        self.shorthash = self.sha256[0:10]

        if self.shorthash not in self.ids:
80
            self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
A
AUTOMATIC 已提交
81

82
        checkpoints_list.pop(self.title)
83
        self.title = f'{self.name} [{self.shorthash}]'
84
        self.register()
85

A
AUTOMATIC 已提交
86 87 88
        return self.shorthash


89 90
try:
    # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
A
AUTOMATIC 已提交
91
    from transformers import logging, CLIPModel  # noqa: F401
92 93 94 95 96 97

    logging.set_verbosity_error()
except Exception:
    pass


98
def setup_model():
99
    os.makedirs(model_path, exist_ok=True)
100

J
Jay Smith 已提交
101
    enable_midas_autodownload()
102 103


A
AUTOMATIC 已提交
104 105 106 107 108 109 110 111
def checkpoint_tiles():
    def convert(name):
        return int(name) if name.isdigit() else name.lower()

    def alphanumeric_key(key):
        return [convert(c) for c in re.split('([0-9]+)', key)]

    return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
112 113


114 115
def list_models():
    checkpoints_list.clear()
A
Aarni Koskela 已提交
116
    checkpoint_aliases.clear()
117 118

    cmd_ckpt = shared.cmd_opts.ckpt
W
w-e-w 已提交
119
    if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file or os.path.exists(cmd_ckpt):
W
w-e-w 已提交
120 121 122 123 124 125
        model_url = None
    else:
        model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"

    model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])

126
    if os.path.exists(cmd_ckpt):
A
AUTOMATIC 已提交
127 128 129 130
        checkpoint_info = CheckpointInfo(cmd_ckpt)
        checkpoint_info.register()

        shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
131
    elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
132
        print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
A
AUTOMATIC 已提交
133

134
    for filename in sorted(model_list, key=str.lower):
A
AUTOMATIC 已提交
135 136 137
        checkpoint_info = CheckpointInfo(filename)
        checkpoint_info.register()

138

A
AUTOMATIC 已提交
139
def get_closet_checkpoint_match(search_string):
A
Aarni Koskela 已提交
140
    checkpoint_info = checkpoint_aliases.get(search_string, None)
A
AUTOMATIC 已提交
141
    if checkpoint_info is not None:
142
        return checkpoint_info
143

A
AUTOMATIC 已提交
144 145 146
    found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
    if found:
        return found[0]
147

D
DepFA 已提交
148
    return None
149

150

151
def model_hash(filename):
A
AUTOMATIC 已提交
152 153
    """old hash that only looks at a small part of the file and is prone to collisions"""

154 155 156 157 158 159 160 161 162 163 164 165 166
    try:
        with open(filename, "rb") as file:
            import hashlib
            m = hashlib.sha256()

            file.seek(0x100000)
            m.update(file.read(0x10000))
            return m.hexdigest()[0:8]
    except FileNotFoundError:
        return 'NOFILE'


def select_checkpoint():
L
linkoid 已提交
167
    """Raises `FileNotFoundError` if no checkpoints are found."""
168
    model_checkpoint = shared.opts.sd_model_checkpoint
169

A
Aarni Koskela 已提交
170
    checkpoint_info = checkpoint_aliases.get(model_checkpoint, None)
171 172 173 174
    if checkpoint_info is not None:
        return checkpoint_info

    if len(checkpoints_list) == 0:
L
linkoid 已提交
175
        error_message = "No checkpoints found. When searching for checkpoints, looked at:"
176
        if shared.cmd_opts.ckpt is not None:
L
linkoid 已提交
177 178
            error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}"
        error_message += f"\n - directory {model_path}"
179
        if shared.cmd_opts.ckpt_dir is not None:
L
linkoid 已提交
180 181 182
            error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}"
        error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations."
        raise FileNotFoundError(error_message)
183 184 185 186 187 188 189 190

    checkpoint_info = next(iter(checkpoints_list.values()))
    if model_checkpoint is not None:
        print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)

    return checkpoint_info


C
carat-johyun 已提交
191
checkpoint_dict_replacements = {
192 193 194 195 196 197 198
    'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
    'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
    'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
}


def transform_checkpoint_dict_key(k):
C
carat-johyun 已提交
199
    for text, replacement in checkpoint_dict_replacements.items():
200 201 202 203 204 205
        if k.startswith(text):
            k = replacement + k[len(text):]

    return k


206
def get_state_dict_from_checkpoint(pl_sd):
207 208
    pl_sd = pl_sd.pop("state_dict", pl_sd)
    pl_sd.pop("state_dict", None)
209 210 211 212 213 214 215

    sd = {}
    for k, v in pl_sd.items():
        new_key = transform_checkpoint_dict_key(k)

        if new_key is not None:
            sd[new_key] = v
216

A
AUTOMATIC 已提交
217 218 219 220
    pl_sd.clear()
    pl_sd.update(sd)

    return pl_sd
221 222


223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
def read_metadata_from_safetensors(filename):
    import json

    with open(filename, mode="rb") as file:
        metadata_len = file.read(8)
        metadata_len = int.from_bytes(metadata_len, "little")
        json_start = file.read(2)

        assert metadata_len > 2 and json_start in (b'{"', b"{'"), f"{filename} is not a safetensors file"
        json_data = json_start + file.read(metadata_len-2)
        json_obj = json.loads(json_data)

        res = {}
        for k, v in json_obj.get("__metadata__", {}).items():
            res[k] = v
238
            if isinstance(v, str) and v[0:1] == '{':
239 240
                try:
                    res[k] = json.loads(v)
A
AUTOMATIC 已提交
241
                except Exception:
242 243 244 245 246
                    pass

        return res


247 248 249
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
    _, extension = os.path.splitext(checkpoint_file)
    if extension.lower() == ".safetensors":
250 251
        device = map_location or shared.weight_load_location or devices.get_optimal_device_name()

252 253 254 255
        if not shared.opts.disable_mmap_load_safetensors:
            pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
        else:
            pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
256
            pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
257 258 259 260 261 262 263 264 265 266
    else:
        pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)

    if print_global_state and "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")

    sd = get_state_dict_from_checkpoint(pl_sd)
    return sd


267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
    sd_model_hash = checkpoint_info.calculate_shorthash()
    timer.record("calculate hash")

    if checkpoint_info in checkpoints_loaded:
        # use checkpoint cache
        print(f"Loading weights [{sd_model_hash}] from cache")
        return checkpoints_loaded[checkpoint_info]

    print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
    res = read_state_dict(checkpoint_info.filename)
    timer.record("load weights from disk")

    return res


def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
A
AUTOMATIC 已提交
284
    sd_model_hash = checkpoint_info.calculate_shorthash()
285 286
    timer.record("calculate hash")

287
    shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
288

289 290
    if state_dict is None:
        state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
C
cluder 已提交
291

292
    model.is_sdxl = hasattr(model, 'conditioner')
293 294 295
    model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
    model.is_sd1 = not model.is_sdxl and not model.is_sd2

296
    if model.is_sdxl:
297 298
        sd_models_xl.extend_sdxl(model)

299 300 301
    model.load_state_dict(state_dict, strict=False)
    del state_dict
    timer.record("apply weights to model")
302

303 304 305 306 307 308 309
    if shared.opts.sd_checkpoint_cache > 0:
        # cache newly loaded model
        checkpoints_loaded[checkpoint_info] = model.state_dict().copy()

    if shared.cmd_opts.opt_channelslast:
        model.to(memory_format=torch.channels_last)
        timer.record("apply channels_last")
310

311 312 313
    if not shared.cmd_opts.no_half:
        vae = model.first_stage_model
        depth_model = getattr(model, 'depth_model', None)
314

315 316 317 318 319 320
        # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
        if shared.cmd_opts.no_half_vae:
            model.first_stage_model = None
        # with --upcast-sampling, don't convert the depth model weights to float16
        if shared.cmd_opts.upcast_sampling and depth_model:
            model.depth_model = None
A
AUTOMATIC 已提交
321

322 323 324 325
        model.half()
        model.first_stage_model = vae
        if depth_model:
            model.depth_model = depth_model
A
AUTOMATIC 已提交
326

327
        timer.record("apply half()")
328

329
    devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype
330
    devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
331

332 333
    model.first_stage_model.to(devices.dtype_vae)
    timer.record("apply dtype to VAE")
A
AUTOMATIC 已提交
334

C
cluder 已提交
335
    # clean up cache if limit is reached
336 337
    while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
        checkpoints_loaded.popitem(last=False)
M
Muhammad Rizqi Nur 已提交
338

339
    model.sd_model_hash = sd_model_hash
A
AUTOMATIC 已提交
340
    model.sd_model_checkpoint = checkpoint_info.filename
341
    model.sd_checkpoint_info = checkpoint_info
342
    shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
343

344 345
    if hasattr(model, 'logvar'):
        model.logvar = model.logvar.to(devices.device)  # fix for training
A
AUTOMATIC 已提交
346

M
Misc  
Muhammad Rizqi Nur 已提交
347
    sd_vae.delete_base_vae()
348
    sd_vae.clear_loaded_vae()
349 350
    vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
    sd_vae.load_vae(model, vae_file, vae_source)
351
    timer.record("load VAE")
352

353

J
Jay Smith 已提交
354 355 356 357 358 359 360 361 362 363
def enable_midas_autodownload():
    """
    Gives the ldm.modules.midas.api.load_model function automatic downloading.

    When the 512-depth-ema model, and other future models like it, is loaded,
    it calls midas.api.load_model to load the associated midas depth model.
    This function applies a wrapper to download the model to the correct
    location automatically.
    """

364
    midas_path = os.path.join(paths.models_path, 'midas')
J
Jay Smith 已提交
365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386

    # stable-diffusion-stability-ai hard-codes the midas model path to
    # a location that differs from where other scripts using this model look.
    # HACK: Overriding the path here.
    for k, v in midas.api.ISL_PATHS.items():
        file_name = os.path.basename(v)
        midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)

    midas_urls = {
        "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
        "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
        "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
        "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
    }

    midas.api.load_model_inner = midas.api.load_model

    def load_model_wrapper(model_type):
        path = midas.api.ISL_PATHS[model_type]
        if not os.path.exists(path):
            if not os.path.exists(midas_path):
                mkdir(midas_path)
387

J
Jay Smith 已提交
388 389 390 391 392 393 394 395
            print(f"Downloading midas model weights for {model_type} to {path}")
            request.urlretrieve(midas_urls[model_type], path)
            print(f"{model_type} downloaded")

        return midas.api.load_model_inner(model_type)

    midas.api.load_model = load_model_wrapper

396

397
def repair_config(sd_config):
398

399 400
    if not hasattr(sd_config.model.params, "use_ema"):
        sd_config.model.params.use_ema = False
401

A
AUTOMATIC1111 已提交
402 403 404 405 406
    if hasattr(sd_config.model.params, 'unet_config'):
        if shared.cmd_opts.no_half:
            sd_config.model.params.unet_config.params.use_fp16 = False
        elif shared.cmd_opts.upcast_sampling:
            sd_config.model.params.unet_config.params.use_fp16 = True
407

408 409 410
    if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available:
        sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla"

411 412 413 414 415
    # For UnCLIP-L, override the hardcoded karlo directory
    if hasattr(sd_config.model.params, "noise_aug_config") and hasattr(sd_config.model.params.noise_aug_config.params, "clip_stats_path"):
        karlo_path = os.path.join(paths.models_path, 'karlo')
        sd_config.model.params.noise_aug_config.params.clip_stats_path = sd_config.model.params.noise_aug_config.params.clip_stats_path.replace("checkpoints/karlo_models", karlo_path)

416

417 418
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
A
AUTOMATIC1111 已提交
419
sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight'
A
AUTOMATIC1111 已提交
420
sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
421

422 423 424 425

class SdModelData:
    def __init__(self):
        self.sd_model = None
426
        self.was_loaded_at_least_once = False
427 428 429
        self.lock = threading.Lock()

    def get_sd_model(self):
430 431 432
        if self.was_loaded_at_least_once:
            return self.sd_model

433 434
        if self.sd_model is None:
            with self.lock:
435
                if self.sd_model is not None or self.was_loaded_at_least_once:
436 437
                    return self.sd_model

438 439 440
                try:
                    load_model()
                except Exception as e:
L
linkoid 已提交
441
                    errors.display(e, "loading stable diffusion model", full_traceback=True)
442 443 444 445 446 447 448 449 450 451 452 453 454
                    print("", file=sys.stderr)
                    print("Stable diffusion model failed to load", file=sys.stderr)
                    self.sd_model = None

        return self.sd_model

    def set_sd_model(self, v):
        self.sd_model = v


model_data = SdModelData()


A
AUTOMATIC1111 已提交
455 456 457 458 459 460 461 462 463
def get_empty_cond(sd_model):
    if hasattr(sd_model, 'conditioner'):
        d = sd_model.get_learned_conditioning([""])
        return d['crossattn']
    else:
        return sd_model.cond_stage_model([""])



464
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
465
    from modules import lowvram, sd_hijack
466
    checkpoint_info = checkpoint_info or select_checkpoint()
467

468 469 470
    if model_data.sd_model:
        sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
        model_data.sd_model = None
471 472 473
        gc.collect()
        devices.torch_gc()

474
    do_inpainting_hijack()
K
Kyle 已提交
475

476
    timer = Timer()
477

478 479 480 481
    if already_loaded_state_dict is not None:
        state_dict = already_loaded_state_dict
    else:
        state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
482

483
    checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
A
linter  
AUTOMATIC1111 已提交
484
    clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict)
M
MrCheeze 已提交
485

486
    timer.record("find config")
487

488 489 490 491 492 493
    sd_config = OmegaConf.load(checkpoint_config)
    repair_config(sd_config)

    timer.record("load config")

    print(f"Creating model from config: {checkpoint_config}")
494

495
    sd_model = None
A
AUTOMATIC 已提交
496
    try:
497
        with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
A
AUTOMATIC 已提交
498
            sd_model = instantiate_from_config(sd_config.model)
A
AUTOMATIC 已提交
499
    except Exception:
500 501 502
        pass

    if sd_model is None:
A
AUTOMATIC 已提交
503
        print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
504
        sd_model = instantiate_from_config(sd_config.model)
505

506
    sd_model.used_config = checkpoint_config
507

508
    timer.record("create model")
509

510
    load_model_weights(sd_model, checkpoint_info, state_dict, timer)
511

512 513 514 515 516
    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
        lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
    else:
        sd_model.to(shared.device)

517 518
    timer.record("move model to device")

519 520
    sd_hijack.model_hijack.hijack(sd_model)

521 522
    timer.record("hijack")

523
    sd_model.eval()
524
    model_data.sd_model = sd_model
525
    model_data.was_loaded_at_least_once = True
526

527 528
    sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)  # Reload embeddings after model load as they may or may not fit the model

529 530
    timer.record("load textual inversion embeddings")

531 532
    script_callbacks.model_loaded_callback(sd_model)

533
    timer.record("scripts callbacks")
534

535
    with devices.autocast(), torch.no_grad():
A
AUTOMATIC1111 已提交
536
        sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model)
537 538 539

    timer.record("calculate empty prompt")

540
    print(f"Model loaded in {timer.summary()}.")
541

542 543 544
    return sd_model


545
def reload_model_weights(sd_model=None, info=None):
546
    from modules import lowvram, devices, sd_hijack
547
    checkpoint_info = info or select_checkpoint()
548

549
    if not sd_model:
550
        sd_model = model_data.sd_model
551

552
    if sd_model is None:  # previous model load failed
553 554 555 556 557
        current_checkpoint_info = None
    else:
        current_checkpoint_info = sd_model.sd_checkpoint_info
        if sd_model.sd_model_checkpoint == checkpoint_info.filename:
            return
558

A
AUTOMATIC 已提交
559 560
        sd_unet.apply_unet("None")

A
AUTOMATIC 已提交
561 562 563 564
        if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
            lowvram.send_everything_to_cpu()
        else:
            sd_model.to(devices.cpu)
565

A
AUTOMATIC 已提交
566
        sd_hijack.model_hijack.undo_hijack(sd_model)
567

568 569
    timer = Timer()

570 571 572 573 574 575 576 577
    state_dict = get_checkpoint_state_dict(checkpoint_info, timer)

    checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)

    timer.record("find config")

    if sd_model is None or checkpoint_config != sd_model.used_config:
        del sd_model
Φ
Φφ 已提交
578
        load_model(checkpoint_info, already_loaded_state_dict=state_dict)
579
        return model_data.sd_model
580

581
    try:
582
        load_model_weights(sd_model, checkpoint_info, state_dict, timer)
A
AUTOMATIC 已提交
583
    except Exception:
584
        print("Failed to load checkpoint, restoring previous")
585
        load_model_weights(sd_model, current_checkpoint_info, None, timer)
586 587 588
        raise
    finally:
        sd_hijack.model_hijack.hijack(sd_model)
589 590
        timer.record("hijack")

591
        script_callbacks.model_loaded_callback(sd_model)
592
        timer.record("script callbacks")
593 594 595

        if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
            sd_model.to(devices.device)
596
            timer.record("move model to device")
597

598
    print(f"Weights loaded in {timer.summary()}.")
599

600
    return sd_model
Φ
Φφ 已提交
601

602

Φ
Φφ 已提交
603
def unload_model_weights(sd_model=None, info=None):
A
AUTOMATIC 已提交
604
    from modules import devices, sd_hijack
Φ
Φφ 已提交
605 606
    timer = Timer()

607 608 609 610
    if model_data.sd_model:
        model_data.sd_model.to(devices.cpu)
        sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
        model_data.sd_model = None
Φ
Φφ 已提交
611 612 613 614 615 616
        sd_model = None
        gc.collect()
        devices.torch_gc()

    print(f"Unloaded weights {timer.summary()}.")

617 618 619
    return sd_model


620
def apply_token_merging(sd_model, token_merging_ratio):
621 622 623 624
    """
    Applies speed and memory optimizations from tomesd.
    """

625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643
    current_token_merging_ratio = getattr(sd_model, 'applied_token_merged_ratio', 0)

    if current_token_merging_ratio == token_merging_ratio:
        return

    if current_token_merging_ratio > 0:
        tomesd.remove_patch(sd_model)

    if token_merging_ratio > 0:
        tomesd.apply_patch(
            sd_model,
            ratio=token_merging_ratio,
            use_rand=False,  # can cause issues with some samplers
            merge_attn=True,
            merge_crossattn=False,
            merge_mlp=False
        )

    sd_model.applied_token_merged_ratio = token_merging_ratio