shared.py 18.6 KB
Newer Older
1
import sys
2 3 4 5
import argparse
import json
import os
import gradio as gr
6
import tqdm
7

A
AUTOMATIC 已提交
8
import modules.artists
9
from modules.paths import script_path, sd_path
A
Abdullah Barhoum 已提交
10
from modules.devices import get_optimal_device
A
AUTOMATIC 已提交
11
import modules.styles
A
AUTOMATIC 已提交
12
import modules.interrogate
E
EyeDeck 已提交
13
import modules.memmon
14
import modules.sd_models
15 16

sd_model_file = os.path.join(script_path, 'model.ckpt')
17
default_sd_model_file = sd_model_file
18 19 20

parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(sd_path, "configs/stable-diffusion/v1-inference.yaml"), help="path to config which constructs model",)
21 22
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; this checkpoint will be added to the list of checkpoints and loaded by default if you don't have a checkpoint selected in settings",)
parser.add_argument("--ckpt-dir", type=str, default=os.path.join(script_path, 'models'), help="path to directory with stable diffusion checkpoints",)
23
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default=('./src/gfpgan' if os.path.exists('./src/gfpgan') else './GFPGAN'))
A
AUTOMATIC 已提交
24
parser.add_argument("--gfpgan-model", type=str, help="GFPGAN model file name", default=None)
25
parser.add_argument("--no-half", action='store_true', help="do not switch the model to 16-bit floats")
26
parser.add_argument("--no-progressbar-hiding", action='store_true', help="do not hide progressbar in gradio UI (we hide it because it slows down ML if you have hardware acceleration in browser)")
27
parser.add_argument("--max-batch-count", type=int, default=16, help="maximum batch count value for the UI")
28
parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_path, 'embeddings'), help="embeddings directory for textual inversion (default: embeddings)")
29
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
O
orionaskatu 已提交
30 31
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
32
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
33
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
34 35
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
A
AUTOMATIC 已提交
36
parser.add_argument("--esrgan-models-path", type=str, help="path to directory with ESRGAN models", default=os.path.join(script_path, 'ESRGAN'))
C
C43H66N12O12S2 已提交
37
parser.add_argument("--swinir-models-path", type=str, help="path to directory with SwinIR models", default=os.path.join(script_path, 'SwinIR'))
38 39
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables cross-attention layer optimization. By default, it's on for torch.cuda and off for other torch devices.")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
B
berkybear 已提交
40
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
41
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
O
orionaskatu 已提交
42
parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None)
A
AUTOMATIC 已提交
43
parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False)
44
parser.add_argument("--ui-config-file", type=str, help="filename to use for ui configuration", default=os.path.join(script_path, 'ui-config.json'))
45
parser.add_argument("--hide-ui-dir-config", action='store_true', help="hide directory configuration from webui", default=False)
46
parser.add_argument("--ui-settings-file", type=str, help="filename to use for ui settings", default=os.path.join(script_path, 'config.json'))
A
AUTOMATIC 已提交
47
parser.add_argument("--gradio-debug",  action='store_true', help="launch gradio with --debug option")
E
EyeDeck 已提交
48
parser.add_argument("--gradio-auth", type=str, help='set gradio authentication like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
49
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
50
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(script_path, 'styles.csv'))
51
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
52 53
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)

A
AUTOMATIC 已提交
54
cmd_opts = parser.parse_args()
55

A
Abdullah Barhoum 已提交
56 57
device = get_optimal_device()

58
batch_cond_uncond = cmd_opts.always_batch_cond_uncond or not (cmd_opts.lowvram or cmd_opts.medvram)
59
parallel_processing_allowed = not cmd_opts.lowvram and not cmd_opts.medvram
A
AUTOMATIC 已提交
60

61
config_filename = cmd_opts.ui_settings_file
A
AUTOMATIC 已提交
62

63 64 65
class State:
    interrupted = False
    job = ""
A
AUTOMATIC 已提交
66 67 68 69
    job_no = 0
    job_count = 0
    sampling_step = 0
    sampling_steps = 0
A
AUTOMATIC 已提交
70 71
    current_latent = None
    current_image = None
72 73
    current_image_sampling_step = 0

74 75 76
    def interrupt(self):
        self.interrupted = True

A
AUTOMATIC 已提交
77 78 79
    def nextjob(self):
        self.job_no += 1
        self.sampling_step = 0
80
        self.current_image_sampling_step = 0
A
AUTOMATIC 已提交
81

A
AUTOMATIC 已提交
82

83 84
state = State()

A
AUTOMATIC 已提交
85 86
artist_db = modules.artists.ArtistsDatabase(os.path.join(script_path, 'artists.csv'))

87
styles_filename = cmd_opts.styles_file
A
AUTOMATIC 已提交
88
prompt_styles = modules.styles.StyleDatabase(styles_filename)
89

A
AUTOMATIC 已提交
90 91
interrogator = modules.interrogate.InterrogateModels("interrogate")

A
AUTOMATIC 已提交
92
face_restorers = []
93

94 95 96
modules.sd_models.list_models()


A
AUTOMATIC 已提交
97 98 99 100 101
def realesrgan_models_names():
    import modules.realesrgan_model
    return [x.name for x in modules.realesrgan_model.get_realesrgan_models()]


102 103 104 105 106 107 108
class OptionInfo:
    def __init__(self, default=None, label="", component=None, component_args=None, onchange=None):
        self.default = default
        self.label = label
        self.component = component
        self.component_args = component_args
        self.onchange = onchange
109
        self.section = None
110 111 112 113 114 115 116 117 118 119 120 121 122


def options_section(section_identifer, options_dict):
    for k, v in options_dict.items():
        v.section = section_identifer

    return options_dict


hide_dirs = {"visible": False} if cmd_opts.hide_ui_dir_config else None

options_templates = {}

123 124 125 126 127 128 129 130 131 132 133
options_templates.update(options_section(('saving-images', "Saving images/grids"), {
    "samples_save": OptionInfo(True, "Always save all generated images"),
    "samples_format": OptionInfo('png', 'File format for images'),
    "samples_filename_pattern": OptionInfo("", "Images filename pattern"),

    "grid_save": OptionInfo(True, "Always save all generated image grids"),
    "grid_format": OptionInfo('png', 'File format for grids'),
    "grid_extended_filename": OptionInfo(False, "Add extended info (seed, prompt) to filename when saving grid"),
    "grid_only_if_multiple": OptionInfo(True, "Do not save grids consisting of one picture"),
    "n_rows": OptionInfo(-1, "Grid row count; use -1 for autodetect and 0 for it to be same as batch size", gr.Slider, {"minimum": -1, "maximum": 16, "step": 1}),

134 135 136
    "enable_pnginfo": OptionInfo(True, "Save text information about generation parameters as chunks to png files"),
    "save_txt": OptionInfo(False, "Create a text file next to every image with generation parameters."),
    "save_images_before_face_restoration": OptionInfo(False, "Save a copy of image before doing face restoration."),
137 138 139 140
    "jpeg_quality": OptionInfo(80, "Quality for saved jpeg images", gr.Slider, {"minimum": 1, "maximum": 100, "step": 1}),
    "export_for_4chan": OptionInfo(True, "If PNG image is larger than 4MB or any dimension is larger than 4000, downscale and save copy as JPG"),

    "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
141 142
}))

143 144
options_templates.update(options_section(('saving-paths', "Paths for saving"), {
    "outdir_samples": OptionInfo("", "Output directory for images; if empty, defaults to three directories below", component_args=hide_dirs),
145 146 147 148 149 150 151 152 153
    "outdir_txt2img_samples": OptionInfo("outputs/txt2img-images", 'Output directory for txt2img images', component_args=hide_dirs),
    "outdir_img2img_samples": OptionInfo("outputs/img2img-images", 'Output directory for img2img images', component_args=hide_dirs),
    "outdir_extras_samples": OptionInfo("outputs/extras-images", 'Output directory for images from extras tab', component_args=hide_dirs),
    "outdir_grids": OptionInfo("", "Output directory for grids; if empty, defaults to two directories below", component_args=hide_dirs),
    "outdir_txt2img_grids": OptionInfo("outputs/txt2img-grids", 'Output directory for txt2img grids', component_args=hide_dirs),
    "outdir_img2img_grids": OptionInfo("outputs/img2img-grids", 'Output directory for img2img grids', component_args=hide_dirs),
    "outdir_save": OptionInfo("log/images", "Directory for saving images using the Save button", component_args=hide_dirs),
}))

154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
options_templates.update(options_section(('saving-to-dirs', "Saving to a directory"), {
    "save_to_dirs": OptionInfo(False, "Save images to a subdirectory"),
    "grid_save_to_dirs": OptionInfo(False, "Save grids to subdirectory"),
    "directories_filename_pattern": OptionInfo("", "Directory name pattern"),
    "directories_max_prompt_words": OptionInfo(8, "Max prompt words", gr.Slider, {"minimum": 1, "maximum": 20, "step": 1}),
}))

options_templates.update(options_section(('upscaling', "Upscaling"), {
    "ESRGAN_tile": OptionInfo(192, "Tile size for ESRGAN upscalers. 0 = no tiling.", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
    "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
    "realesrgan_enabled_models": OptionInfo(["Real-ESRGAN 4x plus", "Real-ESRGAN 4x plus anime 6B"], "Select which RealESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}),
    "SWIN_tile": OptionInfo(192, "Tile size for all SwinIR.", gr.Slider, {"minimum": 16, "maximum": 512, "step": 16}),
    "SWIN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for SwinIR. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}),
    "ldsr_steps": OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}),
    "ldsr_pre_down": OptionInfo(1, "LDSR Pre-process downssample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),
    "ldsr_post_down": OptionInfo(1, "LDSR Post-process down-sample scale. 1 = no down-sampling, 4 = 1/4 scale.", gr.Slider, {"minimum": 1, "maximum": 4, "step": 1}),

171
    "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Radio, lambda: {"choices": [x.name for x in sd_upscalers]}),
172 173 174 175 176 177
}))

options_templates.update(options_section(('face-restoration', "Face restoration"), {
    "face_restoration_model": OptionInfo(None, "Face restoration model", gr.Radio, lambda: {"choices": [x.name() for x in face_restorers]}),
    "code_former_weight": OptionInfo(0.5, "CodeFormer weight parameter; 0 = maximum effect; 1 = minimum effect", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
    "face_restoration_unload": OptionInfo(False, "Move face restoration model from VRAM into RAM after processing"),
178 179 180
    "save_selected_only": OptionInfo(False, "When using 'Save' button, only save a single selected image"),
}))

181 182 183 184
options_templates.update(options_section(('system', "System"), {
    "memmon_poll_rate": OptionInfo(8, "VRAM usage polls per second during generation. Set to 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 40, "step": 1}),
    "samples_log_stdout": OptionInfo(False, "Always print all generation info to standard output"),
    "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job. Broken in PyCharm console."),
185 186
}))

187 188
options_templates.update(options_section(('sd', "Stable Diffusion"), {
    "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Radio, lambda: {"choices": [x.title for x in modules.sd_models.checkpoints_list.values()]}),
189
    "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."),
R
Robin Fernandes 已提交
190
    "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"),    
191 192 193 194
    "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."),
    "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."),
    "enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text and [text] to make it pay less attention"),
    "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"),
195 196
    "filter_nsfw": OptionInfo(False, "Filter NSFW content"),
    "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}),
197 198
}))

199
options_templates.update(options_section(('interrogate', "Interrogate Options"), {
200 201 202 203 204 205 206 207
    "interrogate_keep_models_in_memory": OptionInfo(False, "Interrogate: keep models in VRAM"),
    "interrogate_use_builtin_artists": OptionInfo(True, "Interrogate: use artists from artists.csv"),
    "interrogate_clip_num_beams": OptionInfo(1, "Interrogate: num_beams for BLIP", gr.Slider, {"minimum": 1, "maximum": 16, "step": 1}),
    "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
    "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
    "interrogate_clip_dict_limit": OptionInfo(1500, "Interrogate: maximum number of lines in text file (0 = No limit)"),
}))

208 209 210 211 212 213 214 215 216 217
options_templates.update(options_section(('ui', "User interface"), {
    "show_progressbar": OptionInfo(True, "Show progressbar"),
    "show_progress_every_n_steps": OptionInfo(0, "Show show image creation progress every N sampling steps. Set 0 to disable.", gr.Slider, {"minimum": 0, "maximum": 32, "step": 1}),
    "return_grid": OptionInfo(True, "Show grid in results for web"),
    "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"),
    "font": OptionInfo("", "Font for image grids that have text"),
    "js_modal_lightbox": OptionInfo(True, "Enable full page image viewer"),
    "js_modal_lightbox_initialy_zoomed": OptionInfo(True, "Show images zoomed in by default in full page image viewer"),
}))

218

219
class Options:
220
    data = None
221
    data_labels = options_templates
222
    typemap = {int: float}
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247

    def __init__(self):
        self.data = {k: v.default for k, v in self.data_labels.items()}

    def __setattr__(self, key, value):
        if self.data is not None:
            if key in self.data:
                self.data[key] = value

        return super(Options, self).__setattr__(key, value)

    def __getattr__(self, item):
        if self.data is not None:
            if item in self.data:
                return self.data[item]

        if item in self.data_labels:
            return self.data_labels[item].default

        return super(Options, self).__getattribute__(item)

    def save(self, filename):
        with open(filename, "w", encoding="utf8") as file:
            json.dump(self.data, file)

248 249 250
    def same_type(self, x, y):
        if x is None or y is None:
            return True
251

252 253
        type_x = self.typemap.get(type(x), type(x))
        type_y = self.typemap.get(type(y), type(y))
254

255
        return type_x == type_y
256

257 258 259
    def load(self, filename):
        with open(filename, "r", encoding="utf8") as file:
            self.data = json.load(file)
260 261 262 263

        bad_settings = 0
        for k, v in self.data.items():
            info = self.data_labels.get(k, None)
264
            if info is not None and not self.same_type(info.default, v):
265 266 267 268 269 270
                print(f"Warning: bad setting value: {k}: {v} ({type(v).__name__}; expected {type(info.default).__name__})", file=sys.stderr)
                bad_settings += 1

        if bad_settings > 0:
            print(f"The program is likely to not work with bad settings.\nSettings file: {filename}\nEither fix the file, or delete it and restart.", file=sys.stderr)

271 272 273 274
    def onchange(self, key, func):
        item = self.data_labels.get(key)
        item.onchange = func

275 276 277 278
    def dumpjson(self):
        d = {k: self.data.get(k, self.data_labels.get(k).default) for k in self.data_labels.keys()}
        return json.dumps(d)

279 280 281 282 283

opts = Options()
if os.path.exists(config_filename):
    opts.load(config_filename)

A
AUTOMATIC 已提交
284
sd_upscalers = []
285 286

sd_model = None
A
AUTOMATIC 已提交
287

288
progress_print_out = sys.stdout
A
AUTOMATIC 已提交
289

290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309

class TotalTQDM:
    def __init__(self):
        self._tqdm = None

    def reset(self):
        self._tqdm = tqdm.tqdm(
            desc="Total progress",
            total=state.job_count * state.sampling_steps,
            position=1,
            file=progress_print_out
        )

    def update(self):
        if not opts.multiple_tqdm:
            return
        if self._tqdm is None:
            self.reset()
        self._tqdm.update()

310 311 312 313 314 315 316
    def updateTotal(self, new_total):
        if not opts.multiple_tqdm:
            return
        if self._tqdm is None:
            self.reset()
        self._tqdm.total=new_total

317 318 319 320 321 322 323
    def clear(self):
        if self._tqdm is not None:
            self._tqdm.close()
            self._tqdm = None


total_tqdm = TotalTQDM()
E
EyeDeck 已提交
324 325 326

mem_mon = modules.memmon.MemUsageMonitor("MemMon", device, opts)
mem_mon.start()