Helper.h 3.9 KB
Newer Older
X
xiaojun.lin 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

#include <memory>

#include "knowhere/index/vector_index/IndexGPUIVF.h"
#include "knowhere/index/vector_index/IndexGPUIVFPQ.h"
#include "knowhere/index/vector_index/IndexGPUIVFSQ.h"
#include "knowhere/index/vector_index/IndexIVF.h"
#include "knowhere/index/vector_index/IndexIVFPQ.h"
#include "knowhere/index/vector_index/IndexIVFSQ.h"
#include "knowhere/index/vector_index/IndexIVFSQHybrid.h"

constexpr int DEVICEID = 0;
constexpr int64_t DIM = 128;
constexpr int64_t NB = 10000;
constexpr int64_t NQ = 10;
constexpr int64_t K = 10;
constexpr int64_t PINMEM = 1024 * 1024 * 200;
constexpr int64_t TEMPMEM = 1024 * 1024 * 300;
constexpr int64_t RESNUM = 2;

knowhere::IVFIndexPtr
IndexFactory(const std::string& type) {
    if (type == "IVF") {
        return std::make_shared<knowhere::IVF>();
    } else if (type == "IVFPQ") {
        return std::make_shared<knowhere::IVFPQ>();
    } else if (type == "GPUIVF") {
        return std::make_shared<knowhere::GPUIVF>(DEVICEID);
    } else if (type == "GPUIVFPQ") {
        return std::make_shared<knowhere::GPUIVFPQ>(DEVICEID);
    } else if (type == "IVFSQ") {
        return std::make_shared<knowhere::IVFSQ>();
    } else if (type == "GPUIVFSQ") {
        return std::make_shared<knowhere::GPUIVFSQ>(DEVICEID);
    } else if (type == "IVFSQHybrid") {
        return std::make_shared<knowhere::IVFSQHybrid>(DEVICEID);
    }
}

enum class ParameterType {
    ivf,
    ivfpq,
    ivfsq,
};

class ParamGenerator {
 public:
    static ParamGenerator&
    GetInstance() {
        static ParamGenerator instance;
        return instance;
    }

    knowhere::Config
    Gen(const ParameterType& type) {
        if (type == ParameterType::ivf) {
            auto tempconf = std::make_shared<knowhere::IVFCfg>();
            tempconf->d = DIM;
            tempconf->gpu_id = DEVICEID;
            tempconf->nlist = 100;
            tempconf->nprobe = 4;
            tempconf->k = K;
            tempconf->metric_type = knowhere::METRICTYPE::L2;
            return tempconf;
        } else if (type == ParameterType::ivfpq) {
            auto tempconf = std::make_shared<knowhere::IVFPQCfg>();
            tempconf->d = DIM;
            tempconf->gpu_id = DEVICEID;
            tempconf->nlist = 100;
            tempconf->nprobe = 4;
            tempconf->k = K;
            tempconf->m = 4;
            tempconf->nbits = 8;
            tempconf->metric_type = knowhere::METRICTYPE::L2;
            return tempconf;
        } else if (type == ParameterType::ivfsq) {
            auto tempconf = std::make_shared<knowhere::IVFSQCfg>();
            tempconf->d = DIM;
            tempconf->gpu_id = DEVICEID;
            tempconf->nlist = 100;
            tempconf->nprobe = 4;
            tempconf->k = K;
            tempconf->nbits = 8;
            tempconf->metric_type = knowhere::METRICTYPE::L2;
            return tempconf;
        }
    }
};

#include <gtest/gtest.h>

class TestGpuIndexBase : public ::testing::Test {
 protected:
    void
    SetUp() override {
        knowhere::FaissGpuResourceMgr::GetInstance().InitDevice(DEVICEID, PINMEM, TEMPMEM, RESNUM);
    }

    void
    TearDown() override {
        knowhere::FaissGpuResourceMgr::GetInstance().Free();
    }
};