infotext_utils.py 18.1 KB
Newer Older
1
from __future__ import annotations
2 3
import base64
import io
4
import json
T
Trung Ngo 已提交
5
import os
6
import re
7
import sys
8

9
import gradio as gr
M
Max Audron 已提交
10
from modules.paths import data_path
11
from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions, images, prompt_parser
12
from PIL import Image
13

14 15
sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__]  # alias for old name

16
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
17
re_param = re.compile(re_param_code)
18
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
19
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
20
type_of_gr_update = type(gr.update())
21 22 23


class ParamBinding:
A
AUTOMATIC 已提交
24
    def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
25 26 27 28 29 30
        self.paste_button = paste_button
        self.tabname = tabname
        self.source_text_component = source_text_component
        self.source_image_component = source_image_component
        self.source_tabname = source_tabname
        self.override_settings_component = override_settings_component
A
AUTOMATIC 已提交
31
        self.paste_field_names = paste_field_names or []
32

33

A
a  
AUTOMATIC1111 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46
class PasteField(tuple):
    def __new__(cls, component, target, *, api=None):
        return super().__new__(cls, (component, target))

    def __init__(self, component, target, *, api=None):
        super().__init__()

        self.api = api
        self.component = component
        self.label = target if isinstance(target, str) else None
        self.function = target if callable(target) else None


47 48 49 50
paste_fields: dict[str, dict] = {}
registered_param_bindings: list[ParamBinding] = []


51 52
def reset():
    paste_fields.clear()
53
    registered_param_bindings.clear()
54 55


56
def quote(text):
57
    if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
58 59
        return text

60 61 62 63 64 65 66 67 68 69 70
    return json.dumps(text, ensure_ascii=False)


def unquote(text):
    if len(text) == 0 or text[0] != '"' or text[-1] != '"':
        return text

    try:
        return json.loads(text)
    except Exception:
        return text
71

72

Y
yfszzx 已提交
73
def image_from_url_text(filedata):
74 75 76
    if filedata is None:
        return None

77
    if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False):
A
AUTOMATIC 已提交
78 79 80
        filedata = filedata[0]

    if type(filedata) == dict and filedata.get("is_file", False):
Y
yfszzx 已提交
81
        filename = filedata["name"]
82
        is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
83
        assert is_in_right_dir, 'trying to open image file outside of allowed directories'
Y
yfszzx 已提交
84

S
Sakura-Luna 已提交
85
        filename = filename.rsplit('?', 1)[0]
86
        return images.read(filename)
Y
yfszzx 已提交
87 88 89 90 91 92 93 94 95 96 97

    if type(filedata) == list:
        if len(filedata) == 0:
            return None

        filedata = filedata[0]

    if filedata.startswith("data:image/png;base64,"):
        filedata = filedata[len("data:image/png;base64,"):]

    filedata = base64.decodebytes(filedata.encode('utf-8'))
98
    image = images.read(io.BytesIO(filedata))
Y
yfszzx 已提交
99 100
    return image

101

102
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
A
a  
AUTOMATIC1111 已提交
103 104 105 106 107 108

    if fields:
        for i in range(len(fields)):
            if not isinstance(fields[i], PasteField):
                fields[i] = PasteField(*fields[i])

109
    paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
110 111 112 113 114 115 116

    # backwards compatibility for existing extensions
    import modules.ui
    if tabname == 'txt2img':
        modules.ui.txt2img_paste_fields = fields
    elif tabname == 'img2img':
        modules.ui.img2img_paste_fields = fields
Y
yfszzx 已提交
117

118

Y
yfszzx 已提交
119 120 121
def create_buttons(tabs_list):
    buttons = {}
    for tab in tabs_list:
122
        buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
Y
yfszzx 已提交
123 124
    return buttons

125

Y
yfszzx 已提交
126
def bind_buttons(buttons, send_image, send_generate_info):
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    """old function for backwards compatibility; do not use this, use register_paste_params_button"""
    for tabname, button in buttons.items():
        source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
        source_tabname = send_generate_info if isinstance(send_generate_info, str) else None

        register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))


def register_paste_params_button(binding: ParamBinding):
    registered_param_bindings.append(binding)


def connect_paste_params_buttons():
    for binding in registered_param_bindings:
        destination_image_component = paste_fields[binding.tabname]["init_img"]
        fields = paste_fields[binding.tabname]["fields"]
143
        override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160

        destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
        destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)

        if binding.source_image_component and destination_image_component:
            if isinstance(binding.source_image_component, gr.Gallery):
                func = send_image_and_dimensions if destination_width_component else image_from_url_text
                jsfunc = "extract_image_from_gallery"
            else:
                func = send_image_and_dimensions if destination_width_component else lambda x: x
                jsfunc = None

            binding.paste_button.click(
                fn=func,
                _js=jsfunc,
                inputs=[binding.source_image_component],
                outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
161
                show_progress=False,
162 163 164
            )

        if binding.source_text_component is not None and fields is not None:
165
            connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
166 167

        if binding.source_tabname is not None and fields is not None:
168
            paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
169 170 171 172
            binding.paste_button.click(
                fn=lambda *x: x,
                inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
                outputs=[field for field, name in fields if name in paste_field_names],
173
                show_progress=False,
174 175 176 177 178 179 180
            )

        binding.paste_button.click(
            fn=None,
            _js=f"switch_to_{binding.tabname}",
            inputs=None,
            outputs=None,
181
            show_progress=False,
182
        )
Y
yfszzx 已提交
183

184

185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
def send_image_and_dimensions(x):
    if isinstance(x, Image.Image):
        img = x
    else:
        img = image_from_url_text(x)

    if shared.opts.send_size and isinstance(img, Image.Image):
        w = img.width
        h = img.height
    else:
        w = gr.update()
        h = gr.update()

    return img, w, h


A
AUTOMATIC 已提交
201 202 203 204 205 206 207
def restore_old_hires_fix_params(res):
    """for infotexts that specify old First pass size parameter, convert it into
    width, height, and hr scale"""

    firstpass_width = res.get('First pass size-1', None)
    firstpass_height = res.get('First pass size-2', None)

208
    if shared.opts.use_old_hires_fix_width_height:
209 210
        hires_width = int(res.get("Hires resize-1", 0))
        hires_height = int(res.get("Hires resize-2", 0))
211

212
        if hires_width and hires_height:
213 214 215 216
            res['Size-1'] = hires_width
            res['Size-2'] = hires_height
            return

A
AUTOMATIC 已提交
217 218 219 220 221 222 223 224
    if firstpass_width is None or firstpass_height is None:
        return

    firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
    width = int(res.get("Size-1", 512))
    height = int(res.get("Size-2", 512))

    if firstpass_width == 0 or firstpass_height == 0:
225
        firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
A
AUTOMATIC 已提交
226 227 228

    res['Size-1'] = firstpass_width
    res['Size-2'] = firstpass_height
229 230
    res['Hires resize-1'] = width
    res['Hires resize-2'] = height
A
AUTOMATIC 已提交
231 232


W
w-e-w 已提交
233
def parse_generation_parameters(x: str, skip_fields: list[str] | None = None):
234 235 236 237 238 239 240 241 242
    """parses generation parameters string, the one you see in text field under the picture in UI:
```
girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
```

    returns a dict with field values
    """
W
w-e-w 已提交
243 244
    if skip_fields is None:
        skip_fields = shared.opts.infotext_skip_pasting
245 246 247 248 249 250 251 252 253

    res = {}

    prompt = ""
    negative_prompt = ""

    done_with_prompt = False

    *lines, lastline = x.strip().split("\n")
254
    if len(re_param.findall(lastline)) < 3:
255 256 257
        lines.append(lastline)
        lastline = ''

A
AUTOMATIC 已提交
258
    for line in lines:
259 260 261 262 263
        line = line.strip()
        if line.startswith("Negative prompt:"):
            done_with_prompt = True
            line = line[16:].strip()
        if done_with_prompt:
264
            negative_prompt += ("" if negative_prompt == "" else "\n") + line
265
        else:
266
            prompt += ("" if prompt == "" else "\n") + line
267

268 269 270 271 272 273 274 275
    if shared.opts.infotext_styles != "Ignore":
        found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)

        if shared.opts.infotext_styles == "Apply":
            res["Styles array"] = found_styles
        elif shared.opts.infotext_styles == "Apply if any" and found_styles:
            res["Styles array"] = found_styles

276 277
    res["Prompt"] = prompt
    res["Negative prompt"] = negative_prompt
278 279

    for k, v in re_param.findall(lastline):
W
w-e-w 已提交
280 281 282 283 284 285 286 287 288 289 290 291
        try:
            if v[0] == '"' and v[-1] == '"':
                v = unquote(v)

            m = re_imagesize.match(v)
            if m is not None:
                res[f"{k}-1"] = m.group(1)
                res[f"{k}-2"] = m.group(2)
            else:
                res[k] = v
        except Exception:
            print(f"Error parsing \"{k}: {v}\"")
292

293 294 295 296
    # Missing CLIP skip means it was set to 1 (the default)
    if "Clip skip" not in res:
        res["Clip skip"] = "1"

A
AUTOMATIC 已提交
297 298 299
    hypernet = res.get("Hypernet", None)
    if hypernet is not None:
        res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
300

301 302 303 304
    if "Hires resize-1" not in res:
        res["Hires resize-1"] = 0
        res["Hires resize-2"] = 0

305 306 307
    if "Hires sampler" not in res:
        res["Hires sampler"] = "Use same sampler"

A
AUTOMATIC1111 已提交
308 309 310
    if "Hires checkpoint" not in res:
        res["Hires checkpoint"] = "Use same checkpoint"

311 312 313 314 315 316
    if "Hires prompt" not in res:
        res["Hires prompt"] = ""

    if "Hires negative prompt" not in res:
        res["Hires negative prompt"] = ""

A
AUTOMATIC1111 已提交
317 318 319 320 321 322 323 324 325 326 327 328
    if "Mask mode" not in res:
        res["Mask mode"] = "Inpaint masked"

    if "Masked content" not in res:
        res["Masked content"] = 'original'

    if "Inpaint area" not in res:
        res["Inpaint area"] = "Whole picture"

    if "Masked area padding" not in res:
        res["Masked area padding"] = 32

A
AUTOMATIC 已提交
329 330
    restore_old_hires_fix_params(res)

331 332 333 334
    # Missing RNG means the default was set, which is GPU RNG
    if "RNG" not in res:
        res["RNG"] = "GPU"

335 336
    if "Schedule type" not in res:
        res["Schedule type"] = "Automatic"
K
Kohaku-Blueleaf 已提交
337

338 339
    if "Schedule max sigma" not in res:
        res["Schedule max sigma"] = 0
K
Kohaku-Blueleaf 已提交
340

341 342
    if "Schedule min sigma" not in res:
        res["Schedule min sigma"] = 0
K
Kohaku-Blueleaf 已提交
343

344 345
    if "Schedule rho" not in res:
        res["Schedule rho"] = 0
K
Kohaku-Blueleaf 已提交
346

K
Kohaku-Blueleaf 已提交
347 348 349 350 351 352
    if "VAE Encoder" not in res:
        res["VAE Encoder"] = "Full"

    if "VAE Decoder" not in res:
        res["VAE Decoder"] = "Full"

K
Kohaku-Blueleaf 已提交
353 354 355 356 357 358
    if "FP8 weight" not in res:
        res["FP8 weight"] = "Disable"

    if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
        res["Cache FP16 weight for LoRA"] = False

359 360 361 362
    prompt_attention = prompt_parser.parse_prompt_attention(prompt)
    prompt_attention += prompt_parser.parse_prompt_attention(negative_prompt)
    prompt_uses_emphasis = len(prompt_attention) != len([p for p in prompt_attention if p[1] == 1.0 or p[0] == 'BREAK'])
    if "Emphasis" not in res and prompt_uses_emphasis:
363 364
        res["Emphasis"] = "Original"

A
AUTOMATIC1111 已提交
365 366 367
    if "Refiner switch by sampling steps" not in res:
        res["Refiner switch by sampling steps"] = False

368 369
    infotext_versions.backcompat(res)

W
w-e-w 已提交
370 371
    for key in skip_fields:
        res.pop(key, None)
372

373 374
    return res

375

376
infotext_to_setting_name_mapping = [
377 378 379 380 381 382

]
"""Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext='...') instead.
Example content:

infotext_to_setting_name_mapping = [
383 384 385
    ('Conditional mask weight', 'inpainting_mask_weight'),
    ('Model hash', 'sd_model_checkpoint'),
    ('ENSD', 'eta_noise_seed_delta'),
386
    ('Schedule type', 'k_sched_type'),
387
]
388
"""
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408


def create_override_settings_dict(text_pairs):
    """creates processing's override_settings parameters from gradio's multiselect

    Example input:
        ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']

    Example output:
        {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
    """

    res = {}

    params = {}
    for pair in text_pairs:
        k, v = pair.split(":", maxsplit=1)

        params[k] = v.strip()

409 410
    mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
    for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
411 412 413 414 415 416 417 418 419 420
        value = params.get(param_name, None)

        if value is None:
            continue

        res[setting_name] = shared.opts.cast_value(setting_name, value)

    return res


421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462
def get_override_settings(params, *, skip_fields=None):
    """Returns a list of settings overrides from the infotext parameters dictionary.

    This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns
    a list of tuples containing the parameter name, setting name, and new value cast to correct type.

    It checks for conditions before adding an override:
    - ignores settings that match the current value
    - ignores parameter keys present in skip_fields argument.

    Example input:
        {"Clip skip": "2"}

    Example output:
        [("Clip skip", "CLIP_stop_at_last_layers", 2)]
    """

    res = []

    mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
    for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
        if param_name in (skip_fields or {}):
            continue

        v = params.get(param_name, None)
        if v is None:
            continue

        if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
            continue

        v = shared.opts.cast_value(setting_name, v)
        current_value = getattr(shared.opts, setting_name, None)

        if v == current_value:
            continue

        res.append((param_name, setting_name, v))

    return res


463
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
464
    def paste_func(prompt):
465
        if not prompt and not shared.cmd_opts.hide_ui_dir_config:
466
            filename = os.path.join(data_path, "params.txt")
467
            try:
T
Trung Ngo 已提交
468 469
                with open(filename, "r", encoding="utf8") as file:
                    prompt = file.read()
470 471
            except OSError:
                pass
T
Trung Ngo 已提交
472

473
        params = parse_generation_parameters(prompt)
474
        script_callbacks.infotext_pasted_callback(prompt, params)
475 476
        res = []

477 478 479 480 481
        for output, key in paste_fields:
            if callable(key):
                v = key(params)
            else:
                v = params.get(key, None)
482 483 484

            if v is None:
                res.append(gr.update())
485 486
            elif isinstance(v, type_of_gr_update):
                res.append(v)
487 488 489
            else:
                try:
                    valtype = type(output.value)
490 491 492

                    if valtype == bool and v == "False":
                        val = False
W
w-e-w 已提交
493 494
                    elif valtype == int:
                        val = float(v)
495 496 497
                    else:
                        val = valtype(v)

498 499 500 501 502 503
                    res.append(gr.update(value=val))
                except Exception:
                    res.append(gr.update())

        return res

504
    if override_settings_component is not None:
A
AUTOMATIC1111 已提交
505 506
        already_handled_fields = {key: 1 for _, key in paste_fields}

507
        def paste_settings(params):
508
            vals = get_override_settings(params, skip_fields=already_handled_fields)
509

510
            vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals]
511

512
            return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
513 514 515

        paste_fields = paste_fields + [(override_settings_component, paste_settings)]

516 517 518
    button.click(
        fn=paste_func,
        inputs=[input_comp],
519
        outputs=[x[0] for x in paste_fields],
520
        show_progress=False,
521
    )
A
AUTOMATIC 已提交
522 523 524 525 526
    button.click(
        fn=None,
        _js=f"recalculate_prompts_{tabname}",
        inputs=[],
        outputs=[],
527
        show_progress=False,
A
AUTOMATIC 已提交
528
    )
529