infotext_utils.py 17.1 KB
Newer Older
1
from __future__ import annotations
2 3
import base64
import io
4
import json
T
Trung Ngo 已提交
5
import os
6
import re
7
import sys
8

9
import gradio as gr
M
Max Audron 已提交
10
from modules.paths import data_path
A
linter  
AUTOMATIC1111 已提交
11
from modules import shared, ui_tempdir, script_callbacks, processing, infotext_versions
12
from PIL import Image
13

14 15
sys.modules['modules.generation_parameters_copypaste'] = sys.modules[__name__]  # alias for old name

16
re_param_code = r'\s*(\w[\w \-/]+):\s*("(?:\\.|[^\\"])+"|[^,]*)(?:,|$)'
17
re_param = re.compile(re_param_code)
18
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
19
re_hypernet_hash = re.compile("\(([0-9a-f]+)\)$")
20
type_of_gr_update = type(gr.update())
21 22 23


class ParamBinding:
A
AUTOMATIC 已提交
24
    def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=None):
25 26 27 28 29 30
        self.paste_button = paste_button
        self.tabname = tabname
        self.source_text_component = source_text_component
        self.source_image_component = source_image_component
        self.source_tabname = source_tabname
        self.override_settings_component = override_settings_component
A
AUTOMATIC 已提交
31
        self.paste_field_names = paste_field_names or []
32

33

A
a  
AUTOMATIC1111 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46
class PasteField(tuple):
    def __new__(cls, component, target, *, api=None):
        return super().__new__(cls, (component, target))

    def __init__(self, component, target, *, api=None):
        super().__init__()

        self.api = api
        self.component = component
        self.label = target if isinstance(target, str) else None
        self.function = target if callable(target) else None


47 48 49 50
paste_fields: dict[str, dict] = {}
registered_param_bindings: list[ParamBinding] = []


51 52
def reset():
    paste_fields.clear()
53
    registered_param_bindings.clear()
54 55


56
def quote(text):
57
    if ',' not in str(text) and '\n' not in str(text) and ':' not in str(text):
58 59
        return text

60 61 62 63 64 65 66 67 68 69 70
    return json.dumps(text, ensure_ascii=False)


def unquote(text):
    if len(text) == 0 or text[0] != '"' or text[-1] != '"':
        return text

    try:
        return json.loads(text)
    except Exception:
        return text
71

72

Y
yfszzx 已提交
73
def image_from_url_text(filedata):
74 75 76
    if filedata is None:
        return None

77
    if type(filedata) == list and filedata and type(filedata[0]) == dict and filedata[0].get("is_file", False):
A
AUTOMATIC 已提交
78 79 80
        filedata = filedata[0]

    if type(filedata) == dict and filedata.get("is_file", False):
Y
yfszzx 已提交
81
        filename = filedata["name"]
82
        is_in_right_dir = ui_tempdir.check_tmp_file(shared.demo, filename)
83
        assert is_in_right_dir, 'trying to open image file outside of allowed directories'
Y
yfszzx 已提交
84

S
Sakura-Luna 已提交
85
        filename = filename.rsplit('?', 1)[0]
Y
yfszzx 已提交
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
        return Image.open(filename)

    if type(filedata) == list:
        if len(filedata) == 0:
            return None

        filedata = filedata[0]

    if filedata.startswith("data:image/png;base64,"):
        filedata = filedata[len("data:image/png;base64,"):]

    filedata = base64.decodebytes(filedata.encode('utf-8'))
    image = Image.open(io.BytesIO(filedata))
    return image

101

102
def add_paste_fields(tabname, init_img, fields, override_settings_component=None):
A
a  
AUTOMATIC1111 已提交
103 104 105 106 107 108

    if fields:
        for i in range(len(fields)):
            if not isinstance(fields[i], PasteField):
                fields[i] = PasteField(*fields[i])

109
    paste_fields[tabname] = {"init_img": init_img, "fields": fields, "override_settings_component": override_settings_component}
110 111 112 113 114 115 116

    # backwards compatibility for existing extensions
    import modules.ui
    if tabname == 'txt2img':
        modules.ui.txt2img_paste_fields = fields
    elif tabname == 'img2img':
        modules.ui.img2img_paste_fields = fields
Y
yfszzx 已提交
117

118

Y
yfszzx 已提交
119 120 121
def create_buttons(tabs_list):
    buttons = {}
    for tab in tabs_list:
122
        buttons[tab] = gr.Button(f"Send to {tab}", elem_id=f"{tab}_tab")
Y
yfszzx 已提交
123 124
    return buttons

125

Y
yfszzx 已提交
126
def bind_buttons(buttons, send_image, send_generate_info):
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
    """old function for backwards compatibility; do not use this, use register_paste_params_button"""
    for tabname, button in buttons.items():
        source_text_component = send_generate_info if isinstance(send_generate_info, gr.components.Component) else None
        source_tabname = send_generate_info if isinstance(send_generate_info, str) else None

        register_paste_params_button(ParamBinding(paste_button=button, tabname=tabname, source_text_component=source_text_component, source_image_component=send_image, source_tabname=source_tabname))


def register_paste_params_button(binding: ParamBinding):
    registered_param_bindings.append(binding)


def connect_paste_params_buttons():
    for binding in registered_param_bindings:
        destination_image_component = paste_fields[binding.tabname]["init_img"]
        fields = paste_fields[binding.tabname]["fields"]
143
        override_settings_component = binding.override_settings_component or paste_fields[binding.tabname]["override_settings_component"]
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160

        destination_width_component = next(iter([field for field, name in fields if name == "Size-1"] if fields else []), None)
        destination_height_component = next(iter([field for field, name in fields if name == "Size-2"] if fields else []), None)

        if binding.source_image_component and destination_image_component:
            if isinstance(binding.source_image_component, gr.Gallery):
                func = send_image_and_dimensions if destination_width_component else image_from_url_text
                jsfunc = "extract_image_from_gallery"
            else:
                func = send_image_and_dimensions if destination_width_component else lambda x: x
                jsfunc = None

            binding.paste_button.click(
                fn=func,
                _js=jsfunc,
                inputs=[binding.source_image_component],
                outputs=[destination_image_component, destination_width_component, destination_height_component] if destination_width_component else [destination_image_component],
161
                show_progress=False,
162 163 164
            )

        if binding.source_text_component is not None and fields is not None:
165
            connect_paste(binding.paste_button, fields, binding.source_text_component, override_settings_component, binding.tabname)
166 167

        if binding.source_tabname is not None and fields is not None:
168
            paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
169 170 171 172
            binding.paste_button.click(
                fn=lambda *x: x,
                inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
                outputs=[field for field, name in fields if name in paste_field_names],
173
                show_progress=False,
174 175 176 177 178 179 180
            )

        binding.paste_button.click(
            fn=None,
            _js=f"switch_to_{binding.tabname}",
            inputs=None,
            outputs=None,
181
            show_progress=False,
182
        )
Y
yfszzx 已提交
183

184

185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
def send_image_and_dimensions(x):
    if isinstance(x, Image.Image):
        img = x
    else:
        img = image_from_url_text(x)

    if shared.opts.send_size and isinstance(img, Image.Image):
        w = img.width
        h = img.height
    else:
        w = gr.update()
        h = gr.update()

    return img, w, h


A
AUTOMATIC 已提交
201 202 203 204 205 206 207
def restore_old_hires_fix_params(res):
    """for infotexts that specify old First pass size parameter, convert it into
    width, height, and hr scale"""

    firstpass_width = res.get('First pass size-1', None)
    firstpass_height = res.get('First pass size-2', None)

208
    if shared.opts.use_old_hires_fix_width_height:
209 210
        hires_width = int(res.get("Hires resize-1", 0))
        hires_height = int(res.get("Hires resize-2", 0))
211

212
        if hires_width and hires_height:
213 214 215 216
            res['Size-1'] = hires_width
            res['Size-2'] = hires_height
            return

A
AUTOMATIC 已提交
217 218 219 220 221 222 223 224
    if firstpass_width is None or firstpass_height is None:
        return

    firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height)
    width = int(res.get("Size-1", 512))
    height = int(res.get("Size-2", 512))

    if firstpass_width == 0 or firstpass_height == 0:
225
        firstpass_width, firstpass_height = processing.old_hires_fix_first_pass_dimensions(width, height)
A
AUTOMATIC 已提交
226 227 228

    res['Size-1'] = firstpass_width
    res['Size-2'] = firstpass_height
229 230
    res['Hires resize-1'] = width
    res['Hires resize-2'] = height
A
AUTOMATIC 已提交
231 232


233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
def parse_generation_parameters(x: str):
    """parses generation parameters string, the one you see in text field under the picture in UI:
```
girl with an artist's beret, determined, blue eyes, desert scene, computer monitors, heavy makeup, by Alphonse Mucha and Charlie Bowater, ((eyeshadow)), (coquettish), detailed, intricate
Negative prompt: ugly, fat, obese, chubby, (((deformed))), [blurry], bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), messy drawing
Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hash: 45dee52b
```

    returns a dict with field values
    """

    res = {}

    prompt = ""
    negative_prompt = ""

    done_with_prompt = False

    *lines, lastline = x.strip().split("\n")
252
    if len(re_param.findall(lastline)) < 3:
253 254 255
        lines.append(lastline)
        lastline = ''

A
AUTOMATIC 已提交
256
    for line in lines:
257 258 259 260 261
        line = line.strip()
        if line.startswith("Negative prompt:"):
            done_with_prompt = True
            line = line[16:].strip()
        if done_with_prompt:
262
            negative_prompt += ("" if negative_prompt == "" else "\n") + line
263
        else:
264
            prompt += ("" if prompt == "" else "\n") + line
265

266 267 268 269 270 271 272 273
    if shared.opts.infotext_styles != "Ignore":
        found_styles, prompt, negative_prompt = shared.prompt_styles.extract_styles_from_prompt(prompt, negative_prompt)

        if shared.opts.infotext_styles == "Apply":
            res["Styles array"] = found_styles
        elif shared.opts.infotext_styles == "Apply if any" and found_styles:
            res["Styles array"] = found_styles

274 275
    res["Prompt"] = prompt
    res["Negative prompt"] = negative_prompt
276 277

    for k, v in re_param.findall(lastline):
W
w-e-w 已提交
278 279 280 281 282 283 284 285 286 287 288 289
        try:
            if v[0] == '"' and v[-1] == '"':
                v = unquote(v)

            m = re_imagesize.match(v)
            if m is not None:
                res[f"{k}-1"] = m.group(1)
                res[f"{k}-2"] = m.group(2)
            else:
                res[k] = v
        except Exception:
            print(f"Error parsing \"{k}: {v}\"")
290

291 292 293 294
    # Missing CLIP skip means it was set to 1 (the default)
    if "Clip skip" not in res:
        res["Clip skip"] = "1"

A
AUTOMATIC 已提交
295 296 297
    hypernet = res.get("Hypernet", None)
    if hypernet is not None:
        res["Prompt"] += f"""<hypernet:{hypernet}:{res.get("Hypernet strength", "1.0")}>"""
298

299 300 301 302
    if "Hires resize-1" not in res:
        res["Hires resize-1"] = 0
        res["Hires resize-2"] = 0

303 304 305
    if "Hires sampler" not in res:
        res["Hires sampler"] = "Use same sampler"

A
AUTOMATIC1111 已提交
306 307 308
    if "Hires checkpoint" not in res:
        res["Hires checkpoint"] = "Use same checkpoint"

309 310 311 312 313 314
    if "Hires prompt" not in res:
        res["Hires prompt"] = ""

    if "Hires negative prompt" not in res:
        res["Hires negative prompt"] = ""

A
AUTOMATIC 已提交
315 316
    restore_old_hires_fix_params(res)

317 318 319 320
    # Missing RNG means the default was set, which is GPU RNG
    if "RNG" not in res:
        res["RNG"] = "GPU"

321 322
    if "Schedule type" not in res:
        res["Schedule type"] = "Automatic"
K
Kohaku-Blueleaf 已提交
323

324 325
    if "Schedule max sigma" not in res:
        res["Schedule max sigma"] = 0
K
Kohaku-Blueleaf 已提交
326

327 328
    if "Schedule min sigma" not in res:
        res["Schedule min sigma"] = 0
K
Kohaku-Blueleaf 已提交
329

330 331
    if "Schedule rho" not in res:
        res["Schedule rho"] = 0
K
Kohaku-Blueleaf 已提交
332

K
Kohaku-Blueleaf 已提交
333 334 335 336 337 338
    if "VAE Encoder" not in res:
        res["VAE Encoder"] = "Full"

    if "VAE Decoder" not in res:
        res["VAE Decoder"] = "Full"

K
Kohaku-Blueleaf 已提交
339 340 341 342 343 344
    if "FP8 weight" not in res:
        res["FP8 weight"] = "Disable"

    if "Cache FP16 weight for LoRA" not in res and res["FP8 weight"] != "Disable":
        res["Cache FP16 weight for LoRA"] = False

345 346
    infotext_versions.backcompat(res)

347 348 349
    skip = set(shared.opts.infotext_skip_pasting)
    res = {k: v for k, v in res.items() if k not in skip}

350 351
    return res

352

353
infotext_to_setting_name_mapping = [
354 355 356 357 358 359

]
"""Mapping of infotext labels to setting names. Only left for backwards compatibility - use OptionInfo(..., infotext='...') instead.
Example content:

infotext_to_setting_name_mapping = [
360 361 362
    ('Conditional mask weight', 'inpainting_mask_weight'),
    ('Model hash', 'sd_model_checkpoint'),
    ('ENSD', 'eta_noise_seed_delta'),
363
    ('Schedule type', 'k_sched_type'),
364
]
365
"""
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385


def create_override_settings_dict(text_pairs):
    """creates processing's override_settings parameters from gradio's multiselect

    Example input:
        ['Clip skip: 2', 'Model hash: e6e99610c4', 'ENSD: 31337']

    Example output:
        {'CLIP_stop_at_last_layers': 2, 'sd_model_checkpoint': 'e6e99610c4', 'eta_noise_seed_delta': 31337}
    """

    res = {}

    params = {}
    for pair in text_pairs:
        k, v = pair.split(":", maxsplit=1)

        params[k] = v.strip()

386 387
    mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
    for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
388 389 390 391 392 393 394 395 396 397
        value = params.get(param_name, None)

        if value is None:
            continue

        res[setting_name] = shared.opts.cast_value(setting_name, value)

    return res


398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439
def get_override_settings(params, *, skip_fields=None):
    """Returns a list of settings overrides from the infotext parameters dictionary.

    This function checks the `params` dictionary for any keys that correspond to settings in `shared.opts` and returns
    a list of tuples containing the parameter name, setting name, and new value cast to correct type.

    It checks for conditions before adding an override:
    - ignores settings that match the current value
    - ignores parameter keys present in skip_fields argument.

    Example input:
        {"Clip skip": "2"}

    Example output:
        [("Clip skip", "CLIP_stop_at_last_layers", 2)]
    """

    res = []

    mapping = [(info.infotext, k) for k, info in shared.opts.data_labels.items() if info.infotext]
    for param_name, setting_name in mapping + infotext_to_setting_name_mapping:
        if param_name in (skip_fields or {}):
            continue

        v = params.get(param_name, None)
        if v is None:
            continue

        if setting_name == "sd_model_checkpoint" and shared.opts.disable_weights_auto_swap:
            continue

        v = shared.opts.cast_value(setting_name, v)
        current_value = getattr(shared.opts, setting_name, None)

        if v == current_value:
            continue

        res.append((param_name, setting_name, v))

    return res


440
def connect_paste(button, paste_fields, input_comp, override_settings_component, tabname):
441
    def paste_func(prompt):
442
        if not prompt and not shared.cmd_opts.hide_ui_dir_config:
443
            filename = os.path.join(data_path, "params.txt")
T
Trung Ngo 已提交
444 445 446 447
            if os.path.exists(filename):
                with open(filename, "r", encoding="utf8") as file:
                    prompt = file.read()

448
        params = parse_generation_parameters(prompt)
449
        script_callbacks.infotext_pasted_callback(prompt, params)
450 451
        res = []

452 453 454 455 456
        for output, key in paste_fields:
            if callable(key):
                v = key(params)
            else:
                v = params.get(key, None)
457 458 459

            if v is None:
                res.append(gr.update())
460 461
            elif isinstance(v, type_of_gr_update):
                res.append(v)
462 463 464
            else:
                try:
                    valtype = type(output.value)
465 466 467 468 469 470

                    if valtype == bool and v == "False":
                        val = False
                    else:
                        val = valtype(v)

471 472 473 474 475 476
                    res.append(gr.update(value=val))
                except Exception:
                    res.append(gr.update())

        return res

477
    if override_settings_component is not None:
A
AUTOMATIC1111 已提交
478 479
        already_handled_fields = {key: 1 for _, key in paste_fields}

480
        def paste_settings(params):
481
            vals = get_override_settings(params, skip_fields=already_handled_fields)
482

483
            vals_pairs = [f"{infotext_text}: {value}" for infotext_text, setting_name, value in vals]
484

485
            return gr.Dropdown.update(value=vals_pairs, choices=vals_pairs, visible=bool(vals_pairs))
486 487 488

        paste_fields = paste_fields + [(override_settings_component, paste_settings)]

489 490 491
    button.click(
        fn=paste_func,
        inputs=[input_comp],
492
        outputs=[x[0] for x in paste_fields],
493
        show_progress=False,
494
    )
A
AUTOMATIC 已提交
495 496 497 498 499
    button.click(
        fn=None,
        _js=f"recalculate_prompts_{tabname}",
        inputs=[],
        outputs=[],
500
        show_progress=False,
A
AUTOMATIC 已提交
501
    )
502