vec_index.cpp 9.6 KB
Newer Older
X
MS-154  
xj.lin 已提交
1 2 3 4 5 6
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#include "knowhere/index/vector_index/ivf.h"
X
xj.lin 已提交
7
#include "knowhere/index/vector_index/idmap.h"
X
MS-154  
xj.lin 已提交
8 9
#include "knowhere/index/vector_index/gpu_ivf.h"
#include "knowhere/index/vector_index/cpu_kdt_rng.h"
X
xj.lin 已提交
10
#include "knowhere/index/vector_index/nsg_index.h"
X
xj.lin 已提交
11
#include "knowhere/common/exception.h"
X
MS-154  
xj.lin 已提交
12 13 14

#include "vec_index.h"
#include "vec_impl.h"
X
xj.lin 已提交
15
#include "wrapper_log.h"
X
MS-154  
xj.lin 已提交
16 17 18


namespace zilliz {
X
xj.lin 已提交
19
namespace milvus {
X
MS-154  
xj.lin 已提交
20 21
namespace engine {

22 23
static constexpr float TYPICAL_COUNT = 1000000.0;

X
xj.lin 已提交
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
struct FileIOWriter {
    std::fstream fs;
    std::string name;

    FileIOWriter(const std::string &fname);
    ~FileIOWriter();
    size_t operator()(void *ptr, size_t size);
};

struct FileIOReader {
    std::fstream fs;
    std::string name;

    FileIOReader(const std::string &fname);
    ~FileIOReader();
    size_t operator()(void *ptr, size_t size);
    size_t operator()(void *ptr, size_t size, size_t pos);
};

FileIOReader::FileIOReader(const std::string &fname) {
    name = fname;
    fs = std::fstream(name, std::ios::in | std::ios::binary);
}

FileIOReader::~FileIOReader() {
    fs.close();
}

size_t FileIOReader::operator()(void *ptr, size_t size) {
    fs.read(reinterpret_cast<char *>(ptr), size);
}

size_t FileIOReader::operator()(void *ptr, size_t size, size_t pos) {
    return 0;
}

FileIOWriter::FileIOWriter(const std::string &fname) {
    name = fname;
    fs = std::fstream(name, std::ios::out | std::ios::binary);
}

FileIOWriter::~FileIOWriter() {
    fs.close();
}

size_t FileIOWriter::operator()(void *ptr, size_t size) {
    fs.write(reinterpret_cast<char *>(ptr), size);
}


74
VecIndexPtr GetVecIndexFactory(const IndexType &type, const Config &cfg) {
X
MS-154  
xj.lin 已提交
75
    std::shared_ptr<zilliz::knowhere::VectorIndex> index;
X
xj.lin 已提交
76
    auto gpu_device = cfg.get_with_default("gpu_id", 0);
X
xj.lin 已提交
77 78 79 80 81 82 83 84 85 86
    switch (type) {
        case IndexType::FAISS_IDMAP: {
            index = std::make_shared<zilliz::knowhere::IDMAP>();
            return std::make_shared<BFIndex>(index);
        }
        case IndexType::FAISS_IVFFLAT_CPU: {
            index = std::make_shared<zilliz::knowhere::IVF>();
            break;
        }
        case IndexType::FAISS_IVFFLAT_GPU: {
X
xj.lin 已提交
87 88
            // TODO(linxj): 规范化参数
            index = std::make_shared<zilliz::knowhere::GPUIVF>(gpu_device);
X
xj.lin 已提交
89 90
            break;
        }
X
xj.lin 已提交
91 92
        case IndexType::FAISS_IVFFLAT_MIX: {
            index = std::make_shared<zilliz::knowhere::GPUIVF>(0);
X
xj.lin 已提交
93
            return std::make_shared<IVFMixIndex>(index, IndexType::FAISS_IVFFLAT_MIX);
X
xj.lin 已提交
94
        }
X
xj.lin 已提交
95 96 97 98 99
        case IndexType::FAISS_IVFPQ_CPU: {
            index = std::make_shared<zilliz::knowhere::IVFPQ>();
            break;
        }
        case IndexType::FAISS_IVFPQ_GPU: {
X
xj.lin 已提交
100
            index = std::make_shared<zilliz::knowhere::GPUIVFPQ>(gpu_device);
X
xj.lin 已提交
101 102
            break;
        }
X
xj.lin 已提交
103 104 105
        case IndexType::SPTAG_KDT_RNT_CPU: {
            index = std::make_shared<zilliz::knowhere::CPUKDTRNG>();
            break;
X
xj.lin 已提交
106 107
        }
        case IndexType::FAISS_IVFSQ8_MIX: {
X
xj.lin 已提交
108
            index = std::make_shared<zilliz::knowhere::GPUIVFSQ>(gpu_device);
X
xj.lin 已提交
109
            return std::make_shared<IVFMixIndex>(index, IndexType::FAISS_IVFSQ8_MIX);
X
xj.lin 已提交
110
        }
X
xj.lin 已提交
111
        case IndexType::FAISS_IVFSQ8_CPU: {
W
wxyu 已提交
112 113 114
            index = std::make_shared<zilliz::knowhere::IVFSQ>();
            break;
        }
X
xj.lin 已提交
115 116 117 118
        case IndexType::FAISS_IVFSQ8_GPU: {
            index = std::make_shared<zilliz::knowhere::GPUIVFSQ>(gpu_device);
            break;
        }
X
xj.lin 已提交
119
        case IndexType::NSG_MIX: { // TODO(linxj): bug.
X
xj.lin 已提交
120
            index = std::make_shared<zilliz::knowhere::NSG>(gpu_device);
X
xj.lin 已提交
121 122
            break;
        }
X
xj.lin 已提交
123 124 125
        default: {
            return nullptr;
        }
X
MS-154  
xj.lin 已提交
126
    }
X
xj.lin 已提交
127
    return std::make_shared<VecIndexImpl>(index, type);
X
MS-154  
xj.lin 已提交
128 129
}

X
xj.lin 已提交
130
VecIndexPtr LoadVecIndex(const IndexType &index_type, const zilliz::knowhere::BinarySet &index_binary) {
X
MS-154  
xj.lin 已提交
131 132 133 134 135
    auto index = GetVecIndexFactory(index_type);
    index->Load(index_binary);
    return index;
}

X
xj.lin 已提交
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174
VecIndexPtr read_index(const std::string &location) {
    knowhere::BinarySet load_data_list;
    FileIOReader reader(location);
    reader.fs.seekg(0, reader.fs.end);
    size_t length = reader.fs.tellg();
    reader.fs.seekg(0);

    size_t rp = 0;
    auto current_type = IndexType::INVALID;
    reader(&current_type, sizeof(current_type));
    rp += sizeof(current_type);
    while (rp < length) {
        size_t meta_length;
        reader(&meta_length, sizeof(meta_length));
        rp += sizeof(meta_length);
        reader.fs.seekg(rp);

        auto meta = new char[meta_length];
        reader(meta, meta_length);
        rp += meta_length;
        reader.fs.seekg(rp);

        size_t bin_length;
        reader(&bin_length, sizeof(bin_length));
        rp += sizeof(bin_length);
        reader.fs.seekg(rp);

        auto bin = new uint8_t[bin_length];
        reader(bin, bin_length);
        rp += bin_length;

        auto binptr = std::make_shared<uint8_t>();
        binptr.reset(bin);
        load_data_list.Append(std::string(meta, meta_length), binptr, bin_length);
    }

    return LoadVecIndex(current_type, load_data_list);
}

X
xj.lin 已提交
175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
server::KnowhereError write_index(VecIndexPtr index, const std::string &location) {
    try {
        auto binaryset = index->Serialize();
        auto index_type = index->GetType();

        FileIOWriter writer(location);
        writer(&index_type, sizeof(IndexType));
        for (auto &iter: binaryset.binary_map_) {
            auto meta = iter.first.c_str();
            size_t meta_length = iter.first.length();
            writer(&meta_length, sizeof(meta_length));
            writer((void *) meta, meta_length);

            auto binary = iter.second;
            int64_t binary_length = binary->size;
            writer(&binary_length, sizeof(binary_length));
            writer((void *) binary->data.get(), binary_length);
        }
    } catch (knowhere::KnowhereException &e) {
        WRAPPER_LOG_ERROR << e.what();
        return server::KNOWHERE_UNEXPECTED_ERROR;
X
xj.lin 已提交
196
    } catch (std::exception &e) {
X
xj.lin 已提交
197 198
        WRAPPER_LOG_ERROR << e.what();
        return server::KNOWHERE_ERROR;
X
xj.lin 已提交
199
    }
X
xj.lin 已提交
200
    return server::KNOWHERE_SUCCESS;
X
xj.lin 已提交
201 202
}

X
xj.lin 已提交
203 204 205

// TODO(linxj): redo here.
void AutoGenParams(const IndexType &type, const long &size, zilliz::knowhere::Config &cfg) {
X
fix  
xj.lin 已提交
206
    auto nlist = cfg.get_with_default("nlist", 0);
W
wxyu 已提交
207
    if (size <= TYPICAL_COUNT / 16384 + 1) {
208 209
        //handle less row count, avoid nlist set to 0
        cfg["nlist"] = 1;
W
wxyu 已提交
210
    } else if (int(size / TYPICAL_COUNT) * nlist == 0) {
211 212 213
        //calculate a proper nlist if nlist not specified or size less than TYPICAL_COUNT
        cfg["nlist"] = int(size / TYPICAL_COUNT * 16384);
    }
X
fix  
xj.lin 已提交
214

X
xj.lin 已提交
215
    if (!cfg.contains("gpu_id")) { cfg["gpu_id"] = int(0); }
X
xj.lin 已提交
216
    if (!cfg.contains("metric_type")) { cfg["metric_type"] = "L2"; }
X
xj.lin 已提交
217 218 219 220 221 222

    switch (type) {
        case IndexType::FAISS_IVFSQ8_MIX: {
            if (!cfg.contains("nbits")) { cfg["nbits"] = int(8); }
            break;
        }
X
xj.lin 已提交
223 224 225
        case IndexType::NSG_MIX: {
            auto scale_factor = round(cfg["dim"].as<int>() / 128.0);
            scale_factor = scale_factor >= 4 ? 4 : scale_factor;
X
xj.lin 已提交
226 227
            cfg["nlist"] = int(size / 1000000.0 * 8192);
            if (!cfg.contains("nprobe")) { cfg["nprobe"] = 6 + 10 * scale_factor; }
X
xj.lin 已提交
228
            if (!cfg.contains("knng")) { cfg["knng"] = 100 + 100 * scale_factor; }
X
xj.lin 已提交
229 230
            if (!cfg.contains("search_length")) { cfg["search_length"] = 40 + 5 * scale_factor; }
            if (!cfg.contains("out_degree")) { cfg["out_degree"] = 50 + 5 * scale_factor; }
X
xj.lin 已提交
231
            if (!cfg.contains("candidate_pool_size")) { cfg["candidate_pool_size"] = 200 + 100 * scale_factor; }
X
xj.lin 已提交
232
            WRAPPER_LOG_DEBUG << pretty_print(cfg);
X
xj.lin 已提交
233 234
            break;
        }
X
xj.lin 已提交
235 236 237
    }
}

238 239 240 241 242 243
#if CUDA_VERSION > 9000
#define GPU_MAX_NRPOBE 2048
#else
#define GPU_MAX_NRPOBE 1024
#endif

244
void ParameterValidation(const IndexType &type, Config &cfg) {
245 246 247 248 249 250 251
    switch (type) {
        case IndexType::FAISS_IVFSQ8_GPU:
        case IndexType::FAISS_IVFFLAT_GPU:
        case IndexType::FAISS_IVFPQ_GPU: {
            if (cfg.get_with_default("nprobe", 0) != 0) {
                auto nprobe = cfg["nprobe"].as<int>();
                if (nprobe > GPU_MAX_NRPOBE) {
252
                    WRAPPER_LOG_WARNING << "When search with GPU, nprobe shoud be no more than " << GPU_MAX_NRPOBE << ", but you passed " << nprobe
253 254 255 256 257 258 259 260 261 262
                                      << ". Search with " << GPU_MAX_NRPOBE << " instead";
                    cfg.insert_or_assign("nprobe", GPU_MAX_NRPOBE);
                }
            }
            break;
        }
        default:break;
    }
}

X
xj.lin 已提交
263 264
IndexType ConvertToCpuIndexType(const IndexType &type) {
    // TODO(linxj): add IDMAP
W
wxyu 已提交
265
    switch (type) {
X
xj.lin 已提交
266
        case IndexType::FAISS_IVFFLAT_GPU:
W
wxyu 已提交
267 268 269
        case IndexType::FAISS_IVFFLAT_MIX: {
            return IndexType::FAISS_IVFFLAT_CPU;
        }
X
xj.lin 已提交
270
        case IndexType::FAISS_IVFSQ8_GPU:
W
wxyu 已提交
271
        case IndexType::FAISS_IVFSQ8_MIX: {
X
xj.lin 已提交
272
            return IndexType::FAISS_IVFSQ8_CPU;
W
wxyu 已提交
273 274
        }
        default: {
X
xj.lin 已提交
275
            return type;
W
wxyu 已提交
276 277 278 279
        }
    }
}

X
xj.lin 已提交
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
IndexType ConvertToGpuIndexType(const IndexType &type) {
    switch (type) {
        case IndexType::FAISS_IVFFLAT_MIX:
        case IndexType::FAISS_IVFFLAT_CPU: {
            return IndexType::FAISS_IVFFLAT_GPU;
        }
        case IndexType::FAISS_IVFSQ8_MIX:
        case IndexType::FAISS_IVFSQ8_CPU: {
            return IndexType::FAISS_IVFSQ8_GPU;
        }
        default: {
            return type;
        }
    }
}


X
MS-154  
xj.lin 已提交
297 298 299
}
}
}