SearchTask.cpp 7.2 KB
Newer Older
G
groot 已提交
1 2 3 4 5
/*******************************************************************************
 * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
 * Unauthorized copying of this file, via any medium is strictly prohibited.
 * Proprietary and confidential.
 ******************************************************************************/
G
groot 已提交
6
#include "SearchTask.h"
G
groot 已提交
7
#include "metrics/Metrics.h"
G
groot 已提交
8 9 10 11
#include "utils/Log.h"
#include "utils/TimeRecorder.h"

namespace zilliz {
J
jinhai 已提交
12
namespace milvus {
G
groot 已提交
13 14 15
namespace engine {

namespace {
G
groot 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
void CollectDurationMetrics(int index_type, double total_time) {
    switch(index_type) {
        case meta::TableFileSchema::RAW: {
            server::Metrics::GetInstance().SearchRawDataDurationSecondsHistogramObserve(total_time);
            break;
        }
        case meta::TableFileSchema::TO_INDEX: {
            server::Metrics::GetInstance().SearchRawDataDurationSecondsHistogramObserve(total_time);
            break;
        }
        default: {
            server::Metrics::GetInstance().SearchIndexDataDurationSecondsHistogramObserve(total_time);
            break;
        }
    }
G
groot 已提交
31 32
}

G
groot 已提交
33 34 35 36 37 38 39 40
}

SearchTask::SearchTask()
: IScheduleTask(ScheduleTaskType::kSearch) {

}

std::shared_ptr<IScheduleTask> SearchTask::Execute() {
G
groot 已提交
41
    if(index_engine_ == nullptr) {
G
groot 已提交
42
        return nullptr;
G
groot 已提交
43 44
    }

G
groot 已提交
45 46 47
    SERVER_LOG_INFO << "Searching in index(" << index_id_<< ") with "
                    << search_contexts_.size() << " tasks";

G
groot 已提交
48
    server::TimeRecorder rc("DoSearch index(" + std::to_string(index_id_) + ")");
G
groot 已提交
49

G
groot 已提交
50 51
    auto start_time = METRICS_NOW_TIME;

G
groot 已提交
52 53 54
    std::vector<long> output_ids;
    std::vector<float> output_distence;
    for(auto& context : search_contexts_) {
G
groot 已提交
55
        //step 1: allocate memory
Y
yu yunfeng 已提交
56
        auto inner_k = context->topk();
G
groot 已提交
57 58
        output_ids.resize(inner_k*context->nq());
        output_distence.resize(inner_k*context->nq());
G
groot 已提交
59 60

        try {
G
groot 已提交
61
            //step 2: search
G
groot 已提交
62
            index_engine_->Search(context->nq(), context->vectors(), inner_k, output_distence.data(),
G
groot 已提交
63
                                  output_ids.data());
G
groot 已提交
64 65 66 67 68

            rc.Record("do search");

            //step 3: cluster result
            SearchContext::ResultSet result_set;
69
            auto spec_k = index_engine_->Count() < context->topk() ? index_engine_->Count() : context->topk();
Y
yu yunfeng 已提交
70
            ClusterResult(output_ids, output_distence, context->nq(), spec_k, result_set);
G
groot 已提交
71 72 73
            rc.Record("cluster result");

            //step 4: pick up topk result
74
            SearchTask::TopkResult(result_set, inner_k, context->GetResult());
G
groot 已提交
75 76
            rc.Record("reduce topk");

G
groot 已提交
77 78
        } catch (std::exception& ex) {
            SERVER_LOG_ERROR << "SearchTask encounter exception: " << ex.what();
G
groot 已提交
79 80
            context->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed
            continue;
G
groot 已提交
81 82
        }

G
groot 已提交
83
        //step 5: notify to send result to client
G
groot 已提交
84 85 86
        context->IndexSearchDone(index_id_);
    }

G
groot 已提交
87 88 89 90
    auto end_time = METRICS_NOW_TIME;
    auto total_time = METRICS_MICROSECONDS(start_time, end_time);
    CollectDurationMetrics(index_type_, total_time);

G
groot 已提交
91 92
    rc.Elapse("totally cost");

G
groot 已提交
93
    return nullptr;
G
groot 已提交
94 95
}

96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
Status SearchTask::ClusterResult(const std::vector<long> &output_ids,
                                 const std::vector<float> &output_distence,
                                 uint64_t nq,
                                 uint64_t topk,
                                 SearchContext::ResultSet &result_set) {
    if(output_ids.size() != nq*topk || output_distence.size() != nq*topk) {
        std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) +
                " distance array size: " + std::to_string(output_distence.size());
        SERVER_LOG_ERROR << msg;
        return Status::Error(msg);
    }

    result_set.clear();
    result_set.reserve(nq);
    for (auto i = 0; i < nq; i++) {
        SearchContext::Id2DistanceMap id_distance;
        id_distance.reserve(topk);
        for (auto k = 0; k < topk; k++) {
            uint64_t index = i * topk + k;
            if(output_ids[index] < 0) {
                continue;
            }
            id_distance.push_back(std::make_pair(output_ids[index], output_distence[index]));
        }
        result_set.emplace_back(id_distance);
    }

    return Status::OK();
}

Status SearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
                               SearchContext::Id2DistanceMap &distance_target,
                               uint64_t topk) {
    //Note: the score_src and score_target are already arranged by score in ascending order
    if(distance_src.empty()) {
        SERVER_LOG_WARNING << "Empty distance source array";
        return Status::OK();
    }

    if(distance_target.empty()) {
        distance_target.swap(distance_src);
        return Status::OK();
    }

    size_t src_count = distance_src.size();
    size_t target_count = distance_target.size();
    SearchContext::Id2DistanceMap distance_merged;
    distance_merged.reserve(topk);
    size_t src_index = 0, target_index = 0;
    while(true) {
        //all score_src items are merged, if score_merged.size() still less than topk
        //move items from score_target to score_merged until score_merged.size() equal topk
        if(src_index >= src_count) {
            for(size_t i = target_index; i < target_count && distance_merged.size() < topk; ++i) {
                distance_merged.push_back(distance_target[i]);
            }
            break;
        }

        //all score_target items are merged, if score_merged.size() still less than topk
        //move items from score_src to score_merged until score_merged.size() equal topk
        if(target_index >= target_count) {
            for(size_t i = src_index; i < src_count && distance_merged.size() < topk; ++i) {
                distance_merged.push_back(distance_src[i]);
            }
            break;
        }

        //compare score, put smallest score to score_merged one by one
        auto& src_pair = distance_src[src_index];
        auto& target_pair = distance_target[target_index];
        if(src_pair.second > target_pair.second) {
            distance_merged.push_back(target_pair);
            target_index++;
        } else {
            distance_merged.push_back(src_pair);
            src_index++;
        }

        //score_merged.size() already equal topk
        if(distance_merged.size() >= topk) {
            break;
        }
    }

    distance_target.swap(distance_merged);

    return Status::OK();
}

Status SearchTask::TopkResult(SearchContext::ResultSet &result_src,
                              uint64_t topk,
                              SearchContext::ResultSet &result_target) {
    if (result_target.empty()) {
        result_target.swap(result_src);
        return Status::OK();
    }

    if (result_src.size() != result_target.size()) {
        std::string msg = "Invalid result set size";
        SERVER_LOG_ERROR << msg;
        return Status::Error(msg);
    }

    for (size_t i = 0; i < result_src.size(); i++) {
        SearchContext::Id2DistanceMap &score_src = result_src[i];
        SearchContext::Id2DistanceMap &score_target = result_target[i];
        SearchTask::MergeResult(score_src, score_target, topk);
    }

    return Status::OK();
}

G
groot 已提交
209 210 211
}
}
}