hypernetwork.py 27.5 KB
Newer Older
D
discus0434 已提交
1
import csv
A
AUTOMATIC 已提交
2 3 4 5 6 7
import datetime
import glob
import html
import os
import sys
import traceback
8
import inspect
A
AUTOMATIC 已提交
9

D
discus0434 已提交
10
import modules.textual_inversion.dataset
A
AUTOMATIC 已提交
11
import torch
D
discus0434 已提交
12
import tqdm
D
update  
discus0434 已提交
13
from einops import rearrange, repeat
D
discus0434 已提交
14
from ldm.util import default
15
from modules import devices, processing, sd_models, shared, sd_samplers
16
from modules.textual_inversion import textual_inversion
17
from modules.textual_inversion.learn_schedule import LearnRateScheduler
D
discus0434 已提交
18
from torch import einsum
19
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
20

A
AngelBottomless 已提交
21
from collections import defaultdict, deque
A
AngelBottomless 已提交
22
from statistics import stdev, mean
23

24

A
apply  
aria1th 已提交
25 26
optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}

A
AUTOMATIC 已提交
27
class HypernetworkModule(torch.nn.Module):
A
AUTOMATIC 已提交
28
    multiplier = 1.0
D
discus0434 已提交
29
    activation_dict = {
30
        "linear": torch.nn.Identity,
D
discus0434 已提交
31 32 33 34
        "relu": torch.nn.ReLU,
        "leakyrelu": torch.nn.LeakyReLU,
        "elu": torch.nn.ELU,
        "swish": torch.nn.Hardswish,
35 36
        "tanh": torch.nn.Tanh,
        "sigmoid": torch.nn.Sigmoid,
D
discus0434 已提交
37
    }
38
    activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
D
discus0434 已提交
39

40
    def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
G
guaneec 已提交
41
                 add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=True):
A
AUTOMATIC 已提交
42
        super().__init__()
43

D
update  
discus0434 已提交
44
        assert layer_structure is not None, "layer_structure must not be None"
45 46
        assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
        assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
D
discus0434 已提交
47

48 49
        linears = []
        for i in range(len(layer_structure) - 1):
D
discus0434 已提交
50 51

            # Add a fully-connected layer
52
            linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
D
discus0434 已提交
53

54
            # Add an activation func except last layer
G
guaneec 已提交
55
            if activation_func == "linear" or activation_func is None or (i >= len(layer_structure) - 2 and not activate_output):
D
discus0434 已提交
56 57 58
                pass
            elif activation_func in self.activation_dict:
                linears.append(self.activation_dict[activation_func]())
59
            else:
D
discus0434 已提交
60
                raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
D
discus0434 已提交
61 62

            # Add layer normalization
A
aria1th 已提交
63 64
            if add_layer_norm:
                linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
65

66
            # Add dropout except last layer
G
guaneec 已提交
67
            if use_dropout and (i < len(layer_structure) - 3 or last_layer_dropout and i < len(layer_structure) - 2):
D
discus0434 已提交
68
                linears.append(torch.nn.Dropout(p=0.3))
D
discus0434 已提交
69

70
        self.linear = torch.nn.Sequential(*linears)
A
AUTOMATIC 已提交
71 72

        if state_dict is not None:
73 74
            self.fix_old_state_dict(state_dict)
            self.load_state_dict(state_dict)
A
AUTOMATIC 已提交
75
        else:
76
            for layer in self.linear:
77
                if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
78 79 80
                    w, b = layer.weight.data, layer.bias.data
                    if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
                        normal_(w, mean=0.0, std=0.01)
81
                        normal_(b, mean=0.0, std=0)
82 83 84 85 86 87 88 89 90 91 92 93 94 95
                    elif weight_init == 'XavierUniform':
                        xavier_uniform_(w)
                        zeros_(b)
                    elif weight_init == 'XavierNormal':
                        xavier_normal_(w)
                        zeros_(b)
                    elif weight_init == 'KaimingUniform':
                        kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
                        zeros_(b)
                    elif weight_init == 'KaimingNormal':
                        kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
                        zeros_(b)
                    else:
                        raise KeyError(f"Key {weight_init} is not defined as initialization!")
A
AUTOMATIC 已提交
96 97
        self.to(devices.device)

98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
    def fix_old_state_dict(self, state_dict):
        changes = {
            'linear1.bias': 'linear.0.bias',
            'linear1.weight': 'linear.0.weight',
            'linear2.bias': 'linear.1.bias',
            'linear2.weight': 'linear.1.weight',
        }

        for fr, to in changes.items():
            x = state_dict.get(fr, None)
            if x is None:
                continue

            del state_dict[fr]
            state_dict[to] = x
113

A
AUTOMATIC 已提交
114
    def forward(self, x):
115 116 117
        return x + self.linear(x) * self.multiplier

    def trainables(self):
118
        layer_structure = []
119
        for layer in self.linear:
120
            if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
D
update  
discus0434 已提交
121
                layer_structure += [layer.weight, layer.bias]
122
        return layer_structure
A
AUTOMATIC 已提交
123 124 125 126


def apply_strength(value=None):
    HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
A
AUTOMATIC 已提交
127 128 129 130 131 132


class Hypernetwork:
    filename = None
    name = None

133
    def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, activate_output=False, **kwargs):
A
AUTOMATIC 已提交
134 135 136 137 138 139
        self.filename = None
        self.name = name
        self.layers = {}
        self.step = 0
        self.sd_checkpoint = None
        self.sd_checkpoint_name = None
140
        self.layer_structure = layer_structure
D
update  
discus0434 已提交
141
        self.activation_func = activation_func
142
        self.weight_init = weight_init
D
discus0434 已提交
143 144
        self.add_layer_norm = add_layer_norm
        self.use_dropout = use_dropout
G
guaneec 已提交
145
        self.activate_output = activate_output
146
        self.last_layer_dropout = kwargs['last_layer_dropout'] if 'last_layer_dropout' in kwargs else True
A
apply  
aria1th 已提交
147 148
        self.optimizer_name = None
        self.optimizer_state_dict = None
A
AUTOMATIC 已提交
149

150
        for size in enable_sizes or []:
151
            self.layers[size] = (
152 153 154 155
                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
                                   self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
                                   self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
156
            )
A
AUTOMATIC 已提交
157 158 159 160 161 162

    def weights(self):
        res = []

        for k, layers in self.layers.items():
            for layer in layers:
A
aria1th 已提交
163
                layer.train()
164
                res += layer.trainables()
A
AUTOMATIC 已提交
165 166 167 168 169

        return res

    def save(self, filename):
        state_dict = {}
A
apply  
aria1th 已提交
170
        optimizer_saved_dict = {}
A
AUTOMATIC 已提交
171 172 173 174 175 176

        for k, v in self.layers.items():
            state_dict[k] = (v[0].state_dict(), v[1].state_dict())

        state_dict['step'] = self.step
        state_dict['name'] = self.name
177
        state_dict['layer_structure'] = self.layer_structure
D
update  
discus0434 已提交
178
        state_dict['activation_func'] = self.activation_func
D
discus0434 已提交
179
        state_dict['is_layer_norm'] = self.add_layer_norm
180
        state_dict['weight_initialization'] = self.weight_init
D
discus0434 已提交
181
        state_dict['use_dropout'] = self.use_dropout
A
AUTOMATIC 已提交
182 183
        state_dict['sd_checkpoint'] = self.sd_checkpoint
        state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name
G
guaneec 已提交
184
        state_dict['activate_output'] = self.activate_output
185
        state_dict['last_layer_dropout'] = self.last_layer_dropout
A
apply  
aria1th 已提交
186 187 188 189

        if self.optimizer_name is not None:
            optimizer_saved_dict['optimizer_name'] = self.optimizer_name

A
AUTOMATIC 已提交
190
        torch.save(state_dict, filename)
A
aria1th 已提交
191
        if shared.opts.save_optimizer_state and self.optimizer_state_dict:
A
apply  
aria1th 已提交
192 193 194
            optimizer_saved_dict['hash'] = sd_models.model_hash(filename)
            optimizer_saved_dict['optimizer_state_dict'] = self.optimizer_state_dict
            torch.save(optimizer_saved_dict, filename + '.optim')
A
AUTOMATIC 已提交
195 196 197 198 199 200 201 202

    def load(self, filename):
        self.filename = filename
        if self.name is None:
            self.name = os.path.splitext(os.path.basename(filename))[0]

        state_dict = torch.load(filename, map_location='cpu')

203
        self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
204
        print(self.layer_structure)
D
update  
discus0434 已提交
205
        self.activation_func = state_dict.get('activation_func', None)
206 207 208
        print(f"Activation function is {self.activation_func}")
        self.weight_init = state_dict.get('weight_initialization', 'Normal')
        print(f"Weight initialization is {self.weight_init}")
D
discus0434 已提交
209
        self.add_layer_norm = state_dict.get('is_layer_norm', False)
210
        print(f"Layer norm is set to {self.add_layer_norm}")
G
guaneec 已提交
211
        self.use_dropout = state_dict.get('use_dropout', False)
212
        print(f"Dropout usage is set to {self.use_dropout}" )
G
guaneec 已提交
213
        self.activate_output = state_dict.get('activate_output', True)
214 215
        print(f"Activate last layer is set to {self.activate_output}")
        self.last_layer_dropout = state_dict.get('last_layer_dropout', False)
216

A
apply  
aria1th 已提交
217 218 219 220 221 222 223 224 225 226 227 228
        optimizer_saved_dict = torch.load(self.filename + '.optim', map_location = 'cpu') if os.path.exists(self.filename + '.optim') else {}
        self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW')
        print(f"Optimizer name is {self.optimizer_name}")
        if sd_models.model_hash(filename) == optimizer_saved_dict.get('hash', None):
            self.optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
        else:
            self.optimizer_state_dict = None
        if self.optimizer_state_dict:
            print("Loaded existing optimizer from checkpoint")
        else:
            print("No saved optimizer exists in checkpoint")

A
AUTOMATIC 已提交
229 230
        for size, sd in state_dict.items():
            if type(size) == int:
231
                self.layers[size] = (
232 233 234 235
                    HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init,
                                       self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
                    HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init,
                                       self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
236
                )
A
AUTOMATIC 已提交
237 238 239 240 241 242 243

        self.name = state_dict.get('name', self.name)
        self.step = state_dict.get('step', 0)
        self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
        self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)


A
AUTOMATIC 已提交
244
def list_hypernetworks(path):
A
AUTOMATIC 已提交
245
    res = {}
I
Isaac Poulton 已提交
246
    for filename in sorted(glob.iglob(os.path.join(path, '**/*.pt'), recursive=True)):
A
AUTOMATIC 已提交
247
        name = os.path.splitext(os.path.basename(filename))[0]
248 249
        # Prevent a hypothetical "None.pt" from being listed.
        if name != "None":
A
apply  
aria1th 已提交
250
            res[name + f"({sd_models.model_hash(filename)})"] = filename
A
AUTOMATIC 已提交
251
    return res
A
AUTOMATIC 已提交
252

A
AUTOMATIC 已提交
253 254 255

def load_hypernetwork(filename):
    path = shared.hypernetworks.get(filename, None)
256 257
    # Prevent any file named "None.pt" from being loaded.
    if path is not None and filename != "None":
A
AUTOMATIC 已提交
258
        print(f"Loading hypernetwork {filename}")
A
AUTOMATIC 已提交
259
        try:
A
AUTOMATIC 已提交
260 261 262
            shared.loaded_hypernetwork = Hypernetwork()
            shared.loaded_hypernetwork.load(path)

A
AUTOMATIC 已提交
263
        except Exception:
A
AUTOMATIC 已提交
264
            print(f"Error loading hypernetwork {path}", file=sys.stderr)
A
AUTOMATIC 已提交
265
            print(traceback.format_exc(), file=sys.stderr)
A
AUTOMATIC 已提交
266 267 268
    else:
        if shared.loaded_hypernetwork is not None:
            print(f"Unloading hypernetwork")
A
AUTOMATIC 已提交
269

A
AUTOMATIC 已提交
270
        shared.loaded_hypernetwork = None
A
AUTOMATIC 已提交
271 272


M
Milly 已提交
273 274 275 276 277 278 279 280 281 282 283
def find_closest_hypernetwork_name(search: str):
    if not search:
        return None
    search = search.lower()
    applicable = [name for name in shared.hypernetworks if search in name.lower()]
    if not applicable:
        return None
    applicable = sorted(applicable, key=lambda name: len(name))
    return applicable[0]


A
AUTOMATIC 已提交
284 285
def apply_hypernetwork(hypernetwork, context, layer=None):
    hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
A
AUTOMATIC 已提交
286

A
AUTOMATIC 已提交
287 288
    if hypernetwork_layers is None:
        return context, context
A
AUTOMATIC 已提交
289

A
AUTOMATIC 已提交
290 291 292
    if layer is not None:
        layer.hyper_k = hypernetwork_layers[0]
        layer.hyper_v = hypernetwork_layers[1]
A
AUTOMATIC 已提交
293

A
AUTOMATIC 已提交
294 295 296
    context_k = hypernetwork_layers[0](context)
    context_v = hypernetwork_layers[1](context)
    return context_k, context_v
A
AUTOMATIC 已提交
297 298


A
AUTOMATIC 已提交
299 300 301 302 303
def attention_CrossAttention_forward(self, x, context=None, mask=None):
    h = self.heads

    q = self.to_q(x)
    context = default(context, x)
A
AUTOMATIC 已提交
304

A
AUTOMATIC 已提交
305
    context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
A
AUTOMATIC 已提交
306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
    k = self.to_k(context_k)
    v = self.to_v(context_v)

    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

    sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

    if mask is not None:
        mask = rearrange(mask, 'b ... -> b (...)')
        max_neg_value = -torch.finfo(sim.dtype).max
        mask = repeat(mask, 'b j -> (b h) () j', h=h)
        sim.masked_fill_(~mask, max_neg_value)

    # attention, what we cannot get enough of
    attn = sim.softmax(dim=-1)

    out = einsum('b i j, b j d -> b i d', attn, v)
    out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
    return self.to_out(out)


327 328 329 330 331 332 333 334 335 336 337 338 339 340
def stack_conds(conds):
    if len(conds) == 1:
        return torch.stack(conds)

    # same as in reconstruct_multicond_batch
    token_count = max([x.shape[0] for x in conds])
    for i in range(len(conds)):
        if conds[i].shape[0] != token_count:
            last_vector = conds[i][-1:]
            last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
            conds[i] = torch.vstack([conds[i], last_vector_repeated])

    return torch.stack(conds)

341

A
AngelBottomless 已提交
342
def statistics(data):
A
AngelBottomless 已提交
343 344 345 346 347
    if len(data) < 2:
        std = 0
    else:
        std = stdev(data)
    total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
A
AngelBottomless 已提交
348
    recent_data = data[-32:]
A
AngelBottomless 已提交
349 350 351 352 353
    if len(recent_data) < 2:
        std = 0
    else:
        std = stdev(recent_data)
    recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
A
AngelBottomless 已提交
354 355 356 357 358 359
    return total_information, recent_information


def report_statistics(loss_info:dict):
    keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
    for key in keys:
D
DepFA 已提交
360 361
        try:
            print("Loss statistics for file " + key)
A
AngelBottomless 已提交
362
            info, recent = statistics(list(loss_info[key]))
D
DepFA 已提交
363 364 365 366
            print(info)
            print(recent)
        except Exception as e:
            print(e)
A
AngelBottomless 已提交
367 368 369



370
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
T
timntorres 已提交
371
    # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
372 373
    from modules import images

374 375
    save_hypernetwork_every = save_hypernetwork_every or 0
    create_image_every = create_image_every or 0
376
    textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
A
AUTOMATIC 已提交
377

A
AUTOMATIC 已提交
378 379 380
    path = shared.hypernetworks.get(hypernetwork_name, None)
    shared.loaded_hypernetwork = Hypernetwork()
    shared.loaded_hypernetwork.load(path)
A
AUTOMATIC 已提交
381 382 383 384

    shared.state.textinfo = "Initializing hypernetwork training..."
    shared.state.job_count = steps

A
aria1th 已提交
385
    hypernetwork_name = hypernetwork_name.rsplit('(', 1)[0]
A
AUTOMATIC 已提交
386 387 388
    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')

    log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
389
    unload = shared.opts.unload_models_when_training
A
AUTOMATIC 已提交
390 391 392 393 394 395 396 397 398 399 400 401 402

    if save_hypernetwork_every > 0:
        hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
        os.makedirs(hypernetwork_dir, exist_ok=True)
    else:
        hypernetwork_dir = None

    if create_image_every > 0:
        images_dir = os.path.join(log_directory, "images")
        os.makedirs(images_dir, exist_ok=True)
    else:
        images_dir = None

403
    hypernetwork = shared.loaded_hypernetwork
404
    checkpoint = sd_models.select_checkpoint()
405

406 407
    initial_step = hypernetwork.step or 0
    if initial_step >= steps:
408 409 410
        shared.state.textinfo = f"Model has already been trained beyond specified max steps"
        return hypernetwork, filename

411 412
    scheduler = LearnRateScheduler(learn_rate, steps, initial_step)

413
    # dataset loading may take a while, so input validations and early returns should be done before this
A
AUTOMATIC 已提交
414
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
415 416 417 418

    pin_memory = shared.opts.pin_memory

    ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
419 420 421 422
    
    latent_sampling_method = ds.latent_sampling_method

    dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
423

424 425 426
    if unload:
        shared.sd_model.cond_stage_model.to(devices.cpu)
        shared.sd_model.first_stage_model.to(devices.cpu)
427 428 429 430
    
    weights = hypernetwork.weights()
    for weight in weights:
        weight.requires_grad = True
431

A
apply  
aria1th 已提交
432
    # Here we use optimizer from saved HN, or we can specify as UI option.
433
    if hypernetwork.optimizer_name in optimizer_dict:
A
apply  
aria1th 已提交
434
        optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
435
        optimizer_name = hypernetwork.optimizer_name
A
apply  
aria1th 已提交
436
    else:
437
        print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
438 439
    optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
    optimizer_name = 'AdamW'
440

A
apply  
aria1th 已提交
441 442 443 444 445 446
    if hypernetwork.optimizer_state_dict:  # This line must be changed if Optimizer type can be different from saved optimizer.
        try:
            optimizer.load_state_dict(hypernetwork.optimizer_state_dict)
        except RuntimeError as e:
            print("Cannot resume from saved optimizer!")
            print(e)
A
AUTOMATIC 已提交
447

448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463
    scaler = torch.cuda.amp.GradScaler()
    
    batch_size = ds.batch_size
    gradient_step = ds.gradient_step
    # n steps = batch_size * gradient_step * n image processed
    steps_per_epoch = len(ds) // batch_size // gradient_step
    max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
    loss_step = 0
    _loss_step = 0 #internal
    # size = len(ds.indexes)
    # loss_dict = defaultdict(lambda : deque(maxlen = 1024))
    # losses = torch.zeros((size,))
    # previous_mean_losses = [0]
    # previous_mean_loss = 0
    # print("Mean loss of {} elements".format(size))

464 465
    steps_without_grad = 0

466 467 468 469
    last_saved_file = "<none>"
    last_saved_image = "<none>"
    forced_filename = "<none>"

470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517
    pbar = tqdm.tqdm(total=steps - initial_step)
    try:
        for i in range((steps-initial_step) * gradient_step):
            if scheduler.finished:
                break
            if shared.state.interrupted:
                break
            for j, batch in enumerate(dl):
                # works as a drop_last=True for gradient accumulation
                if j == max_steps_per_epoch:
                    break
                scheduler.apply(optimizer, hypernetwork.step)
                if scheduler.finished:
                    break
                if shared.state.interrupted:
                    break

                with torch.autocast("cuda"):
                    x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
                    if tag_drop_out != 0 or shuffle_tags:
                        shared.sd_model.cond_stage_model.to(devices.device)
                        c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
                        shared.sd_model.cond_stage_model.to(devices.cpu)
                    else:
                        c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
                    loss = shared.sd_model(x, c)[0] / gradient_step
                    del x
                    del c

                    _loss_step += loss.item()
                scaler.scale(loss).backward()
                # go back until we reach gradient accumulation steps
                if (j + 1) % gradient_step != 0:
                    continue
                # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
                # scaler.unscale_(optimizer)
                # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
                # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
                # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
                scaler.step(optimizer)
                scaler.update()
                hypernetwork.step += 1
                pbar.update()
                optimizer.zero_grad(set_to_none=True)
                loss_step = _loss_step
                _loss_step = 0

                steps_done = hypernetwork.step + 1
A
AngelBottomless 已提交
518
                
519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564
                epoch_num = hypernetwork.step // steps_per_epoch
                epoch_step = hypernetwork.step % steps_per_epoch

                pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
                if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
                    # Before saving, change name to match current checkpoint.
                    hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
                    last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
                    hypernetwork.optimizer_name = optimizer_name
                    if shared.opts.save_optimizer_state:
                        hypernetwork.optimizer_state_dict = optimizer.state_dict()
                    save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
                    hypernetwork.optimizer_state_dict = None  # dereference it after saving, to save memory.

                textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
                    "loss": f"{loss_step:.7f}",
                    "learn_rate": scheduler.learn_rate
                })

                if images_dir is not None and steps_done % create_image_every == 0:
                    forced_filename = f'{hypernetwork_name}-{steps_done}'
                    last_saved_image = os.path.join(images_dir, forced_filename)

                    shared.sd_model.cond_stage_model.to(devices.device)
                    shared.sd_model.first_stage_model.to(devices.device)

                    p = processing.StableDiffusionProcessingTxt2Img(
                        sd_model=shared.sd_model,
                        do_not_save_grid=True,
                        do_not_save_samples=True,
                    )

                    if preview_from_txt2img:
                        p.prompt = preview_prompt
                        p.negative_prompt = preview_negative_prompt
                        p.steps = preview_steps
                        p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
                        p.cfg_scale = preview_cfg_scale
                        p.seed = preview_seed
                        p.width = preview_width
                        p.height = preview_height
                    else:
                        p.prompt = batch.cond_text[0]
                        p.steps = 20
                        p.width = training_width
                        p.height = training_height
A
aria1th 已提交
565

566
                    preview_text = p.prompt
A
aria1th 已提交
567

568 569
                    processed = processing.process_images(p)
                    image = processed.images[0] if len(processed.images) > 0 else None
A
aria1th 已提交
570

571 572 573
                    if unload:
                        shared.sd_model.cond_stage_model.to(devices.cpu)
                        shared.sd_model.first_stage_model.to(devices.cpu)
A
aria1th 已提交
574

575 576 577 578
                    if image is not None:
                        shared.state.current_image = image
                        last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
                        last_saved_image += f", prompt: {preview_text}"
A
AUTOMATIC 已提交
579

580
                shared.state.job_no = hypernetwork.step
A
AUTOMATIC 已提交
581

582
                shared.state.textinfo = f"""
A
AUTOMATIC 已提交
583
<p>
584
Loss: {loss_step:.7f}<br/>
A
AUTOMATIC 已提交
585
Step: {hypernetwork.step}<br/>
586
Last prompt: {html.escape(batch.cond_text[0])}<br/>
D
DepFA 已提交
587
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
A
AUTOMATIC 已提交
588 589 590
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
591 592 593 594 595 596
    except Exception:
        print(traceback.format_exc(), file=sys.stderr)
    finally:
        pbar.leave = False
        pbar.close()
        #report_statistics(loss_dict)
A
AUTOMATIC 已提交
597

598
    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
A
apply  
aria1th 已提交
599 600 601
    hypernetwork.optimizer_name = optimizer_name
    if shared.opts.save_optimizer_state:
        hypernetwork.optimizer_state_dict = optimizer.state_dict()
602
    save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
A
apply  
aria1th 已提交
603 604
    del optimizer
    hypernetwork.optimizer_state_dict = None  # dereference it after saving, to save memory.
605 606 607
    shared.sd_model.cond_stage_model.to(devices.device)
    shared.sd_model.first_stage_model.to(devices.device)

A
AUTOMATIC 已提交
608
    return hypernetwork, filename
609 610 611 612 613 614 615 616 617 618 619 620 621 622 623

def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
    old_hypernetwork_name = hypernetwork.name
    old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
    old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
    try:
        hypernetwork.sd_checkpoint = checkpoint.hash
        hypernetwork.sd_checkpoint_name = checkpoint.model_name
        hypernetwork.name = hypernetwork_name
        hypernetwork.save(filename)
    except:
        hypernetwork.sd_checkpoint = old_sd_checkpoint
        hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
        hypernetwork.name = old_hypernetwork_name
        raise