styles.py 3.6 KB
Newer Older
1 2 3
# We need this so Python doesn't complain about the unknown StableDiffusionProcessing-typehint at runtime
from __future__ import annotations

A
AUTOMATIC 已提交
4
import csv
5
import os
A
AUTOMATIC 已提交
6
import os.path
7 8 9 10
import typing
import collections.abc as abc
import tempfile
import shutil
A
AUTOMATIC 已提交
11

12 13 14
if typing.TYPE_CHECKING:
    # Only import this when code is being type-checked, it doesn't have any effect at runtime
    from .processing import StableDiffusionProcessing
A
AUTOMATIC 已提交
15 16


17 18 19 20
class PromptStyle(typing.NamedTuple):
    name: str
    prompt: str
    negative_prompt: str
A
AUTOMATIC 已提交
21 22


A
AUTOMATIC 已提交
23 24 25 26 27 28
def merge_prompts(style_prompt: str, prompt: str) -> str:
    if "{prompt}" in style_prompt:
        res = style_prompt.replace("{prompt}", prompt)
    else:
        parts = filter(None, (prompt.strip(), style_prompt.strip()))
        res = ", ".join(parts)
A
AUTOMATIC 已提交
29

A
AUTOMATIC 已提交
30
    return res
A
AUTOMATIC 已提交
31 32


A
AUTOMATIC 已提交
33 34 35
def apply_styles_to_prompt(prompt, styles):
    for style in styles:
        prompt = merge_prompts(style, prompt)
36

A
AUTOMATIC 已提交
37
    return prompt
A
AUTOMATIC 已提交
38 39


A
AUTOMATIC 已提交
40 41 42 43
class StyleDatabase:
    def __init__(self, path: str):
        self.no_style = PromptStyle("None", "", "")
        self.styles = {"None": self.no_style}
A
AUTOMATIC 已提交
44

A
AUTOMATIC 已提交
45 46 47 48 49 50 51 52 53 54 55
        if not os.path.exists(path):
            return

        with open(path, "r", encoding="utf8", newline='') as file:
            reader = csv.DictReader(file)
            for row in reader:
                # Support loading old CSV format with "name, text"-columns
                prompt = row["prompt"] if "prompt" in row else row["text"]
                negative_prompt = row.get("negative_prompt", "")
                self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)

56 57 58 59 60 61
    def get_style_prompts(self, styles):
        return [self.styles.get(x, self.no_style).prompt for x in styles]

    def get_negative_style_prompts(self, styles):
        return [self.styles.get(x, self.no_style).negative_prompt for x in styles]

A
AUTOMATIC 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
    def apply_styles_to_prompt(self, prompt, styles):
        return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])

    def apply_negative_styles_to_prompt(self, prompt, styles):
        return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])

    def apply_styles(self, p: StableDiffusionProcessing) -> None:
        if isinstance(p.prompt, list):
            p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt]
        else:
            p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles)

        if isinstance(p.negative_prompt, list):
            p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt]
        else:
            p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)

    def save_styles(self, path: str) -> None:
        # Write to temporary file first, so we don't nuke the file if something goes wrong
        fd, temp_path = tempfile.mkstemp(".csv")
        with os.fdopen(fd, "w", encoding="utf8", newline='') as file:
            # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
            # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
            writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
            writer.writeheader()
            writer.writerows(style._asdict() for k,     style in self.styles.items())

        # Always keep a backup file around
        if os.path.exists(path):
            shutil.move(path, path + ".bak")
        shutil.move(temp_path, path)