deepbooru.py 2.6 KB
Newer Older
G
Greendayle 已提交
1 2
import os.path
from concurrent.futures import ProcessPoolExecutor
G
Greendayle 已提交
3
from multiprocessing import get_context
G
Greendayle 已提交
4 5 6


def _load_tf_and_return_tags(pil_image, threshold):
7 8 9 10
    import deepdanbooru as dd
    import tensorflow as tf
    import numpy as np

G
Greendayle 已提交
11
    this_folder = os.path.dirname(__file__)
12 13 14 15 16 17 18 19 20 21
    model_path = os.path.abspath(os.path.join(this_folder, '..', 'models', 'deepbooru'))
    if not os.path.exists(os.path.join(model_path, 'project.json')):
        # there is no point importing these every time
        import zipfile
        from basicsr.utils.download_util import load_file_from_url
        load_file_from_url(r"https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip",
                           model_path)
        with zipfile.ZipFile(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"), "r") as zip_ref:
            zip_ref.extractall(model_path)
        os.remove(os.path.join(model_path, "deepdanbooru-v3-20211112-sgd-e28.zip"))
G
Greendayle 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52

    tags = dd.project.load_tags_from_project(model_path)
    model = dd.project.load_model_from_project(
        model_path, compile_model=True
    )

    width = model.input_shape[2]
    height = model.input_shape[1]
    image = np.array(pil_image)
    image = tf.image.resize(
        image,
        size=(height, width),
        method=tf.image.ResizeMethod.AREA,
        preserve_aspect_ratio=True,
    )
    image = image.numpy()  # EagerTensor to np.array
    image = dd.image.transform_and_pad_image(image, width, height)
    image = image / 255.0
    image_shape = image.shape
    image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))

    y = model.predict(image)[0]

    result_dict = {}

    for i, tag in enumerate(tags):
        result_dict[tag] = y[i]
    result_tags_out = []
    result_tags_print = []
    for tag in tags:
        if result_dict[tag] >= threshold:
G
Greendayle 已提交
53 54
            if tag.startswith("rating:"):
                continue
G
Greendayle 已提交
55 56 57 58 59
            result_tags_out.append(tag)
            result_tags_print.append(f'{result_dict[tag]} {tag}')

    print('\n'.join(sorted(result_tags_print, reverse=True)))

G
Greendayle 已提交
60
    return ', '.join(result_tags_out).replace('_', ' ').replace(':', ' ')
G
Greendayle 已提交
61 62


G
Greendayle 已提交
63 64 65 66 67
def subprocess_init_no_cuda():
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


G
Greendayle 已提交
68
def get_deepbooru_tags(pil_image, threshold=0.5):
G
Greendayle 已提交
69 70
    context = get_context('spawn')
    with ProcessPoolExecutor(initializer=subprocess_init_no_cuda, mp_context=context) as executor:
G
Greendayle 已提交
71
        f = executor.submit(_load_tf_and_return_tags, pil_image, threshold, )
G
Greendayle 已提交
72 73
        ret = f.result()  # will rethrow any exceptions
    return ret