build_gallery.py 3.3 KB
Newer Older
F
Felix 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))

import copy
import cv2
import numpy as np
from tqdm import tqdm

from python.predict_rec import RecPredictor
from vector_search import Graph_Index

from utils import logger
from utils import config


def split_datafile(data_file, image_root, delimiter="\t"):
    '''
        data_file: image path and info, which can be splitted by spacer 
        image_root: image path root
        delimiter: delimiter 
    '''
    gallery_images = []
    gallery_docs = []
B
Bin Lu 已提交
40
    with open(data_file, 'r', encoding='utf-8') as f:
F
Felix 已提交
41
        lines = f.readlines()
42 43 44 45
        for _, ori_line in enumerate(lines):
            line = ori_line.strip().split(delimiter)
            text_num = len(line)
            assert text_num >= 2, f"line({ori_line}) must be splitted into at least 2 parts, but got {text_num}"
F
Felix 已提交
46
            image_file = os.path.join(image_root, line[0])
47

littletomatodonkey's avatar
littletomatodonkey 已提交
48
            image_doc = line[1]
F
Felix 已提交
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
            gallery_images.append(image_file)
            gallery_docs.append(image_doc)

    return gallery_images, gallery_docs


class GalleryBuilder(object):
    def __init__(self, config):

        self.config = config
        self.rec_predictor = RecPredictor(config)
        assert 'IndexProcess' in config.keys(), "Index config not found ... "
        self.build(config['IndexProcess'])

    def build(self, config):
        '''
            build index from scratch
        '''
littletomatodonkey's avatar
littletomatodonkey 已提交
67 68
        gallery_images, gallery_docs = split_datafile(
            config['data_file'], config['image_root'], config['delimiter'])
F
Felix 已提交
69 70

        # extract gallery features
littletomatodonkey's avatar
littletomatodonkey 已提交
71 72
        gallery_features = np.zeros(
            [len(gallery_images), config['embedding_size']], dtype=np.float32)
F
Felix 已提交
73 74

        for i, image_file in enumerate(tqdm(gallery_images)):
littletomatodonkey's avatar
littletomatodonkey 已提交
75 76 77 78 79
            img = cv2.imread(image_file)
            if img is None:
                logger.error("img empty, please check {}".format(image_file))
                exit()
            img = img[:, :, ::-1]
F
Felix 已提交
80
            rec_feat = self.rec_predictor.predict(img)
littletomatodonkey's avatar
littletomatodonkey 已提交
81
            gallery_features[i, :] = rec_feat
F
Felix 已提交
82 83

        # train index 
littletomatodonkey's avatar
littletomatodonkey 已提交
84 85 86 87 88
        self.Searcher = Graph_Index(dist_type=config['dist_type'])
        self.Searcher.build(
            gallery_vectors=gallery_features,
            gallery_docs=gallery_docs,
            pq_size=config['pq_size'],
89 90
            index_path=config['index_path'],
            append_index=config["append_index"])
littletomatodonkey's avatar
littletomatodonkey 已提交
91

F
Felix 已提交
92 93 94 95 96 97 98 99 100 101

def main(config):
    system_builder = GalleryBuilder(config)
    return


if __name__ == "__main__":
    args = config.parse_args()
    config = config.get_config(args.config, overrides=args.override, show=True)
    main(config)