infotext_utils.py 17.4 KB
Newer Older
1
from __future__ import annotations
2 3
import base64
import io
4
import json
T
Trung Ngo 已提交
5
import os
6
import re
7
import sys
8

9
import gradio as gr
M
Max Audron 已提交
10
from modules.paths import data_path
A
linter  
AUTOMATIC1111 已提交
11
from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions
12
from PIL import Image
13

14 15
sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__]  # alias for old name

16
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
17
re_param = re.compile(re_param_code)
18
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
19
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
20
type_of_gr_update = type(gr.update())
21 22 23


class ParamBinding:
A
AUTOMATIC 已提交
24
    def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
25 26 27 28 29 30
        self.paste_button = paste_button
        self.tabname = tabname
        self.source_text_component = source_text_component
        self.source_image_component = source_image_component
        self.source_tabname = source_tabname
        self.override_settings_component = override_settings_component
A
AUTOMATIC 已提交
31
        self.paste_field_names = paste_field_names or []
32

33

A
a  
AUTOMATIC1111 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46
class PasteField(tuple):
    def __new__(cls, component, target, *, api=None):
        return super().__new__(cls, (component, target))

    def __init__(self, component, target, *, api=None):
        super().__init__()

        self.api = api
        self.component = component
        self.label = target if isinstance(target, str) else None
        self.function = target if callable(target) else None


47 48 49 50
paste_fields: dict[str, dict] = {}
registered_param_bindings: list[ParamBinding] = []


51 52
def reset():
    paste_fields.clear()
53
    registered_param_bindings.clear()
54 55


56
def quote(text):
57
    if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
58 59
        return text

60 61 62 63 64 65 66 67 68 69 70
    return json.dumps(text, ensure_ascii=False)


def unquote(text):
    if len(text) == 0 or text[0] != '"' or text[-1] != '"':
        return text

    try:
        return json.loads(text)
    except Exception:
        return text
71

72

Y
yfszzx 已提交
73
def image_from_url_text(filedata):
74 75 76
    if filedata is None:
        return None

77
    if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False):
A
AUTOMATIC 已提交
78 79 80
        filedata = filedata[0]

    if type(filedata) == dict and filedata.get("is_file", False):
Y
yfszzx 已提交
81
        filename = filedata["name"]
82
        is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
83
        assert is_in_right_dir, 'trying to open image file outside of allowed directories'
Y
yfszzx 已提交
84

S
Sakura-Luna 已提交
85
        filename = filename.rsplit('?', 1)[0]
Y
yfszzx 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
        return Image.open(filename)

    if type(filedata) == list:
        if len(filedata) == 0:
            return None

        filedata = filedata[0]

    if filedata.startswith("data:image/png;base64,"):
        filedata = filedata[len("data:image/png;base64,"):]

    filedata = base64.decodebytes(filedata.encode('utf-8'))
    image = Image.open(io.BytesIO(filedata))
    return image

101

102
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
A
a  
AUTOMATIC1111 已提交
103 104 105 106 107 108

    if fields:
        for i in range(len(fields)):
            if not isinstance(fields[i], PasteField):
                fields[i] = PasteField(*fields[i])

109
    paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
110 111 112 113 114 115 116

    # backwards compatibility for existing extensions
    import modules.ui
    if tabname == 'txt2img':
        modules.ui.txt2img_paste_fields = fields
    elif tabname == 'img2img':
        modules.ui.img2img_paste_fields = fields
Y
yfszzx 已提交
117

118

Y
yfszzx 已提交
119 120 121
def create_buttons(tabs_list):
    buttons = {}
    for tab in tabs_list:
122
        buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
Y
yfszzx 已提交
123 124
    return buttons

125

Y
yfszzx 已提交
126
def bind_buttons(buttons, send_image, send_generate_info):
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    """old function for backwards compatibility; do not use this, use register_paste_params_button"""
    for tabname, button in buttons.items():
        source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
        source_tabname = send_generate_info if isinstance(send_generate_info, str) else None

        register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))


def register_paste_params_button(binding: ParamBinding):
    registered_param_bindings.append(binding)


def connect_paste_params_buttons():
    for binding in registered_param_bindings:
        destination_image_component = paste_fields[binding.tabname]["init_img"]
        fields = paste_fields[binding.tabname]["fields"]
143
        override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160

        destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
        destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)

        if binding.source_image_component and destination_image_component:
            if isinstance(binding.source_image_component, gr.Gallery):
                func = send_image_and_dimensions if destination_width_component else image_from_url_text
                jsfunc = "extract_image_from_gallery"
            else:
                func = send_image_and_dimensions if destination_width_component else lambda x: x
                jsfunc = None

            binding.paste_button.click(
                fn=func,
                _js=jsfunc,
                inputs=[binding.source_image_component],
                outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
161
                show_progress=False,
162 163 164
            )

        if binding.source_text_component is not None and fields is not None:
165
            connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
166 167

        if binding.source_tabname is not None and fields is not None:
168
            paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
169 170 171 172
            binding.paste_button.click(
                fn=lambda *x: x,
                inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
                outputs=[field for field, name in fields if name in paste_field_names],
173
                show_progress=False,
174 175 176 177 178 179 180
            )

        binding.paste_button.click(
            fn=None,
            _js=f"switch_to_{binding.tabname}",
            inputs=None,
            outputs=None,
181
            show_progress=False,
182
        )
Y
yfszzx 已提交
183

184

185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
def send_image_and_dimensions(x):
    if isinstance(x, Image.Image):
        img = x
    else:
        img = image_from_url_text(x)

    if shared.opts.send_size and isinstance(img, Image.Image):
        w = img.width
        h = img.height
    else:
        w = gr.update()
        h = gr.update()

    return img, w, h


A
AUTOMATIC 已提交
201 202 203 204 205 206 207
def restore_old_hires_fix_params(res):
    """for infotexts that specify old First pass size parameter, convert it into
    width, height, and hr scale"""

    firstpass_width = res.get('First pass size-1', None)
    firstpass_height = res.get('First pass size-2', None)

208
    if shared.opts.use_old_hires_fix_width_height:
209 210
        hires_width = int(res.get("Hires resize-1", 0))
        hires_height = int(res.get("Hires resize-2", 0))
211

212
        if hires_width and hires_height:
213 214 215 216
            res['Size-1'] = hires_width
            res['Size-2'] = hires_height
            return

A
AUTOMATIC 已提交
217 218 219 220 221 222 223 224
    if firstpass_width is None or firstpass_height is None:
        return

    firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
    width = int(res.get("Size-1", 512))
    height = int(res.get("Size-2", 512))

    if firstpass_width == 0 or firstpass_height == 0:
225
        firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
A
AUTOMATIC 已提交
226 227 228

    res['Size-1'] = firstpass_width
    res['Size-2'] = firstpass_height
229 230
    res['Hires resize-1'] = width
    res['Hires resize-2'] = height
A
AUTOMATIC 已提交
231 232


233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
def parse_generation_parameters(x: str):
    """parses generation parameters string, the one you see in text field under the picture in UI:
```
girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
```

    returns a dict with field values
    """

    res = {}

    prompt = ""
    negative_prompt = ""

    done_with_prompt = False

    *lines, lastline = x.strip().split("\n")
252
    if len(re_param.findall(lastline)) < 3:
253 254 255
        lines.append(lastline)
        lastline = ''

A
AUTOMATIC 已提交
256
    for line in lines:
257 258 259 260 261
        line = line.strip()
        if line.startswith("Negative prompt:"):
            done_with_prompt = True
            line = line[16:].strip()
        if done_with_prompt:
262
            negative_prompt += ("" if negative_prompt == "" else "\n") + line
263
        else:
264
            prompt += ("" if prompt == "" else "\n") + line
265

266 267 268 269 270 271 272 273
    if shared.opts.infotext_styles != "Ignore":
        found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)

        if shared.opts.infotext_styles == "Apply":
            res["Styles array"] = found_styles
        elif shared.opts.infotext_styles == "Apply if any" and found_styles:
            res["Styles array"] = found_styles

274 275
    res["Prompt"] = prompt
    res["Negative prompt"] = negative_prompt
276 277

    for k, v in re_param.findall(lastline):
W
w-e-w 已提交
278 279 280 281 282 283 284 285 286 287 288 289
        try:
            if v[0] == '"' and v[-1] == '"':
                v = unquote(v)

            m = re_imagesize.match(v)
            if m is not None:
                res[f"{k}-1"] = m.group(1)
                res[f"{k}-2"] = m.group(2)
            else:
                res[k] = v
        except Exception:
            print(f"Error parsing \"{k}: {v}\"")
290

291 292 293 294
    # Missing CLIP skip means it was set to 1 (the default)
    if "Clip skip" not in res:
        res["Clip skip"] = "1"

A
AUTOMATIC 已提交
295 296 297
    hypernet = res.get("Hypernet", None)
    if hypernet is not None:
        res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
298

299 300 301 302
    if "Hires resize-1" not in res:
        res["Hires resize-1"] = 0
        res["Hires resize-2"] = 0

303 304 305
    if "Hires sampler" not in res:
        res["Hires sampler"] = "Use same sampler"

A
AUTOMATIC1111 已提交
306 307 308
    if "Hires checkpoint" not in res:
        res["Hires checkpoint"] = "Use same checkpoint"

309 310 311 312 313 314
    if "Hires prompt" not in res:
        res["Hires prompt"] = ""

    if "Hires negative prompt" not in res:
        res["Hires negative prompt"] = ""

A
AUTOMATIC1111 已提交
315 316 317 318 319 320 321 322 323 324 325 326
    if "Mask mode" not in res:
        res["Mask mode"] = "Inpaint masked"

    if "Masked content" not in res:
        res["Masked content"] = 'original'

    if "Inpaint area" not in res:
        res["Inpaint area"] = "Whole picture"

    if "Masked area padding" not in res:
        res["Masked area padding"] = 32

A
AUTOMATIC 已提交
327 328
    restore_old_hires_fix_params(res)

329 330 331 332
    # Missing RNG means the default was set, which is GPU RNG
    if "RNG" not in res:
        res["RNG"] = "GPU"

333 334
    if "Schedule type" not in res:
        res["Schedule type"] = "Automatic"
K
Kohaku-Blueleaf 已提交
335

336 337
    if "Schedule max sigma" not in res:
        res["Schedule max sigma"] = 0
K
Kohaku-Blueleaf 已提交
338

339 340
    if "Schedule min sigma" not in res:
        res["Schedule min sigma"] = 0
K
Kohaku-Blueleaf 已提交
341

342 343
    if "Schedule rho" not in res:
        res["Schedule rho"] = 0
K
Kohaku-Blueleaf 已提交
344

K
Kohaku-Blueleaf 已提交
345 346 347 348 349 350
    if "VAE Encoder" not in res:
        res["VAE Encoder"] = "Full"

    if "VAE Decoder" not in res:
        res["VAE Decoder"] = "Full"

K
Kohaku-Blueleaf 已提交
351 352 353 354 355 356
    if "FP8 weight" not in res:
        res["FP8 weight"] = "Disable"

    if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
        res["Cache FP16 weight for LoRA"] = False

357 358
    infotext_versions.backcompat(res)

359 360 361
    skip = set(shared.opts.infotext_skip_pasting)
    res = {k: v for k, v in res.items() if k not in skip}

362 363
    return res

364

365
infotext_to_setting_name_mapping = [
366 367 368 369 370 371

]
"""Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext='...') instead.
Example content:

infotext_to_setting_name_mapping = [
372 373 374
    ('Conditional mask weight', 'inpainting_mask_weight'),
    ('Model hash', 'sd_model_checkpoint'),
    ('ENSD', 'eta_noise_seed_delta'),
375
    ('Schedule type', 'k_sched_type'),
376
]
377
"""
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397


def create_override_settings_dict(text_pairs):
    """creates processing's override_settings parameters from gradio's multiselect

    Example input:
        ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']

    Example output:
        {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
    """

    res = {}

    params = {}
    for pair in text_pairs:
        k, v = pair.split(":", maxsplit=1)

        params[k] = v.strip()

398 399
    mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
    for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
400 401 402 403 404 405 406 407 408 409
        value = params.get(param_name, None)

        if value is None:
            continue

        res[setting_name] = shared.opts.cast_value(setting_name, value)

    return res


410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
def get_override_settings(params, *, skip_fields=None):
    """Returns a list of settings overrides from the infotext parameters dictionary.

    This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns
    a list of tuples containing the parameter name, setting name, and new value cast to correct type.

    It checks for conditions before adding an override:
    - ignores settings that match the current value
    - ignores parameter keys present in skip_fields argument.

    Example input:
        {"Clip skip": "2"}

    Example output:
        [("Clip skip", "CLIP_stop_at_last_layers", 2)]
    """

    res = []

    mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
    for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
        if param_name in (skip_fields or {}):
            continue

        v = params.get(param_name, None)
        if v is None:
            continue

        if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
            continue

        v = shared.opts.cast_value(setting_name, v)
        current_value = getattr(shared.opts, setting_name, None)

        if v == current_value:
            continue

        res.append((param_name, setting_name, v))

    return res


452
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
453
    def paste_func(prompt):
454
        if not prompt and not shared.cmd_opts.hide_ui_dir_config:
455
            filename = os.path.join(data_path, "params.txt")
T
Trung Ngo 已提交
456 457 458 459
            if os.path.exists(filename):
                with open(filename, "r", encoding="utf8") as file:
                    prompt = file.read()

460
        params = parse_generation_parameters(prompt)
461
        script_callbacks.infotext_pasted_callback(prompt, params)
462 463
        res = []

464 465 466 467 468
        for output, key in paste_fields:
            if callable(key):
                v = key(params)
            else:
                v = params.get(key, None)
469 470 471

            if v is None:
                res.append(gr.update())
472 473
            elif isinstance(v, type_of_gr_update):
                res.append(v)
474 475 476
            else:
                try:
                    valtype = type(output.value)
477 478 479 480 481 482

                    if valtype == bool and v == "False":
                        val = False
                    else:
                        val = valtype(v)

483 484 485 486 487 488
                    res.append(gr.update(value=val))
                except Exception:
                    res.append(gr.update())

        return res

489
    if override_settings_component is not None:
A
AUTOMATIC1111 已提交
490 491
        already_handled_fields = {key: 1 for _, key in paste_fields}

492
        def paste_settings(params):
493
            vals = get_override_settings(params, skip_fields=already_handled_fields)
494

495
            vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals]
496

497
            return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
498 499 500

        paste_fields = paste_fields + [(override_settings_component, paste_settings)]

501 502 503
    button.click(
        fn=paste_func,
        inputs=[input_comp],
504
        outputs=[x[0] for x in paste_fields],
505
        show_progress=False,
506
    )
A
AUTOMATIC 已提交
507 508 509 510 511
    button.click(
        fn=None,
        _js=f"recalculate_prompts_{tabname}",
        inputs=[],
        outputs=[],
512
        show_progress=False,
A
AUTOMATIC 已提交
513
    )
514