未验证 提交 3bc17d8c 编写于 作者: W Wang XiangYu 提交者: GitHub

Fix multi client search crash in tracing module fix #1789 fix #1832 (#1899)

* Fix multi client search crash in tracing module fix #1789 fix #1832
Signed-off-by: Nwxyu <xy.wang@zilliz.com>

* add lock for every context_map_ access
Signed-off-by: Nwxyu <xy.wang@zilliz.com>

* remove never used variable
Signed-off-by: Nwxyu <xy.wang@zilliz.com>
上级 c8a59b27
......@@ -7,6 +7,8 @@ Please mark all change in change log and use the issue from GitHub
## Bug
- \#1276 SQLite throw exception after create 50000+ partitions in a table
- \#1762 Server is not forbidden to create new partition which tag is `_default`
- \#1789 Fix multi-client search cause server crash
- \#1832 Fix crash in tracing module
- \#1873 Fix index file serialize to incorrect path
- \#1881 Fix Annoy index search failure
......
......@@ -301,6 +301,7 @@ if (DEFINED ENV{MILVUS_GRPC_URL})
set(GRPC_SOURCE_URL "$ENV{MILVUS_GRPC_URL}")
else ()
set(GRPC_SOURCE_URL
"https://github.com/milvus-io/grpc-milvus/archive/${GRPC_VERSION}.zip"
"https://github.com/youny626/grpc-milvus/archive/${GRPC_VERSION}.zip"
"https://gitee.com/quicksilver/grpc-milvus/repository/archive/${GRPC_VERSION}.zip")
endif ()
......
......@@ -13,6 +13,7 @@
#include <fiu-local.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
......@@ -158,6 +159,50 @@ ConstructCollectionInfo(const CollectionInfo& collection_info, ::milvus::grpc::C
} // namespace
namespace {
#define REQ_ID ("request_id")
std::atomic<int64_t> _sequential_id;
int64_t
get_sequential_id() {
return _sequential_id++;
}
void
set_request_id(::grpc::ServerContext* context, const std::string& request_id) {
if (not context) {
// error
SERVER_LOG_ERROR << "set_request_id: grpc::ServerContext is nullptr" << std::endl;
return;
}
context->AddInitialMetadata(REQ_ID, request_id);
}
std::string
get_request_id(::grpc::ServerContext* context) {
if (not context) {
// error
SERVER_LOG_ERROR << "get_request_id: grpc::ServerContext is nullptr" << std::endl;
return "INVALID_ID";
}
auto server_metadata = context->server_metadata();
auto request_id_kv = server_metadata.find(REQ_ID);
if (request_id_kv == server_metadata.end()) {
// error
SERVER_LOG_ERROR << std::string(REQ_ID) << " not found in grpc.server_metadata" << std::endl;
return "INVALID_ID";
}
return request_id_kv->second.data();
}
} // namespace
GrpcRequestHandler::GrpcRequestHandler(const std::shared_ptr<opentracing::Tracer>& tracer)
: tracer_(tracer), random_num_generator_() {
std::random_device random_device;
......@@ -187,16 +232,42 @@ GrpcRequestHandler::OnPostRecvInitialMetaData(
return;
}
auto span = tracer_->StartSpan(server_rpc_info->method(), {opentracing::ChildOf(span_context_maybe->get())});
auto server_context = server_rpc_info->server_context();
auto client_metadata = server_context->client_metadata();
// TODO: request id
// if client provide request_id in metadata, milvus just use it,
// else milvus generate a sequential id.
std::string request_id;
auto request_id_kv = client_metadata.find("request_id");
if (request_id_kv != client_metadata.end()) {
request_id = request_id_kv->second.data();
SERVER_LOG_DEBUG << "client provide request_id: " << request_id;
// if request_id is being used by another request,
// convert it to request_id_n.
std::lock_guard<std::mutex> lock(context_map_mutex_);
if (context_map_.find(request_id) == context_map_.end()) {
// if not found exist, mark
context_map_[request_id] = nullptr;
} else {
// Finding a unused suffix
int64_t suffix = 1;
std::string try_request_id;
bool exist = true;
do {
try_request_id = request_id + "_" + std::to_string(suffix);
exist = context_map_.find(try_request_id) != context_map_.end();
suffix++;
} while (exist);
context_map_[try_request_id] = nullptr;
}
} else {
request_id = std::to_string(random_id()) + std::to_string(random_id());
request_id = std::to_string(get_sequential_id());
set_request_id(server_context, request_id);
SERVER_LOG_DEBUG << "milvus generate request_id: " << request_id;
}
auto trace_context = std::make_shared<tracing::TraceContext>(span);
auto context = std::make_shared<Context>(request_id);
context->SetTraceContext(trace_context);
......@@ -207,23 +278,33 @@ void
GrpcRequestHandler::OnPreSendMessage(::grpc::experimental::ServerRpcInfo* server_rpc_info,
::grpc::experimental::InterceptorBatchMethods* interceptor_batch_methods) {
std::lock_guard<std::mutex> lock(context_map_mutex_);
context_map_[server_rpc_info->server_context()]->GetTraceContext()->GetSpan()->Finish();
auto search = context_map_.find(server_rpc_info->server_context());
if (search != context_map_.end()) {
context_map_.erase(search);
auto request_id = get_request_id(server_rpc_info->server_context());
if (context_map_.find(request_id) == context_map_.end()) {
// error
SERVER_LOG_ERROR << "request_id " << request_id << " not found in context_map_";
return;
}
context_map_[request_id]->GetTraceContext()->GetSpan()->Finish();
context_map_.erase(request_id);
}
const std::shared_ptr<Context>&
GrpcRequestHandler::GetContext(::grpc::ServerContext* server_context) {
std::lock_guard<std::mutex> lock(context_map_mutex_);
return context_map_[server_context];
auto request_id = get_request_id(server_context);
if (context_map_.find(request_id) == context_map_.end()) {
SERVER_LOG_ERROR << "GetContext: request_id " << request_id << " not found in context_map_";
return nullptr;
}
return context_map_[request_id];
}
void
GrpcRequestHandler::SetContext(::grpc::ServerContext* server_context, const std::shared_ptr<Context>& context) {
std::lock_guard<std::mutex> lock(context_map_mutex_);
context_map_[server_context] = context;
auto request_id = get_request_id(server_context);
context_map_[request_id] = context;
}
uint64_t
......@@ -244,7 +325,7 @@ GrpcRequestHandler::CreateCollection(::grpc::ServerContext* context, const ::mil
CHECK_NULLPTR_RETURN(request);
Status status =
request_handler_.CreateCollection(context_map_[context], request->collection_name(), request->dimension(),
request_handler_.CreateCollection(GetContext(context), request->collection_name(), request->dimension(),
request->index_file_size(), request->metric_type());
SET_RESPONSE(response, status, context);
......@@ -258,7 +339,7 @@ GrpcRequestHandler::HasCollection(::grpc::ServerContext* context, const ::milvus
bool has_collection = false;
Status status = request_handler_.HasCollection(context_map_[context], request->collection_name(), has_collection);
Status status = request_handler_.HasCollection(GetContext(context), request->collection_name(), has_collection);
response->set_bool_reply(has_collection);
SET_RESPONSE(response->mutable_status(), status, context);
......@@ -270,7 +351,7 @@ GrpcRequestHandler::DropCollection(::grpc::ServerContext* context, const ::milvu
::milvus::grpc::Status* response) {
CHECK_NULLPTR_RETURN(request);
Status status = request_handler_.DropCollection(context_map_[context], request->collection_name());
Status status = request_handler_.DropCollection(GetContext(context), request->collection_name());
SET_RESPONSE(response, status, context);
return ::grpc::Status::OK;
......@@ -289,8 +370,8 @@ GrpcRequestHandler::CreateIndex(::grpc::ServerContext* context, const ::milvus::
}
}
Status status = request_handler_.CreateIndex(context_map_[context], request->collection_name(),
request->index_type(), json_params);
Status status = request_handler_.CreateIndex(GetContext(context), request->collection_name(), request->index_type(),
json_params);
SET_RESPONSE(response, status, context);
return ::grpc::Status::OK;
......@@ -309,7 +390,7 @@ GrpcRequestHandler::Insert(::grpc::ServerContext* context, const ::milvus::grpc:
// step 2: insert vectors
Status status =
request_handler_.Insert(context_map_[context], request->collection_name(), vectors, request->partition_tag());
request_handler_.Insert(GetContext(context), request->collection_name(), vectors, request->partition_tag());
// step 3: return id array
response->mutable_vector_id_array()->Resize(static_cast<int>(vectors.id_array_.size()), 0);
......@@ -329,7 +410,7 @@ GrpcRequestHandler::GetVectorByID(::grpc::ServerContext* context, const ::milvus
std::vector<int64_t> vector_ids = {request->id()};
engine::VectorsData vectors;
Status status =
request_handler_.GetVectorByID(context_map_[context], request->collection_name(), vector_ids, vectors);
request_handler_.GetVectorByID(GetContext(context), request->collection_name(), vector_ids, vectors);
if (!vectors.float_data_.empty()) {
response->mutable_vector_data()->mutable_float_data()->Resize(vectors.float_data_.size(), 0);
......@@ -351,7 +432,7 @@ GrpcRequestHandler::GetVectorIDs(::grpc::ServerContext* context, const ::milvus:
CHECK_NULLPTR_RETURN(request);
std::vector<int64_t> vector_ids;
Status status = request_handler_.GetVectorIDs(context_map_[context], request->collection_name(),
Status status = request_handler_.GetVectorIDs(GetContext(context), request->collection_name(),
request->segment_name(), vector_ids);
if (!vector_ids.empty()) {
......@@ -393,7 +474,8 @@ GrpcRequestHandler::Search(::grpc::ServerContext* context, const ::milvus::grpc:
std::vector<std::string> file_ids;
TopKQueryResult result;
fiu_do_on("GrpcRequestHandler.Search.not_empty_file_ids", file_ids.emplace_back("test_file_id"));
Status status = request_handler_.Search(context_map_[context], request->collection_name(), vectors, request->topk(),
Status status = request_handler_.Search(GetContext(context), request->collection_name(), vectors, request->topk(),
json_params, partitions, file_ids, result);
// step 5: construct and return result
......@@ -428,7 +510,7 @@ GrpcRequestHandler::SearchByID(::grpc::ServerContext* context, const ::milvus::g
// step 3: search vectors
TopKQueryResult result;
Status status = request_handler_.SearchByID(context_map_[context], request->collection_name(), request->id(),
Status status = request_handler_.SearchByID(GetContext(context), request->collection_name(), request->id(),
request->topk(), json_params, partitions, result);
// step 4: construct and return result
......@@ -474,7 +556,7 @@ GrpcRequestHandler::SearchInFiles(::grpc::ServerContext* context, const ::milvus
// step 5: search vectors
TopKQueryResult result;
Status status = request_handler_.Search(context_map_[context], search_request->collection_name(), vectors,
Status status = request_handler_.Search(GetContext(context), search_request->collection_name(), vectors,
search_request->topk(), json_params, partitions, file_ids, result);
// step 6: construct and return result
......@@ -492,7 +574,7 @@ GrpcRequestHandler::DescribeCollection(::grpc::ServerContext* context, const ::m
CollectionSchema collection_schema;
Status status =
request_handler_.DescribeCollection(context_map_[context], request->collection_name(), collection_schema);
request_handler_.DescribeCollection(GetContext(context), request->collection_name(), collection_schema);
response->set_collection_name(collection_schema.collection_name_);
response->set_dimension(collection_schema.dimension_);
response->set_index_file_size(collection_schema.index_file_size_);
......@@ -508,7 +590,7 @@ GrpcRequestHandler::CountCollection(::grpc::ServerContext* context, const ::milv
CHECK_NULLPTR_RETURN(request);
int64_t row_count = 0;
Status status = request_handler_.CountCollection(context_map_[context], request->collection_name(), row_count);
Status status = request_handler_.CountCollection(GetContext(context), request->collection_name(), row_count);
response->set_collection_row_count(row_count);
SET_RESPONSE(response->mutable_status(), status, context);
return ::grpc::Status::OK;
......@@ -520,7 +602,7 @@ GrpcRequestHandler::ShowCollections(::grpc::ServerContext* context, const ::milv
CHECK_NULLPTR_RETURN(request);
std::vector<std::string> collections;
Status status = request_handler_.ShowCollections(context_map_[context], collections);
Status status = request_handler_.ShowCollections(GetContext(context), collections);
for (auto& collection : collections) {
response->add_collection_names(collection);
}
......@@ -536,7 +618,7 @@ GrpcRequestHandler::ShowCollectionInfo(::grpc::ServerContext* context, const ::m
CollectionInfo collection_info;
Status status =
request_handler_.ShowCollectionInfo(context_map_[context], request->collection_name(), collection_info);
request_handler_.ShowCollectionInfo(GetContext(context), request->collection_name(), collection_info);
ConstructCollectionInfo(collection_info, response);
SET_RESPONSE(response->mutable_status(), status, context);
......@@ -549,7 +631,7 @@ GrpcRequestHandler::Cmd(::grpc::ServerContext* context, const ::milvus::grpc::Co
CHECK_NULLPTR_RETURN(request);
std::string reply;
Status status = request_handler_.Cmd(context_map_[context], request->cmd(), reply);
Status status = request_handler_.Cmd(GetContext(context), request->cmd(), reply);
response->set_string_reply(reply);
SET_RESPONSE(response->mutable_status(), status, context);
......@@ -568,7 +650,7 @@ GrpcRequestHandler::DeleteByID(::grpc::ServerContext* context, const ::milvus::g
}
// step 2: delete vector
Status status = request_handler_.DeleteByID(context_map_[context], request->collection_name(), vector_ids);
Status status = request_handler_.DeleteByID(GetContext(context), request->collection_name(), vector_ids);
SET_RESPONSE(response, status, context);
return ::grpc::Status::OK;
......@@ -579,7 +661,7 @@ GrpcRequestHandler::PreloadCollection(::grpc::ServerContext* context, const ::mi
::milvus::grpc::Status* response) {
CHECK_NULLPTR_RETURN(request);
Status status = request_handler_.PreloadCollection(context_map_[context], request->collection_name());
Status status = request_handler_.PreloadCollection(GetContext(context), request->collection_name());
SET_RESPONSE(response, status, context);
return ::grpc::Status::OK;
......@@ -591,7 +673,7 @@ GrpcRequestHandler::DescribeIndex(::grpc::ServerContext* context, const ::milvus
CHECK_NULLPTR_RETURN(request);
IndexParam param;
Status status = request_handler_.DescribeIndex(context_map_[context], request->collection_name(), param);
Status status = request_handler_.DescribeIndex(GetContext(context), request->collection_name(), param);
response->set_collection_name(param.collection_name_);
response->set_index_type(param.index_type_);
::milvus::grpc::KeyValuePair* kv = response->add_extra_params();
......@@ -607,7 +689,7 @@ GrpcRequestHandler::DropIndex(::grpc::ServerContext* context, const ::milvus::gr
::milvus::grpc::Status* response) {
CHECK_NULLPTR_RETURN(request);
Status status = request_handler_.DropIndex(context_map_[context], request->collection_name());
Status status = request_handler_.DropIndex(GetContext(context), request->collection_name());
SET_RESPONSE(response, status, context);
return ::grpc::Status::OK;
......@@ -618,7 +700,7 @@ GrpcRequestHandler::CreatePartition(::grpc::ServerContext* context, const ::milv
::milvus::grpc::Status* response) {
CHECK_NULLPTR_RETURN(request);
Status status = request_handler_.CreatePartition(context_map_[context], request->collection_name(), request->tag());
Status status = request_handler_.CreatePartition(GetContext(context), request->collection_name(), request->tag());
SET_RESPONSE(response, status, context);
return ::grpc::Status::OK;
......@@ -630,7 +712,7 @@ GrpcRequestHandler::ShowPartitions(::grpc::ServerContext* context, const ::milvu
CHECK_NULLPTR_RETURN(request);
std::vector<PartitionParam> partitions;
Status status = request_handler_.ShowPartitions(context_map_[context], request->collection_name(), partitions);
Status status = request_handler_.ShowPartitions(GetContext(context), request->collection_name(), partitions);
for (auto& partition : partitions) {
response->add_partition_tag_array(partition.tag_);
}
......@@ -645,7 +727,7 @@ GrpcRequestHandler::DropPartition(::grpc::ServerContext* context, const ::milvus
::milvus::grpc::Status* response) {
CHECK_NULLPTR_RETURN(request);
Status status = request_handler_.DropPartition(context_map_[context], request->collection_name(), request->tag());
Status status = request_handler_.DropPartition(GetContext(context), request->collection_name(), request->tag());
SET_RESPONSE(response, status, context);
return ::grpc::Status::OK;
......@@ -660,7 +742,7 @@ GrpcRequestHandler::Flush(::grpc::ServerContext* context, const ::milvus::grpc::
for (int32_t i = 0; i < request->collection_name_array().size(); i++) {
collection_names.push_back(request->collection_name_array(i));
}
Status status = request_handler_.Flush(context_map_[context], collection_names);
Status status = request_handler_.Flush(GetContext(context), collection_names);
SET_RESPONSE(response, status, context);
return ::grpc::Status::OK;
......@@ -671,7 +753,7 @@ GrpcRequestHandler::Compact(::grpc::ServerContext* context, const ::milvus::grpc
::milvus::grpc::Status* response) {
CHECK_NULLPTR_RETURN(request);
Status status = request_handler_.Compact(context_map_[context], request->collection_name());
Status status = request_handler_.Compact(GetContext(context), request->collection_name());
SET_RESPONSE(response, status, context);
return ::grpc::Status::OK;
......
......@@ -11,6 +11,7 @@
#pragma once
#include <grpcpp/server_context.h>
#include <server/context/Context.h>
#include <cstdint>
......@@ -311,7 +312,8 @@ class GrpcRequestHandler final : public ::milvus::grpc::MilvusService::Service,
private:
RequestHandler request_handler_;
std::unordered_map<::grpc::ServerContext*, std::shared_ptr<Context>> context_map_;
// std::unordered_map<::grpc::ServerContext*, std::shared_ptr<Context>> context_map_;
std::unordered_map<std::string, std::shared_ptr<Context>> context_map_;
std::shared_ptr<opentracing::Tracer> tracer_;
// std::unordered_map<::grpc::ServerContext*, std::unique_ptr<opentracing::Span>> span_map_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册