sd_models.py 4.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
import glob
import os.path
import sys
from collections import namedtuple
import torch
from omegaconf import OmegaConf


from ldm.util import instantiate_from_config

from modules import shared

CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash'])
checkpoints_list = {}

try:
    # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.

    from transformers import logging

    logging.set_verbosity_error()
except Exception:
    pass


def list_models():
    checkpoints_list.clear()

    model_dir = os.path.abspath(shared.cmd_opts.ckpt_dir)

    def modeltitle(path, h):
        abspath = os.path.abspath(path)

        if abspath.startswith(model_dir):
            name = abspath.replace(model_dir, '')
        else:
            name = os.path.basename(path)

        if name.startswith("\\") or name.startswith("/"):
            name = name[1:]

        return f'{name} [{h}]'

    cmd_ckpt = shared.cmd_opts.ckpt
    if os.path.exists(cmd_ckpt):
        h = model_hash(cmd_ckpt)
        title = modeltitle(cmd_ckpt, h)
        checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h)
    elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
        print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr)

    if os.path.exists(model_dir):
        for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True):
            h = model_hash(filename)
            title = modeltitle(filename, h)
            checkpoints_list[title] = CheckpointInfo(filename, title, h)


def model_hash(filename):
    try:
        with open(filename, "rb") as file:
            import hashlib
            m = hashlib.sha256()

            file.seek(0x100000)
            m.update(file.read(0x10000))
            return m.hexdigest()[0:8]
    except FileNotFoundError:
        return 'NOFILE'


def select_checkpoint():
    model_checkpoint = shared.opts.sd_model_checkpoint
    checkpoint_info = checkpoints_list.get(model_checkpoint, None)
    if checkpoint_info is not None:
        return checkpoint_info

    if len(checkpoints_list) == 0:
        print(f"Checkpoint {model_checkpoint} not found and no other checkpoints found", file=sys.stderr)
        return None

    checkpoint_info = next(iter(checkpoints_list.values()))
    if model_checkpoint is not None:
        print(f"Checkpoint {model_checkpoint} not found; loading fallback {checkpoint_info.title}", file=sys.stderr)

    return checkpoint_info


def load_model_weights(model, checkpoint_file, sd_model_hash):
    print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")

    pl_sd = torch.load(checkpoint_file, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]

    model.load_state_dict(sd, strict=False)

    if shared.cmd_opts.opt_channelslast:
        model.to(memory_format=torch.channels_last)

    if not shared.cmd_opts.no_half:
        model.half()

    model.sd_model_hash = sd_model_hash
    model.sd_model_checkpint = checkpoint_file


def load_model():
    from modules import lowvram, sd_hijack
    checkpoint_info = select_checkpoint()

    sd_config = OmegaConf.load(shared.cmd_opts.config)
    sd_model = instantiate_from_config(sd_config.model)
    load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)

    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
        lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
    else:
        sd_model.to(shared.device)

    sd_hijack.model_hijack.hijack(sd_model)

    sd_model.eval()

    print(f"Model loaded.")
    return sd_model


130
def reload_model_weights(sd_model, info=None):
131
    from modules import lowvram, devices
132
    checkpoint_info = info or select_checkpoint()
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148

    if sd_model.sd_model_checkpint == checkpoint_info.filename:
        return

    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
        lowvram.send_everything_to_cpu()
    else:
        sd_model.to(devices.cpu)

    load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)

    if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
        sd_model.to(devices.device)

    print(f"Weights loaded.")
    return sd_model