From 76ef3d75f61253516c024553335d9083d9660a8a Mon Sep 17 00:00:00 2001 From: JC_Array Date: Mon, 10 Oct 2022 18:01:49 -0500 Subject: [PATCH] added deepbooru settings (threshold and sort by alpha or likelyhood) --- modules/deepbooru.py | 36 +++++++++++++++++++++++++----------- modules/shared.py | 6 ++++++ 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/modules/deepbooru.py b/modules/deepbooru.py index ebdba5e08..e31e92c09 100644 --- a/modules/deepbooru.py +++ b/modules/deepbooru.py @@ -3,31 +3,32 @@ from concurrent.futures import ProcessPoolExecutor import multiprocessing import time - -def get_deepbooru_tags(pil_image, threshold=0.5): +def get_deepbooru_tags(pil_image): """ This method is for running only one image at a time for simple use. Used to the img2img interrogate. """ from modules import shared # prevents circular reference - create_deepbooru_process(threshold) + create_deepbooru_process(shared.opts.deepbooru_threshold, shared.opts.deepbooru_sort_alpha) shared.deepbooru_process_return["value"] = -1 shared.deepbooru_process_queue.put(pil_image) while shared.deepbooru_process_return["value"] == -1: time.sleep(0.2) + tags = shared.deepbooru_process_return["value"] release_process() + return tags -def deepbooru_process(queue, deepbooru_process_return, threshold): +def deepbooru_process(queue, deepbooru_process_return, threshold, alpha_sort): model, tags = get_deepbooru_tags_model() while True: # while process is running, keep monitoring queue for new image pil_image = queue.get() if pil_image == "QUIT": break else: - deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold) + deepbooru_process_return["value"] = get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort) -def create_deepbooru_process(threshold=0.5): +def create_deepbooru_process(threshold, alpha_sort): """ Creates deepbooru process. A queue is created to send images into the process. This enables multiple images to be processed in a row without reloading the model or creating a new process. To return the data, a shared @@ -40,7 +41,7 @@ def create_deepbooru_process(threshold=0.5): shared.deepbooru_process_queue = shared.deepbooru_process_manager.Queue() shared.deepbooru_process_return = shared.deepbooru_process_manager.dict() shared.deepbooru_process_return["value"] = -1 - shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold)) + shared.deepbooru_process = multiprocessing.Process(target=deepbooru_process, args=(shared.deepbooru_process_queue, shared.deepbooru_process_return, threshold, alpha_sort)) shared.deepbooru_process.start() @@ -80,7 +81,7 @@ def get_deepbooru_tags_model(): return model, tags -def get_deepbooru_tags_from_model(model, tags, pil_image, threshold=0.5): +def get_deepbooru_tags_from_model(model, tags, pil_image, threshold, alpha_sort): import deepdanbooru as dd import tensorflow as tf import numpy as np @@ -105,15 +106,28 @@ def get_deepbooru_tags_from_model(model, tags, pil_image, threshold=0.5): for i, tag in enumerate(tags): result_dict[tag] = y[i] - result_tags_out = [] + + unsorted_tags_in_theshold = [] result_tags_print = [] for tag in tags: if result_dict[tag] >= threshold: if tag.startswith("rating:"): continue - result_tags_out.append(tag) + unsorted_tags_in_theshold.append((result_dict[tag], tag)) result_tags_print.append(f'{result_dict[tag]} {tag}') + # sort tags + result_tags_out = [] + sort_ndx = 0 + print(alpha_sort) + if alpha_sort: + sort_ndx = 1 + + # sort by reverse by likelihood and normal for alpha + unsorted_tags_in_theshold.sort(key=lambda y: y[sort_ndx], reverse=(not alpha_sort)) + for weight, tag in unsorted_tags_in_theshold: + result_tags_out.append(tag) + print('\n'.join(sorted(result_tags_print, reverse=True))) - return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') \ No newline at end of file + return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ') diff --git a/modules/shared.py b/modules/shared.py index 1995a99a7..2e307809b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -261,6 +261,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 's_noise': OptionInfo(1.0, "sigma noise", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), })) +if cmd_opts.deepdanbooru: + options_templates.update(options_section(('deepbooru-params', "DeepBooru parameters"), { + "deepbooru_sort_alpha": OptionInfo(True, "Sort Alphabetical", gr.Checkbox), + 'deepbooru_threshold': OptionInfo(0.5, "Threshold", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + })) + class Options: data = None -- GitLab