modelloader.py 6.1 KB
Newer Older
D
d8ahazard 已提交
1
import glob
D
d8ahazard 已提交
2
import os
D
d8ahazard 已提交
3
import shutil
D
d8ahazard 已提交
4
import importlib
D
d8ahazard 已提交
5 6
from urllib.parse import urlparse

D
d8ahazard 已提交
7
from modules import shared
8
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone
D
d8ahazard 已提交
9 10
from modules.paths import script_path, models_path

D
d8ahazard 已提交
11

12
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
D
d8ahazard 已提交
13 14 15
    """
    A one-and done loader to try finding the desired models in specified directories.

D
d8ahazard 已提交
16 17
    @param download_name: Specify to download from model_url immediately.
    @param model_url: If no other models are found, this will be downloaded on upscale.
D
d8ahazard 已提交
18 19 20 21 22
    @param model_path: The location to store/find models in.
    @param command_path: A command-line argument to search for models in first.
    @param ext_filter: An optional list of filename extensions to filter by
    @return: A list of paths containing the desired model(s)
    """
D
d8ahazard 已提交
23 24
    output = []

D
d8ahazard 已提交
25 26
    try:
        places = []
27

D
d8ahazard 已提交
28 29 30
        if command_path is not None and command_path != model_path:
            pretrained_path = os.path.join(command_path, 'experiments/pretrained_models')
            if os.path.exists(pretrained_path):
D
d8ahazard 已提交
31
                print(f"Appending path: {pretrained_path}")
D
d8ahazard 已提交
32 33 34
                places.append(pretrained_path)
            elif os.path.exists(command_path):
                places.append(command_path)
35

D
d8ahazard 已提交
36
        places.append(model_path)
37

D
d8ahazard 已提交
38
        for place in places:
39 40 41 42 43 44 45 46
            for full_path in shared.walk_files(place, allowed_extensions=ext_filter):
                if os.path.islink(full_path) and not os.path.exists(full_path):
                    print(f"Skipping broken symlink: {full_path}")
                    continue
                if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
                    continue
                if full_path not in output:
                    output.append(full_path)
47

D
d8ahazard 已提交
48 49
        if model_url is not None and len(output) == 0:
            if download_name is not None:
A
ArrowM 已提交
50
                from basicsr.utils.download_util import load_file_from_url
D
d8ahazard 已提交
51 52
                dl = load_file_from_url(model_url, model_path, True, download_name)
                output.append(dl)
D
d8ahazard 已提交
53
            else:
D
d8ahazard 已提交
54
                output.append(model_url)
55 56

    except Exception:
D
d8ahazard 已提交
57
        pass
58

D
d8ahazard 已提交
59
    return output
D
d8ahazard 已提交
60 61 62 63 64 65 66 67 68


def friendly_name(file: str):
    if "http" in file:
        file = urlparse(file).path

    file = os.path.basename(file)
    model_name, extension = os.path.splitext(file)
    return model_name
D
d8ahazard 已提交
69 70 71


def cleanup_models():
72 73 74
    # This code could probably be more efficient if we used a tuple list or something to store the src/destinations
    # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler
    # somehow auto-register and just do these things...
D
d8ahazard 已提交
75
    root_path = script_path
76 77 78
    src_path = models_path
    dest_path = os.path.join(models_path, "Stable-diffusion")
    move_files(src_path, dest_path, ".ckpt")
79
    move_files(src_path, dest_path, ".safetensors")
D
d8ahazard 已提交
80 81 82
    src_path = os.path.join(root_path, "ESRGAN")
    dest_path = os.path.join(models_path, "ESRGAN")
    move_files(src_path, dest_path)
83 84 85
    src_path = os.path.join(models_path, "BSRGAN")
    dest_path = os.path.join(models_path, "ESRGAN")
    move_files(src_path, dest_path, ".pth")
D
d8ahazard 已提交
86 87 88 89 90 91 92 93 94 95 96
    src_path = os.path.join(root_path, "gfpgan")
    dest_path = os.path.join(models_path, "GFPGAN")
    move_files(src_path, dest_path)
    src_path = os.path.join(root_path, "SwinIR")
    dest_path = os.path.join(models_path, "SwinIR")
    move_files(src_path, dest_path)
    src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/")
    dest_path = os.path.join(models_path, "LDSR")
    move_files(src_path, dest_path)


97
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
D
d8ahazard 已提交
98 99 100 101 102
    try:
        if not os.path.exists(dest_path):
            os.makedirs(dest_path)
        if os.path.exists(src_path):
            for file in os.listdir(src_path):
103 104 105 106 107 108
                fullpath = os.path.join(src_path, file)
                if os.path.isfile(fullpath):
                    if ext_filter is not None:
                        if ext_filter not in file:
                            continue
                    print(f"Moving {file} from {src_path} to {dest_path}.")
D
d8ahazard 已提交
109 110 111 112
                    try:
                        shutil.move(fullpath, dest_path)
                    except:
                        pass
113 114 115
            if len(os.listdir(src_path)) == 0:
                print(f"Removing empty folder: {src_path}")
                shutil.rmtree(src_path, True)
D
d8ahazard 已提交
116
    except:
D
d8ahazard 已提交
117 118 119 120
        pass


def load_upscalers():
D
d8ahazard 已提交
121 122
    # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
    # so we'll try to import any _model.py files before looking in __subclasses__
A
AUTOMATIC 已提交
123
    modules_dir = os.path.join(shared.script_path, "modules")
D
d8ahazard 已提交
124 125 126 127 128 129 130 131
    for file in os.listdir(modules_dir):
        if "_model.py" in file:
            model_name = file.replace("_model.py", "")
            full_model = f"modules.{model_name}_model"
            try:
                importlib.import_module(full_model)
            except:
                pass
A
AUTOMATIC 已提交
132

D
d8ahazard 已提交
133
    datas = []
A
AUTOMATIC 已提交
134
    commandline_options = vars(shared.cmd_opts)
135

136 137 138 139 140 141 142 143 144 145
    # some of upscaler classes will not go away after reloading their modules, and we'll end
    # up with two copies of those classes. The newest copy will always be the last in the list,
    # so we go from end to beginning and ignore duplicates
    used_classes = {}
    for cls in reversed(Upscaler.__subclasses__()):
        classname = str(cls)
        if classname not in used_classes:
            used_classes[classname] = cls

    for cls in reversed(used_classes.values()):
D
d8ahazard 已提交
146
        name = cls.__name__
D
d8ahazard 已提交
147
        cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
A
AUTOMATIC 已提交
148 149
        scaler = cls(commandline_options.get(cmd_name, None))
        datas += scaler.scalers
D
d8ahazard 已提交
150

B
Brad Smith 已提交
151 152 153
    shared.sd_upscalers = sorted(
        datas,
        # Special case for UpscalerNone keeps it at the beginning of the list.
154
        key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else ""
B
Brad Smith 已提交
155
    )