# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import ctypes import paddle import numpy.ctypeslib as ctl import numpy as np import os import json from ctypes import * from numpy.ctypeslib import ndpointer __dir__ = os.path.dirname(os.path.abspath(__file__)) so_path = os.path.join(__dir__, "index.so") lib = ctypes.cdll.LoadLibrary(so_path) class IndexContext(Structure): _fields_=[("graph",c_void_p), ("data",c_void_p)] # for mobius IP index build_mobius_index = lib.build_mobius_index build_mobius_index.restype = None build_mobius_index.argtypes = [ctl.ndpointer(np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_double, ctypes.c_char_p] search_mobius_index = lib.search_mobius_index search_mobius_index.restype = None search_mobius_index.argtypes = [ctl.ndpointer(np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,ctypes.c_int,POINTER(IndexContext),ctl.ndpointer(np.uint64, flags='aligned, c_contiguous'),ctl.ndpointer(np.float64, flags='aligned, c_contiguous')] load_mobius_index_prefix = lib.load_mobius_index_prefix load_mobius_index_prefix.restype = None load_mobius_index_prefix.argtypes = [ctypes.c_int, ctypes.c_int, POINTER(IndexContext), ctypes.c_char_p] save_mobius_index_prefix = lib.save_mobius_index_prefix save_mobius_index_prefix.restype = None save_mobius_index_prefix.argtypes = [POINTER(IndexContext), ctypes.c_char_p] # for L2 index build_l2_index = lib.build_l2_index build_l2_index.restype = None build_l2_index.argtypes = [ctl.ndpointer(np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_char_p] search_l2_index = lib.search_l2_index search_l2_index.restype = None search_l2_index.argtypes = [ctl.ndpointer(np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,ctypes.c_int,POINTER(IndexContext),ctl.ndpointer(np.uint64, flags='aligned, c_contiguous'),ctl.ndpointer(np.float64, flags='aligned, c_contiguous')] load_l2_index_prefix = lib.load_l2_index_prefix load_l2_index_prefix.restype = None load_l2_index_prefix.argtypes = [ctypes.c_int, ctypes.c_int, POINTER(IndexContext), ctypes.c_char_p] save_l2_index_prefix = lib.save_l2_index_prefix save_l2_index_prefix.restype = None save_l2_index_prefix.argtypes = [POINTER(IndexContext), ctypes.c_char_p] release_context = lib.release_context release_context.restype = None release_context.argtypes = [POINTER(IndexContext)] class Graph_Index(object): """ graph index """ def __init__(self, dist_type="IP"): self.dim = 0 self.total_num = 0 self.dist_type = dist_type self.mobius_pow = 2.0 self.index_context = IndexContext(0,0) self.gallery_doc_dict = {} self.with_attr = False assert dist_type in ["IP", "L2"], "Only support IP and L2 distance ..." def build(self, gallery_vectors, gallery_docs=[], pq_size=100, index_path='graph_index/'): """ build index """ if paddle.is_tensor(gallery_vectors): gallery_vectors = gallery_vectors.numpy() assert gallery_vectors.ndim == 2, "Input vector must be 2D ..." self.total_num = gallery_vectors.shape[0] self.dim = gallery_vectors.shape[1] assert (len(gallery_docs) == self.total_num if len(gallery_docs)>0 else True) print("training index -> num: {}, dim: {}, dist_type: {}".format(self.total_num, self.dim, self.dist_type)) if not os.path.exists(index_path): os.makedirs(index_path) if self.dist_type == "IP": build_mobius_index(gallery_vectors,self.total_num,self.dim, pq_size, self.mobius_pow, create_string_buffer((index_path+"/index").encode('utf-8'))) load_mobius_index_prefix(self.total_num, self.dim, ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8'))) else: build_l2_index(gallery_vectors,self.total_num,self.dim, pq_size, create_string_buffer((index_path+"/index").encode('utf-8'))) load_l2_index_prefix(self.total_num, self.dim, ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8'))) self.gallery_doc_dict = {} if len(gallery_docs) > 0: self.with_attr = True for i in range(gallery_vectors.shape[0]): self.gallery_doc_dict[str(i)] = gallery_docs[i] self.gallery_doc_dict["total_num"] = self.total_num self.gallery_doc_dict["dim"] = self.dim self.gallery_doc_dict["dist_type"] = self.dist_type self.gallery_doc_dict["with_attr"] = self.with_attr with open(index_path + "/info.json", "w") as f: json.dump(self.gallery_doc_dict, f) print("finished creating index ...") def search(self, query, return_k=10, search_budget=100): """ search """ ret_id = np.zeros(return_k, dtype=np.uint64) ret_score = np.zeros(return_k, dtype=np.float64) if paddle.is_tensor(query): query = query.numpy() if self.dist_type == "IP": search_mobius_index(query,self.dim,search_budget,return_k,ctypes.byref(self.index_context),ret_id,ret_score) else: search_l2_index(query,self.dim,search_budget,return_k,ctypes.byref(self.index_context),ret_id,ret_score) ret_id = ret_id.tolist() ret_doc = [] if self.with_attr: for i in range(return_k): ret_doc.append(self.gallery_doc_dict[str(ret_id[i])]) return ret_score, ret_doc else: return ret_score, ret_id def dump(self, index_path): if not os.path.exists(index_path): os.makedirs(index_path) if self.dist_type == "IP": save_mobius_index_prefix(ctypes.byref(self.index_context),create_string_buffer((index_path+"/index").encode('utf-8'))) else: save_l2_index_prefix(ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8'))) with open(index_path + "/info.json", "w") as f: json.dump(self.gallery_doc_dict, f) def load(self, index_path): self.gallery_doc_dict = {} with open(index_path + "/info.json", "r") as f: self.gallery_doc_dict = json.load(f) self.total_num = self.gallery_doc_dict["total_num"] self.dim = self.gallery_doc_dict["dim"] self.dist_type = self.gallery_doc_dict["dist_type"] self.with_attr = self.gallery_doc_dict["with_attr"] if self.dist_type == "IP": load_mobius_index_prefix(self.total_num,self.dim,ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8'))) else: load_l2_index_prefix(self.total_num,self.dim,ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8')))