textual_inversion.py 17.9 KB
Newer Older
1 2 3 4 5 6 7 8
import os
import sys
import traceback

import torch
import tqdm
import html
import datetime
9
import csv
10

D
DepFA 已提交
11
from PIL import Image, PngImagePlugin
12

13
from modules import shared, devices, sd_hijack, processing, sd_models, images
14
import modules.textual_inversion.dataset
15
from modules.textual_inversion.learn_schedule import LearnRateScheduler
16

D
DepFA 已提交
17 18 19
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
                                                       insert_image_data_embed, extract_image_data_embed,
                                                       caption_image_overlay)
20 21 22 23 24 25 26

class Embedding:
    def __init__(self, vec, name, step=None):
        self.vec = vec
        self.name = name
        self.step = step
        self.cached_checksum = None
27 28
        self.sd_checkpoint = None
        self.sd_checkpoint_name = None
29 30 31 32 33 34 35

    def save(self, filename):
        embedding_data = {
            "string_to_token": {"*": 265},
            "string_to_param": {"*": self.vec},
            "name": self.name,
            "step": self.step,
36 37
            "sd_checkpoint": self.sd_checkpoint,
            "sd_checkpoint_name": self.sd_checkpoint_name,
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
        }

        torch.save(embedding_data, filename)

    def checksum(self):
        if self.cached_checksum is not None:
            return self.cached_checksum

        def const_hash(a):
            r = 0
            for v in a:
                r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
            return r

        self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
        return self.cached_checksum

55

56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
class EmbeddingDatabase:
    def __init__(self, embeddings_dir):
        self.ids_lookup = {}
        self.word_embeddings = {}
        self.dir_mtime = None
        self.embeddings_dir = embeddings_dir

    def register_embedding(self, embedding, model):

        self.word_embeddings[embedding.name] = embedding

        ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]

        first_id = ids[0]
        if first_id not in self.ids_lookup:
            self.ids_lookup[first_id] = []
72 73

        self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88

        return embedding

    def load_textual_inversion_embeddings(self):
        mt = os.path.getmtime(self.embeddings_dir)
        if self.dir_mtime is not None and mt <= self.dir_mtime:
            return

        self.dir_mtime = mt
        self.ids_lookup.clear()
        self.word_embeddings.clear()

        def process_file(path, filename):
            name = os.path.splitext(filename)[0]

D
DepFA 已提交
89 90
            data = []

D
DepFA 已提交
91
            if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
D
DepFA 已提交
92
                embed_image = Image.open(path)
D
DepFA 已提交
93
                if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
94
                    data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
D
DepFA 已提交
95
                    name = data.get('name', name)
D
DepFA 已提交
96
                else:
97
                    data = extract_image_data_embed(embed_image)
D
DepFA 已提交
98
                    name = data.get('name', name)
D
DepFA 已提交
99 100
            else:
                data = torch.load(path, map_location="cpu")
101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121

            # textual inversion embeddings
            if 'string_to_param' in data:
                param_dict = data['string_to_param']
                if hasattr(param_dict, '_parameters'):
                    param_dict = getattr(param_dict, '_parameters')  # fix for torch 1.12.1 loading saved file from torch 1.11
                assert len(param_dict) == 1, 'embedding file has multiple terms in it'
                emb = next(iter(param_dict.items()))[1]
            # diffuser concepts
            elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
                assert len(data.keys()) == 1, 'embedding file has multiple terms in it'

                emb = next(iter(data.values()))
                if len(emb.shape) == 1:
                    emb = emb.unsqueeze(0)
            else:
                raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")

            vec = emb.detach().to(devices.device, dtype=torch.float32)
            embedding = Embedding(vec, name)
            embedding.step = data.get('step', None)
122
            embedding.sd_checkpoint = data.get('sd_checkpoint', None)
123
            embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
            self.register_embedding(embedding, shared.sd_model)

        for fn in os.listdir(self.embeddings_dir):
            try:
                fullfn = os.path.join(self.embeddings_dir, fn)

                if os.stat(fullfn).st_size == 0:
                    continue

                process_file(fullfn, fn)
            except Exception:
                print(f"Error loading emedding {fn}:", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)
                continue

        print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
D
DepFA 已提交
140
        print("Embeddings:", ', '.join(self.word_embeddings.keys()))
141 142 143 144 145 146

    def find_embedding_at_position(self, tokens, offset):
        token = tokens[offset]
        possible_matches = self.ids_lookup.get(token, None)

        if possible_matches is None:
147
            return None, None
148 149 150

        for ids, embedding in possible_matches:
            if tokens[offset:offset + len(ids)] == ids:
151
                return embedding, len(ids)
152

153
        return None, None
154 155


D
DepFA 已提交
156
def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
157 158 159
    cond_model = shared.sd_model.cond_stage_model
    embedding_layer = cond_model.wrapped.transformer.text_model.embeddings

160 161 162
    with devices.autocast():
        cond_model([""])  # will send cond model to GPU if lowvram/medvram is active

163
    ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
164
    embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
165 166 167 168 169
    vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)

    for i in range(num_vectors_per_token):
        vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]

170 171
    # Remove illegal characters from name.
    name = "".join( x for x in name if (x.isalnum() or x in "._- "))
172
    fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
D
DepFA 已提交
173 174
    if not overwrite_old:
        assert not os.path.exists(fn), f"file {fn} already exists"
175 176 177 178 179 180 181 182

    embedding = Embedding(vec, name)
    embedding.step = 0
    embedding.save(fn)

    return fn


183 184 185 186
def write_loss(log_directory, filename, step, epoch_len, values):
    if shared.opts.training_write_csv_every == 0:
        return

M
Muhammad Rizqi Nur 已提交
187
    if (step + 1) % shared.opts.training_write_csv_every != 0:
188 189 190 191 192 193 194 195 196 197
        return
    write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True

    with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
        csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])

        if write_csv_header:
            csv_writer.writeheader()

        epoch = step // epoch_len
M
Muhammad Rizqi Nur 已提交
198
        epoch_step = step % epoch_len 
199 200 201

        csv_writer.writerow({
            "step": step + 1,
M
Muhammad Rizqi Nur 已提交
202
            "epoch": epoch,
203 204 205 206
            "epoch_step": epoch_step + 1,
            **values,
        })

207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
def validate_train_inputs(model_name, learn_rate, batch_size, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
    assert model_name, f"{name} not selected"
    assert learn_rate, "Learning rate is empty or 0"
    assert isinstance(batch_size, int), "Batch size must be integer"
    assert batch_size > 0, "Batch size must be positive"
    assert data_root, "Dataset directory is empty"
    assert os.path.isdir(data_root), "Dataset directory doesn't exist"
    assert os.listdir(data_root), "Dataset directory is empty"
    assert template_file, "Prompt template file is empty"
    assert os.path.isfile(template_file), "Prompt template file doesn't exist"
    assert steps, "Max steps is empty or 0"
    assert isinstance(steps, int), "Max steps must be integer"
    assert steps > 0 , "Max steps must be positive"
    assert isinstance(save_model_every, int), "Save {name} must be integer"
    assert save_model_every >= 0 , "Save {name} must be positive or 0"
    assert isinstance(create_image_every, int), "Create image must be integer"
    assert create_image_every >= 0 , "Create image must be positive or 0"
    if save_model_every or create_image_every:
        assert log_directory, "Log directory is empty"
226

227
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
228 229 230
    save_embedding_every = save_embedding_every or 0
    create_image_every = create_image_every or 0
    validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
231 232 233 234 235 236

    shared.state.textinfo = "Initializing textual inversion training..."
    shared.state.job_count = steps

    filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')

A
AUTOMATIC 已提交
237
    log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
F
Fampai 已提交
238
    unload = shared.opts.unload_models_when_training
239 240 241 242 243 244 245 246 247 248 249 250 251

    if save_embedding_every > 0:
        embedding_dir = os.path.join(log_directory, "embeddings")
        os.makedirs(embedding_dir, exist_ok=True)
    else:
        embedding_dir = None

    if create_image_every > 0:
        images_dir = os.path.join(log_directory, "images")
        os.makedirs(images_dir, exist_ok=True)
    else:
        images_dir = None

D
DepFA 已提交
252 253 254 255 256
    if create_image_every > 0 and save_image_with_stored_embedding:
        images_embeds_dir = os.path.join(log_directory, "image_embeddings")
        os.makedirs(images_embeds_dir, exist_ok=True)
    else:
        images_embeds_dir = None
257

258 259
    cond_model = shared.sd_model.cond_stage_model

260 261 262
    hijack = sd_hijack.model_hijack

    embedding = hijack.embedding_db.word_embeddings[embedding_name]
263
    checkpoint = sd_models.select_checkpoint()
264 265

    ititial_step = embedding.step or 0
266
    if ititial_step >= steps:
267 268 269 270 271 272
        shared.state.textinfo = f"Model has already been trained beyond specified max steps"
        return embedding, filename

    scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)

    # dataset loading may take a while, so input validations and early returns should be done before this
273 274
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    with torch.autocast("cuda"):
275
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
F
Fampai 已提交
276 277
    if unload:
        shared.sd_model.first_stage_model.to(devices.cpu)
278 279

    embedding.vec.requires_grad = True
280
    optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
281 282 283 284 285

    losses = torch.zeros((32,))

    last_saved_file = "<none>"
    last_saved_image = "<none>"
286
    forced_filename = "<none>"
D
DepFA 已提交
287
    embedding_yet_to_be_embedded = False
288 289

    pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
290
    for i, entries in pbar:
291 292
        embedding.step = i + ititial_step

293 294 295
        scheduler.apply(optimizer, embedding.step)
        if scheduler.finished:
            break
296 297 298 299 300

        if shared.state.interrupted:
            break

        with torch.autocast("cuda"):
301 302 303
            c = cond_model([entry.cond_text for entry in entries])
            x = torch.stack([entry.latent for entry in entries]).to(devices.device)
            loss = shared.sd_model(x, c)[0]
304
            del x
305 306 307 308 309 310 311

            losses[embedding.step % losses.shape[0]] = loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

M
Muhammad Rizqi Nur 已提交
312
        steps_done = embedding.step + 1
313

A
alg-wiki 已提交
314
        epoch_num = embedding.step // len(ds)
M
Muhammad Rizqi Nur 已提交
315
        epoch_step = embedding.step % len(ds)
316

M
Muhammad Rizqi Nur 已提交
317
        pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{len(ds)}]loss: {losses.mean():.7f}")
318

M
Muhammad Rizqi Nur 已提交
319
        if embedding_dir is not None and steps_done % save_embedding_every == 0:
320
            # Before saving, change name to match current checkpoint.
321 322 323
            embedding_name_every = f'{embedding_name}-{steps_done}'
            last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
            save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
D
DepFA 已提交
324
            embedding_yet_to_be_embedded = True
325

326 327 328 329
        write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
            "loss": f"{losses.mean():.7f}",
            "learn_rate": scheduler.learn_rate
        })
330

M
Muhammad Rizqi Nur 已提交
331 332
        if images_dir is not None and steps_done % create_image_every == 0:
            forced_filename = f'{embedding_name}-{steps_done}'
333
            last_saved_image = os.path.join(images_dir, forced_filename)
F
Fampai 已提交
334 335 336

            shared.sd_model.first_stage_model.to(devices.device)

337 338 339 340
            p = processing.StableDiffusionProcessingTxt2Img(
                sd_model=shared.sd_model,
                do_not_save_grid=True,
                do_not_save_samples=True,
341
                do_not_reload_embeddings=True,
342 343
            )

344 345 346 347 348 349 350 351 352 353
            if preview_from_txt2img:
                p.prompt = preview_prompt
                p.negative_prompt = preview_negative_prompt
                p.steps = preview_steps
                p.sampler_index = preview_sampler_index
                p.cfg_scale = preview_cfg_scale
                p.seed = preview_seed
                p.width = preview_width
                p.height = preview_height
            else:
354
                p.prompt = entries[0].cond_text
355 356 357 358 359 360
                p.steps = 20
                p.width = training_width
                p.height = training_height

            preview_text = p.prompt

361 362 363
            processed = processing.process_images(p)
            image = processed.images[0]

F
Fampai 已提交
364 365 366
            if unload:
                shared.sd_model.first_stage_model.to(devices.cpu)

367
            shared.state.current_image = image
D
DepFA 已提交
368

D
DepFA 已提交
369
            if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
D
DepFA 已提交
370

M
Muhammad Rizqi Nur 已提交
371
                last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
D
DepFA 已提交
372

D
DepFA 已提交
373
                info = PngImagePlugin.PngInfo()
D
DepFA 已提交
374
                data = torch.load(last_saved_file)
375
                info.add_text("sd-ti-embedding", embedding_to_b64(data))
D
DepFA 已提交
376

D
DepFA 已提交
377
                title = "<{}>".format(data.get('name', '???'))
D
DepFA 已提交
378 379 380 381 382 383

                try:
                    vectorSize = list(data['string_to_param'].values())[0].shape[0]
                except Exception as e:
                    vectorSize = '?'

384
                checkpoint = sd_models.select_checkpoint()
D
DepFA 已提交
385 386
                footer_left = checkpoint.model_name
                footer_mid = '[{}]'.format(checkpoint.hash)
M
Muhammad Rizqi Nur 已提交
387
                footer_right = '{}v {}s'.format(vectorSize, steps_done)
D
DepFA 已提交
388

D
DepFA 已提交
389 390
                captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
                captioned_image = insert_image_data_embed(captioned_image, data)
D
DepFA 已提交
391 392

                captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
D
DepFA 已提交
393
                embedding_yet_to_be_embedded = False
D
DepFA 已提交
394

395
            last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
A
AUTOMATIC 已提交
396
            last_saved_image += f", prompt: {preview_text}"
397 398 399 400 401 402 403

        shared.state.job_no = embedding.step

        shared.state.textinfo = f"""
<p>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
404
Last prompt: {html.escape(entries[0].cond_text)}<br/>
405 406 407 408 409
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""

410 411
    filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
    save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
F
Fampai 已提交
412
    shared.sd_model.first_stage_model.to(devices.device)
413 414

    return embedding, filename
415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433

def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
    old_embedding_name = embedding.name
    old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
    old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
    old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
    try:
        embedding.sd_checkpoint = checkpoint.hash
        embedding.sd_checkpoint_name = checkpoint.model_name
        if remove_cached_checksum:
            embedding.cached_checksum = None
        embedding.name = embedding_name
        embedding.save(filename)
    except:
        embedding.sd_checkpoint = old_sd_checkpoint
        embedding.sd_checkpoint_name = old_sd_checkpoint_name
        embedding.name = old_embedding_name
        embedding.cached_checksum = old_cached_checksum
        raise