diff --git a/core/src/db/DBImpl.cpp b/core/src/db/DBImpl.cpp index 4ca8c7c7f866ce3f4b3f221056f9d2ff46bb5971..2afe537c354c6a184ceab31f35889641b1ffbec3 100644 --- a/core/src/db/DBImpl.cpp +++ b/core/src/db/DBImpl.cpp @@ -573,7 +573,52 @@ DBImpl::Query(const server::ContextPtr& context, const query::QueryPtr& query_pt TimeRecorder rc("DBImpl::Query"); - scheduler::SearchJobPtr job = std::make_shared(nullptr, options_, query_ptr); + snapshot::ScopedSnapshotT ss; + STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, query_ptr->collection_id)); + auto ss_id = ss->GetID(); + + /* collect all valid segment */ + std::vector segment_visitors; + auto exec = [&](const snapshot::Segment::Ptr& segment, snapshot::SegmentIterator* handler) -> Status { + auto p_id = segment->GetPartitionId(); + auto p_ptr = ss->GetResource(p_id); + auto& p_name = p_ptr->GetName(); + + /* check partition match pattern */ + bool match = false; + if (query_ptr->partitions.empty()) { + match = true; + } else { + for (auto& pattern : query_ptr->partitions) { + if (StringHelpFunctions::IsRegexMatch(p_name, pattern)) { + match = true; + break; + } + } + } + + if (match) { + auto visitor = SegmentVisitor::Build(ss, segment->GetID()); + if (!visitor) { + return Status(milvus::SS_ERROR, "Cannot build segment visitor"); + } + segment_visitors.push_back(visitor); + } + return Status::OK(); + }; + + auto segment_iter = std::make_shared(ss, exec); + segment_iter->Iterate(); + STATUS_CHECK(segment_iter->GetStatus()); + + LOG_ENGINE_DEBUG_ << LogOut("Engine query begin, segment count: %ld", segment_visitors.size()); + + engine::snapshot::IDS_TYPE segment_ids; + for (auto& sv : segment_visitors) { + segment_ids.emplace_back(sv->GetSegment()->GetID()); + } + + scheduler::SearchJobPtr job = std::make_shared(nullptr, ss, options_, query_ptr, segment_ids); /* put search job to scheduler and wait job finish */ scheduler::JobMgrInst::GetInstance()->Put(job); @@ -583,61 +628,8 @@ DBImpl::Query(const server::ContextPtr& context, const query::QueryPtr& query_pt return job->status(); } - // snapshot::ScopedSnapshotT ss; - // STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name)); - // - // /* collect all valid segment */ - // std::vector segment_visitors; - // auto exec = [&] (const snapshot::Segment::Ptr& segment, snapshot::SegmentIterator* handler) -> Status { - // auto p_id = segment->GetPartitionId(); - // auto p_ptr = ss->GetResource(p_id); - // auto& p_name = p_ptr->GetName(); - // - // /* check partition match pattern */ - // bool match = false; - // if (partition_patterns.empty()) { - // match = true; - // } else { - // for (auto &pattern : partition_patterns) { - // if (StringHelpFunctions::IsRegexMatch(p_name, pattern)) { - // match = true; - // break; - // } - // } - // } - // - // if (match) { - // auto visitor = SegmentVisitor::Build(ss, segment->GetID()); - // if (!visitor) { - // return Status(milvus::SS_ERROR, "Cannot build segment visitor"); - // } - // segment_visitors.push_back(visitor); - // } - // return Status::OK(); - // }; - // - // auto segment_iter = std::make_shared(ss, exec); - // segment_iter->Iterate(); - // STATUS_CHECK(segment_iter->GetStatus()); - // - // LOG_ENGINE_DEBUG_ << LogOut("Engine query begin, segment count: %ld", segment_visitors.size()); - // - // VectorsData vectors; - // scheduler::SearchJobPtr job = - // std::make_shared(tracer.Context(), general_query, query_ptr, attr_type, vectors); - // for (auto& sv : segment_visitors) { - // job->AddSegmentVisitor(sv); - // } - // - // // step 2: put search job to scheduler and wait result - // scheduler::JobMgrInst::GetInstance()->Put(job); - // job->WaitResult(); - // - // if (!job->GetStatus().ok()) { - // return job->GetStatus(); - // } - // - // // step 3: construct results + result = job->query_result(); + // step 3: construct results // result.row_num_ = job->vector_count(); // result.result_ids_ = job->GetResultIds(); // result.result_distances_ = job->GetResultDistances(); diff --git a/core/src/db/engine/ExecutionEngineImpl.cpp b/core/src/db/engine/ExecutionEngineImpl.cpp index 3c6c7401b3fbcdbc7d1a723fb70407d6ec907bfb..3e3e20ec30abacde73bad5e31b1402f2849d7c18 100644 --- a/core/src/db/engine/ExecutionEngineImpl.cpp +++ b/core/src/db/engine/ExecutionEngineImpl.cpp @@ -191,6 +191,84 @@ ExecutionEngineImpl::CopyToGpu(uint64_t device_id) { return Status::OK(); } +void +MapAndCopyResult(const knowhere::DatasetPtr& dataset, const std::vector& uids, int64_t nq, + int64_t k, float* distances, int64_t* labels) { + int64_t* res_ids = dataset->Get(knowhere::meta::IDS); + float* res_dist = dataset->Get(knowhere::meta::DISTANCE); + + memcpy(distances, res_dist, sizeof(float) * nq * k); + + /* map offsets to ids */ + int64_t num = nq * k; + for (int64_t i = 0; i < num; ++i) { + int64_t offset = res_ids[i]; + if (offset != -1) { + labels[i] = uids[offset]; + } else { + labels[i] = -1; + } + } + + free(res_ids); + free(res_dist); +} + +Status +ExecutionEngineImpl::VecSearch(milvus::engine::ExecutionEngineContext& context, + const query::VectorQueryPtr& vector_param, knowhere::VecIndexPtr& vec_index, + bool hybrid) { + TimeRecorder rc(LogOut("[%s][%ld] ExecutionEngineImpl::Search", "search", 0)); + + if (vec_index == nullptr) { + LOG_ENGINE_ERROR_ << LogOut("[%s][%ld] ExecutionEngineImpl: index is null, failed to search", "search", 0); + return Status(DB_ERROR, "index is null"); + } + + uint64_t nq = 0; + auto query_vector = vector_param->query_vector; + if (!query_vector.float_data.empty()) { + nq = vector_param->query_vector.float_data.size() / vec_index->Dim(); + } else if (!query_vector.binary_data.empty()) { + nq = vector_param->query_vector.binary_data.size() * 8 / vec_index->Dim(); + } + uint64_t topk = vector_param->topk; + + context.query_result_ = std::make_shared(); + context.query_result_->result_ids_.resize(topk * nq); + context.query_result_->result_distances_.resize(topk * nq); + + milvus::json conf = vector_param->extra_params; + conf[knowhere::meta::TOPK] = topk; + auto adapter = knowhere::AdapterMgr::GetInstance().GetAdapter(vec_index->index_type()); + if (!adapter->CheckSearch(conf, vec_index->index_type(), vec_index->index_mode())) { + LOG_ENGINE_ERROR_ << LogOut("[%s][%ld] Illegal search params", "search", 0); + throw Exception(DB_ERROR, "Illegal search params"); + } + + if (hybrid) { + // HybridLoad(); + } + + rc.RecordSection("query prepare"); + knowhere::DatasetPtr dataset; + if (!query_vector.float_data.empty()) { + dataset = knowhere::GenDataset(nq, vec_index->Dim(), query_vector.float_data.data()); + } else { + dataset = knowhere::GenDataset(nq, vec_index->Dim(), query_vector.binary_data.data()); + } + auto result = vec_index->Query(dataset, conf); + + MapAndCopyResult(result, vec_index->GetUids(), nq, topk, context.query_result_->result_distances_.data(), + context.query_result_->result_ids_.data()); + + if (hybrid) { + // HybridUnset(); + } + + return Status::OK(); +} + Status ExecutionEngineImpl::Search(ExecutionEngineContext& context) { try { @@ -212,7 +290,6 @@ ExecutionEngineImpl::Search(ExecutionEngineContext& context) { if (field->GetFtype() == (int)engine::meta::DataType::VECTOR_FLOAT || field->GetFtype() == (int)engine::meta::DataType::VECTOR_BINARY) { segment_ptr->GetVectorIndex(field->GetName(), vec_index); - break; } else if (type == (int)engine::meta::DataType::UID) { continue; } else { @@ -236,8 +313,14 @@ ExecutionEngineImpl::Search(ExecutionEngineContext& context) { } vec_index->SetBlacklist(list); - auto vector_query = context.query_ptr_->vectors.at(vector_placeholder); + auto& vector_param = context.query_ptr_->vectors.at(vector_placeholder); + if (!vector_param->query_vector.float_data.empty()) { + vector_param->nq = vector_param->query_vector.float_data.size() / vec_index->Dim(); + } else if (!vector_param->query_vector.binary_data.empty()) { + vector_param->nq = vector_param->query_vector.binary_data.size() * 8 / vec_index->Dim(); + } + status = VecSearch(context, context.query_ptr_->vectors.at(vector_placeholder), vec_index); if (!status.ok()) { return status; } diff --git a/core/src/db/engine/ExecutionEngineImpl.h b/core/src/db/engine/ExecutionEngineImpl.h index fc02fe4efce58c9a563c7ff494c7cdcb85be0435..6f7c6491b590bb86bb9d3541a5dff29d1115b883 100644 --- a/core/src/db/engine/ExecutionEngineImpl.h +++ b/core/src/db/engine/ExecutionEngineImpl.h @@ -41,6 +41,10 @@ class ExecutionEngineImpl : public ExecutionEngine { BuildIndex() override; private: + Status + VecSearch(ExecutionEngineContext& context, const query::VectorQueryPtr& vector_param, + knowhere::VecIndexPtr& vec_index, bool hybrid = false); + knowhere::VecIndexPtr CreateVecIndex(const std::string& index_name); diff --git a/core/src/query/GeneralQuery.h b/core/src/query/GeneralQuery.h index 356a3ff500e2167771036d21750b7ddf01d9d8cc..e1634dd90a7c2456a57fddd8beb3943cbdec02c8 100644 --- a/core/src/query/GeneralQuery.h +++ b/core/src/query/GeneralQuery.h @@ -77,6 +77,7 @@ struct VectorQuery { std::string field_name; milvus::json extra_params = {}; int64_t topk; + int64_t nq; float boost; VectorRecord query_vector; }; diff --git a/core/src/scheduler/job/SearchJob.cpp b/core/src/scheduler/job/SearchJob.cpp index f58ff74aba606ce52cd6d53b5ed6a956167a9a76..8ad01fa7adff3dd5d76a9c7db32e95882aabcf69 100644 --- a/core/src/scheduler/job/SearchJob.cpp +++ b/core/src/scheduler/job/SearchJob.cpp @@ -16,15 +16,21 @@ namespace milvus { namespace scheduler { -SearchJob::SearchJob(const server::ContextPtr& context, engine::DBOptions options, const query::QueryPtr& query_ptr) - : Job(JobType::SEARCH), context_(context), options_(options), query_ptr_(query_ptr) { - GetSegmentsFromQuery(query_ptr, segment_ids_); +SearchJob::SearchJob(const server::ContextPtr& context, const engine::snapshot::ScopedSnapshotT& snapshot, + engine::DBOptions options, const query::QueryPtr& query_ptr, + const engine::snapshot::IDS_TYPE& segment_ids) + : Job(JobType::SEARCH), + context_(context), + snapshot_(snapshot), + options_(options), + query_ptr_(query_ptr), + segment_ids_(segment_ids) { } void SearchJob::OnCreateTasks(JobTasks& tasks) { for (auto& id : segment_ids_) { - auto task = std::make_shared(context_, options_, query_ptr_, id, nullptr); + auto task = std::make_shared(context_, snapshot_, options_, query_ptr_, id, nullptr); task->job_ = this; tasks.emplace_back(task); } @@ -40,10 +46,5 @@ SearchJob::Dump() const { return ret; } -void -SearchJob::GetSegmentsFromQuery(const query::QueryPtr& query_ptr, engine::snapshot::IDS_TYPE& segment_ids) { - // TODO -} - } // namespace scheduler } // namespace milvus diff --git a/core/src/scheduler/job/SearchJob.h b/core/src/scheduler/job/SearchJob.h index 05d075751bfbe2f37df749294f1c5491fc9360f5..4658ddb70d94c5860bceb4bdf1957995c472e6e1 100644 --- a/core/src/scheduler/job/SearchJob.h +++ b/core/src/scheduler/job/SearchJob.h @@ -39,7 +39,9 @@ namespace scheduler { class SearchJob : public Job { public: - SearchJob(const server::ContextPtr& context, engine::DBOptions options, const query::QueryPtr& query_ptr); + SearchJob(const server::ContextPtr& context, const engine::snapshot::ScopedSnapshotT& snapshot, + engine::DBOptions options, const query::QueryPtr& query_ptr, + const engine::snapshot::IDS_TYPE& segment_ids); public: json @@ -74,13 +76,9 @@ class SearchJob : public Job { void OnCreateTasks(JobTasks& tasks) override; - private: - void - GetSegmentsFromQuery(const query::QueryPtr& query_ptr, engine::snapshot::IDS_TYPE& segment_ids); - private: const server::ContextPtr context_; - + engine::snapshot::ScopedSnapshotT snapshot_; engine::DBOptions options_; query::QueryPtr query_ptr_; diff --git a/core/src/scheduler/task/SearchTask.cpp b/core/src/scheduler/task/SearchTask.cpp index 7407a6aa452117f772bfcad03c45eb5d632eb890..0ea75ccaa819da1b72177c6258c83322d419ca14 100644 --- a/core/src/scheduler/task/SearchTask.cpp +++ b/core/src/scheduler/task/SearchTask.cpp @@ -29,10 +29,12 @@ namespace milvus { namespace scheduler { -SearchTask::SearchTask(const server::ContextPtr& context, const engine::DBOptions& options, - const query::QueryPtr& query_ptr, engine::snapshot::ID_TYPE segment_id, TaskLabelPtr label) +SearchTask::SearchTask(const server::ContextPtr& context, engine::snapshot::ScopedSnapshotT snapshot, + const engine::DBOptions& options, const query::QueryPtr& query_ptr, + engine::snapshot::ID_TYPE segment_id, TaskLabelPtr label) : Task(TaskType::SearchTask, std::move(label)), context_(context), + snapshot_(snapshot), options_(options), query_ptr_(query_ptr), segment_id_(segment_id) { @@ -42,9 +44,7 @@ SearchTask::SearchTask(const server::ContextPtr& context, const engine::DBOption void SearchTask::CreateExecEngine() { if (execution_engine_ == nullptr && query_ptr_ != nullptr) { - engine::snapshot::ScopedSnapshotT latest_ss; - engine::snapshot::Snapshots::GetInstance().GetSnapshot(latest_ss, query_ptr_->collection_id); - execution_engine_ = engine::EngineFactory::Build(latest_ss, options_.meta_.path_, segment_id_); + execution_engine_ = engine::EngineFactory::Build(snapshot_, options_.meta_.path_, segment_id_); } } @@ -106,29 +106,37 @@ SearchTask::OnExecute() { return Status(DB_ERROR, "execution engine is null"); } + // auto search_job = std::static_pointer_cast(std::shared_ptr(job_)); + auto search_job = static_cast(job_); try { /* step 2: search */ engine::ExecutionEngineContext context; context.query_ptr_ = query_ptr_; context.query_result_ = std::make_shared(); - auto status = execution_engine_->Search(context); - - if (!status.ok()) { - return status; - } + STATUS_CHECK(execution_engine_->Search(context)); rc.RecordSection("search done"); /* step 3: pick up topk result */ - // auto spec_k = file_->row_count_ < topk ? file_->row_count_ : topk; - // if (spec_k == 0) { - // LOG_ENGINE_WARNING_ << LogOut("[%s][%ld] Searching in an empty file. file location = %s", - // "search", 0, - // file_->location_.c_str()); - // } else { - // std::unique_lock lock(search_job->mutex()); - // XSearchTask::MergeTopkToResultSet(result, spec_k, nq, topk, ascending_, search_job->GetQueryResult()); - // } + // TODO(yukun): Remove hardcode here + auto vector_param = context.query_ptr_->vectors.begin()->second; + auto topk = vector_param->topk; + auto segment_ptr = snapshot_->GetSegmentCommitBySegmentId(segment_id_); + auto spec_k = segment_ptr->GetRowCount() < topk ? segment_ptr->GetRowCount() : topk; + int64_t nq = vector_param->nq; + if (spec_k == 0) { + LOG_ENGINE_WARNING_ << LogOut("[%s][%ld] Searching in an empty segment. segment id = %d", "search", 0, + segment_ptr->GetID()); + } else { + // std::unique_lock lock(search_job->mutex()); + if (!search_job->query_result()) { + search_job->query_result() = std::make_shared(); + search_job->query_result()->row_num_ = nq; + } + SearchTask::MergeTopkToResultSet(context.query_result_->result_ids_, + context.query_result_->result_distances_, spec_k, nq, topk, + ascending_reduce_, search_job->query_result()); + } rc.RecordSection("reduce topk done"); } catch (std::exception& ex) { @@ -140,6 +148,72 @@ SearchTask::OnExecute() { return Status::OK(); } +void +SearchTask::MergeTopkToResultSet(const engine::ResultIds& src_ids, const engine::ResultDistances& src_distances, + size_t src_k, size_t nq, size_t topk, bool ascending, engine::QueryResultPtr& result) { + if (src_ids.empty()) { + LOG_ENGINE_DEBUG_ << LogOut("[%s][%d] Search result is empty.", "search", 0); + return; + } + + size_t tar_k = result->result_ids_.size() / nq; + size_t buf_k = std::min(topk, src_k + tar_k); + + engine::ResultIds buf_ids(nq * buf_k, -1); + engine::ResultDistances buf_distances(nq * buf_k, 0.0); + for (uint64_t i = 0; i < nq; i++) { + size_t buf_k_j = 0, src_k_j = 0, tar_k_j = 0; + size_t buf_idx, src_idx, tar_idx; + + size_t buf_k_multi_i = buf_k * i; + size_t src_k_multi_i = topk * i; + size_t tar_k_multi_i = tar_k * i; + + while (buf_k_j < buf_k && src_k_j < src_k && tar_k_j < tar_k) { + src_idx = src_k_multi_i + src_k_j; + tar_idx = tar_k_multi_i + tar_k_j; + buf_idx = buf_k_multi_i + buf_k_j; + + if ((result->result_ids_[tar_idx] == -1) || // initialized value + (ascending && src_distances[src_idx] < result->result_distances_[tar_idx]) || + (!ascending && src_distances[src_idx] > result->result_distances_[tar_idx])) { + buf_ids[buf_idx] = src_ids[src_idx]; + buf_distances[buf_idx] = src_distances[src_idx]; + src_k_j++; + } else { + buf_ids[buf_idx] = result->result_ids_[tar_idx]; + buf_distances[buf_idx] = result->result_distances_[tar_idx]; + tar_k_j++; + } + buf_k_j++; + } + + if (buf_k_j < buf_k) { + if (src_k_j < src_k) { + while (buf_k_j < buf_k && src_k_j < src_k) { + buf_idx = buf_k_multi_i + buf_k_j; + src_idx = src_k_multi_i + src_k_j; + buf_ids[buf_idx] = src_ids[src_idx]; + buf_distances[buf_idx] = src_distances[src_idx]; + src_k_j++; + buf_k_j++; + } + } else { + while (buf_k_j < buf_k && tar_k_j < tar_k) { + buf_idx = buf_k_multi_i + buf_k_j; + tar_idx = tar_k_multi_i + tar_k_j; + buf_ids[buf_idx] = result->result_ids_[tar_idx]; + buf_distances[buf_idx] = result->result_distances_[tar_idx]; + tar_k_j++; + buf_k_j++; + } + } + } + } + result->result_ids_.swap(buf_ids); + result->result_distances_.swap(buf_distances); +} + int64_t SearchTask::nq() { return 0; diff --git a/core/src/scheduler/task/SearchTask.h b/core/src/scheduler/task/SearchTask.h index 4a43e0c875d6cfa964bb32fd0433e64f19c95a5b..b73b697bcc0d517a7142350a78051e9368179ceb 100644 --- a/core/src/scheduler/task/SearchTask.h +++ b/core/src/scheduler/task/SearchTask.h @@ -26,8 +26,9 @@ namespace scheduler { class SearchTask : public Task { public: - explicit SearchTask(const server::ContextPtr& context, const engine::DBOptions& options, - const query::QueryPtr& query_ptr, engine::snapshot::ID_TYPE segment_id, TaskLabelPtr label); + explicit SearchTask(const server::ContextPtr& context, engine::snapshot::ScopedSnapshotT snapshot, + const engine::DBOptions& options, const query::QueryPtr& query_ptr, + engine::snapshot::ID_TYPE segment_id, TaskLabelPtr label); inline json Dump() const override { @@ -44,6 +45,10 @@ class SearchTask : public Task { Status OnExecute() override; + static void + MergeTopkToResultSet(const engine::ResultIds& src_ids, const engine::ResultDistances& src_distances, size_t src_k, + size_t nq, size_t topk, bool ascending, engine::QueryResultPtr& result); + int64_t nq(); @@ -53,12 +58,17 @@ class SearchTask : public Task { public: const std::shared_ptr context_; + engine::snapshot::ScopedSnapshotT snapshot_; const engine::DBOptions& options_; query::QueryPtr query_ptr_; engine::snapshot::ID_TYPE segment_id_; engine::ExecutionEnginePtr execution_engine_; + + // distance -- value 0 means two vectors equal, ascending reduce, L2/HAMMING/JACCARD/TONIMOTO ... + // similarity -- infinity value means two vectors equal, descending reduce, IP + bool ascending_reduce_ = true; }; } // namespace scheduler diff --git a/core/src/server/grpc_impl/GrpcRequestHandler.cpp b/core/src/server/grpc_impl/GrpcRequestHandler.cpp index 52c7b79994a93facc74b54587c96bed0ae52914a..8d58fb47f61bb96e0095d618ff2599c8c2cb527d 100644 --- a/core/src/server/grpc_impl/GrpcRequestHandler.cpp +++ b/core/src/server/grpc_impl/GrpcRequestHandler.cpp @@ -643,7 +643,9 @@ GrpcRequestHandler::CreateCollection(::grpc::ServerContext* context, const ::mil // Currently only one extra_param if (field.extra_params_size() != 0) { - field_schema.field_params_ = json::parse(field.extra_params(0).value()); + if (!field.extra_params(0).value().empty()) { + field_schema.field_params_ = json::parse(field.extra_params(0).value()); + } } for (int j = 0; j < field.index_params_size(); j++) { diff --git a/sdk/examples/simple/src/ClientTest.cpp b/sdk/examples/simple/src/ClientTest.cpp index 9c35c2d5c850a22c7fcd265cfea4e6fb8959b30d..dd6d128b08927cb63e30e5a6f2cb45d074e02bec 100644 --- a/sdk/examples/simple/src/ClientTest.cpp +++ b/sdk/examples/simple/src/ClientTest.cpp @@ -114,7 +114,7 @@ ClientTest::CreateCollection(const std::string& collection_name) { field_ptr4->extra_params = extra_params_4.dump(); JSON extra_params; - extra_params["segment_size"] = 1024; + extra_params["segment_row_count"] = 1024; milvus::Mapping mapping = {collection_name, {field_ptr1, field_ptr2, field_ptr3, field_ptr4}}; milvus::Status stat = conn_->CreateCollection(mapping, extra_params.dump()); @@ -352,5 +352,5 @@ ClientTest::Test() { // entities // // DropIndex(collection_name, "field_vec", "index_3"); - DropCollection(collection_name); + // DropCollection(collection_name); }