interface.py 8.4 KB
Newer Older
F
Felix 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

F
Felix 已提交
15 16 17 18 19 20
import ctypes
import paddle
import numpy.ctypeslib as ctl
import numpy as np
import os
import json
21
import platform
F
Felix 已提交
22 23 24 25

from ctypes import *
from numpy.ctypeslib import ndpointer

F
Felix 已提交
26
__dir__ = os.path.dirname(os.path.abspath(__file__))
27 28 29 30 31
if platform.system() == "Windows":
    lib_filename = "index.dll"
else:
    lib_filename = "index.so"
so_path = os.path.join(__dir__, lib_filename)
32 33 34 35 36 37 38 39
try:
    lib = ctypes.cdll.LoadLibrary(so_path)
except Exception as ex:
    readme_path = os.path.join(__dir__, "README.md")
    print(
        f"Error happened when load lib {so_path} with msg {ex},\nplease refer to {readme_path} to rebuild your library."
    )
    exit(-1)
F
Felix 已提交
40

41

F
Felix 已提交
42
class IndexContext(Structure):
43 44
    _fields_ = [("graph", c_void_p), ("data", c_void_p)]

F
Felix 已提交
45 46 47 48

# for mobius IP index
build_mobius_index = lib.build_mobius_index
build_mobius_index.restype = None
49 50 51 52 53
build_mobius_index.argtypes = [
    ctl.ndpointer(
        np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,
    ctypes.c_int, ctypes.c_double, ctypes.c_char_p
]
F
Felix 已提交
54 55 56

search_mobius_index = lib.search_mobius_index
search_mobius_index.restype = None
57 58 59 60 61 62 63
search_mobius_index.argtypes = [
    ctl.ndpointer(
        np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,
    ctypes.c_int, POINTER(IndexContext), ctl.ndpointer(
        np.uint64, flags='aligned, c_contiguous'), ctl.ndpointer(
            np.float64, flags='aligned, c_contiguous')
]
F
Felix 已提交
64 65 66

load_mobius_index_prefix = lib.load_mobius_index_prefix
load_mobius_index_prefix.restype = None
67 68 69
load_mobius_index_prefix.argtypes = [
    ctypes.c_int, ctypes.c_int, POINTER(IndexContext), ctypes.c_char_p
]
F
Felix 已提交
70 71 72 73 74 75 76 77

save_mobius_index_prefix = lib.save_mobius_index_prefix
save_mobius_index_prefix.restype = None
save_mobius_index_prefix.argtypes = [POINTER(IndexContext), ctypes.c_char_p]

# for L2 index
build_l2_index = lib.build_l2_index
build_l2_index.restype = None
78 79 80 81 82
build_l2_index.argtypes = [
    ctl.ndpointer(
        np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,
    ctypes.c_int, ctypes.c_char_p
]
F
Felix 已提交
83 84 85

search_l2_index = lib.search_l2_index
search_l2_index.restype = None
86 87 88 89 90 91 92
search_l2_index.argtypes = [
    ctl.ndpointer(
        np.float32, flags='aligned, c_contiguous'), ctypes.c_int, ctypes.c_int,
    ctypes.c_int, POINTER(IndexContext), ctl.ndpointer(
        np.uint64, flags='aligned, c_contiguous'), ctl.ndpointer(
            np.float64, flags='aligned, c_contiguous')
]
F
Felix 已提交
93 94 95

load_l2_index_prefix = lib.load_l2_index_prefix
load_l2_index_prefix.restype = None
96 97 98
load_l2_index_prefix.argtypes = [
    ctypes.c_int, ctypes.c_int, POINTER(IndexContext), ctypes.c_char_p
]
F
Felix 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112

save_l2_index_prefix = lib.save_l2_index_prefix
save_l2_index_prefix.restype = None
save_l2_index_prefix.argtypes = [POINTER(IndexContext), ctypes.c_char_p]

release_context = lib.release_context
release_context.restype = None
release_context.argtypes = [POINTER(IndexContext)]


class Graph_Index(object):
    """
        graph index
    """
113

F
Felix 已提交
114 115 116 117 118
    def __init__(self, dist_type="IP"):
        self.dim = 0
        self.total_num = 0
        self.dist_type = dist_type
        self.mobius_pow = 2.0
119
        self.index_context = IndexContext(0, 0)
F
Felix 已提交
120 121 122
        self.gallery_doc_dict = {}
        self.with_attr = False
        assert dist_type in ["IP", "L2"], "Only support IP and L2 distance ..."
123 124 125 126 127 128

    def build(self,
              gallery_vectors,
              gallery_docs=[],
              pq_size=100,
              index_path='graph_index/'):
F
Felix 已提交
129 130 131 132
        """
        build index 
        """
        if paddle.is_tensor(gallery_vectors):
133
            gallery_vectors = gallery_vectors.numpy()
F
Felix 已提交
134
        assert gallery_vectors.ndim == 2, "Input vector must be 2D ..."
135

F
Felix 已提交
136 137 138
        self.total_num = gallery_vectors.shape[0]
        self.dim = gallery_vectors.shape[1]

139 140 141 142 143
        assert (len(gallery_docs) == self.total_num
                if len(gallery_docs) > 0 else True)

        print("training index -> num: {}, dim: {}, dist_type: {}".format(
            self.total_num, self.dim, self.dist_type))
F
Felix 已提交
144 145 146

        if not os.path.exists(index_path):
            os.makedirs(index_path)
147

F
Felix 已提交
148
        if self.dist_type == "IP":
149 150 151 152 153 154 155 156
            build_mobius_index(
                gallery_vectors, self.total_num, self.dim, pq_size,
                self.mobius_pow,
                create_string_buffer((index_path + "/index").encode('utf-8')))
            load_mobius_index_prefix(
                self.total_num, self.dim,
                ctypes.byref(self.index_context),
                create_string_buffer((index_path + "/index").encode('utf-8')))
F
Felix 已提交
157
        else:
158 159 160 161 162 163 164 165 166
            build_l2_index(
                gallery_vectors, self.total_num, self.dim, pq_size,
                create_string_buffer((index_path + "/index").encode('utf-8')))
            load_l2_index_prefix(
                self.total_num, self.dim,
                ctypes.byref(self.index_context),
                create_string_buffer((index_path + "/index").encode('utf-8')))

        self.gallery_doc_dict = {}
F
Felix 已提交
167 168 169
        if len(gallery_docs) > 0:
            self.with_attr = True
            for i in range(gallery_vectors.shape[0]):
170
                self.gallery_doc_dict[str(i)] = gallery_docs[i]
F
Felix 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189

        self.gallery_doc_dict["total_num"] = self.total_num
        self.gallery_doc_dict["dim"] = self.dim
        self.gallery_doc_dict["dist_type"] = self.dist_type
        self.gallery_doc_dict["with_attr"] = self.with_attr

        with open(index_path + "/info.json", "w") as f:
            json.dump(self.gallery_doc_dict, f)

        print("finished creating index ...")

    def search(self, query, return_k=10, search_budget=100):
        """
        search
        """
        ret_id = np.zeros(return_k, dtype=np.uint64)
        ret_score = np.zeros(return_k, dtype=np.float64)

        if paddle.is_tensor(query):
190
            query = query.numpy()
F
Felix 已提交
191
        if self.dist_type == "IP":
192 193 194
            search_mobius_index(query, self.dim, search_budget, return_k,
                                ctypes.byref(self.index_context), ret_id,
                                ret_score)
F
Felix 已提交
195
        else:
196 197 198 199
            search_l2_index(query, self.dim, search_budget, return_k,
                            ctypes.byref(self.index_context), ret_id,
                            ret_score)

F
Felix 已提交
200 201
        ret_id = ret_id.tolist()
        ret_doc = []
202
        if self.with_attr:
F
Felix 已提交
203 204 205 206 207 208 209 210 211 212 213 214
            for i in range(return_k):
                ret_doc.append(self.gallery_doc_dict[str(ret_id[i])])
            return ret_score, ret_doc
        else:
            return ret_score, ret_id

    def dump(self, index_path):

        if not os.path.exists(index_path):
            os.makedirs(index_path)

        if self.dist_type == "IP":
215 216 217
            save_mobius_index_prefix(
                ctypes.byref(self.index_context),
                create_string_buffer((index_path + "/index").encode('utf-8')))
F
Felix 已提交
218
        else:
219 220 221 222
            save_l2_index_prefix(
                ctypes.byref(self.index_context),
                create_string_buffer((index_path + "/index").encode('utf-8')))

F
Felix 已提交
223 224 225 226 227
        with open(index_path + "/info.json", "w") as f:
            json.dump(self.gallery_doc_dict, f)

    def load(self, index_path):
        self.gallery_doc_dict = {}
228

F
Felix 已提交
229 230
        with open(index_path + "/info.json", "r") as f:
            self.gallery_doc_dict = json.load(f)
231

F
Felix 已提交
232 233
        self.total_num = self.gallery_doc_dict["total_num"]
        self.dim = self.gallery_doc_dict["dim"]
234
        self.dist_type = self.gallery_doc_dict["dist_type"]
F
Felix 已提交
235 236 237
        self.with_attr = self.gallery_doc_dict["with_attr"]

        if self.dist_type == "IP":
238 239 240 241
            load_mobius_index_prefix(
                self.total_num, self.dim,
                ctypes.byref(self.index_context),
                create_string_buffer((index_path + "/index").encode('utf-8')))
F
Felix 已提交
242
        else:
243 244 245 246
            load_l2_index_prefix(
                self.total_num, self.dim,
                ctypes.byref(self.index_context),
                create_string_buffer((index_path + "/index").encode('utf-8')))