img2img.py 5.3 KB
Newer Older
1
import math
2
import numpy as np
3
from PIL import Image, ImageOps, ImageChops
4

5
from modules import devices
6 7 8 9 10 11
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
import modules.shared as shared
import modules.processing as processing
from modules.ui import plaintext_to_html
import modules.images as images
A
AUTOMATIC 已提交
12
import modules.scripts
13

A
AUTOMATIC 已提交
14
def img2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_mask, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, height: int, width: int, resize_mode: int, upscaler_index: str, upscale_overlap: int, inpaint_full_res: bool, inpainting_mask_invert: int, *args):
15
    is_inpaint = mode == 1
16
    is_upscale = mode == 2
17 18

    if is_inpaint:
19 20 21 22 23 24 25 26 27
        if mask_mode == 0:
            image = init_img_with_mask['image']
            mask = init_img_with_mask['mask']
            alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
            mask = ImageChops.lighter(alpha_mask, mask.convert('L')).convert('L')
            image = image.convert('RGB')
        else:
            image = init_img
            mask = init_mask
28 29 30 31 32 33 34 35 36 37 38
    else:
        image = init_img
        mask = None

    assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'

    p = StableDiffusionProcessingImg2Img(
        sd_model=shared.sd_model,
        outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
        outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
        prompt=prompt,
39
        negative_prompt=negative_prompt,
A
AUTOMATIC 已提交
40
        styles=[prompt_style, prompt_style2],
41
        seed=seed,
42 43 44 45
        subseed=subseed,
        subseed_strength=subseed_strength,
        seed_resize_from_h=seed_resize_from_h,
        seed_resize_from_w=seed_resize_from_w,
46 47 48 49 50 51 52
        sampler_index=sampler_index,
        batch_size=batch_size,
        n_iter=n_iter,
        steps=steps,
        cfg_scale=cfg_scale,
        width=width,
        height=height,
A
AUTOMATIC 已提交
53
        restore_faces=restore_faces,
54
        tiling=tiling,
55 56 57 58 59 60 61
        init_images=[image],
        mask=mask,
        mask_blur=mask_blur,
        inpainting_fill=inpainting_fill,
        resize_mode=resize_mode,
        denoising_strength=denoising_strength,
        inpaint_full_res=inpaint_full_res,
A
AUTOMATIC 已提交
62
        inpainting_mask_invert=inpainting_mask_invert,
63
    )
64
    print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
65

A
AUTOMATIC 已提交
66 67
    p.extra_generation_params["Mask blur"] = mask_blur

68
    if is_upscale:
69 70
        initial_info = None

A
AUTOMATIC 已提交
71 72 73
        processing.fix_seed(p)
        seed = p.seed

A
AUTOMATIC 已提交
74 75
        upscaler = shared.sd_upscalers[upscaler_index]
        img = upscaler.upscale(init_img, init_img.width * 2, init_img.height * 2)
76

77
        devices.torch_gc()
78 79 80

        grid = images.split_grid(img, tile_w=width, tile_h=height, overlap=upscale_overlap)

81
        batch_size = p.batch_size
A
AUTOMATIC 已提交
82
        upscale_count = p.n_iter
83 84 85 86 87 88 89 90 91 92
        p.n_iter = 1
        p.do_not_save_grid = True
        p.do_not_save_samples = True

        work = []

        for y, h, row in grid.tiles:
            for tiledata in row:
                work.append(tiledata[2])

93
        batch_count = math.ceil(len(work) / batch_size)
A
AUTOMATIC 已提交
94
        state.job_count = batch_count * upscale_count
95

A
AUTOMATIC 已提交
96
        print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} per upscale in a total of {state.job_count} batches.")
A
AUTOMATIC 已提交
97

A
AUTOMATIC 已提交
98 99 100 101
        result_images = []
        for n in range(upscale_count):
            start_seed = seed + n
            p.seed = start_seed
102

A
AUTOMATIC 已提交
103 104
            work_results = []
            for i in range(batch_count):
105 106
                p.batch_size = batch_size
                p.init_images = work[i*batch_size:(i+1)*batch_size]
107

108
                state.job = f"Batch {i + 1 + n * batch_count} out of {state.job_count}"
A
AUTOMATIC 已提交
109
                processed = process_images(p)
110

A
AUTOMATIC 已提交
111 112
                if initial_info is None:
                    initial_info = processed.info
113

A
AUTOMATIC 已提交
114 115 116 117 118 119 120 121
                p.seed = processed.seed + 1
                work_results += processed.images

            image_index = 0
            for y, h, row in grid.tiles:
                for tiledata in row:
                    tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
                    image_index += 1
122

A
AUTOMATIC 已提交
123 124
            combined_image = images.combine_grid(grid)
            result_images.append(combined_image)
125

A
AUTOMATIC 已提交
126
            if opts.samples_save:
127
                images.save_image(combined_image, p.outpath_samples, "", start_seed, prompt, opts.samples_format, info=initial_info, p=p)
128

A
AUTOMATIC 已提交
129
        processed = Processed(p, result_images, seed, initial_info)
130 131

    else:
A
AUTOMATIC 已提交
132

A
AUTOMATIC 已提交
133
        processed = modules.scripts.scripts_img2img.run(p, *args)
A
AUTOMATIC 已提交
134 135 136 137

        if processed is None:
            processed = process_images(p)

138
    shared.total_tqdm.clear()
139 140

    return processed.images, processed.js(), plaintext_to_html(processed.info)