deepbooru.py 1.9 KB
Newer Older
G
Greendayle 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
import os.path
from concurrent.futures import ProcessPoolExecutor

import numpy as np
import deepdanbooru as dd
import tensorflow as tf


def _load_tf_and_return_tags(pil_image, threshold):
    this_folder = os.path.dirname(__file__)
    model_path = os.path.join(this_folder, '..', 'models', 'deepbooru', 'deepdanbooru-v3-20211112-sgd-e28')
    if not os.path.exists(model_path):
        return "Download https://github.com/KichangKim/DeepDanbooru/releases/download/v3-20211112-sgd-e28/deepdanbooru-v3-20211112-sgd-e28.zip unpack and put into models/deepbooru"

    tags = dd.project.load_tags_from_project(model_path)
    model = dd.project.load_model_from_project(
        model_path, compile_model=True
    )

    width = model.input_shape[2]
    height = model.input_shape[1]
    image = np.array(pil_image)
    image = tf.image.resize(
        image,
        size=(height, width),
        method=tf.image.ResizeMethod.AREA,
        preserve_aspect_ratio=True,
    )
    image = image.numpy()  # EagerTensor to np.array
    image = dd.image.transform_and_pad_image(image, width, height)
    image = image / 255.0
    image_shape = image.shape
    image = image.reshape((1, image_shape[0], image_shape[1], image_shape[2]))

    y = model.predict(image)[0]

    result_dict = {}

    for i, tag in enumerate(tags):
        result_dict[tag] = y[i]



    result_tags_out = []
    result_tags_print = []
    for tag in tags:
        if result_dict[tag] >= threshold:
            result_tags_out.append(tag)
            result_tags_print.append(f'{result_dict[tag]} {tag}')

    print('\n'.join(sorted(result_tags_print, reverse=True)))

    return ', '.join(result_tags_out)


def get_deepbooru_tags(pil_image, threshold=0.5):
    with ProcessPoolExecutor() as executor:
        f = executor.submit(_load_tf_and_return_tags, pil_image, threshold)
        ret = f.result()  # will rethrow any exceptions
    return ret