interrogate.py 6.1 KB
Newer Older
1
import contextlib
A
AUTOMATIC 已提交
2 3 4 5 6 7 8 9 10 11 12 13
import os
import sys
import traceback
from collections import namedtuple
import re

import torch

from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

import modules.shared as shared
14
from modules import devices, paths, lowvram
A
AUTOMATIC 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28

blip_image_eval_size = 384
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
clip_model_name = 'ViT-L/14'

Category = namedtuple("Category", ["name", "topn", "items"])

re_topn = re.compile(r"\.top(\d+)\.")

class InterrogateModels:
    blip_model = None
    clip_model = None
    clip_preprocess = None
    categories = None
29
    dtype = None
A
AUTOMATIC 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63

    def __init__(self, content_dir):
        self.categories = []

        if os.path.exists(content_dir):
            for filename in os.listdir(content_dir):
                m = re_topn.search(filename)
                topn = 1 if m is None else int(m.group(1))

                with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
                    lines = [x.strip() for x in file.readlines()]

                self.categories.append(Category(name=filename, topn=topn, items=lines))

    def load_blip_model(self):
        import models.blip

        blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
        blip_model.eval()

        return blip_model

    def load_clip_model(self):
        import clip

        model, preprocess = clip.load(clip_model_name)
        model.eval()
        model = model.to(shared.device)

        return model, preprocess

    def load(self):
        if self.blip_model is None:
            self.blip_model = self.load_blip_model()
64 65
            if not shared.cmd_opts.no_half:
                self.blip_model = self.blip_model.half()
A
AUTOMATIC 已提交
66 67 68 69 70

        self.blip_model = self.blip_model.to(shared.device)

        if self.clip_model is None:
            self.clip_model, self.clip_preprocess = self.load_clip_model()
71 72
            if not shared.cmd_opts.no_half:
                self.clip_model = self.clip_model.half()
A
AUTOMATIC 已提交
73 74 75

        self.clip_model = self.clip_model.to(shared.device)

76 77
        self.dtype = next(self.clip_model.parameters()).dtype

78
    def send_clip_to_ram(self):
A
AUTOMATIC 已提交
79 80 81 82
        if not shared.opts.interrogate_keep_models_in_memory:
            if self.clip_model is not None:
                self.clip_model = self.clip_model.to(devices.cpu)

83 84
    def send_blip_to_ram(self):
        if not shared.opts.interrogate_keep_models_in_memory:
A
AUTOMATIC 已提交
85 86 87
            if self.blip_model is not None:
                self.blip_model = self.blip_model.to(devices.cpu)

88 89 90 91 92
    def unload(self):
        self.send_clip_to_ram()
        self.send_blip_to_ram()

        devices.torch_gc()
A
AUTOMATIC 已提交
93 94 95 96

    def rank(self, image_features, text_array, top_count=1):
        import clip

97 98 99
        if shared.opts.interrogate_clip_dict_limit != 0:
            text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]

A
AUTOMATIC 已提交
100
        top_count = min(top_count, len(text_array))
101 102
        text_tokens = clip.tokenize([text for text in text_array]).to(shared.device)
        text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
A
AUTOMATIC 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
        text_features /= text_features.norm(dim=-1, keepdim=True)

        similarity = torch.zeros((1, len(text_array))).to(shared.device)
        for i in range(image_features.shape[0]):
            similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
        similarity /= image_features.shape[0]

        top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
        return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]

    def generate_caption(self, pil_image):
        gpu_image = transforms.Compose([
            transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
118
        ])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
A
AUTOMATIC 已提交
119 120 121 122 123 124 125 126 127 128

        with torch.no_grad():
            caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)

        return caption[0]

    def interrogate(self, pil_image):
        res = None

        try:
129 130 131 132 133

            if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
                lowvram.send_everything_to_cpu()
                devices.torch_gc()

A
AUTOMATIC 已提交
134 135 136
            self.load()

            caption = self.generate_caption(pil_image)
137 138 139
            self.send_blip_to_ram()
            devices.torch_gc()

A
AUTOMATIC 已提交
140 141
            res = caption

142
            cilp_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
A
AUTOMATIC 已提交
143

144 145
            precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
            with torch.no_grad(), precision_scope("cuda"):
146
                image_features = self.clip_model.encode_image(cilp_image).type(self.dtype)
A
AUTOMATIC 已提交
147

148
                image_features /= image_features.norm(dim=-1, keepdim=True)
A
AUTOMATIC 已提交
149

150 151
                if shared.opts.interrogate_use_builtin_artists:
                    artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
A
AUTOMATIC 已提交
152

153
                    res += ", " + artist[0]
A
AUTOMATIC 已提交
154

155 156 157 158
                for name, topn, items in self.categories:
                    matches = self.rank(image_features, items, top_count=topn)
                    for match, score in matches:
                        res += ", " + match
A
AUTOMATIC 已提交
159 160 161 162 163 164 165

        except Exception:
            print(f"Error interrogating", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)

        self.unload()

166
        res += "<error>"
A
AUTOMATIC 已提交
167
        return res