styles.py 4.6 KB
Newer Older
A
AUTOMATIC 已提交
1
import csv
2
import os
A
AUTOMATIC 已提交
3
import os.path
4
import re
5 6
import typing
import shutil
A
AUTOMATIC 已提交
7 8


9 10 11 12
class PromptStyle(typing.NamedTuple):
    name: str
    prompt: str
    negative_prompt: str
A
AUTOMATIC 已提交
13 14


A
AUTOMATIC 已提交
15 16 17 18 19 20
def merge_prompts(style_prompt: str, prompt: str) -> str:
    if "{prompt}" in style_prompt:
        res = style_prompt.replace("{prompt}", prompt)
    else:
        parts = filter(None, (prompt.strip(), style_prompt.strip()))
        res = ", ".join(parts)
A
AUTOMATIC 已提交
21

A
AUTOMATIC 已提交
22
    return res
A
AUTOMATIC 已提交
23 24


A
AUTOMATIC 已提交
25 26 27
def apply_styles_to_prompt(prompt, styles):
    for style in styles:
        prompt = merge_prompts(style, prompt)
28

A
AUTOMATIC 已提交
29
    return prompt
A
AUTOMATIC 已提交
30 31


32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
re_spaces = re.compile("  +")


def extract_style_text_from_prompt(style_text, prompt):
    stripped_prompt = re.sub(re_spaces, " ", prompt.strip())
    stripped_style_text = re.sub(re_spaces, " ", style_text.strip())
    if "{prompt}" in stripped_style_text:
        left, right = stripped_style_text.split("{prompt}", 2)
        if stripped_prompt.startswith(left) and stripped_prompt.endswith(right):
            prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)]
            return True, prompt
    else:
        if stripped_prompt.endswith(stripped_style_text):
            prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)]

            if prompt.endswith(', '):
                prompt = prompt[:-2]

            return True, prompt

    return False, prompt


def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt):
    if not style.prompt and not style.negative_prompt:
        return False, prompt, negative_prompt

    match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt)
    if not match_positive:
        return False, prompt, negative_prompt

    match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt)
    if not match_negative:
        return False, prompt, negative_prompt

    return True, extracted_positive, extracted_negative


A
AUTOMATIC 已提交
70 71 72
class StyleDatabase:
    def __init__(self, path: str):
        self.no_style = PromptStyle("None", "", "")
73 74
        self.styles = {}
        self.path = path
A
AUTOMATIC 已提交
75

76 77 78 79 80 81
        self.reload()

    def reload(self):
        self.styles.clear()

        if not os.path.exists(self.path):
A
AUTOMATIC 已提交
82 83
            return

84
        with open(self.path, "r", encoding="utf-8-sig", newline='') as file:
85
            reader = csv.DictReader(file, skipinitialspace=True)
A
AUTOMATIC 已提交
86 87 88 89 90 91
            for row in reader:
                # Support loading old CSV format with "name, text"-columns
                prompt = row["prompt"] if "prompt" in row else row["text"]
                negative_prompt = row.get("negative_prompt", "")
                self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt)

92 93 94 95 96 97
    def get_style_prompts(self, styles):
        return [self.styles.get(x, self.no_style).prompt for x in styles]

    def get_negative_style_prompts(self, styles):
        return [self.styles.get(x, self.no_style).negative_prompt for x in styles]

A
AUTOMATIC 已提交
98 99 100 101 102 103 104
    def apply_styles_to_prompt(self, prompt, styles):
        return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles])

    def apply_negative_styles_to_prompt(self, prompt, styles):
        return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles])

    def save_styles(self, path: str) -> None:
105 106
        # Always keep a backup file around
        if os.path.exists(path):
107
            shutil.copy(path, f"{path}.bak")
108

A
AUTOMATIC1111 已提交
109
        with open(path, "w", encoding="utf-8-sig", newline='') as file:
A
AUTOMATIC 已提交
110 111
            writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
            writer.writeheader()
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
            writer.writerows(style._asdict() for k, style in self.styles.items())

    def extract_styles_from_prompt(self, prompt, negative_prompt):
        extracted = []

        applicable_styles = list(self.styles.values())

        while True:
            found_style = None

            for style in applicable_styles:
                is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt)
                if is_match:
                    found_style = style
                    prompt = new_prompt
                    negative_prompt = new_neg_prompt
                    break

            if not found_style:
                break

            applicable_styles.remove(found_style)
            extracted.append(found_style.name)

        return list(reversed(extracted)), prompt, negative_prompt