vector_engine.py 9.4 KB
Newer Older
J
jinhai 已提交
1 2 3 4 5 6 7 8
from engine.model.group_table import GroupTable
from engine.model.file_table import FileTable
from engine.controller.raw_file_handler import RawFileHandler
from engine.controller.group_handler import GroupHandler
from engine.controller.index_file_handler import IndexFileHandler
from engine.settings import ROW_LIMIT
from flask import jsonify
from engine import db
X
xj.lin 已提交
9 10
from engine.ingestion import build_index
from engine.controller.scheduler import Scheduler
X
xj.lin 已提交
11
from engine.ingestion import serialize
J
jinhai 已提交
12 13 14
import sys, os

class VectorEngine(object):
J
jinhai 已提交
15 16
    group_vector_dict = None
    group_vector_id_dict = None
J
jinhai 已提交
17 18 19
    SUCCESS_CODE = 0
    FAULT_CODE = 1
    GROUP_NOT_EXIST = 2
J
jinhai 已提交
20 21 22 23 24 25

    @staticmethod
    def AddGroup(group_id, dimension):
        group = GroupTable.query.filter(GroupTable.group_name==group_id).first()
        if group:
            print('Already create the group: ', group_id)
J
jinhai 已提交
26 27
            return VectorEngine.FAULT_CODE, group_id, group.file_number
            # return jsonify({'code': 1, 'group_name': group_id, 'file_number': group.file_number})
J
jinhai 已提交
28 29 30 31 32 33 34 35
        else:
            print('To create the group: ', group_id)
            new_group = GroupTable(group_id, dimension)
            GroupHandler.CreateGroupDirectory(group_id)

            # add into database
            db.session.add(new_group)
            db.session.commit()
J
jinhai 已提交
36
            return VectorEngine.SUCCESS_CODE, group_id, 0
J
jinhai 已提交
37 38 39 40 41 42


    @staticmethod
    def GetGroup(group_id):
        group = GroupTable.query.filter(GroupTable.group_name==group_id).first()
        if group:
J
jinhai 已提交
43
            return VectorEngine.SUCCESS_CODE, group_id, group.file_number
J
jinhai 已提交
44
        else:
J
jinhai 已提交
45
            return VectorEngine.FAULT_CODE, group_id, 0
J
jinhai 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62


    @staticmethod
    def DeleteGroup(group_id):
        group = GroupTable.query.filter(GroupTable.group_name==group_id).first()
        if(group):
            # old_group = GroupTable(group_id)
            db.session.delete(group)
            db.session.commit()
            GroupHandler.DeleteGroupDirectory(group_id)

            records = FileTable.query.filter(FileTable.group_name == group_id).all()
            for record in records:
                print("record.group_name: ", record.group_name)
                db.session.delete(record)
            db.session.commit()

J
jinhai 已提交
63
            return VectorEngine.SUCCESS_CODE, group_id, group.file_number
J
jinhai 已提交
64
        else:
J
jinhai 已提交
65
            return VectorEngine.SUCCESS_CODE, group_id, 0
J
jinhai 已提交
66 67 68 69 70 71 72 73 74 75 76 77 78


    @staticmethod
    def GetGroupList():
        group = GroupTable.query.all()
        group_list = []
        for group_tuple in group:
            group_item = {}
            group_item['group_name'] = group_tuple.group_name
            group_item['file_number'] = group_tuple.file_number
            group_list.append(group_item)

        print(group_list)
J
jinhai 已提交
79
        return VectorEngine.SUCCESS_CODE, group_list
J
jinhai 已提交
80 81 82


    @staticmethod
J
jinhai 已提交
83 84
    def AddVector(group_id, vectors):
        print(group_id, vectors)
J
jinhai 已提交
85 86
        code, _, _ = VectorEngine.GetGroup(group_id)
        if code == VectorEngine.FAULT_CODE:
J
jinhai 已提交
87
            return VectorEngine.GROUP_NOT_EXIST, 'invalid'
J
jinhai 已提交
88

J
jinhai 已提交
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
        vector_str_list = []
        for vector in vectors:
            file = FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').first()
            group = GroupTable.query.filter(GroupTable.group_name == group_id).first()

            if file:
                print('insert into exist file')
                # create vector id
                vector_id = file.seq_no + 1
                # insert into raw file
                VectorEngine.InsertVectorIntoRawFile(group_id, file.filename, vector, vector_id)

                # check if the file can be indexed
                if file.row_number + 1 >= ROW_LIMIT:
                    raw_vector_array, raw_vector_id_array = VectorEngine.GetVectorListFromRawFile(group_id)
                    d = group.dimension

                    # create index
                    index_builder = build_index.FactoryIndex()
                    index = index_builder().build(d, raw_vector_array, raw_vector_id_array)

                    # TODO(jinhai): store index into Cache
                    index_filename = file.filename + '_index'
                    serialize.write_index(file_name=index_filename, index=index)

                    FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1,
                                                                                                                    'type': 'index',
                                                                                                                    'filename': index_filename,
                                                                                                                    'seq_no': file.seq_no + 1})
                    db.session.commit()
                    VectorEngine.group_dict = None
                else:
                    # we still can insert into exist raw file, update database
                    FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').update({'row_number':file.row_number + 1, 
                                                                                                                    'seq_no': file.seq_no + 1})
                    db.session.commit()
                    print('Update db for raw file insertion')

J
jinhai 已提交
127
            else:
J
jinhai 已提交
128 129 130 131 132 133 134 135 136
                print('add a new raw file')
                # first raw file
                raw_filename = group_id + '.raw'
                # create vector id
                vector_id = 0
                # create and insert vector into raw file
                VectorEngine.InsertVectorIntoRawFile(group_id, raw_filename, vector, vector_id)
                # insert a record into database
                db.session.add(FileTable(group_id, raw_filename, 'raw', 1))
J
jinhai 已提交
137 138
                db.session.commit()

J
jinhai 已提交
139
            vector_str_list.append(group_id + '.' + str(vector_id))
J
jinhai 已提交
140

J
jinhai 已提交
141
        return VectorEngine.SUCCESS_CODE, vector_str_list
J
jinhai 已提交
142 143 144 145


    @staticmethod
    def SearchVector(group_id, vector, limit):
J
jinhai 已提交
146 147 148
        # Check the group exist
        code, _, _ = VectorEngine.GetGroup(group_id)
        if code == VectorEngine.FAULT_CODE:
J
jinhai 已提交
149
            return VectorEngine.GROUP_NOT_EXIST, {}
J
jinhai 已提交
150

X
xj.lin 已提交
151
        group = GroupTable.query.filter(GroupTable.group_name == group_id).first()
J
jinhai 已提交
152 153
        # find all files
        files = FileTable.query.filter(FileTable.group_name == group_id).all()
X
xj.lin 已提交
154
        index_keys = [ i.filename for i in files if i.type == 'index' ]
J
jinhai 已提交
155
        index_map = {}
X
xj.lin 已提交
156
        index_map['index'] = index_keys
J
jinhai 已提交
157
        index_map['raw'], index_map['raw_id'] = VectorEngine.GetVectorListFromRawFile(group_id, "fakename") #TODO: pass by key, get from storage
X
xj.lin 已提交
158
        index_map['dimension'] = group.dimension
J
jinhai 已提交
159

J
jinhai 已提交
160
        scheduler_instance = Scheduler()
X
xj.lin 已提交
161 162
        vectors = []
        vectors.append(vector)
X
xj.lin 已提交
163
        result = scheduler_instance.search(index_map, vectors, limit)
J
jinhai 已提交
164

J
jinhai 已提交
165
        vector_ids_str = []
X
xj.lin 已提交
166
        for int_id in result:
J
jinhai 已提交
167
            vector_ids_str.append(group_id + '.' + str(int_id))
J
jinhai 已提交
168

J
jinhai 已提交
169
        return VectorEngine.SUCCESS_CODE, vector_ids_str
J
jinhai 已提交
170 171 172 173


    @staticmethod
    def CreateIndex(group_id):
J
jinhai 已提交
174 175 176 177 178
        # Check the group exist
        code, _, _ = VectorEngine.GetGroup(group_id)
        if code == VectorEngine.FAULT_CODE:
            return VectorEngine.GROUP_NOT_EXIST

J
jinhai 已提交
179 180 181 182
        # create index
        file = FileTable.query.filter(FileTable.group_name == group_id).filter(FileTable.type == 'raw').first()
        path = GroupHandler.GetGroupDirectory(group_id) + '/' + file.filename 
        print('Going to create index for: ', path)
J
jinhai 已提交
183
        return VectorEngine.SUCCESS_CODE
J
jinhai 已提交
184 185 186


    @staticmethod
J
jinhai 已提交
187
    def InsertVectorIntoRawFile(group_id, filename, vector, vector_id):
J
jinhai 已提交
188 189
        # print(sys._getframe().f_code.co_name, group_id, vector)
        # path = GroupHandler.GetGroupDirectory(group_id) + '/' + filename
J
jinhai 已提交
190 191 192 193 194 195 196 197 198
        if VectorEngine.group_vector_dict is None:
            # print("VectorEngine.group_vector_dict is None")
            VectorEngine.group_vector_dict = dict()

        if VectorEngine.group_vector_id_dict is None:
            VectorEngine.group_vector_id_dict = dict()

        if not (group_id in VectorEngine.group_vector_dict):
            VectorEngine.group_vector_dict[group_id] = []
J
jinhai 已提交
199

J
jinhai 已提交
200 201
        if not (group_id in VectorEngine.group_vector_id_dict):
            VectorEngine.group_vector_id_dict[group_id] = []
J
jinhai 已提交
202

J
jinhai 已提交
203 204
        VectorEngine.group_vector_dict[group_id].append(vector)
        VectorEngine.group_vector_id_dict[group_id].append(vector_id)
J
jinhai 已提交
205

X
xj.lin 已提交
206
        # print('InsertVectorIntoRawFile: ', VectorEngine.group_vector_dict[group_id], VectorEngine.group_vector_id_dict[group_id])
X
xj.lin 已提交
207
        print("cache size: ", len(VectorEngine.group_vector_dict[group_id]))
J
jinhai 已提交
208 209 210 211 212

        return filename


    @staticmethod
X
xj.lin 已提交
213
    def GetVectorListFromRawFile(group_id, filename="todo"):
X
xj.lin 已提交
214 215
        # print("GetVectorListFromRawFile, vectors: ", serialize.to_array(VectorEngine.group_vector_dict[group_id]))
        # print("GetVectorListFromRawFile, vector_ids: ", serialize.to_int_array(VectorEngine.group_vector_id_dict[group_id]))
J
jinhai 已提交
216
        return serialize.to_array(VectorEngine.group_vector_dict[group_id]), serialize.to_int_array(VectorEngine.group_vector_id_dict[group_id])
J
jinhai 已提交
217

J
jinhai 已提交
218 219
    @staticmethod
    def ClearRawFile(group_id):
J
jinhai 已提交
220 221 222
        print("VectorEngine.group_vector_dict: ", VectorEngine.group_vector_dict)
        del VectorEngine.group_vector_dict[group_id]
        del VectorEngine.group_vector_id_dict[group_id]
J
jinhai 已提交
223 224
        return VectorEngine.SUCCESS_CODE