SearchTask.cpp 6.4 KB
Newer Older
G
groot 已提交
1 2 3 4 5
/*******************************************************************************
 * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
 * Unauthorized copying of this file, via any medium is strictly prohibited.
 * Proprietary and confidential.
 ******************************************************************************/
G
groot 已提交
6
#include "SearchTask.h"
G
groot 已提交
7
#include "metrics/Metrics.h"
G
groot 已提交
8 9 10 11
#include "utils/Log.h"
#include "utils/TimeRecorder.h"

namespace zilliz {
J
jinhai 已提交
12
namespace milvus {
G
groot 已提交
13 14 15 16 17 18 19 20 21
namespace engine {

namespace {
void ClusterResult(const std::vector<long> &output_ids,
                   const std::vector<float> &output_distence,
                   uint64_t nq,
                   uint64_t topk,
                   SearchContext::ResultSet &result_set) {
    result_set.clear();
G
groot 已提交
22
    result_set.reserve(nq);
G
groot 已提交
23
    for (auto i = 0; i < nq; i++) {
G
groot 已提交
24
        SearchContext::Id2ScoreMap id_score;
G
groot 已提交
25
        id_score.reserve(topk);
G
groot 已提交
26
        for (auto k = 0; k < topk; k++) {
G
groot 已提交
27
            uint64_t index = i * topk + k;
G
groot 已提交
28 29 30
            if(output_ids[index] < 0) {
                continue;
            }
G
groot 已提交
31
            id_score.push_back(std::make_pair(output_ids[index], output_distence[index]));
G
groot 已提交
32
        }
G
groot 已提交
33
        result_set.emplace_back(id_score);
G
groot 已提交
34 35 36
    }
}

G
groot 已提交
37
void MergeResult(SearchContext::Id2ScoreMap &score_src,
G
groot 已提交
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
                 SearchContext::Id2ScoreMap &score_target,
                 uint64_t topk) {
    //Note: the score_src and score_target are already arranged by score in ascending order
    if(score_src.empty()) {
        return;
    }

    if(score_target.empty()) {
        score_target.swap(score_src);
        return;
    }

    size_t src_count = score_src.size();
    size_t target_count = score_target.size();
    SearchContext::Id2ScoreMap score_merged;
    score_merged.reserve(topk);
    size_t src_index = 0, target_index = 0;
    while(true) {
        //all score_src items are merged, if score_merged.size() still less than topk
        //move items from score_target to score_merged until score_merged.size() equal topk
G
groot 已提交
58
        if(src_index >= src_count) {
G
groot 已提交
59 60
            for(size_t i = target_index; i < target_count && score_merged.size() < topk; ++i) {
                score_merged.push_back(score_target[i]);
G
groot 已提交
61
            }
G
groot 已提交
62 63 64 65 66
            break;
        }

        //all score_target items are merged, if score_merged.size() still less than topk
        //move items from score_src to score_merged until score_merged.size() equal topk
G
groot 已提交
67
        if(target_index >= target_count) {
G
groot 已提交
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
            for(size_t i = src_index; i < src_count && score_merged.size() < topk; ++i) {
                score_merged.push_back(score_src[i]);
            }
            break;
        }

        //compare score, put smallest score to score_merged one by one
        auto& src_pair = score_src[src_index];
        auto& target_pair = score_target[target_index];
        if(src_pair.second > target_pair.second) {
            score_merged.push_back(target_pair);
            target_index++;
        } else {
            score_merged.push_back(src_pair);
            src_index++;
G
groot 已提交
83 84
        }

G
groot 已提交
85 86 87 88
        //score_merged.size() already equal topk
        if(score_merged.size() >= topk) {
            break;
        }
G
groot 已提交
89
    }
G
groot 已提交
90 91

    score_target.swap(score_merged);
G
groot 已提交
92 93
}

G
groot 已提交
94 95 96 97 98 99 100 101 102 103 104 105 106 107
void TopkResult(SearchContext::ResultSet &result_src,
                uint64_t topk,
                SearchContext::ResultSet &result_target) {
    if (result_target.empty()) {
        result_target.swap(result_src);
        return;
    }

    if (result_src.size() != result_target.size()) {
        SERVER_LOG_ERROR << "Invalid result set";
        return;
    }

    for (size_t i = 0; i < result_src.size(); i++) {
G
groot 已提交
108 109
        SearchContext::Id2ScoreMap &score_src = result_src[i];
        SearchContext::Id2ScoreMap &score_target = result_target[i];
G
groot 已提交
110
        MergeResult(score_src, score_target, topk);
G
groot 已提交
111 112
    }
}
G
groot 已提交
113

G
groot 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
void CollectDurationMetrics(int index_type, double total_time) {
    switch(index_type) {
        case meta::TableFileSchema::RAW: {
            server::Metrics::GetInstance().SearchRawDataDurationSecondsHistogramObserve(total_time);
            break;
        }
        case meta::TableFileSchema::TO_INDEX: {
            server::Metrics::GetInstance().SearchRawDataDurationSecondsHistogramObserve(total_time);
            break;
        }
        default: {
            server::Metrics::GetInstance().SearchIndexDataDurationSecondsHistogramObserve(total_time);
            break;
        }
    }
G
groot 已提交
129 130
}

G
groot 已提交
131 132 133 134 135 136 137 138
}

SearchTask::SearchTask()
: IScheduleTask(ScheduleTaskType::kSearch) {

}

std::shared_ptr<IScheduleTask> SearchTask::Execute() {
G
groot 已提交
139
    if(index_engine_ == nullptr) {
G
groot 已提交
140
        return nullptr;
G
groot 已提交
141 142
    }

G
groot 已提交
143 144 145
    SERVER_LOG_INFO << "Searching in index(" << index_id_<< ") with "
                    << search_contexts_.size() << " tasks";

G
groot 已提交
146
    server::TimeRecorder rc("DoSearch index(" + std::to_string(index_id_) + ")");
G
groot 已提交
147

G
groot 已提交
148 149
    auto start_time = METRICS_NOW_TIME;

G
groot 已提交
150 151 152
    std::vector<long> output_ids;
    std::vector<float> output_distence;
    for(auto& context : search_contexts_) {
G
groot 已提交
153
        //step 1: allocate memory
Y
yu yunfeng 已提交
154
        auto inner_k = context->topk();
G
groot 已提交
155 156
        output_ids.resize(inner_k*context->nq());
        output_distence.resize(inner_k*context->nq());
G
groot 已提交
157 158

        try {
G
groot 已提交
159
            //step 2: search
G
groot 已提交
160
            index_engine_->Search(context->nq(), context->vectors(), inner_k, output_distence.data(),
G
groot 已提交
161
                                  output_ids.data());
G
groot 已提交
162 163 164 165 166

            rc.Record("do search");

            //step 3: cluster result
            SearchContext::ResultSet result_set;
Y
yu yunfeng 已提交
167 168
            auto spec_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk();
            ClusterResult(output_ids, output_distence, context->nq(), spec_k, result_set);
G
groot 已提交
169 170 171 172 173 174
            rc.Record("cluster result");

            //step 4: pick up topk result
            TopkResult(result_set, inner_k, context->GetResult());
            rc.Record("reduce topk");

G
groot 已提交
175 176
        } catch (std::exception& ex) {
            SERVER_LOG_ERROR << "SearchTask encounter exception: " << ex.what();
G
groot 已提交
177 178
            context->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed
            continue;
G
groot 已提交
179 180
        }

G
groot 已提交
181
        //step 5: notify to send result to client
G
groot 已提交
182 183 184
        context->IndexSearchDone(index_id_);
    }

G
groot 已提交
185 186 187 188
    auto end_time = METRICS_NOW_TIME;
    auto total_time = METRICS_MICROSECONDS(start_time, end_time);
    CollectDurationMetrics(index_type_, total_time);

G
groot 已提交
189 190
    rc.Elapse("totally cost");

G
groot 已提交
191
    return nullptr;
G
groot 已提交
192 193 194 195 196
}

}
}
}