build_gallery.py 8.0 KB
Newer Older
F
Felix 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))

import cv2
D
dongshuilong 已提交
21
import faiss
F
Felix 已提交
22 23
import numpy as np
from tqdm import tqdm
D
dongshuilong 已提交
24
import pickle
F
Felix 已提交
25 26 27 28 29 30 31 32

from python.predict_rec import RecPredictor

from utils import logger
from utils import config

def split_datafile(data_file, image_root, delimiter="\t"):
    '''
D
dongshuilong 已提交
33
        data_file: image path and info, which can be splitted by spacer
F
Felix 已提交
34
        image_root: image path root
D
dongshuilong 已提交
35
        delimiter: delimiter
F
Felix 已提交
36 37 38
    '''
    gallery_images = []
    gallery_docs = []
B
Bin Lu 已提交
39
    with open(data_file, 'r', encoding='utf-8') as f:
F
Felix 已提交
40
        lines = f.readlines()
41 42 43 44
        for _, ori_line in enumerate(lines):
            line = ori_line.strip().split(delimiter)
            text_num = len(line)
            assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
F
Felix 已提交
45
            image_file = os.path.join(image_root, line[0])
46

F
Felix 已提交
47
            gallery_images.append(image_file)
D
dongshuilong 已提交
48
            gallery_docs.append(ori_line.strip())
F
Felix 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64

    return gallery_images, gallery_docs


class GalleryBuilder(object):
    def __init__(self, config):

        self.config = config
        self.rec_predictor = RecPredictor(config)
        assert 'IndexProcess' in config.keys(), "Index config not found ... "
        self.build(config['IndexProcess'])

    def build(self, config):
        '''
            build index from scratch
        '''
D
dongshuilong 已提交
65 66
        operation_method = config.get("index_operation", "new").lower()

littletomatodonkey's avatar
littletomatodonkey 已提交
67 68
        gallery_images, gallery_docs = split_datafile(
            config['data_file'], config['image_root'], config['delimiter'])
69 70

        # when remove data in index, do not need extract fatures
D
dongshuilong 已提交
71
        if operation_method != "remove":
72 73
            gallery_features = self._extract_features(gallery_images, config)    #76 * 512
        
D
dongshuilong 已提交
74 75 76
        assert operation_method in [
            "new", "remove", "append"
        ], "Only append, remove and new operation are supported"
77 78 79

        # vector.index: faiss index file
        # id_map.pkl: use this file to map id to image_doc
D
dongshuilong 已提交
80
        if operation_method in ["remove", "append"]:
81
            # if remove or append, vector.index and id_map.pkl must exist
D
dongshuilong 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
            assert os.path.join(
                config["index_dir"], "vector.index"
            ), "The vector.index dose not exist in {} when 'index_operation' is not None".format(
                config["index_dir"])
            assert os.path.join(
                config["index_dir"], "id_map.pkl"
            ), "The id_map.pkl dose not exist in {} when 'index_operation' is not None".format(
                config["index_dir"])
            index = faiss.read_index(
                os.path.join(config["index_dir"], "vector.index"))
            with open(os.path.join(config["index_dir"], "id_map.pkl"),
                      'rb') as fd:
                ids = pickle.load(fd)
            assert index.ntotal == len(ids.keys(
            )), "data number in index is not equal in in id_map"
        else:
            if not os.path.exists(config["index_dir"]):
                os.makedirs(config["index_dir"], exist_ok=True)
            index_method = config.get("index_method", "HNSW32")
101 102

            # if IVF method, cal ivf number automaticlly
D
dongshuilong 已提交
103 104 105
            if index_method == "IVF":
                index_method = index_method + str(
                    min(int(len(gallery_images) // 8), 65536)) + ",Flat"
106 107 108 109 110 111

            # for binary index, add B at head of index_method
            if  config["dist_type"]  == "hamming":
                index_method = "B" + index_method
            
            #dist_type
D
dongshuilong 已提交
112 113
            dist_type = faiss.METRIC_INNER_PRODUCT if config[
                "dist_type"] == "IP" else faiss.METRIC_L2
114 115 116 117 118 119 120 121
            
            #build index
            if config["dist_type"]  == "hamming":
                index = faiss.index_binary_factory(config["embedding_size"], index_method)
            else:
                index = faiss.index_factory(config["embedding_size"], index_method,
                                            dist_type)
                index = faiss.IndexIDMap2(index)
D
dongshuilong 已提交
122 123 124 125 126 127 128
            ids = {}

        if config["index_method"] == "HNSW32":
            logger.warning(
                "The HNSW32 method dose not support 'remove' operation")

        if operation_method != "remove":
129
            # calculate id for new data
D
dongshuilong 已提交
130
            start_id = max(ids.keys()) + 1 if ids else 0
D
dongshuilong 已提交
131
            ids_now = (
132
                np.arange(0, len(gallery_images)) + start_id).astype(np.int64)  #ids: just the number sequence
133 134

            # only train when new index file
D
dongshuilong 已提交
135
            if operation_method == "new":
136 137 138 139 140 141 142
                if config["dist_type"]  == "hamming":
                    index.add(gallery_features)
                else:
                    index.train(gallery_features)
            
            if not config["dist_type"]  == "hamming":
                index.add_with_ids(gallery_features, ids_now)
D
dongshuilong 已提交
143 144 145 146 147 148 149 150

            for i, d in zip(list(ids_now), gallery_docs):
                ids[i] = d
        else:
            if config["index_method"] == "HNSW32":
                raise RuntimeError(
                    "The index_method: HNSW32 dose not support 'remove' operation"
                )
151
            # remove ids in id_map, remove index data in faiss index
D
dongshuilong 已提交
152 153 154 155 156 157 158
            remove_ids = list(
                filter(lambda k: ids.get(k) in gallery_docs, ids.keys()))
            remove_ids = np.asarray(remove_ids)
            index.remove_ids(remove_ids)
            for k in remove_ids:
                del ids[k]

159
        # store faiss index file and id_map file
160 161 162 163 164 165 166
        if  config["dist_type"]  == "hamming":
            faiss.write_index_binary(index,
                            os.path.join(config["index_dir"], "vector.index"))
        else:
            faiss.write_index(index,
                            os.path.join(config["index_dir"], "vector.index"))

D
dongshuilong 已提交
167 168 169 170
        with open(os.path.join(config["index_dir"], "id_map.pkl"), 'wb') as fd:
            pickle.dump(ids, fd)

    def _extract_features(self, gallery_images, config):
171

F
Felix 已提交
172
        # extract gallery features
173 174 175 176 177 178
        if config["dist_type"] == "hamming":
            gallery_features = np.zeros(
                [len(gallery_images), config['embedding_size'] // 8], dtype=np.uint8)
        else:
            gallery_features = np.zeros(
                [len(gallery_images), config['embedding_size']], dtype=np.float32)
F
Felix 已提交
179

L
lubin10 已提交
180 181 182
        #construct batch imgs and do inference
        batch_size = config.get("batch_size", 32)
        batch_img = []
F
Felix 已提交
183
        for i, image_file in enumerate(tqdm(gallery_images)):
littletomatodonkey's avatar
littletomatodonkey 已提交
184 185 186 187 188
            img = cv2.imread(image_file)
            if img is None:
                logger.error("img empty, please check {}".format(image_file))
                exit()
            img = img[:, :, ::-1]
L
lubin10 已提交
189 190 191
            batch_img.append(img)

            if (i + 1) % batch_size == 0:
192
                rec_feat = self.rec_predictor.predict(batch_img)  #32 * 512
L
lubin10 已提交
193 194 195 196 197 198 199
                gallery_features[i - batch_size + 1:i + 1, :] = rec_feat
                batch_img = []

        if len(batch_img) > 0:
            rec_feat = self.rec_predictor.predict(batch_img)
            gallery_features[-len(batch_img):, :] = rec_feat
            batch_img = []
200

D
dongshuilong 已提交
201
        return gallery_features
littletomatodonkey's avatar
littletomatodonkey 已提交
202

F
Felix 已提交
203 204

def main(config):
D
dongshuilong 已提交
205
    GalleryBuilder(config)
F
Felix 已提交
206 207 208 209 210 211 212
    return


if __name__ == "__main__":
    args = config.parse_args()
    config = config.get_config(args.config, overrides=args.override, show=True)
    main(config)