SearchTask.cpp 8.1 KB
Newer Older
G
groot 已提交
1 2 3 4 5
/*******************************************************************************
 * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
 * Unauthorized copying of this file, via any medium is strictly prohibited.
 * Proprietary and confidential.
 ******************************************************************************/
G
groot 已提交
6
#include "SearchTask.h"
G
groot 已提交
7
#include "metrics/Metrics.h"
G
groot 已提交
8 9 10 11
#include "utils/Log.h"
#include "utils/TimeRecorder.h"

namespace zilliz {
J
jinhai 已提交
12
namespace milvus {
G
groot 已提交
13 14 15
namespace engine {

namespace {
G
groot 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
void CollectDurationMetrics(int index_type, double total_time) {
    switch(index_type) {
        case meta::TableFileSchema::RAW: {
            server::Metrics::GetInstance().SearchRawDataDurationSecondsHistogramObserve(total_time);
            break;
        }
        case meta::TableFileSchema::TO_INDEX: {
            server::Metrics::GetInstance().SearchRawDataDurationSecondsHistogramObserve(total_time);
            break;
        }
        default: {
            server::Metrics::GetInstance().SearchIndexDataDurationSecondsHistogramObserve(total_time);
            break;
        }
    }
G
groot 已提交
31 32
}

33 34 35 36 37 38
std::string GetMetricType() {
    server::ServerConfig &config = server::ServerConfig::GetInstance();
    server::ConfigNode engine_config = config.GetConfig(server::CONFIG_ENGINE);
    return engine_config.GetValue(server::CONFIG_METRICTYPE, "L2");
}

G
groot 已提交
39 40 41 42
}

SearchTask::SearchTask()
: IScheduleTask(ScheduleTaskType::kSearch) {
43 44 45 46
    std::string metric_type = GetMetricType();
    if(metric_type != "L2") {
        metric_l2 = false;
    }
G
groot 已提交
47 48 49
}

std::shared_ptr<IScheduleTask> SearchTask::Execute() {
G
groot 已提交
50
    if(index_engine_ == nullptr) {
G
groot 已提交
51
        return nullptr;
G
groot 已提交
52 53
    }

G
groot 已提交
54 55 56
    SERVER_LOG_INFO << "Searching in index(" << index_id_<< ") with "
                    << search_contexts_.size() << " tasks";

G
groot 已提交
57
    server::TimeRecorder rc("DoSearch index(" + std::to_string(index_id_) + ")");
G
groot 已提交
58

G
groot 已提交
59 60
    auto start_time = METRICS_NOW_TIME;

G
groot 已提交
61 62 63
    std::vector<long> output_ids;
    std::vector<float> output_distence;
    for(auto& context : search_contexts_) {
G
groot 已提交
64
        //step 1: allocate memory
Y
yu yunfeng 已提交
65
        auto inner_k = context->topk();
G
groot 已提交
66 67
        output_ids.resize(inner_k*context->nq());
        output_distence.resize(inner_k*context->nq());
G
groot 已提交
68 69

        try {
G
groot 已提交
70
            //step 2: search
G
groot 已提交
71
            index_engine_->Search(context->nq(), context->vectors(), inner_k, output_distence.data(),
G
groot 已提交
72
                                  output_ids.data());
G
groot 已提交
73 74 75 76 77

            rc.Record("do search");

            //step 3: cluster result
            SearchContext::ResultSet result_set;
Y
yu yunfeng 已提交
78
            auto spec_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk();
79
            SearchTask::ClusterResult(output_ids, output_distence, context->nq(), spec_k, result_set);
G
groot 已提交
80 81 82
            rc.Record("cluster result");

            //step 4: pick up topk result
83
            SearchTask::TopkResult(result_set, inner_k, metric_l2, context->GetResult());
G
groot 已提交
84 85
            rc.Record("reduce topk");

G
groot 已提交
86 87
        } catch (std::exception& ex) {
            SERVER_LOG_ERROR << "SearchTask encounter exception: " << ex.what();
G
groot 已提交
88 89
            context->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed
            continue;
G
groot 已提交
90 91
        }

G
groot 已提交
92
        //step 5: notify to send result to client
G
groot 已提交
93 94 95
        context->IndexSearchDone(index_id_);
    }

G
groot 已提交
96 97 98 99
    auto end_time = METRICS_NOW_TIME;
    auto total_time = METRICS_MICROSECONDS(start_time, end_time);
    CollectDurationMetrics(index_type_, total_time);

G
groot 已提交
100 101
    rc.Elapse("totally cost");

G
groot 已提交
102
    return nullptr;
G
groot 已提交
103 104
}

105 106 107 108 109
Status SearchTask::ClusterResult(const std::vector<long> &output_ids,
                                 const std::vector<float> &output_distence,
                                 uint64_t nq,
                                 uint64_t topk,
                                 SearchContext::ResultSet &result_set) {
S
starlord 已提交
110
    if(output_ids.size() < nq*topk || output_distence.size() < nq*topk) {
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
        std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) +
                " distance array size: " + std::to_string(output_distence.size());
        SERVER_LOG_ERROR << msg;
        return Status::Error(msg);
    }

    result_set.clear();
    result_set.reserve(nq);
    for (auto i = 0; i < nq; i++) {
        SearchContext::Id2DistanceMap id_distance;
        id_distance.reserve(topk);
        for (auto k = 0; k < topk; k++) {
            uint64_t index = i * topk + k;
            if(output_ids[index] < 0) {
                continue;
            }
            id_distance.push_back(std::make_pair(output_ids[index], output_distence[index]));
        }
        result_set.emplace_back(id_distance);
    }

    return Status::OK();
}

Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
                               SearchContext::Id2DistanceMap &distance_target,
137 138
                               uint64_t topk,
                               bool ascending) {
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
    //Note: the score_src and score_target are already arranged by score in ascending order
    if(distance_src.empty()) {
        SERVER_LOG_WARNING << "Empty distance source array";
        return Status::OK();
    }

    if(distance_target.empty()) {
        distance_target.swap(distance_src);
        return Status::OK();
    }

    size_t src_count = distance_src.size();
    size_t target_count = distance_target.size();
    SearchContext::Id2DistanceMap distance_merged;
    distance_merged.reserve(topk);
    size_t src_index = 0, target_index = 0;
    while(true) {
        //all score_src items are merged, if score_merged.size() still less than topk
        //move items from score_target to score_merged until score_merged.size() equal topk
        if(src_index >= src_count) {
            for(size_t i = target_index; i < target_count && distance_merged.size() < topk; ++i) {
                distance_merged.push_back(distance_target[i]);
            }
            break;
        }

        //all score_target items are merged, if score_merged.size() still less than topk
        //move items from score_src to score_merged until score_merged.size() equal topk
        if(target_index >= target_count) {
            for(size_t i = src_index; i < src_count && distance_merged.size() < topk; ++i) {
                distance_merged.push_back(distance_src[i]);
            }
            break;
        }

174 175 176
        //compare score,
        // if ascending = true, put smallest score to score_merged one by one
        // else, put largest score to score_merged one by one
177 178
        auto& src_pair = distance_src[src_index];
        auto& target_pair = distance_target[target_index];
179 180 181 182 183 184 185 186
        if(ascending){
            if(src_pair.second > target_pair.second) {
                distance_merged.push_back(target_pair);
                target_index++;
            } else {
                distance_merged.push_back(src_pair);
                src_index++;
            }
187
        } else {
188 189 190 191 192 193 194
            if(src_pair.second < target_pair.second) {
                distance_merged.push_back(target_pair);
                target_index++;
            } else {
                distance_merged.push_back(src_pair);
                src_index++;
            }
195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
        }

        //score_merged.size() already equal topk
        if(distance_merged.size() >= topk) {
            break;
        }
    }

    distance_target.swap(distance_merged);

    return Status::OK();
}

Status SearchTask::TopkResult(SearchContext::ResultSet &result_src,
                              uint64_t topk,
210
                              bool ascending,
211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
                              SearchContext::ResultSet &result_target) {
    if (result_target.empty()) {
        result_target.swap(result_src);
        return Status::OK();
    }

    if (result_src.size() != result_target.size()) {
        std::string msg = "Invalid result set size";
        SERVER_LOG_ERROR << msg;
        return Status::Error(msg);
    }

    for (size_t i = 0; i < result_src.size(); i++) {
        SearchContext::Id2DistanceMap &score_src = result_src[i];
        SearchContext::Id2DistanceMap &score_target = result_target[i];
226
        SearchTask::MergeResult(score_src, score_target, topk, ascending);
227 228 229 230 231
    }

    return Status::OK();
}

G
groot 已提交
232 233 234
}
}
}