GrpcRequestHandler.cpp 39.0 KB
Newer Older
1
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
J
jinhai 已提交
2
//
3 4
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
J
jinhai 已提交
5
//
6 7 8 9 10
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
11 12 13

#include "server/grpc_impl/GrpcRequestHandler.h"

S
shengjh 已提交
14
#include <fiu-local.h>
J
JinHai-CN 已提交
15
#include <algorithm>
Z
Zhiru Zhu 已提交
16
#include <memory>
17
#include <string>
Z
Zhiru Zhu 已提交
18
#include <unordered_map>
19
#include <utility>
Z
Zhiru Zhu 已提交
20 21
#include <vector>

22 23
#include "context/HybridSearchContext.h"
#include "query/BinaryQuery.h"
Z
Zhiru Zhu 已提交
24 25
#include "tracing/TextMapCarrier.h"
#include "tracing/TracerUtil.h"
B
BossZou 已提交
26
#include "utils/Log.h"
27
#include "utils/LogUtil.h"
K
kun yu 已提交
28 29 30 31
#include "utils/TimeRecorder.h"

namespace milvus {
namespace server {
Y
Yu Kun 已提交
32
namespace grpc {
K
kun yu 已提交
33

34 35 36 37 38 39 40 41 42 43 44 45 46
::milvus::grpc::ErrorCode
ErrorMap(ErrorCode code) {
    static const std::map<ErrorCode, ::milvus::grpc::ErrorCode> code_map = {
        {SERVER_UNEXPECTED_ERROR, ::milvus::grpc::ErrorCode::UNEXPECTED_ERROR},
        {SERVER_UNSUPPORTED_ERROR, ::milvus::grpc::ErrorCode::UNEXPECTED_ERROR},
        {SERVER_NULL_POINTER, ::milvus::grpc::ErrorCode::UNEXPECTED_ERROR},
        {SERVER_INVALID_ARGUMENT, ::milvus::grpc::ErrorCode::ILLEGAL_ARGUMENT},
        {SERVER_FILE_NOT_FOUND, ::milvus::grpc::ErrorCode::FILE_NOT_FOUND},
        {SERVER_NOT_IMPLEMENT, ::milvus::grpc::ErrorCode::UNEXPECTED_ERROR},
        {SERVER_CANNOT_CREATE_FOLDER, ::milvus::grpc::ErrorCode::CANNOT_CREATE_FOLDER},
        {SERVER_CANNOT_CREATE_FILE, ::milvus::grpc::ErrorCode::CANNOT_CREATE_FILE},
        {SERVER_CANNOT_DELETE_FOLDER, ::milvus::grpc::ErrorCode::CANNOT_DELETE_FOLDER},
        {SERVER_CANNOT_DELETE_FILE, ::milvus::grpc::ErrorCode::CANNOT_DELETE_FILE},
G
groot 已提交
47 48 49
        {SERVER_COLLECTION_NOT_EXIST, ::milvus::grpc::ErrorCode::COLLECTION_NOT_EXISTS},
        {SERVER_INVALID_COLLECTION_NAME, ::milvus::grpc::ErrorCode::ILLEGAL_COLLECTION_NAME},
        {SERVER_INVALID_COLLECTION_DIMENSION, ::milvus::grpc::ErrorCode::ILLEGAL_DIMENSION},
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
        {SERVER_INVALID_VECTOR_DIMENSION, ::milvus::grpc::ErrorCode::ILLEGAL_DIMENSION},

        {SERVER_INVALID_INDEX_TYPE, ::milvus::grpc::ErrorCode::ILLEGAL_INDEX_TYPE},
        {SERVER_INVALID_ROWRECORD, ::milvus::grpc::ErrorCode::ILLEGAL_ROWRECORD},
        {SERVER_INVALID_ROWRECORD_ARRAY, ::milvus::grpc::ErrorCode::ILLEGAL_ROWRECORD},
        {SERVER_INVALID_TOPK, ::milvus::grpc::ErrorCode::ILLEGAL_TOPK},
        {SERVER_INVALID_NPROBE, ::milvus::grpc::ErrorCode::ILLEGAL_ARGUMENT},
        {SERVER_INVALID_INDEX_NLIST, ::milvus::grpc::ErrorCode::ILLEGAL_NLIST},
        {SERVER_INVALID_INDEX_METRIC_TYPE, ::milvus::grpc::ErrorCode::ILLEGAL_METRIC_TYPE},
        {SERVER_INVALID_INDEX_FILE_SIZE, ::milvus::grpc::ErrorCode::ILLEGAL_ARGUMENT},
        {SERVER_ILLEGAL_VECTOR_ID, ::milvus::grpc::ErrorCode::ILLEGAL_VECTOR_ID},
        {SERVER_ILLEGAL_SEARCH_RESULT, ::milvus::grpc::ErrorCode::ILLEGAL_SEARCH_RESULT},
        {SERVER_CACHE_FULL, ::milvus::grpc::ErrorCode::CACHE_FAILED},
        {DB_META_TRANSACTION_FAILED, ::milvus::grpc::ErrorCode::META_FAILED},
        {SERVER_BUILD_INDEX_ERROR, ::milvus::grpc::ErrorCode::BUILD_INDEX_ERROR},
        {SERVER_OUT_OF_MEMORY, ::milvus::grpc::ErrorCode::OUT_OF_MEMORY},
    };

    if (code_map.find(code) != code_map.end()) {
        return code_map.at(code);
    } else {
        return ::milvus::grpc::ErrorCode::UNEXPECTED_ERROR;
    }
}

G
groot 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
namespace {
void
CopyRowRecords(const google::protobuf::RepeatedPtrField<::milvus::grpc::RowRecord>& grpc_records,
               const google::protobuf::RepeatedField<google::protobuf::int64>& grpc_id_array,
               engine::VectorsData& vectors) {
    // step 1: copy vector data
    int64_t float_data_size = 0, binary_data_size = 0;
    for (auto& record : grpc_records) {
        float_data_size += record.float_data_size();
        binary_data_size += record.binary_data().size();
    }

    std::vector<float> float_array(float_data_size, 0.0f);
    std::vector<uint8_t> binary_array(binary_data_size, 0);
    int64_t float_offset = 0, binary_offset = 0;
    if (float_data_size > 0) {
        for (auto& record : grpc_records) {
            memcpy(&float_array[float_offset], record.float_data().data(), record.float_data_size() * sizeof(float));
            float_offset += record.float_data_size();
        }
    } else if (binary_data_size > 0) {
        for (auto& record : grpc_records) {
            memcpy(&binary_array[binary_offset], record.binary_data().data(), record.binary_data().size());
            binary_offset += record.binary_data().size();
        }
    }

    // step 2: copy id array
    std::vector<int64_t> id_array;
    if (grpc_id_array.size() > 0) {
        id_array.resize(grpc_id_array.size());
        memcpy(id_array.data(), grpc_id_array.data(), grpc_id_array.size() * sizeof(int64_t));
    }

    // step 3: contruct vectors
    vectors.vector_count_ = grpc_records.size();
    vectors.float_data_.swap(float_array);
    vectors.binary_data_.swap(binary_array);
    vectors.id_array_.swap(id_array);
}

116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
void
ConstructResults(const TopKQueryResult& result, ::milvus::grpc::TopKQueryResult* response) {
    if (!response) {
        return;
    }

    response->set_row_num(result.row_num_);

    response->mutable_ids()->Resize(static_cast<int>(result.id_list_.size()), 0);
    memcpy(response->mutable_ids()->mutable_data(), result.id_list_.data(), result.id_list_.size() * sizeof(int64_t));

    response->mutable_distances()->Resize(static_cast<int>(result.distance_list_.size()), 0.0);
    memcpy(response->mutable_distances()->mutable_data(), result.distance_list_.data(),
           result.distance_list_.size() * sizeof(float));
}

void
ConstructPartitionStat(const PartitionStat& partition_stat, ::milvus::grpc::PartitionStat* grpc_partition_stat) {
    if (!grpc_partition_stat) {
        return;
    }

    grpc_partition_stat->set_total_row_count(partition_stat.total_row_num_);
    grpc_partition_stat->set_tag(partition_stat.tag_);

    for (auto& seg_stat : partition_stat.segments_stat_) {
        ::milvus::grpc::SegmentStat* grpc_seg_stat = grpc_partition_stat->mutable_segments_stat()->Add();
        grpc_seg_stat->set_row_count(seg_stat.row_num_);
        grpc_seg_stat->set_segment_name(seg_stat.name_);
        grpc_seg_stat->set_index_name(seg_stat.index_name_);
        grpc_seg_stat->set_data_size(seg_stat.data_size_);
    }
}

void
G
groot 已提交
151
ConstructCollectionInfo(const CollectionInfo& collection_info, ::milvus::grpc::CollectionInfo* response) {
152 153 154 155
    if (!response) {
        return;
    }

156
    response->set_total_row_count(collection_info.total_row_num_);
157

158
    for (auto& partition_stat : collection_info.partitions_stat_) {
159 160 161 162 163
        ::milvus::grpc::PartitionStat* grpc_partiton_stat = response->mutable_partitions_stat()->Add();
        ConstructPartitionStat(partition_stat, grpc_partiton_stat);
    }
}

G
groot 已提交
164 165
}  // namespace

166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
namespace {

#define REQ_ID ("request_id")

std::atomic<int64_t> _sequential_id;

int64_t
get_sequential_id() {
    return _sequential_id++;
}

void
set_request_id(::grpc::ServerContext* context, const std::string& request_id) {
    if (not context) {
        // error
181
        LOG_SERVER_ERROR_ << "set_request_id: grpc::ServerContext is nullptr" << std::endl;
182 183 184 185 186 187 188 189 190 191
        return;
    }

    context->AddInitialMetadata(REQ_ID, request_id);
}

std::string
get_request_id(::grpc::ServerContext* context) {
    if (not context) {
        // error
192
        LOG_SERVER_ERROR_ << "get_request_id: grpc::ServerContext is nullptr" << std::endl;
193 194 195 196 197 198 199 200
        return "INVALID_ID";
    }

    auto server_metadata = context->server_metadata();

    auto request_id_kv = server_metadata.find(REQ_ID);
    if (request_id_kv == server_metadata.end()) {
        // error
201
        LOG_SERVER_ERROR_ << std::string(REQ_ID) << " not found in grpc.server_metadata" << std::endl;
202 203 204 205 206 207 208 209
        return "INVALID_ID";
    }

    return request_id_kv->second.data();
}

}  // namespace

Z
Zhiru Zhu 已提交
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
GrpcRequestHandler::GrpcRequestHandler(const std::shared_ptr<opentracing::Tracer>& tracer)
    : tracer_(tracer), random_num_generator_() {
    std::random_device random_device;
    random_num_generator_.seed(random_device());
}

void
GrpcRequestHandler::OnPostRecvInitialMetaData(
    ::grpc::experimental::ServerRpcInfo* server_rpc_info,
    ::grpc::experimental::InterceptorBatchMethods* interceptor_batch_methods) {
    std::unordered_map<std::string, std::string> text_map;
    auto* metadata_map = interceptor_batch_methods->GetRecvInitialMetadata();
    auto context_kv = metadata_map->find(tracing::TracerUtil::GetTraceContextHeaderName());
    if (context_kv != metadata_map->end()) {
        text_map[std::string(context_kv->first.data(), context_kv->first.length())] =
            std::string(context_kv->second.data(), context_kv->second.length());
    }
    // test debug mode
    //    if (std::string(server_rpc_info->method()).find("Search") != std::string::npos) {
    //        text_map["demo-debug-id"] = "debug-id";
    //    }

    tracing::TextMapCarrier carrier{text_map};
    auto span_context_maybe = tracer_->Extract(carrier);
    if (!span_context_maybe) {
        std::cerr << span_context_maybe.error().message() << std::endl;
        return;
    }
    auto span = tracer_->StartSpan(server_rpc_info->method(), {opentracing::ChildOf(span_context_maybe->get())});
239

Z
Zhiru Zhu 已提交
240 241
    auto server_context = server_rpc_info->server_context();
    auto client_metadata = server_context->client_metadata();
242 243 244

    // if client provide request_id in metadata, milvus just use it,
    // else milvus generate a sequential id.
Z
Zhiru Zhu 已提交
245 246 247 248
    std::string request_id;
    auto request_id_kv = client_metadata.find("request_id");
    if (request_id_kv != client_metadata.end()) {
        request_id = request_id_kv->second.data();
249
        LOG_SERVER_DEBUG_ << "client provide request_id: " << request_id;
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268

        // if request_id is being used by another request,
        // convert it to request_id_n.
        std::lock_guard<std::mutex> lock(context_map_mutex_);
        if (context_map_.find(request_id) == context_map_.end()) {
            // if not found exist, mark
            context_map_[request_id] = nullptr;
        } else {
            // Finding a unused suffix
            int64_t suffix = 1;
            std::string try_request_id;
            bool exist = true;
            do {
                try_request_id = request_id + "_" + std::to_string(suffix);
                exist = context_map_.find(try_request_id) != context_map_.end();
                suffix++;
            } while (exist);
            context_map_[try_request_id] = nullptr;
        }
Z
Zhiru Zhu 已提交
269
    } else {
270 271
        request_id = std::to_string(get_sequential_id());
        set_request_id(server_context, request_id);
272
        LOG_SERVER_DEBUG_ << "milvus generate request_id: " << request_id;
Z
Zhiru Zhu 已提交
273
    }
274

Z
Zhiru Zhu 已提交
275 276 277
    auto trace_context = std::make_shared<tracing::TraceContext>(span);
    auto context = std::make_shared<Context>(request_id);
    context->SetTraceContext(trace_context);
278
    SetContext(server_rpc_info->server_context(), context);
Z
Zhiru Zhu 已提交
279 280 281 282 283
}

void
GrpcRequestHandler::OnPreSendMessage(::grpc::experimental::ServerRpcInfo* server_rpc_info,
                                     ::grpc::experimental::InterceptorBatchMethods* interceptor_batch_methods) {
284
    std::lock_guard<std::mutex> lock(context_map_mutex_);
285 286 287 288
    auto request_id = get_request_id(server_rpc_info->server_context());

    if (context_map_.find(request_id) == context_map_.end()) {
        // error
289
        LOG_SERVER_ERROR_ << "request_id " << request_id << " not found in context_map_";
290
        return;
Z
Zhiru Zhu 已提交
291
    }
292 293
    context_map_[request_id]->GetTraceContext()->GetSpan()->Finish();
    context_map_.erase(request_id);
Z
Zhiru Zhu 已提交
294 295
}

J
Jin Hai 已提交
296
std::shared_ptr<Context>
Z
Zhiru Zhu 已提交
297
GrpcRequestHandler::GetContext(::grpc::ServerContext* server_context) {
298
    std::lock_guard<std::mutex> lock(context_map_mutex_);
299 300
    auto request_id = get_request_id(server_context);
    if (context_map_.find(request_id) == context_map_.end()) {
301
        LOG_SERVER_ERROR_ << "GetContext: request_id " << request_id << " not found in context_map_";
302 303 304
        return nullptr;
    }
    return context_map_[request_id];
Z
Zhiru Zhu 已提交
305 306 307 308
}

void
GrpcRequestHandler::SetContext(::grpc::ServerContext* server_context, const std::shared_ptr<Context>& context) {
309
    std::lock_guard<std::mutex> lock(context_map_mutex_);
310 311
    auto request_id = get_request_id(server_context);
    context_map_[request_id] = context;
Z
Zhiru Zhu 已提交
312 313 314 315 316 317 318 319 320 321 322 323 324 325
}

uint64_t
GrpcRequestHandler::random_id() const {
    std::lock_guard<std::mutex> lock(random_mutex_);
    auto value = random_num_generator_();
    while (value == 0) {
        value = random_num_generator_();
    }
    return value;
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

K
kun yu 已提交
326
::grpc::Status
G
groot 已提交
327 328
GrpcRequestHandler::CreateCollection(::grpc::ServerContext* context, const ::milvus::grpc::CollectionSchema* request,
                                     ::milvus::grpc::Status* response) {
329 330
    CHECK_NULLPTR_RETURN(request);

331
    Status status =
332
        request_handler_.CreateCollection(GetContext(context), request->collection_name(), request->dimension(),
333
                                          request->index_file_size(), request->metric_type());
334 335
    SET_RESPONSE(response, status, context);

K
kun yu 已提交
336 337 338 339
    return ::grpc::Status::OK;
}

::grpc::Status
G
groot 已提交
340 341
GrpcRequestHandler::HasCollection(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
                                  ::milvus::grpc::BoolReply* response) {
342 343
    CHECK_NULLPTR_RETURN(request);

344
    bool has_collection = false;
345

346
    Status status = request_handler_.HasCollection(GetContext(context), request->collection_name(), has_collection);
347
    response->set_bool_reply(has_collection);
348 349
    SET_RESPONSE(response->mutable_status(), status, context);

K
kun yu 已提交
350 351 352 353
    return ::grpc::Status::OK;
}

::grpc::Status
G
groot 已提交
354 355
GrpcRequestHandler::DropCollection(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
                                   ::milvus::grpc::Status* response) {
356 357
    CHECK_NULLPTR_RETURN(request);

358
    Status status = request_handler_.DropCollection(GetContext(context), request->collection_name());
359 360

    SET_RESPONSE(response, status, context);
K
kun yu 已提交
361 362 363 364
    return ::grpc::Status::OK;
}

::grpc::Status
S
starlord 已提交
365 366
GrpcRequestHandler::CreateIndex(::grpc::ServerContext* context, const ::milvus::grpc::IndexParam* request,
                                ::milvus::grpc::Status* response) {
367 368
    CHECK_NULLPTR_RETURN(request);

369 370 371 372 373 374 375 376
    milvus::json json_params;
    for (int i = 0; i < request->extra_params_size(); i++) {
        const ::milvus::grpc::KeyValuePair& extra = request->extra_params(i);
        if (extra.key() == EXTRA_PARAM_KEY) {
            json_params = json::parse(extra.value());
        }
    }

377 378
    Status status = request_handler_.CreateIndex(GetContext(context), request->collection_name(), request->index_type(),
                                                 json_params);
379 380

    SET_RESPONSE(response, status, context);
K
kun yu 已提交
381 382 383 384
    return ::grpc::Status::OK;
}

::grpc::Status
S
starlord 已提交
385 386
GrpcRequestHandler::Insert(::grpc::ServerContext* context, const ::milvus::grpc::InsertParam* request,
                           ::milvus::grpc::VectorIds* response) {
387 388
    CHECK_NULLPTR_RETURN(request);

389
    LOG_SERVER_INFO_ << LogOut("[%s][%d] Start insert.", "insert", 0);
390

G
groot 已提交
391 392 393
    // step 1: copy vector data
    engine::VectorsData vectors;
    CopyRowRecords(request->row_record_array(), request->row_id_array(), vectors);
394

G
groot 已提交
395
    // step 2: insert vectors
396
    Status status =
397
        request_handler_.Insert(GetContext(context), request->collection_name(), vectors, request->partition_tag());
398

G
groot 已提交
399 400 401 402
    // step 3: return id array
    response->mutable_vector_id_array()->Resize(static_cast<int>(vectors.id_array_.size()), 0);
    memcpy(response->mutable_vector_id_array()->mutable_data(), vectors.id_array_.data(),
           vectors.id_array_.size() * sizeof(int64_t));
403

404
    LOG_SERVER_INFO_ << LogOut("[%s][%d] Insert done.", "insert", 0);
405
    SET_RESPONSE(response->mutable_status(), status, context);
K
kun yu 已提交
406 407 408
    return ::grpc::Status::OK;
}

409 410 411 412 413 414 415
::grpc::Status
GrpcRequestHandler::GetVectorByID(::grpc::ServerContext* context, const ::milvus::grpc::VectorIdentity* request,
                                  ::milvus::grpc::VectorData* response) {
    CHECK_NULLPTR_RETURN(request);

    std::vector<int64_t> vector_ids = {request->id()};
    engine::VectorsData vectors;
G
groot 已提交
416
    Status status =
417
        request_handler_.GetVectorByID(GetContext(context), request->collection_name(), vector_ids, vectors);
418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438

    if (!vectors.float_data_.empty()) {
        response->mutable_vector_data()->mutable_float_data()->Resize(vectors.float_data_.size(), 0);
        memcpy(response->mutable_vector_data()->mutable_float_data()->mutable_data(), vectors.float_data_.data(),
               vectors.float_data_.size() * sizeof(float));
    } else if (!vectors.binary_data_.empty()) {
        response->mutable_vector_data()->mutable_binary_data()->resize(vectors.binary_data_.size());
        memcpy(response->mutable_vector_data()->mutable_binary_data()->data(), vectors.binary_data_.data(),
               vectors.binary_data_.size() * sizeof(uint8_t));
    }
    SET_RESPONSE(response->mutable_status(), status, context);

    return ::grpc::Status::OK;
}

::grpc::Status
GrpcRequestHandler::GetVectorIDs(::grpc::ServerContext* context, const ::milvus::grpc::GetVectorIDsParam* request,
                                 ::milvus::grpc::VectorIds* response) {
    CHECK_NULLPTR_RETURN(request);

    std::vector<int64_t> vector_ids;
439
    Status status = request_handler_.GetVectorIDs(GetContext(context), request->collection_name(),
G
groot 已提交
440
                                                  request->segment_name(), vector_ids);
441 442 443 444 445 446 447 448 449 450 451

    if (!vector_ids.empty()) {
        response->mutable_vector_id_array()->Resize(vector_ids.size(), -1);
        memcpy(response->mutable_vector_id_array()->mutable_data(), vector_ids.data(),
               vector_ids.size() * sizeof(int64_t));
    }
    SET_RESPONSE(response->mutable_status(), status, context);

    return ::grpc::Status::OK;
}

K
kun yu 已提交
452
::grpc::Status
S
starlord 已提交
453
GrpcRequestHandler::Search(::grpc::ServerContext* context, const ::milvus::grpc::SearchParam* request,
454
                           ::milvus::grpc::TopKQueryResult* response) {
455 456
    CHECK_NULLPTR_RETURN(request);

457
    LOG_SERVER_INFO_ << LogOut("[%s][%d] Search start in gRPC server", "search", 0);
G
groot 已提交
458 459 460
    // step 1: copy vector data
    engine::VectorsData vectors;
    CopyRowRecords(request->query_record_array(), google::protobuf::RepeatedField<google::protobuf::int64>(), vectors);
461

G
groot 已提交
462
    // step 2: partition tags
463
    std::vector<std::string> partitions;
J
JinHai-CN 已提交
464 465
    std::copy(request->partition_tag_array().begin(), request->partition_tag_array().end(),
              std::back_inserter(partitions));
466

467 468 469 470 471 472 473 474 475 476
    // step 3: parse extra parameters
    milvus::json json_params;
    for (int i = 0; i < request->extra_params_size(); i++) {
        const ::milvus::grpc::KeyValuePair& extra = request->extra_params(i);
        if (extra.key() == EXTRA_PARAM_KEY) {
            json_params = json::parse(extra.value());
        }
    }

    // step 4: search vectors
477 478
    std::vector<std::string> file_ids;
    TopKQueryResult result;
S
shengjh 已提交
479
    fiu_do_on("GrpcRequestHandler.Search.not_empty_file_ids", file_ids.emplace_back("test_file_id"));
480 481

    Status status = request_handler_.Search(GetContext(context), request->collection_name(), vectors, request->topk(),
482
                                            json_params, partitions, file_ids, result);
483

484
    // step 5: construct and return result
485
    ConstructResults(result, response);
486

487
    LOG_SERVER_INFO_ << LogOut("[%s][%d] Search done.", "search", 0);
488

489
    SET_RESPONSE(response->mutable_status(), status, context);
490

491 492 493 494 495 496 497 498 499 500 501 502 503 504
    return ::grpc::Status::OK;
}

::grpc::Status
GrpcRequestHandler::SearchByID(::grpc::ServerContext* context, const ::milvus::grpc::SearchByIDParam* request,
                               ::milvus::grpc::TopKQueryResult* response) {
    CHECK_NULLPTR_RETURN(request);

    // step 1: partition tags
    std::vector<std::string> partitions;
    for (auto& partition : request->partition_tag_array()) {
        partitions.emplace_back(partition);
    }

505 506 507 508 509 510 511 512 513 514
    // step 2: parse extra parameters
    milvus::json json_params;
    for (int i = 0; i < request->extra_params_size(); i++) {
        const ::milvus::grpc::KeyValuePair& extra = request->extra_params(i);
        if (extra.key() == EXTRA_PARAM_KEY) {
            json_params = json::parse(extra.value());
        }
    }

    // step 3: search vectors
515
    TopKQueryResult result;
516
    Status status = request_handler_.SearchByID(GetContext(context), request->collection_name(), request->id(),
517
                                                request->topk(), json_params, partitions, result);
518

519
    // step 4: construct and return result
520
    ConstructResults(result, response);
521 522 523

    SET_RESPONSE(response->mutable_status(), status, context);

Y
Yu Kun 已提交
524
    return ::grpc::Status::OK;
K
kun yu 已提交
525 526 527
}

::grpc::Status
S
starlord 已提交
528
GrpcRequestHandler::SearchInFiles(::grpc::ServerContext* context, const ::milvus::grpc::SearchInFilesParam* request,
529
                                  ::milvus::grpc::TopKQueryResult* response) {
530 531 532 533
    CHECK_NULLPTR_RETURN(request);

    auto* search_request = &request->search_param();

G
groot 已提交
534 535 536 537
    // step 1: copy vector data
    engine::VectorsData vectors;
    CopyRowRecords(search_request->query_record_array(), google::protobuf::RepeatedField<google::protobuf::int64>(),
                   vectors);
538

G
groot 已提交
539 540 541 542 543 544 545
    // step 2: copy file id array
    std::vector<std::string> file_ids;
    for (auto& file_id : request->file_id_array()) {
        file_ids.emplace_back(file_id);
    }

    // step 3: partition tags
546
    std::vector<std::string> partitions;
PJZero's avatar
PJZero 已提交
547 548
    std::copy(search_request->partition_tag_array().begin(), search_request->partition_tag_array().end(),
              std::back_inserter(partitions));
549

550 551 552 553 554 555 556 557 558 559
    // step 4: parse extra parameters
    milvus::json json_params;
    for (int i = 0; i < search_request->extra_params_size(); i++) {
        const ::milvus::grpc::KeyValuePair& extra = search_request->extra_params(i);
        if (extra.key() == EXTRA_PARAM_KEY) {
            json_params = json::parse(extra.value());
        }
    }

    // step 5: search vectors
560
    TopKQueryResult result;
561
    Status status = request_handler_.Search(GetContext(context), search_request->collection_name(), vectors,
562
                                            search_request->topk(), json_params, partitions, file_ids, result);
563

564
    // step 6: construct and return result
565
    ConstructResults(result, response);
566 567 568

    SET_RESPONSE(response->mutable_status(), status, context);

Y
Yu Kun 已提交
569
    return ::grpc::Status::OK;
K
kun yu 已提交
570 571 572
}

::grpc::Status
G
groot 已提交
573 574
GrpcRequestHandler::DescribeCollection(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
                                       ::milvus::grpc::CollectionSchema* response) {
575 576
    CHECK_NULLPTR_RETURN(request);

G
groot 已提交
577 578
    CollectionSchema collection_schema;
    Status status =
579
        request_handler_.DescribeCollection(GetContext(context), request->collection_name(), collection_schema);
G
groot 已提交
580 581 582 583
    response->set_collection_name(collection_schema.collection_name_);
    response->set_dimension(collection_schema.dimension_);
    response->set_index_file_size(collection_schema.index_file_size_);
    response->set_metric_type(collection_schema.metric_type_);
584 585

    SET_RESPONSE(response->mutable_status(), status, context);
K
kun yu 已提交
586 587 588 589
    return ::grpc::Status::OK;
}

::grpc::Status
G
groot 已提交
590 591
GrpcRequestHandler::CountCollection(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
                                    ::milvus::grpc::CollectionRowCount* response) {
592 593
    CHECK_NULLPTR_RETURN(request);

K
kun yu 已提交
594
    int64_t row_count = 0;
595
    Status status = request_handler_.CountCollection(GetContext(context), request->collection_name(), row_count);
G
groot 已提交
596
    response->set_collection_row_count(row_count);
597
    SET_RESPONSE(response->mutable_status(), status, context);
K
kun yu 已提交
598 599 600 601
    return ::grpc::Status::OK;
}

::grpc::Status
G
groot 已提交
602 603
GrpcRequestHandler::ShowCollections(::grpc::ServerContext* context, const ::milvus::grpc::Command* request,
                                    ::milvus::grpc::CollectionNameList* response) {
604 605
    CHECK_NULLPTR_RETURN(request);

G
groot 已提交
606
    std::vector<std::string> collections;
607
    Status status = request_handler_.ShowCollections(GetContext(context), collections);
G
groot 已提交
608 609
    for (auto& collection : collections) {
        response->add_collection_names(collection);
610 611 612
    }
    SET_RESPONSE(response->mutable_status(), status, context);

613
    return ::grpc::Status::OK;
K
kun yu 已提交
614 615
}

616
::grpc::Status
G
groot 已提交
617 618
GrpcRequestHandler::ShowCollectionInfo(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
                                       ::milvus::grpc::CollectionInfo* response) {
619 620
    CHECK_NULLPTR_RETURN(request);

621
    CollectionInfo collection_info;
G
groot 已提交
622
    Status status =
623
        request_handler_.ShowCollectionInfo(GetContext(context), request->collection_name(), collection_info);
G
groot 已提交
624
    ConstructCollectionInfo(collection_info, response);
625 626 627 628 629
    SET_RESPONSE(response->mutable_status(), status, context);

    return ::grpc::Status::OK;
}

K
kun yu 已提交
630
::grpc::Status
S
starlord 已提交
631 632
GrpcRequestHandler::Cmd(::grpc::ServerContext* context, const ::milvus::grpc::Command* request,
                        ::milvus::grpc::StringReply* response) {
633 634 635
    CHECK_NULLPTR_RETURN(request);

    std::string reply;
636
    Status status = request_handler_.Cmd(GetContext(context), request->cmd(), reply);
637 638 639
    response->set_string_reply(reply);
    SET_RESPONSE(response->mutable_status(), status, context);

K
kun yu 已提交
640 641 642
    return ::grpc::Status::OK;
}

Y
Yu Kun 已提交
643
::grpc::Status
644 645
GrpcRequestHandler::DeleteByID(::grpc::ServerContext* context, const ::milvus::grpc::DeleteByIDParam* request,
                               ::milvus::grpc::Status* response) {
646 647
    CHECK_NULLPTR_RETURN(request);

648 649 650 651 652 653 654
    // step 1: prepare id array
    std::vector<int64_t> vector_ids;
    for (int i = 0; i < request->id_array_size(); i++) {
        vector_ids.push_back(request->id_array(i));
    }

    // step 2: delete vector
655
    Status status = request_handler_.DeleteByID(GetContext(context), request->collection_name(), vector_ids);
656 657
    SET_RESPONSE(response, status, context);

658
    return ::grpc::Status::OK;
Y
Yu Kun 已提交
659 660 661
}

::grpc::Status
G
groot 已提交
662 663
GrpcRequestHandler::PreloadCollection(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
                                      ::milvus::grpc::Status* response) {
664 665
    CHECK_NULLPTR_RETURN(request);

666
    Status status = request_handler_.PreloadCollection(GetContext(context), request->collection_name());
667 668
    SET_RESPONSE(response, status, context);

Y
Yu Kun 已提交
669
    return ::grpc::Status::OK;
Y
Yu Kun 已提交
670 671 672
}

::grpc::Status
G
groot 已提交
673
GrpcRequestHandler::DescribeIndex(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
S
starlord 已提交
674
                                  ::milvus::grpc::IndexParam* response) {
675 676 677
    CHECK_NULLPTR_RETURN(request);

    IndexParam param;
678
    Status status = request_handler_.DescribeIndex(GetContext(context), request->collection_name(), param);
G
groot 已提交
679
    response->set_collection_name(param.collection_name_);
680 681 682 683
    response->set_index_type(param.index_type_);
    ::milvus::grpc::KeyValuePair* kv = response->add_extra_params();
    kv->set_key(EXTRA_PARAM_KEY);
    kv->set_value(param.extra_params_);
684 685
    SET_RESPONSE(response->mutable_status(), status, context);

686
    return ::grpc::Status::OK;
Y
Yu Kun 已提交
687 688 689
}

::grpc::Status
G
groot 已提交
690
GrpcRequestHandler::DropIndex(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
S
starlord 已提交
691
                              ::milvus::grpc::Status* response) {
692 693
    CHECK_NULLPTR_RETURN(request);

694
    Status status = request_handler_.DropIndex(GetContext(context), request->collection_name());
695 696
    SET_RESPONSE(response, status, context);

697
    return ::grpc::Status::OK;
Y
Yu Kun 已提交
698 699
}

G
groot 已提交
700 701 702
::grpc::Status
GrpcRequestHandler::CreatePartition(::grpc::ServerContext* context, const ::milvus::grpc::PartitionParam* request,
                                    ::milvus::grpc::Status* response) {
703 704
    CHECK_NULLPTR_RETURN(request);

705
    Status status = request_handler_.CreatePartition(GetContext(context), request->collection_name(), request->tag());
706 707
    SET_RESPONSE(response, status, context);

G
groot 已提交
708 709 710 711
    return ::grpc::Status::OK;
}

::grpc::Status
G
groot 已提交
712
GrpcRequestHandler::ShowPartitions(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
G
groot 已提交
713
                                   ::milvus::grpc::PartitionList* response) {
714 715 716
    CHECK_NULLPTR_RETURN(request);

    std::vector<PartitionParam> partitions;
717
    Status status = request_handler_.ShowPartitions(GetContext(context), request->collection_name(), partitions);
718
    for (auto& partition : partitions) {
719
        response->add_partition_tag_array(partition.tag_);
720 721 722 723
    }

    SET_RESPONSE(response->mutable_status(), status, context);

G
groot 已提交
724 725 726 727 728 729
    return ::grpc::Status::OK;
}

::grpc::Status
GrpcRequestHandler::DropPartition(::grpc::ServerContext* context, const ::milvus::grpc::PartitionParam* request,
                                  ::milvus::grpc::Status* response) {
730 731
    CHECK_NULLPTR_RETURN(request);

732
    Status status = request_handler_.DropPartition(GetContext(context), request->collection_name(), request->tag());
733 734 735 736 737 738 739 740 741 742
    SET_RESPONSE(response, status, context);

    return ::grpc::Status::OK;
}

::grpc::Status
GrpcRequestHandler::Flush(::grpc::ServerContext* context, const ::milvus::grpc::FlushParam* request,
                          ::milvus::grpc::Status* response) {
    CHECK_NULLPTR_RETURN(request);

J
Jin Hai 已提交
743
    std::vector<std::string> collection_names;
G
groot 已提交
744 745
    for (int32_t i = 0; i < request->collection_name_array().size(); i++) {
        collection_names.push_back(request->collection_name_array(i));
746
    }
747
    Status status = request_handler_.Flush(GetContext(context), collection_names);
748 749 750 751 752 753
    SET_RESPONSE(response, status, context);

    return ::grpc::Status::OK;
}

::grpc::Status
G
groot 已提交
754
GrpcRequestHandler::Compact(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
755 756 757
                            ::milvus::grpc::Status* response) {
    CHECK_NULLPTR_RETURN(request);

758
    Status status = request_handler_.Compact(GetContext(context), request->collection_name());
759 760
    SET_RESPONSE(response, status, context);

G
groot 已提交
761 762 763
    return ::grpc::Status::OK;
}

764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806
/*******************************************New Interface*********************************************/

::grpc::Status
GrpcRequestHandler::CreateHybridCollection(::grpc::ServerContext* context, const ::milvus::grpc::Mapping* request,
                                           ::milvus::grpc::Status* response) {
    CHECK_NULLPTR_RETURN(request);

    std::vector<std::pair<std::string, engine::meta::hybrid::DataType>> field_types;
    std::vector<std::pair<std::string, uint64_t>> vector_dimensions;
    std::vector<std::pair<std::string, std::string>> field_params;
    for (uint64_t i = 0; i < request->fields_size(); ++i) {
        if (request->fields(i).type().has_vector_param()) {
            auto vector_dimension =
                std::make_pair(request->fields(i).name(), request->fields(i).type().vector_param().dimension());
            vector_dimensions.emplace_back(vector_dimension);
        } else {
            auto type = std::make_pair(request->fields(i).name(),
                                       (engine::meta::hybrid::DataType)request->fields(i).type().data_type());
            field_types.emplace_back(type);
        }
        // Currently only one extra_param
        if (request->fields(i).extra_params_size() != 0) {
            auto extra_params = std::make_pair(request->fields(i).name(), request->fields(i).extra_params(0).value());
            field_params.emplace_back(extra_params);
        } else {
            auto extra_params = std::make_pair(request->fields(i).name(), "");
            field_params.emplace_back(extra_params);
        }
    }

    Status status = request_handler_.CreateHybridCollection(GetContext(context), request->collection_name(),
                                                            field_types, vector_dimensions, field_params);

    SET_RESPONSE(response, status, context);

    return ::grpc::Status::OK;
}

::grpc::Status
GrpcRequestHandler::InsertEntity(::grpc::ServerContext* context, const ::milvus::grpc::HInsertParam* request,
                                 ::milvus::grpc::HEntityIDs* response) {
    CHECK_NULLPTR_RETURN(request);

807 808
    auto attr_size = request->entities().attr_records().size();
    std::vector<uint8_t> attr_values(attr_size, 0);
809 810
    std::unordered_map<std::string, engine::VectorsData> vector_datas;

811 812 813 814 815 816 817 818 819
    memcpy(attr_values.data(), request->entities().attr_records().data(), attr_size);

    uint64_t row_num = request->entities().row_num();

    std::vector<std::string> field_names;
    auto field_size = request->entities().field_names_size();
    field_names.resize(field_size - 1);
    for (uint64_t i = 0; i < field_size - 1; ++i) {
        field_names[i] = request->entities().field_names(i);
820 821 822 823 824 825 826
    }

    auto vector_size = request->entities().result_values_size();
    for (uint64_t i = 0; i < vector_size; ++i) {
        engine::VectorsData vectors;
        CopyRowRecords(request->entities().result_values(i).vector_value().value(), request->entity_id_array(),
                       vectors);
827
        vector_datas.insert(std::make_pair(request->entities().field_names(field_size - 1), vectors));
828 829 830 831
    }

    std::string collection_name = request->collection_name();
    std::string partition_tag = request->partition_tag();
832 833
    Status status = request_handler_.InsertEntity(GetContext(context), collection_name, partition_tag, row_num,
                                                  field_names, attr_values, vector_datas);
834 835 836 837 838 839 840 841 842 843

    response->mutable_entity_id_array()->Resize(static_cast<int>(vector_datas.begin()->second.id_array_.size()), 0);
    memcpy(response->mutable_entity_id_array()->mutable_data(), vector_datas.begin()->second.id_array_.data(),
           vector_datas.begin()->second.id_array_.size() * sizeof(int64_t));

    SET_RESPONSE(response->mutable_status(), status, context);
    return ::grpc::Status::OK;
}

void
844
DeSerialization(const ::milvus::grpc::GeneralQuery& general_query, query::BooleanQueryPtr& boolean_clause) {
845
    if (general_query.has_boolean_query()) {
846
        boolean_clause->SetOccur((query::Occur)general_query.boolean_query().occur());
847 848
        for (uint64_t i = 0; i < general_query.boolean_query().general_query_size(); ++i) {
            if (general_query.boolean_query().general_query(i).has_boolean_query()) {
849
                query::BooleanQueryPtr query = std::make_shared<query::BooleanQuery>();
850 851 852 853 854 855 856 857 858
                DeSerialization(general_query.boolean_query().general_query(i), query);
                boolean_clause->AddBooleanQuery(query);
            } else {
                auto leaf_query = std::make_shared<query::LeafQuery>();
                auto query = general_query.boolean_query().general_query(i);
                if (query.has_term_query()) {
                    query::TermQueryPtr term_query = std::make_shared<query::TermQuery>();
                    term_query->field_name = query.term_query().field_name();
                    term_query->boost = query.term_query().boost();
859 860 861
                    auto size = query.term_query().values().size();
                    term_query->field_value.resize(size);
                    memcpy(term_query->field_value.data(), query.term_query().values().data(), size);
862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950
                    leaf_query->term_query = term_query;
                    boolean_clause->AddLeafQuery(leaf_query);
                }
                if (query.has_range_query()) {
                    query::RangeQueryPtr range_query = std::make_shared<query::RangeQuery>();
                    range_query->field_name = query.range_query().field_name();
                    range_query->boost = query.range_query().boost();
                    range_query->compare_expr.resize(query.range_query().operand_size());
                    for (uint64_t j = 0; j < query.range_query().operand_size(); ++j) {
                        range_query->compare_expr[j].compare_operator =
                            query::CompareOperator(query.range_query().operand(j).operator_());
                        range_query->compare_expr[j].operand = query.range_query().operand(j).operand();
                    }
                    leaf_query->range_query = range_query;
                    boolean_clause->AddLeafQuery(leaf_query);
                }
                if (query.has_vector_query()) {
                    query::VectorQueryPtr vector_query = std::make_shared<query::VectorQuery>();

                    engine::VectorsData vectors;
                    CopyRowRecords(query.vector_query().records(),
                                   google::protobuf::RepeatedField<google::protobuf::int64>(), vectors);

                    vector_query->query_vector.float_data = vectors.float_data_;
                    vector_query->query_vector.binary_data = vectors.binary_data_;

                    vector_query->boost = query.vector_query().query_boost();
                    vector_query->field_name = query.vector_query().field_name();
                    vector_query->topk = query.vector_query().topk();

                    milvus::json json_params;
                    for (int j = 0; j < query.vector_query().extra_params_size(); j++) {
                        const ::milvus::grpc::KeyValuePair& extra = query.vector_query().extra_params(j);
                        if (extra.key() == EXTRA_PARAM_KEY) {
                            json_params = json::parse(extra.value());
                        }
                    }
                    vector_query->extra_params = json_params;
                    leaf_query->vector_query = vector_query;
                    boolean_clause->AddLeafQuery(leaf_query);
                }
            }
        }
    }
}

::grpc::Status
GrpcRequestHandler::HybridSearch(::grpc::ServerContext* context, const ::milvus::grpc::HSearchParam* request,
                                 ::milvus::grpc::TopKQueryResult* response) {
    CHECK_NULLPTR_RETURN(request);

    context::HybridSearchContextPtr hybrid_search_context = std::make_shared<context::HybridSearchContext>();

    query::BooleanQueryPtr boolean_query = std::make_shared<query::BooleanQuery>();
    DeSerialization(request->general_query(), boolean_query);

    query::GeneralQueryPtr general_query = std::make_shared<query::GeneralQuery>();
    general_query->bin = std::make_shared<query::BinaryQuery>();
    query::GenBinaryQuery(boolean_query, general_query->bin);

    Status status;

    if (!query::ValidateBinaryQuery(general_query->bin)) {
        status = Status{SERVER_INVALID_BINARY_QUERY, "Generate wrong binary query tree"};
        SET_RESPONSE(response->mutable_status(), status, context);
        return ::grpc::Status::OK;
    }

    hybrid_search_context->general_query_ = general_query;

    std::vector<std::string> partition_list;
    partition_list.resize(request->partition_tag_array_size());
    for (uint64_t i = 0; i < request->partition_tag_array_size(); ++i) {
        partition_list[i] = request->partition_tag_array(i);
    }

    TopKQueryResult result;

    status = request_handler_.HybridSearch(GetContext(context), hybrid_search_context, request->collection_name(),
                                           partition_list, general_query, result);

    // step 6: construct and return result
    ConstructResults(result, response);

    SET_RESPONSE(response->mutable_status(), status, context);

    return ::grpc::Status::OK;
}

S
starlord 已提交
951 952 953
}  // namespace grpc
}  // namespace server
}  // namespace milvus