SearchTaskQueue.cpp 5.3 KB
Newer Older
G
groot 已提交
1 2 3 4 5 6 7 8 9 10
/*******************************************************************************
 * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
 * Unauthorized copying of this file, via any medium is strictly prohibited.
 * Proprietary and confidential.
 ******************************************************************************/
#include "SearchTaskQueue.h"
#include "utils/Log.h"
#include "utils/TimeRecorder.h"

namespace zilliz {
J
jinhai 已提交
11
namespace milvus {
G
groot 已提交
12 13 14 15 16 17 18 19 20
namespace engine {

namespace {
void ClusterResult(const std::vector<long> &output_ids,
                   const std::vector<float> &output_distence,
                   uint64_t nq,
                   uint64_t topk,
                   SearchContext::ResultSet &result_set) {
    result_set.clear();
G
groot 已提交
21
    result_set.reserve(nq);
G
groot 已提交
22
    for (auto i = 0; i < nq; i++) {
G
groot 已提交
23
        SearchContext::Id2ScoreMap id_score;
G
groot 已提交
24
        id_score.reserve(topk);
G
groot 已提交
25
        for (auto k = 0; k < topk; k++) {
G
groot 已提交
26
            uint64_t index = i * topk + k;
G
groot 已提交
27 28 29
            if(output_ids[index] < 0) {
                continue;
            }
G
groot 已提交
30
            id_score.push_back(std::make_pair(output_ids[index], output_distence[index]));
G
groot 已提交
31
        }
G
groot 已提交
32
        result_set.emplace_back(id_score);
G
groot 已提交
33 34 35
    }
}

G
groot 已提交
36
void MergeResult(SearchContext::Id2ScoreMap &score_src,
G
groot 已提交
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
                 SearchContext::Id2ScoreMap &score_target,
                 uint64_t topk) {
    //Note: the score_src and score_target are already arranged by score in ascending order
    if(score_src.empty()) {
        return;
    }

    if(score_target.empty()) {
        score_target.swap(score_src);
        return;
    }

    size_t src_count = score_src.size();
    size_t target_count = score_target.size();
    SearchContext::Id2ScoreMap score_merged;
    score_merged.reserve(topk);
    size_t src_index = 0, target_index = 0;
    while(true) {
        //all score_src items are merged, if score_merged.size() still less than topk
        //move items from score_target to score_merged until score_merged.size() equal topk
        if(src_index >= src_count - 1) {
            for(size_t i = target_index; i < target_count && score_merged.size() < topk; ++i) {
                score_merged.push_back(score_target[i]);
G
groot 已提交
60
            }
G
groot 已提交
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
            break;
        }

        //all score_target items are merged, if score_merged.size() still less than topk
        //move items from score_src to score_merged until score_merged.size() equal topk
        if(target_index >= target_count - 1) {
            for(size_t i = src_index; i < src_count && score_merged.size() < topk; ++i) {
                score_merged.push_back(score_src[i]);
            }
            break;
        }

        //compare score, put smallest score to score_merged one by one
        auto& src_pair = score_src[src_index];
        auto& target_pair = score_target[target_index];
        if(src_pair.second > target_pair.second) {
            score_merged.push_back(target_pair);
            target_index++;
        } else {
            score_merged.push_back(src_pair);
            src_index++;
G
groot 已提交
82 83
        }

G
groot 已提交
84 85 86 87
        //score_merged.size() already equal topk
        if(score_merged.size() >= topk) {
            break;
        }
G
groot 已提交
88
    }
G
groot 已提交
89 90

    score_target.swap(score_merged);
G
groot 已提交
91 92
}

G
groot 已提交
93 94 95 96 97 98 99 100 101 102 103 104 105 106
void TopkResult(SearchContext::ResultSet &result_src,
                uint64_t topk,
                SearchContext::ResultSet &result_target) {
    if (result_target.empty()) {
        result_target.swap(result_src);
        return;
    }

    if (result_src.size() != result_target.size()) {
        SERVER_LOG_ERROR << "Invalid result set";
        return;
    }

    for (size_t i = 0; i < result_src.size(); i++) {
G
groot 已提交
107 108
        SearchContext::Id2ScoreMap &score_src = result_src[i];
        SearchContext::Id2ScoreMap &score_target = result_target[i];
G
groot 已提交
109
        MergeResult(score_src, score_target, topk);
G
groot 已提交
110 111
    }
}
G
groot 已提交
112

G
groot 已提交
113 114
}

G
groot 已提交
115
bool SearchTask::DoSearch() {
G
groot 已提交
116 117 118 119
    if(index_engine_ == nullptr) {
        return false;
    }

G
groot 已提交
120
    server::TimeRecorder rc("DoSearch index(" + std::to_string(index_id_) + ")");
G
groot 已提交
121 122 123 124

    std::vector<long> output_ids;
    std::vector<float> output_distence;
    for(auto& context : search_contexts_) {
G
groot 已提交
125
        //step 1: allocate memory
G
groot 已提交
126 127 128
        auto inner_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk();
        output_ids.resize(inner_k*context->nq());
        output_distence.resize(inner_k*context->nq());
G
groot 已提交
129 130

        try {
G
groot 已提交
131
            //step 2: search
G
groot 已提交
132
            index_engine_->Search(context->nq(), context->vectors(), inner_k, output_distence.data(),
G
groot 已提交
133
                                  output_ids.data());
G
groot 已提交
134 135 136 137 138 139 140 141 142 143 144 145

            rc.Record("do search");

            //step 3: cluster result
            SearchContext::ResultSet result_set;
            ClusterResult(output_ids, output_distence, context->nq(), inner_k, result_set);
            rc.Record("cluster result");

            //step 4: pick up topk result
            TopkResult(result_set, inner_k, context->GetResult());
            rc.Record("reduce topk");

G
groot 已提交
146 147
        } catch (std::exception& ex) {
            SERVER_LOG_ERROR << "SearchTask encounter exception: " << ex.what();
G
groot 已提交
148 149
            context->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed
            continue;
G
groot 已提交
150 151
        }

G
groot 已提交
152
        //step 5: notify to send result to client
G
groot 已提交
153 154 155 156 157 158 159 160 161 162 163
        context->IndexSearchDone(index_id_);
    }

    rc.Elapse("totally cost");

    return true;
}

}
}
}