import os import sys import traceback import torch import tqdm import html import datetime import csv from PIL import Image, PngImagePlugin from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay) class Embedding: def __init__(self, vec, name, step=None): self.vec = vec self.name = name self.step = step self.shape = None self.vectors = 0 self.cached_checksum = None self.sd_checkpoint = None self.sd_checkpoint_name = None self.optimizer_state_dict = None def save(self, filename): embedding_data = { "string_to_token": {"*": 265}, "string_to_param": {"*": self.vec}, "name": self.name, "step": self.step, "sd_checkpoint": self.sd_checkpoint, "sd_checkpoint_name": self.sd_checkpoint_name, } torch.save(embedding_data, filename) if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None: optimizer_saved_dict = { 'hash': self.checksum(), 'optimizer_state_dict': self.optimizer_state_dict, } torch.save(optimizer_saved_dict, filename + '.optim') def checksum(self): if self.cached_checksum is not None: return self.cached_checksum def const_hash(a): r = 0 for v in a: r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF return r self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' return self.cached_checksum class EmbeddingDatabase: def __init__(self, embeddings_dir): self.ids_lookup = {} self.word_embeddings = {} self.skipped_embeddings = {} self.dir_mtime = None self.embeddings_dir = embeddings_dir self.expected_shape = -1 def register_embedding(self, embedding, model): self.word_embeddings[embedding.name] = embedding # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working ids = model.cond_stage_model.tokenize([embedding.name])[0] first_id = ids[0] if first_id not in self.ids_lookup: self.ids_lookup[first_id] = [] self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True) return embedding def get_expected_shape(self): vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1) return vec.shape[1] def load_textual_inversion_embeddings(self, force_reload = False): mt = os.path.getmtime(self.embeddings_dir) if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime: return self.dir_mtime = mt self.ids_lookup.clear() self.word_embeddings.clear() self.skipped_embeddings.clear() self.expected_shape = self.get_expected_shape() def process_file(path, filename): name, ext = os.path.splitext(filename) ext = ext.upper() if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']: embed_image = Image.open(path) if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text: data = embedding_from_b64(embed_image.text['sd-ti-embedding']) name = data.get('name', name) else: data = extract_image_data_embed(embed_image) name = data.get('name', name) elif ext in ['.BIN', '.PT']: data = torch.load(path, map_location="cpu") else: return # textual inversion embeddings if 'string_to_param' in data: param_dict = data['string_to_param'] if hasattr(param_dict, '_parameters'): param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 assert len(param_dict) == 1, 'embedding file has multiple terms in it' emb = next(iter(param_dict.items()))[1] # diffuser concepts elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: assert len(data.keys()) == 1, 'embedding file has multiple terms in it' emb = next(iter(data.values())) if len(emb.shape) == 1: emb = emb.unsqueeze(0) else: raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") vec = emb.detach().to(devices.device, dtype=torch.float32) embedding = Embedding(vec, name) embedding.step = data.get('step', None) embedding.sd_checkpoint = data.get('sd_checkpoint', None) embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) embedding.vectors = vec.shape[0] embedding.shape = vec.shape[-1] if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) else: self.skipped_embeddings[name] = embedding for fn in os.listdir(self.embeddings_dir): try: fullfn = os.path.join(self.embeddings_dir, fn) if os.stat(fullfn).st_size == 0: continue process_file(fullfn, fn) except Exception: print(f"Error loading embedding {fn}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) continue print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") if len(self.skipped_embeddings) > 0: print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") def find_embedding_at_position(self, tokens, offset): token = tokens[offset] possible_matches = self.ids_lookup.get(token, None) if possible_matches is None: return None, None for ids, embedding in possible_matches: if tokens[offset:offset + len(ids)] == ids: return embedding, len(ids) return None, None def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): cond_model = shared.sd_model.cond_stage_model with devices.autocast(): cond_model([""]) # will send cond model to GPU if lowvram/medvram is active embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token) vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device) for i in range(num_vectors_per_token): vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token] # Remove illegal characters from name. name = "".join( x for x in name if (x.isalnum() or x in "._- ")) fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt") if not overwrite_old: assert not os.path.exists(fn), f"file {fn} already exists" embedding = Embedding(vec, name) embedding.step = 0 embedding.save(fn) return fn def write_loss(log_directory, filename, step, epoch_len, values): if shared.opts.training_write_csv_every == 0: return if step % shared.opts.training_write_csv_every != 0: return write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True with open(os.path.join(log_directory, filename), "a+", newline='') as fout: csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())]) if write_csv_header: csv_writer.writeheader() epoch = (step - 1) // epoch_len epoch_step = (step - 1) % epoch_len csv_writer.writerow({ "step": step, "epoch": epoch, "epoch_step": epoch_step, **values, }) def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"): assert model_name, f"{name} not selected" assert learn_rate, "Learning rate is empty or 0" assert isinstance(batch_size, int), "Batch size must be integer" assert batch_size > 0, "Batch size must be positive" assert isinstance(gradient_step, int), "Gradient accumulation step must be integer" assert gradient_step > 0, "Gradient accumulation step must be positive" assert data_root, "Dataset directory is empty" assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" assert template_file, "Prompt template file is empty" assert os.path.isfile(template_file), "Prompt template file doesn't exist" assert steps, "Max steps is empty or 0" assert isinstance(steps, int), "Max steps must be integer" assert steps > 0 , "Max steps must be positive" assert isinstance(save_model_every, int), "Save {name} must be integer" assert save_model_every >= 0 , "Save {name} must be positive or 0" assert isinstance(create_image_every, int), "Create image must be integer" assert create_image_every >= 0 , "Create image must be positive or 0" if save_model_every or create_image_every: assert log_directory, "Log directory is empty" def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") shared.state.job = "train-embedding" shared.state.textinfo = "Initializing textual inversion training..." shared.state.job_count = steps filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name) unload = shared.opts.unload_models_when_training if save_embedding_every > 0: embedding_dir = os.path.join(log_directory, "embeddings") os.makedirs(embedding_dir, exist_ok=True) else: embedding_dir = None if create_image_every > 0: images_dir = os.path.join(log_directory, "images") os.makedirs(images_dir, exist_ok=True) else: images_dir = None if create_image_every > 0 and save_image_with_stored_embedding: images_embeds_dir = os.path.join(log_directory, "image_embeddings") os.makedirs(images_embeds_dir, exist_ok=True) else: images_embeds_dir = None hijack = sd_hijack.model_hijack embedding = hijack.embedding_db.word_embeddings[embedding_name] checkpoint = sd_models.select_checkpoint() initial_step = embedding.step or 0 if initial_step >= steps: shared.state.textinfo = "Model has already been trained beyond specified max steps" return embedding, filename scheduler = LearnRateScheduler(learn_rate, steps, initial_step) clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \ None if clip_grad: clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False) # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed pin_memory = shared.opts.pin_memory ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) latent_sampling_method = ds.latent_sampling_method dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory) if unload: shared.parallel_processing_allowed = False shared.sd_model.first_stage_model.to(devices.cpu) embedding.vec.requires_grad = True optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0) if shared.opts.save_optimizer_state: optimizer_state_dict = None if os.path.exists(filename + '.optim'): optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu') if embedding.checksum() == optimizer_saved_dict.get('hash', None): optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None) if optimizer_state_dict is not None: optimizer.load_state_dict(optimizer_state_dict) print("Loaded existing optimizer from checkpoint") else: print("No saved optimizer exists in checkpoint") scaler = torch.cuda.amp.GradScaler() batch_size = ds.batch_size gradient_step = ds.gradient_step # n steps = batch_size * gradient_step * n image processed steps_per_epoch = len(ds) // batch_size // gradient_step max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step loss_step = 0 _loss_step = 0 #internal last_saved_file = "" last_saved_image = "" forced_filename = "" embedding_yet_to_be_embedded = False is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'} img_c = None pbar = tqdm.tqdm(total=steps - initial_step) try: for i in range((steps-initial_step) * gradient_step): if scheduler.finished: break if shared.state.interrupted: break for j, batch in enumerate(dl): # works as a drop_last=True for gradient accumulation if j == max_steps_per_epoch: break scheduler.apply(optimizer, embedding.step) if scheduler.finished: break if shared.state.interrupted: break if clip_grad: clip_grad_sched.step(embedding.step) with devices.autocast(): x = batch.latent_sample.to(devices.device, non_blocking=pin_memory) c = shared.sd_model.cond_stage_model(batch.cond_text) if is_training_inpainting_model: if img_c is None: img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height) cond = {"c_concat": [img_c], "c_crossattn": [c]} else: cond = c loss = shared.sd_model(x, cond)[0] / gradient_step del x _loss_step += loss.item() scaler.scale(loss).backward() # go back until we reach gradient accumulation steps if (j + 1) % gradient_step != 0: continue if clip_grad: clip_grad(embedding.vec, clip_grad_sched.learn_rate) scaler.step(optimizer) scaler.update() embedding.step += 1 pbar.update() optimizer.zero_grad(set_to_none=True) loss_step = _loss_step _loss_step = 0 steps_done = embedding.step + 1 epoch_num = embedding.step // steps_per_epoch epoch_step = embedding.step % steps_per_epoch pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}") if embedding_dir is not None and steps_done % save_embedding_every == 0: # Before saving, change name to match current checkpoint. embedding_name_every = f'{embedding_name}-{steps_done}' last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt') save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True) embedding_yet_to_be_embedded = True write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, { "loss": f"{loss_step:.7f}", "learn_rate": scheduler.learn_rate }) if images_dir is not None and steps_done % create_image_every == 0: forced_filename = f'{embedding_name}-{steps_done}' last_saved_image = os.path.join(images_dir, forced_filename) shared.sd_model.first_stage_model.to(devices.device) p = processing.StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, do_not_save_grid=True, do_not_save_samples=True, do_not_reload_embeddings=True, ) if preview_from_txt2img: p.prompt = preview_prompt p.negative_prompt = preview_negative_prompt p.steps = preview_steps p.sampler_name = sd_samplers.samplers[preview_sampler_index].name p.cfg_scale = preview_cfg_scale p.seed = preview_seed p.width = preview_width p.height = preview_height else: p.prompt = batch.cond_text[0] p.steps = 20 p.width = training_width p.height = training_height preview_text = p.prompt processed = processing.process_images(p) image = processed.images[0] if len(processed.images) > 0 else None if unload: shared.sd_model.first_stage_model.to(devices.cpu) if image is not None: shared.state.current_image = image last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded: last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png') info = PngImagePlugin.PngInfo() data = torch.load(last_saved_file) info.add_text("sd-ti-embedding", embedding_to_b64(data)) title = "<{}>".format(data.get('name', '???')) try: vectorSize = list(data['string_to_param'].values())[0].shape[0] except Exception as e: vectorSize = '?' checkpoint = sd_models.select_checkpoint() footer_left = checkpoint.model_name footer_mid = '[{}]'.format(checkpoint.hash) footer_right = '{}v {}s'.format(vectorSize, steps_done) captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right) captioned_image = insert_image_data_embed(captioned_image, data) captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info) embedding_yet_to_be_embedded = False last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" shared.state.job_no = embedding.step shared.state.textinfo = f"""

Loss: {loss_step:.7f}
Step: {steps_done}
Last prompt: {html.escape(batch.cond_text[0])}
Last saved embedding: {html.escape(last_saved_file)}
Last saved image: {html.escape(last_saved_image)}

""" filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True) except Exception: print(traceback.format_exc(), file=sys.stderr) pass finally: pbar.leave = False pbar.close() shared.sd_model.first_stage_model.to(devices.device) shared.parallel_processing_allowed = old_parallel_processing_allowed return embedding, filename def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True): old_embedding_name = embedding.name old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None try: embedding.sd_checkpoint = checkpoint.hash embedding.sd_checkpoint_name = checkpoint.model_name if remove_cached_checksum: embedding.cached_checksum = None embedding.name = embedding_name embedding.optimizer_state_dict = optimizer.state_dict() embedding.save(filename) except: embedding.sd_checkpoint = old_sd_checkpoint embedding.sd_checkpoint_name = old_sd_checkpoint_name embedding.name = old_embedding_name embedding.cached_checksum = old_cached_checksum raise