SearchTask.cpp 13.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
/*******************************************************************************
 * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
 * Unauthorized copying of this file, via any medium is strictly prohibited.
 * Proprietary and confidential.
 ******************************************************************************/

#include "src/metrics/Metrics.h"
#include "src/utils/TimeRecorder.h"
#include "src/db/engine/EngineFactory.h"
#include "src/db/Log.h"
#include "SearchTask.h"

#include <thread>


namespace zilliz {
namespace milvus {
namespace engine {

static constexpr size_t PARALLEL_REDUCE_THRESHOLD = 10000;
static constexpr size_t PARALLEL_REDUCE_BATCH = 1000;

J
jinhai 已提交
23 24
std::mutex XSearchTask::merge_mutex_;

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
//bool
//NeedParallelReduce(uint64_t nq, uint64_t topk) {
//    server::ServerConfig &config = server::ServerConfig::GetInstance();
//    server::ConfigNode &db_config = config.GetConfig(server::CONFIG_DB);
//    bool need_parallel = db_config.GetBoolValue(server::CONFIG_DB_PARALLEL_REDUCE, false);
//    if (!need_parallel) {
//        return false;
//    }
//
//    return nq * topk >= PARALLEL_REDUCE_THRESHOLD;
//}
//
//void
//ParallelReduce(std::function<void(size_t, size_t)> &reduce_function, size_t max_index) {
//    size_t reduce_batch = PARALLEL_REDUCE_BATCH;
//
//    auto thread_count = std::thread::hardware_concurrency() - 1; //not all core do this work
//    if (thread_count > 0) {
//        reduce_batch = max_index / thread_count + 1;
//    }
//    ENGINE_LOG_DEBUG << "use " << thread_count <<
//                     " thread parallelly do reduce, each thread process " << reduce_batch << " vectors";
//
//    std::vector<std::shared_ptr<std::thread> > thread_array;
//    size_t from_index = 0;
//    while (from_index < max_index) {
//        size_t to_index = from_index + reduce_batch;
//        if (to_index > max_index) {
//            to_index = max_index;
//        }
//
//        auto reduce_thread = std::make_shared<std::thread>(reduce_function, from_index, to_index);
//        thread_array.push_back(reduce_thread);
//
//        from_index = to_index;
//    }
//
//    for (auto &thread_ptr : thread_array) {
//        thread_ptr->join();
//    }
//}
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

void
CollectFileMetrics(int file_type, size_t file_size) {
    switch (file_type) {
        case meta::TableFileSchema::RAW:
        case meta::TableFileSchema::TO_INDEX: {
            server::Metrics::GetInstance().RawFileSizeHistogramObserve(file_size);
            server::Metrics::GetInstance().RawFileSizeTotalIncrement(file_size);
            server::Metrics::GetInstance().RawFileSizeGaugeSet(file_size);
            break;
        }
        default: {
            server::Metrics::GetInstance().IndexFileSizeHistogramObserve(file_size);
            server::Metrics::GetInstance().IndexFileSizeTotalIncrement(file_size);
            server::Metrics::GetInstance().IndexFileSizeGaugeSet(file_size);
            break;
        }
    }
}

W
wxyu 已提交
86 87
XSearchTask::XSearchTask(TableFileSchemaPtr file)
    : Task(TaskType::SearchTask), file_(file) {
W
wxyu 已提交
88 89 90 91 92 93 94 95
    if (file_) {
        index_engine_ = EngineFactory::Build(file_->dimension_,
                                             file_->location_,
                                             (EngineType) file_->engine_type_,
                                             (MetricType) file_->metric_type_,
                                             file_->nlist_);
    }

W
wxyu 已提交
96 97
}

98 99 100
void
XSearchTask::Load(LoadType type, uint8_t device_id) {
    server::TimeRecorder rc("");
W
wxyu 已提交
101 102
    Status stat = Status::OK();
    std::string error_msg;
103
    std::string type_str;
104 105

    try {
W
wxyu 已提交
106
        if (type == LoadType::DISK2CPU) {
W
wxyu 已提交
107
            stat = index_engine_->Load();
108
            type_str = "DISK2CPU";
W
wxyu 已提交
109
        } else if (type == LoadType::CPU2GPU) {
W
wxyu 已提交
110
            stat = index_engine_->CopyToGpu(device_id);
111
            type_str = "CPU2GPU";
W
wxyu 已提交
112
        } else if (type == LoadType::GPU2CPU) {
W
wxyu 已提交
113
            stat = index_engine_->CopyToCpu();
114
            type_str = "GPU2CPU";
W
wxyu 已提交
115
        } else {
W
wxyu 已提交
116 117
            error_msg = "Wrong load type";
            stat = Status(SERVER_UNEXPECTED_ERROR, error_msg);
W
wxyu 已提交
118
        }
119 120
    } catch (std::exception &ex) {
        //typical error: out of disk space or permition denied
W
wxyu 已提交
121 122 123 124 125
        error_msg = "Failed to load index file: " + std::string(ex.what());
        stat = Status(SERVER_UNEXPECTED_ERROR, error_msg);
    }

    if (!stat.ok()) {
126 127 128 129 130 131 132 133
        Status s;
        if (stat.ToString().find("out of memory") != std::string::npos) {
            error_msg = "out of memory: " + type_str;
            s = Status(SERVER_OUT_OF_MEMORY, error_msg);
        } else {
            error_msg = "Failed to load index file: " + type_str;
            s = Status(SERVER_UNEXPECTED_ERROR, error_msg);
        }
134 135 136

        for (auto &context : search_contexts_) {
            context->IndexSearchDone(file_->id_);//mark as done avoid dead lock, even failed
137
            context->GetStatus() = s;
138 139 140 141 142
        }

        return;
    }

W
wxyu 已提交
143
    size_t file_size = index_engine_->PhysicalSize();
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165

    std::string info = "Load file id:" + std::to_string(file_->id_) + " file type:" + std::to_string(file_->file_type_)
        + " size:" + std::to_string(file_size) + " bytes from location: " + file_->location_ + " totally cost";
    double span = rc.ElapseFromBegin(info);
    for (auto &context : search_contexts_) {
        context->AccumLoadCost(span);
    }

    CollectFileMetrics(file_->file_type_, file_size);

    //step 2: return search task for later execution
    index_id_ = file_->id_;
    index_type_ = file_->file_type_;
    search_contexts_.swap(search_contexts_);
}

void
XSearchTask::Execute() {
    if (index_engine_ == nullptr) {
        return;
    }

Y
Yu Kun 已提交
166
    ENGINE_LOG_DEBUG << "Searching in file id:" << index_id_ << " with "
167 168
                     << search_contexts_.size() << " tasks";

Y
Yu Kun 已提交
169
    server::TimeRecorder rc("DoSearch file id:" + std::to_string(index_id_));
170

Y
Yu Kun 已提交
171
    server::CollectDurationMetrics metrics(index_type_);
172 173

    std::vector<long> output_ids;
J
jinhai 已提交
174
    std::vector<float> output_distance;
175 176
    for (auto &context : search_contexts_) {
        //step 1: allocate memory
J
jinhai 已提交
177 178 179 180
        uint64_t nq = context->nq();
        uint64_t topk = context->topk();
        uint64_t nprobe = context->nprobe();
        const float* vectors = context->vectors();
Y
yudong.cai 已提交
181 182 183 184

        output_ids.resize(topk * nq);
        output_distance.resize(topk * nq);
        std::string hdr = "context " + context->Identity() +
Y
Yu Kun 已提交
185 186
            " nq " + std::to_string(nq) +
            " topk " + std::to_string(topk);
187 188 189

        try {
            //step 2: search
Y
yudong.cai 已提交
190
            index_engine_->Search(nq, vectors, topk, nprobe, output_distance.data(), output_ids.data());
191

Y
yudong.cai 已提交
192
            double span = rc.RecordSection(hdr + ", do search");
193 194
            context->AccumSearchCost(span);

195

196 197
            //step 3: cluster result
            SearchContext::ResultSet result_set;
Y
yudong.cai 已提交
198 199
            auto spec_k = index_engine_->Count() < topk ? index_engine_->Count() : topk;
            XSearchTask::ClusterResult(output_ids, output_distance, nq, spec_k, result_set);
200

Y
yudong.cai 已提交
201
            span = rc.RecordSection(hdr + ", cluster result");
202 203
            context->AccumReduceCost(span);

J
jinhai 已提交
204
            // step 4: pick up topk result
Y
yudong.cai 已提交
205
            XSearchTask::TopkResult(result_set, topk, metric_l2, context->GetResult());
206

Y
yudong.cai 已提交
207
            span = rc.RecordSection(hdr + ", reduce topk");
208 209 210 211 212 213 214 215 216 217 218 219
            context->AccumReduceCost(span);
        } catch (std::exception &ex) {
            ENGINE_LOG_ERROR << "SearchTask encounter exception: " << ex.what();
            context->IndexSearchDone(index_id_);//mark as done avoid dead lock, even search failed
            continue;
        }

        //step 5: notify to send result to client
        context->IndexSearchDone(index_id_);
    }

    rc.ElapseFromBegin("totally cost");
220 221 222

    // release index in resource
    index_engine_ = nullptr;
223 224 225
}

Status XSearchTask::ClusterResult(const std::vector<long> &output_ids,
J
jinhai 已提交
226
                                  const std::vector<float> &output_distance,
227 228 229
                                  uint64_t nq,
                                  uint64_t topk,
                                  SearchContext::ResultSet &result_set) {
J
jinhai 已提交
230
    if (output_ids.size() < nq * topk || output_distance.size() < nq * topk) {
231
        std::string msg = "Invalid id array size: " + std::to_string(output_ids.size()) +
J
jinhai 已提交
232
            " distance array size: " + std::to_string(output_distance.size());
233
        ENGINE_LOG_ERROR << msg;
G
groot 已提交
234
        return Status(DB_ERROR, msg);
235 236 237 238 239 240 241 242 243 244 245 246 247 248
    }

    result_set.clear();
    result_set.resize(nq);

    std::function<void(size_t, size_t)> reduce_worker = [&](size_t from_index, size_t to_index) {
        for (auto i = from_index; i < to_index; i++) {
            SearchContext::Id2DistanceMap id_distance;
            id_distance.reserve(topk);
            for (auto k = 0; k < topk; k++) {
                uint64_t index = i * topk + k;
                if (output_ids[index] < 0) {
                    continue;
                }
J
jinhai 已提交
249
                id_distance.push_back(std::make_pair(output_ids[index], output_distance[index]));
250 251 252 253 254
            }
            result_set[i] = id_distance;
        }
    };

255 256 257
//    if (NeedParallelReduce(nq, topk)) {
//        ParallelReduce(reduce_worker, nq);
//    } else {
W
wxyu 已提交
258
    reduce_worker(0, nq);
259
//    }
260 261 262 263 264 265 266 267 268 269 270 271 272 273

    return Status::OK();
}

Status XSearchTask::MergeResult(SearchContext::Id2DistanceMap &distance_src,
                                SearchContext::Id2DistanceMap &distance_target,
                                uint64_t topk,
                                bool ascending) {
    //Note: the score_src and score_target are already arranged by score in ascending order
    if (distance_src.empty()) {
        ENGINE_LOG_WARNING << "Empty distance source array";
        return Status::OK();
    }

J
jinhai 已提交
274
    std::unique_lock<std::mutex> lock(merge_mutex_);
275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
    if (distance_target.empty()) {
        distance_target.swap(distance_src);
        return Status::OK();
    }

    size_t src_count = distance_src.size();
    size_t target_count = distance_target.size();
    SearchContext::Id2DistanceMap distance_merged;
    distance_merged.reserve(topk);
    size_t src_index = 0, target_index = 0;
    while (true) {
        //all score_src items are merged, if score_merged.size() still less than topk
        //move items from score_target to score_merged until score_merged.size() equal topk
        if (src_index >= src_count) {
            for (size_t i = target_index; i < target_count && distance_merged.size() < topk; ++i) {
                distance_merged.push_back(distance_target[i]);
            }
            break;
        }

        //all score_target items are merged, if score_merged.size() still less than topk
        //move items from score_src to score_merged until score_merged.size() equal topk
        if (target_index >= target_count) {
            for (size_t i = src_index; i < src_count && distance_merged.size() < topk; ++i) {
                distance_merged.push_back(distance_src[i]);
            }
            break;
        }

        //compare score,
        // if ascending = true, put smallest score to score_merged one by one
        // else, put largest score to score_merged one by one
        auto &src_pair = distance_src[src_index];
        auto &target_pair = distance_target[target_index];
        if (ascending) {
            if (src_pair.second > target_pair.second) {
                distance_merged.push_back(target_pair);
                target_index++;
            } else {
                distance_merged.push_back(src_pair);
                src_index++;
            }
        } else {
            if (src_pair.second < target_pair.second) {
                distance_merged.push_back(target_pair);
                target_index++;
            } else {
                distance_merged.push_back(src_pair);
                src_index++;
            }
        }

        //score_merged.size() already equal topk
        if (distance_merged.size() >= topk) {
            break;
        }
    }

    distance_target.swap(distance_merged);

    return Status::OK();
}

Status XSearchTask::TopkResult(SearchContext::ResultSet &result_src,
                               uint64_t topk,
                               bool ascending,
                               SearchContext::ResultSet &result_target) {
    if (result_target.empty()) {
        result_target.swap(result_src);
        return Status::OK();
    }

    if (result_src.size() != result_target.size()) {
        std::string msg = "Invalid result set size";
        ENGINE_LOG_ERROR << msg;
G
groot 已提交
350
        return Status(DB_ERROR, msg);
351 352 353 354 355 356 357 358 359 360
    }

    std::function<void(size_t, size_t)> ReduceWorker = [&](size_t from_index, size_t to_index) {
        for (size_t i = from_index; i < to_index; i++) {
            SearchContext::Id2DistanceMap &score_src = result_src[i];
            SearchContext::Id2DistanceMap &score_target = result_target[i];
            XSearchTask::MergeResult(score_src, score_target, topk, ascending);
        }
    };

361 362 363
//    if (NeedParallelReduce(result_src.size(), topk)) {
//        ParallelReduce(ReduceWorker, result_src.size());
//    } else {
W
wxyu 已提交
364
    ReduceWorker(0, result_src.size());
365
//    }
366 367 368 369

    return Status::OK();
}

W
wxyu 已提交
370

371 372 373
}
}
}