sd_models.py 17.5 KB
Newer Older
1
import collections
2 3
import os.path
import sys
4
import gc
5
import torch
6
import re
A
AUTOMATIC 已提交
7
import safetensors.torch
8
from omegaconf import OmegaConf
J
Jay Smith 已提交
9 10 11
from os import mkdir
from urllib import request
import ldm.modules.midas as midas
12 13 14

from ldm.util import instantiate_from_config

15
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
16
from modules.paths import models_path
17 18
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
19 20

model_dir = "Stable-diffusion"
21
model_path = os.path.abspath(os.path.join(paths.models_path, model_dir))
22 23

checkpoints_list = {}
A
AUTOMATIC 已提交
24
checkpoint_alisases = {}
25
checkpoints_loaded = collections.OrderedDict()
26

A
AUTOMATIC 已提交
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42

class CheckpointInfo:
    def __init__(self, filename):
        self.filename = filename
        abspath = os.path.abspath(filename)

        if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
            name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
        elif abspath.startswith(model_path):
            name = abspath.replace(model_path, '')
        else:
            name = os.path.basename(filename)

        if name.startswith("\\") or name.startswith("/"):
            name = name[1:]

43
        self.name = name
44
        self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
A
AUTOMATIC 已提交
45 46
        self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
        self.hash = model_hash(filename)
47

48
        self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name)
49 50
        self.shorthash = self.sha256[0:10] if self.sha256 else None

51 52 53
        self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'

        self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
A
AUTOMATIC 已提交
54 55 56 57 58 59 60

    def register(self):
        checkpoints_list[self.title] = self
        for id in self.ids:
            checkpoint_alisases[id] = self

    def calculate_shorthash(self):
61
        self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
A
AUTOMATIC 已提交
62 63 64
        if self.sha256 is None:
            return

A
AUTOMATIC 已提交
65 66 67
        self.shorthash = self.sha256[0:10]

        if self.shorthash not in self.ids:
68
            self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
A
AUTOMATIC 已提交
69

70
        checkpoints_list.pop(self.title)
71
        self.title = f'{self.name} [{self.shorthash}]'
72
        self.register()
73

A
AUTOMATIC 已提交
74 75 76
        return self.shorthash


77 78 79
try:
    # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.

M
MalumaDev 已提交
80
    from transformers import logging, CLIPModel
81 82 83 84 85 86

    logging.set_verbosity_error()
except Exception:
    pass


87
def setup_model():
88 89
    if not os.path.exists(model_path):
        os.makedirs(model_path)
90

91
    list_models()
J
Jay Smith 已提交
92
    enable_midas_autodownload()
93 94


A
AUTOMATIC 已提交
95 96 97 98 99 100 101 102
def checkpoint_tiles():
    def convert(name):
        return int(name) if name.isdigit() else name.lower()

    def alphanumeric_key(key):
        return [convert(c) for c in re.split('([0-9]+)', key)]

    return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
103 104


105 106
def list_models():
    checkpoints_list.clear()
A
AUTOMATIC 已提交
107
    checkpoint_alisases.clear()
108 109

    cmd_ckpt = shared.cmd_opts.ckpt
W
w-e-w 已提交
110 111 112 113 114 115 116
    if shared.cmd_opts.no_download_sd_model or cmd_ckpt != shared.sd_model_file:
        model_url = None
    else:
        model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors"

    model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])

117
    if os.path.exists(cmd_ckpt):
A
AUTOMATIC 已提交
118 119 120 121
        checkpoint_info = CheckpointInfo(cmd_ckpt)
        checkpoint_info.register()

        shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title
122
    elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
123
        print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
A
AUTOMATIC 已提交
124

125
    for filename in model_list:
A
AUTOMATIC 已提交
126 127 128
        checkpoint_info = CheckpointInfo(filename)
        checkpoint_info.register()

129

A
AUTOMATIC 已提交
130 131 132
def get_closet_checkpoint_match(search_string):
    checkpoint_info = checkpoint_alisases.get(search_string, None)
    if checkpoint_info is not None:
133
        return checkpoint_info
134

A
AUTOMATIC 已提交
135 136 137
    found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title))
    if found:
        return found[0]
138

D
DepFA 已提交
139
    return None
140

141

142
def model_hash(filename):
A
AUTOMATIC 已提交
143 144
    """old hash that only looks at a small part of the file and is prone to collisions"""

145 146 147 148 149 150 151 152 153 154 155 156 157 158
    try:
        with open(filename, "rb") as file:
            import hashlib
            m = hashlib.sha256()

            file.seek(0x100000)
            m.update(file.read(0x10000))
            return m.hexdigest()[0:8]
    except FileNotFoundError:
        return 'NOFILE'


def select_checkpoint():
    model_checkpoint = shared.opts.sd_model_checkpoint
159
        
A
AUTOMATIC 已提交
160
    checkpoint_info = checkpoint_alisases.get(model_checkpoint, None)
161 162 163 164
    if checkpoint_info is not None:
        return checkpoint_info

    if len(checkpoints_list) == 0:
165
        print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
166 167 168 169 170
        if shared.cmd_opts.ckpt is not None:
            print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
        print(f" - directory {model_path}", file=sys.stderr)
        if shared.cmd_opts.ckpt_dir is not None:
            print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
171
        print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
172
        exit(1)
173 174 175 176 177 178 179 180

    checkpoint_info = next(iter(checkpoints_list.values()))
    if model_checkpoint is not None:
        print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)

    return checkpoint_info


181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
chckpoint_dict_replacements = {
    'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
    'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
    'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
}


def transform_checkpoint_dict_key(k):
    for text, replacement in chckpoint_dict_replacements.items():
        if k.startswith(text):
            k = replacement + k[len(text):]

    return k


196
def get_state_dict_from_checkpoint(pl_sd):
197 198
    pl_sd = pl_sd.pop("state_dict", pl_sd)
    pl_sd.pop("state_dict", None)
199 200 201 202 203 204 205

    sd = {}
    for k, v in pl_sd.items():
        new_key = transform_checkpoint_dict_key(k)

        if new_key is not None:
            sd[new_key] = v
206

A
AUTOMATIC 已提交
207 208 209 210
    pl_sd.clear()
    pl_sd.update(sd)

    return pl_sd
211 212


213 214 215
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
    _, extension = os.path.splitext(checkpoint_file)
    if extension.lower() == ".safetensors":
216
        device = map_location or shared.weight_load_location or devices.get_optimal_device_name()
217
        pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
218 219 220 221 222 223 224 225 226 227
    else:
        pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)

    if print_global_state and "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")

    sd = get_state_dict_from_checkpoint(pl_sd)
    return sd


228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
    sd_model_hash = checkpoint_info.calculate_shorthash()
    timer.record("calculate hash")

    if checkpoint_info in checkpoints_loaded:
        # use checkpoint cache
        print(f"Loading weights [{sd_model_hash}] from cache")
        return checkpoints_loaded[checkpoint_info]

    print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
    res = read_state_dict(checkpoint_info.filename)
    timer.record("load weights from disk")

    return res


def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
A
AUTOMATIC 已提交
245
    sd_model_hash = checkpoint_info.calculate_shorthash()
246 247
    timer.record("calculate hash")

248
    shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title
249

250 251
    if state_dict is None:
        state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
C
cluder 已提交
252

253 254 255
    model.load_state_dict(state_dict, strict=False)
    del state_dict
    timer.record("apply weights to model")
256

257 258 259 260 261 262 263
    if shared.opts.sd_checkpoint_cache > 0:
        # cache newly loaded model
        checkpoints_loaded[checkpoint_info] = model.state_dict().copy()

    if shared.cmd_opts.opt_channelslast:
        model.to(memory_format=torch.channels_last)
        timer.record("apply channels_last")
264

265 266 267
    if not shared.cmd_opts.no_half:
        vae = model.first_stage_model
        depth_model = getattr(model, 'depth_model', None)
268

269 270 271 272 273 274
        # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
        if shared.cmd_opts.no_half_vae:
            model.first_stage_model = None
        # with --upcast-sampling, don't convert the depth model weights to float16
        if shared.cmd_opts.upcast_sampling and depth_model:
            model.depth_model = None
A
AUTOMATIC 已提交
275

276 277 278 279
        model.half()
        model.first_stage_model = vae
        if depth_model:
            model.depth_model = depth_model
A
AUTOMATIC 已提交
280

281
        timer.record("apply half()")
282

283 284 285 286
    devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
    devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
    devices.dtype_unet = model.model.diffusion_model.dtype
    devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16
287

288 289
    model.first_stage_model.to(devices.dtype_vae)
    timer.record("apply dtype to VAE")
A
AUTOMATIC 已提交
290

C
cluder 已提交
291
    # clean up cache if limit is reached
292 293
    while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
        checkpoints_loaded.popitem(last=False)
M
Muhammad Rizqi Nur 已提交
294

295
    model.sd_model_hash = sd_model_hash
A
AUTOMATIC 已提交
296
    model.sd_model_checkpoint = checkpoint_info.filename
297
    model.sd_checkpoint_info = checkpoint_info
298
    shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256
299

A
AUTOMATIC 已提交
300 301
    model.logvar = model.logvar.to(devices.device)  # fix for training

M
Misc  
Muhammad Rizqi Nur 已提交
302
    sd_vae.delete_base_vae()
303
    sd_vae.clear_loaded_vae()
304 305
    vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
    sd_vae.load_vae(model, vae_file, vae_source)
306
    timer.record("load VAE")
307

308

J
Jay Smith 已提交
309 310 311 312 313 314 315 316 317 318
def enable_midas_autodownload():
    """
    Gives the ldm.modules.midas.api.load_model function automatic downloading.

    When the 512-depth-ema model, and other future models like it, is loaded,
    it calls midas.api.load_model to load the associated midas depth model.
    This function applies a wrapper to download the model to the correct
    location automatically.
    """

319
    midas_path = os.path.join(paths.models_path, 'midas')
J
Jay Smith 已提交
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350

    # stable-diffusion-stability-ai hard-codes the midas model path to
    # a location that differs from where other scripts using this model look.
    # HACK: Overriding the path here.
    for k, v in midas.api.ISL_PATHS.items():
        file_name = os.path.basename(v)
        midas.api.ISL_PATHS[k] = os.path.join(midas_path, file_name)

    midas_urls = {
        "dpt_large": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt",
        "dpt_hybrid": "https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt",
        "midas_v21": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21-f6b98070.pt",
        "midas_v21_small": "https://github.com/AlexeyAB/MiDaS/releases/download/midas_dpt/midas_v21_small-70d6b9c8.pt",
    }

    midas.api.load_model_inner = midas.api.load_model

    def load_model_wrapper(model_type):
        path = midas.api.ISL_PATHS[model_type]
        if not os.path.exists(path):
            if not os.path.exists(midas_path):
                mkdir(midas_path)
    
            print(f"Downloading midas model weights for {model_type} to {path}")
            request.urlretrieve(midas_urls[model_type], path)
            print(f"{model_type} downloaded")

        return midas.api.load_model_inner(model_type)

    midas.api.load_model = load_model_wrapper

351

352
def repair_config(sd_config):
353

354 355
    if not hasattr(sd_config.model.params, "use_ema"):
        sd_config.model.params.use_ema = False
356

357 358 359 360
    if shared.cmd_opts.no_half:
        sd_config.model.params.unet_config.params.use_fp16 = False
    elif shared.cmd_opts.upcast_sampling:
        sd_config.model.params.unet_config.params.use_fp16 = True
361

362

363 364 365
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'

366
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
367
    from modules import lowvram, sd_hijack
368
    checkpoint_info = checkpoint_info or select_checkpoint()
369

370 371 372 373 374 375
    if shared.sd_model:
        sd_hijack.model_hijack.undo_hijack(shared.sd_model)
        shared.sd_model = None
        gc.collect()
        devices.torch_gc()

376
    do_inpainting_hijack()
K
Kyle 已提交
377

378
    timer = Timer()
379

380 381 382 383
    if already_loaded_state_dict is not None:
        state_dict = already_loaded_state_dict
    else:
        state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
384

385
    checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
386
    clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
M
MrCheeze 已提交
387

388
    timer.record("find config")
389

390 391 392 393 394 395
    sd_config = OmegaConf.load(checkpoint_config)
    repair_config(sd_config)

    timer.record("load config")

    print(f"Creating model from config: {checkpoint_config}")
396

397
    sd_model = None
A
AUTOMATIC 已提交
398
    try:
399
        with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
A
AUTOMATIC 已提交
400 401
            sd_model = instantiate_from_config(sd_config.model)
    except Exception as e:
402 403 404
        pass

    if sd_model is None:
A
AUTOMATIC 已提交
405
        print('Failed to create model quickly; will retry using slow method.', file=sys.stderr)
406
        sd_model = instantiate_from_config(sd_config.model)
407

408
    sd_model.used_config = checkpoint_config
409

410
    timer.record("create model")
411

412
    load_model_weights(sd_model, checkpoint_info, state_dict, timer)
413

414 415 416 417 418
    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
        lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
    else:
        sd_model.to(shared.device)

419 420
    timer.record("move model to device")

421 422
    sd_hijack.model_hijack.hijack(sd_model)

423 424
    timer.record("hijack")

425
    sd_model.eval()
426 427
    shared.sd_model = sd_model

428 429
    sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)  # Reload embeddings after model load as they may or may not fit the model

430 431
    timer.record("load textual inversion embeddings")

432 433
    script_callbacks.model_loaded_callback(sd_model)

434
    timer.record("scripts callbacks")
435

436
    print(f"Model loaded in {timer.summary()}.")
437

438 439 440
    return sd_model


441
def reload_model_weights(sd_model=None, info=None):
442
    from modules import lowvram, devices, sd_hijack
443
    checkpoint_info = info or select_checkpoint()
444

445 446
    if not sd_model:
        sd_model = shared.sd_model
447

448
    if sd_model is None:  # previous model load failed
449 450 451 452 453
        current_checkpoint_info = None
    else:
        current_checkpoint_info = sd_model.sd_checkpoint_info
        if sd_model.sd_model_checkpoint == checkpoint_info.filename:
            return
454

A
AUTOMATIC 已提交
455 456 457 458
        if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
            lowvram.send_everything_to_cpu()
        else:
            sd_model.to(devices.cpu)
459

A
AUTOMATIC 已提交
460
        sd_hijack.model_hijack.undo_hijack(sd_model)
461

462 463
    timer = Timer()

464 465 466 467 468 469 470 471 472 473 474 475
    state_dict = get_checkpoint_state_dict(checkpoint_info, timer)

    checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)

    timer.record("find config")

    if sd_model is None or checkpoint_config != sd_model.used_config:
        del sd_model
        checkpoints_loaded.clear()
        load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
        return shared.sd_model

476
    try:
477
        load_model_weights(sd_model, checkpoint_info, state_dict, timer)
478 479
    except Exception as e:
        print("Failed to load checkpoint, restoring previous")
480
        load_model_weights(sd_model, current_checkpoint_info, None, timer)
481 482 483
        raise
    finally:
        sd_hijack.model_hijack.hijack(sd_model)
484 485
        timer.record("hijack")

486
        script_callbacks.model_loaded_callback(sd_model)
487
        timer.record("script callbacks")
488 489 490

        if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
            sd_model.to(devices.device)
491
            timer.record("move model to device")
492

493
    print(f"Weights loaded in {timer.summary()}.")
494

495
    return sd_model