From 9d5ddb28e492a3859fa08241cde856b60460e8f5 Mon Sep 17 00:00:00 2001 From: cqy123456 <39671710+cqy123456@users.noreply.github.com> Date: Thu, 3 Sep 2020 20:21:55 +0800 Subject: [PATCH] #3576 optimize code structure (#3581) Signed-off-by: cqy --- core/src/db/Types.h | 3 ++ core/src/db/engine/ExecutionEngineImpl.cpp | 31 +++++++++---------- core/src/index/archive/KnowhereResource.cpp | 32 ++++++++++++++++++++ core/src/server/DBWrapper.cpp | 33 --------------------- 4 files changed, 49 insertions(+), 50 deletions(-) diff --git a/core/src/db/Types.h b/core/src/db/Types.h index 2dc2a983..ba3e1ade 100644 --- a/core/src/db/Types.h +++ b/core/src/db/Types.h @@ -60,6 +60,9 @@ using VectorDistances = std::vector; using ResultIds = std::vector; using ResultDistances = std::vector; +using ConCurrentBitset = faiss::ConcurrentBitset; +using ConCurrentBitsetPtr = faiss::ConcurrentBitsetPtr; + /////////////////////////////////////////////////////////////////////////////////////////////////// enum class DataType { NONE = 0, diff --git a/core/src/db/engine/ExecutionEngineImpl.cpp b/core/src/db/engine/ExecutionEngineImpl.cpp index 54b2a7d5..8fc282df 100644 --- a/core/src/db/engine/ExecutionEngineImpl.cpp +++ b/core/src/db/engine/ExecutionEngineImpl.cpp @@ -300,9 +300,9 @@ ExecutionEngineImpl::VecSearch(milvus::engine::ExecutionEngineContext& context, Status ExecutionEngineImpl::Search(ExecutionEngineContext& context) { try { - faiss::ConcurrentBitsetPtr bitset; + ConCurrentBitsetPtr bitset; std::string vector_placeholder; - faiss::ConcurrentBitsetPtr list; + ConCurrentBitsetPtr list; SegmentPtr segment_ptr; segment_reader_->GetSegment(segment_ptr); @@ -360,13 +360,12 @@ ExecutionEngineImpl::Search(ExecutionEngineContext& context) { } Status -ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& general_query, - faiss::ConcurrentBitsetPtr& bitset, +ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& general_query, ConCurrentBitsetPtr& bitset, std::unordered_map& attr_type, std::string& vector_placeholder) { Status status = Status::OK(); if (general_query->leaf == nullptr) { - faiss::ConcurrentBitsetPtr left_bitset, right_bitset; + ConCurrentBitsetPtr left_bitset, right_bitset; if (general_query->bin->left_query != nullptr) { status = ExecBinaryQuery(general_query->bin->left_query, left_bitset, attr_type, vector_placeholder); if (!status.ok()) { @@ -412,16 +411,16 @@ ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& gener return status; } else { if (general_query->leaf->term_query != nullptr) { - bitset = std::make_shared(entity_count_); + bitset = std::make_shared(entity_count_); STATUS_CHECK(ProcessTermQuery(bitset, general_query->leaf->term_query, attr_type)); } if (general_query->leaf->range_query != nullptr) { - bitset = std::make_shared(entity_count_); + bitset = std::make_shared(entity_count_); STATUS_CHECK(ProcessRangeQuery(attr_type, bitset, general_query->leaf->range_query)); } if (!general_query->leaf->vector_placeholder.empty()) { // skip vector query - bitset = std::make_shared(entity_count_, 255); + bitset = std::make_shared(entity_count_, 255); vector_placeholder = general_query->leaf->vector_placeholder; } } @@ -430,8 +429,7 @@ ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& gener template Status -ProcessIndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr& index_ptr, - milvus::json& term_values_json) { +ProcessIndexedTermQuery(ConCurrentBitsetPtr& bitset, knowhere::IndexPtr& index_ptr, milvus::json& term_values_json) { try { auto T_index = std::dynamic_pointer_cast>(index_ptr); if (not T_index) { @@ -453,7 +451,7 @@ ProcessIndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr& } Status -ExecutionEngineImpl::IndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, const std::string& field_name, +ExecutionEngineImpl::IndexedTermQuery(ConCurrentBitsetPtr& bitset, const std::string& field_name, const DataType& data_type, milvus::json& term_values_json) { SegmentPtr segment_ptr; segment_reader_->GetSegment(segment_ptr); @@ -493,7 +491,7 @@ ExecutionEngineImpl::IndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, const } Status -ExecutionEngineImpl::ProcessTermQuery(faiss::ConcurrentBitsetPtr& bitset, const query::TermQueryPtr& term_query, +ExecutionEngineImpl::ProcessTermQuery(ConCurrentBitsetPtr& bitset, const query::TermQueryPtr& term_query, std::unordered_map& attr_type) { try { auto term_query_json = term_query->json_obj; @@ -520,8 +518,7 @@ ExecutionEngineImpl::ProcessTermQuery(faiss::ConcurrentBitsetPtr& bitset, const template Status -ProcessIndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr& index_ptr, - milvus::json& range_values_json) { +ProcessIndexedRangeQuery(ConCurrentBitsetPtr& bitset, knowhere::IndexPtr& index_ptr, milvus::json& range_values_json) { try { auto T_index = std::dynamic_pointer_cast>(index_ptr); @@ -543,7 +540,7 @@ ProcessIndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr& } Status -ExecutionEngineImpl::IndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, const DataType& data_type, +ExecutionEngineImpl::IndexedRangeQuery(ConCurrentBitsetPtr& bitset, const DataType& data_type, knowhere::IndexPtr& index_ptr, milvus::json& range_values_json) { auto status = Status::OK(); switch (data_type) { @@ -579,7 +576,7 @@ ExecutionEngineImpl::IndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, const Status ExecutionEngineImpl::ProcessRangeQuery(const std::unordered_map& attr_type, - faiss::ConcurrentBitsetPtr& bitset, const query::RangeQueryPtr& range_query) { + ConCurrentBitsetPtr& bitset, const query::RangeQueryPtr& range_query) { SegmentPtr segment_ptr; segment_reader_->GetSegment(segment_ptr); try { @@ -809,7 +806,7 @@ ExecutionEngineImpl::BuildKnowhereIndex(const std::string& field_name, const Col LOG_ENGINE_DEBUG_ << "Index config: " << conf.dump(); std::vector uids; - faiss::ConcurrentBitsetPtr blacklist; + ConCurrentBitsetPtr blacklist; knowhere::DatasetPtr dataset; if (from_index) { dataset = diff --git a/core/src/index/archive/KnowhereResource.cpp b/core/src/index/archive/KnowhereResource.cpp index 28646de8..cea8eca9 100644 --- a/core/src/index/archive/KnowhereResource.cpp +++ b/core/src/index/archive/KnowhereResource.cpp @@ -13,10 +13,13 @@ #ifdef MILVUS_GPU_VERSION #include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" #endif +#include +#include #include "config/ServerConfig.h" #include "faiss/FaissHook.h" #include "scheduler/Utils.h" +#include "utils/ConfigUtils.h" #include "utils/Error.h" #include "utils/Log.h" @@ -60,6 +63,35 @@ KnowhereResource::Initialize() { return Status(KNOWHERE_UNEXPECTED_ERROR, "FAISS hook fail, CPU not supported!"); } + // engine config + int64_t omp_thread = config.engine.omp_thread_num(); + + if (omp_thread > 0) { + omp_set_num_threads(omp_thread); + LOG_SERVER_DEBUG_ << "Specify openmp thread number: " << omp_thread; + } else { + int64_t sys_thread_cnt = 8; + if (milvus::server::GetSystemAvailableThreads(sys_thread_cnt)) { + omp_thread = static_cast(ceil(sys_thread_cnt * 0.5)); + omp_set_num_threads(omp_thread); + } + } + + // init faiss global variable + int64_t use_blas_threshold = config.engine.use_blas_threshold(); + faiss::distance_compute_blas_threshold = use_blas_threshold; + + int64_t clustering_type = config.engine.clustering_type(); + switch (clustering_type) { + case ClusteringType::K_MEANS: + default: + faiss::clustering_type = faiss::ClusteringType::K_MEANS; + break; + case ClusteringType::K_MEANS_PLUS_PLUS: + faiss::clustering_type = faiss::ClusteringType::K_MEANS_PLUS_PLUS; + break; + } + #ifdef MILVUS_GPU_VERSION bool enable_gpu = config.gpu.enable(); fiu_do_on("KnowhereResource.Initialize.disable_gpu", enable_gpu = false); diff --git a/core/src/server/DBWrapper.cpp b/core/src/server/DBWrapper.cpp index 65e0d3ac..1a326dbc 100644 --- a/core/src/server/DBWrapper.cpp +++ b/core/src/server/DBWrapper.cpp @@ -16,14 +16,10 @@ #include #include -#include -#include - #include "config/ServerConfig.h" #include "db/DBFactory.h" #include "db/snapshot/OperationExecutor.h" #include "utils/CommonUtil.h" -#include "utils/ConfigUtils.h" #include "utils/Log.h" #include "utils/StringHelpFunctions.h" @@ -66,35 +62,6 @@ DBWrapper::StartService() { opt.transcript_enable_ = config.transcript.enable(); opt.replay_script_path_ = config.transcript.replay(); - // engine config - int64_t omp_thread = config.engine.omp_thread_num(); - - if (omp_thread > 0) { - omp_set_num_threads(omp_thread); - LOG_SERVER_DEBUG_ << "Specify openmp thread number: " << omp_thread; - } else { - int64_t sys_thread_cnt = 8; - if (GetSystemAvailableThreads(sys_thread_cnt)) { - omp_thread = static_cast(ceil(sys_thread_cnt * 0.5)); - omp_set_num_threads(omp_thread); - } - } - - // init faiss global variable - int64_t use_blas_threshold = config.engine.use_blas_threshold(); - faiss::distance_compute_blas_threshold = use_blas_threshold; - - int64_t clustering_type = config.engine.clustering_type(); - switch (clustering_type) { - case ClusteringType::K_MEANS: - default: - faiss::clustering_type = faiss::ClusteringType::K_MEANS; - break; - case ClusteringType::K_MEANS_PLUS_PLUS: - faiss::clustering_type = faiss::ClusteringType::K_MEANS_PLUS_PLUS; - break; - } - // create db root folder s = CommonUtil::CreateDirectory(opt.meta_.path_); if (!s.ok()) { -- GitLab