GrpcRequestHandler.cpp 39.0 KB
Newer Older
1
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
J
jinhai 已提交
2
//
3 4
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
J
jinhai 已提交
5
//
6 7 8 9 10
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
11 12 13

#include "server/grpc_impl/GrpcRequestHandler.h"

S
shengjh 已提交
14
#include <fiu-local.h>
J
JinHai-CN 已提交
15
#include <algorithm>
Z
Zhiru Zhu 已提交
16
#include <memory>
17
#include <string>
Z
Zhiru Zhu 已提交
18
#include <unordered_map>
19
#include <utility>
Z
Zhiru Zhu 已提交
20 21
#include <vector>

22 23
#include "context/HybridSearchContext.h"
#include "query/BinaryQuery.h"
Z
Zhiru Zhu 已提交
24 25
#include "tracing/TextMapCarrier.h"
#include "tracing/TracerUtil.h"
B
BossZou 已提交
26
#include "utils/Log.h"
27
#include "utils/LogUtil.h"
K
kun yu 已提交
28 29 30 31
#include "utils/TimeRecorder.h"

namespace milvus {
namespace server {
Y
Yu Kun 已提交
32
namespace grpc {
K
kun yu 已提交
33

34 35 36 37 38 39 40 41 42 43 44 45 46
::milvus::grpc::ErrorCode
ErrorMap(ErrorCode code) {
    static const std::map<ErrorCode, ::milvus::grpc::ErrorCode> code_map = {
        {SERVER_UNEXPECTED_ERROR, ::milvus::grpc::ErrorCode::UNEXPECTED_ERROR},
        {SERVER_UNSUPPORTED_ERROR, ::milvus::grpc::ErrorCode::UNEXPECTED_ERROR},
        {SERVER_NULL_POINTER, ::milvus::grpc::ErrorCode::UNEXPECTED_ERROR},
        {SERVER_INVALID_ARGUMENT, ::milvus::grpc::ErrorCode::ILLEGAL_ARGUMENT},
        {SERVER_FILE_NOT_FOUND, ::milvus::grpc::ErrorCode::FILE_NOT_FOUND},
        {SERVER_NOT_IMPLEMENT, ::milvus::grpc::ErrorCode::UNEXPECTED_ERROR},
        {SERVER_CANNOT_CREATE_FOLDER, ::milvus::grpc::ErrorCode::CANNOT_CREATE_FOLDER},
        {SERVER_CANNOT_CREATE_FILE, ::milvus::grpc::ErrorCode::CANNOT_CREATE_FILE},
        {SERVER_CANNOT_DELETE_FOLDER, ::milvus::grpc::ErrorCode::CANNOT_DELETE_FOLDER},
        {SERVER_CANNOT_DELETE_FILE, ::milvus::grpc::ErrorCode::CANNOT_DELETE_FILE},
G
groot 已提交
47 48 49
        {SERVER_COLLECTION_NOT_EXIST, ::milvus::grpc::ErrorCode::COLLECTION_NOT_EXISTS},
        {SERVER_INVALID_COLLECTION_NAME, ::milvus::grpc::ErrorCode::ILLEGAL_COLLECTION_NAME},
        {SERVER_INVALID_COLLECTION_DIMENSION, ::milvus::grpc::ErrorCode::ILLEGAL_DIMENSION},
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
        {SERVER_INVALID_VECTOR_DIMENSION, ::milvus::grpc::ErrorCode::ILLEGAL_DIMENSION},

        {SERVER_INVALID_INDEX_TYPE, ::milvus::grpc::ErrorCode::ILLEGAL_INDEX_TYPE},
        {SERVER_INVALID_ROWRECORD, ::milvus::grpc::ErrorCode::ILLEGAL_ROWRECORD},
        {SERVER_INVALID_ROWRECORD_ARRAY, ::milvus::grpc::ErrorCode::ILLEGAL_ROWRECORD},
        {SERVER_INVALID_TOPK, ::milvus::grpc::ErrorCode::ILLEGAL_TOPK},
        {SERVER_INVALID_NPROBE, ::milvus::grpc::ErrorCode::ILLEGAL_ARGUMENT},
        {SERVER_INVALID_INDEX_NLIST, ::milvus::grpc::ErrorCode::ILLEGAL_NLIST},
        {SERVER_INVALID_INDEX_METRIC_TYPE, ::milvus::grpc::ErrorCode::ILLEGAL_METRIC_TYPE},
        {SERVER_INVALID_INDEX_FILE_SIZE, ::milvus::grpc::ErrorCode::ILLEGAL_ARGUMENT},
        {SERVER_ILLEGAL_VECTOR_ID, ::milvus::grpc::ErrorCode::ILLEGAL_VECTOR_ID},
        {SERVER_ILLEGAL_SEARCH_RESULT, ::milvus::grpc::ErrorCode::ILLEGAL_SEARCH_RESULT},
        {SERVER_CACHE_FULL, ::milvus::grpc::ErrorCode::CACHE_FAILED},
        {DB_META_TRANSACTION_FAILED, ::milvus::grpc::ErrorCode::META_FAILED},
        {SERVER_BUILD_INDEX_ERROR, ::milvus::grpc::ErrorCode::BUILD_INDEX_ERROR},
        {SERVER_OUT_OF_MEMORY, ::milvus::grpc::ErrorCode::OUT_OF_MEMORY},
    };

    if (code_map.find(code) != code_map.end()) {
        return code_map.at(code);
    } else {
        return ::milvus::grpc::ErrorCode::UNEXPECTED_ERROR;
    }
}

G
groot 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
namespace {
void
CopyRowRecords(const google::protobuf::RepeatedPtrField<::milvus::grpc::RowRecord>& grpc_records,
               const google::protobuf::RepeatedField<google::protobuf::int64>& grpc_id_array,
               engine::VectorsData& vectors) {
    // step 1: copy vector data
    int64_t float_data_size = 0, binary_data_size = 0;
    for (auto& record : grpc_records) {
        float_data_size += record.float_data_size();
        binary_data_size += record.binary_data().size();
    }

    std::vector<float> float_array(float_data_size, 0.0f);
    std::vector<uint8_t> binary_array(binary_data_size, 0);
    int64_t float_offset = 0, binary_offset = 0;
    if (float_data_size > 0) {
        for (auto& record : grpc_records) {
            memcpy(&float_array[float_offset], record.float_data().data(), record.float_data_size() * sizeof(float));
            float_offset += record.float_data_size();
        }
    } else if (binary_data_size > 0) {
        for (auto& record : grpc_records) {
            memcpy(&binary_array[binary_offset], record.binary_data().data(), record.binary_data().size());
            binary_offset += record.binary_data().size();
        }
    }

    // step 2: copy id array
    std::vector<int64_t> id_array;
    if (grpc_id_array.size() > 0) {
        id_array.resize(grpc_id_array.size());
        memcpy(id_array.data(), grpc_id_array.data(), grpc_id_array.size() * sizeof(int64_t));
    }

    // step 3: contruct vectors
    vectors.vector_count_ = grpc_records.size();
    vectors.float_data_.swap(float_array);
    vectors.binary_data_.swap(binary_array);
    vectors.id_array_.swap(id_array);
}

116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
void
ConstructResults(const TopKQueryResult& result, ::milvus::grpc::TopKQueryResult* response) {
    if (!response) {
        return;
    }

    response->set_row_num(result.row_num_);

    response->mutable_ids()->Resize(static_cast<int>(result.id_list_.size()), 0);
    memcpy(response->mutable_ids()->mutable_data(), result.id_list_.data(), result.id_list_.size() * sizeof(int64_t));

    response->mutable_distances()->Resize(static_cast<int>(result.distance_list_.size()), 0.0);
    memcpy(response->mutable_distances()->mutable_data(), result.distance_list_.data(),
           result.distance_list_.size() * sizeof(float));
}

void
ConstructPartitionStat(const PartitionStat& partition_stat, ::milvus::grpc::PartitionStat* grpc_partition_stat) {
    if (!grpc_partition_stat) {
        return;
    }

    grpc_partition_stat->set_total_row_count(partition_stat.total_row_num_);
    grpc_partition_stat->set_tag(partition_stat.tag_);

    for (auto& seg_stat : partition_stat.segments_stat_) {
        ::milvus::grpc::SegmentStat* grpc_seg_stat = grpc_partition_stat->mutable_segments_stat()->Add();
        grpc_seg_stat->set_row_count(seg_stat.row_num_);
        grpc_seg_stat->set_segment_name(seg_stat.name_);
        grpc_seg_stat->set_index_name(seg_stat.index_name_);
        grpc_seg_stat->set_data_size(seg_stat.data_size_);
    }
}

void
G
groot 已提交
151
ConstructCollectionInfo(const CollectionInfo& collection_info, ::milvus::grpc::CollectionInfo* response) {
152 153 154 155
    if (!response) {
        return;
    }

156
    response->set_total_row_count(collection_info.total_row_num_);
157

158
    for (auto& partition_stat : collection_info.partitions_stat_) {
159 160 161 162 163
        ::milvus::grpc::PartitionStat* grpc_partiton_stat = response->mutable_partitions_stat()->Add();
        ConstructPartitionStat(partition_stat, grpc_partiton_stat);
    }
}

G
groot 已提交
164 165
}  // namespace

166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
namespace {

#define REQ_ID ("request_id")

std::atomic<int64_t> _sequential_id;

int64_t
get_sequential_id() {
    return _sequential_id++;
}

void
set_request_id(::grpc::ServerContext* context, const std::string& request_id) {
    if (not context) {
        // error
181
        LOG_SERVER_ERROR_ << "set_request_id: grpc::ServerContext is nullptr" << std::endl;
182 183 184 185 186 187 188 189 190 191
        return;
    }

    context->AddInitialMetadata(REQ_ID, request_id);
}

std::string
get_request_id(::grpc::ServerContext* context) {
    if (not context) {
        // error
192
        LOG_SERVER_ERROR_ << "get_request_id: grpc::ServerContext is nullptr" << std::endl;
193 194 195 196 197 198 199 200
        return "INVALID_ID";
    }

    auto server_metadata = context->server_metadata();

    auto request_id_kv = server_metadata.find(REQ_ID);
    if (request_id_kv == server_metadata.end()) {
        // error
201
        LOG_SERVER_ERROR_ << std::string(REQ_ID) << " not found in grpc.server_metadata" << std::endl;
202 203 204 205 206 207 208 209
        return "INVALID_ID";
    }

    return request_id_kv->second.data();
}

}  // namespace

Z
Zhiru Zhu 已提交
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
GrpcRequestHandler::GrpcRequestHandler(const std::shared_ptr<opentracing::Tracer>& tracer)
    : tracer_(tracer), random_num_generator_() {
    std::random_device random_device;
    random_num_generator_.seed(random_device());
}

void
GrpcRequestHandler::OnPostRecvInitialMetaData(
    ::grpc::experimental::ServerRpcInfo* server_rpc_info,
    ::grpc::experimental::InterceptorBatchMethods* interceptor_batch_methods) {
    std::unordered_map<std::string, std::string> text_map;
    auto* metadata_map = interceptor_batch_methods->GetRecvInitialMetadata();
    auto context_kv = metadata_map->find(tracing::TracerUtil::GetTraceContextHeaderName());
    if (context_kv != metadata_map->end()) {
        text_map[std::string(context_kv->first.data(), context_kv->first.length())] =
            std::string(context_kv->second.data(), context_kv->second.length());
    }
    // test debug mode
    //    if (std::string(server_rpc_info->method()).find("Search") != std::string::npos) {
    //        text_map["demo-debug-id"] = "debug-id";
    //    }

    tracing::TextMapCarrier carrier{text_map};
    auto span_context_maybe = tracer_->Extract(carrier);
    if (!span_context_maybe) {
        std::cerr << span_context_maybe.error().message() << std::endl;
        return;
    }
    auto span = tracer_->StartSpan(server_rpc_info->method(), {opentracing::ChildOf(span_context_maybe->get())});
239

Z
Zhiru Zhu 已提交
240 241
    auto server_context = server_rpc_info->server_context();
    auto client_metadata = server_context->client_metadata();
242 243 244

    // if client provide request_id in metadata, milvus just use it,
    // else milvus generate a sequential id.
Z
Zhiru Zhu 已提交
245 246 247 248
    std::string request_id;
    auto request_id_kv = client_metadata.find("request_id");
    if (request_id_kv != client_metadata.end()) {
        request_id = request_id_kv->second.data();
249
        LOG_SERVER_DEBUG_ << "client provide request_id: " << request_id;
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268

        // if request_id is being used by another request,
        // convert it to request_id_n.
        std::lock_guard<std::mutex> lock(context_map_mutex_);
        if (context_map_.find(request_id) == context_map_.end()) {
            // if not found exist, mark
            context_map_[request_id] = nullptr;
        } else {
            // Finding a unused suffix
            int64_t suffix = 1;
            std::string try_request_id;
            bool exist = true;
            do {
                try_request_id = request_id + "_" + std::to_string(suffix);
                exist = context_map_.find(try_request_id) != context_map_.end();
                suffix++;
            } while (exist);
            context_map_[try_request_id] = nullptr;
        }
Z
Zhiru Zhu 已提交
269
    } else {
270 271
        request_id = std::to_string(get_sequential_id());
        set_request_id(server_context, request_id);
272
        LOG_SERVER_DEBUG_ << "milvus generate request_id: " << request_id;
Z
Zhiru Zhu 已提交
273
    }
274

Z
Zhiru Zhu 已提交
275 276 277
    auto trace_context = std::make_shared<tracing::TraceContext>(span);
    auto context = std::make_shared<Context>(request_id);
    context->SetTraceContext(trace_context);
278
    SetContext(server_rpc_info->server_context(), context);
Z
Zhiru Zhu 已提交
279 280 281 282 283
}

void
GrpcRequestHandler::OnPreSendMessage(::grpc::experimental::ServerRpcInfo* server_rpc_info,
                                     ::grpc::experimental::InterceptorBatchMethods* interceptor_batch_methods) {
284
    std::lock_guard<std::mutex> lock(context_map_mutex_);
285 286 287 288
    auto request_id = get_request_id(server_rpc_info->server_context());

    if (context_map_.find(request_id) == context_map_.end()) {
        // error
289
        LOG_SERVER_ERROR_ << "request_id " << request_id << " not found in context_map_";
290
        return;
Z
Zhiru Zhu 已提交
291
    }
292 293
    context_map_[request_id]->GetTraceContext()->GetSpan()->Finish();
    context_map_.erase(request_id);
Z
Zhiru Zhu 已提交
294 295
}

J
Jin Hai 已提交
296
std::shared_ptr<Context>
Z
Zhiru Zhu 已提交
297
GrpcRequestHandler::GetContext(::grpc::ServerContext* server_context) {
298
    std::lock_guard<std::mutex> lock(context_map_mutex_);
299 300
    auto request_id = get_request_id(server_context);
    if (context_map_.find(request_id) == context_map_.end()) {
301
        LOG_SERVER_ERROR_ << "GetContext: request_id " << request_id << " not found in context_map_";
302 303 304
        return nullptr;
    }
    return context_map_[request_id];
Z
Zhiru Zhu 已提交
305 306 307 308
}

void
GrpcRequestHandler::SetContext(::grpc::ServerContext* server_context, const std::shared_ptr<Context>& context) {
309
    std::lock_guard<std::mutex> lock(context_map_mutex_);
310 311
    auto request_id = get_request_id(server_context);
    context_map_[request_id] = context;
Z
Zhiru Zhu 已提交
312 313 314 315 316 317 318 319 320 321 322 323 324 325
}

uint64_t
GrpcRequestHandler::random_id() const {
    std::lock_guard<std::mutex> lock(random_mutex_);
    auto value = random_num_generator_();
    while (value == 0) {
        value = random_num_generator_();
    }
    return value;
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////

K
kun yu 已提交
326
::grpc::Status
G
groot 已提交
327 328
GrpcRequestHandler::CreateCollection(::grpc::ServerContext* context, const ::milvus::grpc::CollectionSchema* request,
                                     ::milvus::grpc::Status* response) {
329 330
    CHECK_NULLPTR_RETURN(request);

331
    Status status =
332
        request_handler_.CreateCollection(GetContext(context), request->collection_name(), request->dimension(),
333
                                          request->index_file_size(), request->metric_type());
334 335
    SET_RESPONSE(response, status, context);

K
kun yu 已提交
336 337 338 339
    return ::grpc::Status::OK;
}

::grpc::Status
G
groot 已提交
340 341
GrpcRequestHandler::HasCollection(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
                                  ::milvus::grpc::BoolReply* response) {
342 343
    CHECK_NULLPTR_RETURN(request);

344
    bool has_collection = false;
345

346
    Status status = request_handler_.HasCollection(GetContext(context), request->collection_name(), has_collection);
347
    response->set_bool_reply(has_collection);
348 349
    SET_RESPONSE(response->mutable_status(), status, context);

K
kun yu 已提交
350 351 352 353
    return ::grpc::Status::OK;
}

::grpc::Status
G
groot 已提交
354 355
GrpcRequestHandler::DropCollection(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
                                   ::milvus::grpc::Status* response) {
356 357
    CHECK_NULLPTR_RETURN(request);

358
    Status status = request_handler_.DropCollection(GetContext(context), request->collection_name());
359 360

    SET_RESPONSE(response, status, context);
K
kun yu 已提交
361 362 363 364
    return ::grpc::Status::OK;
}

::grpc::Status
S
starlord 已提交
365 366
GrpcRequestHandler::CreateIndex(::grpc::ServerContext* context, const ::milvus::grpc::IndexParam* request,
                                ::milvus::grpc::Status* response) {
367 368
    CHECK_NULLPTR_RETURN(request);

369 370 371 372 373 374 375 376
    milvus::json json_params;
    for (int i = 0; i < request->extra_params_size(); i++) {
        const ::milvus::grpc::KeyValuePair& extra = request->extra_params(i);
        if (extra.key() == EXTRA_PARAM_KEY) {
            json_params = json::parse(extra.value());
        }
    }

377 378
    Status status = request_handler_.CreateIndex(GetContext(context), request->collection_name(), request->index_type(),
                                                 json_params);
379 380

    SET_RESPONSE(response, status, context);
K
kun yu 已提交
381 382 383 384
    return ::grpc::Status::OK;
}

::grpc::Status
S
starlord 已提交
385 386
GrpcRequestHandler::Insert(::grpc::ServerContext* context, const ::milvus::grpc::InsertParam* request,
                           ::milvus::grpc::VectorIds* response) {
387 388
    CHECK_NULLPTR_RETURN(request);

389
    LOG_SERVER_INFO_ << LogOut("[%s][%d] Start insert.", "insert", 0);
390

G
groot 已提交
391 392 393
    // step 1: copy vector data
    engine::VectorsData vectors;
    CopyRowRecords(request->row_record_array(), request->row_id_array(), vectors);
394

G
groot 已提交
395
    // step 2: insert vectors
396
    Status status =
397
        request_handler_.Insert(GetContext(context), request->collection_name(), vectors, request->partition_tag());
398

G
groot 已提交
399 400 401 402
    // step 3: return id array
    response->mutable_vector_id_array()->Resize(static_cast<int>(vectors.id_array_.size()), 0);
    memcpy(response->mutable_vector_id_array()->mutable_data(), vectors.id_array_.data(),
           vectors.id_array_.size() * sizeof(int64_t));
403

404
    LOG_SERVER_INFO_ << LogOut("[%s][%d] Insert done.", "insert", 0);
405
    SET_RESPONSE(response->mutable_status(), status, context);
K
kun yu 已提交
406 407 408
    return ::grpc::Status::OK;
}

409 410 411 412 413 414 415
::grpc::Status
GrpcRequestHandler::GetVectorByID(::grpc::ServerContext* context, const ::milvus::grpc::VectorIdentity* request,
                                  ::milvus::grpc::VectorData* response) {
    CHECK_NULLPTR_RETURN(request);

    std::vector<int64_t> vector_ids = {request->id()};
    engine::VectorsData vectors;
G
groot 已提交
416
    Status status =
417
        request_handler_.GetVectorByID(GetContext(context), request->collection_name(), vector_ids, vectors);
418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438

    if (!vectors.float_data_.empty()) {
        response->mutable_vector_data()->mutable_float_data()->Resize(vectors.float_data_.size(), 0);
        memcpy(response->mutable_vector_data()->mutable_float_data()->mutable_data(), vectors.float_data_.data(),
               vectors.float_data_.size() * sizeof(float));
    } else if (!vectors.binary_data_.empty()) {
        response->mutable_vector_data()->mutable_binary_data()->resize(vectors.binary_data_.size());
        memcpy(response->mutable_vector_data()->mutable_binary_data()->data(), vectors.binary_data_.data(),
               vectors.binary_data_.size() * sizeof(uint8_t));
    }
    SET_RESPONSE(response->mutable_status(), status, context);

    return ::grpc::Status::OK;
}

::grpc::Status
GrpcRequestHandler::GetVectorIDs(::grpc::ServerContext* context, const ::milvus::grpc::GetVectorIDsParam* request,
                                 ::milvus::grpc::VectorIds* response) {
    CHECK_NULLPTR_RETURN(request);

    std::vector<int64_t> vector_ids;
439
    Status status = request_handler_.GetVectorIDs(GetContext(context), request->collection_name(),
G
groot 已提交
440
                                                  request->segment_name(), vector_ids);
441 442 443 444 445 446 447 448 449 450 451

    if (!vector_ids.empty()) {
        response->mutable_vector_id_array()->Resize(vector_ids.size(), -1);
        memcpy(response->mutable_vector_id_array()->mutable_data(), vector_ids.data(),
               vector_ids.size() * sizeof(int64_t));
    }
    SET_RESPONSE(response->mutable_status(), status, context);

    return ::grpc::Status::OK;
}

K
kun yu 已提交
452
::grpc::Status
S
starlord 已提交
453
GrpcRequestHandler::Search(::grpc::ServerContext* context, const ::milvus::grpc::SearchParam* request,
454
                           ::milvus::grpc::TopKQueryResult* response) {
455 456
    CHECK_NULLPTR_RETURN(request);

457
    LOG_SERVER_INFO_ << LogOut("[%s][%d] Search start in gRPC server", "search", 0);
G
groot 已提交
458 459 460
    // step 1: copy vector data
    engine::VectorsData vectors;
    CopyRowRecords(request->query_record_array(), google::protobuf::RepeatedField<google::protobuf::int64>(), vectors);
461

G
groot 已提交
462
    // step 2: partition tags
463
    std::vector<std::string> partitions;
J
JinHai-CN 已提交
464 465
    std::copy(request->partition_tag_array().begin(), request->partition_tag_array().end(),
              std::back_inserter(partitions));
466

467 468 469 470 471 472 473 474 475 476
    // step 3: parse extra parameters
    milvus::json json_params;
    for (int i = 0; i < request->extra_params_size(); i++) {
        const ::milvus::grpc::KeyValuePair& extra = request->extra_params(i);
        if (extra.key() == EXTRA_PARAM_KEY) {
            json_params = json::parse(extra.value());
        }
    }

    // step 4: search vectors
477 478
    std::vector<std::string> file_ids;
    TopKQueryResult result;
S
shengjh 已提交
479
    fiu_do_on("GrpcRequestHandler.Search.not_empty_file_ids", file_ids.emplace_back("test_file_id"));
480 481

    Status status = request_handler_.Search(GetContext(context), request->collection_name(), vectors, request->topk(),
482
                                            json_params, partitions, file_ids, result);
483

484
    // step 5: construct and return result
485
    ConstructResults(result, response);
486

487
    LOG_SERVER_INFO_ << LogOut("[%s][%d] Search done.", "search", 0);
488

489
    SET_RESPONSE(response->mutable_status(), status, context);
490

491 492 493 494 495 496 497 498 499 500 501 502 503 504
    return ::grpc::Status::OK;
}

::grpc::Status
GrpcRequestHandler::SearchByID(::grpc::ServerContext* context, const ::milvus::grpc::SearchByIDParam* request,
                               ::milvus::grpc::TopKQueryResult* response) {
    CHECK_NULLPTR_RETURN(request);

    // step 1: partition tags
    std::vector<std::string> partitions;
    for (auto& partition : request->partition_tag_array()) {
        partitions.emplace_back(partition);
    }

505 506 507 508 509 510 511 512 513 514
    // step 2: parse extra parameters
    milvus::json json_params;
    for (int i = 0; i < request->extra_params_size(); i++) {
        const ::milvus::grpc::KeyValuePair& extra = request->extra_params(i);
        if (extra.key() == EXTRA_PARAM_KEY) {
            json_params = json::parse(extra.value());
        }
    }

    // step 3: search vectors
515
    TopKQueryResult result;
516
    Status status = request_handler_.SearchByID(GetContext(context), request->collection_name(), request->id(),
517
                                                request->topk(), json_params, partitions, result);
518

519
    // step 4: construct and return result
520
    ConstructResults(result, response);
521 522 523

    SET_RESPONSE(response->mutable_status(), status, context);

Y
Yu Kun 已提交
524
    return ::grpc::Status::OK;
K
kun yu 已提交
525 526 527
}

::grpc::Status
S
starlord 已提交
528
GrpcRequestHandler::SearchInFiles(::grpc::ServerContext* context, const ::milvus::grpc::SearchInFilesParam* request,
529
                                  ::milvus::grpc::TopKQueryResult* response) {
530 531 532 533
    CHECK_NULLPTR_RETURN(request);

    auto* search_request = &request->search_param();

G
groot 已提交
534 535 536 537
    // step 1: copy vector data
    engine::VectorsData vectors;
    CopyRowRecords(search_request->query_record_array(), google::protobuf::RepeatedField<google::protobuf::int64>(),
                   vectors);
538

G
groot 已提交
539 540 541 542 543 544 545
    // step 2: copy file id array
    std::vector<std::string> file_ids;
    for (auto& file_id : request->file_id_array()) {
        file_ids.emplace_back(file_id);
    }

    // step 3: partition tags
546 547 548
    std::vector<std::string> partitions;
    for (auto& partition : search_request->partition_tag_array()) {
        partitions.emplace_back(partition);
S
starlord 已提交
549
    }
550

551 552 553 554 555 556 557 558 559 560
    // step 4: parse extra parameters
    milvus::json json_params;
    for (int i = 0; i < search_request->extra_params_size(); i++) {
        const ::milvus::grpc::KeyValuePair& extra = search_request->extra_params(i);
        if (extra.key() == EXTRA_PARAM_KEY) {
            json_params = json::parse(extra.value());
        }
    }

    // step 5: search vectors
561
    TopKQueryResult result;
562
    Status status = request_handler_.Search(GetContext(context), search_request->collection_name(), vectors,
563
                                            search_request->topk(), json_params, partitions, file_ids, result);
564

565
    // step 6: construct and return result
566
    ConstructResults(result, response);
567 568 569

    SET_RESPONSE(response->mutable_status(), status, context);

Y
Yu Kun 已提交
570
    return ::grpc::Status::OK;
K
kun yu 已提交
571 572 573
}

::grpc::Status
G
groot 已提交
574 575
GrpcRequestHandler::DescribeCollection(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
                                       ::milvus::grpc::CollectionSchema* response) {
576 577
    CHECK_NULLPTR_RETURN(request);

G
groot 已提交
578 579
    CollectionSchema collection_schema;
    Status status =
580
        request_handler_.DescribeCollection(GetContext(context), request->collection_name(), collection_schema);
G
groot 已提交
581 582 583 584
    response->set_collection_name(collection_schema.collection_name_);
    response->set_dimension(collection_schema.dimension_);
    response->set_index_file_size(collection_schema.index_file_size_);
    response->set_metric_type(collection_schema.metric_type_);
585 586

    SET_RESPONSE(response->mutable_status(), status, context);
K
kun yu 已提交
587 588 589 590
    return ::grpc::Status::OK;
}

::grpc::Status
G
groot 已提交
591 592
GrpcRequestHandler::CountCollection(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
                                    ::milvus::grpc::CollectionRowCount* response) {
593 594
    CHECK_NULLPTR_RETURN(request);

K
kun yu 已提交
595
    int64_t row_count = 0;
596
    Status status = request_handler_.CountCollection(GetContext(context), request->collection_name(), row_count);
G
groot 已提交
597
    response->set_collection_row_count(row_count);
598
    SET_RESPONSE(response->mutable_status(), status, context);
K
kun yu 已提交
599 600 601 602
    return ::grpc::Status::OK;
}

::grpc::Status
G
groot 已提交
603 604
GrpcRequestHandler::ShowCollections(::grpc::ServerContext* context, const ::milvus::grpc::Command* request,
                                    ::milvus::grpc::CollectionNameList* response) {
605 606
    CHECK_NULLPTR_RETURN(request);

G
groot 已提交
607
    std::vector<std::string> collections;
608
    Status status = request_handler_.ShowCollections(GetContext(context), collections);
G
groot 已提交
609 610
    for (auto& collection : collections) {
        response->add_collection_names(collection);
611 612 613
    }
    SET_RESPONSE(response->mutable_status(), status, context);

614
    return ::grpc::Status::OK;
K
kun yu 已提交
615 616
}

617
::grpc::Status
G
groot 已提交
618 619
GrpcRequestHandler::ShowCollectionInfo(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
                                       ::milvus::grpc::CollectionInfo* response) {
620 621
    CHECK_NULLPTR_RETURN(request);

622
    CollectionInfo collection_info;
G
groot 已提交
623
    Status status =
624
        request_handler_.ShowCollectionInfo(GetContext(context), request->collection_name(), collection_info);
G
groot 已提交
625
    ConstructCollectionInfo(collection_info, response);
626 627 628 629 630
    SET_RESPONSE(response->mutable_status(), status, context);

    return ::grpc::Status::OK;
}

K
kun yu 已提交
631
::grpc::Status
S
starlord 已提交
632 633
GrpcRequestHandler::Cmd(::grpc::ServerContext* context, const ::milvus::grpc::Command* request,
                        ::milvus::grpc::StringReply* response) {
634 635 636
    CHECK_NULLPTR_RETURN(request);

    std::string reply;
637
    Status status = request_handler_.Cmd(GetContext(context), request->cmd(), reply);
638 639 640
    response->set_string_reply(reply);
    SET_RESPONSE(response->mutable_status(), status, context);

K
kun yu 已提交
641 642 643
    return ::grpc::Status::OK;
}

Y
Yu Kun 已提交
644
::grpc::Status
645 646
GrpcRequestHandler::DeleteByID(::grpc::ServerContext* context, const ::milvus::grpc::DeleteByIDParam* request,
                               ::milvus::grpc::Status* response) {
647 648
    CHECK_NULLPTR_RETURN(request);

649 650 651 652 653 654 655
    // step 1: prepare id array
    std::vector<int64_t> vector_ids;
    for (int i = 0; i < request->id_array_size(); i++) {
        vector_ids.push_back(request->id_array(i));
    }

    // step 2: delete vector
656
    Status status = request_handler_.DeleteByID(GetContext(context), request->collection_name(), vector_ids);
657 658
    SET_RESPONSE(response, status, context);

659
    return ::grpc::Status::OK;
Y
Yu Kun 已提交
660 661 662
}

::grpc::Status
G
groot 已提交
663 664
GrpcRequestHandler::PreloadCollection(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
                                      ::milvus::grpc::Status* response) {
665 666
    CHECK_NULLPTR_RETURN(request);

667
    Status status = request_handler_.PreloadCollection(GetContext(context), request->collection_name());
668 669
    SET_RESPONSE(response, status, context);

Y
Yu Kun 已提交
670
    return ::grpc::Status::OK;
Y
Yu Kun 已提交
671 672 673
}

::grpc::Status
G
groot 已提交
674
GrpcRequestHandler::DescribeIndex(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
S
starlord 已提交
675
                                  ::milvus::grpc::IndexParam* response) {
676 677 678
    CHECK_NULLPTR_RETURN(request);

    IndexParam param;
679
    Status status = request_handler_.DescribeIndex(GetContext(context), request->collection_name(), param);
G
groot 已提交
680
    response->set_collection_name(param.collection_name_);
681 682 683 684
    response->set_index_type(param.index_type_);
    ::milvus::grpc::KeyValuePair* kv = response->add_extra_params();
    kv->set_key(EXTRA_PARAM_KEY);
    kv->set_value(param.extra_params_);
685 686
    SET_RESPONSE(response->mutable_status(), status, context);

687
    return ::grpc::Status::OK;
Y
Yu Kun 已提交
688 689 690
}

::grpc::Status
G
groot 已提交
691
GrpcRequestHandler::DropIndex(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
S
starlord 已提交
692
                              ::milvus::grpc::Status* response) {
693 694
    CHECK_NULLPTR_RETURN(request);

695
    Status status = request_handler_.DropIndex(GetContext(context), request->collection_name());
696 697
    SET_RESPONSE(response, status, context);

698
    return ::grpc::Status::OK;
Y
Yu Kun 已提交
699 700
}

G
groot 已提交
701 702 703
::grpc::Status
GrpcRequestHandler::CreatePartition(::grpc::ServerContext* context, const ::milvus::grpc::PartitionParam* request,
                                    ::milvus::grpc::Status* response) {
704 705
    CHECK_NULLPTR_RETURN(request);

706
    Status status = request_handler_.CreatePartition(GetContext(context), request->collection_name(), request->tag());
707 708
    SET_RESPONSE(response, status, context);

G
groot 已提交
709 710 711 712
    return ::grpc::Status::OK;
}

::grpc::Status
G
groot 已提交
713
GrpcRequestHandler::ShowPartitions(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
G
groot 已提交
714
                                   ::milvus::grpc::PartitionList* response) {
715 716 717
    CHECK_NULLPTR_RETURN(request);

    std::vector<PartitionParam> partitions;
718
    Status status = request_handler_.ShowPartitions(GetContext(context), request->collection_name(), partitions);
719
    for (auto& partition : partitions) {
720
        response->add_partition_tag_array(partition.tag_);
721 722 723 724
    }

    SET_RESPONSE(response->mutable_status(), status, context);

G
groot 已提交
725 726 727 728 729 730
    return ::grpc::Status::OK;
}

::grpc::Status
GrpcRequestHandler::DropPartition(::grpc::ServerContext* context, const ::milvus::grpc::PartitionParam* request,
                                  ::milvus::grpc::Status* response) {
731 732
    CHECK_NULLPTR_RETURN(request);

733
    Status status = request_handler_.DropPartition(GetContext(context), request->collection_name(), request->tag());
734 735 736 737 738 739 740 741 742 743
    SET_RESPONSE(response, status, context);

    return ::grpc::Status::OK;
}

::grpc::Status
GrpcRequestHandler::Flush(::grpc::ServerContext* context, const ::milvus::grpc::FlushParam* request,
                          ::milvus::grpc::Status* response) {
    CHECK_NULLPTR_RETURN(request);

J
Jin Hai 已提交
744
    std::vector<std::string> collection_names;
G
groot 已提交
745 746
    for (int32_t i = 0; i < request->collection_name_array().size(); i++) {
        collection_names.push_back(request->collection_name_array(i));
747
    }
748
    Status status = request_handler_.Flush(GetContext(context), collection_names);
749 750 751 752 753 754
    SET_RESPONSE(response, status, context);

    return ::grpc::Status::OK;
}

::grpc::Status
G
groot 已提交
755
GrpcRequestHandler::Compact(::grpc::ServerContext* context, const ::milvus::grpc::CollectionName* request,
756 757 758
                            ::milvus::grpc::Status* response) {
    CHECK_NULLPTR_RETURN(request);

759
    Status status = request_handler_.Compact(GetContext(context), request->collection_name());
760 761
    SET_RESPONSE(response, status, context);

G
groot 已提交
762 763 764
    return ::grpc::Status::OK;
}

765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807
/*******************************************New Interface*********************************************/

::grpc::Status
GrpcRequestHandler::CreateHybridCollection(::grpc::ServerContext* context, const ::milvus::grpc::Mapping* request,
                                           ::milvus::grpc::Status* response) {
    CHECK_NULLPTR_RETURN(request);

    std::vector<std::pair<std::string, engine::meta::hybrid::DataType>> field_types;
    std::vector<std::pair<std::string, uint64_t>> vector_dimensions;
    std::vector<std::pair<std::string, std::string>> field_params;
    for (uint64_t i = 0; i < request->fields_size(); ++i) {
        if (request->fields(i).type().has_vector_param()) {
            auto vector_dimension =
                std::make_pair(request->fields(i).name(), request->fields(i).type().vector_param().dimension());
            vector_dimensions.emplace_back(vector_dimension);
        } else {
            auto type = std::make_pair(request->fields(i).name(),
                                       (engine::meta::hybrid::DataType)request->fields(i).type().data_type());
            field_types.emplace_back(type);
        }
        // Currently only one extra_param
        if (request->fields(i).extra_params_size() != 0) {
            auto extra_params = std::make_pair(request->fields(i).name(), request->fields(i).extra_params(0).value());
            field_params.emplace_back(extra_params);
        } else {
            auto extra_params = std::make_pair(request->fields(i).name(), "");
            field_params.emplace_back(extra_params);
        }
    }

    Status status = request_handler_.CreateHybridCollection(GetContext(context), request->collection_name(),
                                                            field_types, vector_dimensions, field_params);

    SET_RESPONSE(response, status, context);

    return ::grpc::Status::OK;
}

::grpc::Status
GrpcRequestHandler::InsertEntity(::grpc::ServerContext* context, const ::milvus::grpc::HInsertParam* request,
                                 ::milvus::grpc::HEntityIDs* response) {
    CHECK_NULLPTR_RETURN(request);

808 809
    auto attr_size = request->entities().attr_records().size();
    std::vector<uint8_t> attr_values(attr_size, 0);
810 811
    std::unordered_map<std::string, engine::VectorsData> vector_datas;

812 813 814 815 816 817 818 819 820
    memcpy(attr_values.data(), request->entities().attr_records().data(), attr_size);

    uint64_t row_num = request->entities().row_num();

    std::vector<std::string> field_names;
    auto field_size = request->entities().field_names_size();
    field_names.resize(field_size - 1);
    for (uint64_t i = 0; i < field_size - 1; ++i) {
        field_names[i] = request->entities().field_names(i);
821 822 823 824 825 826 827
    }

    auto vector_size = request->entities().result_values_size();
    for (uint64_t i = 0; i < vector_size; ++i) {
        engine::VectorsData vectors;
        CopyRowRecords(request->entities().result_values(i).vector_value().value(), request->entity_id_array(),
                       vectors);
828
        vector_datas.insert(std::make_pair(request->entities().field_names(field_size - 1), vectors));
829 830 831 832
    }

    std::string collection_name = request->collection_name();
    std::string partition_tag = request->partition_tag();
833 834
    Status status = request_handler_.InsertEntity(GetContext(context), collection_name, partition_tag, row_num,
                                                  field_names, attr_values, vector_datas);
835 836 837 838 839 840 841 842 843 844

    response->mutable_entity_id_array()->Resize(static_cast<int>(vector_datas.begin()->second.id_array_.size()), 0);
    memcpy(response->mutable_entity_id_array()->mutable_data(), vector_datas.begin()->second.id_array_.data(),
           vector_datas.begin()->second.id_array_.size() * sizeof(int64_t));

    SET_RESPONSE(response->mutable_status(), status, context);
    return ::grpc::Status::OK;
}

void
845
DeSerialization(const ::milvus::grpc::GeneralQuery& general_query, query::BooleanQueryPtr& boolean_clause) {
846
    if (general_query.has_boolean_query()) {
847
        boolean_clause->SetOccur((query::Occur)general_query.boolean_query().occur());
848 849
        for (uint64_t i = 0; i < general_query.boolean_query().general_query_size(); ++i) {
            if (general_query.boolean_query().general_query(i).has_boolean_query()) {
850
                query::BooleanQueryPtr query = std::make_shared<query::BooleanQuery>();
851 852 853 854 855 856 857 858 859
                DeSerialization(general_query.boolean_query().general_query(i), query);
                boolean_clause->AddBooleanQuery(query);
            } else {
                auto leaf_query = std::make_shared<query::LeafQuery>();
                auto query = general_query.boolean_query().general_query(i);
                if (query.has_term_query()) {
                    query::TermQueryPtr term_query = std::make_shared<query::TermQuery>();
                    term_query->field_name = query.term_query().field_name();
                    term_query->boost = query.term_query().boost();
860 861 862
                    auto size = query.term_query().values().size();
                    term_query->field_value.resize(size);
                    memcpy(term_query->field_value.data(), query.term_query().values().data(), size);
863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951
                    leaf_query->term_query = term_query;
                    boolean_clause->AddLeafQuery(leaf_query);
                }
                if (query.has_range_query()) {
                    query::RangeQueryPtr range_query = std::make_shared<query::RangeQuery>();
                    range_query->field_name = query.range_query().field_name();
                    range_query->boost = query.range_query().boost();
                    range_query->compare_expr.resize(query.range_query().operand_size());
                    for (uint64_t j = 0; j < query.range_query().operand_size(); ++j) {
                        range_query->compare_expr[j].compare_operator =
                            query::CompareOperator(query.range_query().operand(j).operator_());
                        range_query->compare_expr[j].operand = query.range_query().operand(j).operand();
                    }
                    leaf_query->range_query = range_query;
                    boolean_clause->AddLeafQuery(leaf_query);
                }
                if (query.has_vector_query()) {
                    query::VectorQueryPtr vector_query = std::make_shared<query::VectorQuery>();

                    engine::VectorsData vectors;
                    CopyRowRecords(query.vector_query().records(),
                                   google::protobuf::RepeatedField<google::protobuf::int64>(), vectors);

                    vector_query->query_vector.float_data = vectors.float_data_;
                    vector_query->query_vector.binary_data = vectors.binary_data_;

                    vector_query->boost = query.vector_query().query_boost();
                    vector_query->field_name = query.vector_query().field_name();
                    vector_query->topk = query.vector_query().topk();

                    milvus::json json_params;
                    for (int j = 0; j < query.vector_query().extra_params_size(); j++) {
                        const ::milvus::grpc::KeyValuePair& extra = query.vector_query().extra_params(j);
                        if (extra.key() == EXTRA_PARAM_KEY) {
                            json_params = json::parse(extra.value());
                        }
                    }
                    vector_query->extra_params = json_params;
                    leaf_query->vector_query = vector_query;
                    boolean_clause->AddLeafQuery(leaf_query);
                }
            }
        }
    }
}

::grpc::Status
GrpcRequestHandler::HybridSearch(::grpc::ServerContext* context, const ::milvus::grpc::HSearchParam* request,
                                 ::milvus::grpc::TopKQueryResult* response) {
    CHECK_NULLPTR_RETURN(request);

    context::HybridSearchContextPtr hybrid_search_context = std::make_shared<context::HybridSearchContext>();

    query::BooleanQueryPtr boolean_query = std::make_shared<query::BooleanQuery>();
    DeSerialization(request->general_query(), boolean_query);

    query::GeneralQueryPtr general_query = std::make_shared<query::GeneralQuery>();
    general_query->bin = std::make_shared<query::BinaryQuery>();
    query::GenBinaryQuery(boolean_query, general_query->bin);

    Status status;

    if (!query::ValidateBinaryQuery(general_query->bin)) {
        status = Status{SERVER_INVALID_BINARY_QUERY, "Generate wrong binary query tree"};
        SET_RESPONSE(response->mutable_status(), status, context);
        return ::grpc::Status::OK;
    }

    hybrid_search_context->general_query_ = general_query;

    std::vector<std::string> partition_list;
    partition_list.resize(request->partition_tag_array_size());
    for (uint64_t i = 0; i < request->partition_tag_array_size(); ++i) {
        partition_list[i] = request->partition_tag_array(i);
    }

    TopKQueryResult result;

    status = request_handler_.HybridSearch(GetContext(context), hybrid_search_context, request->collection_name(),
                                           partition_list, general_query, result);

    // step 6: construct and return result
    ConstructResults(result, response);

    SET_RESPONSE(response->mutable_status(), status, context);

    return ::grpc::Status::OK;
}

S
starlord 已提交
952 953 954
}  // namespace grpc
}  // namespace server
}  // namespace milvus