images.py 20.8 KB
Newer Older
1
import datetime
2 3 4
import sys
import traceback

5
import pytz
A
AUTOMATIC 已提交
6
import io
7 8 9 10 11 12
import math
import os
from collections import namedtuple
import re

import numpy as np
J
JJ 已提交
13 14
import piexif
import piexif.helper
15
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
16
from fonts.ttf import Roboto
17
import string
18

T
Trung Ngo 已提交
19
from modules import sd_samplers, shared, script_callbacks
20
from modules.shared import opts, cmd_opts
21 22 23 24 25 26 27 28 29 30

LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)


def image_grid(imgs, batch_size=1, rows=None):
    if rows is None:
        if opts.n_rows > 0:
            rows = opts.n_rows
        elif opts.n_rows == 0:
            rows = batch_size
31
        elif opts.grid_prevent_empty_spots:
32 33 34
            rows = math.floor(math.sqrt(len(imgs)))
            while len(imgs) % rows != 0:
                rows -= 1
35 36 37
        else:
            rows = math.sqrt(len(imgs))
            rows = round(rows)
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56

    cols = math.ceil(len(imgs) / rows)

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols * w, rows * h), color='black')

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))

    return grid


Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])


def split_grid(image, tile_w=512, tile_h=512, overlap=64):
    w = image.width
    h = image.height

A
AUTOMATIC 已提交
57 58
    non_overlap_width = tile_w - overlap
    non_overlap_height = tile_h - overlap
59

A
AUTOMATIC 已提交
60 61 62
    cols = math.ceil((w - overlap) / non_overlap_width)
    rows = math.ceil((h - overlap) / non_overlap_height)

D
d8ahazard 已提交
63 64
    dx = (w - tile_w) / (cols - 1) if cols > 1 else 0
    dy = (h - tile_h) / (rows - 1) if rows > 1 else 0
65 66 67 68 69

    grid = Grid([], tile_w, tile_h, w, h, overlap)
    for row in range(rows):
        row_images = []

A
AUTOMATIC 已提交
70
        y = int(row * dy)
71 72 73 74 75

        if y + tile_h >= h:
            y = h - tile_h

        for col in range(cols):
A
AUTOMATIC 已提交
76
            x = int(col * dx)
77

D
d8ahazard 已提交
78
            if x + tile_w >= w:
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
                x = w - tile_w

            tile = image.crop((x, y, x + tile_w, y + tile_h))

            row_images.append([x, tile_w, tile])

        grid.tiles.append([y, tile_h, row_images])

    return grid


def combine_grid(grid):
    def make_mask_image(r):
        r = r * 255 / grid.overlap
        r = r.astype(np.uint8)
        return Image.fromarray(r, 'L')

96 97
    mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
    mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139

    combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
    for y, h, row in grid.tiles:
        combined_row = Image.new("RGB", (grid.image_w, h))
        for x, w, tile in row:
            if x == 0:
                combined_row.paste(tile, (0, 0))
                continue

            combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
            combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))

        if y == 0:
            combined_image.paste(combined_row, (0, 0))
            continue

        combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
        combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))

    return combined_image


class GridAnnotation:
    def __init__(self, text='', is_active=True):
        self.text = text
        self.is_active = is_active
        self.size = None


def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
    def wrap(drawing, text, font, line_length):
        lines = ['']
        for word in text.split():
            line = f'{lines[-1]} {word}'.strip()
            if drawing.textlength(line, font=font) <= line_length:
                lines[-1] = line
            else:
                lines.append(word)
        return lines

    def draw_texts(drawing, draw_x, draw_y, lines):
        for i, line in enumerate(lines):
140
            drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
141 142

            if not line.is_active:
143
                drawing.line((draw_x - line.size[0] // 2, draw_y + line.size[1] // 2, draw_x + line.size[0] // 2, draw_y + line.size[1] // 2), fill=color_inactive, width=4)
144 145 146 147 148

            draw_y += line.size[1] + line_spacing

    fontsize = (width + height) // 25
    line_spacing = fontsize // 2
149 150 151 152 153 154

    try:
        fnt = ImageFont.truetype(opts.font or Roboto, fontsize)
    except Exception:
        fnt = ImageFont.truetype(Roboto, fontsize)

155 156 157
    color_active = (0, 0, 0)
    color_inactive = (153, 153, 153)

158
    pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181

    cols = im.width // width
    rows = im.height // height

    assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
    assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'

    calc_img = Image.new("RGB", (1, 1), "white")
    calc_d = ImageDraw.Draw(calc_img)

    for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
        items = [] + texts
        texts.clear()

        for line in items:
            wrapped = wrap(calc_d, line.text, fnt, allowed_width)
            texts += [GridAnnotation(x, line.is_active) for x in wrapped]

        for line in texts:
            bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
            line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])

    hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
D
d8ahazard 已提交
182 183
    ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in
                        ver_texts]
184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213

    pad_top = max(hor_text_heights) + line_spacing * 2

    result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
    result.paste(im, (pad_left, pad_top))

    d = ImageDraw.Draw(result)

    for col in range(cols):
        x = pad_left + width * col + width / 2
        y = pad_top / 2 - hor_text_heights[col] / 2

        draw_texts(d, x, y, hor_texts[col])

    for row in range(rows):
        x = pad_left / 2
        y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2

        draw_texts(d, x, y, ver_texts[row])

    return result


def draw_prompt_matrix(im, width, height, all_prompts):
    prompts = all_prompts[1:]
    boundary = math.ceil(len(prompts) / 2)

    prompts_horiz = prompts[:boundary]
    prompts_vert = prompts[boundary:]

214 215
    hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
    ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
216 217 218 219 220

    return draw_grid_annotations(im, width, height, hor_texts, ver_texts)


def resize_image(resize_mode, im, width, height):
221
    def resize(im, w, h):
222
        if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
223 224
            return im.resize((w, h), resample=LANCZOS)

225 226
        scale = max(w / im.width, h / im.height)

227 228 229 230 231 232 233 234 235
        if scale > 1.0:
            upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img]
            assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}"

            upscaler = upscalers[0]
            im = upscaler.scaler.upscale(im, scale, upscaler.data_path)

        if im.width != w or im.height != h:
            im = im.resize((w, h), resample=LANCZOS)
236

237
        return im
238

239
    if resize_mode == 0:
240 241
        res = resize(im, width, height)

242 243 244 245 246 247 248
    elif resize_mode == 1:
        ratio = width / height
        src_ratio = im.width / im.height

        src_w = width if ratio > src_ratio else im.width * height // im.height
        src_h = height if ratio <= src_ratio else im.height * width // im.width

249
        resized = resize(im, src_w, src_h)
250 251
        res = Image.new("RGB", (width, height))
        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
252

253 254 255 256 257 258 259
    else:
        ratio = width / height
        src_ratio = im.width / im.height

        src_w = width if ratio < src_ratio else im.width * height // im.height
        src_h = height if ratio >= src_ratio else im.height * width // im.width

260
        resized = resize(im, src_w, src_h)
261 262 263 264 265 266
        res = Image.new("RGB", (width, height))
        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))

        if ratio < src_ratio:
            fill_height = height // 2 - src_h // 2
            res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
267
            res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
268 269 270
        elif ratio > src_ratio:
            fill_width = width // 2 - src_w // 2
            res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
271
            res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
272 273 274 275 276

    return res


invalid_filename_chars = '<>:"/\\|?*\n'
M
Milly 已提交
277 278
invalid_filename_prefix = ' '
invalid_filename_postfix = ' .'
D
d8ahazard 已提交
279
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
280
re_pattern = re.compile(r"(.*?)(?:\[([^\[\]]+)\]|$)")
281
re_pattern_arg = re.compile(r"(.*)<([^>]*)>$")
M
Milly 已提交
282
max_filename_part_length = 128
283 284


285
def sanitize_filename_part(text, replace_spaces=True):
286 287 288
    if text is None:
        return None

289 290
    if replace_spaces:
        text = text.replace(' ', '_')
291

M
Milly 已提交
292 293 294 295
    text = text.translate({ord(x): '_' for x in invalid_filename_chars})
    text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
    text = text.rstrip(invalid_filename_postfix)
    return text
296

297

298 299 300 301 302
class FilenameGenerator:
    replacements = {
        'seed': lambda self: self.seed if self.seed is not None else '',
        'steps': lambda self:  self.p and self.p.steps,
        'cfg': lambda self: self.p and self.p.cfg_scale,
Y
Yaiol 已提交
303 304
        'width': lambda self: self.image.width,
        'height': lambda self: self.image.height,
305 306 307 308 309 310 311 312 313 314 315 316 317
        'styles': lambda self: self.p and sanitize_filename_part(", ".join([style for style in self.p.styles if not style == "None"]) or "None", replace_spaces=False),
        'sampler': lambda self: self.p and sanitize_filename_part(sd_samplers.samplers[self.p.sampler_index].name, replace_spaces=False),
        'model_hash': lambda self: getattr(self.p, "sd_model_hash", shared.sd_model.sd_model_hash),
        'date': lambda self: datetime.datetime.now().strftime('%Y-%m-%d'),
        'datetime': lambda self, *args: self.datetime(*args),  # accepts formats: [datetime], [datetime<Format>], [datetime<Format><Time Zone>]
        'job_timestamp': lambda self: getattr(self.p, "job_timestamp", shared.state.job_timestamp),
        'prompt': lambda self: sanitize_filename_part(self.prompt),
        'prompt_no_styles': lambda self: self.prompt_no_style(),
        'prompt_spaces': lambda self: sanitize_filename_part(self.prompt, replace_spaces=False),
        'prompt_words': lambda self: self.prompt_words(),
    }
    default_time_format = '%Y%m%d%H%M%S'

Y
Yaiol 已提交
318
    def __init__(self, p, seed, prompt, image):
319 320 321
        self.p = p
        self.seed = seed
        self.prompt = prompt
Y
Yaiol 已提交
322
        self.image = image
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346

    def prompt_no_style(self):
        if self.p is None or self.prompt is None:
            return None

        prompt_no_style = self.prompt
        for style in shared.prompt_styles.get_style_prompts(self.p.styles):
            if len(style) > 0:
                for part in style.split("{prompt}"):
                    prompt_no_style = prompt_no_style.replace(part, "").replace(", ,", ",").strip().strip(',')

                prompt_no_style = prompt_no_style.replace(style, "").strip().strip(',').strip()

        return sanitize_filename_part(prompt_no_style, replace_spaces=False)

    def prompt_words(self):
        words = [x for x in re_nonletters.split(self.prompt or "") if len(x) > 0]
        if len(words) == 0:
            words = ["empty"]
        return sanitize_filename_part(" ".join(words[0:opts.directories_max_prompt_words]), replace_spaces=False)

    def datetime(self, *args):
        time_datetime = datetime.datetime.now()

W
w-e-w 已提交
347
        time_format = args[0] if len(args) > 0 and args[0] != "" else self.default_time_format
W
w-e-w 已提交
348 349 350 351
        try:
            time_zone = pytz.timezone(args[1]) if len(args) > 1 else None
        except pytz.exceptions.UnknownTimeZoneError as _:
            time_zone = None
352 353 354 355 356 357 358 359 360 361 362 363 364 365

        time_zone_time = time_datetime.astimezone(time_zone)
        try:
            formatted_time = time_zone_time.strftime(time_format)
        except (ValueError, TypeError) as _:
            formatted_time = time_zone_time.strftime(self.default_time_format)

        return sanitize_filename_part(formatted_time, replace_spaces=False)

    def apply(self, x):
        res = ''

        for m in re_pattern.finditer(x):
            text, pattern = m.groups()
366
            res += text
367 368 369

            if pattern is None:
                continue
M
Milly 已提交
370

371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
            pattern_args = []
            while True:
                m = re_pattern_arg.match(pattern)
                if m is None:
                    break

                pattern, arg = m.groups()
                pattern_args.insert(0, arg)

            fun = self.replacements.get(pattern.lower())
            if fun is not None:
                try:
                    replacement = fun(self, *pattern_args)
                except Exception:
                    replacement = None
                    print(f"Error adding [{pattern}] to filename", file=sys.stderr)
                    print(traceback.format_exc(), file=sys.stderr)

389
                if replacement is not None:
390
                    res += str(replacement)
391
                    continue
392

393
            res += f'[{pattern}]'
394

395
        return res
396

D
d8ahazard 已提交
397

M
Michoko 已提交
398
def get_next_sequence_number(path, basename):
M
Michoko 已提交
399 400 401 402 403 404
    """
    Determines and returns the next sequence number to use when saving an image in the specified directory.

    The sequence starts at 0.
    """
    result = -1
M
Michoko 已提交
405 406 407 408
    if basename != '':
        basename = basename + "-"

    prefix_length = len(basename)
M
Michoko 已提交
409
    for p in os.listdir(path):
M
Michoko 已提交
410
        if p.startswith(basename):
411
            l = os.path.splitext(p[prefix_length:])[0].split('-')  # splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
M
Michoko 已提交
412
            try:
M
Michoko 已提交
413
                result = max(int(l[0]), result)
M
Michoko 已提交
414
            except ValueError:
M
Michoko 已提交
415 416
                pass

M
Michoko 已提交
417
    return result + 1
418

D
d8ahazard 已提交
419

M
Milly 已提交
420
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix="", save_to_dirs=None):
421
    """Save an image.
A
aoirusann 已提交
422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451

    Args:
        image (`PIL.Image`):
            The image to be saved.
        path (`str`):
            The directory to save the image. Note, the option `save_to_dirs` will make the image to be saved into a sub directory.
        basename (`str`):
            The base filename which will be applied to `filename pattern`.
        seed, prompt, short_filename, 
        extension (`str`):
            Image file extension, default is `png`.
        pngsectionname (`str`):
            Specify the name of the section which `info` will be saved in.
        info (`str` or `PngImagePlugin.iTXt`):
            PNG info chunks.
        existing_info (`dict`):
            Additional PNG info. `existing_info == {pngsectionname: info, ...}`
        no_prompt:
            TODO I don't know its meaning.
        p (`StableDiffusionProcessing`)
        forced_filename (`str`):
            If specified, `basename` and filename pattern will be ignored.
        save_to_dirs (bool):
            If true, the image will be saved into a subdirectory of `path`.

    Returns: (fullfn, txt_fullfn)
        fullfn (`str`):
            The full path of the saved imaged.
        txt_fullfn (`str` or None):
            If a text file is saved for this image, this will be its full path. Otherwise None.
452
    """
Y
Yaiol 已提交
453
    namegen = FilenameGenerator(p, seed, prompt, image)
454

M
Milly 已提交
455 456
    if save_to_dirs is None:
        save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
457

458
    if save_to_dirs:
459
        dirname = namegen.apply(opts.directories_filename_pattern or "[prompt_words]").lstrip(' ').rstrip('\\ /')
460
        path = os.path.join(path, dirname)
461

T
timntorres 已提交
462
    os.makedirs(path, exist_ok=True)
463

464
    if forced_filename is None:
465
        if short_filename or seed is None:
W
w-e-w 已提交
466
            file_decoration = ""
W
w-e-w 已提交
467
        elif opts.save_to_dirs:
468
            file_decoration = opts.samples_filename_pattern or "[seed]"
W
w-e-w 已提交
469 470
        else:
            file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
471 472

        add_number = opts.save_images_add_number or file_decoration == ''
W
w-e-w 已提交
473

474
        if file_decoration != "" and add_number:
475
            file_decoration = "-" + file_decoration
W
w-e-w 已提交
476

477 478 479 480 481 482 483 484 485 486 487 488
        file_decoration = namegen.apply(file_decoration) + suffix

        if add_number:
            basecount = get_next_sequence_number(path, basename)
            fullfn = None
            for i in range(500):
                fn = f"{basecount + i:05}" if basename == '' else f"{basename}-{basecount + i:04}"
                fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
                if not os.path.exists(fullfn):
                    break
        else:
            fullfn = os.path.join(path, f"{file_decoration}.{extension}")
489 490
    else:
        fullfn = os.path.join(path, f"{forced_filename}.{extension}")
491 492 493 494 495 496 497 498 499 500 501 502

    pnginfo = existing_info or {}
    if info is not None:
        pnginfo[pnginfo_section_name] = info

    params = script_callbacks.ImageSaveParams(image, p, fullfn, pnginfo)
    script_callbacks.before_image_saved_callback(params)

    image = params.image
    fullfn = params.filename
    info = params.pnginfo.get(pnginfo_section_name, None)
    fullfn_without_extension, extension = os.path.splitext(params.filename)
503

504
    def exif_bytes():
505
        return piexif.dump({
506
            "Exif": {
507
                piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
J
JJ 已提交
508
            },
509 510
        })

511 512
    if extension.lower() == '.png':
        pnginfo_data = PngImagePlugin.PngInfo()
513 514 515
        if opts.enable_pnginfo:
            for k, v in params.pnginfo.items():
                pnginfo_data.add_text(k, str(v))
516 517 518 519

        image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo_data)

    elif extension.lower() in (".jpg", ".jpeg", ".webp"):
520
        image.save(fullfn, quality=opts.jpeg_quality)
521

522 523
        if opts.enable_pnginfo and info is not None:
            piexif.insert(exif_bytes(), fullfn)
524
    else:
525
        image.save(fullfn, quality=opts.jpeg_quality)
526 527 528 529 530 531 532 533 534 535 536

    target_side_length = 4000
    oversize = image.width > target_side_length or image.height > target_side_length
    if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024):
        ratio = image.width / image.height

        if oversize and ratio > 1:
            image = image.resize((target_side_length, image.height * target_side_length // image.width), LANCZOS)
        elif oversize:
            image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)

537 538
        image.save(fullfn_without_extension + ".jpg", quality=opts.jpeg_quality)
        if opts.enable_pnginfo and info is not None:
A
AUTOMATIC 已提交
539
            piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg")
540 541

    if opts.save_txt and info is not None:
A
aoirusann 已提交
542 543
        txt_fullfn = f"{fullfn_without_extension}.txt"
        with open(txt_fullfn, "w", encoding="utf8") as file:
544
            file.write(info + "\n")
A
aoirusann 已提交
545 546
    else:
        txt_fullfn = None
547

548 549
    script_callbacks.image_saved_callback(params)

A
aoirusann 已提交
550
    return fullfn, txt_fullfn
D
d8ahazard 已提交
551 552


A
AUTOMATIC 已提交
553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569
def image_data(data):
    try:
        image = Image.open(io.BytesIO(data))
        textinfo = image.text["parameters"]
        return textinfo, None
    except Exception:
        pass

    try:
        text = data.decode('utf8')
        assert len(text) < 10000
        return text, None

    except Exception:
        pass

    return '', None