images.py 15.6 KB
Newer Older
1
import datetime
2 3 4 5 6 7
import math
import os
from collections import namedtuple
import re

import numpy as np
J
JJ 已提交
8 9
import piexif
import piexif.helper
10
from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
11
from fonts.ttf import Roboto
12
import string
13

A
AUTOMATIC 已提交
14
import modules.shared
15
from modules import sd_samplers, shared
16
from modules.shared import opts, cmd_opts
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48

LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)


def image_grid(imgs, batch_size=1, rows=None):
    if rows is None:
        if opts.n_rows > 0:
            rows = opts.n_rows
        elif opts.n_rows == 0:
            rows = batch_size
        else:
            rows = math.sqrt(len(imgs))
            rows = round(rows)

    cols = math.ceil(len(imgs) / rows)

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols * w, rows * h), color='black')

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))

    return grid


Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])


def split_grid(image, tile_w=512, tile_h=512, overlap=64):
    w = image.width
    h = image.height

A
AUTOMATIC 已提交
49 50
    non_overlap_width = tile_w - overlap
    non_overlap_height = tile_h - overlap
51

A
AUTOMATIC 已提交
52 53 54
    cols = math.ceil((w - overlap) / non_overlap_width)
    rows = math.ceil((h - overlap) / non_overlap_height)

A
AUTOMATIC 已提交
55 56
    dx = (w - tile_w) / (cols-1) if cols > 1 else 0
    dy = (h - tile_h) / (rows-1) if rows > 1 else 0
57 58 59 60 61

    grid = Grid([], tile_w, tile_h, w, h, overlap)
    for row in range(rows):
        row_images = []

A
AUTOMATIC 已提交
62
        y = int(row * dy)
63 64 65 66 67

        if y + tile_h >= h:
            y = h - tile_h

        for col in range(cols):
A
AUTOMATIC 已提交
68
            x = int(col * dx)
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140

            if x+tile_w >= w:
                x = w - tile_w

            tile = image.crop((x, y, x + tile_w, y + tile_h))

            row_images.append([x, tile_w, tile])

        grid.tiles.append([y, tile_h, row_images])

    return grid


def combine_grid(grid):
    def make_mask_image(r):
        r = r * 255 / grid.overlap
        r = r.astype(np.uint8)
        return Image.fromarray(r, 'L')

    mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
    mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))

    combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
    for y, h, row in grid.tiles:
        combined_row = Image.new("RGB", (grid.image_w, h))
        for x, w, tile in row:
            if x == 0:
                combined_row.paste(tile, (0, 0))
                continue

            combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
            combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))

        if y == 0:
            combined_image.paste(combined_row, (0, 0))
            continue

        combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
        combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))

    return combined_image


class GridAnnotation:
    def __init__(self, text='', is_active=True):
        self.text = text
        self.is_active = is_active
        self.size = None


def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
    def wrap(drawing, text, font, line_length):
        lines = ['']
        for word in text.split():
            line = f'{lines[-1]} {word}'.strip()
            if drawing.textlength(line, font=font) <= line_length:
                lines[-1] = line
            else:
                lines.append(word)
        return lines

    def draw_texts(drawing, draw_x, draw_y, lines):
        for i, line in enumerate(lines):
            drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")

            if not line.is_active:
                drawing.line((draw_x - line.size[0]//2, draw_y + line.size[1]//2, draw_x + line.size[0]//2, draw_y + line.size[1]//2), fill=color_inactive, width=4)

            draw_y += line.size[1] + line_spacing

    fontsize = (width + height) // 25
    line_spacing = fontsize // 2
141 142 143 144 145 146

    try:
        fnt = ImageFont.truetype(opts.font or Roboto, fontsize)
    except Exception:
        fnt = ImageFont.truetype(Roboto, fontsize)

147 148 149
    color_active = (0, 0, 0)
    color_inactive = (153, 153, 153)

150
    pad_left = 0 if sum([sum([len(line.text) for line in lines]) for lines in ver_texts]) == 0 else width * 3 // 4
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211

    cols = im.width // width
    rows = im.height // height

    assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
    assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'

    calc_img = Image.new("RGB", (1, 1), "white")
    calc_d = ImageDraw.Draw(calc_img)

    for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
        items = [] + texts
        texts.clear()

        for line in items:
            wrapped = wrap(calc_d, line.text, fnt, allowed_width)
            texts += [GridAnnotation(x, line.is_active) for x in wrapped]

        for line in texts:
            bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
            line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])

    hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
    ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]

    pad_top = max(hor_text_heights) + line_spacing * 2

    result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
    result.paste(im, (pad_left, pad_top))

    d = ImageDraw.Draw(result)

    for col in range(cols):
        x = pad_left + width * col + width / 2
        y = pad_top / 2 - hor_text_heights[col] / 2

        draw_texts(d, x, y, hor_texts[col])

    for row in range(rows):
        x = pad_left / 2
        y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2

        draw_texts(d, x, y, ver_texts[row])

    return result


def draw_prompt_matrix(im, width, height, all_prompts):
    prompts = all_prompts[1:]
    boundary = math.ceil(len(prompts) / 2)

    prompts_horiz = prompts[:boundary]
    prompts_vert = prompts[boundary:]

    hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
    ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]

    return draw_grid_annotations(im, width, height, hor_texts, ver_texts)


def resize_image(resize_mode, im, width, height):
212
    def resize(im, w, h):
213
        if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L':
214 215 216 217 218
            return im.resize((w, h), resample=LANCZOS)

        upscaler = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img][0]
        return upscaler.upscale(im, w, h)

219
    if resize_mode == 0:
220 221
        res = resize(im, width, height)

222 223 224 225 226 227 228
    elif resize_mode == 1:
        ratio = width / height
        src_ratio = im.width / im.height

        src_w = width if ratio > src_ratio else im.width * height // im.height
        src_h = height if ratio <= src_ratio else im.height * width // im.width

229
        resized = resize(im, src_w, src_h)
230 231
        res = Image.new("RGB", (width, height))
        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
232

233 234 235 236 237 238 239
    else:
        ratio = width / height
        src_ratio = im.width / im.height

        src_w = width if ratio < src_ratio else im.width * height // im.height
        src_h = height if ratio >= src_ratio else im.height * width // im.width

240
        resized = resize(im, src_w, src_h)
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
        res = Image.new("RGB", (width, height))
        res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))

        if ratio < src_ratio:
            fill_height = height // 2 - src_h // 2
            res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
            res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
        elif ratio > src_ratio:
            fill_width = width // 2 - src_w // 2
            res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
            res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))

    return res


invalid_filename_chars = '<>:"/\\|?*\n'
M
Milly 已提交
257 258
invalid_filename_prefix = ' '
invalid_filename_postfix = ' .'
259
re_nonletters = re.compile(r'[\s'+string.punctuation+']+')
M
Milly 已提交
260
max_filename_part_length = 128
261 262


263 264 265
def sanitize_filename_part(text, replace_spaces=True):
    if replace_spaces:
        text = text.replace(' ', '_')
266

M
Milly 已提交
267 268 269 270
    text = text.translate({ord(x): '_' for x in invalid_filename_chars})
    text = text.lstrip(invalid_filename_prefix)[:max_filename_part_length]
    text = text.rstrip(invalid_filename_postfix)
    return text
271

272

273
def apply_filename_pattern(x, p, seed, prompt):
274 275
    max_prompt_words = opts.directories_max_prompt_words

276 277
    if seed is not None:
        x = x.replace("[seed]", str(seed))
M
Milly 已提交
278

279
    if prompt is not None:
M
Milly 已提交
280 281
        x = x.replace("[prompt]", sanitize_filename_part(prompt))
        x = x.replace("[prompt_spaces]", sanitize_filename_part(prompt, replace_spaces=False))
282 283 284 285
        if "[prompt_words]" in x:
            words = [x for x in re_nonletters.split(prompt or "") if len(x) > 0]
            if len(words) == 0:
                words = ["empty"]
M
Milly 已提交
286
            x = x.replace("[prompt_words]", sanitize_filename_part(" ".join(words[0:max_prompt_words]), replace_spaces=False))
287 288 289 290 291 292

    if p is not None:
        x = x.replace("[steps]", str(p.steps))
        x = x.replace("[cfg]", str(p.cfg_scale))
        x = x.replace("[width]", str(p.width))
        x = x.replace("[height]", str(p.height))
R
RnDMonkey 已提交
293
        x = x.replace("[styles]", sanitize_filename_part(", ".join(p.styles), replace_spaces=False))
M
Milly 已提交
294
        x = x.replace("[sampler]", sanitize_filename_part(sd_samplers.samplers[p.sampler_index].name, replace_spaces=False))
295

296
    x = x.replace("[model_hash]", shared.sd_model.sd_model_hash)
297
    x = x.replace("[date]", datetime.date.today().isoformat())
E
Eyrie 已提交
298
    x = x.replace("[job_timestamp]", shared.state.job_timestamp)
299

300 301 302
    if cmd_opts.hide_ui_dir_config:
        x = re.sub(r'^[\\/]+|\.{2,}[\\/]+|[\\/]+\.{2,}', '', x)

303 304
    return x

M
Michoko 已提交
305
def get_next_sequence_number(path, basename):
M
Michoko 已提交
306 307 308 309 310 311
    """
    Determines and returns the next sequence number to use when saving an image in the specified directory.

    The sequence starts at 0.
    """
    result = -1
M
Michoko 已提交
312 313 314 315
    if basename != '':
        basename = basename + "-"

    prefix_length = len(basename)
M
Michoko 已提交
316
    for p in os.listdir(path):
M
Michoko 已提交
317 318 319
        if p.startswith(basename):
            l = os.path.splitext(p[prefix_length:])[0].split('-') #splits the filename (removing the basename first if one is defined, so the sequence number is always the first element)
            try:
M
Michoko 已提交
320
                result = max(int(l[0]), result)
M
Michoko 已提交
321
            except ValueError:
M
Michoko 已提交
322 323
                pass

M
Michoko 已提交
324
    return result + 1
325

326
def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False, grid=False, pnginfo_section_name='parameters', p=None, existing_info=None, forced_filename=None, suffix=""):
327 328 329
    if short_filename or prompt is None or seed is None:
        file_decoration = ""
    elif opts.save_to_dirs:
330
        file_decoration = opts.samples_filename_pattern or "[seed]"
331
    else:
332
        file_decoration = opts.samples_filename_pattern or "[seed]-[prompt_spaces]"
M
Michoko 已提交
333

334 335 336
    if file_decoration != "":
        file_decoration = "-" + file_decoration.lower()

337
    file_decoration = apply_filename_pattern(file_decoration, p, seed, prompt) + suffix
338

339 340
    if extension == 'png' and opts.enable_pnginfo and info is not None:
        pnginfo = PngImagePlugin.PngInfo()
341 342 343

        if existing_info is not None:
            for k, v in existing_info.items():
O
oobabooga 已提交
344
                pnginfo.add_text(k, str(v))
345

346
        pnginfo.add_text(pnginfo_section_name, info)
347 348 349
    else:
        pnginfo = None

J
JustAnOkapi 已提交
350
    save_to_dirs = (grid and opts.grid_save_to_dirs) or (not grid and opts.save_to_dirs and not no_prompt)
351

352 353
    if save_to_dirs:
        dirname = apply_filename_pattern(opts.directories_filename_pattern or "[prompt_words]", p, seed, prompt)
354 355 356 357
        path = os.path.join(path, dirname)

    os.makedirs(path, exist_ok=True)

358 359 360 361 362 363 364 365 366 367 368 369 370
    if forced_filename is None:
        basecount = get_next_sequence_number(path, basename)
        fullfn = "a.png"
        fullfn_without_extension = "a"
        for i in range(500):
            fn = f"{basecount+i:05}" if basename == '' else f"{basename}-{basecount+i:04}"
            fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
            fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
            if not os.path.exists(fullfn):
                break
    else:
        fullfn = os.path.join(path, f"{forced_filename}.{extension}")
        fullfn_without_extension = os.path.join(path, forced_filename)
371

372
    def exif_bytes():
373
        return piexif.dump({
374
            "Exif": {
375
                piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(info or "", encoding="unicode")
J
JJ 已提交
376
            },
377 378
        })

379
    if extension.lower() in ("jpg", "jpeg", "webp"):
380 381 382
        image.save(fullfn, quality=opts.jpeg_quality)
        if opts.enable_pnginfo and info is not None:
            piexif.insert(exif_bytes(), fullfn)
383 384
    else:
        image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo)
385 386 387 388 389 390 391 392 393 394 395

    target_side_length = 4000
    oversize = image.width > target_side_length or image.height > target_side_length
    if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024):
        ratio = image.width / image.height

        if oversize and ratio > 1:
            image = image.resize((target_side_length, image.height * target_side_length // image.width), LANCZOS)
        elif oversize:
            image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)

396 397
        image.save(fullfn_without_extension + ".jpg", quality=opts.jpeg_quality)
        if opts.enable_pnginfo and info is not None:
A
AUTOMATIC 已提交
398
            piexif.insert(exif_bytes(), fullfn_without_extension + ".jpg")
399 400 401 402 403

    if opts.save_txt and info is not None:
        with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file:
            file.write(info + "\n")

A
AUTOMATIC 已提交
404 405 406 407 408 409 410 411 412 413 414 415 416 417 418

class Upscaler:
    name = "Lanczos"

    def do_upscale(self, img):
        return img

    def upscale(self, img, w, h):
        for i in range(3):
            if img.width >= w and img.height >= h:
                break

            img = self.do_upscale(img)

        if img.width != w or img.height != h:
419
            img = img.resize((int(w), int(h)), resample=LANCZOS)
A
AUTOMATIC 已提交
420 421 422 423 424 425 426 427 428 429 430 431 432

        return img


class UpscalerNone(Upscaler):
    name = "None"

    def upscale(self, img, w, h):
        return img


modules.shared.sd_upscalers.append(UpscalerNone())
modules.shared.sd_upscalers.append(Upscaler())