提交 07127989 编写于 作者: W wxyu

MS-631 IVFSQ8H Index support


Former-commit-id: 21e17a20794e4fde31e79c4bbd4e26d46c79d886
上级 6769b6d9
...@@ -36,6 +36,7 @@ Please mark all change in change log and use the ticket from JIRA. ...@@ -36,6 +36,7 @@ Please mark all change in change log and use the ticket from JIRA.
## New Feature ## New Feature
- MS-627 - Integrate new index: IVFSQHybrid - MS-627 - Integrate new index: IVFSQHybrid
- MS-631 - IVFSQ8H Index support
## Task ## Task
- MS-554 - Change license to Apache 2.0 - MS-554 - Change license to Apache 2.0
......
...@@ -32,7 +32,7 @@ CpuCacheMgr::CpuCacheMgr() { ...@@ -32,7 +32,7 @@ CpuCacheMgr::CpuCacheMgr() {
server::Config& config = server::Config::GetInstance(); server::Config& config = server::Config::GetInstance();
Status s; Status s;
int32_t cpu_cache_cap; int64_t cpu_cache_cap;
s = config.GetCacheConfigCpuCacheCapacity(cpu_cache_cap); s = config.GetCacheConfigCpuCacheCapacity(cpu_cache_cap);
if (!s.ok()) { if (!s.ok()) {
SERVER_LOG_ERROR << s.message(); SERVER_LOG_ERROR << s.message();
......
...@@ -36,12 +36,12 @@ GpuCacheMgr::GpuCacheMgr() { ...@@ -36,12 +36,12 @@ GpuCacheMgr::GpuCacheMgr() {
server::Config& config = server::Config::GetInstance(); server::Config& config = server::Config::GetInstance();
Status s; Status s;
int32_t gpu_cache_cap; int64_t gpu_cache_cap;
s = config.GetCacheConfigGpuCacheCapacity(gpu_cache_cap); s = config.GetCacheConfigGpuCacheCapacity(gpu_cache_cap);
if (!s.ok()) { if (!s.ok()) {
SERVER_LOG_ERROR << s.message(); SERVER_LOG_ERROR << s.message();
} }
int32_t cap = gpu_cache_cap * G_BYTE; int64_t cap = gpu_cache_cap * G_BYTE;
cache_ = std::make_shared<Cache<DataObjPtr>>(cap, 1UL << 32); cache_ = std::make_shared<Cache<DataObjPtr>>(cap, 1UL << 32);
float gpu_mem_threshold; float gpu_mem_threshold;
......
...@@ -100,16 +100,20 @@ IVFSQHybrid::CopyCpuToGpu(const int64_t& device_id, const Config& config) { ...@@ -100,16 +100,20 @@ IVFSQHybrid::CopyCpuToGpu(const int64_t& device_id, const Config& config) {
void void
IVFSQHybrid::LoadImpl(const BinarySet& index_binary) { IVFSQHybrid::LoadImpl(const BinarySet& index_binary) {
FaissBaseIndex::LoadImpl(index_binary); // load on cpu FaissBaseIndex::LoadImpl(index_binary); // load on cpu
auto* ivf_index = dynamic_cast<faiss::IndexIVF*>(index_.get());
ivf_index->backup_quantizer();
} }
void void
IVFSQHybrid::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels, IVFSQHybrid::search_impl(int64_t n, const float* data, int64_t k, float* distances, int64_t* labels,
const Config& cfg) { const Config& cfg) {
if (gpu_mode) { if (gpu_mode == 2) {
GPUIVF::search_impl(n, data, k, distances, labels, cfg); GPUIVF::search_impl(n, data, k, distances, labels, cfg);
} else { } else if (gpu_mode == 1) {
ResScope rs(res_, gpu_id_); ResScope rs(res_, gpu_id_);
IVF::search_impl(n, data, k, distances, labels, cfg); IVF::search_impl(n, data, k, distances, labels, cfg);
} else if (gpu_mode == 0) {
IVF::search_impl(n, data, k, distances, labels, cfg);
} }
} }
...@@ -137,8 +141,12 @@ IVFSQHybrid::LoadQuantizer(const Config& conf) { ...@@ -137,8 +141,12 @@ IVFSQHybrid::LoadQuantizer(const Config& conf) {
delete gpu_index; delete gpu_index;
auto q = std::make_shared<FaissIVFQuantizer>(); auto q = std::make_shared<FaissIVFQuantizer>();
q->quantizer = index_composition->quantizer;
auto& q_ptr = index_composition->quantizer;
q->size = q_ptr->d * q_ptr->getNumVecs() * sizeof(float);
q->quantizer = q_ptr;
res_ = res; res_ = res;
gpu_mode = 1;
return q; return q;
} else { } else {
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource"); KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
...@@ -156,7 +164,7 @@ IVFSQHybrid::SetQuantizer(const QuantizerPtr& q) { ...@@ -156,7 +164,7 @@ IVFSQHybrid::SetQuantizer(const QuantizerPtr& q) {
faiss::gpu::GpuIndexFlat* is_gpu_flat_index = dynamic_cast<faiss::gpu::GpuIndexFlat*>(ivf_index->quantizer); faiss::gpu::GpuIndexFlat* is_gpu_flat_index = dynamic_cast<faiss::gpu::GpuIndexFlat*>(ivf_index->quantizer);
if (is_gpu_flat_index == nullptr) { if (is_gpu_flat_index == nullptr) {
delete ivf_index->quantizer; // delete ivf_index->quantizer;
ivf_index->quantizer = ivf_quantizer->quantizer; ivf_index->quantizer = ivf_quantizer->quantizer;
} }
} }
...@@ -199,10 +207,18 @@ IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) { ...@@ -199,10 +207,18 @@ IVFSQHybrid::LoadData(const knowhere::QuantizerPtr& q, const Config& conf) {
auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id_, index_composition, &option); auto gpu_index = faiss::gpu::index_cpu_to_gpu(res->faiss_res.get(), gpu_id_, index_composition, &option);
index_.reset(gpu_index); index_.reset(gpu_index);
gpu_mode = true; // all in gpu gpu_mode = 2; // all in gpu
} else { } else {
KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource"); KNOWHERE_THROW_MSG("CopyCpuToGpu Error, can't get gpu_resource");
} }
} }
FaissIVFQuantizer::~FaissIVFQuantizer() {
if (quantizer != nullptr) {
delete quantizer;
quantizer = nullptr;
}
// else do nothing
}
} // namespace knowhere } // namespace knowhere
...@@ -27,23 +27,25 @@ namespace knowhere { ...@@ -27,23 +27,25 @@ namespace knowhere {
struct FaissIVFQuantizer : public Quantizer { struct FaissIVFQuantizer : public Quantizer {
faiss::gpu::GpuIndexFlat* quantizer = nullptr; faiss::gpu::GpuIndexFlat* quantizer = nullptr;
~FaissIVFQuantizer() override;
}; };
using FaissIVFQuantizerPtr = std::shared_ptr<FaissIVFQuantizer>; using FaissIVFQuantizerPtr = std::shared_ptr<FaissIVFQuantizer>;
class IVFSQHybrid : public GPUIVFSQ { class IVFSQHybrid : public GPUIVFSQ {
public: public:
explicit IVFSQHybrid(const int& device_id) : GPUIVFSQ(device_id) { explicit IVFSQHybrid(const int& device_id) : GPUIVFSQ(device_id) {
gpu_mode = false; gpu_mode = 0;
} }
explicit IVFSQHybrid(std::shared_ptr<faiss::Index> index) : GPUIVFSQ(-1) { explicit IVFSQHybrid(std::shared_ptr<faiss::Index> index) : GPUIVFSQ(-1) {
index_ = index; index_ = index;
gpu_mode = false; gpu_mode = 0;
} }
explicit IVFSQHybrid(std::shared_ptr<faiss::Index> index, const int64_t& device_id, ResPtr& resource) explicit IVFSQHybrid(std::shared_ptr<faiss::Index> index, const int64_t& device_id, ResPtr& resource)
: GPUIVFSQ(index, device_id, resource) { : GPUIVFSQ(index, device_id, resource) {
gpu_mode = true; gpu_mode = 2;
} }
public: public:
...@@ -76,7 +78,7 @@ class IVFSQHybrid : public GPUIVFSQ { ...@@ -76,7 +78,7 @@ class IVFSQHybrid : public GPUIVFSQ {
LoadImpl(const BinarySet& index_binary) override; LoadImpl(const BinarySet& index_binary) override;
protected: protected:
bool gpu_mode = false; int64_t gpu_mode = 0; // 0,1,2
}; };
} // namespace knowhere } // namespace knowhere
...@@ -24,11 +24,13 @@ namespace knowhere { ...@@ -24,11 +24,13 @@ namespace knowhere {
struct Quantizer { struct Quantizer {
virtual ~Quantizer() = default; virtual ~Quantizer() = default;
int64_t size = -1;
}; };
using QuantizerPtr = std::shared_ptr<Quantizer>; using QuantizerPtr = std::shared_ptr<Quantizer>;
struct QuantizerCfg : Cfg { struct QuantizerCfg : Cfg {
uint64_t mode = -1; // 0: all data, 1: copy quantizer, 2: copy data int64_t mode = -1; // 0: all data, 1: copy quantizer, 2: copy data
}; };
using QuantizerConfig = std::shared_ptr<QuantizerCfg>; using QuantizerConfig = std::shared_ptr<QuantizerCfg>;
......
...@@ -32,7 +32,8 @@ enum class EngineType { ...@@ -32,7 +32,8 @@ enum class EngineType {
FAISS_IVFFLAT, FAISS_IVFFLAT,
FAISS_IVFSQ8, FAISS_IVFSQ8,
NSG_MIX, NSG_MIX,
MAX_VALUE = NSG_MIX, FAISS_IVFSQ8H,
MAX_VALUE = FAISS_IVFSQ8H,
}; };
enum class MetricType { enum class MetricType {
......
...@@ -33,10 +33,31 @@ ...@@ -33,10 +33,31 @@
#include <stdexcept> #include <stdexcept>
#include <utility> #include <utility>
#include <src/scheduler/Utils.h>
namespace milvus { namespace milvus {
namespace engine { namespace engine {
class CachedQuantizer : public cache::DataObj {
public:
explicit
CachedQuantizer(knowhere::QuantizerPtr data)
: data_(std::move(data)) {}
knowhere::QuantizerPtr
Data() {
return data_;
}
int64_t
Size() override {
return data_->size;
}
private:
knowhere::QuantizerPtr data_;
};
ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string& location, EngineType index_type, ExecutionEngineImpl::ExecutionEngineImpl(uint16_t dimension, const std::string& location, EngineType index_type,
MetricType metric_type, int32_t nlist) MetricType metric_type, int32_t nlist)
: location_(location), dim_(dimension), index_type_(index_type), metric_type_(metric_type), nlist_(nlist) { : location_(location), dim_(dimension), index_type_(index_type), metric_type_(metric_type), nlist_(nlist) {
...@@ -83,6 +104,10 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) { ...@@ -83,6 +104,10 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
index = GetVecIndexFactory(IndexType::NSG_MIX); index = GetVecIndexFactory(IndexType::NSG_MIX);
break; break;
} }
case EngineType::FAISS_IVFSQ8H: {
index = GetVecIndexFactory(IndexType::FAISS_IVFSQ8_HYBRID);
break;
}
default: { default: {
ENGINE_LOG_ERROR << "Invalid engine type"; ENGINE_LOG_ERROR << "Invalid engine type";
return nullptr; return nullptr;
...@@ -92,57 +117,63 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) { ...@@ -92,57 +117,63 @@ ExecutionEngineImpl::CreatetVecIndex(EngineType type) {
} }
void void
ExecutionEngineImpl::HybridLoad() { ExecutionEngineImpl::HybridLoad() const {
// if (index_type_ != EngineType::FAISS_IVFSQ8Hybrid) { if (index_type_ != EngineType::FAISS_IVFSQ8H) {
// return; return;
// } }
//
// const std::string key = location_ + ".quantizer"; const std::string key = location_ + ".quantizer";
// std::vector<uint64_t> gpus; std::vector<uint64_t> gpus = scheduler::get_gpu_pool();
//
// // cache hit // cache hit
// { {
// int64_t selected = -1; const int64_t NOT_FOUND = -1;
// void* quantizer = nullptr; int64_t device_id = NOT_FOUND;
// for (auto& gpu : gpus) { knowhere::QuantizerPtr quantizer = nullptr;
// auto cache = cache::GpuCacheMgr::GetInstance(gpu);
// if (auto quan = cache->GetIndex(key)) { for (auto& gpu : gpus) {
// selected = gpu; auto cache = cache::GpuCacheMgr::GetInstance(gpu);
// quantizer = quan; if (auto cached_quantizer = cache->GetIndex(key)) {
// } device_id = gpu;
// } quantizer = std::static_pointer_cast<CachedQuantizer>(cached_quantizer)->Data();
// }
// if (selected != -1) { }
// // set quantizer into index;
// return; if (device_id != NOT_FOUND) {
// } index_->SetQuantizer(quantizer);
// } return;
// }
// // cache miss }
// {
// std::vector<int64_t> all_free_mem; // cache miss
// for (auto& gpu : gpus) { {
// auto cache = cache::GpuCacheMgr::GetInstance(gpu); std::vector<int64_t> all_free_mem;
// auto free_mem = cache->CacheCapacity() - cache->CacheUsage(); for (auto& gpu : gpus) {
// all_free_mem.push_back(free_mem); auto cache = cache::GpuCacheMgr::GetInstance(gpu);
// } auto free_mem = cache->CacheCapacity() - cache->CacheUsage();
// all_free_mem.push_back(free_mem);
// auto max_e = std::max_element(all_free_mem.begin(), all_free_mem.end()); }
// auto best = std::distance(all_free_mem.begin(), max_e);
// auto max_e = std::max_element(all_free_mem.begin(), all_free_mem.end());
// // load to best device; auto best_index = std::distance(all_free_mem.begin(), max_e);
// // cache quantizer auto best_device_id = gpus[best_index];
// }
// auto quantizer_conf = std::make_shared<knowhere::QuantizerCfg>();
// // if index_type == Hybrid quantizer_conf->mode = 1;
// quantizer_conf->gpu_id = best_device_id;
// // 1. quantizer in which gpu auto quantizer = index_->LoadQuantizer(quantizer_conf);
// index_->SetQuantizer(quantizer);
// // 2.1 which gpu cache best auto cache_quantizer = std::make_shared<CachedQuantizer>(quantizer);
// cache::GpuCacheMgr::GetInstance(best_device_id)->InsertItem(key, cache_quantizer);
// // 2.2 load to that gpu cache }
// }
// // set quantizer into index
void
ExecutionEngineImpl::HybridUnset() const {
if (index_type_ != EngineType::FAISS_IVFSQ8H) {
return;
}
index_->UnsetQuantizer();
} }
Status Status
...@@ -375,7 +406,12 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr ...@@ -375,7 +406,12 @@ ExecutionEngineImpl::Search(int64_t n, const float* data, int64_t k, int64_t npr
auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType()); auto adapter = AdapterMgr::GetInstance().GetAdapter(index_->GetType());
auto conf = adapter->MatchSearch(temp_conf, index_->GetType()); auto conf = adapter->MatchSearch(temp_conf, index_->GetType());
HybridLoad();
auto status = index_->Search(n, data, distances, labels, conf); auto status = index_->Search(n, data, distances, labels, conf);
HybridUnset();
if (!status.ok()) { if (!status.ok()) {
ENGINE_LOG_ERROR << "Search error"; ENGINE_LOG_ERROR << "Search error";
} }
......
...@@ -108,7 +108,10 @@ class ExecutionEngineImpl : public ExecutionEngine { ...@@ -108,7 +108,10 @@ class ExecutionEngineImpl : public ExecutionEngine {
Load(const std::string& location); Load(const std::string& location);
void void
HybridLoad(); HybridLoad() const;
void
HybridUnset() const;
protected: protected:
VecIndexPtr index_ = nullptr; VecIndexPtr index_ = nullptr;
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "task/Task.h" #include "task/Task.h"
#include <utility> #include <utility>
#include <src/scheduler/optimizer/Optimizer.h>
namespace milvus { namespace milvus {
namespace scheduler { namespace scheduler {
...@@ -66,6 +67,9 @@ JobMgr::worker_function() { ...@@ -66,6 +67,9 @@ JobMgr::worker_function() {
} }
auto tasks = build_task(job); auto tasks = build_task(job);
// TODO: optimizer all task
// disk resources NEVER be empty. // disk resources NEVER be empty.
if (auto disk = res_mgr_->GetDiskResources()[0].lock()) { if (auto disk = res_mgr_->GetDiskResources()[0].lock()) {
for (auto& task : tasks) { for (auto& task : tasks) {
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
// specific language governing permissions and limitations // specific language governing permissions and limitations
// under the License. // under the License.
#include "server/Config.h"
#include "scheduler/Utils.h" #include "scheduler/Utils.h"
#include "utils/Log.h"
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <chrono> #include <chrono>
...@@ -38,5 +40,42 @@ get_num_gpu() { ...@@ -38,5 +40,42 @@ get_num_gpu() {
return n_devices; return n_devices;
} }
std::vector<uint64_t>
get_gpu_pool() {
std::vector<uint64_t> gpu_pool;
server::Config& config = server::Config::GetInstance();
std::vector<std::string> pool;
Status s = config.GetResourceConfigPool(pool);
if (!s.ok()) {
SERVER_LOG_ERROR << s.message();
}
std::set<uint64_t> gpu_ids;
for (auto& resource : pool) {
if (resource == "cpu") {
continue;
} else {
if (resource.length() < 4 || resource.substr(0, 3) != "gpu") {
// error
exit(-1);
}
auto gpu_id = std::stoi(resource.substr(3));
if (gpu_id >= scheduler::get_num_gpu()) {
// error
exit(-1);
}
gpu_ids.insert(gpu_id);
}
}
for (auto& gpu_id : gpu_ids) {
gpu_pool.push_back(gpu_id);
}
return gpu_pool;
};
} // namespace scheduler } // namespace scheduler
} // namespace milvus } // namespace milvus
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
// under the License. // under the License.
#include <cstdint> #include <cstdint>
#include <vector>
namespace milvus { namespace milvus {
namespace scheduler { namespace scheduler {
...@@ -26,5 +27,8 @@ get_current_timestamp(); ...@@ -26,5 +27,8 @@ get_current_timestamp();
uint64_t uint64_t
get_num_gpu(); get_num_gpu();
std::vector<uint64_t>
get_gpu_pool();
} // namespace scheduler } // namespace scheduler
} // namespace milvus } // namespace milvus
...@@ -23,11 +23,14 @@ namespace scheduler { ...@@ -23,11 +23,14 @@ namespace scheduler {
bool bool
HybridPass::Run(const TaskPtr& task) { HybridPass::Run(const TaskPtr& task) {
// TODO: Index::IVFSQ8Hybrid, if nq < threshold set cpu, else set gpu // TODO: future, Index::IVFSQ8H, if nq < threshold set cpu, else set gpu
if (task->Type() != TaskType::SearchTask) if (task->Type() != TaskType::SearchTask)
return false; return false;
auto search_task = std::static_pointer_cast<XSearchTask>(task); auto search_task = std::static_pointer_cast<XSearchTask>(task);
// if (search_task->file_->engine_type_ == engine::EngineType::FAISS_IVFSQ8Hybrid) if (search_task->file_->engine_type_ == (int)engine::EngineType::FAISS_IVFSQ8H) {
// TODO: make specified label
return true;
}
return false; return false;
} }
......
...@@ -36,6 +36,7 @@ enum class IndexType { ...@@ -36,6 +36,7 @@ enum class IndexType {
gpu_ivfflat, gpu_ivfflat,
gpu_ivfsq8, gpu_ivfsq8,
mix_nsg, mix_nsg,
ivfsq8h,
}; };
enum class MetricType { enum class MetricType {
......
...@@ -161,7 +161,7 @@ Config::ValidateConfig() { ...@@ -161,7 +161,7 @@ Config::ValidateConfig() {
} }
/* cache config */ /* cache config */
int32_t cache_cpu_cache_capacity; int64_t cache_cpu_cache_capacity;
s = GetCacheConfigCpuCacheCapacity(cache_cpu_cache_capacity); s = GetCacheConfigCpuCacheCapacity(cache_cpu_cache_capacity);
if (!s.ok()) { if (!s.ok()) {
return s; return s;
...@@ -173,7 +173,7 @@ Config::ValidateConfig() { ...@@ -173,7 +173,7 @@ Config::ValidateConfig() {
return s; return s;
} }
int32_t cache_gpu_cache_capacity; int64_t cache_gpu_cache_capacity;
s = GetCacheConfigGpuCacheCapacity(cache_gpu_cache_capacity); s = GetCacheConfigGpuCacheCapacity(cache_gpu_cache_capacity);
if (!s.ok()) { if (!s.ok()) {
return s; return s;
...@@ -789,7 +789,7 @@ Config::GetMetricConfigPrometheusPort(std::string& value) { ...@@ -789,7 +789,7 @@ Config::GetMetricConfigPrometheusPort(std::string& value) {
} }
Status Status
Config::GetCacheConfigCpuCacheCapacity(int32_t& value) { Config::GetCacheConfigCpuCacheCapacity(int64_t& value) {
std::string str = std::string str =
GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_CPU_CACHE_CAPACITY, CONFIG_CACHE_CPU_CACHE_CAPACITY_DEFAULT); GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_CPU_CACHE_CAPACITY, CONFIG_CACHE_CPU_CACHE_CAPACITY_DEFAULT);
Status s = CheckCacheConfigCpuCacheCapacity(str); Status s = CheckCacheConfigCpuCacheCapacity(str);
...@@ -815,7 +815,7 @@ Config::GetCacheConfigCpuCacheThreshold(float& value) { ...@@ -815,7 +815,7 @@ Config::GetCacheConfigCpuCacheThreshold(float& value) {
} }
Status Status
Config::GetCacheConfigGpuCacheCapacity(int32_t& value) { Config::GetCacheConfigGpuCacheCapacity(int64_t& value) {
std::string str = std::string str =
GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_GPU_CACHE_CAPACITY, CONFIG_CACHE_GPU_CACHE_CAPACITY_DEFAULT); GetConfigStr(CONFIG_CACHE, CONFIG_CACHE_GPU_CACHE_CAPACITY, CONFIG_CACHE_GPU_CACHE_CAPACITY_DEFAULT);
Status s = CheckCacheConfigGpuCacheCapacity(str); Status s = CheckCacheConfigGpuCacheCapacity(str);
......
...@@ -221,11 +221,11 @@ class Config { ...@@ -221,11 +221,11 @@ class Config {
/* cache config */ /* cache config */
Status Status
GetCacheConfigCpuCacheCapacity(int32_t& value); GetCacheConfigCpuCacheCapacity(int64_t& value);
Status Status
GetCacheConfigCpuCacheThreshold(float& value); GetCacheConfigCpuCacheThreshold(float& value);
Status Status
GetCacheConfigGpuCacheCapacity(int32_t& value); GetCacheConfigGpuCacheCapacity(int64_t& value);
Status Status
GetCacheConfigGpuCacheThreshold(float& value); GetCacheConfigGpuCacheThreshold(float& value);
Status Status
......
...@@ -49,6 +49,7 @@ AdapterMgr::RegisterAdapter() { ...@@ -49,6 +49,7 @@ AdapterMgr::RegisterAdapter() {
REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexType::FAISS_IVFSQ8_CPU, ivfsq8_cpu); REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexType::FAISS_IVFSQ8_CPU, ivfsq8_cpu);
REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexType::FAISS_IVFSQ8_GPU, ivfsq8_gpu); REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexType::FAISS_IVFSQ8_GPU, ivfsq8_gpu);
REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexType::FAISS_IVFSQ8_MIX, ivfsq8_mix); REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexType::FAISS_IVFSQ8_MIX, ivfsq8_mix);
REGISTER_CONF_ADAPTER(IVFSQConfAdapter, IndexType::FAISS_IVFSQ8_HYBRID, ivfsq8_h);
REGISTER_CONF_ADAPTER(IVFPQConfAdapter, IndexType::FAISS_IVFPQ_CPU, ivfpq_cpu); REGISTER_CONF_ADAPTER(IVFPQConfAdapter, IndexType::FAISS_IVFPQ_CPU, ivfpq_cpu);
REGISTER_CONF_ADAPTER(IVFPQConfAdapter, IndexType::FAISS_IVFPQ_GPU, ivfpq_gpu); REGISTER_CONF_ADAPTER(IVFPQConfAdapter, IndexType::FAISS_IVFPQ_GPU, ivfpq_gpu);
......
...@@ -93,6 +93,10 @@ class IVFMixIndex : public VecIndexImpl { ...@@ -93,6 +93,10 @@ class IVFMixIndex : public VecIndexImpl {
class IVFHybridIndex : public IVFMixIndex { class IVFHybridIndex : public IVFMixIndex {
public: public:
explicit IVFHybridIndex(std::shared_ptr<knowhere::VectorIndex> index, const IndexType& type)
: IVFMixIndex(std::move(index), type) {
}
knowhere::QuantizerPtr knowhere::QuantizerPtr
LoadQuantizer(const Config& conf) override; LoadQuantizer(const Config& conf) override;
......
...@@ -145,7 +145,7 @@ GetVecIndexFactory(const IndexType& type, const Config& cfg) { ...@@ -145,7 +145,7 @@ GetVecIndexFactory(const IndexType& type, const Config& cfg) {
} }
case IndexType::FAISS_IVFSQ8_HYBRID: { case IndexType::FAISS_IVFSQ8_HYBRID: {
index = std::make_shared<knowhere::IVFSQHybrid>(gpu_device); index = std::make_shared<knowhere::IVFSQHybrid>(gpu_device);
break; return std::make_shared<IVFHybridIndex>(index, IndexType::FAISS_IVFSQ8_HYBRID);
} }
case IndexType::NSG_MIX: { // TODO(linxj): bug. case IndexType::NSG_MIX: { // TODO(linxj): bug.
index = std::make_shared<knowhere::NSG>(gpu_device); index = std::make_shared<knowhere::NSG>(gpu_device);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册