build_index.py 1.3 KB
Newer Older
X
xj.lin 已提交
1 2 3 4 5
import faiss
from enum import Enum, unique


@unique
X
xj.lin 已提交
6
class INDEXDEVICES(Enum):
X
xj.lin 已提交
7 8 9 10 11 12 13
    CPU = 0
    GPU = 1
    MULTI_GPU = 2


def FactoryIndex(index_name="DefaultIndex"):
    cls = globals()[index_name]
X
xj.lin 已提交
14
    return cls  # invoke __init__() by user
X
xj.lin 已提交
15 16 17


class Index():
X
xj.lin 已提交
18
    def build(self, d, vectors, DEVICE=INDEXDEVICES.CPU):
X
xj.lin 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
        pass

    @staticmethod
    def increase(trained_index, vectors):
        trained_index.add((vectors))

    @staticmethod
    def serialize(index):
        writer = faiss.VectorIOWriter()
        faiss.write_index(index, writer)
        array_data = faiss.vector_to_array(writer.data)
        return array_data


class DefaultIndex(Index):
    def __init__(self, *args, **kwargs):
        # maybe need to specif parameters
        pass

X
xj.lin 已提交
38 39
    def build(self, d, vectors, DEVICE=INDEXDEVICES.CPU):
        index = faiss.IndexFlatL2(d)  # trained
X
xj.lin 已提交
40 41 42 43 44 45 46 47 48 49
        index.add(vectors)
        return index


class LowMemoryIndex(Index):
    def __init__(self, *args, **kwargs):
        self.__nlist = 100
        self.__bytes_per_vector = 8
        self.__bits_per_sub_vector = 8

X
xj.lin 已提交
50
    def build(self, d, vectors, DEVICE=INDEXDEVICES.CPU):
X
xj.lin 已提交
51 52 53 54
        # quantizer = faiss.IndexFlatL2(d)
        # index = faiss.IndexIVFPQ(quantizer, d, self.nlist,
        #                          self.__bytes_per_vector, self.__bits_per_sub_vector)
        # return index
X
xj.lin 已提交
55
        pass