hypernetwork.py 12.0 KB
Newer Older
A
AUTOMATIC 已提交
1 2 3 4 5 6 7
import datetime
import glob
import html
import os
import sys
import traceback
import tqdm
8
import csv
A
AUTOMATIC 已提交
9 10 11 12 13 14 15 16 17

import torch

from ldm.util import default
from modules import devices, shared, processing, sd_models
import torch
from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
18
from modules.textual_inversion import textual_inversion
19
from modules.textual_inversion.learn_schedule import LearnRateScheduler
A
AUTOMATIC 已提交
20 21 22


class HypernetworkModule(torch.nn.Module):
A
AUTOMATIC 已提交
23 24
    multiplier = 1.0

A
AUTOMATIC 已提交
25 26 27 28 29 30 31 32 33
    def __init__(self, dim, state_dict=None):
        super().__init__()

        self.linear1 = torch.nn.Linear(dim, dim * 2)
        self.linear2 = torch.nn.Linear(dim * 2, dim)

        if state_dict is not None:
            self.load_state_dict(state_dict, strict=True)
        else:
A
AUTOMATIC 已提交
34 35 36 37 38

            self.linear1.weight.data.normal_(mean=0.0, std=0.01)
            self.linear1.bias.data.zero_()
            self.linear2.weight.data.normal_(mean=0.0, std=0.01)
            self.linear2.bias.data.zero_()
A
AUTOMATIC 已提交
39 40 41 42

        self.to(devices.device)

    def forward(self, x):
A
AUTOMATIC 已提交
43 44 45 46 47
        return x + (self.linear2(self.linear1(x))) * self.multiplier


def apply_strength(value=None):
    HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength
A
AUTOMATIC 已提交
48 49 50 51 52 53


class Hypernetwork:
    filename = None
    name = None

54
    def __init__(self, name=None, enable_sizes=None):
A
AUTOMATIC 已提交
55 56 57 58 59 60 61
        self.filename = None
        self.name = name
        self.layers = {}
        self.step = 0
        self.sd_checkpoint = None
        self.sd_checkpoint_name = None

62
        for size in enable_sizes or []:
A
AUTOMATIC 已提交
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
            self.layers[size] = (HypernetworkModule(size), HypernetworkModule(size))

    def weights(self):
        res = []

        for k, layers in self.layers.items():
            for layer in layers:
                layer.train()
                res += [layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias]

        return res

    def save(self, filename):
        state_dict = {}

        for k, v in self.layers.items():
            state_dict[k] = (v[0].state_dict(), v[1].state_dict())

        state_dict['step'] = self.step
        state_dict['name'] = self.name
        state_dict['sd_checkpoint'] = self.sd_checkpoint
        state_dict['sd_checkpoint_name'] = self.sd_checkpoint_name

        torch.save(state_dict, filename)

    def load(self, filename):
        self.filename = filename
        if self.name is None:
            self.name = os.path.splitext(os.path.basename(filename))[0]

        state_dict = torch.load(filename, map_location='cpu')

        for size, sd in state_dict.items():
            if type(size) == int:
                self.layers[size] = (HypernetworkModule(size, sd[0]), HypernetworkModule(size, sd[1]))

        self.name = state_dict.get('name', self.name)
        self.step = state_dict.get('step', 0)
        self.sd_checkpoint = state_dict.get('sd_checkpoint', None)
        self.sd_checkpoint_name = state_dict.get('sd_checkpoint_name', None)


A
AUTOMATIC 已提交
105
def list_hypernetworks(path):
A
AUTOMATIC 已提交
106
    res = {}
A
AUTOMATIC 已提交
107 108 109 110
    for filename in glob.iglob(os.path.join(path, '**/*.pt'), recursive=True):
        name = os.path.splitext(os.path.basename(filename))[0]
        res[name] = filename
    return res
A
AUTOMATIC 已提交
111

A
AUTOMATIC 已提交
112 113 114 115 116

def load_hypernetwork(filename):
    path = shared.hypernetworks.get(filename, None)
    if path is not None:
        print(f"Loading hypernetwork {filename}")
A
AUTOMATIC 已提交
117
        try:
A
AUTOMATIC 已提交
118 119 120
            shared.loaded_hypernetwork = Hypernetwork()
            shared.loaded_hypernetwork.load(path)

A
AUTOMATIC 已提交
121
        except Exception:
A
AUTOMATIC 已提交
122
            print(f"Error loading hypernetwork {path}", file=sys.stderr)
A
AUTOMATIC 已提交
123
            print(traceback.format_exc(), file=sys.stderr)
A
AUTOMATIC 已提交
124 125 126
    else:
        if shared.loaded_hypernetwork is not None:
            print(f"Unloading hypernetwork")
A
AUTOMATIC 已提交
127

A
AUTOMATIC 已提交
128
        shared.loaded_hypernetwork = None
A
AUTOMATIC 已提交
129 130


M
Milly 已提交
131 132 133 134 135 136 137 138 139 140 141
def find_closest_hypernetwork_name(search: str):
    if not search:
        return None
    search = search.lower()
    applicable = [name for name in shared.hypernetworks if search in name.lower()]
    if not applicable:
        return None
    applicable = sorted(applicable, key=lambda name: len(name))
    return applicable[0]


A
AUTOMATIC 已提交
142 143
def apply_hypernetwork(hypernetwork, context, layer=None):
    hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
A
AUTOMATIC 已提交
144

A
AUTOMATIC 已提交
145 146
    if hypernetwork_layers is None:
        return context, context
A
AUTOMATIC 已提交
147

A
AUTOMATIC 已提交
148 149 150
    if layer is not None:
        layer.hyper_k = hypernetwork_layers[0]
        layer.hyper_v = hypernetwork_layers[1]
A
AUTOMATIC 已提交
151

A
AUTOMATIC 已提交
152 153 154
    context_k = hypernetwork_layers[0](context)
    context_v = hypernetwork_layers[1](context)
    return context_k, context_v
A
AUTOMATIC 已提交
155 156


A
AUTOMATIC 已提交
157 158 159 160 161
def attention_CrossAttention_forward(self, x, context=None, mask=None):
    h = self.heads

    q = self.to_q(x)
    context = default(context, x)
A
AUTOMATIC 已提交
162

A
AUTOMATIC 已提交
163
    context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self)
A
AUTOMATIC 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
    k = self.to_k(context_k)
    v = self.to_v(context_v)

    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

    sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

    if mask is not None:
        mask = rearrange(mask, 'b ... -> b (...)')
        max_neg_value = -torch.finfo(sim.dtype).max
        mask = repeat(mask, 'b j -> (b h) () j', h=h)
        sim.masked_fill_(~mask, max_neg_value)

    # attention, what we cannot get enough of
    attn = sim.softmax(dim=-1)

    out = einsum('b i j, b j d -> b i d', attn, v)
    out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
    return self.to_out(out)


185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
def stack_conds(conds):
    if len(conds) == 1:
        return torch.stack(conds)

    # same as in reconstruct_multicond_batch
    token_count = max([x.shape[0] for x in conds])
    for i in range(len(conds)):
        if conds[i].shape[0] != token_count:
            last_vector = conds[i][-1:]
            last_vector_repeated = last_vector.repeat([token_count - conds[i].shape[0], 1])
            conds[i] = torch.vstack([conds[i], last_vector_repeated])

    return torch.stack(conds)

def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
A
AUTOMATIC 已提交
200
    assert hypernetwork_name, 'hypernetwork not selected'
A
AUTOMATIC 已提交
201

A
AUTOMATIC 已提交
202 203 204
    path = shared.hypernetworks.get(hypernetwork_name, None)
    shared.loaded_hypernetwork = Hypernetwork()
    shared.loaded_hypernetwork.load(path)
A
AUTOMATIC 已提交
205 206 207 208 209 210 211

    shared.state.textinfo = "Initializing hypernetwork training..."
    shared.state.job_count = steps

    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')

    log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), hypernetwork_name)
212
    unload = shared.opts.unload_models_when_training
A
AUTOMATIC 已提交
213 214 215 216 217 218 219 220 221 222 223 224 225 226 227

    if save_hypernetwork_every > 0:
        hypernetwork_dir = os.path.join(log_directory, "hypernetworks")
        os.makedirs(hypernetwork_dir, exist_ok=True)
    else:
        hypernetwork_dir = None

    if create_image_every > 0:
        images_dir = os.path.join(log_directory, "images")
        os.makedirs(images_dir, exist_ok=True)
    else:
        images_dir = None

    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    with torch.autocast("cuda"):
228
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
229 230 231 232

    if unload:
        shared.sd_model.cond_stage_model.to(devices.cpu)
        shared.sd_model.first_stage_model.to(devices.cpu)
A
AUTOMATIC 已提交
233

A
AUTOMATIC 已提交
234
    hypernetwork = shared.loaded_hypernetwork
A
AUTOMATIC 已提交
235 236 237 238 239 240 241 242 243 244 245 246 247
    weights = hypernetwork.weights()
    for weight in weights:
        weight.requires_grad = True

    losses = torch.zeros((32,))

    last_saved_file = "<none>"
    last_saved_image = "<none>"

    ititial_step = hypernetwork.step or 0
    if ititial_step > steps:
        return hypernetwork, filename

248 249
    scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
    optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
A
AUTOMATIC 已提交
250

A
AUTOMATIC 已提交
251
    pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
252
    for i, entries in pbar:
A
AUTOMATIC 已提交
253 254
        hypernetwork.step = i + ititial_step

255 256 257
        scheduler.apply(optimizer, hypernetwork.step)
        if scheduler.finished:
            break
A
AUTOMATIC 已提交
258 259 260 261 262

        if shared.state.interrupted:
            break

        with torch.autocast("cuda"):
263 264 265 266
            c = stack_conds([entry.cond for entry in entries]).to(devices.device)
#            c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
            x = torch.stack([entry.latent for entry in entries]).to(devices.device)
            loss = shared.sd_model(x, c)[0]
A
AUTOMATIC 已提交
267
            del x
268
            del c
A
AUTOMATIC 已提交
269 270 271 272 273 274 275 276 277 278 279 280 281

            losses[hypernetwork.step % losses.shape[0]] = loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        pbar.set_description(f"loss: {losses.mean():.7f}")

        if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
            last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
            hypernetwork.save(last_saved_file)

282 283 284 285
        textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
            "loss": f"{losses.mean():.7f}",
            "learn_rate": scheduler.learn_rate
        })
286

A
AUTOMATIC 已提交
287 288 289
        if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
            last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')

290 291 292 293
            optimizer.zero_grad()
            shared.sd_model.cond_stage_model.to(devices.device)
            shared.sd_model.first_stage_model.to(devices.device)

A
AUTOMATIC 已提交
294 295 296 297 298 299
            p = processing.StableDiffusionProcessingTxt2Img(
                sd_model=shared.sd_model,
                do_not_save_grid=True,
                do_not_save_samples=True,
            )

300 301 302 303 304 305 306 307 308 309
            if preview_from_txt2img:
                p.prompt = preview_prompt
                p.negative_prompt = preview_negative_prompt
                p.steps = preview_steps
                p.sampler_index = preview_sampler_index
                p.cfg_scale = preview_cfg_scale
                p.seed = preview_seed
                p.width = preview_width
                p.height = preview_height
            else:
310
                p.prompt = entries[0].cond_text
311 312 313 314
                p.steps = 20

            preview_text = p.prompt

A
AUTOMATIC 已提交
315
            processed = processing.process_images(p)
316
            image = processed.images[0] if len(processed.images)>0 else None
A
AUTOMATIC 已提交
317

318 319 320 321
            if unload:
                shared.sd_model.cond_stage_model.to(devices.cpu)
                shared.sd_model.first_stage_model.to(devices.cpu)

322 323 324 325
            if image is not None:
                shared.state.current_image = image
                image.save(last_saved_image)
                last_saved_image += f", prompt: {preview_text}"
A
AUTOMATIC 已提交
326 327 328 329 330 331 332

        shared.state.job_no = hypernetwork.step

        shared.state.textinfo = f"""
<p>
Loss: {losses.mean():.7f}<br/>
Step: {hypernetwork.step}<br/>
333
Last prompt: {html.escape(entries[0].cond_text)}<br/>
A
AUTOMATIC 已提交
334 335 336 337 338 339 340 341 342 343 344 345 346 347
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""

    checkpoint = sd_models.select_checkpoint()

    hypernetwork.sd_checkpoint = checkpoint.hash
    hypernetwork.sd_checkpoint_name = checkpoint.model_name
    hypernetwork.save(filename)

    return hypernetwork, filename