未验证 提交 69252f81 编写于 作者: C congqixia 提交者: GitHub

Implement memory replica in Proxy, QueryNode and QueryCoord (#16470)

Related to #16298 #16291 #16154
Co-authored-by: Nsunby <bingyi.sun@zilliz.com>
Co-authored-by: Nyangxuan <xuan.yang@zilliz.com>
Co-authored-by: Nyah01 <yang.cen@zilliz.com>
Co-authored-by: NLetian Jiang <letian.jiang@zilliz.com>
Signed-off-by: NCongqi Xia <congqi.xia@zilliz.com>
上级 9aa557f8
...@@ -290,7 +290,7 @@ const char descriptor_table_protodef_common_2eproto[] PROTOBUF_SECTION_VARIABLE( ...@@ -290,7 +290,7 @@ const char descriptor_table_protodef_common_2eproto[] PROTOBUF_SECTION_VARIABLE(
"led\020\004*\202\001\n\014SegmentState\022\024\n\020SegmentStateNo" "led\020\004*\202\001\n\014SegmentState\022\024\n\020SegmentStateNo"
"ne\020\000\022\014\n\010NotExist\020\001\022\013\n\007Growing\020\002\022\n\n\006Seale" "ne\020\000\022\014\n\010NotExist\020\001\022\013\n\007Growing\020\002\022\n\n\006Seale"
"d\020\003\022\013\n\007Flushed\020\004\022\014\n\010Flushing\020\005\022\013\n\007Droppe" "d\020\003\022\013\n\007Flushed\020\004\022\014\n\010Flushing\020\005\022\013\n\007Droppe"
"d\020\006\022\r\n\tImporting\020\007*\261\n\n\007MsgType\022\r\n\tUndefi" "d\020\006\022\r\n\tImporting\020\007*\307\n\n\007MsgType\022\r\n\tUndefi"
"ned\020\000\022\024\n\020CreateCollection\020d\022\022\n\016DropColle" "ned\020\000\022\024\n\020CreateCollection\020d\022\022\n\016DropColle"
"ction\020e\022\021\n\rHasCollection\020f\022\026\n\022DescribeCo" "ction\020e\022\021\n\rHasCollection\020f\022\026\n\022DescribeCo"
"llection\020g\022\023\n\017ShowCollections\020h\022\024\n\020GetSy" "llection\020g\022\023\n\017ShowCollections\020h\022\024\n\020GetSy"
...@@ -314,27 +314,28 @@ const char descriptor_table_protodef_common_2eproto[] PROTOBUF_SECTION_VARIABLE( ...@@ -314,27 +314,28 @@ const char descriptor_table_protodef_common_2eproto[] PROTOBUF_SECTION_VARIABLE(
"atchDmChannels\020\374\003\022\025\n\020RemoveDmChannels\020\375\003" "atchDmChannels\020\374\003\022\025\n\020RemoveDmChannels\020\375\003"
"\022\027\n\022WatchQueryChannels\020\376\003\022\030\n\023RemoveQuery" "\022\027\n\022WatchQueryChannels\020\376\003\022\030\n\023RemoveQuery"
"Channels\020\377\003\022\035\n\030SealedSegmentsChangeInfo\020" "Channels\020\377\003\022\035\n\030SealedSegmentsChangeInfo\020"
"\200\004\022\027\n\022WatchDeltaChannels\020\201\004\022\020\n\013SegmentIn" "\200\004\022\027\n\022WatchDeltaChannels\020\201\004\022\024\n\017GetShardL"
"fo\020\330\004\022\017\n\nSystemInfo\020\331\004\022\024\n\017GetRecoveryInf" "eaders\020\202\004\022\020\n\013SegmentInfo\020\330\004\022\017\n\nSystemInf"
"o\020\332\004\022\024\n\017GetSegmentState\020\333\004\022\r\n\010TimeTick\020\260" "o\020\331\004\022\024\n\017GetRecoveryInfo\020\332\004\022\024\n\017GetSegment"
"\t\022\023\n\016QueryNodeStats\020\261\t\022\016\n\tLoadIndex\020\262\t\022\016" "State\020\333\004\022\r\n\010TimeTick\020\260\t\022\023\n\016QueryNodeStat"
"\n\tRequestID\020\263\t\022\017\n\nRequestTSO\020\264\t\022\024\n\017Alloc" "s\020\261\t\022\016\n\tLoadIndex\020\262\t\022\016\n\tRequestID\020\263\t\022\017\n\n"
"ateSegment\020\265\t\022\026\n\021SegmentStatistics\020\266\t\022\025\n" "RequestTSO\020\264\t\022\024\n\017AllocateSegment\020\265\t\022\026\n\021S"
"\020SegmentFlushDone\020\267\t\022\017\n\nDataNodeTt\020\270\t\022\025\n" "egmentStatistics\020\266\t\022\025\n\020SegmentFlushDone\020"
"\020CreateCredential\020\334\013\022\022\n\rGetCredential\020\335\013" "\267\t\022\017\n\nDataNodeTt\020\270\t\022\025\n\020CreateCredential\020"
"\022\025\n\020DeleteCredential\020\336\013\022\025\n\020UpdateCredent" "\334\013\022\022\n\rGetCredential\020\335\013\022\025\n\020DeleteCredenti"
"ial\020\337\013\022\026\n\021ListCredUsernames\020\340\013*\"\n\007DslTyp" "al\020\336\013\022\025\n\020UpdateCredential\020\337\013\022\026\n\021ListCred"
"e\022\007\n\003Dsl\020\000\022\016\n\nBoolExprV1\020\001*B\n\017Compaction" "Usernames\020\340\013*\"\n\007DslType\022\007\n\003Dsl\020\000\022\016\n\nBool"
"State\022\021\n\rUndefiedState\020\000\022\r\n\tExecuting\020\001\022" "ExprV1\020\001*B\n\017CompactionState\022\021\n\rUndefiedS"
"\r\n\tCompleted\020\002*X\n\020ConsistencyLevel\022\n\n\006St" "tate\020\000\022\r\n\tExecuting\020\001\022\r\n\tCompleted\020\002*X\n\020"
"rong\020\000\022\013\n\007Session\020\001\022\013\n\007Bounded\020\002\022\016\n\nEven" "ConsistencyLevel\022\n\n\006Strong\020\000\022\013\n\007Session\020"
"tually\020\003\022\016\n\nCustomized\020\004*\227\001\n\013ImportState" "\001\022\013\n\007Bounded\020\002\022\016\n\nEventually\020\003\022\016\n\nCustom"
"\022\021\n\rImportPending\020\000\022\020\n\014ImportFailed\020\001\022\021\n" "ized\020\004*\227\001\n\013ImportState\022\021\n\rImportPending\020"
"\rImportStarted\020\002\022\024\n\020ImportDownloaded\020\003\022\020" "\000\022\020\n\014ImportFailed\020\001\022\021\n\rImportStarted\020\002\022\024"
"\n\014ImportParsed\020\004\022\023\n\017ImportPersisted\020\005\022\023\n" "\n\020ImportDownloaded\020\003\022\020\n\014ImportParsed\020\004\022\023"
"\017ImportCompleted\020\006BW\n\016io.milvus.grpcB\013Co" "\n\017ImportPersisted\020\005\022\023\n\017ImportCompleted\020\006"
"mmonProtoP\001Z3github.com/milvus-io/milvus" "BW\n\016io.milvus.grpcB\013CommonProtoP\001Z3githu"
"/internal/proto/commonpb\240\001\001b\006proto3" "b.com/milvus-io/milvus/internal/proto/co"
"mmonpb\240\001\001b\006proto3"
; ;
static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_common_2eproto_deps[1] = { static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_common_2eproto_deps[1] = {
}; };
...@@ -351,7 +352,7 @@ static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_com ...@@ -351,7 +352,7 @@ static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_com
static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_common_2eproto_once; static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_common_2eproto_once;
static bool descriptor_table_common_2eproto_initialized = false; static bool descriptor_table_common_2eproto_initialized = false;
const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_common_2eproto = { const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_common_2eproto = {
&descriptor_table_common_2eproto_initialized, descriptor_table_protodef_common_2eproto, "common.proto", 3275, &descriptor_table_common_2eproto_initialized, descriptor_table_protodef_common_2eproto, "common.proto", 3297,
&descriptor_table_common_2eproto_once, descriptor_table_common_2eproto_sccs, descriptor_table_common_2eproto_deps, 8, 0, &descriptor_table_common_2eproto_once, descriptor_table_common_2eproto_sccs, descriptor_table_common_2eproto_deps, 8, 0,
schemas, file_default_instances, TableStruct_common_2eproto::offsets, schemas, file_default_instances, TableStruct_common_2eproto::offsets,
file_level_metadata_common_2eproto, 8, file_level_enum_descriptors_common_2eproto, file_level_service_descriptors_common_2eproto, file_level_metadata_common_2eproto, 8, file_level_enum_descriptors_common_2eproto, file_level_service_descriptors_common_2eproto,
...@@ -497,6 +498,7 @@ bool MsgType_IsValid(int value) { ...@@ -497,6 +498,7 @@ bool MsgType_IsValid(int value) {
case 511: case 511:
case 512: case 512:
case 513: case 513:
case 514:
case 600: case 600:
case 601: case 601:
case 602: case 602:
......
...@@ -262,6 +262,7 @@ enum MsgType : int { ...@@ -262,6 +262,7 @@ enum MsgType : int {
RemoveQueryChannels = 511, RemoveQueryChannels = 511,
SealedSegmentsChangeInfo = 512, SealedSegmentsChangeInfo = 512,
WatchDeltaChannels = 513, WatchDeltaChannels = 513,
GetShardLeaders = 514,
SegmentInfo = 600, SegmentInfo = 600,
SystemInfo = 601, SystemInfo = 601,
GetRecoveryInfo = 602, GetRecoveryInfo = 602,
......
此差异已折叠。
...@@ -3888,6 +3888,7 @@ class LoadPartitionsRequest : ...@@ -3888,6 +3888,7 @@ class LoadPartitionsRequest :
kDbNameFieldNumber = 2, kDbNameFieldNumber = 2,
kCollectionNameFieldNumber = 3, kCollectionNameFieldNumber = 3,
kBaseFieldNumber = 1, kBaseFieldNumber = 1,
kReplicaNumberFieldNumber = 5,
}; };
// repeated string partition_names = 4; // repeated string partition_names = 4;
int partition_names_size() const; int partition_names_size() const;
...@@ -3936,6 +3937,11 @@ class LoadPartitionsRequest : ...@@ -3936,6 +3937,11 @@ class LoadPartitionsRequest :
::milvus::proto::common::MsgBase* mutable_base(); ::milvus::proto::common::MsgBase* mutable_base();
void set_allocated_base(::milvus::proto::common::MsgBase* base); void set_allocated_base(::milvus::proto::common::MsgBase* base);
// int32 replica_number = 5;
void clear_replica_number();
::PROTOBUF_NAMESPACE_ID::int32 replica_number() const;
void set_replica_number(::PROTOBUF_NAMESPACE_ID::int32 value);
// @@protoc_insertion_point(class_scope:milvus.proto.milvus.LoadPartitionsRequest) // @@protoc_insertion_point(class_scope:milvus.proto.milvus.LoadPartitionsRequest)
private: private:
class _Internal; class _Internal;
...@@ -3945,6 +3951,7 @@ class LoadPartitionsRequest : ...@@ -3945,6 +3951,7 @@ class LoadPartitionsRequest :
::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr db_name_; ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr db_name_;
::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr collection_name_; ::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr collection_name_;
::milvus::proto::common::MsgBase* base_; ::milvus::proto::common::MsgBase* base_;
::PROTOBUF_NAMESPACE_ID::int32 replica_number_;
mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_;
friend struct ::TableStruct_milvus_2eproto; friend struct ::TableStruct_milvus_2eproto;
}; };
...@@ -11873,6 +11880,7 @@ class LoadBalanceRequest : ...@@ -11873,6 +11880,7 @@ class LoadBalanceRequest :
enum : int { enum : int {
kDstNodeIDsFieldNumber = 3, kDstNodeIDsFieldNumber = 3,
kSealedSegmentIDsFieldNumber = 4, kSealedSegmentIDsFieldNumber = 4,
kCollectionNameFieldNumber = 5,
kBaseFieldNumber = 1, kBaseFieldNumber = 1,
kSrcNodeIDFieldNumber = 2, kSrcNodeIDFieldNumber = 2,
}; };
...@@ -11898,6 +11906,17 @@ class LoadBalanceRequest : ...@@ -11898,6 +11906,17 @@ class LoadBalanceRequest :
::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >* ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >*
mutable_sealed_segmentids(); mutable_sealed_segmentids();
// string collectionName = 5;
void clear_collectionname();
const std::string& collectionname() const;
void set_collectionname(const std::string& value);
void set_collectionname(std::string&& value);
void set_collectionname(const char* value);
void set_collectionname(const char* value, size_t size);
std::string* mutable_collectionname();
std::string* release_collectionname();
void set_allocated_collectionname(std::string* collectionname);
// .milvus.proto.common.MsgBase base = 1; // .milvus.proto.common.MsgBase base = 1;
bool has_base() const; bool has_base() const;
void clear_base(); void clear_base();
...@@ -11920,6 +11939,7 @@ class LoadBalanceRequest : ...@@ -11920,6 +11939,7 @@ class LoadBalanceRequest :
mutable std::atomic<int> _dst_nodeids_cached_byte_size_; mutable std::atomic<int> _dst_nodeids_cached_byte_size_;
::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 > sealed_segmentids_; ::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 > sealed_segmentids_;
mutable std::atomic<int> _sealed_segmentids_cached_byte_size_; mutable std::atomic<int> _sealed_segmentids_cached_byte_size_;
::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr collectionname_;
::milvus::proto::common::MsgBase* base_; ::milvus::proto::common::MsgBase* base_;
::PROTOBUF_NAMESPACE_ID::int64 src_nodeid_; ::PROTOBUF_NAMESPACE_ID::int64 src_nodeid_;
mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_; mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_;
...@@ -18987,6 +19007,20 @@ LoadPartitionsRequest::mutable_partition_names() { ...@@ -18987,6 +19007,20 @@ LoadPartitionsRequest::mutable_partition_names() {
return &partition_names_; return &partition_names_;
} }
// int32 replica_number = 5;
inline void LoadPartitionsRequest::clear_replica_number() {
replica_number_ = 0;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 LoadPartitionsRequest::replica_number() const {
// @@protoc_insertion_point(field_get:milvus.proto.milvus.LoadPartitionsRequest.replica_number)
return replica_number_;
}
inline void LoadPartitionsRequest::set_replica_number(::PROTOBUF_NAMESPACE_ID::int32 value) {
replica_number_ = value;
// @@protoc_insertion_point(field_set:milvus.proto.milvus.LoadPartitionsRequest.replica_number)
}
// ------------------------------------------------------------------- // -------------------------------------------------------------------
// ReleasePartitionsRequest // ReleasePartitionsRequest
...@@ -26343,6 +26377,57 @@ LoadBalanceRequest::mutable_sealed_segmentids() { ...@@ -26343,6 +26377,57 @@ LoadBalanceRequest::mutable_sealed_segmentids() {
return &sealed_segmentids_; return &sealed_segmentids_;
} }
// string collectionName = 5;
inline void LoadBalanceRequest::clear_collectionname() {
collectionname_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
}
inline const std::string& LoadBalanceRequest::collectionname() const {
// @@protoc_insertion_point(field_get:milvus.proto.milvus.LoadBalanceRequest.collectionName)
return collectionname_.GetNoArena();
}
inline void LoadBalanceRequest::set_collectionname(const std::string& value) {
collectionname_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value);
// @@protoc_insertion_point(field_set:milvus.proto.milvus.LoadBalanceRequest.collectionName)
}
inline void LoadBalanceRequest::set_collectionname(std::string&& value) {
collectionname_.SetNoArena(
&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value));
// @@protoc_insertion_point(field_set_rvalue:milvus.proto.milvus.LoadBalanceRequest.collectionName)
}
inline void LoadBalanceRequest::set_collectionname(const char* value) {
GOOGLE_DCHECK(value != nullptr);
collectionname_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value));
// @@protoc_insertion_point(field_set_char:milvus.proto.milvus.LoadBalanceRequest.collectionName)
}
inline void LoadBalanceRequest::set_collectionname(const char* value, size_t size) {
collectionname_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(),
::std::string(reinterpret_cast<const char*>(value), size));
// @@protoc_insertion_point(field_set_pointer:milvus.proto.milvus.LoadBalanceRequest.collectionName)
}
inline std::string* LoadBalanceRequest::mutable_collectionname() {
// @@protoc_insertion_point(field_mutable:milvus.proto.milvus.LoadBalanceRequest.collectionName)
return collectionname_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
}
inline std::string* LoadBalanceRequest::release_collectionname() {
// @@protoc_insertion_point(field_release:milvus.proto.milvus.LoadBalanceRequest.collectionName)
return collectionname_.ReleaseNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
}
inline void LoadBalanceRequest::set_allocated_collectionname(std::string* collectionname) {
if (collectionname != nullptr) {
} else {
}
collectionname_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), collectionname);
// @@protoc_insertion_point(field_set_allocated:milvus.proto.milvus.LoadBalanceRequest.collectionName)
}
// ------------------------------------------------------------------- // -------------------------------------------------------------------
// ManualCompactionRequest // ManualCompactionRequest
...@@ -528,6 +528,7 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf ...@@ -528,6 +528,7 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf
segment2Binlogs := make(map[UniqueID][]*datapb.FieldBinlog) segment2Binlogs := make(map[UniqueID][]*datapb.FieldBinlog)
segment2StatsBinlogs := make(map[UniqueID][]*datapb.FieldBinlog) segment2StatsBinlogs := make(map[UniqueID][]*datapb.FieldBinlog)
segment2DeltaBinlogs := make(map[UniqueID][]*datapb.FieldBinlog) segment2DeltaBinlogs := make(map[UniqueID][]*datapb.FieldBinlog)
segment2InsertChannel := make(map[UniqueID]string)
segmentsNumOfRows := make(map[UniqueID]int64) segmentsNumOfRows := make(map[UniqueID]int64)
flushedIDs := make(map[int64]struct{}) flushedIDs := make(map[int64]struct{})
...@@ -542,6 +543,7 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf ...@@ -542,6 +543,7 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf
if segment.State != commonpb.SegmentState_Flushed && segment.State != commonpb.SegmentState_Flushing { if segment.State != commonpb.SegmentState_Flushed && segment.State != commonpb.SegmentState_Flushing {
continue continue
} }
segment2InsertChannel[segment.ID] = segment.InsertChannel
binlogs := segment.GetBinlogs() binlogs := segment.GetBinlogs()
if len(binlogs) == 0 { if len(binlogs) == 0 {
...@@ -590,11 +592,12 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf ...@@ -590,11 +592,12 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf
binlogs := make([]*datapb.SegmentBinlogs, 0, len(segment2Binlogs)) binlogs := make([]*datapb.SegmentBinlogs, 0, len(segment2Binlogs))
for segmentID := range flushedIDs { for segmentID := range flushedIDs {
sbl := &datapb.SegmentBinlogs{ sbl := &datapb.SegmentBinlogs{
SegmentID: segmentID, SegmentID: segmentID,
NumOfRows: segmentsNumOfRows[segmentID], NumOfRows: segmentsNumOfRows[segmentID],
FieldBinlogs: segment2Binlogs[segmentID], FieldBinlogs: segment2Binlogs[segmentID],
Statslogs: segment2StatsBinlogs[segmentID], Statslogs: segment2StatsBinlogs[segmentID],
Deltalogs: segment2DeltaBinlogs[segmentID], Deltalogs: segment2DeltaBinlogs[segmentID],
InsertChannel: segment2InsertChannel[segmentID],
} }
binlogs = append(binlogs, sbl) binlogs = append(binlogs, sbl)
} }
......
...@@ -21,7 +21,6 @@ import ( ...@@ -21,7 +21,6 @@ import (
"errors" "errors"
"testing" "testing"
"github.com/milvus-io/milvus/internal/proxy"
"github.com/milvus-io/milvus/internal/util/mock" "github.com/milvus-io/milvus/internal/util/mock"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
...@@ -29,7 +28,7 @@ import ( ...@@ -29,7 +28,7 @@ import (
) )
func Test_NewClient(t *testing.T) { func Test_NewClient(t *testing.T) {
proxy.Params.InitOnce() ClientParams.InitOnce(typeutil.QueryNodeRole)
ctx := context.Background() ctx := context.Background()
client, err := NewClient(ctx, "") client, err := NewClient(ctx, "")
......
...@@ -149,6 +149,7 @@ enum MsgType { ...@@ -149,6 +149,7 @@ enum MsgType {
RemoveQueryChannels = 511; RemoveQueryChannels = 511;
SealedSegmentsChangeInfo = 512; SealedSegmentsChangeInfo = 512;
WatchDeltaChannels = 513; WatchDeltaChannels = 513;
GetShardLeaders = 514;
/* DATA SERVICE */ /* DATA SERVICE */
SegmentInfo = 600; SegmentInfo = 600;
...@@ -221,4 +222,4 @@ enum ImportState { ...@@ -221,4 +222,4 @@ enum ImportState {
ImportParsed = 4; ImportParsed = 4;
ImportPersisted = 5; ImportPersisted = 5;
ImportCompleted = 6; ImportCompleted = 6;
} }
\ No newline at end of file
...@@ -303,6 +303,7 @@ message SegmentBinlogs { ...@@ -303,6 +303,7 @@ message SegmentBinlogs {
int64 num_of_rows = 3; int64 num_of_rows = 3;
repeated FieldBinlog statslogs = 4; repeated FieldBinlog statslogs = 4;
repeated FieldBinlog deltalogs = 5; repeated FieldBinlog deltalogs = 5;
string insert_channel = 6;
} }
message FieldBinlog{ message FieldBinlog{
...@@ -367,6 +368,7 @@ message CompactionSegmentBinlogs { ...@@ -367,6 +368,7 @@ message CompactionSegmentBinlogs {
repeated FieldBinlog fieldBinlogs = 2; repeated FieldBinlog fieldBinlogs = 2;
repeated FieldBinlog field2StatslogPaths = 3; repeated FieldBinlog field2StatslogPaths = 3;
repeated FieldBinlog deltalogs = 4; repeated FieldBinlog deltalogs = 4;
string insert_channel = 5;
} }
message CompactionPlan { message CompactionPlan {
......
...@@ -343,6 +343,8 @@ message LoadPartitionsRequest { ...@@ -343,6 +343,8 @@ message LoadPartitionsRequest {
string collection_name = 3; string collection_name = 3;
// The partition names you want to load // The partition names you want to load
repeated string partition_names = 4; repeated string partition_names = 4;
// The replicas number you would load, 1 by default
int32 replica_number = 5;
} }
/* /*
...@@ -755,6 +757,7 @@ message LoadBalanceRequest { ...@@ -755,6 +757,7 @@ message LoadBalanceRequest {
int64 src_nodeID = 2; int64 src_nodeID = 2;
repeated int64 dst_nodeIDs = 3; repeated int64 dst_nodeIDs = 3;
repeated int64 sealed_segmentIDs = 4; repeated int64 sealed_segmentIDs = 4;
string collectionName = 5;
} }
message ManualCompactionRequest { message ManualCompactionRequest {
...@@ -857,8 +860,6 @@ message ShardReplica { ...@@ -857,8 +860,6 @@ message ShardReplica {
repeated int64 node_ids = 4; repeated int64 node_ids = 4;
} }
service ProxyService { service ProxyService {
rpc RegisterLink(RegisterLinkRequest) returns (RegisterLinkResponse) {} rpc RegisterLink(RegisterLinkRequest) returns (RegisterLinkResponse) {}
} }
...@@ -911,3 +912,4 @@ message ListCredUsersRequest { ...@@ -911,3 +912,4 @@ message ListCredUsersRequest {
// Not useful for now // Not useful for now
common.MsgBase base = 1; common.MsgBase base = 1;
} }
...@@ -201,6 +201,7 @@ message WatchDmChannelsRequest { ...@@ -201,6 +201,7 @@ message WatchDmChannelsRequest {
schema.CollectionSchema schema = 6; schema.CollectionSchema schema = 6;
repeated data.SegmentInfo exclude_infos = 7; repeated data.SegmentInfo exclude_infos = 7;
LoadMetaInfo load_meta = 8; LoadMetaInfo load_meta = 8;
int64 replicaID = 9;
} }
message WatchDeltaChannelsRequest { message WatchDeltaChannelsRequest {
...@@ -224,6 +225,7 @@ message SegmentLoadInfo { ...@@ -224,6 +225,7 @@ message SegmentLoadInfo {
repeated int64 compactionFrom = 10; // segmentIDs compacted from repeated int64 compactionFrom = 10; // segmentIDs compacted from
repeated FieldIndexInfo index_infos = 11; repeated FieldIndexInfo index_infos = 11;
int64 segment_size = 12; int64 segment_size = 12;
string insert_channel = 13;
} }
message FieldIndexInfo { message FieldIndexInfo {
...@@ -245,6 +247,7 @@ message LoadSegmentsRequest { ...@@ -245,6 +247,7 @@ message LoadSegmentsRequest {
int64 source_nodeID = 5; int64 source_nodeID = 5;
int64 collectionID = 6; int64 collectionID = 6;
LoadMetaInfo load_meta = 7; LoadMetaInfo load_meta = 7;
int64 replicaID = 8;
} }
message ReleaseSegmentsRequest { message ReleaseSegmentsRequest {
...@@ -281,6 +284,7 @@ message LoadBalanceRequest { ...@@ -281,6 +284,7 @@ message LoadBalanceRequest {
TriggerCondition balance_reason = 3; TriggerCondition balance_reason = 3;
repeated int64 dst_nodeIDs = 4; repeated int64 dst_nodeIDs = 4;
repeated int64 sealed_segmentIDs = 5; repeated int64 sealed_segmentIDs = 5;
int64 collectionID = 6;
} }
//-------------------- internal meta proto------------------ //-------------------- internal meta proto------------------
...@@ -312,6 +316,7 @@ message DmChannelWatchInfo { ...@@ -312,6 +316,7 @@ message DmChannelWatchInfo {
int64 collectionID = 1; int64 collectionID = 1;
string dmChannel = 2; string dmChannel = 2;
int64 nodeID_loaded = 3; int64 nodeID_loaded = 3;
int64 replicaID = 4;
} }
message QueryChannelInfo { message QueryChannelInfo {
...@@ -380,4 +385,3 @@ message SealedSegmentsChangeInfo { ...@@ -380,4 +385,3 @@ message SealedSegmentsChangeInfo {
common.MsgBase base = 1; common.MsgBase base = 1;
repeated SegmentChangeInfo infos = 2; repeated SegmentChangeInfo infos = 2;
} }
...@@ -31,6 +31,7 @@ import ( ...@@ -31,6 +31,7 @@ import (
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics" "github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
...@@ -144,8 +145,8 @@ func (node *Proxy) ReleaseDQLMessageStream(ctx context.Context, request *proxypb ...@@ -144,8 +145,8 @@ func (node *Proxy) ReleaseDQLMessageStream(ctx context.Context, request *proxypb
}, nil }, nil
} }
// TODO(dragondriver): add more detailed ut for ConsistencyLevel, should we support multiple consistency level in Proxy?
// CreateCollection create a collection by the schema. // CreateCollection create a collection by the schema.
// TODO(dragondriver): add more detailed ut for ConsistencyLevel, should we support multiple consistency level in Proxy?
func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) { func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) {
if !node.checkHealthy() { if !node.checkHealthy() {
return unhealthyStatus(), nil return unhealthyStatus(), nil
...@@ -2399,11 +2400,10 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) ...@@ -2399,11 +2400,10 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
}, },
ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
}, },
resultBuf: make(chan []*internalpb.SearchResults, 1), request: request,
query: request, qc: node.queryCoord,
chMgr: node.chMgr, tr: timerecord.NewTimeRecorder("search"),
qc: node.queryCoord, getQueryNodePolicy: defaultGetQueryNodePolicy,
tr: timerecord.NewTimeRecorder("search"),
} }
travelTs := request.TravelTimestamp travelTs := request.TravelTimestamp
...@@ -2516,11 +2516,11 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) ...@@ -2516,11 +2516,11 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
metrics.ProxySearchCount.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), metrics.ProxySearchCount.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
metrics.SearchLabel, metrics.SuccessLabel).Inc() metrics.SearchLabel, metrics.SuccessLabel).Inc()
metrics.ProxySearchVectors.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), metrics.ProxySearchVectors.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
metrics.SearchLabel).Set(float64(qt.result.Results.NumQueries)) metrics.SearchLabel).Set(float64(qt.result.GetResults().GetNumQueries()))
searchDur := tr.ElapseSpan().Milliseconds() searchDur := tr.ElapseSpan().Milliseconds()
metrics.ProxySearchLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), metrics.ProxySearchLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
metrics.SearchLabel).Observe(float64(searchDur)) metrics.SearchLabel).Observe(float64(searchDur))
metrics.ProxySearchLatencyPerNQ.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10)).Observe(float64(searchDur) / float64(qt.result.Results.NumQueries)) metrics.ProxySearchLatencyPerNQ.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10)).Observe(float64(searchDur) / float64(qt.result.GetResults().GetNumQueries()))
return qt.result, nil return qt.result, nil
} }
...@@ -2641,10 +2641,10 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (* ...@@ -2641,10 +2641,10 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
}, },
ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
}, },
resultBuf: make(chan []*internalpb.RetrieveResults), request: request,
query: request, qc: node.queryCoord,
chMgr: node.chMgr, getQueryNodePolicy: defaultGetQueryNodePolicy,
qc: node.queryCoord, queryShardPolicy: roundRobinPolicy,
} }
method := "Query" method := "Query"
...@@ -3058,11 +3058,12 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista ...@@ -3058,11 +3058,12 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
}, },
ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
}, },
resultBuf: make(chan []*internalpb.RetrieveResults), request: queryRequest,
query: queryRequest, qc: node.queryCoord,
chMgr: node.chMgr, ids: ids.IdArray,
qc: node.queryCoord,
ids: ids.IdArray, getQueryNodePolicy: defaultGetQueryNodePolicy,
queryShardPolicy: roundRobinPolicy,
} }
err := node.sched.dqQueue.Enqueue(qt) err := node.sched.dqQueue.Enqueue(qt)
...@@ -3715,6 +3716,7 @@ func (node *Proxy) RegisterLink(ctx context.Context, req *milvuspb.RegisterLinkR ...@@ -3715,6 +3716,7 @@ func (node *Proxy) RegisterLink(ctx context.Context, req *milvuspb.RegisterLinkR
}, nil }, nil
} }
// GetMetrics gets the metrics of proxy
// TODO(dragondriver): cache the Metrics and set a retention to the cache // TODO(dragondriver): cache the Metrics and set a retention to the cache
func (node *Proxy) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) { func (node *Proxy) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
log.Debug("Proxy.GetMetrics", log.Debug("Proxy.GetMetrics",
...@@ -3817,6 +3819,13 @@ func (node *Proxy) LoadBalance(ctx context.Context, req *milvuspb.LoadBalanceReq ...@@ -3817,6 +3819,13 @@ func (node *Proxy) LoadBalance(ctx context.Context, req *milvuspb.LoadBalanceReq
status := &commonpb.Status{ status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
} }
collectionID, err := globalMetaCache.GetCollectionID(ctx, req.GetCollectionName())
if err != nil {
log.Error("failed to get collection id", zap.String("collection name", req.GetCollectionName()), zap.Error(err))
status.Reason = err.Error()
return status, nil
}
infoResp, err := node.queryCoord.LoadBalance(ctx, &querypb.LoadBalanceRequest{ infoResp, err := node.queryCoord.LoadBalance(ctx, &querypb.LoadBalanceRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadBalanceSegments, MsgType: commonpb.MsgType_LoadBalanceSegments,
...@@ -3828,6 +3837,7 @@ func (node *Proxy) LoadBalance(ctx context.Context, req *milvuspb.LoadBalanceReq ...@@ -3828,6 +3837,7 @@ func (node *Proxy) LoadBalance(ctx context.Context, req *milvuspb.LoadBalanceReq
DstNodeIDs: req.DstNodeIDs, DstNodeIDs: req.DstNodeIDs,
BalanceReason: querypb.TriggerCondition_GrpcRequest, BalanceReason: querypb.TriggerCondition_GrpcRequest,
SealedSegmentIDs: req.SealedSegmentIDs, SealedSegmentIDs: req.SealedSegmentIDs,
CollectionID: collectionID,
}) })
if err != nil { if err != nil {
log.Error("Failed to LoadBalance from Query Coordinator", log.Error("Failed to LoadBalance from Query Coordinator",
...@@ -3873,6 +3883,7 @@ func (node *Proxy) ManualCompaction(ctx context.Context, req *milvuspb.ManualCom ...@@ -3873,6 +3883,7 @@ func (node *Proxy) ManualCompaction(ctx context.Context, req *milvuspb.ManualCom
return resp, err return resp, err
} }
// GetCompactionStateWithPlans returns the compactions states with the given plan ID
func (node *Proxy) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.GetCompactionPlansRequest) (*milvuspb.GetCompactionPlansResponse, error) { func (node *Proxy) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.GetCompactionPlansRequest) (*milvuspb.GetCompactionPlansResponse, error) {
log.Info("received GetCompactionStateWithPlans request", zap.Int64("compactionID", req.GetCompactionID())) log.Info("received GetCompactionStateWithPlans request", zap.Int64("compactionID", req.GetCompactionID()))
resp := &milvuspb.GetCompactionPlansResponse{} resp := &milvuspb.GetCompactionPlansResponse{}
...@@ -3979,7 +3990,7 @@ func (node *Proxy) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasReq ...@@ -3979,7 +3990,7 @@ func (node *Proxy) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasReq
return resp, err return resp, err
} }
// Check import task state from datanode // GetImportState checks import task state from datanode
func (node *Proxy) GetImportState(ctx context.Context, req *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error) { func (node *Proxy) GetImportState(ctx context.Context, req *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error) {
log.Info("received get import state request", zap.Int64("taskID", req.GetTask())) log.Info("received get import state request", zap.Int64("taskID", req.GetTask()))
resp := &milvuspb.GetImportStateResponse{} resp := &milvuspb.GetImportStateResponse{}
...@@ -4206,3 +4217,19 @@ func (node *Proxy) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUser ...@@ -4206,3 +4217,19 @@ func (node *Proxy) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUser
Usernames: usernames, Usernames: usernames,
}, nil }, nil
} }
// SendSearchResult needs to be removed TODO
func (node *Proxy) SendSearchResult(ctx context.Context, req *internalpb.SearchResults) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "Not implemented",
}, nil
}
// SendRetrieveResult needs to be removed TODO
func (node *Proxy) SendRetrieveResult(ctx context.Context, req *internalpb.RetrieveResults) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "Not implemented",
}, nil
}
...@@ -31,6 +31,7 @@ import ( ...@@ -31,6 +31,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
...@@ -52,6 +53,7 @@ type Cache interface { ...@@ -52,6 +53,7 @@ type Cache interface {
GetPartitionInfo(ctx context.Context, collectionName string, partitionName string) (*partitionInfo, error) GetPartitionInfo(ctx context.Context, collectionName string, partitionName string) (*partitionInfo, error)
// GetCollectionSchema get collection's schema. // GetCollectionSchema get collection's schema.
GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error) GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error)
GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) ([]*querypb.ShardLeadersList, error)
RemoveCollection(ctx context.Context, collectionName string) RemoveCollection(ctx context.Context, collectionName string)
RemovePartition(ctx context.Context, collectionName string, partitionName string) RemovePartition(ctx context.Context, collectionName string, partitionName string)
...@@ -67,6 +69,7 @@ type collectionInfo struct { ...@@ -67,6 +69,7 @@ type collectionInfo struct {
collID typeutil.UniqueID collID typeutil.UniqueID
schema *schemapb.CollectionSchema schema *schemapb.CollectionSchema
partInfo map[string]*partitionInfo partInfo map[string]*partitionInfo
shardLeaders []*querypb.ShardLeadersList
createdTimestamp uint64 createdTimestamp uint64
createdUtcTimestamp uint64 createdUtcTimestamp uint64
} }
...@@ -160,6 +163,7 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, collectionName string ...@@ -160,6 +163,7 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, collectionName string
collInfo = m.collInfo[collectionName] collInfo = m.collInfo[collectionName]
metrics.ProxyUpdateCacheLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10)).Observe(float64(tr.ElapseSpan().Milliseconds())) metrics.ProxyUpdateCacheLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10)).Observe(float64(tr.ElapseSpan().Milliseconds()))
} }
metrics.ProxyCacheHitCounter.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), "GetCollectionInfo", metrics.CacheHitLabel).Inc() metrics.ProxyCacheHitCounter.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), "GetCollectionInfo", metrics.CacheHitLabel).Inc()
return &collectionInfo{ return &collectionInfo{
collID: collInfo.collID, collID: collInfo.collID,
...@@ -167,6 +171,7 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, collectionName string ...@@ -167,6 +171,7 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, collectionName string
partInfo: collInfo.partInfo, partInfo: collInfo.partInfo,
createdTimestamp: collInfo.createdTimestamp, createdTimestamp: collInfo.createdTimestamp,
createdUtcTimestamp: collInfo.createdUtcTimestamp, createdUtcTimestamp: collInfo.createdUtcTimestamp,
shardLeaders: collInfo.shardLeaders,
}, nil }, nil
} }
...@@ -520,3 +525,41 @@ func (m *MetaCache) GetCredUsernames(ctx context.Context) ([]string, error) { ...@@ -520,3 +525,41 @@ func (m *MetaCache) GetCredUsernames(ctx context.Context) ([]string, error) {
return usernames, nil return usernames, nil
} }
// GetShards update cache if withCache == false
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) ([]*querypb.ShardLeadersList, error) {
info, err := m.GetCollectionInfo(ctx, collectionName)
if err != nil {
return nil, err
}
if withCache {
if len(info.shardLeaders) > 0 {
return info.shardLeaders, nil
}
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord",
zap.String("collectionName", collectionName))
}
m.mu.Lock()
defer m.mu.Unlock()
req := &querypb.GetShardLeadersRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_GetShardLeaders,
SourceID: Params.ProxyCfg.ProxyID,
},
CollectionID: info.collID,
}
resp, err := qc.GetShardLeaders(ctx, req)
if err != nil {
return nil, err
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return nil, fmt.Errorf("fail to get shard leaders from QueryCoord: %s", resp.Status.Reason)
}
shards := resp.GetShards()
m.collInfo[collectionName].shardLeaders = shards
return shards, nil
}
...@@ -22,17 +22,17 @@ import ( ...@@ -22,17 +22,17 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/milvus-io/milvus/internal/util/crypto"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb" "github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/crypto"
"github.com/milvus-io/milvus/internal/util/typeutil" "github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
type MockRootCoordClientInterface struct { type MockRootCoordClientInterface struct {
...@@ -310,3 +310,51 @@ func TestMetaCache_GetPartitionError(t *testing.T) { ...@@ -310,3 +310,51 @@ func TestMetaCache_GetPartitionError(t *testing.T) {
log.Debug(err.Error()) log.Debug(err.Error())
assert.Equal(t, id, typeutil.UniqueID(0)) assert.Equal(t, id, typeutil.UniqueID(0))
} }
func TestMetaCache_GetShards(t *testing.T) {
client := &MockRootCoordClientInterface{}
err := InitMetaCache(client)
require.Nil(t, err)
var (
ctx = context.TODO()
collectionName = "collection1"
qc = NewQueryCoordMock()
)
qc.Init()
qc.Start()
defer qc.Stop()
t.Run("No collection in meta cache", func(t *testing.T) {
shards, err := globalMetaCache.GetShards(ctx, true, "non-exists", qc)
assert.Error(t, err)
assert.Empty(t, shards)
})
t.Run("without shardLeaders in collection info invalid shardLeaders", func(t *testing.T) {
qc.validShardLeaders = false
shards, err := globalMetaCache.GetShards(ctx, false, collectionName, qc)
assert.Error(t, err)
assert.Empty(t, shards)
})
t.Run("without shardLeaders in collection info", func(t *testing.T) {
qc.validShardLeaders = true
shards, err := globalMetaCache.GetShards(ctx, true, collectionName, qc)
assert.NoError(t, err)
assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards))
assert.Equal(t, 3, len(shards[0].GetNodeAddrs()))
assert.Equal(t, 3, len(shards[0].GetNodeIds()))
// get from cache
qc.validShardLeaders = false
shards, err = globalMetaCache.GetShards(ctx, true, collectionName, qc)
assert.NoError(t, err)
assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards))
assert.Equal(t, 3, len(shards[0].GetNodeAddrs()))
assert.Equal(t, 3, len(shards[0].GetNodeIds()))
})
}
...@@ -93,8 +93,7 @@ type Proxy struct { ...@@ -93,8 +93,7 @@ type Proxy struct {
factory dependency.Factory factory dependency.Factory
searchResultCh chan *internalpb.SearchResults searchResultCh chan *internalpb.SearchResults
retrieveResultCh chan *internalpb.RetrieveResults
// Add callback functions at different stages // Add callback functions at different stages
startCallbacks []func() startCallbacks []func()
...@@ -107,11 +106,10 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) { ...@@ -107,11 +106,10 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
ctx1, cancel := context.WithCancel(ctx) ctx1, cancel := context.WithCancel(ctx)
n := 1024 // better to be configurable n := 1024 // better to be configurable
node := &Proxy{ node := &Proxy{
ctx: ctx1, ctx: ctx1,
cancel: cancel, cancel: cancel,
factory: factory, factory: factory,
searchResultCh: make(chan *internalpb.SearchResults, n), searchResultCh: make(chan *internalpb.SearchResults, n),
retrieveResultCh: make(chan *internalpb.RetrieveResults, n),
} }
node.UpdateStateCode(internalpb.StateCode_Abnormal) node.UpdateStateCode(internalpb.StateCode_Abnormal)
logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load())) logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load()))
...@@ -228,9 +226,7 @@ func (node *Proxy) Init() error { ...@@ -228,9 +226,7 @@ func (node *Proxy) Init() error {
log.Debug("create channels manager done", zap.String("role", typeutil.ProxyRole)) log.Debug("create channels manager done", zap.String("role", typeutil.ProxyRole))
log.Debug("create task scheduler", zap.String("role", typeutil.ProxyRole)) log.Debug("create task scheduler", zap.String("role", typeutil.ProxyRole))
node.sched, err = newTaskScheduler(node.ctx, node.idAllocator, node.tsoAllocator, node.factory, node.sched, err = newTaskScheduler(node.ctx, node.idAllocator, node.tsoAllocator, node.factory)
schedOptWithSearchResultCh(node.searchResultCh),
schedOptWithRetrieveResultCh(node.retrieveResultCh))
if err != nil { if err != nil {
log.Warn("failed to create task scheduler", zap.Error(err), zap.String("role", typeutil.ProxyRole)) log.Warn("failed to create task scheduler", zap.Error(err), zap.String("role", typeutil.ProxyRole))
return err return err
......
...@@ -17,12 +17,7 @@ ...@@ -17,12 +17,7 @@
package proxy package proxy
import ( import (
"bytes"
"context" "context"
"encoding/binary"
"encoding/json"
"fmt"
"math/rand"
"net" "net"
"os" "os"
"strconv" "strconv"
...@@ -30,81 +25,56 @@ import ( ...@@ -30,81 +25,56 @@ import (
"testing" "testing"
"time" "time"
"github.com/golang/protobuf/proto"
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/rootcoord"
"github.com/milvus-io/milvus/internal/util/crypto" "github.com/milvus-io/milvus/internal/util/crypto"
"github.com/milvus-io/milvus/internal/util/dependency" "github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/distance"
"github.com/prometheus/client_golang/prometheus"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/etcd" "github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/tsoutil" "github.com/milvus-io/milvus/internal/util/logutil"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/util/metricsinfo" "github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/internal/util/distance" "github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/proxypb" "github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb" "github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus/internal/proto/schemapb" "github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
grpcindexcoordclient "github.com/milvus-io/milvus/internal/distributed/indexcoord/client"
grpcquerycoordclient "github.com/milvus-io/milvus/internal/distributed/querycoord/client"
grpcdatacoordclient2 "github.com/milvus-io/milvus/internal/distributed/datacoord/client"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client"
grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode"
grpcindexcoord "github.com/milvus-io/milvus/internal/distributed/indexcoord"
grpcdatacoordclient "github.com/milvus-io/milvus/internal/distributed/datacoord" grpcdatacoordclient "github.com/milvus-io/milvus/internal/distributed/datacoord"
grpcdatacoordclient2 "github.com/milvus-io/milvus/internal/distributed/datacoord/client"
grpcdatanode "github.com/milvus-io/milvus/internal/distributed/datanode" grpcdatanode "github.com/milvus-io/milvus/internal/distributed/datanode"
grpcindexcoord "github.com/milvus-io/milvus/internal/distributed/indexcoord"
grpcquerynode "github.com/milvus-io/milvus/internal/distributed/querynode" grpcindexcoordclient "github.com/milvus-io/milvus/internal/distributed/indexcoord/client"
grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode"
grpcquerycoord "github.com/milvus-io/milvus/internal/distributed/querycoord" grpcquerycoord "github.com/milvus-io/milvus/internal/distributed/querycoord"
grpcquerycoordclient "github.com/milvus-io/milvus/internal/distributed/querycoord/client"
grpcquerynode "github.com/milvus-io/milvus/internal/distributed/querynode"
grpcrootcoord "github.com/milvus-io/milvus/internal/distributed/rootcoord" grpcrootcoord "github.com/milvus-io/milvus/internal/distributed/rootcoord"
rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client"
"github.com/milvus-io/milvus/internal/datacoord" "github.com/milvus-io/milvus/internal/datacoord"
"github.com/milvus-io/milvus/internal/datanode" "github.com/milvus-io/milvus/internal/datanode"
"github.com/milvus-io/milvus/internal/indexcoord" "github.com/milvus-io/milvus/internal/indexcoord"
"github.com/milvus-io/milvus/internal/indexnode" "github.com/milvus-io/milvus/internal/indexnode"
"github.com/milvus-io/milvus/internal/querynode"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/querycoord" "github.com/milvus-io/milvus/internal/querycoord"
"github.com/milvus-io/milvus/internal/querynode"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/rootcoord"
"github.com/milvus-io/milvus/internal/util/logutil"
) )
const ( const (
...@@ -620,12 +590,12 @@ func TestProxy(t *testing.T) { ...@@ -620,12 +590,12 @@ func TestProxy(t *testing.T) {
rowNum := 3000 rowNum := 3000
indexName := "_default" indexName := "_default"
nlist := 10 nlist := 10
nprobe := 10 // nprobe := 10
topk := 10 // topk := 10
// add a test parameter // add a test parameter
roundDecimal := 6 // roundDecimal := 6
nq := 10 nq := 10
expr := fmt.Sprintf("%s > 0", int64Field) // expr := fmt.Sprintf("%s > 0", int64Field)
var segmentIDs []int64 var segmentIDs []int64
// an int64 field (pk) & a float vector field // an int64 field (pk) & a float vector field
...@@ -721,76 +691,6 @@ func TestProxy(t *testing.T) { ...@@ -721,76 +691,6 @@ func TestProxy(t *testing.T) {
} }
} }
constructPlaceholderGroup := func() *milvuspb.PlaceholderGroup {
values := make([][]byte, 0, nq)
for i := 0; i < nq; i++ {
bs := make([]byte, 0, dim*4)
for j := 0; j < dim; j++ {
var buffer bytes.Buffer
f := rand.Float32()
err := binary.Write(&buffer, common.Endian, f)
assert.NoError(t, err)
bs = append(bs, buffer.Bytes()...)
}
values = append(values, bs)
}
return &milvuspb.PlaceholderGroup{
Placeholders: []*milvuspb.PlaceholderValue{
{
Tag: "$0",
Type: milvuspb.PlaceholderType_FloatVector,
Values: values,
},
},
}
}
constructSearchRequest := func() *milvuspb.SearchRequest {
params := make(map[string]string)
params["nprobe"] = strconv.Itoa(nprobe)
b, err := json.Marshal(params)
assert.NoError(t, err)
plg := constructPlaceholderGroup()
plgBs, err := proto.Marshal(plg)
assert.NoError(t, err)
return &milvuspb.SearchRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
PartitionNames: nil,
Dsl: expr,
PlaceholderGroup: plgBs,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: nil,
SearchParams: []*commonpb.KeyValuePair{
{
Key: MetricTypeKey,
Value: distance.L2,
},
{
Key: SearchParamsKey,
Value: string(b),
},
{
Key: AnnsFieldKey,
Value: floatVecField,
},
{
Key: TopKKey,
Value: strconv.Itoa(topk),
},
{
Key: RoundDecimalKey,
Value: strconv.Itoa(roundDecimal),
},
},
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
}
}
wg.Add(1) wg.Add(1)
t.Run("create collection", func(t *testing.T) { t.Run("create collection", func(t *testing.T) {
defer wg.Done() defer wg.Done()
...@@ -1368,103 +1268,178 @@ func TestProxy(t *testing.T) { ...@@ -1368,103 +1268,178 @@ func TestProxy(t *testing.T) {
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
}) })
if loaded { // nprobe := 10
wg.Add(1) // topk := 10
t.Run("search", func(t *testing.T) { // roundDecimal := 6
defer wg.Done() // expr := fmt.Sprintf("%s > 0", int64Field)
req := constructSearchRequest() // constructPlaceholderGroup := func() *milvuspb.PlaceholderGroup {
// values := make([][]byte, 0, nq)
resp, err := proxy.Search(ctx, req) // for i := 0; i < nq; i++ {
assert.NoError(t, err) // bs := make([]byte, 0, dim*4)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) // for j := 0; j < dim; j++ {
}) // var buffer bytes.Buffer
// f := rand.Float32()
wg.Add(1) // err := binary.Write(&buffer, common.Endian, f)
t.Run("search_travel", func(t *testing.T) { // assert.NoError(t, err)
defer wg.Done() // bs = append(bs, buffer.Bytes()...)
past := time.Now().Add(time.Duration(-1*Params.CommonCfg.RetentionDuration-100) * time.Second) // }
travelTs := tsoutil.ComposeTSByTime(past, 0) // values = append(values, bs)
req := constructSearchRequest() // }
req.TravelTimestamp = travelTs //
//resp, err := proxy.Search(ctx, req) // return &milvuspb.PlaceholderGroup{
res, err := proxy.Search(ctx, req) // Placeholders: []*milvuspb.PlaceholderValue{
assert.NoError(t, err) // {
assert.NotEqual(t, commonpb.ErrorCode_Success, res.Status.ErrorCode) // Tag: "$0",
}) // Type: milvuspb.PlaceholderType_FloatVector,
// Values: values,
wg.Add(1) // },
t.Run("search_travel_succ", func(t *testing.T) { // },
defer wg.Done() // }
past := time.Now().Add(time.Duration(-1*Params.CommonCfg.RetentionDuration+100) * time.Second) // }
travelTs := tsoutil.ComposeTSByTime(past, 0) //
req := constructSearchRequest() // constructSearchRequest := func() *milvuspb.SearchRequest {
req.TravelTimestamp = travelTs // params := make(map[string]string)
//resp, err := proxy.Search(ctx, req) // params["nprobe"] = strconv.Itoa(nprobe)
res, err := proxy.Search(ctx, req) // b, err := json.Marshal(params)
assert.NoError(t, err) // assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, res.Status.ErrorCode) // plg := constructPlaceholderGroup()
}) // plgBs, err := proto.Marshal(plg)
// assert.NoError(t, err)
wg.Add(1) //
t.Run("query", func(t *testing.T) { // return &milvuspb.SearchRequest{
defer wg.Done() // Base: nil,
//resp, err := proxy.Query(ctx, &milvuspb.QueryRequest{ // DbName: dbName,
_, err := proxy.Query(ctx, &milvuspb.QueryRequest{ // CollectionName: collectionName,
Base: nil, // PartitionNames: nil,
DbName: dbName, // Dsl: expr,
CollectionName: collectionName, // PlaceholderGroup: plgBs,
Expr: expr, // DslType: commonpb.DslType_BoolExprV1,
OutputFields: nil, // OutputFields: nil,
PartitionNames: nil, // SearchParams: []*commonpb.KeyValuePair{
TravelTimestamp: 0, // {
GuaranteeTimestamp: 0, // Key: MetricTypeKey,
}) // Value: distance.L2,
assert.NoError(t, err) // },
// FIXME(dragondriver) // {
// assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode) // Key: SearchParamsKey,
// TODO(dragondriver): compare query result // Value: string(b),
}) // },
// {
wg.Add(1) // Key: AnnsFieldKey,
t.Run("query_travel", func(t *testing.T) { // Value: floatVecField,
defer wg.Done() // },
past := time.Now().Add(time.Duration(-1*Params.CommonCfg.RetentionDuration-100) * time.Second) // {
travelTs := tsoutil.ComposeTSByTime(past, 0) // Key: TopKKey,
queryReq := &milvuspb.QueryRequest{ // Value: strconv.Itoa(topk),
Base: nil, // },
DbName: dbName, // {
CollectionName: collectionName, // Key: RoundDecimalKey,
Expr: expr, // Value: strconv.Itoa(roundDecimal),
OutputFields: nil, // },
PartitionNames: nil, // },
TravelTimestamp: travelTs, // TravelTimestamp: 0,
GuaranteeTimestamp: 0, // GuaranteeTimestamp: 0,
} // }
res, err := proxy.Query(ctx, queryReq) // }
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, res.Status.ErrorCode) // TODO(Goose): reopen after joint-tests
}) // if loaded {
// wg.Add(1)
wg.Add(1) // t.Run("search", func(t *testing.T) {
t.Run("query_travel_succ", func(t *testing.T) { // defer wg.Done()
defer wg.Done() // req := constructSearchRequest()
past := time.Now().Add(time.Duration(-1*Params.CommonCfg.RetentionDuration+100) * time.Second) //
travelTs := tsoutil.ComposeTSByTime(past, 0) // resp, err := proxy.Search(ctx, req)
queryReq := &milvuspb.QueryRequest{ // assert.NoError(t, err)
Base: nil, // assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
DbName: dbName, // })
CollectionName: collectionName, //
Expr: expr, // wg.Add(1)
OutputFields: nil, // t.Run("search_travel", func(t *testing.T) {
PartitionNames: nil, // defer wg.Done()
TravelTimestamp: travelTs, // past := time.Now().Add(time.Duration(-1*Params.CommonCfg.RetentionDuration-100) * time.Second)
GuaranteeTimestamp: 0, // travelTs := tsoutil.ComposeTSByTime(past, 0)
} // req := constructSearchRequest()
res, err := proxy.Query(ctx, queryReq) // req.TravelTimestamp = travelTs
assert.NoError(t, err) // //resp, err := proxy.Search(ctx, req)
assert.Equal(t, commonpb.ErrorCode_EmptyCollection, res.Status.ErrorCode) // res, err := proxy.Search(ctx, req)
}) // assert.NoError(t, err)
} // assert.NotEqual(t, commonpb.ErrorCode_Success, res.Status.ErrorCode)
// })
//
// wg.Add(1)
// t.Run("search_travel_succ", func(t *testing.T) {
// defer wg.Done()
// past := time.Now().Add(time.Duration(-1*Params.CommonCfg.RetentionDuration+100) * time.Second)
// travelTs := tsoutil.ComposeTSByTime(past, 0)
// req := constructSearchRequest()
// req.TravelTimestamp = travelTs
// //resp, err := proxy.Search(ctx, req)
// res, err := proxy.Search(ctx, req)
// assert.NoError(t, err)
// assert.Equal(t, commonpb.ErrorCode_Success, res.Status.ErrorCode)
// })
//
// wg.Add(1)
// t.Run("query", func(t *testing.T) {
// defer wg.Done()
// //resp, err := proxy.Query(ctx, &milvuspb.QueryRequest{
// _, err := proxy.Query(ctx, &milvuspb.QueryRequest{
// Base: nil,
// DbName: dbName,
// CollectionName: collectionName,
// Expr: expr,
// OutputFields: nil,
// PartitionNames: nil,
// TravelTimestamp: 0,
// GuaranteeTimestamp: 0,
// })
// assert.NoError(t, err)
// // FIXME(dragondriver)
// // assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
// // TODO(dragondriver): compare query result
// })
//
// wg.Add(1)
// t.Run("query_travel", func(t *testing.T) {
// defer wg.Done()
// past := time.Now().Add(time.Duration(-1*Params.CommonCfg.RetentionDuration-100) * time.Second)
// travelTs := tsoutil.ComposeTSByTime(past, 0)
// queryReq := &milvuspb.QueryRequest{
// Base: nil,
// DbName: dbName,
// CollectionName: collectionName,
// Expr: expr,
// OutputFields: nil,
// PartitionNames: nil,
// TravelTimestamp: travelTs,
// GuaranteeTimestamp: 0,
// }
// res, err := proxy.Query(ctx, queryReq)
// assert.NoError(t, err)
// assert.NotEqual(t, commonpb.ErrorCode_Success, res.Status.ErrorCode)
// })
//
// wg.Add(1)
// t.Run("query_travel_succ", func(t *testing.T) {
// defer wg.Done()
// past := time.Now().Add(time.Duration(-1*Params.CommonCfg.RetentionDuration+100) * time.Second)
// travelTs := tsoutil.ComposeTSByTime(past, 0)
// queryReq := &milvuspb.QueryRequest{
// Base: nil,
// DbName: dbName,
// CollectionName: collectionName,
// Expr: expr,
// OutputFields: nil,
// PartitionNames: nil,
// TravelTimestamp: travelTs,
// GuaranteeTimestamp: 0,
// }
// res, err := proxy.Query(ctx, queryReq)
// assert.NoError(t, err)
// assert.Equal(t, commonpb.ErrorCode_EmptyCollection, res.Status.ErrorCode)
// })
// }
wg.Add(1) wg.Add(1)
t.Run("calculate distance", func(t *testing.T) { t.Run("calculate distance", func(t *testing.T) {
...@@ -1683,6 +1658,7 @@ func TestProxy(t *testing.T) { ...@@ -1683,6 +1658,7 @@ func TestProxy(t *testing.T) {
DbName: dbName, DbName: dbName,
CollectionName: collectionName, CollectionName: collectionName,
PartitionNames: []string{partitionName}, PartitionNames: []string{partitionName},
ReplicaNumber: 1,
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode) assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
...@@ -1693,6 +1669,7 @@ func TestProxy(t *testing.T) { ...@@ -1693,6 +1669,7 @@ func TestProxy(t *testing.T) {
DbName: dbName, DbName: dbName,
CollectionName: collectionName, CollectionName: collectionName,
PartitionNames: []string{otherPartitionName}, PartitionNames: []string{otherPartitionName},
ReplicaNumber: 1,
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode) assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode)
...@@ -1703,6 +1680,7 @@ func TestProxy(t *testing.T) { ...@@ -1703,6 +1680,7 @@ func TestProxy(t *testing.T) {
DbName: dbName, DbName: dbName,
CollectionName: otherCollectionName, CollectionName: otherCollectionName,
PartitionNames: []string{partitionName}, PartitionNames: []string{partitionName},
ReplicaNumber: 1,
}) })
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode) assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode)
......
...@@ -36,12 +36,20 @@ type QueryCoordMockOption func(mock *QueryCoordMock) ...@@ -36,12 +36,20 @@ type QueryCoordMockOption func(mock *QueryCoordMock)
type queryCoordShowCollectionsFuncType func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error) type queryCoordShowCollectionsFuncType func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error)
type queryCoordShowPartitionsFuncType func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error)
func SetQueryCoordShowCollectionsFunc(f queryCoordShowCollectionsFuncType) QueryCoordMockOption { func SetQueryCoordShowCollectionsFunc(f queryCoordShowCollectionsFuncType) QueryCoordMockOption {
return func(mock *QueryCoordMock) { return func(mock *QueryCoordMock) {
mock.showCollectionsFunc = f mock.showCollectionsFunc = f
} }
} }
func withValidShardLeaders() QueryCoordMockOption {
return func(mock *QueryCoordMock) {
mock.validShardLeaders = true
}
}
type QueryCoordMock struct { type QueryCoordMock struct {
nodeID typeutil.UniqueID nodeID typeutil.UniqueID
address string address string
...@@ -54,9 +62,12 @@ type QueryCoordMock struct { ...@@ -54,9 +62,12 @@ type QueryCoordMock struct {
showCollectionsFunc queryCoordShowCollectionsFuncType showCollectionsFunc queryCoordShowCollectionsFuncType
getMetricsFunc getMetricsFuncType getMetricsFunc getMetricsFuncType
showPartitionsFunc queryCoordShowPartitionsFuncType
statisticsChannel string statisticsChannel string
timeTickChannel string timeTickChannel string
validShardLeaders bool
} }
func (coord *QueryCoordMock) updateState(state internalpb.StateCode) { func (coord *QueryCoordMock) updateState(state internalpb.StateCode) {
...@@ -223,6 +234,14 @@ func (coord *QueryCoordMock) ReleaseCollection(ctx context.Context, req *querypb ...@@ -223,6 +234,14 @@ func (coord *QueryCoordMock) ReleaseCollection(ctx context.Context, req *querypb
}, nil }, nil
} }
func (coord *QueryCoordMock) SetShowPartitionsFunc(f queryCoordShowPartitionsFuncType) {
coord.showPartitionsFunc = f
}
func (coord *QueryCoordMock) ResetShowPartitionsFunc() {
coord.showPartitionsFunc = nil
}
func (coord *QueryCoordMock) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) { func (coord *QueryCoordMock) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
if !coord.healthy() { if !coord.healthy() {
return &querypb.ShowPartitionsResponse{ return &querypb.ShowPartitionsResponse{
...@@ -233,6 +252,10 @@ func (coord *QueryCoordMock) ShowPartitions(ctx context.Context, req *querypb.Sh ...@@ -233,6 +252,10 @@ func (coord *QueryCoordMock) ShowPartitions(ctx context.Context, req *querypb.Sh
}, nil }, nil
} }
if coord.showPartitionsFunc != nil {
return coord.showPartitionsFunc(ctx, req)
}
panic("implement me") panic("implement me")
} }
...@@ -360,6 +383,21 @@ func (coord *QueryCoordMock) GetShardLeaders(ctx context.Context, req *querypb.G ...@@ -360,6 +383,21 @@ func (coord *QueryCoordMock) GetShardLeaders(ctx context.Context, req *querypb.G
}, nil }, nil
} }
if coord.validShardLeaders {
return &querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
},
}, nil
}
return &querypb.GetShardLeadersResponse{ return &querypb.GetShardLeadersResponse{
Status: &commonpb.Status{ Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError, ErrorCode: commonpb.ErrorCode_UnexpectedError,
......
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"context"
"sync/atomic"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
var _ types.QueryNode = &QueryNodeMock{}
type QueryNodeMock struct {
nodeID typeutil.UniqueID
address string
state atomic.Value // internal.StateCode
withSearchResult *internalpb.SearchResults
withQueryResult *internalpb.RetrieveResults
}
func (m *QueryNodeMock) Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) {
return m.withSearchResult, nil
}
func (m *QueryNodeMock) Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
return m.withQueryResult, nil
}
// TODO
func (m *QueryNodeMock) AddQueryChannel(ctx context.Context, req *querypb.AddQueryChannelRequest) (*commonpb.Status, error) {
return nil, nil
}
// TODO
func (m *QueryNodeMock) RemoveQueryChannel(ctx context.Context, req *querypb.RemoveQueryChannelRequest) (*commonpb.Status, error) {
return nil, nil
}
// TODO
func (m *QueryNodeMock) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) {
return nil, nil
}
// TODO
func (m *QueryNodeMock) WatchDeltaChannels(ctx context.Context, req *querypb.WatchDeltaChannelsRequest) (*commonpb.Status, error) {
return nil, nil
}
// TODO
func (m *QueryNodeMock) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
return nil, nil
}
// TODO
func (m *QueryNodeMock) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
return nil, nil
}
// TODO
func (m *QueryNodeMock) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
return nil, nil
}
// TODO
func (m *QueryNodeMock) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
return nil, nil
}
// TODO
func (m *QueryNodeMock) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
return nil, nil
}
// TODO
func (m *QueryNodeMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
return nil, nil
}
func (m *QueryNodeMock) Init() error { return nil }
func (m *QueryNodeMock) Start() error { return nil }
func (m *QueryNodeMock) Stop() error { return nil }
func (m *QueryNodeMock) Register() error { return nil }
func (m *QueryNodeMock) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
return nil, nil
}
func (m *QueryNodeMock) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return nil, nil
}
func (m *QueryNodeMock) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return nil, nil
}
package proxy
import (
"context"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
)
func (node *Proxy) SendSearchResult(ctx context.Context, req *internalpb.SearchResults) (*commonpb.Status, error) {
node.searchResultCh <- req
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
}, nil
}
func (node *Proxy) SendRetrieveResult(ctx context.Context, req *internalpb.RetrieveResults) (*commonpb.Status, error) {
node.retrieveResultCh <- req
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
}, nil
}
...@@ -2675,9 +2675,10 @@ func (lct *loadCollectionTask) Execute(ctx context.Context) (err error) { ...@@ -2675,9 +2675,10 @@ func (lct *loadCollectionTask) Execute(ctx context.Context) (err error) {
Timestamp: lct.Base.Timestamp, Timestamp: lct.Base.Timestamp,
SourceID: lct.Base.SourceID, SourceID: lct.Base.SourceID,
}, },
DbID: 0, DbID: 0,
CollectionID: collID, CollectionID: collID,
Schema: collSchema, Schema: collSchema,
ReplicaNumber: lct.ReplicaNumber,
} }
log.Debug("send LoadCollectionRequest to query coordinator", zap.String("role", typeutil.ProxyRole), log.Debug("send LoadCollectionRequest to query coordinator", zap.String("role", typeutil.ProxyRole),
zap.Int64("msgID", request.Base.MsgID), zap.Int64("collectionID", request.CollectionID), zap.Int64("msgID", request.Base.MsgID), zap.Int64("collectionID", request.CollectionID),
...@@ -2869,10 +2870,11 @@ func (lpt *loadPartitionsTask) Execute(ctx context.Context) error { ...@@ -2869,10 +2870,11 @@ func (lpt *loadPartitionsTask) Execute(ctx context.Context) error {
Timestamp: lpt.Base.Timestamp, Timestamp: lpt.Base.Timestamp,
SourceID: lpt.Base.SourceID, SourceID: lpt.Base.SourceID,
}, },
DbID: 0, DbID: 0,
CollectionID: collID, CollectionID: collID,
PartitionIDs: partitionIDs, PartitionIDs: partitionIDs,
Schema: collSchema, Schema: collSchema,
ReplicaNumber: lpt.ReplicaNumber,
} }
lpt.result, err = lpt.queryCoord.LoadPartitions(ctx, request) lpt.result, err = lpt.queryCoord.LoadPartitions(ctx, request)
return err return err
......
package proxy
import (
"context"
"errors"
"fmt"
qnClient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"go.uber.org/zap"
)
type getQueryNodePolicy func(context.Context, string) (types.QueryNode, error)
type pickShardPolicy func(ctx context.Context, policy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders *querypb.ShardLeadersList) error
// TODO add another policy to enbale the use of cache
// defaultGetQueryNodePolicy creates QueryNode client for every address everytime
func defaultGetQueryNodePolicy(ctx context.Context, address string) (types.QueryNode, error) {
qn, err := qnClient.NewClient(ctx, address)
if err != nil {
return nil, err
}
if err := qn.Init(); err != nil {
return nil, err
}
if err := qn.Start(); err != nil {
return nil, err
}
return qn, nil
}
var (
errBegin = errors.New("begin error")
errInvalidShardLeaders = errors.New("Invalid shard leader")
)
func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders *querypb.ShardLeadersList) error {
var (
err = errBegin
current = 0
qn types.QueryNode
)
replicaNum := len(leaders.GetNodeIds())
for err != nil && current < replicaNum {
currentID := leaders.GetNodeIds()[current]
if err != errBegin {
log.Warn("retry with another QueryNode", zap.String("leader", leaders.GetChannelName()), zap.Int64("nodeID", currentID))
}
qn, err = getQueryNodePolicy(ctx, leaders.GetNodeAddrs()[current])
if err != nil {
log.Warn("fail to get valid QueryNode", zap.Int64("nodeID", currentID),
zap.Error(err))
current++
continue
}
defer qn.Stop()
err = query(currentID, qn)
if err != nil {
log.Warn("fail to Query with shard leader",
zap.String("leader", leaders.GetChannelName()),
zap.Int64("nodeID", currentID),
zap.Error(err))
}
current++
}
if current == replicaNum && err != nil {
return fmt.Errorf("no shard leaders available for channel: %s, leaders: %v, err: %s", leaders.GetChannelName(), leaders.GetNodeIds(), err.Error())
}
return nil
}
此差异已折叠。
...@@ -3,16 +3,15 @@ package proxy ...@@ -3,16 +3,15 @@ package proxy
import ( import (
"context" "context"
"fmt" "fmt"
"strconv"
"sync"
"testing" "testing"
"time" "time"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus/internal/common" "github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb" "github.com/milvus-io/milvus/internal/proto/internalpb"
...@@ -25,25 +24,35 @@ import ( ...@@ -25,25 +24,35 @@ import (
) )
func TestQueryTask_all(t *testing.T) { func TestQueryTask_all(t *testing.T) {
var err error
Params.Init() Params.Init()
Params.ProxyCfg.RetrieveResultChannelNames = []string{funcutil.GenRandomStr()}
rc := NewRootCoordMock() var (
err error
ctx = context.TODO()
rc = NewRootCoordMock()
qc = NewQueryCoordMock(withValidShardLeaders())
qn = &QueryNodeMock{}
shardsNum = int32(2)
collectionName = t.Name() + funcutil.GenRandomStr()
expr = fmt.Sprintf("%s > 0", testInt64Field)
hitNum = 10
)
mockGetQueryNodePolicy := func(ctx context.Context, address string) (types.QueryNode, error) {
return qn, nil
}
rc.Start() rc.Start()
defer rc.Stop() defer rc.Stop()
qc.Start()
ctx := context.Background() defer qc.Stop()
err = InitMetaCache(rc) err = InitMetaCache(rc)
assert.NoError(t, err) assert.NoError(t, err)
shardsNum := int32(2)
prefix := "TestQueryTask_all"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
fieldName2Types := map[string]schemapb.DataType{ fieldName2Types := map[string]schemapb.DataType{
testBoolField: schemapb.DataType_Bool, testBoolField: schemapb.DataType_Bool,
testInt32Field: schemapb.DataType_Int32, testInt32Field: schemapb.DataType_Int32,
...@@ -56,9 +65,6 @@ func TestQueryTask_all(t *testing.T) { ...@@ -56,9 +65,6 @@ func TestQueryTask_all(t *testing.T) {
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
} }
expr := fmt.Sprintf("%s > 0", testInt64Field)
hitNum := 10
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false) schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false)
marshaledSchema, err := proto.Marshal(schema) marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err) assert.NoError(t, err)
...@@ -66,165 +72,66 @@ func TestQueryTask_all(t *testing.T) { ...@@ -66,165 +72,66 @@ func TestQueryTask_all(t *testing.T) {
createColT := &createCollectionTask{ createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{ CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName, CollectionName: collectionName,
Schema: marshaledSchema, Schema: marshaledSchema,
ShardsNum: shardsNum, ShardsNum: shardsNum,
}, },
ctx: ctx, ctx: ctx,
rootCoord: rc, rootCoord: rc,
result: nil,
schema: nil,
} }
assert.NoError(t, createColT.OnEnqueue()) require.NoError(t, createColT.OnEnqueue())
assert.NoError(t, createColT.PreExecute(ctx)) require.NoError(t, createColT.PreExecute(ctx))
assert.NoError(t, createColT.Execute(ctx)) require.NoError(t, createColT.Execute(ctx))
assert.NoError(t, createColT.PostExecute(ctx)) require.NoError(t, createColT.PostExecute(ctx))
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
query := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
defer chMgr.removeAllDMLStream()
defer chMgr.removeAllDQLStream()
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName) collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
assert.NoError(t, err) assert.NoError(t, err)
qc := NewQueryCoordMock()
qc.Start()
defer qc.Stop()
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{ status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection, MsgType: commonpb.MsgType_LoadCollection,
MsgID: 0, SourceID: Params.ProxyCfg.ProxyID,
Timestamp: 0,
SourceID: Params.ProxyCfg.ProxyID,
}, },
DbID: 0,
CollectionID: collectionID, CollectionID: collectionID,
Schema: nil,
}) })
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode) require.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
// test begins
task := &queryTask{ task := &queryTask{
Condition: NewTaskCondition(ctx), Condition: NewTaskCondition(ctx),
RetrieveRequest: &internalpb.RetrieveRequest{ RetrieveRequest: &internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve, MsgType: commonpb.MsgType_Retrieve,
MsgID: 0, SourceID: Params.ProxyCfg.ProxyID,
Timestamp: 0,
SourceID: Params.ProxyCfg.ProxyID,
}, },
ResultChannelID: strconv.Itoa(int(Params.ProxyCfg.ProxyID)), CollectionID: collectionID,
DbID: 0, OutputFieldsId: make([]int64, len(fieldName2Types)),
CollectionID: collectionID,
PartitionIDs: nil,
SerializedExprPlan: nil,
OutputFieldsId: make([]int64, len(fieldName2Types)),
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
}, },
ctx: ctx, ctx: ctx,
resultBuf: make(chan []*internalpb.RetrieveResults),
result: &milvuspb.QueryResults{ result: &milvuspb.QueryResults{
Status: &commonpb.Status{ Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success, ErrorCode: commonpb.ErrorCode_Success,
}, },
FieldsData: nil,
}, },
query: &milvuspb.QueryRequest{ request: &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve, MsgType: commonpb.MsgType_Retrieve,
MsgID: 0, SourceID: Params.ProxyCfg.ProxyID,
Timestamp: 0,
SourceID: Params.ProxyCfg.ProxyID,
}, },
DbName: dbName, CollectionName: collectionName,
CollectionName: collectionName, Expr: expr,
Expr: expr,
OutputFields: nil,
PartitionNames: nil,
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
}, },
chMgr: chMgr, qc: qc,
qc: qc,
ids: nil, getQueryNodePolicy: mockGetQueryNodePolicy,
queryShardPolicy: roundRobinPolicy,
} }
for i := 0; i < len(fieldName2Types); i++ { for i := 0; i < len(fieldName2Types); i++ {
task.RetrieveRequest.OutputFieldsId[i] = int64(common.StartOfUserFieldID + i) task.RetrieveRequest.OutputFieldsId[i] = int64(common.StartOfUserFieldID + i)
} }
// simple mock for query node
// TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream?
err = chMgr.createDQLStream(collectionID)
assert.NoError(t, err)
stream, err := chMgr.getDQLStream(collectionID)
assert.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
consumeCtx, cancel := context.WithCancel(ctx)
go func() {
defer wg.Done()
for {
select {
case <-consumeCtx.Done():
return
case pack, ok := <-stream.Chan():
assert.True(t, ok)
if pack == nil {
continue
}
for _, msg := range pack.Msgs {
_, ok := msg.(*msgstream.RetrieveMsg)
assert.True(t, ok)
// TODO(dragondriver): construct result according to the request
result1 := &internalpb.RetrieveResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RetrieveResult,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
ResultChannelID: strconv.Itoa(int(Params.ProxyCfg.ProxyID)),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: generateInt64Array(hitNum),
},
},
},
SealedSegmentIDsRetrieved: nil,
ChannelIDsRetrieved: nil,
GlobalSealedSegmentIDs: nil,
}
fieldID := common.StartOfUserFieldID
for fieldName, dataType := range fieldName2Types {
result1.FieldsData = append(result1.FieldsData, generateFieldData(dataType, fieldName, int64(fieldID), hitNum))
fieldID++
}
// send search result
task.resultBuf <- []*internalpb.RetrieveResults{result1}
}
}
}
}()
assert.NoError(t, task.OnEnqueue()) assert.NoError(t, task.OnEnqueue())
// test query task with timeout // test query task with timeout
...@@ -236,11 +143,29 @@ func TestQueryTask_all(t *testing.T) { ...@@ -236,11 +143,29 @@ func TestQueryTask_all(t *testing.T) {
assert.NoError(t, task.PreExecute(ctx)) assert.NoError(t, task.PreExecute(ctx))
// after preExecute // after preExecute
assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp) assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp)
task.ctx = ctx
result1 := &internalpb.RetrieveResults{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_RetrieveResult},
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{Data: generateInt64Array(hitNum)},
},
},
}
fieldID := common.StartOfUserFieldID
for fieldName, dataType := range fieldName2Types {
result1.FieldsData = append(result1.FieldsData, generateFieldData(dataType, fieldName, int64(fieldID), hitNum))
fieldID++
}
qn.withQueryResult = result1
task.ctx = ctx
assert.NoError(t, task.Execute(ctx)) assert.NoError(t, task.Execute(ctx))
assert.NoError(t, task.PostExecute(ctx))
cancel() assert.NoError(t, task.PostExecute(ctx))
wg.Wait()
} }
...@@ -23,16 +23,10 @@ import ( ...@@ -23,16 +23,10 @@ import (
"fmt" "fmt"
"sync" "sync"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/internal/util/funcutil"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/trace" "github.com/milvus-io/milvus/internal/util/trace"
"github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go"
oplog "github.com/opentracing/opentracing-go/log" oplog "github.com/opentracing/opentracing-go/log"
...@@ -385,25 +379,10 @@ type taskScheduler struct { ...@@ -385,25 +379,10 @@ type taskScheduler struct {
cancel context.CancelFunc cancel context.CancelFunc
msFactory msgstream.Factory msFactory msgstream.Factory
searchResultCh chan *internalpb.SearchResults
retrieveResultCh chan *internalpb.RetrieveResults
} }
type schedOpt func(*taskScheduler) type schedOpt func(*taskScheduler)
func schedOptWithSearchResultCh(ch chan *internalpb.SearchResults) schedOpt {
return func(sched *taskScheduler) {
sched.searchResultCh = ch
}
}
func schedOptWithRetrieveResultCh(ch chan *internalpb.RetrieveResults) schedOpt {
return func(sched *taskScheduler) {
sched.retrieveResultCh = ch
}
}
func newTaskScheduler(ctx context.Context, func newTaskScheduler(ctx context.Context,
idAllocatorIns idAllocatorInterface, idAllocatorIns idAllocatorInterface,
tsoAllocatorIns tsoAllocator, tsoAllocatorIns tsoAllocator,
...@@ -551,265 +530,6 @@ func (sched *taskScheduler) queryLoop() { ...@@ -551,265 +530,6 @@ func (sched *taskScheduler) queryLoop() {
} }
} }
type resultBufHeader struct {
msgID UniqueID
usedVChans map[interface{}]struct{} // set of vChan
receivedVChansSet map[interface{}]struct{} // set of vChan
receivedSealedSegmentIDsSet map[interface{}]struct{} // set of UniqueID
receivedGlobalSegmentIDsSet map[interface{}]struct{} // set of UniqueID
haveError bool
}
type searchResultBuf struct {
resultBufHeader
resultBuf []*internalpb.SearchResults
}
type queryResultBuf struct {
resultBufHeader
resultBuf []*internalpb.RetrieveResults
}
func newSearchResultBuf(msgID UniqueID) *searchResultBuf {
return &searchResultBuf{
resultBufHeader: resultBufHeader{
usedVChans: make(map[interface{}]struct{}),
receivedVChansSet: make(map[interface{}]struct{}),
receivedSealedSegmentIDsSet: make(map[interface{}]struct{}),
receivedGlobalSegmentIDsSet: make(map[interface{}]struct{}),
haveError: false,
msgID: msgID,
},
resultBuf: make([]*internalpb.SearchResults, 0),
}
}
func newQueryResultBuf(msgID UniqueID) *queryResultBuf {
return &queryResultBuf{
resultBufHeader: resultBufHeader{
usedVChans: make(map[interface{}]struct{}),
receivedVChansSet: make(map[interface{}]struct{}),
receivedSealedSegmentIDsSet: make(map[interface{}]struct{}),
receivedGlobalSegmentIDsSet: make(map[interface{}]struct{}),
haveError: false,
msgID: msgID,
},
resultBuf: make([]*internalpb.RetrieveResults, 0),
}
}
func (sr *resultBufHeader) readyToReduce() bool {
if sr.haveError {
log.Debug("Proxy searchResultBuf readyToReduce", zap.Any("haveError", true))
return true
}
log.Debug("check if result buf is ready to reduce",
zap.String("role", typeutil.ProxyRole),
zap.Int64("MsgID", sr.msgID),
zap.Any("receivedVChansSet", funcutil.SetToSlice(sr.receivedVChansSet)),
zap.Any("usedVChans", funcutil.SetToSlice(sr.usedVChans)),
zap.Any("receivedSealedSegmentIDsSet", funcutil.SetToSlice(sr.receivedSealedSegmentIDsSet)),
zap.Any("receivedGlobalSegmentIDsSet", funcutil.SetToSlice(sr.receivedGlobalSegmentIDsSet)))
ret1 := funcutil.SetContain(sr.receivedVChansSet, sr.usedVChans)
if !ret1 {
return false
}
return funcutil.SetContain(sr.receivedSealedSegmentIDsSet, sr.receivedGlobalSegmentIDsSet)
}
func (sr *resultBufHeader) addPartialResult(vchans []vChan, searchSegIDs, globalSegIDs []UniqueID) {
for _, vchan := range vchans {
sr.receivedVChansSet[vchan] = struct{}{}
}
for _, sealedSegment := range searchSegIDs {
sr.receivedSealedSegmentIDsSet[sealedSegment] = struct{}{}
}
for _, globalSegment := range globalSegIDs {
sr.receivedGlobalSegmentIDsSet[globalSegment] = struct{}{}
}
}
func (sr *searchResultBuf) addPartialResult(result *internalpb.SearchResults) {
sr.resultBuf = append(sr.resultBuf, result)
if result.Status.ErrorCode != commonpb.ErrorCode_Success {
sr.haveError = true
return
}
sr.resultBufHeader.addPartialResult(result.ChannelIDsSearched, result.SealedSegmentIDsSearched,
result.GlobalSealedSegmentIDs)
}
func (qr *queryResultBuf) addPartialResult(result *internalpb.RetrieveResults) {
qr.resultBuf = append(qr.resultBuf, result)
if result.Status.ErrorCode != commonpb.ErrorCode_Success {
qr.haveError = true
return
}
qr.resultBufHeader.addPartialResult(result.ChannelIDsRetrieved, result.SealedSegmentIDsRetrieved,
result.GlobalSealedSegmentIDs)
}
func (sched *taskScheduler) collectionResultLoopV2() {
defer sched.wg.Done()
searchResultBufs := make(map[UniqueID]*searchResultBuf)
searchResultBufFlags := newIDCache(Params.ProxyCfg.BufFlagExpireTime, Params.ProxyCfg.BufFlagCleanupInterval) // if value is true, we can ignore searchResult
queryResultBufs := make(map[UniqueID]*queryResultBuf)
queryResultBufFlags := newIDCache(Params.ProxyCfg.BufFlagExpireTime, Params.ProxyCfg.BufFlagCleanupInterval) // if value is true, we can ignore queryResult
processSearchResult := func(results *internalpb.SearchResults) error {
reqID := results.Base.MsgID
ignoreThisResult, ok := searchResultBufFlags.Get(reqID)
if !ok {
searchResultBufFlags.Set(reqID, false)
ignoreThisResult = false
}
if ignoreThisResult {
log.Debug("got a search result, but we should ignore", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
return nil
}
log.Debug("got a search result", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
t := sched.getTaskByReqID(reqID)
if t == nil {
log.Debug("got a search result, but not in task scheduler", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
delete(searchResultBufs, reqID)
searchResultBufFlags.Set(reqID, true)
}
st, ok := t.(*searchTask)
if !ok {
log.Debug("got a search result, but the related task is not of search task", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
delete(searchResultBufs, reqID)
searchResultBufFlags.Set(reqID, true)
return nil
}
resultBuf, ok := searchResultBufs[reqID]
if !ok {
log.Debug("first receive search result of this task", zap.String("role", typeutil.ProxyRole), zap.Int64("reqID", reqID))
resultBuf = newSearchResultBuf(reqID)
vchans, err := st.getVChannels()
if err != nil {
delete(searchResultBufs, reqID)
log.Warn("failed to get virtual channels", zap.String("role", typeutil.ProxyRole), zap.Error(err), zap.Int64("reqID", reqID))
return err
}
for _, vchan := range vchans {
resultBuf.usedVChans[vchan] = struct{}{}
}
searchResultBufs[reqID] = resultBuf
}
resultBuf.addPartialResult(results)
colName := t.(*searchTask).query.CollectionName
log.Debug("process search result", zap.String("role", typeutil.ProxyRole), zap.String("collection", colName), zap.Int64("reqID", reqID), zap.Int("answer cnt", len(searchResultBufs[reqID].resultBuf)))
if resultBuf.readyToReduce() {
log.Debug("process search result, ready to reduce", zap.String("role", typeutil.ProxyRole), zap.Int64("reqID", reqID))
searchResultBufFlags.Set(reqID, true)
st.resultBuf <- resultBuf.resultBuf
delete(searchResultBufs, reqID)
}
return nil
}
processRetrieveResult := func(results *internalpb.RetrieveResults) error {
reqID := results.Base.MsgID
ignoreThisResult, ok := queryResultBufFlags.Get(reqID)
if !ok {
queryResultBufFlags.Set(reqID, false)
ignoreThisResult = false
}
if ignoreThisResult {
log.Debug("got a retrieve result, but we should ignore", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
return nil
}
log.Debug("got a retrieve result", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
t := sched.getTaskByReqID(reqID)
if t == nil {
log.Debug("got a retrieve result, but not in task scheduler", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
delete(queryResultBufs, reqID)
queryResultBufFlags.Set(reqID, true)
}
st, ok := t.(*queryTask)
if !ok {
log.Debug("got a retrieve result, but the related task is not of retrieve task", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
delete(queryResultBufs, reqID)
queryResultBufFlags.Set(reqID, true)
return nil
}
resultBuf, ok := queryResultBufs[reqID]
if !ok {
log.Debug("first receive retrieve result of this task", zap.String("role", typeutil.ProxyRole), zap.Int64("reqID", reqID))
resultBuf = newQueryResultBuf(reqID)
vchans, err := st.getVChannels()
if err != nil {
delete(queryResultBufs, reqID)
log.Warn("failed to get virtual channels", zap.String("role", typeutil.ProxyRole), zap.Error(err), zap.Int64("reqID", reqID))
return err
}
for _, vchan := range vchans {
resultBuf.usedVChans[vchan] = struct{}{}
}
queryResultBufs[reqID] = resultBuf
}
resultBuf.addPartialResult(results)
colName := t.(*queryTask).query.CollectionName
log.Debug("process retrieve result", zap.String("role", typeutil.ProxyRole), zap.String("collection", colName), zap.Int64("reqID", reqID), zap.Int("answer cnt", len(queryResultBufs[reqID].resultBuf)))
if resultBuf.readyToReduce() {
log.Debug("process retrieve result, ready to reduce", zap.String("role", typeutil.ProxyRole), zap.Int64("reqID", reqID))
queryResultBufFlags.Set(reqID, true)
st.resultBuf <- resultBuf.resultBuf
delete(queryResultBufs, reqID)
}
return nil
}
for {
select {
case <-sched.ctx.Done():
log.Info("task scheduler's result loop of Proxy exit", zap.String("reason", "context done"))
return
case sr, ok := <-sched.searchResultCh:
if !ok {
log.Info("task scheduler's result loop of Proxy exit", zap.String("reason", "search result channel closed"))
return
}
if err := processSearchResult(sr); err != nil {
log.Warn("failed to process search result", zap.Error(err))
}
case rr, ok := <-sched.retrieveResultCh:
if !ok {
log.Info("task scheduler's result loop of Proxy exit", zap.String("reason", "retrieve result channel closed"))
return
}
if err := processRetrieveResult(rr); err != nil {
log.Warn("failed to process retrieve result", zap.Error(err))
}
}
}
}
func (sched *taskScheduler) Start() error { func (sched *taskScheduler) Start() error {
sched.wg.Add(1) sched.wg.Add(1)
go sched.definitionLoop() go sched.definitionLoop()
...@@ -820,10 +540,6 @@ func (sched *taskScheduler) Start() error { ...@@ -820,10 +540,6 @@ func (sched *taskScheduler) Start() error {
sched.wg.Add(1) sched.wg.Add(1)
go sched.queryLoop() go sched.queryLoop()
sched.wg.Add(1)
// go sched.collectResultLoop()
go sched.collectionResultLoopV2()
return nil return nil
} }
......
此差异已折叠。
此差异已折叠。
...@@ -30,6 +30,7 @@ import ( ...@@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus/internal/log" "github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgstream" "github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/querypb" "github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/funcutil" "github.com/milvus-io/milvus/internal/util/funcutil"
) )
...@@ -51,7 +52,7 @@ type channelUnsubscribeHandler struct { ...@@ -51,7 +52,7 @@ type channelUnsubscribeHandler struct {
} }
// newChannelUnsubscribeHandler create a new handler service to unsubscribe channels // newChannelUnsubscribeHandler create a new handler service to unsubscribe channels
func newChannelUnsubscribeHandler(ctx context.Context, kv *etcdkv.EtcdKV, factory msgstream.Factory) (*channelUnsubscribeHandler, error) { func newChannelUnsubscribeHandler(ctx context.Context, kv *etcdkv.EtcdKV, factory dependency.Factory) (*channelUnsubscribeHandler, error) {
childCtx, cancel := context.WithCancel(ctx) childCtx, cancel := context.WithCancel(ctx)
handler := &channelUnsubscribeHandler{ handler := &channelUnsubscribeHandler{
ctx: childCtx, ctx: childCtx,
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册