diff --git a/.eslintrc.js b/.eslintrc.js index 944cc869ecc51ac81e3a394c2e6898c99e1f24ec..f33aca09fa022638e45e8737386402711e464656 100644 --- a/.eslintrc.js +++ b/.eslintrc.js @@ -50,13 +50,14 @@ module.exports = { globals: { //script.js gradioApp: "readonly", + executeCallbacks: "readonly", + onAfterUiUpdate: "readonly", + onOptionsChanged: "readonly", onUiLoaded: "readonly", onUiUpdate: "readonly", - onOptionsChanged: "readonly", uiCurrentTab: "writable", - uiElementIsVisible: "readonly", uiElementInSight: "readonly", - executeCallbacks: "readonly", + uiElementIsVisible: "readonly", //ui.js opts: "writable", all_gallery_buttons: "readonly", @@ -84,5 +85,7 @@ module.exports = { // imageviewer.js modalPrevImage: "readonly", modalNextImage: "readonly", + // token-counters.js + setupTokenCounters: "readonly", } }; diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index 3a8b99535618f803278493f0cd1e38ce8fa1a983..d80b24e2bde995177bf61022701548c32796b0df 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -43,8 +43,8 @@ body: - type: input id: commit attributes: - label: Commit where the problem happens - description: Which commit are you running ? (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.) + label: Version or Commit where the problem happens + description: "Which webui version or commit are you running ? (Do not write *Latest Version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Version: v1.2.3** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)" validations: required: true - type: dropdown @@ -80,6 +80,23 @@ body: - AMD GPUs (RX 5000 below) - CPU - Other GPUs + - type: dropdown + id: cross_attention_opt + attributes: + label: Cross attention optimization + description: What cross attention optimization are you using, Settings -> Optimizations -> Cross attention optimization + multiple: false + options: + - Automatic + - xformers + - sdp-no-mem + - sdp + - Doggettx + - V1 + - InvokeAI + - "None " + validations: + required: true - type: dropdown id: browsers attributes: diff --git a/CHANGELOG.md b/CHANGELOG.md index 57f2dde7497eb5544ee3939875f6155fad7a56a3..6a31f35bcd3b5fbc585d1607fde93a59be1deb71 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,60 @@ +## 1.4.0 + +### Features: + * zoom controls for inpainting + * run basic torch calculation at startup in parallel to reduce the performance impact of first generation + * option to pad prompt/neg prompt to be same length + * remove taming_transformers dependency + * custom k-diffusion scheduler settings + * add an option to show selected settings in main txt2img/img2img UI + * sysinfo tab in settings + * infer styles from prompts when pasting params into the UI + * an option to control the behavior of the above + +### Minor: + * bump Gradio to 3.32.0 + * bump xformers to 0.0.20 + * Add option to disable token counters + * tooltip fixes & optimizations + * make it possible to configure filename for the zip download + * `[vae_filename]` pattern for filenames + * Revert discarding penultimate sigma for DPM-Solver++(2M) SDE + * change UI reorder setting to multiselect + * read version info form CHANGELOG.md if git version info is not available + * link footer API to Wiki when API is not active + * persistent conds cache (opt-in optimization) + +### Extensions: + * After installing extensions, webui properly restarts the process rather than reloads the UI + * Added VAE listing to web API. Via: /sdapi/v1/sd-vae + * custom unet support + * Add onAfterUiUpdate callback + * refactor EmbeddingDatabase.register_embedding() to allow unregistering + * add before_process callback for scripts + * add ability for alwayson scripts to specify section and let user reorder those sections + +### Bug Fixes: + * Fix dragging text to prompt + * fix incorrect quoting for infotext values with colon in them + * fix "hires. fix" prompt sharing same labels with txt2img_prompt + * Fix s_min_uncond default type int + * Fix for #10643 (Inpainting mask sometimes not working) + * fix bad styling for thumbs view in extra networks #10639 + * fix for empty list of optimizations #10605 + * small fixes to prepare_tcmalloc for Debian/Ubuntu compatibility + * fix --ui-debug-mode exit + * patch GitPython to not use leaky persistent processes + * fix duplicate Cross attention optimization after UI reload + * torch.cuda.is_available() check for SdOptimizationXformers + * fix hires fix using wrong conds in second pass if using Loras. + * handle exception when parsing generation parameters from png info + * fix upcast attention dtype error + * forcing Torch Version to 1.13.1 for RX 5000 series GPUs + * split mask blur into X and Y components, patch Outpainting MK2 accordingly + * don't die when a LoRA is a broken symlink + * allow activation of Generate Forever during generation + + ## 1.3.2 ### Bug Fixes: diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index c4da79f31f83f4d520b43f9291d964036e8bd4d1..dbd6d331dad9def9c3682d4e95fe13bb2c6647b3 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -1,12 +1,10 @@ import os -import sys -import traceback from basicsr.utils.download_util import load_file_from_url from modules.upscaler import Upscaler, UpscalerData from ldsr_model_arch import LDSR -from modules import shared, script_callbacks +from modules import shared, script_callbacks, errors import sd_hijack_autoencoder # noqa: F401 import sd_hijack_ddpm_v1 # noqa: F401 @@ -51,10 +49,8 @@ class UpscalerLDSR(Upscaler): try: return LDSR(model, yaml) - except Exception: - print("Error importing LDSR:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report("Error importing LDSR", exc_info=True) return None def do_upscale(self, img, path): diff --git a/extensions-builtin/LDSR/sd_hijack_autoencoder.py b/extensions-builtin/LDSR/sd_hijack_autoencoder.py index 81c5101b7d7f14f65c51ae91e97b63ee3debb086..c29d274da825d2500b77a2022db3421b40b18886 100644 --- a/extensions-builtin/LDSR/sd_hijack_autoencoder.py +++ b/extensions-builtin/LDSR/sd_hijack_autoencoder.py @@ -10,7 +10,7 @@ from contextlib import contextmanager from torch.optim.lr_scheduler import LambdaLR from ldm.modules.ema import LitEma -from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer +from vqvae_quantize import VectorQuantizer2 as VectorQuantizer from ldm.modules.diffusionmodules.model import Encoder, Decoder from ldm.util import instantiate_from_config @@ -91,8 +91,9 @@ class VQModel(pl.LightningModule): del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: + if missing: print(f"Missing Keys: {missing}") + if unexpected: print(f"Unexpected Keys: {unexpected}") def on_train_batch_end(self, *args, **kwargs): diff --git a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py index 631a08ef0902246b3cc4e35e391db3a79af18346..04adc5eb2cfe9aa1d5f75e5653624456c5e37a47 100644 --- a/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py +++ b/extensions-builtin/LDSR/sd_hijack_ddpm_v1.py @@ -195,9 +195,9 @@ class DDPMV1(pl.LightningModule): missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: + if missing: print(f"Missing Keys: {missing}") - if len(unexpected) > 0: + if unexpected: print(f"Unexpected Keys: {unexpected}") def q_mean_variance(self, x_start, t): diff --git a/extensions-builtin/LDSR/vqvae_quantize.py b/extensions-builtin/LDSR/vqvae_quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..dd14b8fda5ce25a8cea8b70eb1d387b9c46c80d8 --- /dev/null +++ b/extensions-builtin/LDSR/vqvae_quantize.py @@ -0,0 +1,147 @@ +# Vendored from https://raw.githubusercontent.com/CompVis/taming-transformers/24268930bf1dce879235a7fddd0b2355b84d7ea6/taming/modules/vqvae/quantize.py, +# where the license is as follows: +# +# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, +# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE +# OR OTHER DEALINGS IN THE SOFTWARE./ + +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + + +class VectorQuantizer2(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", + sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed + 1 + print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + match = (inds[:, :, None] == used[None, None, ...]).long() + new = match.argmax(-1) + unknown = match.sum(2) < 1 + if self.unknown_index == "random": + new[unknown] = torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape) > 1 + inds = inds.reshape(ishape[0], -1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds >= self.used.shape[0]] = 0 # simply set to zero + back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" + assert rescale_logits is False, "Only for interface compatible with Gumbel" + assert return_logits is False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight ** 2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + \ + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0], -1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index b5fea4d2ec7b75b28953d823771fccf70e99fd38..66ee9c8563919fcda648f58d34278b85b4571f8d 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -9,14 +9,14 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): def activate(self, p, params_list): additional = shared.opts.sd_lora - if additional != "None" and additional in lora.available_loras and len([x for x in params_list if x.items[0] == additional]) == 0: + if additional != "None" and additional in lora.available_loras and not any(x for x in params_list if x.items[0] == additional): p.all_prompts = [x + f"" for x in p.all_prompts] params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) names = [] multipliers = [] for params in params_list: - assert len(params.items) > 0 + assert params.items names.append(params.items[0]) multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index eec147122244a8f5ba97e6d116c798fc16a36660..34ff57dd0e7550c6587cc2c100dbebb8ff266142 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -219,7 +219,7 @@ def load_lora(name, lora_on_disk): else: raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha") - if len(keys_failed_to_match) > 0: + if keys_failed_to_match: print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}") return lora @@ -267,7 +267,7 @@ def load_loras(names, multipliers=None): lora.multiplier = multipliers[i] if multipliers else 1.0 loaded_loras.append(lora) - if len(failed_to_load_loras) > 0: + if failed_to_load_loras: sd_hijack.model_hijack.comments.append("Failed to find Loras: " + ", ".join(failed_to_load_loras)) @@ -448,7 +448,11 @@ def list_available_loras(): continue name = os.path.splitext(os.path.basename(filename))[0] - entry = LoraOnDisk(name, filename) + try: + entry = LoraOnDisk(name, filename) + except OSError: # should catch FileNotFoundError and PermissionError etc. + errors.report(f"Failed to load LoRA {name} from {filename}", exc_info=True) + continue available_loras[name] = entry diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index 259e99ac83783f4a7c5aa13d9fe7aa7635e58b43..da49790b56c0d4d037bd41d61104b9a14ec75bb3 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -13,7 +13,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): lora.list_available_loras() def list_items(self): - for name, lora_on_disk in lora.available_loras.items(): + for index, (name, lora_on_disk) in enumerate(lora.available_loras.items()): path, ext = os.path.splitext(lora_on_disk.filename) alias = lora_on_disk.get_alias() @@ -27,6 +27,8 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "prompt": json.dumps(f""), "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": json.dumps(lora_on_disk.metadata, indent=4) if lora_on_disk.metadata else None, + "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, + } def allowed_directories_for_previews(self): diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 45d9297b6a194b5c089f58fbddf148fba5a2291e..85b4505f6a7882276fdaf978c5c4851b2f89ee38 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -1,6 +1,5 @@ import os.path import sys -import traceback import PIL.Image import numpy as np @@ -10,8 +9,9 @@ from tqdm import tqdm from basicsr.utils.download_util import load_file_from_url import modules.upscaler -from modules import devices, modelloader, script_callbacks +from modules import devices, modelloader, script_callbacks, errors from scunet_model_arch import SCUNet as net + from modules.shared import opts @@ -38,8 +38,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scaler_data = modules.upscaler.UpscalerData(name, file, self, 4) scalers.append(scaler_data) except Exception: - print(f"Error loading ScuNET model: {file}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error loading ScuNET model: {file}", exc_info=True) if add_model2: scaler_data2 = modules.upscaler.UpscalerData(self.model_name2, self.model_url2, self) scalers.append(scaler_data2) diff --git a/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js new file mode 100644 index 0000000000000000000000000000000000000000..4ecb3d3676f51365de35553fd9b84eba8bba0189 --- /dev/null +++ b/extensions-builtin/canvas-zoom-and-pan/javascript/zoom.js @@ -0,0 +1,640 @@ +onUiLoaded(async() => { + const elementIDs = { + img2imgTabs: "#mode_img2img .tab-nav", + inpaint: "#img2maskimg", + inpaintSketch: "#inpaint_sketch", + rangeGroup: "#img2img_column_size", + sketch: "#img2img_sketch", + }; + const tabNameToElementId = { + "Inpaint sketch": elementIDs.inpaintSketch, + "Inpaint": elementIDs.inpaint, + "Sketch": elementIDs.sketch, + }; + + // Helper functions + // Get active tab + function getActiveTab(elements, all = false) { + const tabs = elements.img2imgTabs.querySelectorAll("button"); + + if (all) return tabs; + + for (let tab of tabs) { + if (tab.classList.contains("selected")) { + return tab; + } + } + } + + // Get tab ID + function getTabId(elements) { + const activeTab = getActiveTab(elements); + return tabNameToElementId[activeTab.innerText]; + } + + // Wait until opts loaded + async function waitForOpts() { + for (;;) { + if (window.opts && Object.keys(window.opts).length) { + return window.opts; + } + await new Promise(resolve => setTimeout(resolve, 100)); + } + } + + // Check is hotkey valid + function isSingleLetter(value) { + return ( + typeof value === "string" && value.length === 1 && /[a-z]/i.test(value) + ); + } + + // Create hotkeyConfig from opts + function createHotkeyConfig(defaultHotkeysConfig, hotkeysConfigOpts) { + const result = {}; + const usedKeys = new Set(); + + for (const key in defaultHotkeysConfig) { + if (typeof hotkeysConfigOpts[key] === "boolean") { + result[key] = hotkeysConfigOpts[key]; + continue; + } + if ( + hotkeysConfigOpts[key] && + isSingleLetter(hotkeysConfigOpts[key]) && + !usedKeys.has(hotkeysConfigOpts[key].toUpperCase()) + ) { + // If the property passed the test and has not yet been used, add 'Key' before it and save it + result[key] = "Key" + hotkeysConfigOpts[key].toUpperCase(); + usedKeys.add(hotkeysConfigOpts[key].toUpperCase()); + } else { + // If the property does not pass the test or has already been used, we keep the default value + console.error( + `Hotkey: ${hotkeysConfigOpts[key]} for ${key} is repeated and conflicts with another hotkey or is not 1 letter. The default hotkey is used: ${defaultHotkeysConfig[key][3]}` + ); + result[key] = defaultHotkeysConfig[key]; + } + } + + return result; + } + + /** + * The restoreImgRedMask function displays a red mask around an image to indicate the aspect ratio. + * If the image display property is set to 'none', the mask breaks. To fix this, the function + * temporarily sets the display property to 'block' and then hides the mask again after 300 milliseconds + * to avoid breaking the canvas. Additionally, the function adjusts the mask to work correctly on + * very long images. + */ + function restoreImgRedMask(elements) { + const mainTabId = getTabId(elements); + + if (!mainTabId) return; + + const mainTab = gradioApp().querySelector(mainTabId); + const img = mainTab.querySelector("img"); + const imageARPreview = gradioApp().querySelector("#imageARPreview"); + + if (!img || !imageARPreview) return; + + imageARPreview.style.transform = ""; + if (parseFloat(mainTab.style.width) > 865) { + const transformString = mainTab.style.transform; + const scaleMatch = transformString.match(/scale\(([-+]?[0-9]*\.?[0-9]+)\)/); + let zoom = 1; // default zoom + + if (scaleMatch && scaleMatch[1]) { + zoom = Number(scaleMatch[1]); + } + + imageARPreview.style.transformOrigin = "0 0"; + imageARPreview.style.transform = `scale(${zoom})`; + } + + if (img.style.display !== "none") return; + + img.style.display = "block"; + + setTimeout(() => { + img.style.display = "none"; + }, 400); + } + + const hotkeysConfigOpts = await waitForOpts(); + + // Default config + const defaultHotkeysConfig = { + canvas_hotkey_reset: "KeyR", + canvas_hotkey_fullscreen: "KeyS", + canvas_hotkey_move: "KeyF", + canvas_hotkey_overlap: "KeyO", + canvas_show_tooltip: true, + canvas_swap_controls: false + }; + // swap the actions for ctr + wheel and shift + wheel + const hotkeysConfig = createHotkeyConfig( + defaultHotkeysConfig, + hotkeysConfigOpts + ); + + let isMoving = false; + let mouseX, mouseY; + let activeElement; + + const elements = Object.fromEntries(Object.keys(elementIDs).map((id) => [ + id, + gradioApp().querySelector(elementIDs[id]), + ])); + const elemData = {}; + + // Apply functionality to the range inputs. Restore redmask and correct for long images. + const rangeInputs = elements.rangeGroup ? Array.from(elements.rangeGroup.querySelectorAll("input")) : + [ + gradioApp().querySelector("#img2img_width input[type='range']"), + gradioApp().querySelector("#img2img_height input[type='range']") + ]; + + for (const input of rangeInputs) { + input?.addEventListener("input", () => restoreImgRedMask(elements)); + } + + function applyZoomAndPan(elemId) { + const targetElement = gradioApp().querySelector(elemId); + + if (!targetElement) { + console.log("Element not found"); + return; + } + + targetElement.style.transformOrigin = "0 0"; + + elemData[elemId] = { + zoom: 1, + panX: 0, + panY: 0 + }; + let fullScreenMode = false; + + // Create tooltip + function createTooltip() { + const toolTipElemnt = + targetElement.querySelector(".image-container"); + const tooltip = document.createElement("div"); + tooltip.className = "tooltip"; + + // Creating an item of information + const info = document.createElement("i"); + info.className = "tooltip-info"; + info.textContent = ""; + + // Create a container for the contents of the tooltip + const tooltipContent = document.createElement("div"); + tooltipContent.className = "tooltip-content"; + + // Add info about hotkeys + const zoomKey = hotkeysConfig.canvas_swap_controls ? "Ctrl" : "Shift"; + const adjustKey = hotkeysConfig.canvas_swap_controls ? "Shift" : "Ctrl"; + + const hotkeys = [ + {key: `${zoomKey} + wheel`, action: "Zoom canvas"}, + {key: `${adjustKey} + wheel`, action: "Adjust brush size"}, + { + key: hotkeysConfig.canvas_hotkey_reset.charAt(hotkeysConfig.canvas_hotkey_reset.length - 1), + action: "Reset zoom" + }, + { + key: hotkeysConfig.canvas_hotkey_fullscreen.charAt(hotkeysConfig.canvas_hotkey_fullscreen.length - 1), + action: "Fullscreen mode" + }, + { + key: hotkeysConfig.canvas_hotkey_move.charAt(hotkeysConfig.canvas_hotkey_move.length - 1), + action: "Move canvas" + } + ]; + for (const hotkey of hotkeys) { + const p = document.createElement("p"); + p.innerHTML = `${hotkey.key} - ${hotkey.action}`; + tooltipContent.appendChild(p); + } + + // Add information and content elements to the tooltip element + tooltip.appendChild(info); + tooltip.appendChild(tooltipContent); + + // Add a hint element to the target element + toolTipElemnt.appendChild(tooltip); + } + + //Show tool tip if setting enable + if (hotkeysConfig.canvas_show_tooltip) { + createTooltip(); + } + + // In the course of research, it was found that the tag img is very harmful when zooming and creates white canvases. This hack allows you to almost never think about this problem, it has no effect on webui. + function fixCanvas() { + const activeTab = getActiveTab(elements).textContent.trim(); + + if (activeTab !== "img2img") { + const img = targetElement.querySelector(`${elemId} img`); + + if (img && img.style.display !== "none") { + img.style.display = "none"; + img.style.visibility = "hidden"; + } + } + } + + // Reset the zoom level and pan position of the target element to their initial values + function resetZoom() { + elemData[elemId] = { + zoomLevel: 1, + panX: 0, + panY: 0 + }; + + fixCanvas(); + targetElement.style.transform = `scale(${elemData[elemId].zoomLevel}) translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px)`; + + const canvas = gradioApp().querySelector( + `${elemId} canvas[key="interface"]` + ); + + toggleOverlap("off"); + fullScreenMode = false; + + if ( + canvas && + parseFloat(canvas.style.width) > 865 && + parseFloat(targetElement.style.width) > 865 + ) { + fitToElement(); + return; + } + + targetElement.style.width = ""; + if (canvas) { + targetElement.style.height = canvas.style.height; + } + } + + // Toggle the zIndex of the target element between two values, allowing it to overlap or be overlapped by other elements + function toggleOverlap(forced = "") { + const zIndex1 = "0"; + const zIndex2 = "998"; + + targetElement.style.zIndex = + targetElement.style.zIndex !== zIndex2 ? zIndex2 : zIndex1; + + if (forced === "off") { + targetElement.style.zIndex = zIndex1; + } else if (forced === "on") { + targetElement.style.zIndex = zIndex2; + } + } + + // Adjust the brush size based on the deltaY value from a mouse wheel event + function adjustBrushSize( + elemId, + deltaY, + withoutValue = false, + percentage = 5 + ) { + const input = + gradioApp().querySelector( + `${elemId} input[aria-label='Brush radius']` + ) || + gradioApp().querySelector( + `${elemId} button[aria-label="Use brush"]` + ); + + if (input) { + input.click(); + if (!withoutValue) { + const maxValue = + parseFloat(input.getAttribute("max")) || 100; + const changeAmount = maxValue * (percentage / 100); + const newValue = + parseFloat(input.value) + + (deltaY > 0 ? -changeAmount : changeAmount); + input.value = Math.min(Math.max(newValue, 0), maxValue); + input.dispatchEvent(new Event("change")); + } + } + } + + // Reset zoom when uploading a new image + const fileInput = gradioApp().querySelector( + `${elemId} input[type="file"][accept="image/*"].svelte-116rqfv` + ); + fileInput.addEventListener("click", resetZoom); + + // Update the zoom level and pan position of the target element based on the values of the zoomLevel, panX and panY variables + function updateZoom(newZoomLevel, mouseX, mouseY) { + newZoomLevel = Math.max(0.5, Math.min(newZoomLevel, 15)); + + elemData[elemId].panX += + mouseX - (mouseX * newZoomLevel) / elemData[elemId].zoomLevel; + elemData[elemId].panY += + mouseY - (mouseY * newZoomLevel) / elemData[elemId].zoomLevel; + + targetElement.style.transformOrigin = "0 0"; + targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${newZoomLevel})`; + + toggleOverlap("on"); + return newZoomLevel; + } + + // Change the zoom level based on user interaction + function changeZoomLevel(operation, e) { + if ( + (!hotkeysConfig.canvas_swap_controls && e.shiftKey) || + (hotkeysConfig.canvas_swap_controls && e.ctrlKey) + ) { + e.preventDefault(); + + let zoomPosX, zoomPosY; + let delta = 0.2; + if (elemData[elemId].zoomLevel > 7) { + delta = 0.9; + } else if (elemData[elemId].zoomLevel > 2) { + delta = 0.6; + } + + zoomPosX = e.clientX; + zoomPosY = e.clientY; + + fullScreenMode = false; + elemData[elemId].zoomLevel = updateZoom( + elemData[elemId].zoomLevel + + (operation === "+" ? delta : -delta), + zoomPosX - targetElement.getBoundingClientRect().left, + zoomPosY - targetElement.getBoundingClientRect().top + ); + } + } + + /** + * This function fits the target element to the screen by calculating + * the required scale and offsets. It also updates the global variables + * zoomLevel, panX, and panY to reflect the new state. + */ + + function fitToElement() { + //Reset Zoom + targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`; + + // Get element and screen dimensions + const elementWidth = targetElement.offsetWidth; + const elementHeight = targetElement.offsetHeight; + const parentElement = targetElement.parentElement; + const screenWidth = parentElement.clientWidth; + const screenHeight = parentElement.clientHeight; + + // Get element's coordinates relative to the parent element + const elementRect = targetElement.getBoundingClientRect(); + const parentRect = parentElement.getBoundingClientRect(); + const elementX = elementRect.x - parentRect.x; + + // Calculate scale and offsets + const scaleX = screenWidth / elementWidth; + const scaleY = screenHeight / elementHeight; + const scale = Math.min(scaleX, scaleY); + + const transformOrigin = + window.getComputedStyle(targetElement).transformOrigin; + const [originX, originY] = transformOrigin.split(" "); + const originXValue = parseFloat(originX); + const originYValue = parseFloat(originY); + + const offsetX = + (screenWidth - elementWidth * scale) / 2 - + originXValue * (1 - scale); + const offsetY = + (screenHeight - elementHeight * scale) / 2.5 - + originYValue * (1 - scale); + + // Apply scale and offsets to the element + targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`; + + // Update global variables + elemData[elemId].zoomLevel = scale; + elemData[elemId].panX = offsetX; + elemData[elemId].panY = offsetY; + + fullScreenMode = false; + toggleOverlap("off"); + } + + /** + * This function fits the target element to the screen by calculating + * the required scale and offsets. It also updates the global variables + * zoomLevel, panX, and panY to reflect the new state. + */ + + // Fullscreen mode + function fitToScreen() { + const canvas = gradioApp().querySelector( + `${elemId} canvas[key="interface"]` + ); + + if (!canvas) return; + + if (canvas.offsetWidth > 862) { + targetElement.style.width = canvas.offsetWidth + "px"; + } + + if (fullScreenMode) { + resetZoom(); + fullScreenMode = false; + return; + } + + //Reset Zoom + targetElement.style.transform = `translate(${0}px, ${0}px) scale(${1})`; + + // Get scrollbar width to right-align the image + const scrollbarWidth = + window.innerWidth - document.documentElement.clientWidth; + + // Get element and screen dimensions + const elementWidth = targetElement.offsetWidth; + const elementHeight = targetElement.offsetHeight; + const screenWidth = window.innerWidth - scrollbarWidth; + const screenHeight = window.innerHeight; + + // Get element's coordinates relative to the page + const elementRect = targetElement.getBoundingClientRect(); + const elementY = elementRect.y; + const elementX = elementRect.x; + + // Calculate scale and offsets + const scaleX = screenWidth / elementWidth; + const scaleY = screenHeight / elementHeight; + const scale = Math.min(scaleX, scaleY); + + // Get the current transformOrigin + const computedStyle = window.getComputedStyle(targetElement); + const transformOrigin = computedStyle.transformOrigin; + const [originX, originY] = transformOrigin.split(" "); + const originXValue = parseFloat(originX); + const originYValue = parseFloat(originY); + + // Calculate offsets with respect to the transformOrigin + const offsetX = + (screenWidth - elementWidth * scale) / 2 - + elementX - + originXValue * (1 - scale); + const offsetY = + (screenHeight - elementHeight * scale) / 2 - + elementY - + originYValue * (1 - scale); + + // Apply scale and offsets to the element + targetElement.style.transform = `translate(${offsetX}px, ${offsetY}px) scale(${scale})`; + + // Update global variables + elemData[elemId].zoomLevel = scale; + elemData[elemId].panX = offsetX; + elemData[elemId].panY = offsetY; + + fullScreenMode = true; + toggleOverlap("on"); + } + + // Handle keydown events + function handleKeyDown(event) { + const hotkeyActions = { + [hotkeysConfig.canvas_hotkey_reset]: resetZoom, + [hotkeysConfig.canvas_hotkey_overlap]: toggleOverlap, + [hotkeysConfig.canvas_hotkey_fullscreen]: fitToScreen + }; + + const action = hotkeyActions[event.code]; + if (action) { + event.preventDefault(); + action(event); + } + } + + // Get Mouse position + function getMousePosition(e) { + mouseX = e.offsetX; + mouseY = e.offsetY; + } + + targetElement.addEventListener("mousemove", getMousePosition); + + // Handle events only inside the targetElement + let isKeyDownHandlerAttached = false; + + function handleMouseMove() { + if (!isKeyDownHandlerAttached) { + document.addEventListener("keydown", handleKeyDown); + isKeyDownHandlerAttached = true; + + activeElement = elemId; + } + } + + function handleMouseLeave() { + if (isKeyDownHandlerAttached) { + document.removeEventListener("keydown", handleKeyDown); + isKeyDownHandlerAttached = false; + + activeElement = null; + } + } + + // Add mouse event handlers + targetElement.addEventListener("mousemove", handleMouseMove); + targetElement.addEventListener("mouseleave", handleMouseLeave); + + // Reset zoom when click on another tab + elements.img2imgTabs.addEventListener("click", resetZoom); + elements.img2imgTabs.addEventListener("click", () => { + // targetElement.style.width = ""; + if (parseInt(targetElement.style.width) > 865) { + setTimeout(fitToElement, 0); + } + }); + + targetElement.addEventListener("wheel", e => { + // change zoom level + const operation = e.deltaY > 0 ? "-" : "+"; + changeZoomLevel(operation, e); + + // Handle brush size adjustment with ctrl key pressed + if ( + (hotkeysConfig.canvas_swap_controls && e.shiftKey) || + (!hotkeysConfig.canvas_swap_controls && + (e.ctrlKey || e.metaKey)) + ) { + e.preventDefault(); + + // Increase or decrease brush size based on scroll direction + adjustBrushSize(elemId, e.deltaY); + } + }); + + // Handle the move event for pan functionality. Updates the panX and panY variables and applies the new transform to the target element. + function handleMoveKeyDown(e) { + if (e.code === hotkeysConfig.canvas_hotkey_move) { + if (!e.ctrlKey && !e.metaKey && isKeyDownHandlerAttached) { + e.preventDefault(); + document.activeElement.blur(); + isMoving = true; + } + } + } + + function handleMoveKeyUp(e) { + if (e.code === hotkeysConfig.canvas_hotkey_move) { + isMoving = false; + } + } + + document.addEventListener("keydown", handleMoveKeyDown); + document.addEventListener("keyup", handleMoveKeyUp); + + // Detect zoom level and update the pan speed. + function updatePanPosition(movementX, movementY) { + let panSpeed = 2; + + if (elemData[elemId].zoomLevel > 8) { + panSpeed = 3.5; + } + + elemData[elemId].panX += movementX * panSpeed; + elemData[elemId].panY += movementY * panSpeed; + + // Delayed redraw of an element + requestAnimationFrame(() => { + targetElement.style.transform = `translate(${elemData[elemId].panX}px, ${elemData[elemId].panY}px) scale(${elemData[elemId].zoomLevel})`; + toggleOverlap("on"); + }); + } + + function handleMoveByKey(e) { + if (isMoving && elemId === activeElement) { + updatePanPosition(e.movementX, e.movementY); + targetElement.style.pointerEvents = "none"; + } else { + targetElement.style.pointerEvents = "auto"; + } + } + + // Prevents sticking to the mouse + window.onblur = function() { + isMoving = false; + }; + + gradioApp().addEventListener("mousemove", handleMoveByKey); + } + + applyZoomAndPan(elementIDs.sketch); + applyZoomAndPan(elementIDs.inpaint); + applyZoomAndPan(elementIDs.inpaintSketch); + + // Make the function global so that other extensions can take advantage of this solution + window.applyZoomAndPan = applyZoomAndPan; +}); diff --git a/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d83e14daf9896c3eea310bbb3d7fdf239416d5d1 --- /dev/null +++ b/extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py @@ -0,0 +1,10 @@ +from modules import shared + +shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), { + "canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas"), + "canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "), + "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas positon"), + "canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap ( Technical button, neededs for testing )"), + "canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"), + "canvas_swap_controls": shared.OptionInfo(False, "Swap hotkey combinations for Zoom and Adjust brush resize"), +})) diff --git a/extensions-builtin/canvas-zoom-and-pan/style.css b/extensions-builtin/canvas-zoom-and-pan/style.css new file mode 100644 index 0000000000000000000000000000000000000000..5b131d50e18347d71805140a33665e02bdd83f84 --- /dev/null +++ b/extensions-builtin/canvas-zoom-and-pan/style.css @@ -0,0 +1,63 @@ +.tooltip-info { + position: absolute; + top: 10px; + left: 10px; + cursor: help; + background-color: rgba(0, 0, 0, 0.3); + width: 20px; + height: 20px; + border-radius: 50%; + display: flex; + align-items: center; + justify-content: center; + flex-direction: column; + + z-index: 100; +} + +.tooltip-info::after { + content: ''; + display: block; + width: 2px; + height: 7px; + background-color: white; + margin-top: 2px; +} + +.tooltip-info::before { + content: ''; + display: block; + width: 2px; + height: 2px; + background-color: white; +} + +.tooltip-content { + display: none; + background-color: #f9f9f9; + color: #333; + border: 1px solid #ddd; + padding: 15px; + position: absolute; + top: 40px; + left: 10px; + width: 250px; + font-size: 16px; + opacity: 0; + border-radius: 8px; + box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2); + + z-index: 100; +} + +.tooltip:hover .tooltip-content { + display: block; + animation: fadeIn 0.5s; + opacity: 1; +} + +@keyframes fadeIn { + from {opacity: 0;} + to {opacity: 1;} +} + diff --git a/extensions-builtin/extra-options-section/scripts/extra_options_section.py b/extensions-builtin/extra-options-section/scripts/extra_options_section.py new file mode 100644 index 0000000000000000000000000000000000000000..a05e10d865ab5a5e30dd9db936bad4cb96c4643a --- /dev/null +++ b/extensions-builtin/extra-options-section/scripts/extra_options_section.py @@ -0,0 +1,48 @@ +import gradio as gr +from modules import scripts, shared, ui_components, ui_settings +from modules.ui_components import FormColumn + + +class ExtraOptionsSection(scripts.Script): + section = "extra_options" + + def __init__(self): + self.comps = None + self.setting_names = None + + def title(self): + return "Extra options" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_img2img): + self.comps = [] + self.setting_names = [] + + with gr.Blocks() as interface: + with gr.Accordion("Options", open=False) if shared.opts.extra_options_accordion and shared.opts.extra_options else gr.Group(), gr.Row(): + for setting_name in shared.opts.extra_options: + with FormColumn(): + comp = ui_settings.create_setting_component(setting_name) + + self.comps.append(comp) + self.setting_names.append(setting_name) + + def get_settings_values(): + return [ui_settings.get_value_for_setting(key) for key in self.setting_names] + + interface.load(fn=get_settings_values, inputs=[], outputs=self.comps, queue=False, show_progress=False) + + return self.comps + + def before_process(self, p, *args): + for name, value in zip(self.setting_names, args): + if name not in p.override_settings: + p.override_settings[name] = value + + +shared.options_templates.update(shared.options_section(('ui', "User interface"), { + "extra_options": shared.OptionInfo([], "Options in main UI", ui_components.DropdownMulti, lambda: {"choices": list(shared.opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that also appear in txt2img/img2img interfaces").needs_restart(), + "extra_options_accordion": shared.OptionInfo(False, "Place options in main UI into an accordion") +})) diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index 2b32e71269c0f4691aa1c9a31e20d8a3bcd720d8..68a84c3aac4e591677fa95ed6539d757ee5fbb12 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,4 +1,4 @@ -
+
{background_image} {metadata_button}
diff --git a/html/footer.html b/html/footer.html index bad87ff619df2488b9732017a1e5658e240adb2e..69b2372c755298066a177f2f60a228f4aa11a944 100644 --- a/html/footer.html +++ b/html/footer.html @@ -1,10 +1,12 @@
- API + API  •  Github  •  Gradio  •  + Startup profile +  •  Reload UI

diff --git a/javascript/aspectRatioOverlay.js b/javascript/aspectRatioOverlay.js index 1c08a1a97d6e296ad41cb30fe69240acfeb54dbc..2cf2d571fc02a026b6cdedcf589a217ef0d65d27 100644 --- a/javascript/aspectRatioOverlay.js +++ b/javascript/aspectRatioOverlay.js @@ -81,7 +81,7 @@ function dimensionChange(e, is_width, is_height) { } -onUiUpdate(function() { +onAfterUiUpdate(function() { var arPreviewRect = gradioApp().querySelector('#imageARPreview'); if (arPreviewRect) { arPreviewRect.style.display = 'none'; diff --git a/javascript/contextMenus.js b/javascript/contextMenus.js index f14af1d427cb03b3d59ecc906600a1fa0e7dd766..ccae242f2b6a731e89d8752814aae6b78e143482 100644 --- a/javascript/contextMenus.js +++ b/javascript/contextMenus.js @@ -148,12 +148,18 @@ var addContextMenuEventListener = initResponse[2]; 500); }; - appendContextMenuOption('#txt2img_generate', 'Generate forever', function() { + let generateOnRepeat_txt2img = function() { generateOnRepeat('#txt2img_generate', '#txt2img_interrupt'); - }); - appendContextMenuOption('#img2img_generate', 'Generate forever', function() { + }; + + let generateOnRepeat_img2img = function() { generateOnRepeat('#img2img_generate', '#img2img_interrupt'); - }); + }; + + appendContextMenuOption('#txt2img_generate', 'Generate forever', generateOnRepeat_txt2img); + appendContextMenuOption('#txt2img_interrupt', 'Generate forever', generateOnRepeat_txt2img); + appendContextMenuOption('#img2img_generate', 'Generate forever', generateOnRepeat_img2img); + appendContextMenuOption('#img2img_interrupt', 'Generate forever', generateOnRepeat_img2img); let cancelGenerateForever = function() { clearInterval(window.generateOnRepeatInterval); @@ -167,6 +173,4 @@ var addContextMenuEventListener = initResponse[2]; })(); //End example Context Menu Items -onUiUpdate(function() { - addContextMenuEventListener(); -}); +onAfterUiUpdate(addContextMenuEventListener); diff --git a/javascript/dragdrop.js b/javascript/dragdrop.js index 77a24a070312fd9dd58b5d410e1fbad881e3d4ee..5803daea5ef33341b5307e03a7ebbadc7c324ed7 100644 --- a/javascript/dragdrop.js +++ b/javascript/dragdrop.js @@ -48,12 +48,27 @@ function dropReplaceImage(imgWrap, files) { } } +function eventHasFiles(e) { + if (!e.dataTransfer || !e.dataTransfer.files) return false; + if (e.dataTransfer.files.length > 0) return true; + if (e.dataTransfer.items.length > 0 && e.dataTransfer.items[0].kind == "file") return true; + + return false; +} + +function dragDropTargetIsPrompt(target) { + if (target?.placeholder && target?.placeholder.indexOf("Prompt") >= 0) return true; + if (target?.parentNode?.parentNode?.className?.indexOf("prompt") > 0) return true; + return false; +} + window.document.addEventListener('dragover', e => { const target = e.composedPath()[0]; - const imgWrap = target.closest('[data-testid="image"]'); - if (!imgWrap && target.placeholder && target.placeholder.indexOf("Prompt") == -1) { - return; - } + if (!eventHasFiles(e)) return; + + var targetImage = target.closest('[data-testid="image"]'); + if (!dragDropTargetIsPrompt(target) && !targetImage) return; + e.stopPropagation(); e.preventDefault(); e.dataTransfer.dropEffect = 'copy'; @@ -61,17 +76,31 @@ window.document.addEventListener('dragover', e => { window.document.addEventListener('drop', e => { const target = e.composedPath()[0]; - if (target.placeholder.indexOf("Prompt") == -1) { - return; + if (!eventHasFiles(e)) return; + + if (dragDropTargetIsPrompt(target)) { + e.stopPropagation(); + e.preventDefault(); + + let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; + + const imgParent = gradioApp().getElementById(prompt_target); + const files = e.dataTransfer.files; + const fileInput = imgParent.querySelector('input[type="file"]'); + if (fileInput) { + fileInput.files = files; + fileInput.dispatchEvent(new Event('change')); + } } - const imgWrap = target.closest('[data-testid="image"]'); - if (!imgWrap) { + + var targetImage = target.closest('[data-testid="image"]'); + if (targetImage) { + e.stopPropagation(); + e.preventDefault(); + const files = e.dataTransfer.files; + dropReplaceImage(targetImage, files); return; } - e.stopPropagation(); - e.preventDefault(); - const files = e.dataTransfer.files; - dropReplaceImage(imgWrap, files); }); window.addEventListener('paste', e => { diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index aafe0a005f8b674d2a40b644b8115a3ee9afae11..b87bca3ecb43a3098df3f0c65db9d0767a6c9f1b 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -3,10 +3,17 @@ function setupExtraNetworksForTab(tabname) { var tabs = gradioApp().querySelector('#' + tabname + '_extra_tabs > div'); var search = gradioApp().querySelector('#' + tabname + '_extra_search textarea'); + var sort = gradioApp().getElementById(tabname + '_extra_sort'); + var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); search.classList.add('search'); + sort.classList.add('sort'); + sortOrder.classList.add('sortorder'); + sort.dataset.sortkey = 'sortDefault'; tabs.appendChild(search); + tabs.appendChild(sort); + tabs.appendChild(sortOrder); tabs.appendChild(refresh); var applyFilter = function() { @@ -26,8 +33,51 @@ function setupExtraNetworksForTab(tabname) { }); }; + var applySort = function() { + var reverse = sortOrder.classList.contains("sortReverse"); + var sortKey = sort.querySelector("input").value.toLowerCase().replace("sort", "").replaceAll(" ", "_").replace(/_+$/, "").trim(); + sortKey = sortKey ? "sort" + sortKey.charAt(0).toUpperCase() + sortKey.slice(1) : ""; + var sortKeyStore = sortKey ? sortKey + (reverse ? "Reverse" : "") : ""; + if (!sortKey || sortKeyStore == sort.dataset.sortkey) { + return; + } + + sort.dataset.sortkey = sortKeyStore; + + var cards = gradioApp().querySelectorAll('#' + tabname + '_extra_tabs div.card'); + cards.forEach(function(card) { + card.originalParentElement = card.parentElement; + }); + var sortedCards = Array.from(cards); + sortedCards.sort(function(cardA, cardB) { + var a = cardA.dataset[sortKey]; + var b = cardB.dataset[sortKey]; + if (!isNaN(a) && !isNaN(b)) { + return parseInt(a) - parseInt(b); + } + + return (a < b ? -1 : (a > b ? 1 : 0)); + }); + if (reverse) { + sortedCards.reverse(); + } + cards.forEach(function(card) { + card.remove(); + }); + sortedCards.forEach(function(card) { + card.originalParentElement.appendChild(card); + }); + }; + search.addEventListener("input", applyFilter); applyFilter(); + ["change", "blur", "click"].forEach(function(evt) { + sort.querySelector("input").addEventListener(evt, applySort); + }); + sortOrder.addEventListener("click", function() { + sortOrder.classList.toggle("sortReverse"); + applySort(); + }); extraNetworksApplyFilter[tabname] = applyFilter; } diff --git a/javascript/generationParams.js b/javascript/generationParams.js index a877f8a5474e0bf199cb0082e3f8b971bec7292d..7c0fd221d63313ab063f545570eb0da780b9da3a 100644 --- a/javascript/generationParams.js +++ b/javascript/generationParams.js @@ -1,7 +1,7 @@ // attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes let txt2img_gallery, img2img_gallery, modal = undefined; -onUiUpdate(function() { +onAfterUiUpdate(function() { if (!txt2img_gallery) { txt2img_gallery = attachGalleryListeners("txt2img"); } diff --git a/javascript/hints.js b/javascript/hints.js index 46f342cb9cffb77059d0a3f81453806397988aad..05ae5f22cd7164faef6dced2a76824214070e91b 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -116,17 +116,25 @@ var titles = { "Negative Guidance minimum sigma": "Skip negative prompt for steps where image is already mostly denoised; the higher this value, the more skips there will be; provides increased performance in exchange for minor quality reduction." }; -function updateTooltipForSpan(span) { - if (span.title) return; // already has a title +function updateTooltip(element) { + if (element.title) return; // already has a title - let tooltip = localization[titles[span.textContent]] || titles[span.textContent]; + let text = element.textContent; + let tooltip = localization[titles[text]] || titles[text]; if (!tooltip) { - tooltip = localization[titles[span.value]] || titles[span.value]; + let value = element.value; + if (value) tooltip = localization[titles[value]] || titles[value]; } if (!tooltip) { - for (const c of span.classList) { + // Gradio dropdown options have `data-value`. + let dataValue = element.dataset.value; + if (dataValue) tooltip = localization[titles[dataValue]] || titles[dataValue]; + } + + if (!tooltip) { + for (const c of element.classList) { if (c in titles) { tooltip = localization[titles[c]] || titles[c]; break; @@ -135,34 +143,53 @@ function updateTooltipForSpan(span) { } if (tooltip) { - span.title = tooltip; + element.title = tooltip; } } -function updateTooltipForSelect(select) { - if (select.onchange != null) return; +// Nodes to check for adding tooltips. +const tooltipCheckNodes = new Set(); +// Timer for debouncing tooltip check. +let tooltipCheckTimer = null; - select.onchange = function() { - select.title = localization[titles[select.value]] || titles[select.value] || ""; - }; +function processTooltipCheckNodes() { + for (const node of tooltipCheckNodes) { + updateTooltip(node); + } + tooltipCheckNodes.clear(); } -var observedTooltipElements = {SPAN: 1, BUTTON: 1, SELECT: 1, P: 1}; - -onUiUpdate(function(m) { - m.forEach(function(record) { - record.addedNodes.forEach(function(node) { - if (observedTooltipElements[node.tagName]) { - updateTooltipForSpan(node); - } - if (node.tagName == "SELECT") { - updateTooltipForSelect(node); +onUiUpdate(function(mutationRecords) { + for (const record of mutationRecords) { + if (record.type === "childList" && record.target.classList.contains("options")) { + // This smells like a Gradio dropdown menu having changed, + // so let's enqueue an update for the input element that shows the current value. + let wrap = record.target.parentNode; + let input = wrap?.querySelector("input"); + if (input) { + input.title = ""; // So we'll even have a chance to update it. + tooltipCheckNodes.add(input); } - - if (node.querySelectorAll) { - node.querySelectorAll('span, button, select, p').forEach(updateTooltipForSpan); - node.querySelectorAll('select').forEach(updateTooltipForSelect); + } + for (const node of record.addedNodes) { + if (node.nodeType === Node.ELEMENT_NODE && !node.classList.contains("hide")) { + if (!node.title) { + if ( + node.tagName === "SPAN" || + node.tagName === "BUTTON" || + node.tagName === "P" || + node.tagName === "INPUT" || + (node.tagName === "LI" && node.classList.contains("item")) // Gradio dropdown item + ) { + tooltipCheckNodes.add(node); + } + } + node.querySelectorAll('span, button, p').forEach(n => tooltipCheckNodes.add(n)); } - }); - }); + } + } + if (tooltipCheckNodes.size) { + clearTimeout(tooltipCheckTimer); + tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000); + } }); diff --git a/javascript/imageMaskFix.js b/javascript/imageMaskFix.js index 3c9b8a6fd6a563d9480619ffca56e0b174d3a39b..900c56f32fdf7128f0433621df25a0fbd14c4e42 100644 --- a/javascript/imageMaskFix.js +++ b/javascript/imageMaskFix.js @@ -39,5 +39,5 @@ function imageMaskResize() { }); } -onUiUpdate(imageMaskResize); +onAfterUiUpdate(imageMaskResize); window.addEventListener('resize', imageMaskResize); diff --git a/javascript/imageParams.js b/javascript/imageParams.js deleted file mode 100644 index 057e2d3924ee627c33e7adc0c6098a0af382fc85..0000000000000000000000000000000000000000 --- a/javascript/imageParams.js +++ /dev/null @@ -1,18 +0,0 @@ -window.onload = (function() { - window.addEventListener('drop', e => { - const target = e.composedPath()[0]; - if (target.placeholder.indexOf("Prompt") == -1) return; - - let prompt_target = get_tab_index('tabs') == 1 ? "img2img_prompt_image" : "txt2img_prompt_image"; - - e.stopPropagation(); - e.preventDefault(); - const imgParent = gradioApp().getElementById(prompt_target); - const files = e.dataTransfer.files; - const fileInput = imgParent.querySelector('input[type="file"]'); - if (fileInput) { - fileInput.files = files; - fileInput.dispatchEvent(new Event('change')); - } - }); -}); diff --git a/javascript/imageviewer.js b/javascript/imageviewer.js index 78e24eb9e81561710e8092087ef761ffd6d8251a..677e95c1bc7b700e61bd5e6263e980dd703165e2 100644 --- a/javascript/imageviewer.js +++ b/javascript/imageviewer.js @@ -170,7 +170,7 @@ function modalTileImageToggle(event) { event.stopPropagation(); } -onUiUpdate(function() { +onAfterUiUpdate(function() { var fullImg_preview = gradioApp().querySelectorAll('.gradio-gallery > div > img'); if (fullImg_preview != null) { fullImg_preview.forEach(setupImageForLightbox); diff --git a/javascript/imageviewerGamepad.js b/javascript/imageviewerGamepad.js index 31d226dee7487ef55a57b7b8eaabcddeaa856814..a22c7e6e6435f677c7a86dbbae5da86af8fdc9eb 100644 --- a/javascript/imageviewerGamepad.js +++ b/javascript/imageviewerGamepad.js @@ -1,7 +1,9 @@ +let gamepads = []; + window.addEventListener('gamepadconnected', (e) => { const index = e.gamepad.index; let isWaiting = false; - setInterval(async() => { + gamepads[index] = setInterval(async() => { if (!opts.js_modal_lightbox_gamepad || isWaiting) return; const gamepad = navigator.getGamepads()[index]; const xValue = gamepad.axes[0]; @@ -24,6 +26,10 @@ window.addEventListener('gamepadconnected', (e) => { }, 10); }); +window.addEventListener('gamepaddisconnected', (e) => { + clearInterval(gamepads[e.gamepad.index]); +}); + /* Primarily for vr controller type pointer devices. I use the wheel event because there's currently no way to do it properly with web xr. diff --git a/javascript/notification.js b/javascript/notification.js index a68a76f2514d12dc4d32c03ae8f9597e66e511c1..76c5715dab43923736949a7b9cec6d0799dd5672 100644 --- a/javascript/notification.js +++ b/javascript/notification.js @@ -4,7 +4,7 @@ let lastHeadImg = null; let notificationButton = null; -onUiUpdate(function() { +onAfterUiUpdate(function() { if (notificationButton == null) { notificationButton = gradioApp().getElementById('request_notifications'); diff --git a/javascript/profilerVisualization.js b/javascript/profilerVisualization.js new file mode 100644 index 0000000000000000000000000000000000000000..9d8e5f42f327f93db42773ebf0b97ee1e9671806 --- /dev/null +++ b/javascript/profilerVisualization.js @@ -0,0 +1,153 @@ + +function createRow(table, cellName, items) { + var tr = document.createElement('tr'); + var res = []; + + items.forEach(function(x, i) { + if (x === undefined) { + res.push(null); + return; + } + + var td = document.createElement(cellName); + td.textContent = x; + tr.appendChild(td); + res.push(td); + + var colspan = 1; + for (var n = i + 1; n < items.length; n++) { + if (items[n] !== undefined) { + break; + } + + colspan += 1; + } + + if (colspan > 1) { + td.colSpan = colspan; + } + }); + + table.appendChild(tr); + + return res; +} + +function showProfile(path, cutoff = 0.05) { + requestGet(path, {}, function(data) { + var table = document.createElement('table'); + table.className = 'popup-table'; + + data.records['total'] = data.total; + var keys = Object.keys(data.records).sort(function(a, b) { + return data.records[b] - data.records[a]; + }); + var items = keys.map(function(x) { + return {key: x, parts: x.split('/'), time: data.records[x]}; + }); + var maxLength = items.reduce(function(a, b) { + return Math.max(a, b.parts.length); + }, 0); + + var cols = createRow(table, 'th', ['record', 'seconds']); + cols[0].colSpan = maxLength; + + function arraysEqual(a, b) { + return !(a < b || b < a); + } + + var addLevel = function(level, parent, hide) { + var matching = items.filter(function(x) { + return x.parts[level] && !x.parts[level + 1] && arraysEqual(x.parts.slice(0, level), parent); + }); + var sorted = matching.sort(function(a, b) { + return b.time - a.time; + }); + var othersTime = 0; + var othersList = []; + var othersRows = []; + var childrenRows = []; + sorted.forEach(function(x) { + var visible = x.time >= cutoff && !hide; + + var cells = []; + for (var i = 0; i < maxLength; i++) { + cells.push(x.parts[i]); + } + cells.push(x.time.toFixed(3)); + var cols = createRow(table, 'td', cells); + for (i = 0; i < level; i++) { + cols[i].className = 'muted'; + } + + var tr = cols[0].parentNode; + if (!visible) { + tr.classList.add("hidden"); + } + + if (x.time >= cutoff) { + childrenRows.push(tr); + } else { + othersTime += x.time; + othersList.push(x.parts[level]); + othersRows.push(tr); + } + + var children = addLevel(level + 1, parent.concat([x.parts[level]]), true); + if (children.length > 0) { + var cell = cols[level]; + var onclick = function() { + cell.classList.remove("link"); + cell.removeEventListener("click", onclick); + children.forEach(function(x) { + x.classList.remove("hidden"); + }); + }; + cell.classList.add("link"); + cell.addEventListener("click", onclick); + } + }); + + if (othersTime > 0) { + var cells = []; + for (var i = 0; i < maxLength; i++) { + cells.push(parent[i]); + } + cells.push(othersTime.toFixed(3)); + cells[level] = 'others'; + var cols = createRow(table, 'td', cells); + for (i = 0; i < level; i++) { + cols[i].className = 'muted'; + } + + var cell = cols[level]; + var tr = cell.parentNode; + var onclick = function() { + tr.classList.add("hidden"); + cell.classList.remove("link"); + cell.removeEventListener("click", onclick); + othersRows.forEach(function(x) { + x.classList.remove("hidden"); + }); + }; + + cell.title = othersList.join(", "); + cell.classList.add("link"); + cell.addEventListener("click", onclick); + + if (hide) { + tr.classList.add("hidden"); + } + + childrenRows.push(tr); + } + + return childrenRows; + }; + + addLevel(0, []); + + popup(table); + }); +} + diff --git a/javascript/token-counters.js b/javascript/token-counters.js new file mode 100644 index 0000000000000000000000000000000000000000..9d81a723b01f8b6e3c0894b7a5191dc6b1614c2d --- /dev/null +++ b/javascript/token-counters.js @@ -0,0 +1,83 @@ +let promptTokenCountDebounceTime = 800; +let promptTokenCountTimeouts = {}; +var promptTokenCountUpdateFunctions = {}; + +function update_txt2img_tokens(...args) { + // Called from Gradio + update_token_counter("txt2img_token_button"); + if (args.length == 2) { + return args[0]; + } + return args; +} + +function update_img2img_tokens(...args) { + // Called from Gradio + update_token_counter("img2img_token_button"); + if (args.length == 2) { + return args[0]; + } + return args; +} + +function update_token_counter(button_id) { + if (opts.disable_token_counters) { + return; + } + if (promptTokenCountTimeouts[button_id]) { + clearTimeout(promptTokenCountTimeouts[button_id]); + } + promptTokenCountTimeouts[button_id] = setTimeout( + () => gradioApp().getElementById(button_id)?.click(), + promptTokenCountDebounceTime, + ); +} + + +function recalculatePromptTokens(name) { + promptTokenCountUpdateFunctions[name]?.(); +} + +function recalculate_prompts_txt2img() { + // Called from Gradio + recalculatePromptTokens('txt2img_prompt'); + recalculatePromptTokens('txt2img_neg_prompt'); + return Array.from(arguments); +} + +function recalculate_prompts_img2img() { + // Called from Gradio + recalculatePromptTokens('img2img_prompt'); + recalculatePromptTokens('img2img_neg_prompt'); + return Array.from(arguments); +} + +function setupTokenCounting(id, id_counter, id_button) { + var prompt = gradioApp().getElementById(id); + var counter = gradioApp().getElementById(id_counter); + var textarea = gradioApp().querySelector(`#${id} > label > textarea`); + + if (opts.disable_token_counters) { + counter.style.display = "none"; + return; + } + + if (counter.parentElement == prompt.parentElement) { + return; + } + + prompt.parentElement.insertBefore(counter, prompt); + prompt.parentElement.style.position = "relative"; + + promptTokenCountUpdateFunctions[id] = function() { + update_token_counter(id_button); + }; + textarea.addEventListener("input", promptTokenCountUpdateFunctions[id]); +} + +function setupTokenCounters() { + setupTokenCounting('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button'); + setupTokenCounting('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button'); + setupTokenCounting('img2img_prompt', 'img2img_token_counter', 'img2img_token_button'); + setupTokenCounting('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button'); +} diff --git a/javascript/ui.js b/javascript/ui.js index 648a5290ecb5a9ba311e36aa80922048fa0ffa16..d70a681bff7b45fe5711431ee8ec55c444443a5b 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -248,29 +248,8 @@ function confirm_clear_prompt(prompt, negative_prompt) { } -var promptTokecountUpdateFuncs = {}; - -function recalculatePromptTokens(name) { - if (promptTokecountUpdateFuncs[name]) { - promptTokecountUpdateFuncs[name](); - } -} - -function recalculate_prompts_txt2img() { - recalculatePromptTokens('txt2img_prompt'); - recalculatePromptTokens('txt2img_neg_prompt'); - return Array.from(arguments); -} - -function recalculate_prompts_img2img() { - recalculatePromptTokens('img2img_prompt'); - recalculatePromptTokens('img2img_neg_prompt'); - return Array.from(arguments); -} - - var opts = {}; -onUiUpdate(function() { +onAfterUiUpdate(function() { if (Object.keys(opts).length != 0) return; var json_elem = gradioApp().getElementById('settings_json'); @@ -302,28 +281,7 @@ onUiUpdate(function() { json_elem.parentElement.style.display = "none"; - function registerTextarea(id, id_counter, id_button) { - var prompt = gradioApp().getElementById(id); - var counter = gradioApp().getElementById(id_counter); - var textarea = gradioApp().querySelector("#" + id + " > label > textarea"); - - if (counter.parentElement == prompt.parentElement) { - return; - } - - prompt.parentElement.insertBefore(counter, prompt); - prompt.parentElement.style.position = "relative"; - - promptTokecountUpdateFuncs[id] = function() { - update_token_counter(id_button); - }; - textarea.addEventListener("input", promptTokecountUpdateFuncs[id]); - } - - registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button'); - registerTextarea('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button'); - registerTextarea('img2img_prompt', 'img2img_token_counter', 'img2img_token_button'); - registerTextarea('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button'); + setupTokenCounters(); var show_all_pages = gradioApp().getElementById('settings_show_all_pages'); var settings_tabs = gradioApp().querySelector('#settings div'); @@ -354,33 +312,6 @@ onOptionsChanged(function() { }); let txt2img_textarea, img2img_textarea = undefined; -let wait_time = 800; -let token_timeouts = {}; - -function update_txt2img_tokens(...args) { - update_token_counter("txt2img_token_button"); - if (args.length == 2) { - return args[0]; - } - return args; -} - -function update_img2img_tokens(...args) { - update_token_counter( - "img2img_token_button" - ); - if (args.length == 2) { - return args[0]; - } - return args; -} - -function update_token_counter(button_id) { - if (token_timeouts[button_id]) { - clearTimeout(token_timeouts[button_id]); - } - token_timeouts[button_id] = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); -} function restart_reload() { document.body.innerHTML = '

Reloading...

'; diff --git a/javascript/ui_settings_hints.js b/javascript/ui_settings_hints.js index e216852b59c2b5d8b8d1661638d032d2b46c0119..d088f9494f826d9534dc105ac2f99bda702d22c0 100644 --- a/javascript/ui_settings_hints.js +++ b/javascript/ui_settings_hints.js @@ -42,7 +42,7 @@ onOptionsChanged(function() { function settingsHintsShowQuicksettings() { requestGet("./internal/quicksettings-hint", {}, function(data) { var table = document.createElement('table'); - table.className = 'settings-value-table'; + table.className = 'popup-table'; data.forEach(function(obj) { var tr = document.createElement('tr'); diff --git a/modules/api/api.py b/modules/api/api.py index eee99bbb29cf263fd3390e4192d021fea081bed3..2e49526e3e2faf964d1b20441d2a3d9ad5b68906 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -14,7 +14,7 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images @@ -23,6 +23,7 @@ from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork from PIL import PngImagePlugin,Image from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights +from modules.sd_vae import vae_dict from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices @@ -108,7 +109,6 @@ def api_middleware(app: FastAPI): from rich.console import Console console = Console() except Exception: - import traceback rich_available = False @app.middleware("http") @@ -139,11 +139,12 @@ def api_middleware(app: FastAPI): "errors": str(e), } if not isinstance(e, HTTPException): # do not print backtrace on known httpexceptions - print(f"API error: {request.method}: {request.url} {err}") + message = f"API error: {request.method}: {request.url} {err}" if rich_available: + print(message) console.print_exception(show_locals=True, max_frames=2, extra_lines=1, suppress=[anyio, starlette], word_wrap=False, width=min([console.width, 200])) else: - traceback.print_exc() + errors.report(message, exc_info=True) return JSONResponse(status_code=vars(e).get('status_code', 500), content=jsonable_encoder(err)) @app.middleware("http") @@ -188,7 +189,9 @@ class Api: self.add_api_route("/sdapi/v1/cmd-flags", self.get_cmd_flags, methods=["GET"], response_model=models.FlagsModel) self.add_api_route("/sdapi/v1/samplers", self.get_samplers, methods=["GET"], response_model=List[models.SamplerItem]) self.add_api_route("/sdapi/v1/upscalers", self.get_upscalers, methods=["GET"], response_model=List[models.UpscalerItem]) + self.add_api_route("/sdapi/v1/latent-upscale-modes", self.get_latent_upscale_modes, methods=["GET"], response_model=List[models.LatentUpscalerModeItem]) self.add_api_route("/sdapi/v1/sd-models", self.get_sd_models, methods=["GET"], response_model=List[models.SDModelItem]) + self.add_api_route("/sdapi/v1/sd-vae", self.get_sd_vaes, methods=["GET"], response_model=List[models.SDVaeItem]) self.add_api_route("/sdapi/v1/hypernetworks", self.get_hypernetworks, methods=["GET"], response_model=List[models.HypernetworkItem]) self.add_api_route("/sdapi/v1/face-restorers", self.get_face_restorers, methods=["GET"], response_model=List[models.FaceRestorerItem]) self.add_api_route("/sdapi/v1/realesrgan-models", self.get_realesrgan_models, methods=["GET"], response_model=List[models.RealesrganItem]) @@ -278,7 +281,7 @@ class Api: script_args[0] = selectable_idx + 1 # Now check for always on scripts - if request.alwayson_scripts and (len(request.alwayson_scripts) > 0): + if request.alwayson_scripts: for alwayson_script_name in request.alwayson_scripts.keys(): alwayson_script = self.get_script(alwayson_script_name, script_runner) if alwayson_script is None: @@ -538,9 +541,20 @@ class Api: for upscaler in shared.sd_upscalers ] + def get_latent_upscale_modes(self): + return [ + { + "name": upscale_mode, + } + for upscale_mode in [*(shared.latent_upscale_modes or {})] + ] + def get_sd_models(self): return [{"title": x.title, "model_name": x.model_name, "hash": x.shorthash, "sha256": x.sha256, "filename": x.filename, "config": find_checkpoint_config_near_filename(x)} for x in checkpoints_list.values()] + def get_sd_vaes(self): + return [{"model_name": x, "filename": vae_dict[x]} for x in vae_dict.keys()] + def get_hypernetworks(self): return [{"name": name, "path": shared.hypernetworks[name]} for name in shared.hypernetworks] @@ -700,4 +714,4 @@ class Api: def launch(self, server_name, port): self.app.include_router(self.router) - uvicorn.run(self.app, host=server_name, port=port) + uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=0) diff --git a/modules/api/models.py b/modules/api/models.py index 1ff2fb338cebac1cd3d996599e0a0fce3712dd43..b3a745f0e8c034732b88b2fc041ed3b02d581efe 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -241,6 +241,9 @@ class UpscalerItem(BaseModel): model_url: Optional[str] = Field(title="URL") scale: Optional[float] = Field(title="Scale") +class LatentUpscalerModeItem(BaseModel): + name: str = Field(title="Name") + class SDModelItem(BaseModel): title: str = Field(title="Title") model_name: str = Field(title="Model Name") @@ -249,6 +252,10 @@ class SDModelItem(BaseModel): filename: str = Field(title="Filename") config: Optional[str] = Field(title="Config file") +class SDVaeItem(BaseModel): + model_name: str = Field(title="Model Name") + filename: str = Field(title="Filename") + class HypernetworkItem(BaseModel): name: str = Field(title="Name") path: Optional[str] = Field(title="Path") diff --git a/modules/call_queue.py b/modules/call_queue.py index 447bb7644d6d230bc7da844ca764260c7fa955b1..1b5e5273856caddec0e01d4c047edb8c412d5c8b 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -1,10 +1,8 @@ import html -import sys import threading -import traceback import time -from modules import shared, progress +from modules import shared, progress, errors queue_lock = threading.Lock() @@ -23,7 +21,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): def f(*args, **kwargs): # if the first argument is a string that says "task(...)", it is treated as a job id - if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": + if args and type(args[0]) == str and args[0].startswith("task(") and args[0].endswith(")"): id_task = args[0] progress.add_task_to_queue(id_task) else: @@ -56,16 +54,14 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): try: res = list(func(*args, **kwargs)) except Exception as e: - # When printing out our debug argument list, do not print out more than a MB of text - max_debug_str_len = 131072 # (1024*1024)/8 - - print("Error completing request", file=sys.stderr) - argStr = f"Arguments: {args} {kwargs}" - print(argStr[:max_debug_str_len], file=sys.stderr) - if len(argStr) > max_debug_str_len: - print(f"(Argument list truncated at {max_debug_str_len}/{len(argStr)} characters)", file=sys.stderr) - - print(traceback.format_exc(), file=sys.stderr) + # When printing out our debug argument list, + # do not print out more than a 100 KB of text + max_debug_str_len = 131072 + message = "Error completing request" + arg_str = f"Arguments: {args} {kwargs}"[:max_debug_str_len] + if len(arg_str) > max_debug_str_len: + arg_str += f" (Argument list truncated at {max_debug_str_len}/{len(arg_str)} characters)" + errors.report(f"{message}\n{arg_str}", exc_info=True) shared.state.job = "" shared.state.job_count = 0 @@ -108,4 +104,3 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): return tuple(res) return f - diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 71a21b1a54a064973805b61c321169daa24e4724..de905caa14e1bc5cb0667f82c9ad2b1734072bf3 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -11,7 +11,7 @@ parser.add_argument("--skip-python-version-check", action='store_true', help="la parser.add_argument("--skip-torch-cuda-test", action='store_true', help="launch.py argument: do not check if CUDA is able to work properly") parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py argument: install the appropriate version of xformers even if you have some version already installed") parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed") -parser.add_argument("--update-check", action='store_true', help="launch.py argument: chck for updates at startup") +parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup") parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing") parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation") parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages") diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index ececdbae4c4002c1465519b46565950c78fcb1c6..4260b016caff197be1363706a09504da5d1149c9 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -1,13 +1,11 @@ import os -import sys -import traceback import cv2 import torch import modules.face_restoration import modules.shared -from modules import shared, devices, modelloader +from modules import shared, devices, modelloader, errors from modules.paths import models_path # codeformer people made a choice to include modified basicsr library to their project which makes @@ -105,8 +103,8 @@ def setup_model(dirname): restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) del output torch.cuda.empty_cache() - except Exception as error: - print(f'\tFailed inference for CodeFormer: {error}', file=sys.stderr) + except Exception: + errors.report('Failed inference for CodeFormer', exc_info=True) restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) restored_face = restored_face.astype('uint8') @@ -135,7 +133,6 @@ def setup_model(dirname): shared.face_restorers.append(codeformer) except Exception: - print("Error setting up CodeFormer:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report("Error setting up CodeFormer", exc_info=True) # sys.path = stored_sys_path diff --git a/modules/config_states.py b/modules/config_states.py index db65bcdbf7a8f484dddde20ba7b54ef9b8e0ecd4..6f1ab53fc5909888413d42ede0f09c28f90b90cc 100644 --- a/modules/config_states.py +++ b/modules/config_states.py @@ -3,8 +3,6 @@ Supports saving and restoring webui and extensions from a known working set of c """ import os -import sys -import traceback import json import time import tqdm @@ -13,7 +11,7 @@ from datetime import datetime from collections import OrderedDict import git -from modules import shared, extensions +from modules import shared, extensions, errors from modules.paths_internal import script_path, config_states_dir @@ -53,8 +51,7 @@ def get_webui_config(): if os.path.exists(os.path.join(script_path, ".git")): webui_repo = git.Repo(script_path) except Exception: - print(f"Error reading webui git info from {script_path}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error reading webui git info from {script_path}", exc_info=True) webui_remote = None webui_commit_hash = None @@ -134,8 +131,7 @@ def restore_webui_config(config): if os.path.exists(os.path.join(script_path, ".git")): webui_repo = git.Repo(script_path) except Exception: - print(f"Error reading webui git info from {script_path}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error reading webui git info from {script_path}", exc_info=True) return try: @@ -143,8 +139,7 @@ def restore_webui_config(config): webui_repo.git.reset(webui_commit_hash, hard=True) print(f"* Restored webui to commit {webui_commit_hash}.") except Exception: - print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error restoring webui to commit{webui_commit_hash}") def restore_extension_config(config): diff --git a/modules/devices.py b/modules/devices.py index d8a34a0fdde27402c45c9d22ea316014b3d6db73..1ed6ffdc1143892bb1fb2aaff719b4464cf3a667 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,5 +1,7 @@ import sys import contextlib +from functools import lru_cache + import torch from modules import errors @@ -154,3 +156,19 @@ def test_for_nans(x, where): message += " Use --disable-nan-check commandline argument to disable this check." raise NansException(message) + + +@lru_cache +def first_time_calculation(): + """ + just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and + spends about 2.7 seconds doing that, at least wih NVidia. + """ + + x = torch.zeros((1, 1)).to(device, dtype) + linear = torch.nn.Linear(1, 1).to(device, dtype) + linear(x) + + x = torch.zeros((1, 1, 3, 3)).to(device, dtype) + conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) + conv2d(x) diff --git a/modules/errors.py b/modules/errors.py index f6b80dbbde7947511b58eae309f3b077c8c09fb5..5271a9fe1de9923ca31dbcfc78f7e52472fc75b0 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -1,8 +1,42 @@ import sys +import textwrap import traceback +exception_records = [] + + +def record_exception(): + _, e, tb = sys.exc_info() + if e is None: + return + + if exception_records and exception_records[-1] == e: + return + + exception_records.append((e, tb)) + + if len(exception_records) > 5: + exception_records.pop(0) + + +def report(message: str, *, exc_info: bool = False) -> None: + """ + Print an error message to stderr, with optional traceback. + """ + + record_exception() + + for line in message.splitlines(): + print("***", line, file=sys.stderr) + if exc_info: + print(textwrap.indent(traceback.format_exc(), " "), file=sys.stderr) + print("---", file=sys.stderr) + + def print_error_explanation(message): + record_exception() + lines = message.strip().split("\n") max_len = max([len(x) for x in lines]) @@ -12,9 +46,15 @@ def print_error_explanation(message): print('=' * max_len, file=sys.stderr) -def display(e: Exception, task): +def display(e: Exception, task, *, full_traceback=False): + record_exception() + print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + te = traceback.TracebackException.from_exception(e) + if full_traceback: + # include frames leading up to the try-catch block + te.stack = traceback.StackSummary(traceback.extract_stack()[:-2] + te.stack) + print(*te.format(), sep="", file=sys.stderr) message = str(e) if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message: @@ -28,6 +68,8 @@ already_displayed = {} def display_once(e: Exception, task): + record_exception() + if task in already_displayed: return diff --git a/modules/extensions.py b/modules/extensions.py index 624832a00e926bf73192f797cd2e2d6eb413fe66..8608584b125e41d75595e7c41e960cd2210185c5 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,11 +1,8 @@ import os -import sys import threading -import traceback -import git - -from modules import shared +from modules import shared, errors +from modules.gitpython_hack import Repo from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401 extensions = [] @@ -54,10 +51,9 @@ class Extension: repo = None try: if os.path.exists(os.path.join(self.path, ".git")): - repo = git.Repo(self.path) + repo = Repo(self.path) except Exception: - print(f"Error reading github repository info from {self.path}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error reading github repository info from {self.path}", exc_info=True) if repo is None or repo.bare: self.remote = None @@ -72,8 +68,8 @@ class Extension: self.commit_hash = commit.hexsha self.version = self.commit_hash[:8] - except Exception as ex: - print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr) + except Exception: + errors.report(f"Failed reading extension data from Git repository ({self.name})", exc_info=True) self.remote = None self.have_info_from_repo = True @@ -94,7 +90,7 @@ class Extension: return res def check_updates(self): - repo = git.Repo(self.path) + repo = Repo(self.path) for fetch in repo.remote().fetch(dry_run=True): if fetch.flags != fetch.HEAD_UPTODATE: self.can_update = True @@ -116,7 +112,7 @@ class Extension: self.status = "latest" def fetch_and_reset_hard(self, commit='origin'): - repo = git.Repo(self.path) + repo = Repo(self.path) # Fix: `error: Your local changes to the following files would be overwritten by merge`, # because WSL2 Docker set 755 file permissions instead of 644, this results to the error. repo.git.fetch(all=True) diff --git a/modules/extra_networks.py b/modules/extra_networks.py index f4743928c7baf22efbca1b7d92485fc753a783d6..1f093df27859e057045ce08bf19f77e8def60894 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -32,6 +32,9 @@ class ExtraNetworkParams: else: self.positional.append(item) + def __eq__(self, other): + return self.items == other.items + class ExtraNetwork: def __init__(self, name): diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py index aa2a14efd03c760276cb3c1fa8ddd35bdcca4333..b6a6dc0e1d0d169f292b6ad72a9d236baef0ed23 100644 --- a/modules/extra_networks_hypernet.py +++ b/modules/extra_networks_hypernet.py @@ -9,7 +9,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): def activate(self, p, params_list): additional = shared.opts.sd_hypernetwork - if additional != "None" and additional in shared.hypernetworks and len([x for x in params_list if x.items[0] == additional]) == 0: + if additional != "None" and additional in shared.hypernetworks and not any(x for x in params_list if x.items[0] == additional): hypernet_prompt_text = f"" p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts] params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) @@ -17,7 +17,7 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): names = [] multipliers = [] for params in params_list: - assert len(params.items) > 0 + assert params.items names.append(params.items[0]) multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index cfc04114a740d3ddd784c9572cf3e5584b46f26e..a638f9121869b8abbbf2b84b16f96969330e22a9 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -55,7 +55,7 @@ def image_from_url_text(filedata): if filedata is None: return None - if type(filedata) == list and len(filedata) > 0 and type(filedata[0]) == dict and filedata[0].get("is_file", False): + if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False): filedata = filedata[0] if type(filedata) == dict and filedata.get("is_file", False): @@ -265,19 +265,30 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model else: prompt += ("" if prompt == "" else "\n") + line + if shared.opts.infotext_styles != "Ignore": + found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt) + + if shared.opts.infotext_styles == "Apply": + res["Styles array"] = found_styles + elif shared.opts.infotext_styles == "Apply if any" and found_styles: + res["Styles array"] = found_styles + res["Prompt"] = prompt res["Negative prompt"] = negative_prompt for k, v in re_param.findall(lastline): - if v[0] == '"' and v[-1] == '"': - v = unquote(v) - - m = re_imagesize.match(v) - if m is not None: - res[f"{k}-1"] = m.group(1) - res[f"{k}-2"] = m.group(2) - else: - res[k] = v + try: + if v[0] == '"' and v[-1] == '"': + v = unquote(v) + + m = re_imagesize.match(v) + if m is not None: + res[f"{k}-1"] = m.group(1) + res[f"{k}-2"] = m.group(2) + else: + res[k] = v + except Exception: + print(f"Error parsing \"{k}: {v}\"") # Missing CLIP skip means it was set to 1 (the default) if "Clip skip" not in res: @@ -306,6 +317,18 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "RNG" not in res: res["RNG"] = "GPU" + if "Schedule type" not in res: + res["Schedule type"] = "Automatic" + + if "Schedule max sigma" not in res: + res["Schedule max sigma"] = 0 + + if "Schedule min sigma" not in res: + res["Schedule min sigma"] = 0 + + if "Schedule rho" not in res: + res["Schedule rho"] = 0 + return res @@ -318,6 +341,10 @@ infotext_to_setting_name_mapping = [ ('Conditional mask weight', 'inpainting_mask_weight'), ('Model hash', 'sd_model_checkpoint'), ('ENSD', 'eta_noise_seed_delta'), + ('Schedule type', 'k_sched_type'), + ('Schedule max sigma', 'sigma_max'), + ('Schedule min sigma', 'sigma_min'), + ('Schedule rho', 'rho'), ('Noise multiplier', 'initial_noise_multiplier'), ('Eta', 'eta_ancestral'), ('Eta DDIM', 'eta_ddim'), @@ -421,7 +448,7 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, vals_pairs = [f"{k}: {v}" for k, v in vals.items()] - return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=len(vals_pairs) > 0) + return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs)) paste_fields = paste_fields + [(override_settings_component, paste_settings)] @@ -438,5 +465,3 @@ def connect_paste(button, paste_fields, input_comp, override_settings_component, outputs=[], show_progress=False, ) - - diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 0131dea4265aba9cdfe243e978a08229ab4ebdd8..e239a09d667a7a2e774955e33e5158741c71fe03 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -1,12 +1,10 @@ import os -import sys -import traceback import facexlib import gfpgan import modules.face_restoration -from modules import paths, shared, devices, modelloader +from modules import paths, shared, devices, modelloader, errors model_dir = "GFPGAN" user_path = None @@ -112,5 +110,4 @@ def setup_model(dirname): shared.face_restorers.append(FaceRestorerGFPGAN()) except Exception: - print("Error setting up GFPGAN:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report("Error setting up GFPGAN", exc_info=True) diff --git a/modules/gitpython_hack.py b/modules/gitpython_hack.py new file mode 100644 index 0000000000000000000000000000000000000000..e537c1df93e15679d90e9eea3337035a8d50da89 --- /dev/null +++ b/modules/gitpython_hack.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import io +import subprocess + +import git + + +class Git(git.Git): + """ + Git subclassed to never use persistent processes. + """ + + def _get_persistent_cmd(self, attr_name, cmd_name, *args, **kwargs): + raise NotImplementedError(f"Refusing to use persistent process: {attr_name} ({cmd_name} {args} {kwargs})") + + def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]: + ret = subprocess.check_output( + [self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch-check"], + input=self._prepare_ref(ref), + cwd=self._working_dir, + timeout=2, + ) + return self._parse_object_header(ret) + + def stream_object_data(self, ref: str) -> tuple[str, str, int, "Git.CatFileContentStream"]: + # Not really streaming, per se; this buffers the entire object in memory. + # Shouldn't be a problem for our use case, since we're only using this for + # object headers (commit objects). + ret = subprocess.check_output( + [self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch"], + input=self._prepare_ref(ref), + cwd=self._working_dir, + timeout=30, + ) + bio = io.BytesIO(ret) + hexsha, typename, size = self._parse_object_header(bio.readline()) + return (hexsha, typename, size, self.CatFileContentStream(size, bio)) + + +class Repo(git.Repo): + GitCommandWrapperType = Git diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 570b5603d1a96c71b3c6aad6aba3e74570759826..5d12b4490871c0d373bfc5eca2639565e99af6c8 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -2,8 +2,6 @@ import datetime import glob import html import os -import sys -import traceback import inspect import modules.textual_inversion.dataset @@ -11,7 +9,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint +from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -325,17 +323,14 @@ def load_hypernetwork(name): if path is None: return None - hypernetwork = Hypernetwork() - try: + hypernetwork = Hypernetwork() hypernetwork.load(path) + return hypernetwork except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error loading hypernetwork {path}", exc_info=True) return None - return hypernetwork - def load_hypernetworks(names, multipliers=None): already_loaded = {} @@ -770,7 +765,7 @@ Last saved image: {html.escape(last_saved_image)}

""" except Exception: - print(traceback.format_exc(), file=sys.stderr) + errors.report("Exception in training hypernetwork", exc_info=True) finally: pbar.leave = False pbar.close() diff --git a/modules/images.py b/modules/images.py index 40efc96cc3abfbda08ddc44c9184aa31e161fee9..7bbfc3e0b10da807166a4e804e719cd638580d07 100644 --- a/modules/images.py +++ b/modules/images.py @@ -1,6 +1,4 @@ import datetime -import sys -import traceback import pytz import io @@ -21,6 +19,8 @@ from modules import sd_samplers, shared, script_callbacks, errors from modules.paths_internal import roboto_ttf_file from modules.shared import opts +import modules.sd_vae as sd_vae + LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) @@ -336,8 +336,20 @@ def sanitize_filename_part(text, replace_spaces=True): class FilenameGenerator: + def get_vae_filename(self): #get the name of the VAE file. + if sd_vae.loaded_vae_file is None: + return "NoneType" + file_name = os.path.basename(sd_vae.loaded_vae_file) + split_file_name = file_name.split('.') + if len(split_file_name) > 1 and split_file_name[0] == '': + return split_file_name[1] # if the first character of the filename is "." then [1] is obtained. + else: + return split_file_name[0] + replacements = { 'seed': lambda self: self.seed if self.seed is not None else '', + 'seed_first': lambda self: self.seed if self.p.batch_size == 1 else self.p.all_seeds[0], + 'seed_last': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.all_seeds[-1], 'steps': lambda self: self.p and self.p.steps, 'cfg': lambda self: self.p and self.p.cfg_scale, 'width': lambda self: self.image.width, @@ -354,19 +366,23 @@ class FilenameGenerator: 'prompt_no_styles': lambda self: self.prompt_no_style(), 'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False), 'prompt_words': lambda self: self.prompt_words(), - 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 else self.p.batch_index + 1, - 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.n_iter == 1 and self.p.batch_size == 1 else self.p.iteration * self.p.batch_size + self.p.batch_index + 1, + 'batch_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if self.p.batch_size == 1 or self.zip else self.p.batch_index + 1, + 'batch_size': lambda self: self.p.batch_size, + 'generation_number': lambda self: NOTHING_AND_SKIP_PREVIOUS_TEXT if (self.p.n_iter == 1 and self.p.batch_size == 1) or self.zip else self.p.iteration * self.p.batch_size + self.p.batch_index + 1, 'hasprompt': lambda self, *args: self.hasprompt(*args), # accepts formats:[hasprompt..] 'clip_skip': lambda self: opts.data["CLIP_stop_at_last_layers"], 'denoising': lambda self: self.p.denoising_strength if self.p and self.p.denoising_strength else NOTHING_AND_SKIP_PREVIOUS_TEXT, + 'vae_filename': lambda self: self.get_vae_filename(), + } default_time_format = '%Y%m%d%H%M%S' - def __init__(self, p, seed, prompt, image): + def __init__(self, p, seed, prompt, image, zip=False): self.p = p self.seed = seed self.prompt = prompt self.image = image + self.zip = zip def hasprompt(self, *args): lower = self.prompt.lower() @@ -390,7 +406,7 @@ class FilenameGenerator: prompt_no_style = self.prompt for style in shared.prompt_styles.get_style_prompts(self.p.styles): - if len(style) > 0: + if style: for part in style.split("{prompt}"): prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',') @@ -399,7 +415,7 @@ class FilenameGenerator: return sanitize_filename_part(prompt_no_style, replace_spaces=False) def prompt_words(self): - words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0] + words = [x for x in re_nonletters.split(self.prompt or "") if x] if len(words) == 0: words = ["empty"] return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False) @@ -407,7 +423,7 @@ class FilenameGenerator: def datetime(self, *args): time_datetime = datetime.datetime.now() - time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format + time_format = args[0] if (args and args[0] != "") else self.default_time_format try: time_zone = pytz.timezone(args[1]) if len(args) > 1 else None except pytz.exceptions.UnknownTimeZoneError: @@ -446,8 +462,7 @@ class FilenameGenerator: replacement = fun(self, *pattern_args) except Exception: replacement = None - print(f"Error adding [{pattern}] to filename", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error adding [{pattern}] to filename", exc_info=True) if replacement == NOTHING_AND_SKIP_PREVIOUS_TEXT: continue @@ -664,9 +679,10 @@ def read_info_from_image(image): items['exif comment'] = exif_comment geninfo = exif_comment - for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif', - 'loop', 'background', 'timestamp', 'duration']: - items.pop(field, None) + for field in ['jfif', 'jfif_version', 'jfif_unit', 'jfif_density', 'dpi', 'exif', + 'loop', 'background', 'timestamp', 'duration', 'progressive', 'progression', + 'icc_profile', 'chromaticity']: + items.pop(field, None) if items.get("Software", None) == "NovelAI": try: @@ -677,8 +693,7 @@ def read_info_from_image(image): Negative prompt: {json_info["uc"]} Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337""" except Exception: - print("Error parsing NovelAI image generation parameters:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report("Error parsing NovelAI image generation parameters", exc_info=True) return geninfo, items diff --git a/modules/img2img.py b/modules/img2img.py index d704bf9006d647f10cd0dab6c7c341c7fd54dc29..2c4970207fe2c298c0424eb13ec9af564bf4e3d9 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import numpy as np from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError @@ -13,7 +14,7 @@ from modules.ui import plaintext_to_html import modules.scripts -def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): +def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=False, scale_by=1.0): processing.fix_seed(p) images = shared.listfiles(input_dir) @@ -21,9 +22,10 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): is_inpaint_batch = False if inpaint_mask_dir: inpaint_masks = shared.listfiles(inpaint_mask_dir) - is_inpaint_batch = len(inpaint_masks) > 0 - if is_inpaint_batch: - print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.") + is_inpaint_batch = bool(inpaint_masks) + + if is_inpaint_batch: + print(f"\nInpaint batch is enabled. {len(inpaint_masks)} masks found.") print(f"Will process {len(images)} images, creating {p.n_iter * p.batch_size} new images for each.") @@ -49,14 +51,31 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): continue # Use the EXIF orientation of photos taken by smartphones. img = ImageOps.exif_transpose(img) + + if to_scale: + p.width = int(img.width * scale_by) + p.height = int(img.height * scale_by) + p.init_images = [img] * p.batch_size + image_path = Path(image) if is_inpaint_batch: # try to find corresponding mask for an image using simple filename matching - mask_image_path = os.path.join(inpaint_mask_dir, os.path.basename(image)) - # if not found use first one ("same mask for all images" use-case) - if mask_image_path not in inpaint_masks: + if len(inpaint_masks) == 1: mask_image_path = inpaint_masks[0] + else: + # try to find corresponding mask for an image using simple filename matching + mask_image_dir = Path(inpaint_mask_dir) + masks_found = list(mask_image_dir.glob(f"{image_path.stem}.*")) + + if len(masks_found) == 0: + print(f"Warning: mask is not found for {image_path} in {mask_image_dir}. Skipping it.") + continue + + # it should contain only 1 matching mask + # otherwise user has many masks with the same name but different extensions + mask_image_path = masks_found[0] + mask_image = Image.open(mask_image_path) p.image_mask = mask_image @@ -65,7 +84,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args): proc = process_images(p) for n, processed_image in enumerate(proc.images): - filename = os.path.basename(image) + filename = image_path.name if n > 0: left, right = os.path.splitext(filename) @@ -92,7 +111,8 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s elif mode == 2: # inpaint image, mask = init_img_with_mask["image"], init_img_with_mask["mask"] alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1') - mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L') + mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1') + mask = ImageChops.lighter(alpha_mask, mask).convert('L') image = image.convert("RGB") elif mode == 3: # inpaint sketch image = inpaint_color_sketch @@ -114,7 +134,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if image is not None: image = ImageOps.exif_transpose(image) - if selected_scale_tab == 1: + if selected_scale_tab == 1 and not is_batch: assert image, "Can't scale by because no image is selected" width = int(image.width * scale_by) @@ -169,7 +189,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s if is_batch: assert not shared.cmd_opts.hide_ui_dir_config, "Launched with --hide-ui-dir-config, batch img2img disabled" - process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args) + process_batch(p, img2img_batch_input_dir, img2img_batch_output_dir, img2img_batch_inpaint_mask_dir, args, to_scale=selected_scale_tab == 1, scale_by=scale_by) processed = Processed(p, [], p.seed, "") else: diff --git a/modules/interrogate.py b/modules/interrogate.py index 111b1322c3938f602503e6b10ef8fc57095c905e..9b2c5b60efbd10bd6b652c34ad3b62d9e1ca9f3e 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -1,6 +1,5 @@ import os import sys -import traceback from collections import namedtuple from pathlib import Path import re @@ -216,8 +215,7 @@ class InterrogateModels: res += f", {match}" except Exception: - print("Error interrogating", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report("Error interrogating", exc_info=True) res += "" self.unload() diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 35a52310b07119195558c218bee97a971b217df5..609a181e013731dc43b66d2788efe0e469a818f7 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -7,7 +7,7 @@ import platform import json from functools import lru_cache -from modules import cmd_args +from modules import cmd_args, errors from modules.paths_internal import script_path, extensions_dir args, _ = cmd_args.parser.parse_known_args() @@ -68,7 +68,13 @@ def git_tag(): try: return subprocess.check_output([git, "describe", "--tags"], shell=False, encoding='utf8').strip() except Exception: - return "" + try: + from pathlib import Path + changelog_md = Path(__file__).parent.parent / "CHANGELOG.md" + with changelog_md.open(encoding="utf-8") as file: + return next((line.strip() for line in file if line.strip()), "") + except Exception: + return "" def run(command, desc=None, errdesc=None, custom_env=None, live: bool = default_command_live) -> str: @@ -188,7 +194,7 @@ def run_extension_installer(extension_dir): print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env)) except Exception as e: - print(e, file=sys.stderr) + errors.report(str(e)) def list_extensions(settings_file): @@ -198,8 +204,8 @@ def list_extensions(settings_file): if os.path.isfile(settings_file): with open(settings_file, "r", encoding="utf8") as file: settings = json.load(file) - except Exception as e: - print(e, file=sys.stderr) + except Exception: + errors.report("Could not load settings", exc_info=True) disabled_extensions = set(settings.get('disabled_extensions', [])) disable_all_extensions = settings.get('disable_all_extensions', 'none') @@ -223,23 +229,28 @@ def prepare_environment(): torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") - xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17') + xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.20') gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "https://github.com/TencentARC/GFPGAN/archive/8d2447a2d918f8eba5a4a01463fd48e45126a379.zip") clip_package = os.environ.get('CLIP_PACKAGE', "https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "https://github.com/mlfoundations/open_clip/archive/bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b.zip") stable_diffusion_repo = os.environ.get('STABLE_DIFFUSION_REPO', "https://github.com/Stability-AI/stablediffusion.git") - taming_transformers_repo = os.environ.get('TAMING_TRANSFORMERS_REPO', "https://github.com/CompVis/taming-transformers.git") k_diffusion_repo = os.environ.get('K_DIFFUSION_REPO', 'https://github.com/crowsonkb/k-diffusion.git') codeformer_repo = os.environ.get('CODEFORMER_REPO', 'https://github.com/sczhou/CodeFormer.git') blip_repo = os.environ.get('BLIP_REPO', 'https://github.com/salesforce/BLIP.git') stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "cf1d67a6fd5ea1aa600c4df58e5b47da45f6bdbf") - taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6") k_diffusion_commit_hash = os.environ.get('K_DIFFUSION_COMMIT_HASH', "c9fe758757e022f05ca5a53fa8fac28889e4f1cf") codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af") blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9") + try: + # the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution + os.remove(os.path.join(script_path, "tmp", "restart")) + os.environ.setdefault('SD_WEBUI_RESTARTING ', '1') + except OSError: + pass + if not args.skip_python_version_check: check_python_version() @@ -286,7 +297,6 @@ def prepare_environment(): os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True) git_clone(stable_diffusion_repo, repo_dir('stable-diffusion-stability-ai'), "Stable Diffusion", stable_diffusion_commit_hash) - git_clone(taming_transformers_repo, repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash) git_clone(k_diffusion_repo, repo_dir('k-diffusion'), "K-diffusion", k_diffusion_commit_hash) git_clone(codeformer_repo, repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash) git_clone(blip_repo, repo_dir('BLIP'), "BLIP", blip_commit_hash) diff --git a/modules/localization.py b/modules/localization.py index ee9c65e7dce54dba588107af53b8273acbd80aa2..e8f585dab2e285fd477222ce90173aaefa88a3d4 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -1,8 +1,7 @@ import json import os -import sys -import traceback +from modules import errors localizations = {} @@ -31,7 +30,6 @@ def localization_js(current_localization_name: str) -> str: with open(fn, "r", encoding="utf8") as file: data = json.load(file) except Exception: - print(f"Error loading localization from {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error loading localization from {fn}", exc_info=True) return f"window.localization = {json.dumps(data)}" diff --git a/modules/lowvram.py b/modules/lowvram.py index e254cc1316f96634604ebfdb8b01e73a256589aa..d95bcfbf0f3b8cd6adff29ed094b834eb0d41b8e 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -15,6 +15,8 @@ def send_everything_to_cpu(): def setup_for_low_vram(sd_model, use_medvram): + sd_model.lowvram = True + parents = {} def send_me_to_gpu(module, _): @@ -96,3 +98,7 @@ def setup_for_low_vram(sd_model, use_medvram): diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu) for block in diff_model.output_blocks: block.register_forward_pre_hook(send_me_to_gpu) + + +def is_enabled(sd_model): + return getattr(sd_model, 'lowvram', False) diff --git a/modules/models/diffusion/ddpm_edit.py b/modules/models/diffusion/ddpm_edit.py index 3fb76b651e1edc547a72905864e930191a4e7107..b892d5fc7b04755a3de6c17cf8787605df0faed3 100644 --- a/modules/models/diffusion/ddpm_edit.py +++ b/modules/models/diffusion/ddpm_edit.py @@ -230,9 +230,9 @@ class DDPM(pl.LightningModule): missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( sd, strict=False) print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") - if len(missing) > 0: + if missing: print(f"Missing Keys: {missing}") - if len(unexpected) > 0: + if unexpected: print(f"Unexpected Keys: {unexpected}") def q_mean_variance(self, x_start, t): diff --git a/modules/paths.py b/modules/paths.py index 5f6474c032afbbb4ef8081aad4fd4da81eddda9a..5171df4f85e7ee8a5b6f4c340f8d53673045934a 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -20,7 +20,6 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), - (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers', []), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), diff --git a/modules/processing.py b/modules/processing.py index d22b353f4983cf4a7a88a24802cb48e756713637..8da738842cf421967367c3bb41da2c255fcb3569 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1,4 +1,5 @@ import json +import logging import math import os import sys @@ -6,14 +7,14 @@ import hashlib import torch import numpy as np -from PIL import Image, ImageFilter, ImageOps +from PIL import Image, ImageOps import random import cv2 from skimage import exposure from typing import Any, Dict, List import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -23,7 +24,6 @@ import modules.images as images import modules.styles import modules.sd_models as sd_models import modules.sd_vae as sd_vae -import logging from ldm.data.util import AddMiDaS from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion @@ -106,6 +106,9 @@ class StableDiffusionProcessing: """ The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing """ + cached_uc = [None, None] + cached_c = [None, None] + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): if sampler_index is not None: print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) @@ -171,12 +174,13 @@ class StableDiffusionProcessing: self.prompts = None self.negative_prompts = None + self.extra_network_data = None self.seeds = None self.subseeds = None self.step_multiplier = 1 - self.cached_uc = [None, None] - self.cached_c = [None, None] + self.cached_uc = StableDiffusionProcessing.cached_uc + self.cached_c = StableDiffusionProcessing.cached_c self.uc = None self.c = None @@ -288,8 +292,9 @@ class StableDiffusionProcessing: self.sampler = None self.c = None self.uc = None - self.cached_c = [None, None] - self.cached_uc = [None, None] + if not opts.experimental_persistent_cond_cache: + StableDiffusionProcessing.cached_c = [None, None] + StableDiffusionProcessing.cached_uc = [None, None] def get_token_merging_ratio(self, for_hr=False): if for_hr: @@ -311,7 +316,7 @@ class StableDiffusionProcessing: self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts] self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts] - def get_conds_with_caching(self, function, required_prompts, steps, cache): + def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data): """ Returns the result of calling function(shared.sd_model, required_prompts, steps) using a cache to store the result if the same arguments have been used before. @@ -320,27 +325,29 @@ class StableDiffusionProcessing: representing the previously used arguments, or None if no arguments have been used before. The second element is where the previously computed result is stored. + + caches is a list with items described above. """ - if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) == cache[0]: - return cache[1] + for cache in caches: + if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]: + return cache[1] + + cache = caches[0] with devices.autocast(): cache[1] = function(shared.sd_model, required_prompts, steps) - cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) + cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) return cache[1] def setup_conds(self): sampler_config = sd_samplers.find_sampler_config(self.sampler_name) self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 - - self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc) - self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c) + self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) + self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) def parse_extra_network_prompts(self): - self.prompts, extra_network_data = extra_networks.parse_prompts(self.prompts) - - return extra_network_data + self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) class Processed: @@ -588,6 +595,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter def process_images(p: StableDiffusionProcessing) -> Processed: + if p.scripts is not None: + p.scripts.before_process(p) + stored_opts = {k: opts.data[k] for k in p.override_settings.keys()} try: @@ -673,10 +683,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN": sd_vae_approx.model() + sd_unet.apply_unet() + if state.job_count == -1: state.job_count = p.n_iter - extra_network_data = None for n in range(p.n_iter): p.iteration = n @@ -697,11 +708,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if len(p.prompts) == 0: break - extra_network_data = p.parse_extra_network_prompts() + p.parse_extra_network_prompts() if not p.disable_extra_networks: with devices.autocast(): - extra_networks.activate(p, extra_network_data) + extra_networks.activate(p, p.extra_network_data) if p.scripts is not None: p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) @@ -736,7 +747,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: del samples_ddim - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + if lowvram.is_enabled(shared.sd_model): lowvram.send_everything_to_cpu() devices.torch_gc() @@ -823,8 +834,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) - if not p.disable_extra_networks and extra_network_data: - extra_networks.deactivate(p, extra_network_data) + if not p.disable_extra_networks and p.extra_network_data: + extra_networks.deactivate(p, p.extra_network_data) devices.torch_gc() @@ -859,6 +870,8 @@ def old_hires_fix_first_pass_dimensions(width, height): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None + cached_hr_uc = [None, None] + cached_hr_c = [None, None] def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs): super().__init__(**kwargs) @@ -891,6 +904,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_negative_prompts = None self.hr_extra_network_data = None + self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc + self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c self.hr_c = None self.hr_uc = None @@ -970,7 +985,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") if self.enable_hr and latent_scale_mode is None: - assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}" + if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers): + raise Exception(f"could not find upscaler named {self.hr_upscaler}") x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) @@ -1053,6 +1069,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): with devices.autocast(): extra_networks.activate(self, self.hr_extra_network_data) + with devices.autocast(): + self.calculate_hr_conds() + sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True)) samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) @@ -1064,8 +1083,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): return samples def close(self): + super().close() self.hr_c = None self.hr_uc = None + if not opts.experimental_persistent_cond_cache: + StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None] + StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None] def setup_prompts(self): super().setup_prompts() @@ -1092,12 +1115,31 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts] self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts] + def calculate_hr_conds(self): + if self.hr_c is not None: + return + + self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data) + self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data) + def setup_conds(self): super().setup_conds() + self.hr_uc = None + self.hr_c = None + if self.enable_hr: - self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc) - self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c) + if shared.opts.hires_fix_use_firstpass_conds: + self.calculate_hr_conds() + + elif lowvram.is_enabled(shared.sd_model): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded + with devices.autocast(): + extra_networks.activate(self, self.hr_extra_network_data) + + self.calculate_hr_conds() + + with devices.autocast(): + extra_networks.activate(self, self.extra_network_data) def parse_extra_network_prompts(self): res = super().parse_extra_network_prompts() @@ -1114,7 +1156,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): sampler = None - def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs): + def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs): super().__init__(**kwargs) self.init_images = init_images @@ -1125,7 +1167,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.image_mask = mask self.latent_mask = None self.mask_for_overlay = None - self.mask_blur = mask_blur + if mask_blur is not None: + mask_blur_x = mask_blur + mask_blur_y = mask_blur + self.mask_blur_x = mask_blur_x + self.mask_blur_y = mask_blur_y self.inpainting_fill = inpainting_fill self.inpaint_full_res = inpaint_full_res self.inpaint_full_res_padding = inpaint_full_res_padding @@ -1147,8 +1193,17 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.inpainting_mask_invert: image_mask = ImageOps.invert(image_mask) - if self.mask_blur > 0: - image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)) + if self.mask_blur_x > 0: + np_mask = np.array(image_mask) + kernel_size = 2 * int(4 * self.mask_blur_x + 0.5) + 1 + np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x) + image_mask = Image.fromarray(np_mask) + + if self.mask_blur_y > 0: + np_mask = np.array(image_mask) + kernel_size = 2 * int(4 * self.mask_blur_y + 0.5) + 1 + np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y) + image_mask = Image.fromarray(np_mask) if self.inpaint_full_res: self.mask_for_overlay = image_mask diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index b4aff704af5fdeabc509dc44d3b648b05caa58be..0069d8b0e1a5376dc663a249428ef77421d74dca 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -336,11 +336,11 @@ def parse_prompt_attention(text): round_brackets.append(len(res)) elif text == '[': square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: + elif weight is not None and round_brackets: multiply_range(round_brackets.pop(), float(weight)) - elif text == ')' and len(round_brackets) > 0: + elif text == ')' and round_brackets: multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == ']' and len(square_brackets) > 0: + elif text == ']' and square_brackets: multiply_range(square_brackets.pop(), square_bracket_multiplier) else: parts = re.split(re_break, text) diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 99983678d0ec327066f3c77c0e31f95dbc4935ee..2d27b321c16dec81c9c4dff1d78ece3726d79d6f 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,6 +1,4 @@ import os -import sys -import traceback import numpy as np from PIL import Image @@ -9,7 +7,8 @@ from realesrgan import RealESRGANer from modules.upscaler import Upscaler, UpscalerData from modules.shared import cmd_opts, opts -from modules import modelloader +from modules import modelloader, errors + class UpscalerRealESRGAN(Upscaler): def __init__(self, path): @@ -36,8 +35,7 @@ class UpscalerRealESRGAN(Upscaler): self.scalers.append(scaler) except Exception: - print("Error importing Real-ESRGAN:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report("Error importing Real-ESRGAN", exc_info=True) self.enable = False self.scalers = [] @@ -76,9 +74,8 @@ class UpscalerRealESRGAN(Upscaler): info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True) return info - except Exception as e: - print(f"Error making Real-ESRGAN models list: {e}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + except Exception: + errors.report("Error making Real-ESRGAN models list", exc_info=True) return None def load_models(self, _): @@ -135,5 +132,4 @@ def get_realesrgan_models(scaler): ] return models except Exception: - print("Error making Real-ESRGAN models list:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report("Error making Real-ESRGAN models list", exc_info=True) diff --git a/modules/restart.py b/modules/restart.py new file mode 100644 index 0000000000000000000000000000000000000000..18eacaf377ee4f13c5dc3e12d7a2d818143cc69d --- /dev/null +++ b/modules/restart.py @@ -0,0 +1,23 @@ +import os +from pathlib import Path + +from modules.paths_internal import script_path + + +def is_restartable() -> bool: + """ + Return True if the webui is restartable (i.e. there is something watching to restart it with) + """ + return bool(os.environ.get('SD_WEBUI_RESTART')) + + +def restart_program() -> None: + """creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again""" + + (Path(script_path) / "tmp" / "restart").touch() + + stop_program() + + +def stop_program() -> None: + os._exit(0) diff --git a/modules/safe.py b/modules/safe.py index e8f50774374dc83ff9b5b1a5a8cdd2f21fe8ea59..b1d08a7928ed338ab24da67d785b7ac4f41aafa4 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -2,8 +2,6 @@ import pickle import collections -import sys -import traceback import torch import numpy @@ -11,7 +9,10 @@ import _codecs import zipfile import re + # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage +from modules import errors + TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage def encode(*args): @@ -136,17 +137,20 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs): check_pt(filename, extra_handler) except pickle.UnpicklingError: - print(f"Error verifying pickled file from {filename}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) - print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) + errors.report( + f"Error verifying pickled file from {filename}\n" + "-----> !!!! The file is most likely corrupted !!!! <-----\n" + "You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", + exc_info=True, + ) return None - except Exception: - print(f"Error verifying pickled file from {filename}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) - print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) + errors.report( + f"Error verifying pickled file from {filename}\n" + f"The file may be malicious, so the program is not going to read it.\n" + f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", + exc_info=True, + ) return None return unsafe_torch_load(filename, *args, **kwargs) @@ -190,4 +194,3 @@ with safe.Extra(handler): unsafe_torch_load = torch.load torch.load = load global_extra_handler = None - diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 40f388a59bbe402565787a51c42451b7d5497caa..77ee55ee3f41598e582ec783d4604d90d3d323cd 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,16 +1,16 @@ -import sys -import traceback -from collections import namedtuple import inspect +import os +from collections import namedtuple from typing import Optional, Dict, Any from fastapi import FastAPI from gradio import Blocks +from modules import errors, timer + def report_exception(c, job): - print(f"Error executing callback {job} for {c.script}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error executing callback {job} for {c.script}", exc_info=True) class ImageSaveParams: @@ -111,6 +111,7 @@ callback_map = dict( callbacks_before_ui=[], callbacks_on_reload=[], callbacks_list_optimizers=[], + callbacks_list_unets=[], ) @@ -123,6 +124,7 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI): for c in callback_map['callbacks_app_started']: try: c.callback(demo, app) + timer.startup_timer.record(os.path.basename(c.script)) except Exception: report_exception(c, 'app_started_callback') @@ -271,16 +273,28 @@ def list_optimizers_callback(): return res +def list_unets_callback(): + res = [] + + for c in callback_map['callbacks_list_unets']: + try: + c.callback(res) + except Exception: + report_exception(c, 'list_unets') + + return res + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] - filename = stack[0].filename if len(stack) > 0 else 'unknown file' + filename = stack[0].filename if stack else 'unknown file' callbacks.append(ScriptCallback(filename, fun)) def remove_current_script_callbacks(): stack = [x for x in inspect.stack() if x.filename != __file__] - filename = stack[0].filename if len(stack) > 0 else 'unknown file' + filename = stack[0].filename if stack else 'unknown file' if filename == 'unknown file': return for callback_list in callback_map.values(): @@ -430,3 +444,10 @@ def on_list_optimizers(callback): to it.""" add_callback(callback_map['callbacks_list_optimizers'], callback) + + +def on_list_unets(callback): + """register a function to be called when UI is making a list of alternative options for unet. + The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it.""" + + add_callback(callback_map['callbacks_list_unets'], callback) diff --git a/modules/script_loading.py b/modules/script_loading.py index 57b15862466248a31f3bb3ed5ad17001308c63f0..306a1f35f778ee06b12910b92a38c0cc2be75a31 100644 --- a/modules/script_loading.py +++ b/modules/script_loading.py @@ -1,8 +1,8 @@ import os -import sys -import traceback import importlib.util +from modules import errors + def load_module(path): module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path) @@ -27,5 +27,4 @@ def preload_extensions(extensions_dir, parser): module.preload(parser) except Exception: - print(f"Error running preload() for {preload_script}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error running preload() for {preload_script}", exc_info=True) diff --git a/modules/scripts.py b/modules/scripts.py index c902804b6e4e9a584fb8e6a81b0a91ecaea1f149..99bf836af7b81fd7e0ed38486a6fbc144ca52cad 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,12 +1,11 @@ import os import re import sys -import traceback from collections import namedtuple import gradio as gr -from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing +from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing, errors, timer AlwaysVisible = object() @@ -20,6 +19,9 @@ class Script: name = None """script's internal name derived from title""" + section = None + """name of UI section that the script's controls will be placed into""" + filename = None args_from = None args_to = None @@ -82,6 +84,15 @@ class Script: pass + def before_process(self, p, *args): + """ + This function is called very early before processing begins for AlwaysVisible scripts. + You can modify the processing object (p) here, inject hooks, etc. + args contains all values returned by components from ui() + """ + + pass + def process(self, p, *args): """ This function is called before processing begins for AlwaysVisible scripts. @@ -264,12 +275,12 @@ def load_scripts(): register_scripts_from_module(script_module) except Exception: - print(f"Error loading script: {scriptfile.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error loading script: {scriptfile.filename}", exc_info=True) finally: sys.path = syspath current_basedir = paths.script_path + timer.startup_timer.record(scriptfile.filename) global scripts_txt2img, scripts_img2img, scripts_postproc @@ -280,11 +291,9 @@ def load_scripts(): def wrap_call(func, filename, funcname, *args, default=None, **kwargs): try: - res = func(*args, **kwargs) - return res + return func(*args, **kwargs) except Exception: - print(f"Error calling: {filename}/{funcname}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error calling: {filename}/{funcname}", exc_info=True) return default @@ -297,6 +306,7 @@ class ScriptRunner: self.titles = [] self.infotext_fields = [] self.paste_field_names = [] + self.inputs = [None] def initialize_scripts(self, is_img2img): from modules import scripts_auto_postprocessing @@ -324,69 +334,73 @@ class ScriptRunner: self.scripts.append(script) self.selectable_scripts.append(script) - def setup_ui(self): + def create_script_ui(self, script): import modules.api.models as api_models - self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts] + script.args_from = len(self.inputs) + script.args_to = len(self.inputs) - inputs = [None] - inputs_alwayson = [True] + controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) - def create_script_ui(script, inputs, inputs_alwayson): - script.args_from = len(inputs) - script.args_to = len(inputs) + if controls is None: + return - controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) + script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() + api_args = [] - if controls is None: - return + for control in controls: + control.custom_script_source = os.path.basename(script.filename) - script.name = wrap_call(script.title, script.filename, "title", default=script.filename).lower() - api_args = [] + arg_info = api_models.ScriptArg(label=control.label or "") - for control in controls: - control.custom_script_source = os.path.basename(script.filename) + for field in ("value", "minimum", "maximum", "step", "choices"): + v = getattr(control, field, None) + if v is not None: + setattr(arg_info, field, v) - arg_info = api_models.ScriptArg(label=control.label or "") + api_args.append(arg_info) - for field in ("value", "minimum", "maximum", "step", "choices"): - v = getattr(control, field, None) - if v is not None: - setattr(arg_info, field, v) + script.api_info = api_models.ScriptInfo( + name=script.name, + is_img2img=script.is_img2img, + is_alwayson=script.alwayson, + args=api_args, + ) - api_args.append(arg_info) + if script.infotext_fields is not None: + self.infotext_fields += script.infotext_fields - script.api_info = api_models.ScriptInfo( - name=script.name, - is_img2img=script.is_img2img, - is_alwayson=script.alwayson, - args=api_args, - ) + if script.paste_field_names is not None: + self.paste_field_names += script.paste_field_names - if script.infotext_fields is not None: - self.infotext_fields += script.infotext_fields + self.inputs += controls + script.args_to = len(self.inputs) - if script.paste_field_names is not None: - self.paste_field_names += script.paste_field_names + def setup_ui_for_section(self, section, scriptlist=None): + if scriptlist is None: + scriptlist = self.alwayson_scripts - inputs += controls - inputs_alwayson += [script.alwayson for _ in controls] - script.args_to = len(inputs) + for script in scriptlist: + if script.alwayson and script.section != section: + continue - for script in self.alwayson_scripts: - with gr.Group() as group: - create_script_ui(script, inputs, inputs_alwayson) + with gr.Group(visible=script.alwayson) as group: + self.create_script_ui(script) script.group = group - dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index") - inputs[0] = dropdown + def prepare_ui(self): + self.inputs = [None] - for script in self.selectable_scripts: - with gr.Group(visible=False) as group: - create_script_ui(script, inputs, inputs_alwayson) + def setup_ui(self): + self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts] - script.group = group + self.setup_ui_for_section(None) + + dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index") + self.inputs[0] = dropdown + + self.setup_ui_for_section(None, self.selectable_scripts) def select_script(script_index): selected_script = self.selectable_scripts[script_index - 1] if script_index>0 else None @@ -411,6 +425,7 @@ class ScriptRunner: ) self.script_load_ctr = 0 + def onload_script_visibility(params): title = params.get('Script', None) if title: @@ -421,10 +436,10 @@ class ScriptRunner: else: return gr.update(visible=False) - self.infotext_fields.append( (dropdown, lambda x: gr.update(value=x.get('Script', 'None'))) ) - self.infotext_fields.extend( [(script.group, onload_script_visibility) for script in self.selectable_scripts] ) + self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None')))) + self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts]) - return inputs + return self.inputs def run(self, p, *args): script_index = args[0] @@ -444,14 +459,21 @@ class ScriptRunner: return processed + def before_process(self, p): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.before_process(p, *script_args) + except Exception: + errors.report(f"Error running before_process: {script.filename}", exc_info=True) + def process(self, p): for script in self.alwayson_scripts: try: script_args = p.script_args[script.args_from:script.args_to] script.process(p, *script_args) except Exception: - print(f"Error running process: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error running process: {script.filename}", exc_info=True) def before_process_batch(self, p, **kwargs): for script in self.alwayson_scripts: @@ -459,8 +481,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.before_process_batch(p, *script_args, **kwargs) except Exception: - print(f"Error running before_process_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True) def process_batch(self, p, **kwargs): for script in self.alwayson_scripts: @@ -468,8 +489,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.process_batch(p, *script_args, **kwargs) except Exception: - print(f"Error running process_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error running process_batch: {script.filename}", exc_info=True) def postprocess(self, p, processed): for script in self.alwayson_scripts: @@ -477,8 +497,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess(p, processed, *script_args) except Exception: - print(f"Error running postprocess: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error running postprocess: {script.filename}", exc_info=True) def postprocess_batch(self, p, images, **kwargs): for script in self.alwayson_scripts: @@ -486,8 +505,7 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_batch(p, *script_args, images=images, **kwargs) except Exception: - print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True) def postprocess_image(self, p, pp: PostprocessImageArgs): for script in self.alwayson_scripts: @@ -495,24 +513,21 @@ class ScriptRunner: script_args = p.script_args[script.args_from:script.args_to] script.postprocess_image(p, pp, *script_args) except Exception: - print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True) def before_component(self, component, **kwargs): for script in self.scripts: try: script.before_component(component, **kwargs) except Exception: - print(f"Error running before_component: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error running before_component: {script.filename}", exc_info=True) def after_component(self, component, **kwargs): for script in self.scripts: try: script.after_component(component, **kwargs) except Exception: - print(f"Error running after_component: {script.filename}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error running after_component: {script.filename}", exc_info=True) def reload_sources(self, cache): for si, script in list(enumerate(self.scripts)): diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 221771da4719e2e5cf4ce85874c189395690df48..3b6f95ce2ea2ed85a1392d438c5cda2e597a0320 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -3,7 +3,7 @@ from torch.nn.functional import silu from types import MethodType import modules.textual_inversion.textual_inversion -from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors +from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr @@ -43,7 +43,7 @@ def list_optimizers(): optimizers.extend(new_optimizers) -def apply_optimizations(): +def apply_optimizations(option=None): global current_optimizer undo_optimizations() @@ -60,7 +60,7 @@ def apply_optimizations(): current_optimizer.undo() current_optimizer = None - selection = shared.opts.cross_attention_optimization + selection = option or shared.opts.cross_attention_optimization if selection == "Automatic" and len(optimizers) > 0: matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0]) else: @@ -74,12 +74,13 @@ def apply_optimizations(): matching_optimizer = optimizers[0] if matching_optimizer is not None: - print(f"Applying optimization: {matching_optimizer.name}... ", end='') + print(f"Applying attention optimization: {matching_optimizer.name}... ", end='') matching_optimizer.apply() print("done.") current_optimizer = matching_optimizer return current_optimizer.name else: + print("Disabling attention optimization") return '' @@ -157,9 +158,9 @@ class StableDiffusionModelHijack: def __init__(self): self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) - def apply_optimizations(self): + def apply_optimizations(self, option=None): try: - self.optimization_method = apply_optimizations() + self.optimization_method = apply_optimizations(option) except Exception as e: errors.display(e, "applying cross attention optimization") undo_optimizations() @@ -196,6 +197,11 @@ class StableDiffusionModelHijack: self.layers = flatten(m) + if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'): + ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward + + ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward + def undo_hijack(self, m): if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: m.cond_stage_model = m.cond_stage_model.wrapped @@ -217,6 +223,8 @@ class StableDiffusionModelHijack: self.layers = None self.clip = None + ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui + def apply_circular(self, enable): if self.circular_enabled == enable: return diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index cc6e8c21ec6f2469a03c1d3e2e89e2a91d6a255b..3b5a766671721fb00a321f393a1fff34c0633c4b 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -167,7 +167,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): chunk.multipliers += [weight] * emb_len position += embedding_length_in_tokens - if len(chunk.tokens) > 0 or len(chunks) == 0: + if chunk.tokens or not chunks: next_chunk(is_last=True) return chunks, token_count diff --git a/modules/sd_hijack_clip_old.py b/modules/sd_hijack_clip_old.py index a3476e9563315ead1dfda05bc96505b8caf642ad..c5c6270b9c44dee0ca45038c028e1d726ef2d6d3 100644 --- a/modules/sd_hijack_clip_old.py +++ b/modules/sd_hijack_clip_old.py @@ -74,7 +74,7 @@ def forward_old(self: sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase, text self.hijack.comments += hijack_comments - if len(used_custom_terms) > 0: + if used_custom_terms: embedding_names = ", ".join(f"{word} [{checksum}]" for word, checksum in used_custom_terms) self.hijack.comments.append(f"Used embeddings: {embedding_names}") diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 80e48a422e0201947b6039f1c44d2a27a76e6e23..53e27adea2caaf39f3d58a9a015bb36eb0395bba 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -1,7 +1,5 @@ from __future__ import annotations import math -import sys -import traceback import psutil import torch @@ -48,7 +46,7 @@ class SdOptimizationXformers(SdOptimization): priority = 100 def is_available(self): - return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)) + return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.cuda.is_available() and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)) def apply(self): ldm.modules.attention.CrossAttention.forward = xformers_attention_forward @@ -140,8 +138,7 @@ if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers: import xformers.ops shared.xformers_available = True except Exception: - print("Cannot import xformers", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report("Cannot import xformers", exc_info=True) def get_available_vram(): @@ -605,7 +602,7 @@ def sdp_attnblock_forward(self, x): q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v)) dtype = q.dtype if shared.opts.upcast_attn: - q, k = q.float(), k.float() + q, k, v = q.float(), k.float(), v.float() q = q.contiguous() k = k.contiguous() v = v.contiguous() diff --git a/modules/sd_models.py b/modules/sd_models.py index 7742998a575f22146811dd1dc765f5cd63dba60e..918f6fd6489e576840c348d2ef84fec451587a9d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd @@ -164,6 +164,7 @@ def model_hash(filename): def select_checkpoint(): + """Raises `FileNotFoundError` if no checkpoints are found.""" model_checkpoint = shared.opts.sd_model_checkpoint checkpoint_info = checkpoint_alisases.get(model_checkpoint, None) @@ -171,14 +172,14 @@ def select_checkpoint(): return checkpoint_info if len(checkpoints_list) == 0: - print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) + error_message = "No checkpoints found. When searching for checkpoints, looked at:" if shared.cmd_opts.ckpt is not None: - print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) - print(f" - directory {model_path}", file=sys.stderr) + error_message += f"\n - file {os.path.abspath(shared.cmd_opts.ckpt)}" + error_message += f"\n - directory {model_path}" if shared.cmd_opts.ckpt_dir is not None: - print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) - print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr) - exit(1) + error_message += f"\n - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}" + error_message += "Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations." + raise FileNotFoundError(error_message) checkpoint_info = next(iter(checkpoints_list.values())) if model_checkpoint is not None: @@ -421,7 +422,7 @@ class SdModelData: try: load_model() except Exception as e: - errors.display(e, "loading stable diffusion model") + errors.display(e, "loading stable diffusion model", full_traceback=True) print("", file=sys.stderr) print("Stable diffusion model failed to load", file=sys.stderr) self.sd_model = None @@ -506,6 +507,11 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("scripts callbacks") + with devices.autocast(), torch.no_grad(): + sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""]) + + timer.record("calculate empty prompt") + print(f"Model loaded in {timer.summary()}.") return sd_model @@ -525,6 +531,8 @@ def reload_model_weights(sd_model=None, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return + sd_unet.apply_unet("None") + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() else: diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 59982fc9ba5d9aec9fd9db2cbd50cca158cd6547..f8a0c7ba820ba9ab787779ed79f3df26c7b5926d 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -20,7 +20,7 @@ samplers_k_diffusion = [ ('DPM++ 2S a', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a'], {"uses_ensd": True, "second_order": True}), ('DPM++ 2M', 'sample_dpmpp_2m', ['k_dpmpp_2m'], {}), ('DPM++ SDE', 'sample_dpmpp_sde', ['k_dpmpp_sde'], {"second_order": True, "brownian_noise": True}), - ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True, 'discard_next_to_last_sigma': True}), + ('DPM++ 2M SDE', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {"brownian_noise": True}), ('DPM fast', 'sample_dpm_fast', ['k_dpm_fast'], {"uses_ensd": True}), ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad'], {"uses_ensd": True}), ('LMS Karras', 'sample_lms', ['k_lms_ka'], {'scheduler': 'karras'}), @@ -29,7 +29,7 @@ samplers_k_diffusion = [ ('DPM++ 2S a Karras', 'sample_dpmpp_2s_ancestral', ['k_dpmpp_2s_a_ka'], {'scheduler': 'karras', "uses_ensd": True, "second_order": True}), ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), - ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True, 'discard_next_to_last_sigma': True}), + ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), ] samplers_data_k_diffusion = [ @@ -44,6 +44,14 @@ sampler_extra_params = { 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], } +k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion} +k_diffusion_scheduler = { + 'Automatic': None, + 'karras': k_diffusion.sampling.get_sigmas_karras, + 'exponential': k_diffusion.sampling.get_sigmas_exponential, + 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential +} + class CFGDenoiser(torch.nn.Module): """ @@ -125,6 +133,16 @@ class CFGDenoiser(torch.nn.Module): x_in = x_in[:-batch_size] sigma_in = sigma_in[:-batch_size] + # TODO add infotext entry + if shared.opts.pad_cond_uncond and tensor.shape[1] != uncond.shape[1]: + empty = shared.sd_model.cond_stage_model_empty_prompt + num_repeats = (tensor.shape[1] - uncond.shape[1]) // empty.shape[1] + + if num_repeats < 0: + tensor = torch.cat([tensor, empty.repeat((tensor.shape[0], -num_repeats, 1))], axis=1) + elif num_repeats > 0: + uncond = torch.cat([uncond, empty.repeat((uncond.shape[0], num_repeats, 1))], axis=1) + if tensor.shape[1] == uncond.shape[1] or skip_uncond: if is_edit_model: cond_in = torch.cat([tensor, uncond, uncond]) @@ -255,6 +273,13 @@ class KDiffusionSampler: try: return func() + except RecursionError: + print( + 'Encountered RecursionError during sampling, returning last latent. ' + 'rho >5 with a polyexponential scheduler may cause this error. ' + 'You should try to use a smaller rho value instead.' + ) + return self.last_latent except sd_samplers_common.InterruptedException: return self.last_latent @@ -294,6 +319,31 @@ class KDiffusionSampler: if p.sampler_noise_scheduler_override: sigmas = p.sampler_noise_scheduler_override(steps) + elif opts.k_sched_type != "Automatic": + m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) + sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (m_sigma_min, m_sigma_max) + sigmas_kwargs = { + 'sigma_min': sigma_min, + 'sigma_max': sigma_max, + } + + sigmas_func = k_diffusion_scheduler[opts.k_sched_type] + p.extra_generation_params["Schedule type"] = opts.k_sched_type + + if opts.sigma_min != m_sigma_min and opts.sigma_min != 0: + sigmas_kwargs['sigma_min'] = opts.sigma_min + p.extra_generation_params["Schedule min sigma"] = opts.sigma_min + if opts.sigma_max != m_sigma_max and opts.sigma_max != 0: + sigmas_kwargs['sigma_max'] = opts.sigma_max + p.extra_generation_params["Schedule max sigma"] = opts.sigma_max + + default_rho = 1. if opts.k_sched_type == "polyexponential" else 7. + + if opts.k_sched_type != 'exponential' and opts.rho != 0 and opts.rho != default_rho: + sigmas_kwargs['rho'] = opts.rho + p.extra_generation_params["Schedule rho"] = opts.rho + + sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device) elif self.config is not None and self.config.options.get('scheduler', None) == 'karras': sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item()) diff --git a/modules/sd_unet.py b/modules/sd_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..6d708ad296bbb5a46a4925db350fb6f319572ddb --- /dev/null +++ b/modules/sd_unet.py @@ -0,0 +1,92 @@ +import torch.nn +import ldm.modules.diffusionmodules.openaimodel + +from modules import script_callbacks, shared, devices + +unet_options = [] +current_unet_option = None +current_unet = None + + +def list_unets(): + new_unets = script_callbacks.list_unets_callback() + + unet_options.clear() + unet_options.extend(new_unets) + + +def get_unet_option(option=None): + option = option or shared.opts.sd_unet + + if option == "None": + return None + + if option == "Automatic": + name = shared.sd_model.sd_checkpoint_info.model_name + + options = [x for x in unet_options if x.model_name == name] + + option = options[0].label if options else "None" + + return next(iter([x for x in unet_options if x.label == option]), None) + + +def apply_unet(option=None): + global current_unet_option + global current_unet + + new_option = get_unet_option(option) + if new_option == current_unet_option: + return + + if current_unet is not None: + print(f"Dectivating unet: {current_unet.option.label}") + current_unet.deactivate() + + current_unet_option = new_option + if current_unet_option is None: + current_unet = None + + if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram): + shared.sd_model.model.diffusion_model.to(devices.device) + + return + + shared.sd_model.model.diffusion_model.to(devices.cpu) + devices.torch_gc() + + current_unet = current_unet_option.create_unet() + current_unet.option = current_unet_option + print(f"Activating unet: {current_unet.option.label}") + current_unet.activate() + + +class SdUnetOption: + model_name = None + """name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this""" + + label = None + """name of the unet in UI""" + + def create_unet(self): + """returns SdUnet object to be used as a Unet instead of built-in unet when making pictures""" + raise NotImplementedError() + + +class SdUnet(torch.nn.Module): + def forward(self, x, timesteps, context, *args, **kwargs): + raise NotImplementedError() + + def activate(self): + pass + + def deactivate(self): + pass + + +def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): + if current_unet is not None: + return current_unet.forward(x, timesteps, context, *args, **kwargs) + + return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs) + diff --git a/modules/shared.py b/modules/shared.py index 271a062d9c73b8ca3f9e012341d9da05dcb94a65..91c31d5564ea7e4a4ee211ff725d0389dafaf1da 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -44,19 +44,6 @@ restricted_opts = { "outdir_init_images" } -ui_reorder_categories = [ - "inpaint", - "sampler", - "checkboxes", - "hires_fix", - "dimensions", - "cfg", - "seed", - "batch", - "override_settings", - "scripts", -] - # https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json gradio_hf_hub_themes = [ "gradio/glass", @@ -273,6 +260,10 @@ class OptionInfo: self.comment_after += f"({info})" return self + def html(self, html): + self.comment_after += html + return self + def needs_restart(self): self.comment_after += " (requires restart)" return self @@ -318,6 +309,7 @@ options_templates.update(options_section(('saving-images', "Saving images/grids" "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"), "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"), "grid_prevent_empty_spots": OptionInfo(False, "Prevent empty spots in grid (when set to autodetect)"), + "grid_zip_filename_pattern": OptionInfo("", "Archive filename pattern", component_args=hide_dirs).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Custom-Images-Filename-Name-and-Subdirectory"), "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}), "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"), @@ -407,6 +399,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), + "sd_unet": OptionInfo("Automatic", "SD Unet", gr.Dropdown, lambda: {"choices": shared_items.sd_unet_items()}, refresh=shared_items.refresh_unet_list).info("choose Unet model: Automatic = use one with same filename as checkpoint; None = use Unet from checkpoint"), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01}), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), @@ -418,15 +411,17 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "comma_padding_backtrack": OptionInfo(20, "Prompt word wrap length limit", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1}).info("in tokens - for texts shorter than specified, if they don't fit into 75 token limit, move them to the next 75 token chunk"), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP nrtwork; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), - "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different vidocard vendors"), + "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), })) options_templates.update(options_section(('optimizations', "Optimizations"), { "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}), - "s_min_uncond": OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), + "s_min_uncond": OptionInfo(0.0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"), "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"), "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), "token_merging_ratio_hr": OptionInfo(0.0, "Token merging ratio for high-res pass", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"), + "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length").info("improves performance when prompt and negative prompt have different lengths; changes seeds"), + "experimental_persistent_cond_cache": OptionInfo(False, "persistent cond cache").info("Experimental, keep cond caches across jobs, reduce overhead."), })) options_templates.update(options_section(('compatibility', "Compatibility"), { @@ -435,6 +430,7 @@ options_templates.update(options_section(('compatibility', "Compatibility"), { "no_dpmpp_sde_batch_determinism": OptionInfo(False, "Do not make DPM++ SDE deterministic across different batch sizes."), "use_old_hires_fix_width_height": OptionInfo(False, "For hires fix, use width/height sliders to set final resolution rather than first pass (disables Upscale by, Resize width/height to)."), "dont_fix_second_order_samplers_schedule": OptionInfo(False, "Do not fix prompt schedule for second order samplers."), + "hires_fix_use_firstpass_conds": OptionInfo(False, "For hires fix, calculate conds of second pass using extra networks of first pass."), })) options_templates.update(options_section(('interrogate', "Interrogate Options"), { @@ -488,16 +484,24 @@ options_templates.update(options_section(('ui', "User interface"), { "quicksettings_list": OptionInfo(["sd_model_checkpoint"], "Quicksettings list", ui_components.DropdownMulti, lambda: {"choices": list(opts.data_labels.keys())}).js("info", "settingsHintsShowQuicksettings").info("setting entries that appear at the top of page rather than in settings tab").needs_restart(), "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(), "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(), - "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order").needs_restart(), + "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(), "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(), "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(), + "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(), })) options_templates.update(options_section(('infotext', "Infotext"), { "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), "add_model_name_to_info": OptionInfo(True, "Add model name to generation information"), "add_version_to_infotext": OptionInfo(True, "Add program version to generation information"), - "disable_weights_auto_swap": OptionInfo(True, "When reading generation parameters from text into UI (from PNG info or pasted text), do not change the selected model/checkpoint."), + "disable_weights_auto_swap": OptionInfo(True, "Disregard checkpoint information from pasted infotext").info("when reading generation parameters from text into UI"), + "infotext_styles": OptionInfo("Apply if any", "Infer styles from prompts of pasted infotext", gr.Radio, {"choices": ["Ignore", "Apply", "Discard", "Apply if any"]}).info("when reading generation parameters from text into UI)").html("""
    +
  • Ignore: keep prompt and styles dropdown as it is.
  • +
  • Apply: remove style text from prompt, always replace styles dropdown value with found styles (even if none are found).
  • +
  • Discard: remove style text from prompt, keep styles dropdown as it is.
  • +
  • Apply if any: remove style text from prompt; if any styles are found in prompt, put them into styles dropdown, otherwise keep it as it is.
  • +
"""), + })) options_templates.update(options_section(('ui', "Live previews"), { @@ -519,6 +523,10 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_tmin': OptionInfo(0.0, "sigma tmin", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + 'k_sched_type': OptionInfo("Automatic", "scheduler type", gr.Dropdown, {"choices": ["Automatic", "karras", "exponential", "polyexponential"]}).info("lets you override the noise schedule for k-diffusion samplers; choosing Automatic disables the three parameters below"), + 'sigma_min': OptionInfo(0.0, "sigma min", gr.Number).info("0 = default (~0.03); minimum noise strength for k-diffusion noise scheduler"), + 'sigma_max': OptionInfo(0.0, "sigma max", gr.Number).info("0 = default (~14.6); maximum noise strength for k-diffusion noise schedule"), + 'rho': OptionInfo(0.0, "rho", gr.Number).info("0 = default (7 for karras, 1 for polyexponential); higher values result in a more steep noise schedule (decreases faster)"), 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}).info("ENSD; does not improve anything, just produces different results for ancestral samplers - only useful for reproducing images"), 'always_discard_next_to_last_sigma': OptionInfo(False, "Always discard next-to-last sigma").link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/6044"), 'uni_pc_variant': OptionInfo("bh1", "UniPC variant", gr.Radio, {"choices": ["bh1", "bh2", "vary_coeff"]}), @@ -634,6 +642,10 @@ class Options: if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None: self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')] + # 1.4.0 ui_reorder + if isinstance(self.data.get('ui_reorder'), str) and self.data.get('ui_reorder') and "ui_reorder_list" not in self.data: + self.data['ui_reorder_list'] = [i.strip() for i in self.data.get('ui_reorder').split(',')] + bad_settings = 0 for k, v in self.data.items(): info = self.data_labels.get(k, None) diff --git a/modules/shared_items.py b/modules/shared_items.py index 2a8713c8716f436f99b354fb8a22dc437a4573fb..89792e88aec9a0f66ce44e9983cbae4e41ffe897 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -29,3 +29,41 @@ def cross_attention_optimizations(): return ["Automatic"] + [x.title() for x in modules.sd_hijack.optimizers] + ["None"] +def sd_unet_items(): + import modules.sd_unet + + return ["Automatic"] + [x.label for x in modules.sd_unet.unet_options] + ["None"] + + +def refresh_unet_list(): + import modules.sd_unet + + modules.sd_unet.list_unets() + + +ui_reorder_categories_builtin_items = [ + "inpaint", + "sampler", + "checkboxes", + "hires_fix", + "dimensions", + "cfg", + "seed", + "batch", + "override_settings", +] + + +def ui_reorder_categories(): + from modules import scripts + + yield from ui_reorder_categories_builtin_items + + sections = {} + for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts: + if isinstance(script.section, str): + sections[script.section] = 1 + + yield from sections + + yield "scripts" diff --git a/modules/styles.py b/modules/styles.py index 34e1b5e15e3708ff4eb9f46cc45ef94d46ccee62..ec0e1bc51dbfb6415da3c953c96fd524a554ac0d 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -1,6 +1,7 @@ import csv import os import os.path +import re import typing import shutil @@ -28,6 +29,44 @@ def apply_styles_to_prompt(prompt, styles): return prompt +re_spaces = re.compile(" +") + + +def extract_style_text_from_prompt(style_text, prompt): + stripped_prompt = re.sub(re_spaces, " ", prompt.strip()) + stripped_style_text = re.sub(re_spaces, " ", style_text.strip()) + if "{prompt}" in stripped_style_text: + left, right = stripped_style_text.split("{prompt}", 2) + if stripped_prompt.startswith(left) and stripped_prompt.endswith(right): + prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)] + return True, prompt + else: + if stripped_prompt.endswith(stripped_style_text): + prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)] + + if prompt.endswith(', '): + prompt = prompt[:-2] + + return True, prompt + + return False, prompt + + +def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt): + if not style.prompt and not style.negative_prompt: + return False, prompt, negative_prompt + + match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt) + if not match_positive: + return False, prompt, negative_prompt + + match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt) + if not match_negative: + return False, prompt, negative_prompt + + return True, extracted_positive, extracted_negative + + class StyleDatabase: def __init__(self, path: str): self.no_style = PromptStyle("None", "", "") @@ -67,10 +106,34 @@ class StyleDatabase: if os.path.exists(path): shutil.copy(path, f"{path}.bak") - fd = os.open(path, os.O_RDWR|os.O_CREAT) + fd = os.open(path, os.O_RDWR | os.O_CREAT) with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) writer.writeheader() - writer.writerows(style._asdict() for k, style in self.styles.items()) + writer.writerows(style._asdict() for k, style in self.styles.items()) + + def extract_styles_from_prompt(self, prompt, negative_prompt): + extracted = [] + + applicable_styles = list(self.styles.values()) + + while True: + found_style = None + + for style in applicable_styles: + is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt) + if is_match: + found_style = style + prompt = new_prompt + negative_prompt = new_neg_prompt + break + + if not found_style: + break + + applicable_styles.remove(found_style) + extracted.append(found_style.name) + + return list(reversed(extracted)), prompt, negative_prompt diff --git a/modules/sysinfo.py b/modules/sysinfo.py new file mode 100644 index 0000000000000000000000000000000000000000..5f15ac4fa94ca0eb1a41bac92623963bd6140189 --- /dev/null +++ b/modules/sysinfo.py @@ -0,0 +1,162 @@ +import json +import os +import sys +import traceback + +import platform +import hashlib +import pkg_resources +import psutil +import re + +import launch +from modules import paths_internal, timer + +checksum_token = "DontStealMyGamePlz__WINNERS_DONT_USE_DRUGS__DONT_COPY_THAT_FLOPPY" +environment_whitelist = { + "GIT", + "INDEX_URL", + "WEBUI_LAUNCH_LIVE_OUTPUT", + "GRADIO_ANALYTICS_ENABLED", + "PYTHONPATH", + "TORCH_INDEX_URL", + "TORCH_COMMAND", + "REQS_FILE", + "XFORMERS_PACKAGE", + "GFPGAN_PACKAGE", + "CLIP_PACKAGE", + "OPENCLIP_PACKAGE", + "STABLE_DIFFUSION_REPO", + "K_DIFFUSION_REPO", + "CODEFORMER_REPO", + "BLIP_REPO", + "STABLE_DIFFUSION_COMMIT_HASH", + "K_DIFFUSION_COMMIT_HASH", + "CODEFORMER_COMMIT_HASH", + "BLIP_COMMIT_HASH", + "COMMANDLINE_ARGS", + "IGNORE_CMD_ARGS_ERRORS", +} + + +def pretty_bytes(num, suffix="B"): + for unit in ["", "K", "M", "G", "T", "P", "E", "Z", "Y"]: + if abs(num) < 1024 or unit == 'Y': + return f"{num:.0f}{unit}{suffix}" + num /= 1024 + + +def get(): + res = get_dict() + + text = json.dumps(res, ensure_ascii=False, indent=4) + + h = hashlib.sha256(text.encode("utf8")) + text = text.replace(checksum_token, h.hexdigest()) + + return text + + +re_checksum = re.compile(r'"Checksum": "([0-9a-fA-F]{64})"') + + +def check(x): + m = re.search(re_checksum, x) + if not m: + return False + + replaced = re.sub(re_checksum, f'"Checksum": "{checksum_token}"', x) + + h = hashlib.sha256(replaced.encode("utf8")) + return h.hexdigest() == m.group(1) + + +def get_dict(): + ram = psutil.virtual_memory() + + res = { + "Platform": platform.platform(), + "Python": platform.python_version(), + "Version": launch.git_tag(), + "Commit": launch.commit_hash(), + "Script path": paths_internal.script_path, + "Data path": paths_internal.data_path, + "Extensions dir": paths_internal.extensions_dir, + "Checksum": checksum_token, + "Commandline": sys.argv, + "Torch env info": get_torch_sysinfo(), + "Exceptions": get_exceptions(), + "CPU": { + "model": platform.processor(), + "count logical": psutil.cpu_count(logical=True), + "count physical": psutil.cpu_count(logical=False), + }, + "RAM": { + x: pretty_bytes(getattr(ram, x, 0)) for x in ["total", "used", "free", "active", "inactive", "buffers", "cached", "shared"] if getattr(ram, x, 0) != 0 + }, + "Extensions": get_extensions(enabled=True), + "Inactive extensions": get_extensions(enabled=False), + "Environment": get_environment(), + "Config": get_config(), + "Startup": timer.startup_record, + "Packages": sorted([f"{pkg.key}=={pkg.version}" for pkg in pkg_resources.working_set]), + } + + return res + + +def format_traceback(tb): + return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)] + + +def get_exceptions(): + try: + from modules import errors + + return [{"exception": str(e), "traceback": format_traceback(tb)} for e, tb in reversed(errors.exception_records)] + except Exception as e: + return str(e) + + +def get_environment(): + return {k: os.environ[k] for k in sorted(os.environ) if k in environment_whitelist} + + +re_newline = re.compile(r"\r*\n") + + +def get_torch_sysinfo(): + try: + import torch.utils.collect_env + info = torch.utils.collect_env.get_env_info()._asdict() + + return {k: re.split(re_newline, str(v)) if "\n" in str(v) else v for k, v in info.items()} + except Exception as e: + return str(e) + + +def get_extensions(*, enabled): + + try: + from modules import extensions + + def to_json(x: extensions.Extension): + return { + "name": x.name, + "path": x.path, + "version": x.version, + "branch": x.branch, + "remote": x.remote, + } + + return [to_json(x) for x in extensions.extensions if not x.is_builtin and x.enabled == enabled] + except Exception as e: + return str(e) + + +def get_config(): + try: + from modules import shared + return shared.opts.data + except Exception as e: + return str(e) diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 8e667a4d3ee4a6803d63b70348bc9d42d07e3057..75705459b564531114e895382409d1db394b115d 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -77,27 +77,27 @@ def focal_point(im, settings): pois = [] weight_pref_total = 0 - if len(corner_points) > 0: + if corner_points: weight_pref_total += settings.corner_points_weight - if len(entropy_points) > 0: + if entropy_points: weight_pref_total += settings.entropy_points_weight - if len(face_points) > 0: + if face_points: weight_pref_total += settings.face_points_weight corner_centroid = None - if len(corner_points) > 0: + if corner_points: corner_centroid = centroid(corner_points) corner_centroid.weight = settings.corner_points_weight / weight_pref_total pois.append(corner_centroid) entropy_centroid = None - if len(entropy_points) > 0: + if entropy_points: entropy_centroid = centroid(entropy_points) entropy_centroid.weight = settings.entropy_points_weight / weight_pref_total pois.append(entropy_centroid) face_centroid = None - if len(face_points) > 0: + if face_points: face_centroid = centroid(face_points) face_centroid.weight = settings.face_points_weight / weight_pref_total pois.append(face_centroid) @@ -187,7 +187,7 @@ def image_face_points(im, settings): except Exception: continue - if len(faces) > 0: + if faces: rects = [[f[0], f[1], f[0] + f[2], f[1] + f[3]] for f in faces] return [PointOfInterest((r[0] +r[2]) // 2, (r[1] + r[3]) // 2, size=abs(r[0]-r[2]), weight=1/len(rects)) for r in rects] return [] diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index b9621fc952972a89ba5e3d21082f94dfde625c5e..7ee050615459a4c597c37f104644e2b76d7a6987 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -32,7 +32,7 @@ class DatasetEntry: class PersonalizedBase(Dataset): def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False, use_weight=False): - re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None + re_word = re.compile(shared.opts.dataset_filename_word_regex) if shared.opts.dataset_filename_word_regex else None self.placeholder_token = placeholder_token diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index a009d8e80067c0b8a994b8b2c351f266bfbd2254..0d4c3f84b0b2029a462b49029153d75038ffdee5 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -47,7 +47,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti caption += shared.interrogator.generate_caption(image) if params.process_caption_deepbooru: - if len(caption) > 0: + if caption: caption += ", " caption += deepbooru.model.tag_multi(image) @@ -67,7 +67,7 @@ def save_pic_with_caption(image, index, params: PreprocessParams, existing_capti caption = caption.strip() - if len(caption) > 0: + if caption: with open(os.path.join(params.dstdir, f"{basename}.txt"), "w", encoding="utf8") as file: file.write(caption) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index d489ed1e0c93cb0a55d753ab6db029e465bb295d..bb6f211caa24cc55cf2b79ba762de446af9d75d8 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -1,6 +1,4 @@ import os -import sys -import traceback from collections import namedtuple import torch @@ -14,7 +12,7 @@ import numpy as np from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter -from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint +from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -120,16 +118,29 @@ class EmbeddingDatabase: self.embedding_dirs.clear() def register_embedding(self, embedding, model): - self.word_embeddings[embedding.name] = embedding - - ids = model.cond_stage_model.tokenize([embedding.name])[0] + return self.register_embedding_by_name(embedding, model, embedding.name) + def register_embedding_by_name(self, embedding, model, name): + ids = model.cond_stage_model.tokenize([name])[0] first_id = ids[0] if first_id not in self.ids_lookup: self.ids_lookup[first_id] = [] - - self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True) - + if name in self.word_embeddings: + # remove old one from the lookup list + lookup = [x for x in self.ids_lookup[first_id] if x[1].name!=name] + else: + lookup = self.ids_lookup[first_id] + if embedding is not None: + lookup += [(ids, embedding)] + self.ids_lookup[first_id] = sorted(lookup, key=lambda x: len(x[0]), reverse=True) + if embedding is None: + # unregister embedding with specified name + if name in self.word_embeddings: + del self.word_embeddings[name] + if len(self.ids_lookup[first_id])==0: + del self.ids_lookup[first_id] + return None + self.word_embeddings[name] = embedding return embedding def get_expected_shape(self): @@ -207,8 +218,7 @@ class EmbeddingDatabase: self.load_from_file(fullfn, fn) except Exception: - print(f"Error loading embedding {fn}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error loading embedding {fn}", exc_info=True) continue def load_textual_inversion_embeddings(self, force_reload=False): @@ -241,7 +251,7 @@ class EmbeddingDatabase: if self.previously_displayed_embeddings != displayed_embeddings: self.previously_displayed_embeddings = displayed_embeddings print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}") - if len(self.skipped_embeddings) > 0: + if self.skipped_embeddings: print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}") def find_embedding_at_position(self, tokens, offset): @@ -632,8 +642,7 @@ Last saved image: {html.escape(last_saved_image)}
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt') save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True) except Exception: - print(traceback.format_exc(), file=sys.stderr) - pass + errors.report("Error training embedding", exc_info=True) finally: pbar.leave = False pbar.close() diff --git a/modules/timer.py b/modules/timer.py index ba92be336cdb21eef875b8b82cb30b788ceac32a..da99e49f82df4260676cf8205cf6bcc93b9d3f01 100644 --- a/modules/timer.py +++ b/modules/timer.py @@ -1,11 +1,30 @@ import time +class TimerSubcategory: + def __init__(self, timer, category): + self.timer = timer + self.category = category + self.start = None + self.original_base_category = timer.base_category + + def __enter__(self): + self.start = time.time() + self.timer.base_category = self.original_base_category + self.category + "/" + + def __exit__(self, exc_type, exc_val, exc_tb): + elapsed_for_subcategroy = time.time() - self.start + self.timer.base_category = self.original_base_category + self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy) + self.timer.record(self.category) + + class Timer: def __init__(self): self.start = time.time() self.records = {} self.total = 0 + self.base_category = '' def elapsed(self): end = time.time() @@ -13,18 +32,29 @@ class Timer: self.start = end return res - def record(self, category, extra_time=0): - e = self.elapsed() + def add_time_to_record(self, category, amount): if category not in self.records: self.records[category] = 0 - self.records[category] += e + extra_time + self.records[category] += amount + + def record(self, category, extra_time=0): + e = self.elapsed() + + self.add_time_to_record(self.base_category + category, e + extra_time) + self.total += e + extra_time + def subcategory(self, name): + self.elapsed() + + subcat = TimerSubcategory(self, name) + return subcat + def summary(self): res = f"{self.total:.1f}s" - additions = [x for x in self.records.items() if x[1] >= 0.1] + additions = [(category, time_taken) for category, time_taken in self.records.items() if time_taken >= 0.1 and '/' not in category] if not additions: return res @@ -34,5 +64,13 @@ class Timer: return res + def dump(self): + return {'total': self.total, 'records': self.records} + def reset(self): self.__init__() + + +startup_timer = Timer() + +startup_record = None diff --git a/modules/ui.py b/modules/ui.py index 361f596eb67a7f6c0216c3bd78a28301932d3be0..e2e3b6da3122610387b2dfae14e80ca8f098ea44 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1,21 +1,23 @@ +import datetime import json import mimetypes import os import sys -import traceback from functools import reduce import warnings import gradio as gr -import gradio.routes import gradio.utils import numpy as np from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave +from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML -from modules.paths import script_path, data_path +from modules.paths import script_path +from modules.ui_common import create_refresh_button +from modules.ui_gradio_extensions import reload_javascript + from modules.shared import opts, cmd_opts @@ -35,6 +37,8 @@ import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text import modules.extras +create_setting_component = ui_settings.create_setting_component + warnings.filterwarnings("default" if opts.show_warnings else "ignore", category=UserWarning) # this is a fix for Windows users. Without it, javascript files will be served with text/html content-type and the browser will not show any UI @@ -76,6 +80,7 @@ extra_networks_symbol = '\U0001F3B4' # 🎴 switch_values_symbol = '\U000021C5' # ⇅ restore_progress_symbol = '\U0001F300' # 🌀 detect_image_size_symbol = '\U0001F4D0' # 📐 +up_down_symbol = '\u2195\ufe0f' # ↕️ def plaintext_to_html(text): @@ -231,9 +236,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: res = all_seeds[index if 0 <= index < len(all_seeds) else 0] except json.decoder.JSONDecodeError: - if gen_info_string != '': - print("Error parsing JSON generation info:", file=sys.stderr) - print(gen_info_string, file=sys.stderr) + if gen_info_string: + errors.report(f"Error parsing JSON generation info: {gen_info_string}") return [res, gr_show(False)] @@ -368,25 +372,6 @@ def apply_setting(key, value): return getattr(opts, key) -def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): - def refresh(): - refresh_method() - args = refreshed_args() if callable(refreshed_args) else refreshed_args - - for k, v in args.items(): - setattr(refresh_component, k, v) - - return gr.update(**(args or {})) - - refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) - refresh_button.click( - fn=refresh, - inputs=[], - outputs=[refresh_component] - ) - return refresh_button - - def create_output_panel(tabname, outdir): return ui_common.create_output_panel(tabname, outdir) @@ -405,27 +390,17 @@ def create_sampler_and_steps_selection(choices, tabname): def ordered_ui_categories(): - user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder.split(","))} + user_order = {x.strip(): i * 2 + 1 for i, x in enumerate(shared.opts.ui_reorder_list)} - for _, category in sorted(enumerate(shared.ui_reorder_categories), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)): + for _, category in sorted(enumerate(shared_items.ui_reorder_categories()), key=lambda x: user_order.get(x[1], x[0] * 2 + 0)): yield category -def get_value_for_setting(key): - value = getattr(opts, key) - - info = opts.data_labels[key] - args = info.component_args() if callable(info.component_args) else info.component_args or {} - args = {k: v for k, v in args.items() if k not in {'precision'}} - - return gr.update(value=value, **args) - - def create_override_settings_dropdown(tabname, row): dropdown = gr.Dropdown([], label="Override settings", visible=False, elem_id=f"{tabname}_override_settings", multiselect=True) dropdown.change( - fn=lambda x: gr.Dropdown.update(visible=len(x) > 0), + fn=lambda x: gr.Dropdown.update(visible=bool(x)), inputs=[dropdown], outputs=[dropdown], ) @@ -456,6 +431,8 @@ def create_ui(): with gr.Row().style(equal_height=False): with gr.Column(variant='compact', elem_id="txt2img_settings"): + modules.scripts.scripts_txt2img.prepare_ui() + for category in ordered_ui_categories(): if category == "sampler": steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") @@ -524,6 +501,9 @@ def create_ui(): with FormGroup(elem_id="txt2img_script_container"): custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + else: + modules.scripts.scripts_txt2img.setup_ui_for_section(category) + hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] for component in hr_resolution_preview_inputs: @@ -616,7 +596,8 @@ def create_ui(): outputs=[ txt2img_prompt, txt_prompt_img - ] + ], + show_progress=False, ) enable_hr.change( @@ -641,6 +622,7 @@ def create_ui(): (subseed_strength, "Variation seed strength"), (seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_h, "Seed resize from-2"), + (txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (enable_hr, lambda d: "Denoising strength" in d), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), @@ -779,6 +761,8 @@ def create_ui(): with FormRow(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") + modules.scripts.scripts_img2img.prepare_ui() + for category in ordered_ui_categories(): if category == "sampler": steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") @@ -888,6 +872,8 @@ def create_ui(): inputs=[], outputs=[inpaint_controls, mask_alpha], ) + else: + modules.scripts.scripts_img2img.setup_ui_for_section(category) img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) @@ -902,7 +888,8 @@ def create_ui(): outputs=[ img2img_prompt, img2img_prompt_img - ] + ], + show_progress=False, ) img2img_args = dict( @@ -1051,6 +1038,7 @@ def create_ui(): (subseed_strength, "Variation seed strength"), (seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_h, "Seed resize from-2"), + (img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), *modules.scripts.scripts_img2img.infotext_fields @@ -1460,195 +1448,10 @@ def create_ui(): outputs=[], ) - def create_setting_component(key, is_quicksettings=False): - def fun(): - return opts.data[key] if key in opts.data else opts.data_labels[key].default - - info = opts.data_labels[key] - t = type(info.default) - - args = info.component_args() if callable(info.component_args) else info.component_args - - if info.component is not None: - comp = info.component - elif t == str: - comp = gr.Textbox - elif t == int: - comp = gr.Number - elif t == bool: - comp = gr.Checkbox - else: - raise Exception(f'bad options item type: {t} for key {key}') - - elem_id = f"setting_{key}" - - if info.refresh is not None: - if is_quicksettings: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}") - else: - with FormRow(): - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}") - else: - res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) - - return res - loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file) - components = [] - component_dict = {} - shared.settings_components = component_dict - - script_callbacks.ui_settings_callback() - opts.reorder() - - def run_settings(*args): - changed = [] - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" - - for key, value, comp in zip(opts.data_labels.keys(), args, components): - if comp == dummy_component: - continue - - if opts.set(key, value): - changed.append(key) - - try: - opts.save(shared.config_filename) - except RuntimeError: - return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' - return opts.dumpjson(), f'{len(changed)} settings changed{": " if len(changed) > 0 else ""}{", ".join(changed)}.' - - def run_settings_single(value, key): - if not opts.same_type(value, opts.data_labels[key].default): - return gr.update(visible=True), opts.dumpjson() - - if not opts.set(key, value): - return gr.update(value=getattr(opts, key)), opts.dumpjson() - - opts.save(shared.config_filename) - - return get_value_for_setting(key), opts.dumpjson() - - with gr.Blocks(analytics_enabled=False) as settings_interface: - with gr.Row(): - with gr.Column(scale=6): - settings_submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") - with gr.Column(): - restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") - - result = gr.HTML(elem_id="settings_result") - - quicksettings_names = opts.quicksettings_list - quicksettings_names = {x: i for i, x in enumerate(quicksettings_names) if x != 'quicksettings'} - - quicksettings_list = [] - - previous_section = None - current_tab = None - current_row = None - with gr.Tabs(elem_id="settings"): - for i, (k, item) in enumerate(opts.data_labels.items()): - section_must_be_skipped = item.section[0] is None - - if previous_section != item.section and not section_must_be_skipped: - elem_id, text = item.section - - if current_tab is not None: - current_row.__exit__() - current_tab.__exit__() - - gr.Group() - current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text) - current_tab.__enter__() - current_row = gr.Column(variant='compact') - current_row.__enter__() - - previous_section = item.section - - if k in quicksettings_names and not shared.cmd_opts.freeze_settings: - quicksettings_list.append((i, k, item)) - components.append(dummy_component) - elif section_must_be_skipped: - components.append(dummy_component) - else: - component = create_setting_component(k) - component_dict[k] = component - components.append(component) - - if current_tab is not None: - current_row.__exit__() - current_tab.__exit__() - - with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"): - loadsave.create_ui() - - with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"): - request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") - download_localization = gr.Button(value='Download localization template', elem_id="download_localization") - reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") - with gr.Row(): - unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model") - reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model") - - with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"): - gr.HTML(shared.html("licenses.html"), elem_id="licenses") - - gr.Button(value="Show all pages", elem_id="settings_show_all_pages") - - - def unload_sd_weights(): - modules.sd_models.unload_model_weights() - - def reload_sd_weights(): - modules.sd_models.reload_model_weights() - - unload_sd_model.click( - fn=unload_sd_weights, - inputs=[], - outputs=[] - ) - - reload_sd_model.click( - fn=reload_sd_weights, - inputs=[], - outputs=[] - ) - - request_notifications.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='function(){}' - ) - - download_localization.click( - fn=lambda: None, - inputs=[], - outputs=[], - _js='download_localization' - ) - - def reload_scripts(): - modules.scripts.reload_script_body_only() - reload_javascript() # need to refresh the html page - - reload_script_bodies.click( - fn=reload_scripts, - inputs=[], - outputs=[] - ) - - restart_gradio.click( - fn=shared.state.request_restart, - _js='restart_reload', - inputs=[], - outputs=[], - ) + settings = ui_settings.UiSettings() + settings.create_ui(loadsave, dummy_component) interfaces = [ (txt2img_interface, "txt2img", "txt2img"), @@ -1660,7 +1463,7 @@ def create_ui(): ] interfaces += script_callbacks.ui_tabs_callback() - interfaces += [(settings_interface, "Settings", "settings")] + interfaces += [(settings.interface, "Settings", "settings")] extensions_interface = ui_extensions.create_ui() interfaces += [(extensions_interface, "Extensions", "extensions")] @@ -1670,10 +1473,7 @@ def create_ui(): shared.tab_names.append(label) with gr.Blocks(theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo: - with gr.Row(elem_id="quicksettings", variant="compact"): - for _i, k, _item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): - component = create_setting_component(k, is_quicksettings=True) - component_dict[k] = component + settings.add_quicksettings() parameters_copypaste.connect_paste_params_buttons() @@ -1701,58 +1501,20 @@ def create_ui(): gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) footer = shared.html("footer.html") - footer = footer.format(versions=versions_html()) + footer = footer.format(versions=versions_html(), api_docs="/docs" if shared.cmd_opts.api else "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/API") gr.HTML(footer, elem_id="footer") - text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) - settings_submit.click( - fn=wrap_gradio_call(run_settings, extra_outputs=[gr.update()]), - inputs=components, - outputs=[text_settings, result], - ) - - for _i, k, _item in quicksettings_list: - component = component_dict[k] - info = opts.data_labels[k] - - change_handler = component.release if hasattr(component, 'release') else component.change - change_handler( - fn=lambda value, k=k: run_settings_single(value, key=k), - inputs=[component], - outputs=[component, text_settings], - show_progress=info.refresh is not None, - ) + settings.add_functionality(demo) update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit") - text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) + settings.text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale]) - button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) - button_set_checkpoint.click( - fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'), - _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", - inputs=[component_dict['sd_model_checkpoint'], dummy_component], - outputs=[component_dict['sd_model_checkpoint'], text_settings], - ) - - component_keys = [k for k in opts.data_labels.keys() if k in component_dict] - - def get_settings_values(): - return [get_value_for_setting(key) for key in component_keys] - - demo.load( - fn=get_settings_values, - inputs=[], - outputs=[component_dict[k] for k in component_keys], - queue=False, - ) - def modelmerger(*args): try: results = modules.extras.run_modelmerger(*args) except Exception as e: - print("Error loading/saving model file:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report("Error loading/saving model file", exc_info=True) modules.sd_models.list_models() # to remove the potentially missing models from the list return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"] return results @@ -1780,7 +1542,7 @@ def create_ui(): primary_model_name, secondary_model_name, tertiary_model_name, - component_dict['sd_model_checkpoint'], + settings.component_dict['sd_model_checkpoint'], modelmerger_result, ] ) @@ -1794,70 +1556,6 @@ def create_ui(): return demo -def webpath(fn): - if fn.startswith(script_path): - web_path = os.path.relpath(fn, script_path).replace('\\', '/') - else: - web_path = os.path.abspath(fn) - - return f'file={web_path}?{os.path.getmtime(fn)}' - - -def javascript_html(): - # Ensure localization is in `window` before scripts - head = f'\n' - - script_js = os.path.join(script_path, "script.js") - head += f'\n' - - for script in modules.scripts.list_scripts("javascript", ".js"): - head += f'\n' - - for script in modules.scripts.list_scripts("javascript", ".mjs"): - head += f'\n' - - if cmd_opts.theme: - head += f'\n' - - return head - - -def css_html(): - head = "" - - def stylesheet(fn): - return f'' - - for cssfile in modules.scripts.list_files_with_name("style.css"): - if not os.path.isfile(cssfile): - continue - - head += stylesheet(cssfile) - - if os.path.exists(os.path.join(data_path, "user.css")): - head += stylesheet(os.path.join(data_path, "user.css")) - - return head - - -def reload_javascript(): - js = javascript_html() - css = css_html() - - def template_response(*args, **kwargs): - res = shared.GradioTemplateResponseOriginal(*args, **kwargs) - res.body = res.body.replace(b'', f'{js}'.encode("utf8")) - res.body = res.body.replace(b'', f'{css}'.encode("utf8")) - res.init_headers() - return res - - gradio.routes.templates.TemplateResponse = template_response - - -if not hasattr(shared, 'GradioTemplateResponseOriginal'): - shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse - - def versions_html(): import torch import launch @@ -1901,3 +1599,17 @@ def setup_ui_api(app): app.add_api_route("/internal/quicksettings-hint", quicksettings_hint, methods=["GET"], response_model=List[QuicksettingsHint]) app.add_api_route("/internal/ping", lambda: {}, methods=["GET"]) + + app.add_api_route("/internal/profile-startup", lambda: timer.startup_record, methods=["GET"]) + + def download_sysinfo(attachment=False): + from fastapi.responses import PlainTextResponse + + text = sysinfo.get() + filename = f"sysinfo-{datetime.datetime.utcnow().strftime('%Y-%m-%d-%H-%M')}.txt" + + return PlainTextResponse(text, headers={'Content-Disposition': f'{"attachment" if attachment else "inline"}; filename="{filename}"'}) + + app.add_api_route("/internal/sysinfo", download_sysinfo, methods=["GET"]) + app.add_api_route("/internal/sysinfo-download", lambda: download_sysinfo(attachment=True), methods=["GET"]) + diff --git a/modules/ui_common.py b/modules/ui_common.py index 27ab3ebb6a822d5a70a1b03eec7933c34b93e237..57c2d0ade525b6e01e8c6a8c2f428ea4c939d699 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -10,8 +10,11 @@ import subprocess as sp from modules import call_queue, shared from modules.generation_parameters_copypaste import image_from_url_text import modules.images +from modules.ui_components import ToolButton + folder_symbol = '\U0001f4c2' # 📂 +refresh_symbol = '\U0001f504' # 🔄 def update_generation_info(generation_info, html_info, img_index): @@ -50,9 +53,10 @@ def save_files(js_data, images, do_make_zip, index): save_to_dirs = shared.opts.use_save_to_dirs_for_ui extension: str = shared.opts.samples_format start_index = 0 + only_one = False if index > -1 and shared.opts.save_selected_only and (index >= data["index_of_first_image"]): # ensures we are looking at a specific non-grid picture, and we have save_selected_only - + only_one = True images = [images[index]] start_index = index @@ -70,6 +74,7 @@ def save_files(js_data, images, do_make_zip, index): is_grid = image_index < p.index_of_first_image i = 0 if is_grid else (image_index - p.index_of_first_image) + p.batch_index = image_index-1 fullfn, txt_fullfn = modules.images.save_image(image, path, "", seed=p.all_seeds[i], prompt=p.all_prompts[i], extension=extension, info=p.infotexts[image_index], grid=is_grid, p=p, save_to_dirs=save_to_dirs) filename = os.path.relpath(fullfn, path) @@ -83,7 +88,10 @@ def save_files(js_data, images, do_make_zip, index): # Make Zip if do_make_zip: - zip_filepath = os.path.join(path, "images.zip") + zip_fileseed = p.all_seeds[index-1] if only_one else p.all_seeds[0] + namegen = modules.images.FilenameGenerator(p, zip_fileseed, p.all_prompts[0], image, True) + zip_filename = namegen.apply(shared.opts.grid_zip_filename_pattern or "[datetime]_[[model_name]]_[seed]-[seed_last]") + zip_filepath = os.path.join(path, f"{zip_filename}.zip") from zipfile import ZipFile with ZipFile(zip_filepath, "w") as zip_file: @@ -211,3 +219,23 @@ Requested path was: {f} )) return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log + + +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args + + for k, v in args.items(): + setattr(refresh_component, k, v) + + return gr.update(**(args or {})) + + refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button + diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 515ec262244cd9ca05546984befab535a90a970b..4379a6418afe729e7901770f7e9702a7aa7c1e3a 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -1,10 +1,8 @@ import json import os.path -import sys import threading import time from datetime import datetime -import traceback import git @@ -13,7 +11,7 @@ import html import shutil import errno -from modules import extensions, shared, paths, config_states +from modules import extensions, shared, paths, config_states, errors, restart from modules.paths_internal import config_states_dir from modules.call_queue import wrap_gradio_gpu_call @@ -46,13 +44,16 @@ def apply_and_restart(disable_list, update_list, disable_all): try: ext.fetch_and_reset_hard() except Exception: - print(f"Error getting updates for {ext.name}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error getting updates for {ext.name}", exc_info=True) shared.opts.disabled_extensions = disabled shared.opts.disable_all_extensions = disable_all shared.opts.save(shared.config_filename) - shared.state.request_restart() + + if restart.is_restartable(): + restart.restart_program() + else: + restart.stop_program() def save_config_state(name): @@ -113,8 +114,7 @@ def check_updates(id_task, disable_list): if 'FETCH_HEAD' not in str(e): raise except Exception: - print(f"Error checking updates for {ext.name}:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error checking updates for {ext.name}", exc_info=True) shared.state.nextjob() @@ -337,7 +337,8 @@ def install_extension_from_url(dirname, url, branch_name=None): assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}' normalized_url = normalize_git_url(url) - assert len([x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url]) == 0, 'Extension with this URL is already installed' + if any(x for x in extensions.extensions if normalize_git_url(x.remote) == normalized_url): + raise Exception(f'Extension with this URL is already installed: {url}') tmpdir = os.path.join(paths.data_path, "tmp", dirname) @@ -453,7 +454,7 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" existing = installed_extension_urls.get(normalize_git_url(url), None) extension_tags = extension_tags + ["installed"] if existing else extension_tags - if len([x for x in extension_tags if x in tags_to_hide]) > 0: + if any(x for x in extension_tags if x in tags_to_hide): hidden += 1 continue @@ -490,8 +491,14 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" def preload_extensions_git_metadata(): + t0 = time.time() for extension in extensions.extensions: extension.read_info_from_repo() + print( + f"preload_extensions_git_metadata for " + f"{len(extensions.extensions)} extensions took " + f"{time.time() - t0:.2f}s" + ) def create_ui(): @@ -506,7 +513,8 @@ def create_ui(): with gr.TabItem("Installed", id="installed"): with gr.Row(elem_id="extensions_installed_top"): - apply = gr.Button(value="Apply and restart UI", variant="primary") + apply_label = ("Apply and restart UI" if restart.is_restartable() else "Apply and quit") + apply = gr.Button(value=apply_label, variant="primary") check = gr.Button(value="Check for updates") extensions_disable_all = gr.Radio(label="Disable all extensions", choices=["none", "extra", "all"], value=shared.opts.disable_all_extensions, elem_id="extensions_disable_all") extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False).style(container=False) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 19fbaae5a84a7ae4935296e5452d34369e3a78b2..a7d3bc792f96a76d03ce8a6be7181a8f1aea06f7 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -4,6 +4,7 @@ from pathlib import Path from modules import shared from modules.images import read_info_from_image, save_image_with_geninfo +from modules.ui import up_down_symbol import gradio as gr import json import html @@ -185,6 +186,8 @@ class ExtraNetworksPage: if search_only and shared.opts.extra_networks_hidden_models == "Never": return "" + sort_keys = " ".join([html.escape(f'data-sort-{k}={v}') for k, v in item.get("sort_keys", {}).items()]).strip() + args = { "background_image": background_image, "style": f"'display: none; {height}{width}'", @@ -198,10 +201,23 @@ class ExtraNetworksPage: "search_term": item.get("search_term", ""), "metadata_button": metadata_button, "search_only": " search_only" if search_only else "", + "sort_keys": sort_keys, } return self.card_page.format(**args) + def get_sort_keys(self, path): + """ + List of default keys used for sorting in the UI. + """ + pth = Path(path) + stat = pth.stat() + return { + "date_created": int(stat.st_ctime or 0), + "date_modified": int(stat.st_mtime or 0), + "name": pth.name.lower(), + } + def find_preview(self, path): """ Find a preview PNG for a given path (without extension) and call link_preview on it. @@ -296,6 +312,8 @@ def create_ui(container, button, tabname): page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + json.dumps(tabname) + '); return []}', inputs=[], outputs=[]) gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) + gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True) + gr.Button(up_down_symbol, elem_id=tabname+"_extra_sortorder") button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index a17aa9c9c122299be27dca41631917c5766ab224..8b9ab71b2448a4bfb5c457bb757775b02e8662f5 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -14,7 +14,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): def list_items(self): checkpoint: sd_models.CheckpointInfo - for name, checkpoint in sd_models.checkpoints_list.items(): + for index, (name, checkpoint) in enumerate(sd_models.checkpoints_list.items()): path, ext = os.path.splitext(checkpoint.filename) yield { "name": checkpoint.name_for_extra, @@ -24,6 +24,8 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"', "local_preview": f"{path}.{shared.opts.samples_format}", + "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, + } def allowed_directories_for_previews(self): diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index 6187e00071845941753f6f3556a69f932cf808a4..7c19b53257a5ddc8070ea56f23fcc5817a10cdf8 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -12,7 +12,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): shared.reload_hypernetworks() def list_items(self): - for name, path in shared.hypernetworks.items(): + for index, (name, path) in enumerate(shared.hypernetworks.items()): path, ext = os.path.splitext(path) yield { @@ -23,6 +23,8 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): "search_term": self.search_terms_from_path(path), "prompt": json.dumps(f""), "local_preview": f"{path}.preview.{shared.opts.samples_format}", + "sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, + } def allowed_directories_for_previews(self): diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index 6944d5593e06b45b82ae5a7b054d83a3482466e3..58a61c55b3df438241d030d06854b14d354bd86b 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -13,7 +13,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) def list_items(self): - for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values(): + for index, embedding in enumerate(sd_hijack.model_hijack.embedding_db.word_embeddings.values()): path, ext = os.path.splitext(embedding.filename) yield { "name": embedding.name, @@ -23,6 +23,8 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): "search_term": self.search_terms_from_path(embedding.filename), "prompt": json.dumps(embedding.name), "local_preview": f"{path}.preview.{shared.opts.samples_format}", + "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, + } def allowed_directories_for_previews(self): diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..b824b113732d4fdfcd4d4b191e91fd300eb2d1a6 --- /dev/null +++ b/modules/ui_gradio_extensions.py @@ -0,0 +1,69 @@ +import os +import gradio as gr + +from modules import localization, shared, scripts +from modules.paths import script_path, data_path + + +def webpath(fn): + if fn.startswith(script_path): + web_path = os.path.relpath(fn, script_path).replace('\\', '/') + else: + web_path = os.path.abspath(fn) + + return f'file={web_path}?{os.path.getmtime(fn)}' + + +def javascript_html(): + # Ensure localization is in `window` before scripts + head = f'\n' + + script_js = os.path.join(script_path, "script.js") + head += f'\n' + + for script in scripts.list_scripts("javascript", ".js"): + head += f'\n' + + for script in scripts.list_scripts("javascript", ".mjs"): + head += f'\n' + + if shared.cmd_opts.theme: + head += f'\n' + + return head + + +def css_html(): + head = "" + + def stylesheet(fn): + return f'' + + for cssfile in scripts.list_files_with_name("style.css"): + if not os.path.isfile(cssfile): + continue + + head += stylesheet(cssfile) + + if os.path.exists(os.path.join(data_path, "user.css")): + head += stylesheet(os.path.join(data_path, "user.css")) + + return head + + +def reload_javascript(): + js = javascript_html() + css = css_html() + + def template_response(*args, **kwargs): + res = shared.GradioTemplateResponseOriginal(*args, **kwargs) + res.body = res.body.replace(b'', f'{js}'.encode("utf8")) + res.body = res.body.replace(b'', f'{css}'.encode("utf8")) + res.init_headers() + return res + + gr.routes.templates.TemplateResponse = template_response + + +if not hasattr(shared, 'GradioTemplateResponseOriginal'): + shared.GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse diff --git a/modules/ui_settings.py b/modules/ui_settings.py new file mode 100644 index 0000000000000000000000000000000000000000..0c560b30f9fcd3686d0ef9482c9c93b69f0b2b19 --- /dev/null +++ b/modules/ui_settings.py @@ -0,0 +1,289 @@ +import gradio as gr + +from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo +from modules.call_queue import wrap_gradio_call +from modules.shared import opts +from modules.ui_components import FormRow +from modules.ui_gradio_extensions import reload_javascript + + +def get_value_for_setting(key): + value = getattr(opts, key) + + info = opts.data_labels[key] + args = info.component_args() if callable(info.component_args) else info.component_args or {} + args = {k: v for k, v in args.items() if k not in {'precision'}} + + return gr.update(value=value, **args) + + +def create_setting_component(key, is_quicksettings=False): + def fun(): + return opts.data[key] if key in opts.data else opts.data_labels[key].default + + info = opts.data_labels[key] + t = type(info.default) + + args = info.component_args() if callable(info.component_args) else info.component_args + + if info.component is not None: + comp = info.component + elif t == str: + comp = gr.Textbox + elif t == int: + comp = gr.Number + elif t == bool: + comp = gr.Checkbox + else: + raise Exception(f'bad options item type: {t} for key {key}') + + elem_id = f"setting_{key}" + + if info.refresh is not None: + if is_quicksettings: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}") + else: + with FormRow(): + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + ui_common.create_refresh_button(res, info.refresh, info.component_args, f"refresh_{key}") + else: + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) + + return res + + +class UiSettings: + submit = None + result = None + interface = None + components = None + component_dict = None + dummy_component = None + quicksettings_list = None + quicksettings_names = None + text_settings = None + + def run_settings(self, *args): + changed = [] + + for key, value, comp in zip(opts.data_labels.keys(), args, self.components): + assert comp == self.dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" + + for key, value, comp in zip(opts.data_labels.keys(), args, self.components): + if comp == self.dummy_component: + continue + + if opts.set(key, value): + changed.append(key) + + try: + opts.save(shared.config_filename) + except RuntimeError: + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed{": " if changed else ""}{", ".join(changed)}.' + + def run_settings_single(self, value, key): + if not opts.same_type(value, opts.data_labels[key].default): + return gr.update(visible=True), opts.dumpjson() + + if not opts.set(key, value): + return gr.update(value=getattr(opts, key)), opts.dumpjson() + + opts.save(shared.config_filename) + + return get_value_for_setting(key), opts.dumpjson() + + def create_ui(self, loadsave, dummy_component): + self.components = [] + self.component_dict = {} + self.dummy_component = dummy_component + + shared.settings_components = self.component_dict + + script_callbacks.ui_settings_callback() + opts.reorder() + + with gr.Blocks(analytics_enabled=False) as settings_interface: + with gr.Row(): + with gr.Column(scale=6): + self.submit = gr.Button(value="Apply settings", variant='primary', elem_id="settings_submit") + with gr.Column(): + restart_gradio = gr.Button(value='Reload UI', variant='primary', elem_id="settings_restart_gradio") + + self.result = gr.HTML(elem_id="settings_result") + + self.quicksettings_names = opts.quicksettings_list + self.quicksettings_names = {x: i for i, x in enumerate(self.quicksettings_names) if x != 'quicksettings'} + + self.quicksettings_list = [] + + previous_section = None + current_tab = None + current_row = None + with gr.Tabs(elem_id="settings"): + for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None + + if previous_section != item.section and not section_must_be_skipped: + elem_id, text = item.section + + if current_tab is not None: + current_row.__exit__() + current_tab.__exit__() + + gr.Group() + current_tab = gr.TabItem(elem_id=f"settings_{elem_id}", label=text) + current_tab.__enter__() + current_row = gr.Column(variant='compact') + current_row.__enter__() + + previous_section = item.section + + if k in self.quicksettings_names and not shared.cmd_opts.freeze_settings: + self.quicksettings_list.append((i, k, item)) + self.components.append(dummy_component) + elif section_must_be_skipped: + self.components.append(dummy_component) + else: + component = create_setting_component(k) + self.component_dict[k] = component + self.components.append(component) + + if current_tab is not None: + current_row.__exit__() + current_tab.__exit__() + + with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"): + loadsave.create_ui() + + with gr.TabItem("Sysinfo", id="sysinfo", elem_id="settings_tab_sysinfo"): + gr.HTML('Download system info
(or open as text in a new page)', elem_id="sysinfo_download") + + with gr.Row(): + with gr.Column(scale=1): + sysinfo_check_file = gr.File(label="Check system info for validity", type='binary') + with gr.Column(scale=1): + sysinfo_check_output = gr.HTML("", elem_id="sysinfo_validity") + with gr.Column(scale=100): + pass + + with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"): + request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications") + download_localization = gr.Button(value='Download localization template', elem_id="download_localization") + reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") + with gr.Row(): + unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model") + reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model") + + with gr.TabItem("Licenses", id="licenses", elem_id="settings_tab_licenses"): + gr.HTML(shared.html("licenses.html"), elem_id="licenses") + + gr.Button(value="Show all pages", elem_id="settings_show_all_pages") + + self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) + + unload_sd_model.click( + fn=sd_models.unload_model_weights, + inputs=[], + outputs=[] + ) + + reload_sd_model.click( + fn=sd_models.reload_model_weights, + inputs=[], + outputs=[] + ) + + request_notifications.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='function(){}' + ) + + download_localization.click( + fn=lambda: None, + inputs=[], + outputs=[], + _js='download_localization' + ) + + def reload_scripts(): + scripts.reload_script_body_only() + reload_javascript() # need to refresh the html page + + reload_script_bodies.click( + fn=reload_scripts, + inputs=[], + outputs=[] + ) + + restart_gradio.click( + fn=shared.state.request_restart, + _js='restart_reload', + inputs=[], + outputs=[], + ) + + def check_file(x): + if x is None: + return '' + + if sysinfo.check(x.decode('utf8', errors='ignore')): + return 'Valid' + + return 'Invalid' + + sysinfo_check_file.change( + fn=check_file, + inputs=[sysinfo_check_file], + outputs=[sysinfo_check_output], + ) + + self.interface = settings_interface + + def add_quicksettings(self): + with gr.Row(elem_id="quicksettings", variant="compact"): + for _i, k, _item in sorted(self.quicksettings_list, key=lambda x: self.quicksettings_names.get(x[1], x[0])): + component = create_setting_component(k, is_quicksettings=True) + self.component_dict[k] = component + + def add_functionality(self, demo): + self.submit.click( + fn=wrap_gradio_call(lambda *args: self.run_settings(*args), extra_outputs=[gr.update()]), + inputs=self.components, + outputs=[self.text_settings, self.result], + ) + + for _i, k, _item in self.quicksettings_list: + component = self.component_dict[k] + info = opts.data_labels[k] + + change_handler = component.release if hasattr(component, 'release') else component.change + change_handler( + fn=lambda value, k=k: self.run_settings_single(value, key=k), + inputs=[component], + outputs=[component, self.text_settings], + show_progress=info.refresh is not None, + ) + + button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False) + button_set_checkpoint.click( + fn=lambda value, _: self.run_settings_single(value, key='sd_model_checkpoint'), + _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }", + inputs=[self.component_dict['sd_model_checkpoint'], self.dummy_component], + outputs=[self.component_dict['sd_model_checkpoint'], self.text_settings], + ) + + component_keys = [k for k in opts.data_labels.keys() if k in self.component_dict] + + def get_settings_values(): + return [get_value_for_setting(key) for key in component_keys] + + demo.load( + fn=get_settings_values, + inputs=[], + outputs=[self.component_dict[k] for k in component_keys], + queue=False, + ) diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py index 9fc7d76438e455d1e43d2ab3c2c79ed18bce70c7..fb75137e6596107e899fd28e93bca7339306697a 100644 --- a/modules/ui_tempdir.py +++ b/modules/ui_tempdir.py @@ -31,7 +31,7 @@ def check_tmp_file(gradio, filename): return False -def save_pil_to_file(self, pil_image, dir=None): +def save_pil_to_file(self, pil_image, dir=None, format="png"): already_saved_as = getattr(pil_image, 'already_saved_as', None) if already_saved_as and os.path.isfile(already_saved_as): register_tmp_file(shared.demo, already_saved_as) diff --git a/modules/upscaler.py b/modules/upscaler.py index 7b1046d64e9c7a7e490e1aee6dcc0bc951f991ca..e682bbaa26cd05fa8f00f6e6ca438a8c53f7d47b 100644 --- a/modules/upscaler.py +++ b/modules/upscaler.py @@ -53,8 +53,8 @@ class Upscaler: def upscale(self, img: PIL.Image, scale, selected_model: str = None): self.scale = scale - dest_w = int(img.width * scale) - dest_h = int(img.height * scale) + dest_w = int((img.width * scale) // 8 * 8) + dest_h = int((img.height * scale) // 8 * 8) for _ in range(3): shape = (img.width, img.height) @@ -77,7 +77,7 @@ class Upscaler: pass def find_models(self, ext_filter=None) -> list: - return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path) + return modelloader.load_models(model_path=self.model_path, model_url=self.model_url, command_path=self.user_path, ext_filter=ext_filter) def update_status(self, prompt): print(f"\nextras: {prompt}", file=shared.progress_print_out) diff --git a/requirements.txt b/requirements.txt index 34e4520d66a6462aaa6810b0a2b821f6c42e5f57..3142085eaf43d16b660ee4fc121ad3c7014ddb9c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,32 +1,32 @@ -astunparse -blendmodes +GitPython +Pillow accelerate + basicsr +blendmodes +clean-fid +einops gfpgan -gradio==3.31.0 +gradio==3.32.0 +inflection +jsonmerge +kornia +lark numpy omegaconf -opencv-contrib-python -requests + piexif -Pillow -pytorch_lightning==1.7.7 +psutil +pytorch_lightning realesrgan +requests +resize-right + +safetensors scikit-image>=0.19 -timm==0.4.12 -transformers==4.25.1 +timm +tomesd torch -einops -jsonmerge -clean-fid -resize-right torchdiffeq -kornia -lark -inflection -GitPython torchsde -safetensors -psutil -rich -tomesd +transformers==4.25.1 diff --git a/requirements_versions.txt b/requirements_versions.txt index 31b179a9e804b943df40959de002053e01265e1f..f71b9d6c555280ebbe991e12fd117b51dc0410ce 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -1,29 +1,30 @@ -blendmodes==2022 -transformers==4.25.1 +GitPython==3.1.30 +Pillow==9.5.0 accelerate==0.18.0 basicsr==1.4.2 +blendmodes==2022 +clean-fid==0.1.35 +einops==0.4.1 +fastapi==0.94.0 gfpgan==1.3.8 gradio==3.32.0 +httpcore<=0.15 +inflection==0.5.1 +jsonmerge==1.8.0 +kornia==0.6.7 +lark==1.1.2 numpy==1.23.5 -Pillow==9.5.0 -realesrgan==0.3.0 -torch omegaconf==2.2.3 +piexif==1.1.3 +psutil~=5.9.5 pytorch_lightning==1.9.4 +realesrgan==0.3.0 +resize-right==0.0.2 +safetensors==0.3.1 scikit-image==0.20.0 timm==0.6.7 -piexif==1.1.3 -einops==0.4.1 -jsonmerge==1.8.0 -clean-fid==0.1.35 -resize-right==0.0.2 +tomesd==0.1.2 +torch torchdiffeq==0.2.3 -kornia==0.6.7 -lark==1.1.2 -inflection==0.5.1 -GitPython==3.1.30 torchsde==0.2.5 -safetensors==0.3.1 -httpcore<=0.15 -fastapi==0.94.0 -tomesd==0.1.2 +transformers==4.25.1 diff --git a/script.js b/script.js index f76127799c1f89d592ffe5b8294f5fb3c955edad..34cca7651dd31a79fe68b27cb0febb58d3e4a237 100644 --- a/script.js +++ b/script.js @@ -10,44 +10,94 @@ function gradioApp() { return elem.shadowRoot ? elem.shadowRoot : elem; } +/** + * Get the currently selected top-level UI tab button (e.g. the button that says "Extras"). + */ function get_uiCurrentTab() { - return gradioApp().querySelector('#tabs button.selected'); + return gradioApp().querySelector('#tabs > .tab-nav > button.selected'); } +/** + * Get the first currently visible top-level UI tab content (e.g. the div hosting the "txt2img" UI). + */ function get_uiCurrentTabContent() { - return gradioApp().querySelector('.tabitem[id^=tab_]:not([style*="display: none"])'); + return gradioApp().querySelector('#tabs > .tabitem[id^=tab_]:not([style*="display: none"])'); } var uiUpdateCallbacks = []; +var uiAfterUpdateCallbacks = []; var uiLoadedCallbacks = []; var uiTabChangeCallbacks = []; var optionsChangedCallbacks = []; +var uiAfterUpdateTimeout = null; var uiCurrentTab = null; +/** + * Register callback to be called at each UI update. + * The callback receives an array of MutationRecords as an argument. + */ function onUiUpdate(callback) { uiUpdateCallbacks.push(callback); } + +/** + * Register callback to be called soon after UI updates. + * The callback receives no arguments. + * + * This is preferred over `onUiUpdate` if you don't need + * access to the MutationRecords, as your function will + * not be called quite as often. + */ +function onAfterUiUpdate(callback) { + uiAfterUpdateCallbacks.push(callback); +} + +/** + * Register callback to be called when the UI is loaded. + * The callback receives no arguments. + */ function onUiLoaded(callback) { uiLoadedCallbacks.push(callback); } + +/** + * Register callback to be called when the UI tab is changed. + * The callback receives no arguments. + */ function onUiTabChange(callback) { uiTabChangeCallbacks.push(callback); } + +/** + * Register callback to be called when the options are changed. + * The callback receives no arguments. + * @param callback + */ function onOptionsChanged(callback) { optionsChangedCallbacks.push(callback); } -function runCallback(x, m) { - try { - x(m); - } catch (e) { - (console.error || console.log).call(console, e.message, e); +function executeCallbacks(queue, arg) { + for (const callback of queue) { + try { + callback(arg); + } catch (e) { + console.error("error running callback", callback, ":", e); + } } } -function executeCallbacks(queue, m) { - queue.forEach(function(x) { - runCallback(x, m); - }); + +/** + * Schedule the execution of the callbacks registered with onAfterUiUpdate. + * The callbacks are executed after a short while, unless another call to this function + * is made before that time. IOW, the callbacks are executed only once, even + * when there are multiple mutations observed. + */ +function scheduleAfterUiUpdateCallbacks() { + clearTimeout(uiAfterUpdateTimeout); + uiAfterUpdateTimeout = setTimeout(function() { + executeCallbacks(uiAfterUpdateCallbacks); + }, 200); } var executedOnLoaded = false; @@ -60,6 +110,7 @@ document.addEventListener("DOMContentLoaded", function() { } executeCallbacks(uiUpdateCallbacks, m); + scheduleAfterUiUpdateCallbacks(); const newTab = get_uiCurrentTab(); if (newTab && (newTab !== uiCurrentTab)) { uiCurrentTab = newTab; diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index 665dbe89706da04325ee1a90916f132dd74c9c85..c98ab48098ec553fec85607ff53ab1a1c956e71d 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -145,7 +145,6 @@ class Script(scripts.Script): process_width = p.width process_height = p.height - p.mask_blur = mask_blur*4 p.inpaint_full_res = False p.inpainting_fill = 1 p.do_not_save_samples = True @@ -156,6 +155,19 @@ class Script(scripts.Script): up = pixels if "up" in direction else 0 down = pixels if "down" in direction else 0 + if left > 0 or right > 0: + mask_blur_x = mask_blur + else: + mask_blur_x = 0 + + if up > 0 or down > 0: + mask_blur_y = mask_blur + else: + mask_blur_y = 0 + + p.mask_blur_x = mask_blur_x*4 + p.mask_blur_y = mask_blur_y*4 + init_img = p.init_images[0] target_w = math.ceil((init_img.width + left + right) / 64) * 64 target_h = math.ceil((init_img.height + up + down) / 64) * 64 @@ -191,10 +203,10 @@ class Script(scripts.Script): mask = Image.new("RGB", (process_res_w, process_res_h), "white") draw = ImageDraw.Draw(mask) draw.rectangle(( - expand_pixels + mask_blur if is_left else 0, - expand_pixels + mask_blur if is_top else 0, - mask.width - expand_pixels - mask_blur if is_right else res_w, - mask.height - expand_pixels - mask_blur if is_bottom else res_h, + expand_pixels + mask_blur_x if is_left else 0, + expand_pixels + mask_blur_y if is_top else 0, + mask.width - expand_pixels - mask_blur_x if is_right else res_w, + mask.height - expand_pixels - mask_blur_y if is_bottom else res_h, ), fill="black") np_image = (np.asarray(img) / 255.0).astype(np.float64) @@ -224,10 +236,10 @@ class Script(scripts.Script): latent_mask = Image.new("RGB", (p.width, p.height), "white") draw = ImageDraw.Draw(latent_mask) draw.rectangle(( - expand_pixels + mask_blur * 2 if is_left else 0, - expand_pixels + mask_blur * 2 if is_top else 0, - mask.width - expand_pixels - mask_blur * 2 if is_right else res_w, - mask.height - expand_pixels - mask_blur * 2 if is_bottom else res_h, + expand_pixels + mask_blur_x * 2 if is_left else 0, + expand_pixels + mask_blur_y * 2 if is_top else 0, + mask.width - expand_pixels - mask_blur_x * 2 if is_right else res_w, + mask.height - expand_pixels - mask_blur_y * 2 if is_bottom else res_h, ), fill="black") p.latent_mask = latent_mask diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index b918a764e9b0290e6df22dba55a7ad02f24707d2..50320d553bd2bbd2a599194fafa914e1aa74945b 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -1,13 +1,11 @@ import copy import random -import sys -import traceback import shlex import modules.scripts as scripts import gradio as gr -from modules import sd_samplers +from modules import sd_samplers, errors from modules.processing import Processed, process_images from modules.shared import state @@ -123,8 +121,7 @@ class Script(scripts.Script): return [checkbox_iterate, checkbox_iterate_batch, prompt_txt] def run(self, p, checkbox_iterate, checkbox_iterate_batch, prompt_txt: str): - lines = [x.strip() for x in prompt_txt.splitlines()] - lines = [x for x in lines if len(x) > 0] + lines = [x for x in (x.strip() for x in prompt_txt.splitlines()) if x] p.do_not_save_grid = True @@ -136,8 +133,7 @@ class Script(scripts.Script): try: args = cmdargs(line) except Exception: - print(f"Error parsing line {line} as commandline:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + errors.report(f"Error parsing line {line} as commandline", exc_info=True) args = {"prompt": line} else: args = {"prompt": line} diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index da820b394d4adc918a3509fdaac1ebc0b4de197d..7821cc655cd658f4ffed646df8d7a91891bd7a4b 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -10,7 +10,7 @@ import numpy as np import modules.scripts as scripts import gradio as gr -from modules import images, sd_samplers, processing, sd_models, sd_vae +from modules import images, sd_samplers, processing, sd_models, sd_vae, sd_samplers_kdiffusion from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.shared import opts, state import modules.shared as shared @@ -220,6 +220,10 @@ axis_options = [ AxisOption("Sigma min", float, apply_field("s_tmin")), AxisOption("Sigma max", float, apply_field("s_tmax")), AxisOption("Sigma noise", float, apply_field("s_noise")), + AxisOption("Schedule type", str, apply_override("k_sched_type"), choices=lambda: list(sd_samplers_kdiffusion.k_diffusion_scheduler)), + AxisOption("Schedule min sigma", float, apply_override("sigma_min")), + AxisOption("Schedule max sigma", float, apply_override("sigma_max")), + AxisOption("Schedule rho", float, apply_override("rho")), AxisOption("Eta", float, apply_field("eta")), AxisOption("Clip skip", int, apply_clip_skip), AxisOption("Denoising", float, apply_field("denoising_strength")), diff --git a/style.css b/style.css index 571f4cf45174e80af310a32fd11303e6d2505f03..e1df716f13533201f0c6406e856371e4201caa54 100644 --- a/style.css +++ b/style.css @@ -403,19 +403,29 @@ div#extras_scale_to_tab div.form{ margin: 0 1.2em; } -table.settings-value-table{ +table.popup-table{ background: white; border-collapse: collapse; margin: 1em; border: 4px solid white; } -table.settings-value-table td{ +table.popup-table td{ padding: 0.4em; border: 1px solid #ccc; max-width: 36em; } +table.popup-table .muted{ + color: #aaa; +} + +table.popup-table .link{ + text-decoration: underline; + cursor: pointer; + font-weight: bold; +} + .ui-defaults-none{ color: #aaa !important; } @@ -440,6 +450,19 @@ table.settings-value-table td{ opacity: 0.75; } +#sysinfo_download a.sysinfo_big_link{ + font-size: 24pt; +} + +#sysinfo_download a{ + text-decoration: underline; +} + +#sysinfo_validity{ + font-size: 18pt; +} + + /* live preview */ .progressDiv{ position: relative; @@ -724,12 +747,22 @@ footer { .extra-network-subdirs button{ margin: 0 0.15em; } -.extra-networks .tab-nav .search{ +.extra-networks .tab-nav .search, +.extra-networks .tab-nav .sort, +.extra-networks .tab-nav .sortorder{ display: inline-block; - max-width: 16em; margin: 0.3em; align-self: center; +} + +.extra-networks .tab-nav .search { width: 16em; + max-width: 16em; +} + +.extra-networks .tab-nav .sort { + width: 12em; + max-width: 12em; } #txt2img_extra_view, #img2img_extra_view { diff --git a/webui-user.sh b/webui-user.sh index 49a426ff98ed5c87607a491ed10cee0787019254..70306c60d5b495bebd87da8f06da58fb72706553 100644 --- a/webui-user.sh +++ b/webui-user.sh @@ -36,7 +36,6 @@ # Fixed git commits #export STABLE_DIFFUSION_COMMIT_HASH="" -#export TAMING_TRANSFORMERS_COMMIT_HASH="" #export CODEFORMER_COMMIT_HASH="" #export BLIP_COMMIT_HASH="" diff --git a/webui.bat b/webui.bat index 209d972bd755293bc16b4c8b6fe0e5108b3244b3..42e7d517d18631d3826d48ee338e87cb41771019 100644 --- a/webui.bat +++ b/webui.bat @@ -3,7 +3,7 @@ if not defined PYTHON (set PYTHON=python) if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv") - +set SD_WEBUI_RESTART=tmp/restart set ERROR_REPORTING=FALSE mkdir tmp 2>NUL @@ -51,12 +51,14 @@ if EXIST %ACCELERATE% goto :accelerate_launch :launch %PYTHON% launch.py %* +if EXIST tmp/restart goto :skip_venv pause exit /b :accelerate_launch echo Accelerating %ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py +if EXIST tmp/restart goto :skip_venv pause exit /b diff --git a/webui.py b/webui.py index f0ffbbbf7f544e2159ee7499b3c6176c46b403dc..136d036d53e5344db9a0ed80db6c2f48822421ae 100644 --- a/webui.py +++ b/webui.py @@ -20,9 +20,9 @@ import logging logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) -from modules import paths, timer, import_hook, errors # noqa: F401 +from modules import paths, timer, import_hook, errors, devices # noqa: F401 -startup_timer = timer.Timer() +startup_timer = timer.startup_timer import torch import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them @@ -58,6 +58,7 @@ import modules.sd_hijack import modules.sd_hijack_optimizations import modules.sd_models import modules.sd_vae +import modules.sd_unet import modules.txt2img import modules.script_callbacks import modules.textual_inversion.textual_inversion @@ -134,7 +135,7 @@ there are reports of issues with training tab on the latest version. Use --skip-version-check commandline argument to disable this check. """.strip()) - expected_xformers_version = "0.0.17" + expected_xformers_version = "0.0.20" if shared.xformers_available: import xformers @@ -271,8 +272,8 @@ def initialize_rest(*, reload_script_modules=False): localization.list_localizations(cmd_opts.localizations_dir) - modules.scripts.load_scripts() - startup_timer.record("load scripts") + with startup_timer.subcategory("load scripts"): + modules.scripts.load_scripts() if reload_script_modules: for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: @@ -291,6 +292,9 @@ def initialize_rest(*, reload_script_modules=False): modules.sd_hijack.list_optimizers() startup_timer.record("scripts list_optimizers") + modules.sd_unet.list_unets() + startup_timer.record("scripts list_unets") + def load_model(): """ Accesses shared.sd_model property to load model. @@ -306,6 +310,8 @@ def initialize_rest(*, reload_script_modules=False): Thread(target=load_model).start() + Thread(target=devices.first_time_calculation).start() + shared.reload_hypernetworks() startup_timer.record("reload hypernetworks") @@ -381,17 +387,6 @@ def webui(): gradio_auth_creds = list(get_gradio_auth_creds()) or None - # this restores the missing /docs endpoint - if launch_api and not hasattr(FastAPI, 'original_setup'): - # TODO: replace this with `launch(app_kwargs=...)` if https://github.com/gradio-app/gradio/pull/4282 gets merged - def fastapi_setup(self): - self.docs_url = "/docs" - self.redoc_url = "/redoc" - self.original_setup() - - FastAPI.original_setup = FastAPI.setup - FastAPI.setup = fastapi_setup - app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, server_name=server_name, @@ -401,9 +396,13 @@ def webui(): ssl_verify=cmd_opts.disable_tls_verify, debug=cmd_opts.gradio_debug, auth=gradio_auth_creds, - inbrowser=cmd_opts.autolaunch, + inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING ') != '1', prevent_thread_lock=True, allowed_paths=cmd_opts.gradio_allowed_path, + app_kwargs={ + "docs_url": "/docs", + "redoc_url": "/redoc", + }, ) if cmd_opts.add_stop_route: app.add_route("/_stop", stop_route, methods=["POST"]) @@ -429,9 +428,12 @@ def webui(): ui_extra_networks.add_pages_to_demo(app) - modules.script_callbacks.app_started_callback(shared.demo, app) - startup_timer.record("scripts app_started_callback") + startup_timer.record("add APIs") + + with startup_timer.subcategory("app_started_callback"): + modules.script_callbacks.app_started_callback(shared.demo, app) + timer.startup_record = startup_timer.dump() print(f"Startup time: {startup_timer.summary()}.") if cmd_opts.subpath: @@ -456,6 +458,7 @@ def webui(): # If we catch a keyboard interrupt, we want to stop the server and exit. shared.demo.close() break + print('Restarting UI...') shared.demo.close() time.sleep(0.5) @@ -466,10 +469,6 @@ def webui(): startup_timer.record("scripts unloaded callback") initialize_rest(reload_script_modules=True) - modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers) - modules.sd_hijack.list_optimizers() - startup_timer.record("scripts list_optimizers") - if __name__ == "__main__": if cmd_opts.nowebui: diff --git a/webui.sh b/webui.sh index ab52ac3b5d46d94a23365acbe365669037139fe6..5c8d977cb2db058c37e8270beb80b7587d1fa9c0 100755 --- a/webui.sh +++ b/webui.sh @@ -112,9 +112,24 @@ then fi # Check prerequisites -gpu_info=$(lspci 2>/dev/null | grep VGA) +gpu_info=$(lspci 2>/dev/null | grep -E "VGA|Display") case "$gpu_info" in - *"Navi 1"*|*"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 + *"Navi 1"*) + export HSA_OVERRIDE_GFX_VERSION=10.3.0 + if [[ -z "${TORCH_COMMAND}" ]] + then + pyv="$(${python_cmd} -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')" + if [[ $(bc <<< "$pyv <= 3.10") -eq 1 ]] + then + # Navi users will still use torch 1.13 because 2.0 does not seem to work. + export TORCH_COMMAND="pip install torch==1.13.1+rocm5.2 torchvision==0.14.1+rocm5.2 --index-url https://download.pytorch.org/whl/rocm5.2" + else + printf "\e[1m\e[31mERROR: RX 5000 series GPUs must be using at max python 3.10, aborting...\e[0m" + exit 1 + fi + fi + ;; + *"Navi 2"*) export HSA_OVERRIDE_GFX_VERSION=10.3.0 ;; *"Renoir"*) export HSA_OVERRIDE_GFX_VERSION=9.0.0 printf "\n%s\n" "${delimiter}" @@ -124,9 +139,12 @@ case "$gpu_info" in *) ;; esac -if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] +if ! echo "$gpu_info" | grep -q "NVIDIA"; then - export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2" + if echo "$gpu_info" | grep -q "AMD" && [[ -z "${TORCH_COMMAND}" ]] + then + export TORCH_COMMAND="pip install torch==2.0.1+rocm5.4.2 torchvision==0.15.2+rocm5.4.2 --index-url https://download.pytorch.org/whl/rocm5.4.2" + fi fi for preq in "${GIT}" "${python_cmd}" @@ -190,7 +208,7 @@ fi # Try using TCMalloc on Linux prepare_tcmalloc() { if [[ "${OSTYPE}" == "linux"* ]] && [[ -z "${NO_TCMALLOC}" ]] && [[ -z "${LD_PRELOAD}" ]]; then - TCMALLOC="$(ldconfig -p | grep -Po "libtcmalloc.so.\d" | head -n 1)" + TCMALLOC="$(PATH=/usr/sbin:$PATH ldconfig -p | grep -Po "libtcmalloc(_minimal|)\.so\.\d" | head -n 1)" if [[ ! -z "${TCMALLOC}" ]]; then echo "Using TCMalloc: ${TCMALLOC}" export LD_PRELOAD="${TCMALLOC}" @@ -200,17 +218,24 @@ prepare_tcmalloc() { fi } -if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ] -then - printf "\n%s\n" "${delimiter}" - printf "Accelerating launch.py..." - printf "\n%s\n" "${delimiter}" - prepare_tcmalloc - exec accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@" -else - printf "\n%s\n" "${delimiter}" - printf "Launching launch.py..." - printf "\n%s\n" "${delimiter}" - prepare_tcmalloc - exec "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" -fi +KEEP_GOING=1 +export SD_WEBUI_RESTART=tmp/restart +while [[ "$KEEP_GOING" -eq "1" ]]; do + if [[ ! -z "${ACCELERATE}" ]] && [ ${ACCELERATE}="True" ] && [ -x "$(command -v accelerate)" ]; then + printf "\n%s\n" "${delimiter}" + printf "Accelerating launch.py..." + printf "\n%s\n" "${delimiter}" + prepare_tcmalloc + accelerate launch --num_cpu_threads_per_process=6 "${LAUNCH_SCRIPT}" "$@" + else + printf "\n%s\n" "${delimiter}" + printf "Launching launch.py..." + printf "\n%s\n" "${delimiter}" + prepare_tcmalloc + "${python_cmd}" "${LAUNCH_SCRIPT}" "$@" + fi + + if [[ ! -f tmp/restart ]]; then + KEEP_GOING=0 + fi +done