vec_index.cpp 7.1 KB
Newer Older
X
MS-154  
xj.lin 已提交
1 2 3 4 5 6
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#include "knowhere/index/vector_index/ivf.h"
X
xj.lin 已提交
7
#include "knowhere/index/vector_index/idmap.h"
X
MS-154  
xj.lin 已提交
8 9
#include "knowhere/index/vector_index/gpu_ivf.h"
#include "knowhere/index/vector_index/cpu_kdt_rng.h"
X
xj.lin 已提交
10
#include "knowhere/index/vector_index/nsg_index.h"
X
xj.lin 已提交
11
#include "knowhere/common/exception.h"
X
MS-154  
xj.lin 已提交
12 13 14

#include "vec_index.h"
#include "vec_impl.h"
X
xj.lin 已提交
15
#include "wrapper_log.h"
X
MS-154  
xj.lin 已提交
16 17 18


namespace zilliz {
X
xj.lin 已提交
19
namespace milvus {
X
MS-154  
xj.lin 已提交
20 21
namespace engine {

X
xj.lin 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
struct FileIOWriter {
    std::fstream fs;
    std::string name;

    FileIOWriter(const std::string &fname);
    ~FileIOWriter();
    size_t operator()(void *ptr, size_t size);
};

struct FileIOReader {
    std::fstream fs;
    std::string name;

    FileIOReader(const std::string &fname);
    ~FileIOReader();
    size_t operator()(void *ptr, size_t size);
    size_t operator()(void *ptr, size_t size, size_t pos);
};

FileIOReader::FileIOReader(const std::string &fname) {
    name = fname;
    fs = std::fstream(name, std::ios::in | std::ios::binary);
}

FileIOReader::~FileIOReader() {
    fs.close();
}

size_t FileIOReader::operator()(void *ptr, size_t size) {
    fs.read(reinterpret_cast<char *>(ptr), size);
}

size_t FileIOReader::operator()(void *ptr, size_t size, size_t pos) {
    return 0;
}

FileIOWriter::FileIOWriter(const std::string &fname) {
    name = fname;
    fs = std::fstream(name, std::ios::out | std::ios::binary);
}

FileIOWriter::~FileIOWriter() {
    fs.close();
}

size_t FileIOWriter::operator()(void *ptr, size_t size) {
    fs.write(reinterpret_cast<char *>(ptr), size);
}


X
xj.lin 已提交
72
VecIndexPtr GetVecIndexFactory(const IndexType &type) {
X
MS-154  
xj.lin 已提交
73
    std::shared_ptr<zilliz::knowhere::VectorIndex> index;
X
xj.lin 已提交
74 75 76 77 78 79 80 81 82 83 84 85 86
    switch (type) {
        case IndexType::FAISS_IDMAP: {
            index = std::make_shared<zilliz::knowhere::IDMAP>();
            return std::make_shared<BFIndex>(index);
        }
        case IndexType::FAISS_IVFFLAT_CPU: {
            index = std::make_shared<zilliz::knowhere::IVF>();
            break;
        }
        case IndexType::FAISS_IVFFLAT_GPU: {
            index = std::make_shared<zilliz::knowhere::GPUIVF>(0);
            break;
        }
X
xj.lin 已提交
87 88
        case IndexType::FAISS_IVFFLAT_MIX: {
            index = std::make_shared<zilliz::knowhere::GPUIVF>(0);
X
xj.lin 已提交
89
            return std::make_shared<IVFMixIndex>(index, IndexType::FAISS_IVFFLAT_MIX);
X
xj.lin 已提交
90
        }
X
xj.lin 已提交
91 92 93 94 95 96 97 98
        case IndexType::FAISS_IVFPQ_CPU: {
            index = std::make_shared<zilliz::knowhere::IVFPQ>();
            break;
        }
        case IndexType::FAISS_IVFPQ_GPU: {
            index = std::make_shared<zilliz::knowhere::GPUIVFPQ>(0);
            break;
        }
X
xj.lin 已提交
99 100 101
        case IndexType::SPTAG_KDT_RNT_CPU: {
            index = std::make_shared<zilliz::knowhere::CPUKDTRNG>();
            break;
X
xj.lin 已提交
102 103 104 105
        }
        case IndexType::FAISS_IVFSQ8_MIX: {
            index = std::make_shared<zilliz::knowhere::GPUIVFSQ>(0);
            return std::make_shared<IVFMixIndex>(index, IndexType::FAISS_IVFSQ8_MIX);
X
xj.lin 已提交
106
        }
X
xj.lin 已提交
107 108 109 110
        case IndexType::NSG_MIX: { // TODO(linxj): bug.
            index = std::make_shared<zilliz::knowhere::NSG>(0);
            break;
        }
X
xj.lin 已提交
111 112 113
        default: {
            return nullptr;
        }
X
MS-154  
xj.lin 已提交
114
    }
X
xj.lin 已提交
115
    return std::make_shared<VecIndexImpl>(index, type);
X
MS-154  
xj.lin 已提交
116 117
}

X
xj.lin 已提交
118
VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary) {
X
MS-154  
xj.lin 已提交
119 120 121 122 123
    auto index = GetVecIndexFactory(index_type);
    index->Load(index_binary);
    return index;
}

X
xj.lin 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
VecIndexPtr read_index(const std::string &location) {
    knowhere::BinarySet load_data_list;
    FileIOReader reader(location);
    reader.fs.seekg(0, reader.fs.end);
    size_t length = reader.fs.tellg();
    reader.fs.seekg(0);

    size_t rp = 0;
    auto current_type = IndexType::INVALID;
    reader(&current_type, sizeof(current_type));
    rp += sizeof(current_type);
    while (rp < length) {
        size_t meta_length;
        reader(&meta_length, sizeof(meta_length));
        rp += sizeof(meta_length);
        reader.fs.seekg(rp);

        auto meta = new char[meta_length];
        reader(meta, meta_length);
        rp += meta_length;
        reader.fs.seekg(rp);

        size_t bin_length;
        reader(&bin_length, sizeof(bin_length));
        rp += sizeof(bin_length);
        reader.fs.seekg(rp);

        auto bin = new uint8_t[bin_length];
        reader(bin, bin_length);
        rp += bin_length;

        auto binptr = std::make_shared<uint8_t>();
        binptr.reset(bin);
        load_data_list.Append(std::string(meta, meta_length), binptr, bin_length);
    }

    return LoadVecIndex(current_type, load_data_list);
}

X
xj.lin 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183
server::KnowhereError write_index(VecIndexPtr index, const std::string &location) {
    try {
        auto binaryset = index->Serialize();
        auto index_type = index->GetType();

        FileIOWriter writer(location);
        writer(&index_type, sizeof(IndexType));
        for (auto &iter: binaryset.binary_map_) {
            auto meta = iter.first.c_str();
            size_t meta_length = iter.first.length();
            writer(&meta_length, sizeof(meta_length));
            writer((void *) meta, meta_length);

            auto binary = iter.second;
            int64_t binary_length = binary->size;
            writer(&binary_length, sizeof(binary_length));
            writer((void *) binary->data.get(), binary_length);
        }
    } catch (knowhere::KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
        return server::KNOWHERE_UNEXPECTED_ERROR;
X
xj.lin 已提交
184
    } catch (std::exception &e) {
X
xj.lin 已提交
185 186
        WRAPPER_LOG_ERROR << e.what();
        return server::KNOWHERE_ERROR;
X
xj.lin 已提交
187
    }
X
xj.lin 已提交
188
    return server::KNOWHERE_SUCCESS;
X
xj.lin 已提交
189 190
}

X
xj.lin 已提交
191 192 193 194 195

// TODO(linxj): redo here.
void AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Config &cfg) {
    if (!cfg.contains("nlist")) { cfg["nlist"] = int(size / 1000000.0 * 16384); }
    if (!cfg.contains("gpu_id")) { cfg["gpu_id"] = int(0); }
X
xj.lin 已提交
196
    if (!cfg.contains("metric_type")) { cfg["metric_type"] = "L2"; }
X
xj.lin 已提交
197 198 199 200 201 202

    switch (type) {
        case IndexType::FAISS_IVFSQ8_MIX: {
            if (!cfg.contains("nbits")) { cfg["nbits"] = int(8); }
            break;
        }
X
xj.lin 已提交
203 204 205
        case IndexType::NSG_MIX: {
            auto scale_factor = round(cfg["dim"].as<int>() / 128.0);
            scale_factor = scale_factor >= 4 ? 4 : scale_factor;
X
xj.lin 已提交
206 207
            cfg["nlist"] = int(size / 1000000.0 * 8192);
            if (!cfg.contains("nprobe")) { cfg["nprobe"] = 6 + 10 * scale_factor; }
X
xj.lin 已提交
208
            if (!cfg.contains("knng")) { cfg["knng"] = 100 + 100 * scale_factor; }
X
xj.lin 已提交
209 210
            if (!cfg.contains("search_length")) { cfg["search_length"] = 40 + 5 * scale_factor; }
            if (!cfg.contains("out_degree")) { cfg["out_degree"] = 50 + 5 * scale_factor; }
X
xj.lin 已提交
211
            if (!cfg.contains("candidate_pool_size")) { cfg["candidate_pool_size"] = 200 + 100 * scale_factor; }
X
xj.lin 已提交
212
            WRAPPER_LOG_DEBUG << pretty_print(cfg);
X
xj.lin 已提交
213 214
            break;
        }
X
xj.lin 已提交
215 216 217
    }
}

X
MS-154  
xj.lin 已提交
218 219 220
}
}
}