interface.py 6.7 KB
Newer Older
F
Felix 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
import ctypes
import paddle
import numpy.ctypeslib as ctl
import numpy as np
import os
import json

from ctypes import *
from numpy.ctypeslib import ndpointer

lib = ctypes.cdll.LoadLibrary("./index.so")

class IndexContext(Structure):
    _fields_=[("graph",c_void_p),
              ("data",c_void_p)]

# for mobius IP index
build_mobius_index = lib.build_mobius_index
build_mobius_index.restype = None
build_mobius_index.argtypes = [ctl.ndpointer(np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_double, ctypes.c_char_p]

search_mobius_index = lib.search_mobius_index
search_mobius_index.restype = None
search_mobius_index.argtypes = [ctl.ndpointer(np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,ctypes.c_int,POINTER(IndexContext),ctl.ndpointer(np.uint64, flags='aligned, c_contiguous'),ctl.ndpointer(np.float64, flags='aligned, c_contiguous')]

load_mobius_index_prefix = lib.load_mobius_index_prefix
load_mobius_index_prefix.restype = None
load_mobius_index_prefix.argtypes = [ctypes.c_int, ctypes.c_int, POINTER(IndexContext), ctypes.c_char_p]

save_mobius_index_prefix = lib.save_mobius_index_prefix
save_mobius_index_prefix.restype = None
save_mobius_index_prefix.argtypes = [POINTER(IndexContext), ctypes.c_char_p]


# for L2 index
build_l2_index = lib.build_l2_index
build_l2_index.restype = None
build_l2_index.argtypes = [ctl.ndpointer(np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_char_p]

search_l2_index = lib.search_l2_index
search_l2_index.restype = None
search_l2_index.argtypes = [ctl.ndpointer(np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,ctypes.c_int,POINTER(IndexContext),ctl.ndpointer(np.uint64, flags='aligned, c_contiguous'),ctl.ndpointer(np.float64, flags='aligned, c_contiguous')]

load_l2_index_prefix = lib.load_l2_index_prefix
load_l2_index_prefix.restype = None
load_l2_index_prefix.argtypes = [ctypes.c_int, ctypes.c_int, POINTER(IndexContext), ctypes.c_char_p]

save_l2_index_prefix = lib.save_l2_index_prefix
save_l2_index_prefix.restype = None
save_l2_index_prefix.argtypes = [POINTER(IndexContext), ctypes.c_char_p]

release_context = lib.release_context
release_context.restype = None
release_context.argtypes = [POINTER(IndexContext)]



class Graph_Index(object):
    """
        graph index
    """
    def __init__(self, dist_type="IP"):
        self.dim = 0
        self.total_num = 0
        self.dist_type = dist_type
        self.mobius_pow = 2.0
        self.index_context = IndexContext(0,0)
        self.gallery_doc_dict = {}
        self.with_attr = False
        assert dist_type in ["IP", "L2"], "Only support IP and L2 distance ..."
    
    def build(self, gallery_vectors, gallery_docs=[], pq_size=100, index_path='graph_index/'):
        """
        build index 
        """
        if paddle.is_tensor(gallery_vectors):
              gallery_vectors = gallery_vectors.numpy()
        assert gallery_vectors.ndim == 2, "Input vector must be 2D ..."
        
        self.total_num = gallery_vectors.shape[0]
        self.dim = gallery_vectors.shape[1]

        assert (len(gallery_docs) == self.total_num if len(gallery_docs)>0 else True)
 
        print("training index -> num: {}, dim: {}, dist_type: {}".format(self.total_num, self.dim, self.dist_type))

        if not os.path.exists(index_path):
            os.makedirs(index_path)
 
        if self.dist_type == "IP":
            build_mobius_index(gallery_vectors,self.total_num,self.dim, pq_size, self.mobius_pow, create_string_buffer((index_path+"/index").encode('utf-8')))
            load_mobius_index_prefix(self.total_num, self.dim, ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8')))
        else:
            build_l2_index(gallery_vectors,self.total_num,self.dim, pq_size, create_string_buffer((index_path+"/index").encode('utf-8')))
            load_l2_index_prefix(self.total_num, self.dim, ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8')))
        
        self.gallery_doc_dict = {}       
        if len(gallery_docs) > 0:
            self.with_attr = True
            for i in range(gallery_vectors.shape[0]):
                self.gallery_doc_dict[str(i)] = gallery_docs[i] 

        self.gallery_doc_dict["total_num"] = self.total_num
        self.gallery_doc_dict["dim"] = self.dim
        self.gallery_doc_dict["dist_type"] = self.dist_type
        self.gallery_doc_dict["with_attr"] = self.with_attr

        with open(index_path + "/info.json", "w") as f:
            json.dump(self.gallery_doc_dict, f)

        print("finished creating index ...")

    def search(self, query, return_k=10, search_budget=100):
        """
        search
        """
        ret_id = np.zeros(return_k, dtype=np.uint64)
        ret_score = np.zeros(return_k, dtype=np.float64)

        if paddle.is_tensor(query):
              query = query.numpy()
        if self.dist_type == "IP":
            search_mobius_index(query,self.dim,search_budget,return_k,ctypes.byref(self.index_context),ret_id,ret_score)
        else:
            search_l2_index(query,self.dim,search_budget,return_k,ctypes.byref(self.index_context),ret_id,ret_score)
         
        ret_id = ret_id.tolist()
        ret_doc = []
        if self.with_attr: 
            for i in range(return_k):
                ret_doc.append(self.gallery_doc_dict[str(ret_id[i])])
            return ret_score, ret_doc
        else:
            return ret_score, ret_id

    def dump(self, index_path):

        if not os.path.exists(index_path):
            os.makedirs(index_path)

        if self.dist_type == "IP":
            save_mobius_index_prefix(ctypes.byref(self.index_context),create_string_buffer((index_path+"/index").encode('utf-8')))
        else:
            save_l2_index_prefix(ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8')))
        
        with open(index_path + "/info.json", "w") as f:
            json.dump(self.gallery_doc_dict, f)

    def load(self, index_path):
        self.gallery_doc_dict = {}
        
        with open(index_path + "/info.json", "r") as f:
            self.gallery_doc_dict = json.load(f)
        
        self.total_num = self.gallery_doc_dict["total_num"]
        self.dim = self.gallery_doc_dict["dim"]
        self.dist_type = self.gallery_doc_dict["dist_type"]    
        self.with_attr = self.gallery_doc_dict["with_attr"]

        if self.dist_type == "IP":
            load_mobius_index_prefix(self.total_num,self.dim,ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8')))
        else:
            load_l2_index_prefix(self.total_num,self.dim,ctypes.byref(self.index_context), create_string_buffer((index_path+"/index").encode('utf-8')))