未验证 提交 9d5ddb28 编写于 作者: C cqy123456 提交者: GitHub

#3576 optimize code structure (#3581)

Signed-off-by: Ncqy <yaya645@126.com>
上级 bf5bdf25
...@@ -60,6 +60,9 @@ using VectorDistances = std::vector<VectorDistance>; ...@@ -60,6 +60,9 @@ using VectorDistances = std::vector<VectorDistance>;
using ResultIds = std::vector<faiss::Index::idx_t>; using ResultIds = std::vector<faiss::Index::idx_t>;
using ResultDistances = std::vector<faiss::Index::distance_t>; using ResultDistances = std::vector<faiss::Index::distance_t>;
using ConCurrentBitset = faiss::ConcurrentBitset;
using ConCurrentBitsetPtr = faiss::ConcurrentBitsetPtr;
/////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////
enum class DataType { enum class DataType {
NONE = 0, NONE = 0,
......
...@@ -300,9 +300,9 @@ ExecutionEngineImpl::VecSearch(milvus::engine::ExecutionEngineContext& context, ...@@ -300,9 +300,9 @@ ExecutionEngineImpl::VecSearch(milvus::engine::ExecutionEngineContext& context,
Status Status
ExecutionEngineImpl::Search(ExecutionEngineContext& context) { ExecutionEngineImpl::Search(ExecutionEngineContext& context) {
try { try {
faiss::ConcurrentBitsetPtr bitset; ConCurrentBitsetPtr bitset;
std::string vector_placeholder; std::string vector_placeholder;
faiss::ConcurrentBitsetPtr list; ConCurrentBitsetPtr list;
SegmentPtr segment_ptr; SegmentPtr segment_ptr;
segment_reader_->GetSegment(segment_ptr); segment_reader_->GetSegment(segment_ptr);
...@@ -360,13 +360,12 @@ ExecutionEngineImpl::Search(ExecutionEngineContext& context) { ...@@ -360,13 +360,12 @@ ExecutionEngineImpl::Search(ExecutionEngineContext& context) {
} }
Status Status
ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& general_query, ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& general_query, ConCurrentBitsetPtr& bitset,
faiss::ConcurrentBitsetPtr& bitset,
std::unordered_map<std::string, DataType>& attr_type, std::unordered_map<std::string, DataType>& attr_type,
std::string& vector_placeholder) { std::string& vector_placeholder) {
Status status = Status::OK(); Status status = Status::OK();
if (general_query->leaf == nullptr) { if (general_query->leaf == nullptr) {
faiss::ConcurrentBitsetPtr left_bitset, right_bitset; ConCurrentBitsetPtr left_bitset, right_bitset;
if (general_query->bin->left_query != nullptr) { if (general_query->bin->left_query != nullptr) {
status = ExecBinaryQuery(general_query->bin->left_query, left_bitset, attr_type, vector_placeholder); status = ExecBinaryQuery(general_query->bin->left_query, left_bitset, attr_type, vector_placeholder);
if (!status.ok()) { if (!status.ok()) {
...@@ -412,16 +411,16 @@ ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& gener ...@@ -412,16 +411,16 @@ ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& gener
return status; return status;
} else { } else {
if (general_query->leaf->term_query != nullptr) { if (general_query->leaf->term_query != nullptr) {
bitset = std::make_shared<faiss::ConcurrentBitset>(entity_count_); bitset = std::make_shared<ConCurrentBitset>(entity_count_);
STATUS_CHECK(ProcessTermQuery(bitset, general_query->leaf->term_query, attr_type)); STATUS_CHECK(ProcessTermQuery(bitset, general_query->leaf->term_query, attr_type));
} }
if (general_query->leaf->range_query != nullptr) { if (general_query->leaf->range_query != nullptr) {
bitset = std::make_shared<faiss::ConcurrentBitset>(entity_count_); bitset = std::make_shared<ConCurrentBitset>(entity_count_);
STATUS_CHECK(ProcessRangeQuery(attr_type, bitset, general_query->leaf->range_query)); STATUS_CHECK(ProcessRangeQuery(attr_type, bitset, general_query->leaf->range_query));
} }
if (!general_query->leaf->vector_placeholder.empty()) { if (!general_query->leaf->vector_placeholder.empty()) {
// skip vector query // skip vector query
bitset = std::make_shared<faiss::ConcurrentBitset>(entity_count_, 255); bitset = std::make_shared<ConCurrentBitset>(entity_count_, 255);
vector_placeholder = general_query->leaf->vector_placeholder; vector_placeholder = general_query->leaf->vector_placeholder;
} }
} }
...@@ -430,8 +429,7 @@ ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& gener ...@@ -430,8 +429,7 @@ ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& gener
template <typename T> template <typename T>
Status Status
ProcessIndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr& index_ptr, ProcessIndexedTermQuery(ConCurrentBitsetPtr& bitset, knowhere::IndexPtr& index_ptr, milvus::json& term_values_json) {
milvus::json& term_values_json) {
try { try {
auto T_index = std::dynamic_pointer_cast<knowhere::StructuredIndexSort<T>>(index_ptr); auto T_index = std::dynamic_pointer_cast<knowhere::StructuredIndexSort<T>>(index_ptr);
if (not T_index) { if (not T_index) {
...@@ -453,7 +451,7 @@ ProcessIndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr& ...@@ -453,7 +451,7 @@ ProcessIndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr&
} }
Status Status
ExecutionEngineImpl::IndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, const std::string& field_name, ExecutionEngineImpl::IndexedTermQuery(ConCurrentBitsetPtr& bitset, const std::string& field_name,
const DataType& data_type, milvus::json& term_values_json) { const DataType& data_type, milvus::json& term_values_json) {
SegmentPtr segment_ptr; SegmentPtr segment_ptr;
segment_reader_->GetSegment(segment_ptr); segment_reader_->GetSegment(segment_ptr);
...@@ -493,7 +491,7 @@ ExecutionEngineImpl::IndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, const ...@@ -493,7 +491,7 @@ ExecutionEngineImpl::IndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, const
} }
Status Status
ExecutionEngineImpl::ProcessTermQuery(faiss::ConcurrentBitsetPtr& bitset, const query::TermQueryPtr& term_query, ExecutionEngineImpl::ProcessTermQuery(ConCurrentBitsetPtr& bitset, const query::TermQueryPtr& term_query,
std::unordered_map<std::string, DataType>& attr_type) { std::unordered_map<std::string, DataType>& attr_type) {
try { try {
auto term_query_json = term_query->json_obj; auto term_query_json = term_query->json_obj;
...@@ -520,8 +518,7 @@ ExecutionEngineImpl::ProcessTermQuery(faiss::ConcurrentBitsetPtr& bitset, const ...@@ -520,8 +518,7 @@ ExecutionEngineImpl::ProcessTermQuery(faiss::ConcurrentBitsetPtr& bitset, const
template <typename T> template <typename T>
Status Status
ProcessIndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr& index_ptr, ProcessIndexedRangeQuery(ConCurrentBitsetPtr& bitset, knowhere::IndexPtr& index_ptr, milvus::json& range_values_json) {
milvus::json& range_values_json) {
try { try {
auto T_index = std::dynamic_pointer_cast<knowhere::StructuredIndexSort<T>>(index_ptr); auto T_index = std::dynamic_pointer_cast<knowhere::StructuredIndexSort<T>>(index_ptr);
...@@ -543,7 +540,7 @@ ProcessIndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr& ...@@ -543,7 +540,7 @@ ProcessIndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr&
} }
Status Status
ExecutionEngineImpl::IndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, const DataType& data_type, ExecutionEngineImpl::IndexedRangeQuery(ConCurrentBitsetPtr& bitset, const DataType& data_type,
knowhere::IndexPtr& index_ptr, milvus::json& range_values_json) { knowhere::IndexPtr& index_ptr, milvus::json& range_values_json) {
auto status = Status::OK(); auto status = Status::OK();
switch (data_type) { switch (data_type) {
...@@ -579,7 +576,7 @@ ExecutionEngineImpl::IndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, const ...@@ -579,7 +576,7 @@ ExecutionEngineImpl::IndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, const
Status Status
ExecutionEngineImpl::ProcessRangeQuery(const std::unordered_map<std::string, DataType>& attr_type, ExecutionEngineImpl::ProcessRangeQuery(const std::unordered_map<std::string, DataType>& attr_type,
faiss::ConcurrentBitsetPtr& bitset, const query::RangeQueryPtr& range_query) { ConCurrentBitsetPtr& bitset, const query::RangeQueryPtr& range_query) {
SegmentPtr segment_ptr; SegmentPtr segment_ptr;
segment_reader_->GetSegment(segment_ptr); segment_reader_->GetSegment(segment_ptr);
try { try {
...@@ -809,7 +806,7 @@ ExecutionEngineImpl::BuildKnowhereIndex(const std::string& field_name, const Col ...@@ -809,7 +806,7 @@ ExecutionEngineImpl::BuildKnowhereIndex(const std::string& field_name, const Col
LOG_ENGINE_DEBUG_ << "Index config: " << conf.dump(); LOG_ENGINE_DEBUG_ << "Index config: " << conf.dump();
std::vector<idx_t> uids; std::vector<idx_t> uids;
faiss::ConcurrentBitsetPtr blacklist; ConCurrentBitsetPtr blacklist;
knowhere::DatasetPtr dataset; knowhere::DatasetPtr dataset;
if (from_index) { if (from_index) {
dataset = dataset =
......
...@@ -13,10 +13,13 @@ ...@@ -13,10 +13,13 @@
#ifdef MILVUS_GPU_VERSION #ifdef MILVUS_GPU_VERSION
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h" #include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
#endif #endif
#include <faiss/Clustering.h>
#include <faiss/utils/distances.h>
#include "config/ServerConfig.h" #include "config/ServerConfig.h"
#include "faiss/FaissHook.h" #include "faiss/FaissHook.h"
#include "scheduler/Utils.h" #include "scheduler/Utils.h"
#include "utils/ConfigUtils.h"
#include "utils/Error.h" #include "utils/Error.h"
#include "utils/Log.h" #include "utils/Log.h"
...@@ -60,6 +63,35 @@ KnowhereResource::Initialize() { ...@@ -60,6 +63,35 @@ KnowhereResource::Initialize() {
return Status(KNOWHERE_UNEXPECTED_ERROR, "FAISS hook fail, CPU not supported!"); return Status(KNOWHERE_UNEXPECTED_ERROR, "FAISS hook fail, CPU not supported!");
} }
// engine config
int64_t omp_thread = config.engine.omp_thread_num();
if (omp_thread > 0) {
omp_set_num_threads(omp_thread);
LOG_SERVER_DEBUG_ << "Specify openmp thread number: " << omp_thread;
} else {
int64_t sys_thread_cnt = 8;
if (milvus::server::GetSystemAvailableThreads(sys_thread_cnt)) {
omp_thread = static_cast<int32_t>(ceil(sys_thread_cnt * 0.5));
omp_set_num_threads(omp_thread);
}
}
// init faiss global variable
int64_t use_blas_threshold = config.engine.use_blas_threshold();
faiss::distance_compute_blas_threshold = use_blas_threshold;
int64_t clustering_type = config.engine.clustering_type();
switch (clustering_type) {
case ClusteringType::K_MEANS:
default:
faiss::clustering_type = faiss::ClusteringType::K_MEANS;
break;
case ClusteringType::K_MEANS_PLUS_PLUS:
faiss::clustering_type = faiss::ClusteringType::K_MEANS_PLUS_PLUS;
break;
}
#ifdef MILVUS_GPU_VERSION #ifdef MILVUS_GPU_VERSION
bool enable_gpu = config.gpu.enable(); bool enable_gpu = config.gpu.enable();
fiu_do_on("KnowhereResource.Initialize.disable_gpu", enable_gpu = false); fiu_do_on("KnowhereResource.Initialize.disable_gpu", enable_gpu = false);
......
...@@ -16,14 +16,10 @@ ...@@ -16,14 +16,10 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <faiss/Clustering.h>
#include <faiss/utils/distances.h>
#include "config/ServerConfig.h" #include "config/ServerConfig.h"
#include "db/DBFactory.h" #include "db/DBFactory.h"
#include "db/snapshot/OperationExecutor.h" #include "db/snapshot/OperationExecutor.h"
#include "utils/CommonUtil.h" #include "utils/CommonUtil.h"
#include "utils/ConfigUtils.h"
#include "utils/Log.h" #include "utils/Log.h"
#include "utils/StringHelpFunctions.h" #include "utils/StringHelpFunctions.h"
...@@ -66,35 +62,6 @@ DBWrapper::StartService() { ...@@ -66,35 +62,6 @@ DBWrapper::StartService() {
opt.transcript_enable_ = config.transcript.enable(); opt.transcript_enable_ = config.transcript.enable();
opt.replay_script_path_ = config.transcript.replay(); opt.replay_script_path_ = config.transcript.replay();
// engine config
int64_t omp_thread = config.engine.omp_thread_num();
if (omp_thread > 0) {
omp_set_num_threads(omp_thread);
LOG_SERVER_DEBUG_ << "Specify openmp thread number: " << omp_thread;
} else {
int64_t sys_thread_cnt = 8;
if (GetSystemAvailableThreads(sys_thread_cnt)) {
omp_thread = static_cast<int32_t>(ceil(sys_thread_cnt * 0.5));
omp_set_num_threads(omp_thread);
}
}
// init faiss global variable
int64_t use_blas_threshold = config.engine.use_blas_threshold();
faiss::distance_compute_blas_threshold = use_blas_threshold;
int64_t clustering_type = config.engine.clustering_type();
switch (clustering_type) {
case ClusteringType::K_MEANS:
default:
faiss::clustering_type = faiss::ClusteringType::K_MEANS;
break;
case ClusteringType::K_MEANS_PLUS_PLUS:
faiss::clustering_type = faiss::ClusteringType::K_MEANS_PLUS_PLUS;
break;
}
// create db root folder // create db root folder
s = CommonUtil::CreateDirectory(opt.meta_.path_); s = CommonUtil::CreateDirectory(opt.meta_.path_);
if (!s.ok()) { if (!s.ok()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册