hypernetwork.py 14.2 KB
Newer Older
A
AUTOMATIC 已提交
1 2 3 4 5 6
import datetime
import glob
import html
import os
import sys
import traceback
D
update  
discus0434 已提交
7 8
import tqdm
import csv
A
AUTOMATIC 已提交
9 10

import torch
D
update  
discus0434 已提交
11

12
from ldm.util import default
D
update  
discus0434 已提交
13 14 15 16 17
from modules import devices, shared, processing, sd_models
import torch
from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
18
from modules.textual_inversion import textual_inversion
19
from modules.textual_inversion.learn_schedule import LearnRateScheduler
20 21


A
AUTOMATIC 已提交
22
class HypernetworkModule(torch.nn.Module):
A
AUTOMATIC 已提交
23 24
    multiplier = 1.0

25
    def __init__(self, dim, state_dict=None, layer_structure=None, add_layer_norm=False):
A
AUTOMATIC 已提交
26
        super().__init__()
27 28 29
        if layer_structure is not None:
            assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!"
            assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!"
30
        else:
31
            layer_structure = parse_layer_structure(dim, state_dict)
A
AUTOMATIC 已提交
32

33 34 35
        linears = []
        for i in range(len(layer_structure) - 1):
            linears.append(torch.nn.Linear(int(dim * layer_structure[i]), int(dim * layer_structure[i+1])))
36
            if add_layer_norm:
37 38 39
                linears.append(torch.nn.LayerNorm(int(dim * layer_structure[i+1])))

        self.linear = torch.nn.Sequential(*linears)
A
AUTOMATIC 已提交
40 41

        if state_dict is not None:
42 43 44 45
            try:
                self.load_state_dict(state_dict)
            except RuntimeError:
                self.try_load_previous(state_dict)
A
AUTOMATIC 已提交
46
        else:
47 48 49
            for layer in self.linear:
                layer.weight.data.normal_(mean = 0.0, std = 0.01)
                layer.bias.data.zero_()
A
AUTOMATIC 已提交
50 51 52

        self.to(devices.device)

53 54 55 56 57 58 59
    def try_load_previous(self, state_dict):
        states = self.state_dict()
        states['linear.0.bias'].copy_(state_dict['linear1.bias'])
        states['linear.0.weight'].copy_(state_dict['linear1.weight'])
        states['linear.1.bias'].copy_(state_dict['linear2.bias'])
        states['linear.1.weight'].copy_(state_dict['linear2.weight'])

A
AUTOMATIC 已提交
60
    def forward(self, x):
61 62 63
        return x + self.linear(x) * self.multiplier

    def trainables(self):
64
        layer_structure = []
65
        for layer in self.linear:
66 67
            layer_structure += [layer.weight, layer.bias]
        return layer_structure
A
AUTOMATIC 已提交
68 69 70 71


def apply_strength(value=None):
    HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
A
AUTOMATIC 已提交
72 73


74 75 76
def parse_layer_structure(dim, state_dict):
    i = 0
    layer_structure = [1]
77

78 79 80 81
    while (key := "linear.{}.weight".format(i)) in state_dict:
        weight = state_dict[key]
        layer_structure.append(len(weight) // dim)
        i += 1
82

83
    return layer_structure
84 85


A
AUTOMATIC 已提交
86 87 88 89
class Hypernetwork:
    filename = None
    name = None

90
    def __init__(self, name=None, enable_sizes=None, layer_structure=None, add_layer_norm=False):
A
AUTOMATIC 已提交
91 92 93 94 95 96
        self.filename = None
        self.name = name
        self.layers = {}
        self.step = 0
        self.sd_checkpoint = None
        self.sd_checkpoint_name = None
97 98
        self.layer_structure = layer_structure
        self.add_layer_norm = add_layer_norm
A
AUTOMATIC 已提交
99

100
        for size in enable_sizes or []:
101 102 103 104
            self.layers[size] = (
                HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
                HypernetworkModule(size, None, self.layer_structure, self.add_layer_norm),
            )
A
AUTOMATIC 已提交
105 106 107 108 109 110 111

    def weights(self):
        res = []

        for k, layers in self.layers.items():
            for layer in layers:
                layer.train()
112
                res += layer.trainables()
A
AUTOMATIC 已提交
113 114 115 116 117 118 119 120 121 122 123

        return res

    def save(self, filename):
        state_dict = {}

        for k, v in self.layers.items():
            state_dict[k] = (v[0].state_dict(), v[1].state_dict())

        state_dict['step'] = self.step
        state_dict['name'] = self.name
124 125
        state_dict['layer_structure'] = self.layer_structure
        state_dict['is_layer_norm'] = self.add_layer_norm
A
AUTOMATIC 已提交
126 127 128 129 130 131 132 133 134 135 136 137 138 139
        state_dict['sd_checkpoint'] = self.sd_checkpoint
        state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name

        torch.save(state_dict, filename)

    def load(self, filename):
        self.filename = filename
        if self.name is None:
            self.name = os.path.splitext(os.path.basename(filename))[0]

        state_dict = torch.load(filename, map_location='cpu')

        for size, sd in state_dict.items():
            if type(size) == int:
140 141 142 143
                self.layers[size] = (
                    HypernetworkModule(size, sd[0], state_dict["layer_structure"], state_dict["is_layer_norm"]),
                    HypernetworkModule(size, sd[1], state_dict["layer_structure"], state_dict["is_layer_norm"]),
                )
A
AUTOMATIC 已提交
144 145 146

        self.name = state_dict.get('name', self.name)
        self.step = state_dict.get('step', 0)
147 148
        self.layer_structure = state_dict.get('layer_structure', None)
        self.add_layer_norm = state_dict.get('is_layer_norm', False)
A
AUTOMATIC 已提交
149 150 151 152
        self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
        self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)


A
AUTOMATIC 已提交
153
def list_hypernetworks(path):
A
AUTOMATIC 已提交
154
    res = {}
A
AUTOMATIC 已提交
155 156 157 158
    for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
        name = os.path.splitext(os.path.basename(filename))[0]
        res[name] = filename
    return res
A
AUTOMATIC 已提交
159

A
AUTOMATIC 已提交
160 161 162 163 164

def load_hypernetwork(filename):
    path = shared.hypernetworks.get(filename, None)
    if path is not None:
        print(f"Loading hypernetwork {filename}")
A
AUTOMATIC 已提交
165
        try:
A
AUTOMATIC 已提交
166 167 168
            shared.loaded_hypernetwork = Hypernetwork()
            shared.loaded_hypernetwork.load(path)

A
AUTOMATIC 已提交
169
        except Exception:
A
AUTOMATIC 已提交
170
            print(f"Error loading hypernetwork {path}", file=sys.stderr)
A
AUTOMATIC 已提交
171
            print(traceback.format_exc(), file=sys.stderr)
A
AUTOMATIC 已提交
172 173 174
    else:
        if shared.loaded_hypernetwork is not None:
            print(f"Unloading hypernetwork")
A
AUTOMATIC 已提交
175

A
AUTOMATIC 已提交
176
        shared.loaded_hypernetwork = None
A
AUTOMATIC 已提交
177 178


M
Milly 已提交
179 180 181 182 183 184 185 186 187 188 189
def find_closest_hypernetwork_name(search: str):
    if not search:
        return None
    search = search.lower()
    applicable = [name for name in shared.hypernetworks if search in name.lower()]
    if not applicable:
        return None
    applicable = sorted(applicable, key=lambda name: len(name))
    return applicable[0]


A
AUTOMATIC 已提交
190 191
def apply_hypernetwork(hypernetwork, context, layer=None):
    hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
A
AUTOMATIC 已提交
192

A
AUTOMATIC 已提交
193 194
    if hypernetwork_layers is None:
        return context, context
A
AUTOMATIC 已提交
195

A
AUTOMATIC 已提交
196 197 198
    if layer is not None:
        layer.hyper_k = hypernetwork_layers[0]
        layer.hyper_v = hypernetwork_layers[1]
A
AUTOMATIC 已提交
199

A
AUTOMATIC 已提交
200 201 202
    context_k = hypernetwork_layers[0](context)
    context_v = hypernetwork_layers[1](context)
    return context_k, context_v
A
AUTOMATIC 已提交
203 204


A
AUTOMATIC 已提交
205 206 207 208 209
def attention_CrossAttention_forward(self, x, context=None, mask=None):
    h = self.heads

    q = self.to_q(x)
    context = default(context, x)
A
AUTOMATIC 已提交
210

A
AUTOMATIC 已提交
211
    context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
A
AUTOMATIC 已提交
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
    k = self.to_k(context_k)
    v = self.to_v(context_v)

    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

    sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

    if mask is not None:
        mask = rearrange(mask, 'b ... -> b (...)')
        max_neg_value = -torch.finfo(sim.dtype).max
        mask = repeat(mask, 'b j -> (b h) () j', h=h)
        sim.masked_fill_(~mask, max_neg_value)

    # attention, what we cannot get enough of
    attn = sim.softmax(dim=-1)

    out = einsum('b i j, b j d -> b i d', attn, v)
    out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
    return self.to_out(out)


233 234 235 236 237 238 239 240 241 242 243 244 245 246
def stack_conds(conds):
    if len(conds) == 1:
        return torch.stack(conds)

    # same as in reconstruct_multicond_batch
    token_count = max([x.shape[0] for x in conds])
    for i in range(len(conds)):
        if conds[i].shape[0] != token_count:
            last_vector = conds[i][-1:]
            last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
            conds[i] = torch.vstack([conds[i], last_vector_repeated])

    return torch.stack(conds)

247
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
A
AUTOMATIC 已提交
248
    assert hypernetwork_name, 'hypernetwork not selected'
A
AUTOMATIC 已提交
249

A
AUTOMATIC 已提交
250 251 252
    path = shared.hypernetworks.get(hypernetwork_name, None)
    shared.loaded_hypernetwork = Hypernetwork()
    shared.loaded_hypernetwork.load(path)
A
AUTOMATIC 已提交
253 254 255 256 257 258 259

    shared.state.textinfo = "Initializing hypernetwork training..."
    shared.state.job_count = steps

    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')

    log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
260
    unload = shared.opts.unload_models_when_training
A
AUTOMATIC 已提交
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275

    if save_hypernetwork_every > 0:
        hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
        os.makedirs(hypernetwork_dir, exist_ok=True)
    else:
        hypernetwork_dir = None

    if create_image_every > 0:
        images_dir = os.path.join(log_directory, "images")
        os.makedirs(images_dir, exist_ok=True)
    else:
        images_dir = None

    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    with torch.autocast("cuda"):
276
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
277 278 279
    if unload:
        shared.sd_model.cond_stage_model.to(devices.cpu)
        shared.sd_model.first_stage_model.to(devices.cpu)
A
AUTOMATIC 已提交
280

A
AUTOMATIC 已提交
281
    hypernetwork = shared.loaded_hypernetwork
A
AUTOMATIC 已提交
282 283 284 285 286 287 288 289 290 291 292 293 294
    weights = hypernetwork.weights()
    for weight in weights:
        weight.requires_grad = True

    losses = torch.zeros((32,))

    last_saved_file = "<none>"
    last_saved_image = "<none>"

    ititial_step = hypernetwork.step or 0
    if ititial_step > steps:
        return hypernetwork, filename

295 296
    scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
    optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
A
AUTOMATIC 已提交
297

A
AUTOMATIC 已提交
298
    pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
299
    for i, entries in pbar:
A
AUTOMATIC 已提交
300 301
        hypernetwork.step = i + ititial_step

302 303 304
        scheduler.apply(optimizer, hypernetwork.step)
        if scheduler.finished:
            break
A
AUTOMATIC 已提交
305 306 307 308 309

        if shared.state.interrupted:
            break

        with torch.autocast("cuda"):
310
            c = stack_conds([entry.cond for entry in entries]).to(devices.device)
D
update  
discus0434 已提交
311
            # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
312 313
            x = torch.stack([entry.latent for entry in entries]).to(devices.device)
            loss = shared.sd_model(x, c)[0]
A
AUTOMATIC 已提交
314
            del x
315
            del c
A
AUTOMATIC 已提交
316 317 318 319 320 321

            losses[hypernetwork.step % losses.shape[0]] = loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
322 323 324 325
        mean_loss = losses.mean()
        if torch.isnan(mean_loss):
            raise RuntimeError("Loss diverged.")
        pbar.set_description(f"loss: {mean_loss:.7f}")
A
AUTOMATIC 已提交
326 327 328 329 330

        if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
            last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
            hypernetwork.save(last_saved_file)

331
        textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
332
            "loss": f"{mean_loss:.7f}",
D
update  
discus0434 已提交
333
            "learn_rate": scheduler.learn_rate
334
        })
335

A
AUTOMATIC 已提交
336 337 338
        if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
            last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')

339 340 341 342
            optimizer.zero_grad()
            shared.sd_model.cond_stage_model.to(devices.device)
            shared.sd_model.first_stage_model.to(devices.device)

A
AUTOMATIC 已提交
343 344 345 346 347 348
            p = processing.StableDiffusionProcessingTxt2Img(
                sd_model=shared.sd_model,
                do_not_save_grid=True,
                do_not_save_samples=True,
            )

349 350 351 352 353 354 355 356 357 358
            if preview_from_txt2img:
                p.prompt = preview_prompt
                p.negative_prompt = preview_negative_prompt
                p.steps = preview_steps
                p.sampler_index = preview_sampler_index
                p.cfg_scale = preview_cfg_scale
                p.seed = preview_seed
                p.width = preview_width
                p.height = preview_height
            else:
359
                p.prompt = entries[0].cond_text
360 361 362 363
                p.steps = 20

            preview_text = p.prompt

A
AUTOMATIC 已提交
364
            processed = processing.process_images(p)
365
            image = processed.images[0] if len(processed.images)>0 else None
A
AUTOMATIC 已提交
366

367 368 369 370
            if unload:
                shared.sd_model.cond_stage_model.to(devices.cpu)
                shared.sd_model.first_stage_model.to(devices.cpu)

371 372 373 374
            if image is not None:
                shared.state.current_image = image
                image.save(last_saved_image)
                last_saved_image += f", prompt: {preview_text}"
A
AUTOMATIC 已提交
375 376 377 378 379

        shared.state.job_no = hypernetwork.step

        shared.state.textinfo = f"""
<p>
380
Loss: {mean_loss:.7f}<br/>
A
AUTOMATIC 已提交
381
Step: {hypernetwork.step}<br/>
382
Last prompt: {html.escape(entries[0].cond_text)}<br/>
A
AUTOMATIC 已提交
383 384 385 386 387 388 389 390 391 392 393 394 395 396
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""

    checkpoint = sd_models.select_checkpoint()

    hypernetwork.sd_checkpoint = checkpoint.hash
    hypernetwork.sd_checkpoint_name = checkpoint.model_name
    hypernetwork.save(filename)

    return hypernetwork, filename