From b7f45e67dcf48914d2f34d4ace977a431a5aa12e Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 11 Feb 2024 12:56:53 +0300 Subject: [PATCH] add before_token_counter callback and use it for prompt comments --- modules/processing_scripts/comments.py | 12 +++++++++++- modules/script_callbacks.py | 26 ++++++++++++++++++++++++++ modules/ui.py | 6 ++++++ 3 files changed, 43 insertions(+), 1 deletion(-) diff --git a/modules/processing_scripts/comments.py b/modules/processing_scripts/comments.py index 316356c77..638e39f29 100644 --- a/modules/processing_scripts/comments.py +++ b/modules/processing_scripts/comments.py @@ -1,4 +1,4 @@ -from modules import scripts, shared +from modules import scripts, shared, script_callbacks import re @@ -27,6 +27,16 @@ class ScriptStripComments(scripts.Script): p.main_negative_prompt = strip_comments(p.main_negative_prompt) +def before_token_counter(params: script_callbacks.BeforeTokenCounterParams): + if not shared.opts.enable_prompt_comments: + return + + params.prompt = strip_comments(params.prompt) + + +script_callbacks.on_before_token_counter(before_token_counter) + + shared.options_templates.update(shared.options_section(('sd', "Stable Diffusion", "sd"), { "enable_prompt_comments": shared.OptionInfo(True, "Enable comments").info("Use # anywhere in the prompt to hide the text between # and the end of the line from the generation."), })) diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index a54cb3ebb..08bc52564 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,3 +1,4 @@ +import dataclasses import inspect import os from collections import namedtuple @@ -106,6 +107,15 @@ class ImageGridLoopParams: self.rows = rows +@dataclasses.dataclass +class BeforeTokenCounterParams: + prompt: str + steps: int + styles: list + + is_positive: bool = True + + ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) callback_map = dict( callbacks_app_started=[], @@ -128,6 +138,7 @@ callback_map = dict( callbacks_on_reload=[], callbacks_list_optimizers=[], callbacks_list_unets=[], + callbacks_before_token_counter=[], ) @@ -309,6 +320,14 @@ def list_unets_callback(): return res +def before_token_counter_callback(params: BeforeTokenCounterParams): + for c in callback_map['callbacks_before_token_counter']: + try: + c.callback(params) + except Exception: + report_exception(c, 'before_token_counter') + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if stack else 'unknown file' @@ -483,3 +502,10 @@ def on_list_unets(callback): The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it.""" add_callback(callback_map['callbacks_list_unets'], callback) + + +def on_before_token_counter(callback): + """register a function to be called when UI is counting tokens for a prompt. + The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary.""" + + add_callback(callback_map['callbacks_before_token_counter'], callback) diff --git a/modules/ui.py b/modules/ui.py index 5284a6300..dcba8e885 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -152,6 +152,12 @@ def connect_clear_prompt(button): def update_token_counter(text, steps, styles, *, is_positive=True): + params = script_callbacks.BeforeTokenCounterParams(text, steps, styles, is_positive=is_positive) + script_callbacks.before_token_counter_callback(params) + text = params.prompt + steps = params.steps + styles = params.styles + is_positive = params.is_positive if shared.opts.include_styles_into_token_counters: apply_styles = shared.prompt_styles.apply_styles_to_prompt if is_positive else shared.prompt_styles.apply_negative_styles_to_prompt -- GitLab