diff --git a/cpp/src/scheduler/task/SearchTask.cpp b/cpp/src/scheduler/task/SearchTask.cpp index 8d76075014984902992cddfa5663ecfdda966be4..59f6509c02bebf22a02d6f72ae2426d577740bba 100644 --- a/cpp/src/scheduler/task/SearchTask.cpp +++ b/cpp/src/scheduler/task/SearchTask.cpp @@ -163,32 +163,37 @@ XSearchTask::Execute() { std::vector output_distance; for (auto &context : search_contexts_) { //step 1: allocate memory - auto inner_k = context->topk(); - auto nprobe = context->nprobe(); - output_ids.resize(inner_k * context->nq()); - output_distance.resize(inner_k * context->nq()); + uint64_t nq = context->nq(); + uint64_t topk = context->topk(); + uint64_t nprobe = context->nprobe(); + const float* vectors = context->vectors(); + + output_ids.resize(topk * nq); + output_distance.resize(topk * nq); + std::string hdr = "context " + context->Identity() + + " nq " + std::to_string(nq) + + " topk " + std::to_string(topk); try { //step 2: search - index_engine_->Search(context->nq(), context->vectors(), inner_k, nprobe, output_distance.data(), - output_ids.data()); + index_engine_->Search(nq, vectors, topk, nprobe, output_distance.data(), output_ids.data()); - double span = rc.RecordSection("do search for context:" + context->Identity()); + double span = rc.RecordSection(hdr + ", do search"); context->AccumSearchCost(span); //step 3: cluster result SearchContext::ResultSet result_set; - auto spec_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk(); - XSearchTask::ClusterResult(output_ids, output_distance, context->nq(), spec_k, result_set); + auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk; + XSearchTask::ClusterResult(output_ids, output_distance, nq, spec_k, result_set); - span = rc.RecordSection("cluster result for context:" + context->Identity()); + span = rc.RecordSection(hdr + ", cluster result"); context->AccumReduceCost(span); // step 4: pick up topk result - XSearchTask::TopkResult(result_set, inner_k, metric_l2, context->GetResult()); + XSearchTask::TopkResult(result_set, topk, metric_l2, context->GetResult()); - span = rc.RecordSection("reduce topk for context:" + context->Identity()); + span = rc.RecordSection(hdr + ", reduce topk"); context->AccumReduceCost(span); } catch (std::exception &ex) { ENGINE_LOG_ERROR << "SearchTask encounter exception: " << ex.what();