extra_networks_lora.py 2.3 KB
Newer Older
1
from modules import extra_networks, shared
2
import networks
A
AUTOMATIC 已提交
3

4

A
AUTOMATIC 已提交
5 6 7 8 9
class ExtraNetworkLora(extra_networks.ExtraNetwork):
    def __init__(self):
        super().__init__('lora')

    def activate(self, p, params_list):
10 11
        additional = shared.opts.sd_lora

12
        if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional):
13 14 15
            p.all_prompts = [x + f"<lora:{additional}:{shared.opts.extra_networks_default_multiplier}>" for x in p.all_prompts]
            params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))

A
AUTOMATIC 已提交
16
        names = []
17 18 19
        te_multipliers = []
        unet_multipliers = []
        dyn_dims = []
A
AUTOMATIC 已提交
20
        for params in params_list:
21
            assert params.items
A
AUTOMATIC 已提交
22

23
            names.append(params.positional[0])
A
AUTOMATIC 已提交
24

25 26 27
            te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0
            te_multiplier = float(params.named.get("te", te_multiplier))

28
            unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else te_multiplier
29 30 31 32 33 34 35 36 37 38
            unet_multiplier = float(params.named.get("unet", unet_multiplier))

            dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
            dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim

            te_multipliers.append(te_multiplier)
            unet_multipliers.append(unet_multiplier)
            dyn_dims.append(dyn_dim)

        networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)
A
AUTOMATIC 已提交
39

A
AUTOMATIC 已提交
40
        if shared.opts.lora_add_hashes_to_infotext:
41 42 43
            network_hashes = []
            for item in networks.loaded_networks:
                shorthash = item.network_on_disk.shorthash
A
AUTOMATIC 已提交
44 45 46 47 48 49 50 51 52
                if not shorthash:
                    continue

                alias = item.mentioned_name
                if not alias:
                    continue

                alias = alias.replace(":", "").replace(",", "")

53
                network_hashes.append(f"{alias}: {shorthash}")
A
AUTOMATIC 已提交
54

55 56
            if network_hashes:
                p.extra_generation_params["Lora hashes"] = ", ".join(network_hashes)
A
AUTOMATIC 已提交
57

A
AUTOMATIC 已提交
58 59
    def deactivate(self, p):
        pass