interrogate.py 6.3 KB
Newer Older
1
import contextlib
A
AUTOMATIC 已提交
2 3 4 5 6 7 8 9 10 11 12 13
import os
import sys
import traceback
from collections import namedtuple
import re

import torch

from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

import modules.shared as shared
14
from modules import devices, paths, lowvram
A
AUTOMATIC 已提交
15 16 17 18 19 20 21 22 23

blip_image_eval_size = 384
blip_model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_caption_capfilt_large.pth'
clip_model_name = 'ViT-L/14'

Category = namedtuple("Category", ["name", "topn", "items"])

re_topn = re.compile(r"\.top(\d+)\.")

24

A
AUTOMATIC 已提交
25 26 27 28 29
class InterrogateModels:
    blip_model = None
    clip_model = None
    clip_preprocess = None
    categories = None
30
    dtype = None
A
AUTOMATIC 已提交
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64

    def __init__(self, content_dir):
        self.categories = []

        if os.path.exists(content_dir):
            for filename in os.listdir(content_dir):
                m = re_topn.search(filename)
                topn = 1 if m is None else int(m.group(1))

                with open(os.path.join(content_dir, filename), "r", encoding="utf8") as file:
                    lines = [x.strip() for x in file.readlines()]

                self.categories.append(Category(name=filename, topn=topn, items=lines))

    def load_blip_model(self):
        import models.blip

        blip_model = models.blip.blip_decoder(pretrained=blip_model_url, image_size=blip_image_eval_size, vit='base', med_config=os.path.join(paths.paths["BLIP"], "configs", "med_config.json"))
        blip_model.eval()

        return blip_model

    def load_clip_model(self):
        import clip

        model, preprocess = clip.load(clip_model_name)
        model.eval()
        model = model.to(shared.device)

        return model, preprocess

    def load(self):
        if self.blip_model is None:
            self.blip_model = self.load_blip_model()
65 66
            if not shared.cmd_opts.no_half:
                self.blip_model = self.blip_model.half()
A
AUTOMATIC 已提交
67 68 69 70 71

        self.blip_model = self.blip_model.to(shared.device)

        if self.clip_model is None:
            self.clip_model, self.clip_preprocess = self.load_clip_model()
72 73
            if not shared.cmd_opts.no_half:
                self.clip_model = self.clip_model.half()
A
AUTOMATIC 已提交
74 75 76

        self.clip_model = self.clip_model.to(shared.device)

77 78
        self.dtype = next(self.clip_model.parameters()).dtype

79
    def send_clip_to_ram(self):
A
AUTOMATIC 已提交
80 81 82 83
        if not shared.opts.interrogate_keep_models_in_memory:
            if self.clip_model is not None:
                self.clip_model = self.clip_model.to(devices.cpu)

84 85
    def send_blip_to_ram(self):
        if not shared.opts.interrogate_keep_models_in_memory:
A
AUTOMATIC 已提交
86 87 88
            if self.blip_model is not None:
                self.blip_model = self.blip_model.to(devices.cpu)

89 90 91 92 93
    def unload(self):
        self.send_clip_to_ram()
        self.send_blip_to_ram()

        devices.torch_gc()
A
AUTOMATIC 已提交
94 95 96 97

    def rank(self, image_features, text_array, top_count=1):
        import clip

98 99 100
        if shared.opts.interrogate_clip_dict_limit != 0:
            text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]

A
AUTOMATIC 已提交
101
        top_count = min(top_count, len(text_array))
102
        text_tokens = clip.tokenize([text for text in text_array], truncate=True).to(shared.device)
103
        text_features = self.clip_model.encode_text(text_tokens).type(self.dtype)
A
AUTOMATIC 已提交
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        text_features /= text_features.norm(dim=-1, keepdim=True)

        similarity = torch.zeros((1, len(text_array))).to(shared.device)
        for i in range(image_features.shape[0]):
            similarity += (100.0 * image_features[i].unsqueeze(0) @ text_features.T).softmax(dim=-1)
        similarity /= image_features.shape[0]

        top_probs, top_labels = similarity.cpu().topk(top_count, dim=-1)
        return [(text_array[top_labels[0][i].numpy()], (top_probs[0][i].numpy()*100)) for i in range(top_count)]

    def generate_caption(self, pil_image):
        gpu_image = transforms.Compose([
            transforms.Resize((blip_image_eval_size, blip_image_eval_size), interpolation=InterpolationMode.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
119
        ])(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
A
AUTOMATIC 已提交
120 121 122 123 124 125

        with torch.no_grad():
            caption = self.blip_model.generate(gpu_image, sample=False, num_beams=shared.opts.interrogate_clip_num_beams, min_length=shared.opts.interrogate_clip_min_length, max_length=shared.opts.interrogate_clip_max_length)

        return caption[0]

126
    def interrogate(self, pil_image, include_ranks=False):
A
AUTOMATIC 已提交
127 128 129
        res = None

        try:
130 131 132 133 134

            if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
                lowvram.send_everything_to_cpu()
                devices.torch_gc()

A
AUTOMATIC 已提交
135 136 137
            self.load()

            caption = self.generate_caption(pil_image)
138 139 140
            self.send_blip_to_ram()
            devices.torch_gc()

A
AUTOMATIC 已提交
141 142
            res = caption

A
Aidan Holland 已提交
143
            clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(shared.device)
A
AUTOMATIC 已提交
144

145 146
            precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
            with torch.no_grad(), precision_scope("cuda"):
A
Aidan Holland 已提交
147
                image_features = self.clip_model.encode_image(clip_image).type(self.dtype)
A
AUTOMATIC 已提交
148

149
                image_features /= image_features.norm(dim=-1, keepdim=True)
A
AUTOMATIC 已提交
150

151 152
                if shared.opts.interrogate_use_builtin_artists:
                    artist = self.rank(image_features, ["by " + artist.name for artist in shared.artist_db.artists])[0]
A
AUTOMATIC 已提交
153

154
                    res += ", " + artist[0]
A
AUTOMATIC 已提交
155

156 157 158
                for name, topn, items in self.categories:
                    matches = self.rank(image_features, items, top_count=topn)
                    for match, score in matches:
159 160 161 162
                        if include_ranks:
                            res += ", " + match
                        else:
                            res += f", ({match}:{score})"
A
AUTOMATIC 已提交
163 164 165 166

        except Exception:
            print(f"Error interrogating", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
A
AUTOMATIC 已提交
167
            res += "<error>"
A
AUTOMATIC 已提交
168 169 170 171

        self.unload()

        return res