hypernetwork.py 20.3 KB
Newer Older
D
discus0434 已提交
1
import csv
A
AUTOMATIC 已提交
2 3 4 5 6 7
import datetime
import glob
import html
import os
import sys
import traceback
8
import inspect
A
AUTOMATIC 已提交
9

D
discus0434 已提交
10
import modules.textual_inversion.dataset
A
AUTOMATIC 已提交
11
import torch
D
discus0434 已提交
12
import tqdm
D
update  
discus0434 已提交
13
from einops import rearrange, repeat
D
discus0434 已提交
14 15
from ldm.util import default
from modules import devices, processing, sd_models, shared
16
from modules.textual_inversion import textual_inversion
17
from modules.textual_inversion.learn_schedule import LearnRateScheduler
D
discus0434 已提交
18
from torch import einsum
19
from torch.nn.init import normal_, xavier_normal_, xavier_uniform_, kaiming_normal_, kaiming_uniform_, zeros_
20

A
AngelBottomless 已提交
21
from collections import defaultdict, deque
A
AngelBottomless 已提交
22
from statistics import stdev, mean
23

24

A
AUTOMATIC 已提交
25
class HypernetworkModule(torch.nn.Module):
A
AUTOMATIC 已提交
26
    multiplier = 1.0
D
discus0434 已提交
27 28 29 30 31
    activation_dict = {
        "relu": torch.nn.ReLU,
        "leakyrelu": torch.nn.LeakyReLU,
        "elu": torch.nn.ELU,
        "swish": torch.nn.Hardswish,
32 33
        "tanh": torch.nn.Tanh,
        "sigmoid": torch.nn.Sigmoid,
D
discus0434 已提交
34
    }
35
    activation_dict.update({cls_name: cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})
D
discus0434 已提交
36

37
    def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal', add_layer_norm=False, use_dropout=False):
A
AUTOMATIC 已提交
38
        super().__init__()
39

D
update  
discus0434 已提交
40
        assert layer_structure is not None, "layer_structure must not be None"
41 42
        assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
        assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
D
discus0434 已提交
43

44 45
        linears = []
        for i in range(len(layer_structure) - 1):
D
discus0434 已提交
46 47

            # Add a fully-connected layer
48
            linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
D
discus0434 已提交
49 50

            # Add an activation func
D
discus0434 已提交
51
            if activation_func == "linear" or activation_func is None:
D
discus0434 已提交
52 53 54
                pass
            elif activation_func in self.activation_dict:
                linears.append(self.activation_dict[activation_func]())
55
            else:
D
discus0434 已提交
56
                raise RuntimeError(f'hypernetwork uses an unsupported activation function: {activation_func}')
D
discus0434 已提交
57 58

            # Add layer normalization
A
aria1th 已提交
59 60
            if add_layer_norm:
                linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))
61

D
discus0434 已提交
62 63 64
            # Add dropout expect last layer
            if use_dropout and i < len(layer_structure) - 3:
                linears.append(torch.nn.Dropout(p=0.3))
D
discus0434 已提交
65

66
        self.linear = torch.nn.Sequential(*linears)
A
AUTOMATIC 已提交
67 68

        if state_dict is not None:
69 70
            self.fix_old_state_dict(state_dict)
            self.load_state_dict(state_dict)
A
AUTOMATIC 已提交
71
        else:
72
            for layer in self.linear:
73
                if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
                    w, b = layer.weight.data, layer.bias.data
                    if weight_init == "Normal" or type(layer) == torch.nn.LayerNorm:
                        normal_(w, mean=0.0, std=0.01)
                        normal_(b, mean=0.0, std=0.005)
                    elif weight_init == 'XavierUniform':
                        xavier_uniform_(w)
                        zeros_(b)
                    elif weight_init == 'XavierNormal':
                        xavier_normal_(w)
                        zeros_(b)
                    elif weight_init == 'KaimingUniform':
                        kaiming_uniform_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
                        zeros_(b)
                    elif weight_init == 'KaimingNormal':
                        kaiming_normal_(w, nonlinearity='leaky_relu' if 'leakyrelu' == activation_func else 'relu')
                        zeros_(b)
                    else:
                        raise KeyError(f"Key {weight_init} is not defined as initialization!")
A
AUTOMATIC 已提交
92 93
        self.to(devices.device)

94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
    def fix_old_state_dict(self, state_dict):
        changes = {
            'linear1.bias': 'linear.0.bias',
            'linear1.weight': 'linear.0.weight',
            'linear2.bias': 'linear.1.bias',
            'linear2.weight': 'linear.1.weight',
        }

        for fr, to in changes.items():
            x = state_dict.get(fr, None)
            if x is None:
                continue

            del state_dict[fr]
            state_dict[to] = x
109

A
AUTOMATIC 已提交
110
    def forward(self, x):
111 112 113
        return x + self.linear(x) * self.multiplier

    def trainables(self):
114
        layer_structure = []
115
        for layer in self.linear:
116
            if type(layer) == torch.nn.Linear or type(layer) == torch.nn.LayerNorm:
D
update  
discus0434 已提交
117
                layer_structure += [layer.weight, layer.bias]
118
        return layer_structure
A
AUTOMATIC 已提交
119 120 121 122


def apply_strength(value=None):
    HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
A
AUTOMATIC 已提交
123 124 125 126 127 128


class Hypernetwork:
    filename = None
    name = None

129
    def __init__(self, name=None, enable_sizes=None, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False):
A
AUTOMATIC 已提交
130 131 132 133 134 135
        self.filename = None
        self.name = name
        self.layers = {}
        self.step = 0
        self.sd_checkpoint = None
        self.sd_checkpoint_name = None
136
        self.layer_structure = layer_structure
D
update  
discus0434 已提交
137
        self.activation_func = activation_func
138
        self.weight_init = weight_init
D
discus0434 已提交
139 140
        self.add_layer_norm = add_layer_norm
        self.use_dropout = use_dropout
A
AUTOMATIC 已提交
141

142
        for size in enable_sizes or []:
143
            self.layers[size] = (
144 145
                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
146
            )
A
AUTOMATIC 已提交
147 148 149 150 151 152

    def weights(self):
        res = []

        for k, layers in self.layers.items():
            for layer in layers:
A
aria1th 已提交
153
                layer.train()
154
                res += layer.trainables()
A
AUTOMATIC 已提交
155 156 157 158 159 160 161 162 163 164 165

        return res

    def save(self, filename):
        state_dict = {}

        for k, v in self.layers.items():
            state_dict[k] = (v[0].state_dict(), v[1].state_dict())

        state_dict['step'] = self.step
        state_dict['name'] = self.name
166
        state_dict['layer_structure'] = self.layer_structure
D
update  
discus0434 已提交
167
        state_dict['activation_func'] = self.activation_func
D
discus0434 已提交
168
        state_dict['is_layer_norm'] = self.add_layer_norm
169
        state_dict['weight_initialization'] = self.weight_init
D
discus0434 已提交
170
        state_dict['use_dropout'] = self.use_dropout
A
AUTOMATIC 已提交
171 172 173 174 175 176 177 178 179 180 181 182
        state_dict['sd_checkpoint'] = self.sd_checkpoint
        state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name

        torch.save(state_dict, filename)

    def load(self, filename):
        self.filename = filename
        if self.name is None:
            self.name = os.path.splitext(os.path.basename(filename))[0]

        state_dict = torch.load(filename, map_location='cpu')

183
        self.layer_structure = state_dict.get('layer_structure', [1, 2, 1])
184
        print(self.layer_structure)
D
update  
discus0434 已提交
185
        self.activation_func = state_dict.get('activation_func', None)
186 187 188
        print(f"Activation function is {self.activation_func}")
        self.weight_init = state_dict.get('weight_initialization', 'Normal')
        print(f"Weight initialization is {self.weight_init}")
D
discus0434 已提交
189
        self.add_layer_norm = state_dict.get('is_layer_norm', False)
190
        print(f"Layer norm is set to {self.add_layer_norm}")
D
discus0434 已提交
191
        self.use_dropout = state_dict.get('use_dropout', False)
192
        print(f"Dropout usage is set to {self.use_dropout}" )
193

A
AUTOMATIC 已提交
194 195
        for size, sd in state_dict.items():
            if type(size) == int:
196
                self.layers[size] = (
197 198
                    HypernetworkModule(size, sd[0], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
                    HypernetworkModule(size, sd[1], self.layer_structure, self.activation_func, self.weight_init, self.add_layer_norm, self.use_dropout),
199
                )
A
AUTOMATIC 已提交
200 201 202 203 204 205 206

        self.name = state_dict.get('name', self.name)
        self.step = state_dict.get('step', 0)
        self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
        self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)


A
AUTOMATIC 已提交
207
def list_hypernetworks(path):
A
AUTOMATIC 已提交
208
    res = {}
A
AUTOMATIC 已提交
209 210 211 212
    for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
        name = os.path.splitext(os.path.basename(filename))[0]
        res[name] = filename
    return res
A
AUTOMATIC 已提交
213

A
AUTOMATIC 已提交
214 215 216 217 218

def load_hypernetwork(filename):
    path = shared.hypernetworks.get(filename, None)
    if path is not None:
        print(f"Loading hypernetwork {filename}")
A
AUTOMATIC 已提交
219
        try:
A
AUTOMATIC 已提交
220 221 222
            shared.loaded_hypernetwork = Hypernetwork()
            shared.loaded_hypernetwork.load(path)

A
AUTOMATIC 已提交
223
        except Exception:
A
AUTOMATIC 已提交
224
            print(f"Error loading hypernetwork {path}", file=sys.stderr)
A
AUTOMATIC 已提交
225
            print(traceback.format_exc(), file=sys.stderr)
A
AUTOMATIC 已提交
226 227 228
    else:
        if shared.loaded_hypernetwork is not None:
            print(f"Unloading hypernetwork")
A
AUTOMATIC 已提交
229

A
AUTOMATIC 已提交
230
        shared.loaded_hypernetwork = None
A
AUTOMATIC 已提交
231 232


M
Milly 已提交
233 234 235 236 237 238 239 240 241 242 243
def find_closest_hypernetwork_name(search: str):
    if not search:
        return None
    search = search.lower()
    applicable = [name for name in shared.hypernetworks if search in name.lower()]
    if not applicable:
        return None
    applicable = sorted(applicable, key=lambda name: len(name))
    return applicable[0]


A
AUTOMATIC 已提交
244 245
def apply_hypernetwork(hypernetwork, context, layer=None):
    hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
A
AUTOMATIC 已提交
246

A
AUTOMATIC 已提交
247 248
    if hypernetwork_layers is None:
        return context, context
A
AUTOMATIC 已提交
249

A
AUTOMATIC 已提交
250 251 252
    if layer is not None:
        layer.hyper_k = hypernetwork_layers[0]
        layer.hyper_v = hypernetwork_layers[1]
A
AUTOMATIC 已提交
253

A
AUTOMATIC 已提交
254 255 256
    context_k = hypernetwork_layers[0](context)
    context_v = hypernetwork_layers[1](context)
    return context_k, context_v
A
AUTOMATIC 已提交
257 258


A
AUTOMATIC 已提交
259 260 261 262 263
def attention_CrossAttention_forward(self, x, context=None, mask=None):
    h = self.heads

    q = self.to_q(x)
    context = default(context, x)
A
AUTOMATIC 已提交
264

A
AUTOMATIC 已提交
265
    context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
A
AUTOMATIC 已提交
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
    k = self.to_k(context_k)
    v = self.to_v(context_v)

    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

    sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

    if mask is not None:
        mask = rearrange(mask, 'b ... -> b (...)')
        max_neg_value = -torch.finfo(sim.dtype).max
        mask = repeat(mask, 'b j -> (b h) () j', h=h)
        sim.masked_fill_(~mask, max_neg_value)

    # attention, what we cannot get enough of
    attn = sim.softmax(dim=-1)

    out = einsum('b i j, b j d -> b i d', attn, v)
    out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
    return self.to_out(out)


287 288 289 290 291 292 293 294 295 296 297 298 299 300
def stack_conds(conds):
    if len(conds) == 1:
        return torch.stack(conds)

    # same as in reconstruct_multicond_batch
    token_count = max([x.shape[0] for x in conds])
    for i in range(len(conds)):
        if conds[i].shape[0] != token_count:
            last_vector = conds[i][-1:]
            last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
            conds[i] = torch.vstack([conds[i], last_vector_repeated])

    return torch.stack(conds)

301

A
AngelBottomless 已提交
302
def statistics(data):
A
AngelBottomless 已提交
303 304 305 306 307
    if len(data) < 2:
        std = 0
    else:
        std = stdev(data)
    total_information = f"loss:{mean(data):.3f}" + u"\u00B1" + f"({std/ (len(data) ** 0.5):.3f})"
A
AngelBottomless 已提交
308
    recent_data = data[-32:]
A
AngelBottomless 已提交
309 310 311 312 313
    if len(recent_data) < 2:
        std = 0
    else:
        std = stdev(recent_data)
    recent_information = f"recent 32 loss:{mean(recent_data):.3f}" + u"\u00B1" + f"({std / (len(recent_data) ** 0.5):.3f})"
A
AngelBottomless 已提交
314 315 316 317 318 319
    return total_information, recent_information


def report_statistics(loss_info:dict):
    keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
    for key in keys:
D
DepFA 已提交
320 321
        try:
            print("Loss statistics for file " + key)
A
AngelBottomless 已提交
322
            info, recent = statistics(list(loss_info[key]))
D
DepFA 已提交
323 324 325 326
            print(info)
            print(recent)
        except Exception as e:
            print(e)
A
AngelBottomless 已提交
327 328 329



330
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
T
timntorres 已提交
331
    # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
332 333
    from modules import images

A
AUTOMATIC 已提交
334
    assert hypernetwork_name, 'hypernetwork not selected'
A
AUTOMATIC 已提交
335

A
AUTOMATIC 已提交
336 337 338
    path = shared.hypernetworks.get(hypernetwork_name, None)
    shared.loaded_hypernetwork = Hypernetwork()
    shared.loaded_hypernetwork.load(path)
A
AUTOMATIC 已提交
339 340 341 342 343 344 345

    shared.state.textinfo = "Initializing hypernetwork training..."
    shared.state.job_count = steps

    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')

    log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
346
    unload = shared.opts.unload_models_when_training
A
AUTOMATIC 已提交
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361

    if save_hypernetwork_every > 0:
        hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
        os.makedirs(hypernetwork_dir, exist_ok=True)
    else:
        hypernetwork_dir = None

    if create_image_every > 0:
        images_dir = os.path.join(log_directory, "images")
        os.makedirs(images_dir, exist_ok=True)
    else:
        images_dir = None

    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    with torch.autocast("cuda"):
362
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
363 364 365
    if unload:
        shared.sd_model.cond_stage_model.to(devices.cpu)
        shared.sd_model.first_stage_model.to(devices.cpu)
A
AUTOMATIC 已提交
366

A
AUTOMATIC 已提交
367
    hypernetwork = shared.loaded_hypernetwork
A
aria1th 已提交
368 369 370 371
    weights = hypernetwork.weights()
    for weight in weights:
        weight.requires_grad = True

A
AngelBottomless 已提交
372
    size = len(ds.indexes)
A
AngelBottomless 已提交
373
    loss_dict = defaultdict(lambda : deque(maxlen = 1024))
A
AngelBottomless 已提交
374
    losses = torch.zeros((size,))
A
AngelBottomless 已提交
375
    previous_mean_losses = [0]
A
AngelBottomless 已提交
376 377
    previous_mean_loss = 0
    print("Mean loss of {} elements".format(size))
A
AUTOMATIC 已提交
378 379 380

    last_saved_file = "<none>"
    last_saved_image = "<none>"
381
    forced_filename = "<none>"
A
AUTOMATIC 已提交
382 383 384 385 386

    ititial_step = hypernetwork.step or 0
    if ititial_step > steps:
        return hypernetwork, filename

387
    scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
A
aria1th 已提交
388 389
    # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
    optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
A
AUTOMATIC 已提交
390

391 392
    steps_without_grad = 0

A
AUTOMATIC 已提交
393
    pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
394
    for i, entries in pbar:
A
AUTOMATIC 已提交
395
        hypernetwork.step = i + ititial_step
A
AngelBottomless 已提交
396
        if len(loss_dict) > 0:
A
AngelBottomless 已提交
397 398
            previous_mean_losses = [i[-1] for i in loss_dict.values()]
            previous_mean_loss = mean(previous_mean_losses)
A
AngelBottomless 已提交
399
            
400 401 402
        scheduler.apply(optimizer, hypernetwork.step)
        if scheduler.finished:
            break
A
AUTOMATIC 已提交
403 404 405 406 407

        if shared.state.interrupted:
            break

        with torch.autocast("cuda"):
408
            c = stack_conds([entry.cond for entry in entries]).to(devices.device)
D
update  
discus0434 已提交
409
            # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
410 411
            x = torch.stack([entry.latent for entry in entries]).to(devices.device)
            loss = shared.sd_model(x, c)[0]
A
AUTOMATIC 已提交
412
            del x
413
            del c
A
AUTOMATIC 已提交
414 415

            losses[hypernetwork.step % losses.shape[0]] = loss.item()
A
AngelBottomless 已提交
416
            for entry in entries:
A
AngelBottomless 已提交
417
                loss_dict[entry.filename].append(loss.item())
A
AngelBottomless 已提交
418
                
A
aria1th 已提交
419
            optimizer.zero_grad()
420
            weights[0].grad = None
A
AUTOMATIC 已提交
421
            loss.backward()
422 423 424 425 426 427 428

            if weights[0].grad is None:
                steps_without_grad += 1
            else:
                steps_without_grad = 0
            assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'

A
AUTOMATIC 已提交
429
            optimizer.step()
430

A
AngelBottomless 已提交
431
        if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
432
            raise RuntimeError("Loss diverged.")
A
AngelBottomless 已提交
433 434 435 436 437 438 439
        
        if len(previous_mean_losses) > 1:
            std = stdev(previous_mean_losses)
        else:
            std = 0
        dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
        pbar.set_description(dataset_loss_info)
A
AUTOMATIC 已提交
440 441

        if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
442 443 444
            # Before saving, change name to match current checkpoint.
            hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
            last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
A
AUTOMATIC 已提交
445 446
            hypernetwork.save(last_saved_file)

447
        textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
A
AngelBottomless 已提交
448
            "loss": f"{previous_mean_loss:.7f}",
D
update  
discus0434 已提交
449
            "learn_rate": scheduler.learn_rate
450
        })
451

A
AUTOMATIC 已提交
452
        if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
453 454
            forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
            last_saved_image = os.path.join(images_dir, forced_filename)
A
AUTOMATIC 已提交
455

A
aria1th 已提交
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489
            optimizer.zero_grad()
            shared.sd_model.cond_stage_model.to(devices.device)
            shared.sd_model.first_stage_model.to(devices.device)

            p = processing.StableDiffusionProcessingTxt2Img(
                sd_model=shared.sd_model,
                do_not_save_grid=True,
                do_not_save_samples=True,
            )

            if preview_from_txt2img:
                p.prompt = preview_prompt
                p.negative_prompt = preview_negative_prompt
                p.steps = preview_steps
                p.sampler_index = preview_sampler_index
                p.cfg_scale = preview_cfg_scale
                p.seed = preview_seed
                p.width = preview_width
                p.height = preview_height
            else:
                p.prompt = entries[0].cond_text
                p.steps = 20

            preview_text = p.prompt

            processed = processing.process_images(p)
            image = processed.images[0] if len(processed.images)>0 else None

            if unload:
                shared.sd_model.cond_stage_model.to(devices.cpu)
                shared.sd_model.first_stage_model.to(devices.cpu)

            if image is not None:
                shared.state.current_image = image
490
                last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename)
A
aria1th 已提交
491
                last_saved_image += f", prompt: {preview_text}"
A
AUTOMATIC 已提交
492 493 494 495 496

        shared.state.job_no = hypernetwork.step

        shared.state.textinfo = f"""
<p>
A
AngelBottomless 已提交
497
Loss: {previous_mean_loss:.7f}<br/>
A
AUTOMATIC 已提交
498
Step: {hypernetwork.step}<br/>
499
Last prompt: {html.escape(entries[0].cond_text)}<br/>
D
DepFA 已提交
500
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
A
AUTOMATIC 已提交
501 502 503
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
A
AngelBottomless 已提交
504 505
        
    report_statistics(loss_dict)
A
AUTOMATIC 已提交
506 507 508 509
    checkpoint = sd_models.select_checkpoint()

    hypernetwork.sd_checkpoint = checkpoint.hash
    hypernetwork.sd_checkpoint_name = checkpoint.model_name
510 511 512
    # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
    hypernetwork.name = hypernetwork_name
    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
A
AUTOMATIC 已提交
513 514 515
    hypernetwork.save(filename)

    return hypernetwork, filename