未验证 提交 9d5ddb28 编写于 作者: C cqy123456 提交者: GitHub

#3576 optimize code structure (#3581)

Signed-off-by: Ncqy <yaya645@126.com>
上级 bf5bdf25
......@@ -60,6 +60,9 @@ using VectorDistances = std::vector<VectorDistance>;
using ResultIds = std::vector<faiss::Index::idx_t>;
using ResultDistances = std::vector<faiss::Index::distance_t>;
using ConCurrentBitset = faiss::ConcurrentBitset;
using ConCurrentBitsetPtr = faiss::ConcurrentBitsetPtr;
///////////////////////////////////////////////////////////////////////////////////////////////////
enum class DataType {
NONE = 0,
......
......@@ -300,9 +300,9 @@ ExecutionEngineImpl::VecSearch(milvus::engine::ExecutionEngineContext& context,
Status
ExecutionEngineImpl::Search(ExecutionEngineContext& context) {
try {
faiss::ConcurrentBitsetPtr bitset;
ConCurrentBitsetPtr bitset;
std::string vector_placeholder;
faiss::ConcurrentBitsetPtr list;
ConCurrentBitsetPtr list;
SegmentPtr segment_ptr;
segment_reader_->GetSegment(segment_ptr);
......@@ -360,13 +360,12 @@ ExecutionEngineImpl::Search(ExecutionEngineContext& context) {
}
Status
ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& general_query,
faiss::ConcurrentBitsetPtr& bitset,
ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& general_query, ConCurrentBitsetPtr& bitset,
std::unordered_map<std::string, DataType>& attr_type,
std::string& vector_placeholder) {
Status status = Status::OK();
if (general_query->leaf == nullptr) {
faiss::ConcurrentBitsetPtr left_bitset, right_bitset;
ConCurrentBitsetPtr left_bitset, right_bitset;
if (general_query->bin->left_query != nullptr) {
status = ExecBinaryQuery(general_query->bin->left_query, left_bitset, attr_type, vector_placeholder);
if (!status.ok()) {
......@@ -412,16 +411,16 @@ ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& gener
return status;
} else {
if (general_query->leaf->term_query != nullptr) {
bitset = std::make_shared<faiss::ConcurrentBitset>(entity_count_);
bitset = std::make_shared<ConCurrentBitset>(entity_count_);
STATUS_CHECK(ProcessTermQuery(bitset, general_query->leaf->term_query, attr_type));
}
if (general_query->leaf->range_query != nullptr) {
bitset = std::make_shared<faiss::ConcurrentBitset>(entity_count_);
bitset = std::make_shared<ConCurrentBitset>(entity_count_);
STATUS_CHECK(ProcessRangeQuery(attr_type, bitset, general_query->leaf->range_query));
}
if (!general_query->leaf->vector_placeholder.empty()) {
// skip vector query
bitset = std::make_shared<faiss::ConcurrentBitset>(entity_count_, 255);
bitset = std::make_shared<ConCurrentBitset>(entity_count_, 255);
vector_placeholder = general_query->leaf->vector_placeholder;
}
}
......@@ -430,8 +429,7 @@ ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& gener
template <typename T>
Status
ProcessIndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr& index_ptr,
milvus::json& term_values_json) {
ProcessIndexedTermQuery(ConCurrentBitsetPtr& bitset, knowhere::IndexPtr& index_ptr, milvus::json& term_values_json) {
try {
auto T_index = std::dynamic_pointer_cast<knowhere::StructuredIndexSort<T>>(index_ptr);
if (not T_index) {
......@@ -453,7 +451,7 @@ ProcessIndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr&
}
Status
ExecutionEngineImpl::IndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, const std::string& field_name,
ExecutionEngineImpl::IndexedTermQuery(ConCurrentBitsetPtr& bitset, const std::string& field_name,
const DataType& data_type, milvus::json& term_values_json) {
SegmentPtr segment_ptr;
segment_reader_->GetSegment(segment_ptr);
......@@ -493,7 +491,7 @@ ExecutionEngineImpl::IndexedTermQuery(faiss::ConcurrentBitsetPtr& bitset, const
}
Status
ExecutionEngineImpl::ProcessTermQuery(faiss::ConcurrentBitsetPtr& bitset, const query::TermQueryPtr& term_query,
ExecutionEngineImpl::ProcessTermQuery(ConCurrentBitsetPtr& bitset, const query::TermQueryPtr& term_query,
std::unordered_map<std::string, DataType>& attr_type) {
try {
auto term_query_json = term_query->json_obj;
......@@ -520,8 +518,7 @@ ExecutionEngineImpl::ProcessTermQuery(faiss::ConcurrentBitsetPtr& bitset, const
template <typename T>
Status
ProcessIndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr& index_ptr,
milvus::json& range_values_json) {
ProcessIndexedRangeQuery(ConCurrentBitsetPtr& bitset, knowhere::IndexPtr& index_ptr, milvus::json& range_values_json) {
try {
auto T_index = std::dynamic_pointer_cast<knowhere::StructuredIndexSort<T>>(index_ptr);
......@@ -543,7 +540,7 @@ ProcessIndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr&
}
Status
ExecutionEngineImpl::IndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, const DataType& data_type,
ExecutionEngineImpl::IndexedRangeQuery(ConCurrentBitsetPtr& bitset, const DataType& data_type,
knowhere::IndexPtr& index_ptr, milvus::json& range_values_json) {
auto status = Status::OK();
switch (data_type) {
......@@ -579,7 +576,7 @@ ExecutionEngineImpl::IndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, const
Status
ExecutionEngineImpl::ProcessRangeQuery(const std::unordered_map<std::string, DataType>& attr_type,
faiss::ConcurrentBitsetPtr& bitset, const query::RangeQueryPtr& range_query) {
ConCurrentBitsetPtr& bitset, const query::RangeQueryPtr& range_query) {
SegmentPtr segment_ptr;
segment_reader_->GetSegment(segment_ptr);
try {
......@@ -809,7 +806,7 @@ ExecutionEngineImpl::BuildKnowhereIndex(const std::string& field_name, const Col
LOG_ENGINE_DEBUG_ << "Index config: " << conf.dump();
std::vector<idx_t> uids;
faiss::ConcurrentBitsetPtr blacklist;
ConCurrentBitsetPtr blacklist;
knowhere::DatasetPtr dataset;
if (from_index) {
dataset =
......
......@@ -13,10 +13,13 @@
#ifdef MILVUS_GPU_VERSION
#include "knowhere/index/vector_index/helpers/FaissGpuResourceMgr.h"
#endif
#include <faiss/Clustering.h>
#include <faiss/utils/distances.h>
#include "config/ServerConfig.h"
#include "faiss/FaissHook.h"
#include "scheduler/Utils.h"
#include "utils/ConfigUtils.h"
#include "utils/Error.h"
#include "utils/Log.h"
......@@ -60,6 +63,35 @@ KnowhereResource::Initialize() {
return Status(KNOWHERE_UNEXPECTED_ERROR, "FAISS hook fail, CPU not supported!");
}
// engine config
int64_t omp_thread = config.engine.omp_thread_num();
if (omp_thread > 0) {
omp_set_num_threads(omp_thread);
LOG_SERVER_DEBUG_ << "Specify openmp thread number: " << omp_thread;
} else {
int64_t sys_thread_cnt = 8;
if (milvus::server::GetSystemAvailableThreads(sys_thread_cnt)) {
omp_thread = static_cast<int32_t>(ceil(sys_thread_cnt * 0.5));
omp_set_num_threads(omp_thread);
}
}
// init faiss global variable
int64_t use_blas_threshold = config.engine.use_blas_threshold();
faiss::distance_compute_blas_threshold = use_blas_threshold;
int64_t clustering_type = config.engine.clustering_type();
switch (clustering_type) {
case ClusteringType::K_MEANS:
default:
faiss::clustering_type = faiss::ClusteringType::K_MEANS;
break;
case ClusteringType::K_MEANS_PLUS_PLUS:
faiss::clustering_type = faiss::ClusteringType::K_MEANS_PLUS_PLUS;
break;
}
#ifdef MILVUS_GPU_VERSION
bool enable_gpu = config.gpu.enable();
fiu_do_on("KnowhereResource.Initialize.disable_gpu", enable_gpu = false);
......
......@@ -16,14 +16,10 @@
#include <string>
#include <vector>
#include <faiss/Clustering.h>
#include <faiss/utils/distances.h>
#include "config/ServerConfig.h"
#include "db/DBFactory.h"
#include "db/snapshot/OperationExecutor.h"
#include "utils/CommonUtil.h"
#include "utils/ConfigUtils.h"
#include "utils/Log.h"
#include "utils/StringHelpFunctions.h"
......@@ -66,35 +62,6 @@ DBWrapper::StartService() {
opt.transcript_enable_ = config.transcript.enable();
opt.replay_script_path_ = config.transcript.replay();
// engine config
int64_t omp_thread = config.engine.omp_thread_num();
if (omp_thread > 0) {
omp_set_num_threads(omp_thread);
LOG_SERVER_DEBUG_ << "Specify openmp thread number: " << omp_thread;
} else {
int64_t sys_thread_cnt = 8;
if (GetSystemAvailableThreads(sys_thread_cnt)) {
omp_thread = static_cast<int32_t>(ceil(sys_thread_cnt * 0.5));
omp_set_num_threads(omp_thread);
}
}
// init faiss global variable
int64_t use_blas_threshold = config.engine.use_blas_threshold();
faiss::distance_compute_blas_threshold = use_blas_threshold;
int64_t clustering_type = config.engine.clustering_type();
switch (clustering_type) {
case ClusteringType::K_MEANS:
default:
faiss::clustering_type = faiss::ClusteringType::K_MEANS;
break;
case ClusteringType::K_MEANS_PLUS_PLUS:
faiss::clustering_type = faiss::ClusteringType::K_MEANS_PLUS_PLUS;
break;
}
// create db root folder
s = CommonUtil::CreateDirectory(opt.meta_.path_);
if (!s.ok()) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册