hypernetwork.py 35.5 KB
Newer Older
D
discus0434 已提交
1
import csv
A
AUTOMATIC 已提交
2 3 4 5 6 7
import datetime
import glob
import html
import os
import sys
import traceback
8
import inspect
D
update  
discus0434 已提交
9

D
discus0434 已提交
10
import modules.textual_inversion.dataset
D
update  
discus0434 已提交
11
import torch
D
discus0434 已提交
12
import tqdm
D
update  
discus0434 已提交
13
from einops import rearrange, repeat
D
discus0434 已提交
14
from ldm.util import default
A
AUTOMATIC 已提交
15
from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint
16
from modules.textual_inversion import textual_inversion, logging
17
from modules.textual_inversion.learn_schedule import LearnRateScheduler
D
discus0434 已提交
18
from torch import einsum
19
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
20

A
AngelBottomless 已提交
21
from collections import defaultdict, deque
A
AngelBottomless 已提交
22
from statistics import stdev, mean
23

24

A
apply  
aria1th 已提交
25
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
26

A
AUTOMATIC 已提交
27
class HypernetworkModule(torch.nn.Module):
D
discus0434 已提交
28
    activation_dict = {
29
        "linear": torch.nn.Identity,
D
discus0434 已提交
30 31 32 33
        "relu": torch.nn.ReLU,
        "leakyrelu": torch.nn.LeakyReLU,
        "elu": torch.nn.ELU,
        "swish": torch.nn.Hardswish,
34 35
        "tanh": torch.nn.Tanh,
        "sigmoid": torch.nn.Sigmoid,
D
discus0434 已提交
36
    }
37
    activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
D
discus0434 已提交
38

39
    def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
A
aria1th 已提交
40
                 add_layer_norm=False, activate_output=False, dropout_structure=None):
A
AUTOMATIC 已提交
41
        super().__init__()
42

A
AUTOMATIC 已提交
43 44
        self.multiplier = 1.0

D
update  
discus0434 已提交
45
        assert layer_structure is not None, "layer_structure must not be None"
46 47
        assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
        assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
A
AUTOMATIC 已提交
48

49 50
        linears = []
        for i in range(len(layer_structure) - 1):
D
discus0434 已提交
51 52

            # Add a fully-connected layer
53
            linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
D
discus0434 已提交
54

55
            # Add an activation func except last layer
G
guaneec 已提交
56
            if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
D
discus0434 已提交
57 58 59
                pass
            elif activation_func in self.activation_dict:
                linears.append(self.activation_dict[activation_func]())
60
            else:
D
discus0434 已提交
61
                raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
D
discus0434 已提交
62 63

            # Add layer normalization
64
            if add_layer_norm:
65 66
                linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))

A
aria1th 已提交
67 68 69 70 71 72
            # Everything should be now parsed into dropout structure, and applied here.
            # Since we only have dropouts after layers, dropout structure should start with 0 and end with 0.
            if dropout_structure is not None and dropout_structure[i+1] > 0:
                assert 0 < dropout_structure[i+1] < 1, "Dropout probability should be 0 or float between 0 and 1!"
                linears.append(torch.nn.Dropout(p=dropout_structure[i+1]))
            # Code explanation : [1, 2, 1] -> dropout is missing when last_layer_dropout is false. [1, 2, 2, 1] -> [0, 0.3, 0, 0], when its True, [0, 0.3, 0.3, 0].
D
discus0434 已提交
73

74
        self.linear = torch.nn.Sequential(*linears)
A
AUTOMATIC 已提交
75 76

        if state_dict is not None:
77 78
            self.fix_old_state_dict(state_dict)
            self.load_state_dict(state_dict)
A
AUTOMATIC 已提交
79
        else:
80
            for layer in self.linear:
81
                if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
82 83 84
                    w, b = layer.weight.data, layer.bias.data
                    if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
                        normal_(w, mean=0.0, std=0.01)
85
                        normal_(b, mean=0.0, std=0)
86 87 88 89 90 91 92 93 94 95 96 97 98 99
                    elif weight_init == 'XavierUniform':
                        xavier_uniform_(w)
                        zeros_(b)
                    elif weight_init == 'XavierNormal':
                        xavier_normal_(w)
                        zeros_(b)
                    elif weight_init == 'KaimingUniform':
                        kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
                        zeros_(b)
                    elif weight_init == 'KaimingNormal':
                        kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
                        zeros_(b)
                    else:
                        raise KeyError(f"Key {weight_init} is not defined as initialization!")
A
AUTOMATIC 已提交
100 101
        self.to(devices.device)

102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
    def fix_old_state_dict(self, state_dict):
        changes = {
            'linear1.bias': 'linear.0.bias',
            'linear1.weight': 'linear.0.weight',
            'linear2.bias': 'linear.1.bias',
            'linear2.weight': 'linear.1.weight',
        }

        for fr, to in changes.items():
            x = state_dict.get(fr, None)
            if x is None:
                continue

            del state_dict[fr]
            state_dict[to] = x
117

A
AUTOMATIC 已提交
118
    def forward(self, x):
A
AUTOMATIC 已提交
119
        return x + self.linear(x) * (self.multiplier if not self.training else 1)
120 121

    def trainables(self):
122
        layer_structure = []
123
        for layer in self.linear:
124
            if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
D
update  
discus0434 已提交
125
                layer_structure += [layer.weight, layer.bias]
126
        return layer_structure
A
AUTOMATIC 已提交
127 128


A
aria1th 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
#param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check.
def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout):
    if layer_structure is None:
        layer_structure = [1, 2, 1]
    if not use_dropout:
        return [0] * len(layer_structure)
    dropout_values = [0]
    dropout_values.extend([0.3] * (len(layer_structure) - 3))
    if last_layer_dropout:
        dropout_values.append(0.3)
    else:
        dropout_values.append(0)
    dropout_values.append(0)
    return dropout_values

A
AUTOMATIC 已提交
144 145 146 147 148

class Hypernetwork:
    filename = None
    name = None

149
    def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
A
AUTOMATIC 已提交
150 151 152 153 154 155
        self.filename = None
        self.name = name
        self.layers = {}
        self.step = 0
        self.sd_checkpoint = None
        self.sd_checkpoint_name = None
156
        self.layer_structure = layer_structure
D
update  
discus0434 已提交
157
        self.activation_func = activation_func
158
        self.weight_init = weight_init
159
        self.add_layer_norm = add_layer_norm
D
discus0434 已提交
160
        self.use_dropout = use_dropout
G
guaneec 已提交
161
        self.activate_output = activate_output
A
aria1th 已提交
162 163 164 165
        self.last_layer_dropout = kwargs.get('last_layer_dropout', True)
        self.dropout_structure = kwargs.get('dropout_structure', None)
        if self.dropout_structure is None:
            self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
A
apply  
aria1th 已提交
166 167
        self.optimizer_name = None
        self.optimizer_state_dict = None
A
aria1th 已提交
168
        self.optional_info = None
A
AUTOMATIC 已提交
169

170
        for size in enable_sizes or []:
171
            self.layers[size] = (
172
                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
A
aria1th 已提交
173
                                   self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
174
                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
A
aria1th 已提交
175
                                   self.add_layer_norm, self.activate_output, dropout_structure=self.dropout_structure),
176
            )
A
aria1th 已提交
177
        self.eval()
A
AUTOMATIC 已提交
178 179 180

    def weights(self):
        res = []
F
flamelaw 已提交
181 182 183 184
        for k, layers in self.layers.items():
            for layer in layers:
                res += layer.parameters()
        return res
A
AUTOMATIC 已提交
185

A
aria1th 已提交
186
    def train(self, mode=True):
A
AUTOMATIC 已提交
187 188
        for k, layers in self.layers.items():
            for layer in layers:
A
aria1th 已提交
189
                layer.train(mode=mode)
F
flamelaw 已提交
190
                for param in layer.parameters():
A
aria1th 已提交
191
                    param.requires_grad = mode
A
AUTOMATIC 已提交
192

A
AUTOMATIC 已提交
193 194 195 196 197 198 199 200 201 202 203 204 205 206
    def to(self, device):
        for k, layers in self.layers.items():
            for layer in layers:
                layer.to(device)

        return self

    def set_multiplier(self, multiplier):
        for k, layers in self.layers.items():
            for layer in layers:
                layer.multiplier = multiplier

        return self

A
aria1th 已提交
207
    def eval(self):
F
flamelaw 已提交
208 209 210 211 212
        for k, layers in self.layers.items():
            for layer in layers:
                layer.eval()
                for param in layer.parameters():
                    param.requires_grad = False
A
AUTOMATIC 已提交
213 214 215

    def save(self, filename):
        state_dict = {}
A
apply  
aria1th 已提交
216
        optimizer_saved_dict = {}
A
AUTOMATIC 已提交
217 218 219 220 221 222

        for k, v in self.layers.items():
            state_dict[k] = (v[0].state_dict(), v[1].state_dict())

        state_dict['step'] = self.step
        state_dict['name'] = self.name
223
        state_dict['layer_structure'] = self.layer_structure
D
update  
discus0434 已提交
224
        state_dict['activation_func'] = self.activation_func
225
        state_dict['is_layer_norm'] = self.add_layer_norm
226
        state_dict['weight_initialization'] = self.weight_init
A
AUTOMATIC 已提交
227 228
        state_dict['sd_checkpoint'] = self.sd_checkpoint
        state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
G
guaneec 已提交
229
        state_dict['activate_output'] = self.activate_output
A
aria1th 已提交
230 231 232 233
        state_dict['use_dropout'] = self.use_dropout
        state_dict['dropout_structure'] = self.dropout_structure
        state_dict['last_layer_dropout'] = (self.dropout_structure[-2] != 0) if self.dropout_structure is not None else self.last_layer_dropout
        state_dict['optional_info'] = self.optional_info if self.optional_info else None
A
apply  
aria1th 已提交
234 235 236

        if self.optimizer_name is not None:
            optimizer_saved_dict['optimizer_name'] = self.optimizer_name
A
AUTOMATIC 已提交
237 238

        torch.save(state_dict, filename)
A
aria1th 已提交
239
        if shared.opts.save_optimizer_state and self.optimizer_state_dict:
240
            optimizer_saved_dict['hash'] = self.shorthash()
A
apply  
aria1th 已提交
241 242
            optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
            torch.save(optimizer_saved_dict, filename + '.optim')
A
AUTOMATIC 已提交
243 244 245 246 247 248 249 250

    def load(self, filename):
        self.filename = filename
        if self.name is None:
            self.name = os.path.splitext(os.path.basename(filename))[0]

        state_dict = torch.load(filename, map_location='cpu')

251
        self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
252
        self.optional_info = state_dict.get('optional_info', None)
D
update  
discus0434 已提交
253
        self.activation_func = state_dict.get('activation_func', None)
254
        self.weight_init = state_dict.get('weight_initialization', 'Normal')
255
        self.add_layer_norm = state_dict.get('is_layer_norm', False)
A
aria1th 已提交
256 257
        self.dropout_structure = state_dict.get('dropout_structure', None)
        self.use_dropout = True if self.dropout_structure is not None and any(self.dropout_structure) else state_dict.get('use_dropout', False)
G
guaneec 已提交
258
        self.activate_output = state_dict.get('activate_output', True)
259
        self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
A
aria1th 已提交
260 261 262
        # Dropout structure should have same length as layer structure, Every digits should be in [0,1), and last digit must be 0.
        if self.dropout_structure is None:
            self.dropout_structure = parse_dropout_structure(self.layer_structure, self.use_dropout, self.last_layer_dropout)
263

264 265 266
        if shared.opts.print_hypernet_extra:
            if self.optional_info is not None:
                print(f"  INFO:\n {self.optional_info}\n")
A
aria1th 已提交
267

268 269 270 271 272 273 274 275 276 277 278
            print(f"  Layer structure: {self.layer_structure}")
            print(f"  Activation function: {self.activation_func}")
            print(f"  Weight initialization: {self.weight_init}")
            print(f"  Layer norm: {self.add_layer_norm}")
            print(f"  Dropout usage: {self.use_dropout}" )
            print(f"  Activate last layer: {self.activate_output}")
            print(f"  Dropout structure: {self.dropout_structure}")

        optimizer_saved_dict = torch.load(self.filename + '.optim', map_location='cpu') if os.path.exists(self.filename + '.optim') else {}

        if self.shorthash() == optimizer_saved_dict.get('hash', None):
A
apply  
aria1th 已提交
279 280 281 282
            self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
        else:
            self.optimizer_state_dict = None
        if self.optimizer_state_dict:
A
aria1th 已提交
283
            self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
A
AUTOMATIC 已提交
284 285 286
            if shared.opts.print_hypernet_extra:
                print("Loaded existing optimizer from checkpoint")
                print(f"Optimizer name is {self.optimizer_name}")
A
apply  
aria1th 已提交
287
        else:
A
aria1th 已提交
288
            self.optimizer_name = "AdamW"
A
AUTOMATIC 已提交
289 290
            if shared.opts.print_hypernet_extra:
                print("No saved optimizer exists in checkpoint")
291

A
AUTOMATIC 已提交
292 293
        for size, sd in state_dict.items():
            if type(size) == int:
294
                self.layers[size] = (
295
                    HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
A
aria1th 已提交
296
                                       self.add_layer_norm, self.activate_output, self.dropout_structure),
297
                    HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
A
aria1th 已提交
298
                                       self.add_layer_norm, self.activate_output, self.dropout_structure),
299
                )
A
AUTOMATIC 已提交
300 301 302 303 304

        self.name = state_dict.get('name', self.name)
        self.step = state_dict.get('step', 0)
        self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
        self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)
A
aria1th 已提交
305
        self.eval()
A
AUTOMATIC 已提交
306

307 308 309
    def shorthash(self):
        sha256 = hashes.sha256(self.filename, f'hypernet/{self.name}')

A
AUTOMATIC 已提交
310
        return sha256[0:10] if sha256 else None
311

A
AUTOMATIC 已提交
312

A
AUTOMATIC 已提交
313
def list_hypernetworks(path):
A
AUTOMATIC 已提交
314
    res = {}
I
Isaac Poulton 已提交
315
    for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
A
AUTOMATIC 已提交
316
        name = os.path.splitext(os.path.basename(filename))[0]
317 318
        # Prevent a hypothetical "None.pt" from being listed.
        if name != "None":
319
            res[name] = filename
A
AUTOMATIC 已提交
320
    return res
A
AUTOMATIC 已提交
321

A
AUTOMATIC 已提交
322

A
AUTOMATIC 已提交
323 324
def load_hypernetwork(name):
    path = shared.hypernetworks.get(name, None)
A
AUTOMATIC 已提交
325

A
AUTOMATIC 已提交
326 327
    if path is None:
        return None
A
AUTOMATIC 已提交
328

A
AUTOMATIC 已提交
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359
    hypernetwork = Hypernetwork()

    try:
        hypernetwork.load(path)
    except Exception:
        print(f"Error loading hypernetwork {path}", file=sys.stderr)
        print(traceback.format_exc(), file=sys.stderr)
        return None

    return hypernetwork


def load_hypernetworks(names, multipliers=None):
    already_loaded = {}

    for hypernetwork in shared.loaded_hypernetworks:
        if hypernetwork.name in names:
            already_loaded[hypernetwork.name] = hypernetwork

    shared.loaded_hypernetworks.clear()

    for i, name in enumerate(names):
        hypernetwork = already_loaded.get(name, None)
        if hypernetwork is None:
            hypernetwork = load_hypernetwork(name)

        if hypernetwork is None:
            continue

        hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0)
        shared.loaded_hypernetworks.append(hypernetwork)
A
AUTOMATIC 已提交
360 361


M
Milly 已提交
362 363 364 365 366 367 368 369 370 371 372
def find_closest_hypernetwork_name(search: str):
    if not search:
        return None
    search = search.lower()
    applicable = [name for name in shared.hypernetworks if search in name.lower()]
    if not applicable:
        return None
    applicable = sorted(applicable, key=lambda name: len(name))
    return applicable[0]


A
AUTOMATIC 已提交
373 374
def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None):
    hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None)
A
AUTOMATIC 已提交
375

A
AUTOMATIC 已提交
376
    if hypernetwork_layers is None:
A
AUTOMATIC 已提交
377
        return context_k, context_v
A
AUTOMATIC 已提交
378

A
AUTOMATIC 已提交
379 380 381
    if layer is not None:
        layer.hyper_k = hypernetwork_layers[0]
        layer.hyper_v = hypernetwork_layers[1]
A
AUTOMATIC 已提交
382

A
AUTOMATIC 已提交
383 384 385 386 387 388 389 390 391 392 393
    context_k = hypernetwork_layers[0](context_k)
    context_v = hypernetwork_layers[1](context_v)
    return context_k, context_v


def apply_hypernetworks(hypernetworks, context, layer=None):
    context_k = context
    context_v = context
    for hypernetwork in hypernetworks:
        context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer)

A
AUTOMATIC 已提交
394
    return context_k, context_v
A
AUTOMATIC 已提交
395 396


A
AUTOMATIC 已提交
397 398 399 400 401
def attention_CrossAttention_forward(self, x, context=None, mask=None):
    h = self.heads

    q = self.to_q(x)
    context = default(context, x)
A
AUTOMATIC 已提交
402

A
AUTOMATIC 已提交
403
    context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self)
A
AUTOMATIC 已提交
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
    k = self.to_k(context_k)
    v = self.to_v(context_v)

    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

    sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

    if mask is not None:
        mask = rearrange(mask, 'b ... -> b (...)')
        max_neg_value = -torch.finfo(sim.dtype).max
        mask = repeat(mask, 'b j -> (b h) () j', h=h)
        sim.masked_fill_(~mask, max_neg_value)

    # attention, what we cannot get enough of
    attn = sim.softmax(dim=-1)

    out = einsum('b i j, b j d -> b i d', attn, v)
    out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
    return self.to_out(out)


425 426 427 428 429 430 431 432 433 434 435 436 437 438
def stack_conds(conds):
    if len(conds) == 1:
        return torch.stack(conds)

    # same as in reconstruct_multicond_batch
    token_count = max([x.shape[0] for x in conds])
    for i in range(len(conds)):
        if conds[i].shape[0] != token_count:
            last_vector = conds[i][-1:]
            last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
            conds[i] = torch.vstack([conds[i], last_vector_repeated])

    return torch.stack(conds)

439

A
AngelBottomless 已提交
440
def statistics(data):
A
AngelBottomless 已提交
441 442 443 444 445
    if len(data) < 2:
        std = 0
    else:
        std = stdev(data)
    total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
A
AngelBottomless 已提交
446
    recent_data = data[-32:]
A
AngelBottomless 已提交
447 448 449 450 451
    if len(recent_data) < 2:
        std = 0
    else:
        std = stdev(recent_data)
    recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
A
AngelBottomless 已提交
452 453 454 455 456 457
    return total_information, recent_information


def report_statistics(loss_info:dict):
    keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
    for key in keys:
D
DepFA 已提交
458 459
        try:
            print("Loss statistics for file " + key)
A
AngelBottomless 已提交
460
            info, recent = statistics(list(loss_info[key]))
D
DepFA 已提交
461 462 463 464
            print(info)
            print(recent)
        except Exception as e:
            print(e)
A
AngelBottomless 已提交
465 466


A
aria1th 已提交
467
def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None):
V
Vladimir Mandic 已提交
468 469
    # Remove illegal characters from name.
    name = "".join( x for x in name if (x.isalnum() or x in "._- "))
A
aria1th 已提交
470
    assert name, "Name cannot be empty!"
V
Vladimir Mandic 已提交
471 472 473 474 475 476 477 478

    fn = os.path.join(shared.cmd_opts.hypernetwork_dir, f"{name}.pt")
    if not overwrite_old:
        assert not os.path.exists(fn), f"file {fn} already exists"

    if type(layer_structure) == str:
        layer_structure = [float(x.strip()) for x in layer_structure.split(",")]

A
aria1th 已提交
479 480 481 482 483
    if use_dropout and dropout_structure and type(dropout_structure) == str:
        dropout_structure = [float(x.strip()) for x in dropout_structure.split(",")]
    else:
        dropout_structure = [0] * len(layer_structure)

V
Vladimir Mandic 已提交
484 485 486 487 488 489 490 491
    hypernet = modules.hypernetworks.hypernetwork.Hypernetwork(
        name=name,
        enable_sizes=[int(x) for x in enable_sizes],
        layer_structure=layer_structure,
        activation_func=activation_func,
        weight_init=weight_init,
        add_layer_norm=add_layer_norm,
        use_dropout=use_dropout,
A
aria1th 已提交
492
        dropout_structure=dropout_structure
V
Vladimir Mandic 已提交
493 494 495 496
    )
    hypernet.save(fn)

    shared.reload_hypernetworks()
497

V
Vladimir Mandic 已提交
498

499
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
T
timntorres 已提交
500
    # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
501 502
    from modules import images

503 504
    save_hypernetwork_every = save_hypernetwork_every or 0
    create_image_every = create_image_every or 0
505 506 507
    template_file = textual_inversion.textual_inversion_templates.get(template_filename, None)
    textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
    template_file = template_file.path
A
AUTOMATIC 已提交
508

A
AUTOMATIC 已提交
509
    path = shared.hypernetworks.get(hypernetwork_name, None)
A
AUTOMATIC 已提交
510 511 512
    hypernetwork = Hypernetwork()
    hypernetwork.load(path)
    shared.loaded_hypernetworks = [hypernetwork]
A
AUTOMATIC 已提交
513

V
Vladimir Mandic 已提交
514
    shared.state.job = "train-hypernetwork"
A
AUTOMATIC 已提交
515 516 517
    shared.state.textinfo = "Initializing hypernetwork training..."
    shared.state.job_count = steps

A
aria1th 已提交
518
    hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
A
AUTOMATIC 已提交
519 520 521
    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')

    log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
522
    unload = shared.opts.unload_models_when_training
A
AUTOMATIC 已提交
523 524 525 526 527 528 529 530 531 532 533 534 535

    if save_hypernetwork_every > 0:
        hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
        os.makedirs(hypernetwork_dir, exist_ok=True)
    else:
        hypernetwork_dir = None

    if create_image_every > 0:
        images_dir = os.path.join(log_directory, "images")
        os.makedirs(images_dir, exist_ok=True)
    else:
        images_dir = None

536
    checkpoint = sd_models.select_checkpoint()
A
AUTOMATIC 已提交
537

M
Melan 已提交
538
    initial_step = hypernetwork.step or 0
539
    if initial_step >= steps:
540
        shared.state.textinfo = "Model has already been trained beyond specified max steps"
A
AUTOMATIC 已提交
541 542
        return hypernetwork, filename

M
Melan 已提交
543
    scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
544
    
545
    clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
M
Muhammad Rizqi Nur 已提交
546
    if clip_grad:
547
        clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
A
AUTOMATIC 已提交
548

549 550 551
    if shared.opts.training_enable_tensorboard:
        tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)

552
    # dataset loading may take a while, so input validations and early returns should be done before this
A
AUTOMATIC 已提交
553 554
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."

555
    pin_memory = shared.opts.pin_memory
556

D
dan 已提交
557
    ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
558

559
    if shared.opts.save_training_settings_to_txt:
560
        saved_params = dict(
A
AUTOMATIC 已提交
561
            model_name=checkpoint.model_name, model_hash=checkpoint.shorthash, num_of_dataset_images=len(ds),
562 563 564
            **{field: getattr(hypernetwork, field) for field in ['layer_structure', 'activation_func', 'weight_init', 'add_layer_norm', 'use_dropout', ]}
        )
        logging.save_settings_to_file(log_directory, {**saved_params, **locals()})
A
AUTOMATIC 已提交
565

566
    latent_sampling_method = ds.latent_sampling_method
567

568
    dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
569

570
    old_parallel_processing_allowed = shared.parallel_processing_allowed
A
AUTOMATIC 已提交
571

572
    if unload:
573
        shared.parallel_processing_allowed = False
574 575
        shared.sd_model.cond_stage_model.to(devices.cpu)
        shared.sd_model.first_stage_model.to(devices.cpu)
A
AUTOMATIC 已提交
576

577
    weights = hypernetwork.weights()
A
aria1th 已提交
578
    hypernetwork.train()
A
AUTOMATIC 已提交
579

A
apply  
aria1th 已提交
580
    # Here we use optimizer from saved HN, or we can specify as UI option.
581
    if hypernetwork.optimizer_name in optimizer_dict:
A
apply  
aria1th 已提交
582
        optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
583
        optimizer_name = hypernetwork.optimizer_name
A
apply  
aria1th 已提交
584
    else:
585
        print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
A
apply  
aria1th 已提交
586 587
        optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
        optimizer_name = 'AdamW'
588

A
apply  
aria1th 已提交
589 590 591 592 593 594
    if hypernetwork.optimizer_state_dict:  # This line must be changed if Optimizer type can be different from saved optimizer.
        try:
            optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
        except RuntimeError as e:
            print("Cannot resume from saved optimizer!")
            print(e)
A
AUTOMATIC 已提交
595

596 597 598 599 600 601 602 603 604 605 606
    scaler = torch.cuda.amp.GradScaler()
    
    batch_size = ds.batch_size
    gradient_step = ds.gradient_step
    # n steps = batch_size * gradient_step * n image processed
    steps_per_epoch = len(ds) // batch_size // gradient_step
    max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
    loss_step = 0
    _loss_step = 0 #internal
    # size = len(ds.indexes)
    # loss_dict = defaultdict(lambda : deque(maxlen = 1024))
A
aria1th 已提交
607
    loss_logging = deque(maxlen=len(ds) * 3)  # this should be configurable parameter, this is 3 * epoch(dataset size)
608 609 610 611
    # losses = torch.zeros((size,))
    # previous_mean_losses = [0]
    # previous_mean_loss = 0
    # print("Mean loss of {} elements".format(size))
A
AUTOMATIC 已提交
612

613 614
    steps_without_grad = 0

615 616 617 618
    last_saved_file = "<none>"
    last_saved_image = "<none>"
    forced_filename = "<none>"

619 620
    pbar = tqdm.tqdm(total=steps - initial_step)
    try:
A
AUTOMATIC 已提交
621 622
        sd_hijack_checkpoint.add()

623 624 625 626 627 628 629 630 631 632 633 634 635 636 637
        for i in range((steps-initial_step) * gradient_step):
            if scheduler.finished:
                break
            if shared.state.interrupted:
                break
            for j, batch in enumerate(dl):
                # works as a drop_last=True for gradient accumulation
                if j == max_steps_per_epoch:
                    break
                scheduler.apply(optimizer, hypernetwork.step)
                if scheduler.finished:
                    break
                if shared.state.interrupted:
                    break

638 639
                if clip_grad:
                    clip_grad_sched.step(hypernetwork.step)
A
AngelBottomless 已提交
640
                
641
                with devices.autocast():
642
                    x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
S
Shondoit 已提交
643
                    w = batch.weight.to(devices.device, non_blocking=pin_memory)
644 645 646 647 648 649
                    if tag_drop_out != 0 or shuffle_tags:
                        shared.sd_model.cond_stage_model.to(devices.device)
                        c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
                        shared.sd_model.cond_stage_model.to(devices.cpu)
                    else:
                        c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
S
Shondoit 已提交
650
                    loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
651 652
                    del x
                    del c
A
aria1th 已提交
653

654 655
                    _loss_step += loss.item()
                scaler.scale(loss).backward()
656
                
657 658 659
                # go back until we reach gradient accumulation steps
                if (j + 1) % gradient_step != 0:
                    continue
A
aria1th 已提交
660
                loss_logging.append(_loss_step)
661 662 663
                if clip_grad:
                    clip_grad(weights, clip_grad_sched.learn_rate)
                
664 665 666 667 668 669 670 671 672
                scaler.step(optimizer)
                scaler.update()
                hypernetwork.step += 1
                pbar.update()
                optimizer.zero_grad(set_to_none=True)
                loss_step = _loss_step
                _loss_step = 0

                steps_done = hypernetwork.step + 1
A
AngelBottomless 已提交
673
                
674 675 676
                epoch_num = hypernetwork.step // steps_per_epoch
                epoch_step = hypernetwork.step % steps_per_epoch

V
Vladimir Mandic 已提交
677 678
                description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}"
                pbar.set_description(description)
679 680 681 682 683 684 685 686 687 688
                if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
                    # Before saving, change name to match current checkpoint.
                    hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
                    last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
                    hypernetwork.optimizer_name = optimizer_name
                    if shared.opts.save_optimizer_state:
                        hypernetwork.optimizer_state_dict = optimizer.state_dict()
                    save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
                    hypernetwork.optimizer_state_dict = None  # dereference it after saving, to save memory.

689 690 691 692 693


                if shared.opts.training_enable_tensorboard:
                    epoch_num = hypernetwork.step // len(ds)
                    epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
A
aria1th 已提交
694
                    mean_loss = sum(loss_logging) / len(loss_logging)
695 696
                    textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)

697 698 699 700 701 702 703 704
                textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
                    "loss": f"{loss_step:.7f}",
                    "learn_rate": scheduler.learn_rate
                })

                if images_dir is not None and steps_done % create_image_every == 0:
                    forced_filename = f'{hypernetwork_name}-{steps_done}'
                    last_saved_image = os.path.join(images_dir, forced_filename)
A
aria1th 已提交
705 706 707 708 709
                    hypernetwork.eval()
                    rng_state = torch.get_rng_state()
                    cuda_rng_state = None
                    if torch.cuda.is_available():
                        cuda_rng_state = torch.cuda.get_rng_state_all()
710 711 712 713 714 715 716 717 718
                    shared.sd_model.cond_stage_model.to(devices.device)
                    shared.sd_model.first_stage_model.to(devices.device)

                    p = processing.StableDiffusionProcessingTxt2Img(
                        sd_model=shared.sd_model,
                        do_not_save_grid=True,
                        do_not_save_samples=True,
                    )

A
AUTOMATIC 已提交
719 720
                    p.disable_extra_networks = True

721 722 723 724 725 726 727 728 729 730 731 732 733 734
                    if preview_from_txt2img:
                        p.prompt = preview_prompt
                        p.negative_prompt = preview_negative_prompt
                        p.steps = preview_steps
                        p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
                        p.cfg_scale = preview_cfg_scale
                        p.seed = preview_seed
                        p.width = preview_width
                        p.height = preview_height
                    else:
                        p.prompt = batch.cond_text[0]
                        p.steps = 20
                        p.width = training_width
                        p.height = training_height
A
aria1th 已提交
735

736
                    preview_text = p.prompt
A
aria1th 已提交
737

738 739
                    processed = processing.process_images(p)
                    image = processed.images[0] if len(processed.images) > 0 else None
A
aria1th 已提交
740

741 742 743
                    if unload:
                        shared.sd_model.cond_stage_model.to(devices.cpu)
                        shared.sd_model.first_stage_model.to(devices.cpu)
A
aria1th 已提交
744 745 746 747
                    torch.set_rng_state(rng_state)
                    if torch.cuda.is_available():
                        torch.cuda.set_rng_state_all(cuda_rng_state)
                    hypernetwork.train()
748
                    if image is not None:
749
                        shared.state.assign_current_image(image)
A
aria1th 已提交
750 751 752 753
                        if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
                            textual_inversion.tensorboard_add_image(tensorboard_writer,
                                                                    f"Validation at epoch {epoch_num}", image,
                                                                    hypernetwork.step)
754 755
                        last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
                        last_saved_image += f", prompt: {preview_text}"
A
AUTOMATIC 已提交
756

757
                shared.state.job_no = hypernetwork.step
A
AUTOMATIC 已提交
758

759
                shared.state.textinfo = f"""
A
AUTOMATIC 已提交
760
<p>
761
Loss: {loss_step:.7f}<br/>
F
flamelaw 已提交
762
Step: {steps_done}<br/>
763
Last prompt: {html.escape(batch.cond_text[0])}<br/>
D
DepFA 已提交
764
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
A
AUTOMATIC 已提交
765 766 767
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
768 769 770 771 772
    except Exception:
        print(traceback.format_exc(), file=sys.stderr)
    finally:
        pbar.leave = False
        pbar.close()
A
aria1th 已提交
773
        hypernetwork.eval()
774
        #report_statistics(loss_dict)
A
AUTOMATIC 已提交
775 776 777
        sd_hijack_checkpoint.remove()


A
AUTOMATIC 已提交
778

779
    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
A
apply  
aria1th 已提交
780 781 782
    hypernetwork.optimizer_name = optimizer_name
    if shared.opts.save_optimizer_state:
        hypernetwork.optimizer_state_dict = optimizer.state_dict()
783
    save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
A
AUTOMATIC 已提交
784

A
apply  
aria1th 已提交
785 786
    del optimizer
    hypernetwork.optimizer_state_dict = None  # dereference it after saving, to save memory.
787 788
    shared.sd_model.cond_stage_model.to(devices.device)
    shared.sd_model.first_stage_model.to(devices.device)
789
    shared.parallel_processing_allowed = old_parallel_processing_allowed
A
AUTOMATIC 已提交
790 791 792

    return hypernetwork, filename

793 794 795 796 797
def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
    old_hypernetwork_name = hypernetwork.name
    old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
    old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
    try:
A
AUTOMATIC 已提交
798
        hypernetwork.sd_checkpoint = checkpoint.shorthash
799 800 801 802 803 804 805 806
        hypernetwork.sd_checkpoint_name = checkpoint.model_name
        hypernetwork.name = hypernetwork_name
        hypernetwork.save(filename)
    except:
        hypernetwork.sd_checkpoint = old_sd_checkpoint
        hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
        hypernetwork.name = old_hypernetwork_name
        raise