未验证 提交 69252f81 编写于 作者: C congqixia 提交者: GitHub

Implement memory replica in Proxy, QueryNode and QueryCoord (#16470)

Related to #16298 #16291 #16154
Co-authored-by: Nsunby <bingyi.sun@zilliz.com>
Co-authored-by: Nyangxuan <xuan.yang@zilliz.com>
Co-authored-by: Nyah01 <yang.cen@zilliz.com>
Co-authored-by: NLetian Jiang <letian.jiang@zilliz.com>
Signed-off-by: NCongqi Xia <congqi.xia@zilliz.com>
上级 9aa557f8
......@@ -290,7 +290,7 @@ const char descriptor_table_protodef_common_2eproto[] PROTOBUF_SECTION_VARIABLE(
"led\020\004*\202\001\n\014SegmentState\022\024\n\020SegmentStateNo"
"ne\020\000\022\014\n\010NotExist\020\001\022\013\n\007Growing\020\002\022\n\n\006Seale"
"d\020\003\022\013\n\007Flushed\020\004\022\014\n\010Flushing\020\005\022\013\n\007Droppe"
"d\020\006\022\r\n\tImporting\020\007*\261\n\n\007MsgType\022\r\n\tUndefi"
"d\020\006\022\r\n\tImporting\020\007*\307\n\n\007MsgType\022\r\n\tUndefi"
"ned\020\000\022\024\n\020CreateCollection\020d\022\022\n\016DropColle"
"ction\020e\022\021\n\rHasCollection\020f\022\026\n\022DescribeCo"
"llection\020g\022\023\n\017ShowCollections\020h\022\024\n\020GetSy"
......@@ -314,27 +314,28 @@ const char descriptor_table_protodef_common_2eproto[] PROTOBUF_SECTION_VARIABLE(
"atchDmChannels\020\374\003\022\025\n\020RemoveDmChannels\020\375\003"
"\022\027\n\022WatchQueryChannels\020\376\003\022\030\n\023RemoveQuery"
"Channels\020\377\003\022\035\n\030SealedSegmentsChangeInfo\020"
"\200\004\022\027\n\022WatchDeltaChannels\020\201\004\022\020\n\013SegmentIn"
"fo\020\330\004\022\017\n\nSystemInfo\020\331\004\022\024\n\017GetRecoveryInf"
"o\020\332\004\022\024\n\017GetSegmentState\020\333\004\022\r\n\010TimeTick\020\260"
"\t\022\023\n\016QueryNodeStats\020\261\t\022\016\n\tLoadIndex\020\262\t\022\016"
"\n\tRequestID\020\263\t\022\017\n\nRequestTSO\020\264\t\022\024\n\017Alloc"
"ateSegment\020\265\t\022\026\n\021SegmentStatistics\020\266\t\022\025\n"
"\020SegmentFlushDone\020\267\t\022\017\n\nDataNodeTt\020\270\t\022\025\n"
"\020CreateCredential\020\334\013\022\022\n\rGetCredential\020\335\013"
"\022\025\n\020DeleteCredential\020\336\013\022\025\n\020UpdateCredent"
"ial\020\337\013\022\026\n\021ListCredUsernames\020\340\013*\"\n\007DslTyp"
"e\022\007\n\003Dsl\020\000\022\016\n\nBoolExprV1\020\001*B\n\017Compaction"
"State\022\021\n\rUndefiedState\020\000\022\r\n\tExecuting\020\001\022"
"\r\n\tCompleted\020\002*X\n\020ConsistencyLevel\022\n\n\006St"
"rong\020\000\022\013\n\007Session\020\001\022\013\n\007Bounded\020\002\022\016\n\nEven"
"tually\020\003\022\016\n\nCustomized\020\004*\227\001\n\013ImportState"
"\022\021\n\rImportPending\020\000\022\020\n\014ImportFailed\020\001\022\021\n"
"\rImportStarted\020\002\022\024\n\020ImportDownloaded\020\003\022\020"
"\n\014ImportParsed\020\004\022\023\n\017ImportPersisted\020\005\022\023\n"
"\017ImportCompleted\020\006BW\n\016io.milvus.grpcB\013Co"
"mmonProtoP\001Z3github.com/milvus-io/milvus"
"/internal/proto/commonpb\240\001\001b\006proto3"
"\200\004\022\027\n\022WatchDeltaChannels\020\201\004\022\024\n\017GetShardL"
"eaders\020\202\004\022\020\n\013SegmentInfo\020\330\004\022\017\n\nSystemInf"
"o\020\331\004\022\024\n\017GetRecoveryInfo\020\332\004\022\024\n\017GetSegment"
"State\020\333\004\022\r\n\010TimeTick\020\260\t\022\023\n\016QueryNodeStat"
"s\020\261\t\022\016\n\tLoadIndex\020\262\t\022\016\n\tRequestID\020\263\t\022\017\n\n"
"RequestTSO\020\264\t\022\024\n\017AllocateSegment\020\265\t\022\026\n\021S"
"egmentStatistics\020\266\t\022\025\n\020SegmentFlushDone\020"
"\267\t\022\017\n\nDataNodeTt\020\270\t\022\025\n\020CreateCredential\020"
"\334\013\022\022\n\rGetCredential\020\335\013\022\025\n\020DeleteCredenti"
"al\020\336\013\022\025\n\020UpdateCredential\020\337\013\022\026\n\021ListCred"
"Usernames\020\340\013*\"\n\007DslType\022\007\n\003Dsl\020\000\022\016\n\nBool"
"ExprV1\020\001*B\n\017CompactionState\022\021\n\rUndefiedS"
"tate\020\000\022\r\n\tExecuting\020\001\022\r\n\tCompleted\020\002*X\n\020"
"ConsistencyLevel\022\n\n\006Strong\020\000\022\013\n\007Session\020"
"\001\022\013\n\007Bounded\020\002\022\016\n\nEventually\020\003\022\016\n\nCustom"
"ized\020\004*\227\001\n\013ImportState\022\021\n\rImportPending\020"
"\000\022\020\n\014ImportFailed\020\001\022\021\n\rImportStarted\020\002\022\024"
"\n\020ImportDownloaded\020\003\022\020\n\014ImportParsed\020\004\022\023"
"\n\017ImportPersisted\020\005\022\023\n\017ImportCompleted\020\006"
"BW\n\016io.milvus.grpcB\013CommonProtoP\001Z3githu"
"b.com/milvus-io/milvus/internal/proto/co"
"mmonpb\240\001\001b\006proto3"
;
static const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable*const descriptor_table_common_2eproto_deps[1] = {
};
......@@ -351,7 +352,7 @@ static ::PROTOBUF_NAMESPACE_ID::internal::SCCInfoBase*const descriptor_table_com
static ::PROTOBUF_NAMESPACE_ID::internal::once_flag descriptor_table_common_2eproto_once;
static bool descriptor_table_common_2eproto_initialized = false;
const ::PROTOBUF_NAMESPACE_ID::internal::DescriptorTable descriptor_table_common_2eproto = {
&descriptor_table_common_2eproto_initialized, descriptor_table_protodef_common_2eproto, "common.proto", 3275,
&descriptor_table_common_2eproto_initialized, descriptor_table_protodef_common_2eproto, "common.proto", 3297,
&descriptor_table_common_2eproto_once, descriptor_table_common_2eproto_sccs, descriptor_table_common_2eproto_deps, 8, 0,
schemas, file_default_instances, TableStruct_common_2eproto::offsets,
file_level_metadata_common_2eproto, 8, file_level_enum_descriptors_common_2eproto, file_level_service_descriptors_common_2eproto,
......@@ -497,6 +498,7 @@ bool MsgType_IsValid(int value) {
case 511:
case 512:
case 513:
case 514:
case 600:
case 601:
case 602:
......
......@@ -262,6 +262,7 @@ enum MsgType : int {
RemoveQueryChannels = 511,
SealedSegmentsChangeInfo = 512,
WatchDeltaChannels = 513,
GetShardLeaders = 514,
SegmentInfo = 600,
SystemInfo = 601,
GetRecoveryInfo = 602,
......
此差异已折叠。
......@@ -3888,6 +3888,7 @@ class LoadPartitionsRequest :
kDbNameFieldNumber = 2,
kCollectionNameFieldNumber = 3,
kBaseFieldNumber = 1,
kReplicaNumberFieldNumber = 5,
};
// repeated string partition_names = 4;
int partition_names_size() const;
......@@ -3936,6 +3937,11 @@ class LoadPartitionsRequest :
::milvus::proto::common::MsgBase* mutable_base();
void set_allocated_base(::milvus::proto::common::MsgBase* base);
// int32 replica_number = 5;
void clear_replica_number();
::PROTOBUF_NAMESPACE_ID::int32 replica_number() const;
void set_replica_number(::PROTOBUF_NAMESPACE_ID::int32 value);
// @@protoc_insertion_point(class_scope:milvus.proto.milvus.LoadPartitionsRequest)
private:
class _Internal;
......@@ -3945,6 +3951,7 @@ class LoadPartitionsRequest :
::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr db_name_;
::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr collection_name_;
::milvus::proto::common::MsgBase* base_;
::PROTOBUF_NAMESPACE_ID::int32 replica_number_;
mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_;
friend struct ::TableStruct_milvus_2eproto;
};
......@@ -11873,6 +11880,7 @@ class LoadBalanceRequest :
enum : int {
kDstNodeIDsFieldNumber = 3,
kSealedSegmentIDsFieldNumber = 4,
kCollectionNameFieldNumber = 5,
kBaseFieldNumber = 1,
kSrcNodeIDFieldNumber = 2,
};
......@@ -11898,6 +11906,17 @@ class LoadBalanceRequest :
::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 >*
mutable_sealed_segmentids();
// string collectionName = 5;
void clear_collectionname();
const std::string& collectionname() const;
void set_collectionname(const std::string& value);
void set_collectionname(std::string&& value);
void set_collectionname(const char* value);
void set_collectionname(const char* value, size_t size);
std::string* mutable_collectionname();
std::string* release_collectionname();
void set_allocated_collectionname(std::string* collectionname);
// .milvus.proto.common.MsgBase base = 1;
bool has_base() const;
void clear_base();
......@@ -11920,6 +11939,7 @@ class LoadBalanceRequest :
mutable std::atomic<int> _dst_nodeids_cached_byte_size_;
::PROTOBUF_NAMESPACE_ID::RepeatedField< ::PROTOBUF_NAMESPACE_ID::int64 > sealed_segmentids_;
mutable std::atomic<int> _sealed_segmentids_cached_byte_size_;
::PROTOBUF_NAMESPACE_ID::internal::ArenaStringPtr collectionname_;
::milvus::proto::common::MsgBase* base_;
::PROTOBUF_NAMESPACE_ID::int64 src_nodeid_;
mutable ::PROTOBUF_NAMESPACE_ID::internal::CachedSize _cached_size_;
......@@ -18987,6 +19007,20 @@ LoadPartitionsRequest::mutable_partition_names() {
return &partition_names_;
}
// int32 replica_number = 5;
inline void LoadPartitionsRequest::clear_replica_number() {
replica_number_ = 0;
}
inline ::PROTOBUF_NAMESPACE_ID::int32 LoadPartitionsRequest::replica_number() const {
// @@protoc_insertion_point(field_get:milvus.proto.milvus.LoadPartitionsRequest.replica_number)
return replica_number_;
}
inline void LoadPartitionsRequest::set_replica_number(::PROTOBUF_NAMESPACE_ID::int32 value) {
replica_number_ = value;
// @@protoc_insertion_point(field_set:milvus.proto.milvus.LoadPartitionsRequest.replica_number)
}
// -------------------------------------------------------------------
// ReleasePartitionsRequest
......@@ -26343,6 +26377,57 @@ LoadBalanceRequest::mutable_sealed_segmentids() {
return &sealed_segmentids_;
}
// string collectionName = 5;
inline void LoadBalanceRequest::clear_collectionname() {
collectionname_.ClearToEmptyNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
}
inline const std::string& LoadBalanceRequest::collectionname() const {
// @@protoc_insertion_point(field_get:milvus.proto.milvus.LoadBalanceRequest.collectionName)
return collectionname_.GetNoArena();
}
inline void LoadBalanceRequest::set_collectionname(const std::string& value) {
collectionname_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), value);
// @@protoc_insertion_point(field_set:milvus.proto.milvus.LoadBalanceRequest.collectionName)
}
inline void LoadBalanceRequest::set_collectionname(std::string&& value) {
collectionname_.SetNoArena(
&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::move(value));
// @@protoc_insertion_point(field_set_rvalue:milvus.proto.milvus.LoadBalanceRequest.collectionName)
}
inline void LoadBalanceRequest::set_collectionname(const char* value) {
GOOGLE_DCHECK(value != nullptr);
collectionname_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), ::std::string(value));
// @@protoc_insertion_point(field_set_char:milvus.proto.milvus.LoadBalanceRequest.collectionName)
}
inline void LoadBalanceRequest::set_collectionname(const char* value, size_t size) {
collectionname_.SetNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(),
::std::string(reinterpret_cast<const char*>(value), size));
// @@protoc_insertion_point(field_set_pointer:milvus.proto.milvus.LoadBalanceRequest.collectionName)
}
inline std::string* LoadBalanceRequest::mutable_collectionname() {
// @@protoc_insertion_point(field_mutable:milvus.proto.milvus.LoadBalanceRequest.collectionName)
return collectionname_.MutableNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
}
inline std::string* LoadBalanceRequest::release_collectionname() {
// @@protoc_insertion_point(field_release:milvus.proto.milvus.LoadBalanceRequest.collectionName)
return collectionname_.ReleaseNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited());
}
inline void LoadBalanceRequest::set_allocated_collectionname(std::string* collectionname) {
if (collectionname != nullptr) {
} else {
}
collectionname_.SetAllocatedNoArena(&::PROTOBUF_NAMESPACE_ID::internal::GetEmptyStringAlreadyInited(), collectionname);
// @@protoc_insertion_point(field_set_allocated:milvus.proto.milvus.LoadBalanceRequest.collectionName)
}
// -------------------------------------------------------------------
// ManualCompactionRequest
......@@ -528,6 +528,7 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf
segment2Binlogs := make(map[UniqueID][]*datapb.FieldBinlog)
segment2StatsBinlogs := make(map[UniqueID][]*datapb.FieldBinlog)
segment2DeltaBinlogs := make(map[UniqueID][]*datapb.FieldBinlog)
segment2InsertChannel := make(map[UniqueID]string)
segmentsNumOfRows := make(map[UniqueID]int64)
flushedIDs := make(map[int64]struct{})
......@@ -542,6 +543,7 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf
if segment.State != commonpb.SegmentState_Flushed && segment.State != commonpb.SegmentState_Flushing {
continue
}
segment2InsertChannel[segment.ID] = segment.InsertChannel
binlogs := segment.GetBinlogs()
if len(binlogs) == 0 {
......@@ -590,11 +592,12 @@ func (s *Server) GetRecoveryInfo(ctx context.Context, req *datapb.GetRecoveryInf
binlogs := make([]*datapb.SegmentBinlogs, 0, len(segment2Binlogs))
for segmentID := range flushedIDs {
sbl := &datapb.SegmentBinlogs{
SegmentID: segmentID,
NumOfRows: segmentsNumOfRows[segmentID],
FieldBinlogs: segment2Binlogs[segmentID],
Statslogs: segment2StatsBinlogs[segmentID],
Deltalogs: segment2DeltaBinlogs[segmentID],
SegmentID: segmentID,
NumOfRows: segmentsNumOfRows[segmentID],
FieldBinlogs: segment2Binlogs[segmentID],
Statslogs: segment2StatsBinlogs[segmentID],
Deltalogs: segment2DeltaBinlogs[segmentID],
InsertChannel: segment2InsertChannel[segmentID],
}
binlogs = append(binlogs, sbl)
}
......
......@@ -21,7 +21,6 @@ import (
"errors"
"testing"
"github.com/milvus-io/milvus/internal/proxy"
"github.com/milvus-io/milvus/internal/util/mock"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/stretchr/testify/assert"
......@@ -29,7 +28,7 @@ import (
)
func Test_NewClient(t *testing.T) {
proxy.Params.InitOnce()
ClientParams.InitOnce(typeutil.QueryNodeRole)
ctx := context.Background()
client, err := NewClient(ctx, "")
......
......@@ -149,6 +149,7 @@ enum MsgType {
RemoveQueryChannels = 511;
SealedSegmentsChangeInfo = 512;
WatchDeltaChannels = 513;
GetShardLeaders = 514;
/* DATA SERVICE */
SegmentInfo = 600;
......@@ -221,4 +222,4 @@ enum ImportState {
ImportParsed = 4;
ImportPersisted = 5;
ImportCompleted = 6;
}
\ No newline at end of file
}
......@@ -303,6 +303,7 @@ message SegmentBinlogs {
int64 num_of_rows = 3;
repeated FieldBinlog statslogs = 4;
repeated FieldBinlog deltalogs = 5;
string insert_channel = 6;
}
message FieldBinlog{
......@@ -367,6 +368,7 @@ message CompactionSegmentBinlogs {
repeated FieldBinlog fieldBinlogs = 2;
repeated FieldBinlog field2StatslogPaths = 3;
repeated FieldBinlog deltalogs = 4;
string insert_channel = 5;
}
message CompactionPlan {
......
......@@ -343,6 +343,8 @@ message LoadPartitionsRequest {
string collection_name = 3;
// The partition names you want to load
repeated string partition_names = 4;
// The replicas number you would load, 1 by default
int32 replica_number = 5;
}
/*
......@@ -755,6 +757,7 @@ message LoadBalanceRequest {
int64 src_nodeID = 2;
repeated int64 dst_nodeIDs = 3;
repeated int64 sealed_segmentIDs = 4;
string collectionName = 5;
}
message ManualCompactionRequest {
......@@ -857,8 +860,6 @@ message ShardReplica {
repeated int64 node_ids = 4;
}
service ProxyService {
rpc RegisterLink(RegisterLinkRequest) returns (RegisterLinkResponse) {}
}
......@@ -911,3 +912,4 @@ message ListCredUsersRequest {
// Not useful for now
common.MsgBase base = 1;
}
......@@ -201,6 +201,7 @@ message WatchDmChannelsRequest {
schema.CollectionSchema schema = 6;
repeated data.SegmentInfo exclude_infos = 7;
LoadMetaInfo load_meta = 8;
int64 replicaID = 9;
}
message WatchDeltaChannelsRequest {
......@@ -224,6 +225,7 @@ message SegmentLoadInfo {
repeated int64 compactionFrom = 10; // segmentIDs compacted from
repeated FieldIndexInfo index_infos = 11;
int64 segment_size = 12;
string insert_channel = 13;
}
message FieldIndexInfo {
......@@ -245,6 +247,7 @@ message LoadSegmentsRequest {
int64 source_nodeID = 5;
int64 collectionID = 6;
LoadMetaInfo load_meta = 7;
int64 replicaID = 8;
}
message ReleaseSegmentsRequest {
......@@ -281,6 +284,7 @@ message LoadBalanceRequest {
TriggerCondition balance_reason = 3;
repeated int64 dst_nodeIDs = 4;
repeated int64 sealed_segmentIDs = 5;
int64 collectionID = 6;
}
//-------------------- internal meta proto------------------
......@@ -312,6 +316,7 @@ message DmChannelWatchInfo {
int64 collectionID = 1;
string dmChannel = 2;
int64 nodeID_loaded = 3;
int64 replicaID = 4;
}
message QueryChannelInfo {
......@@ -380,4 +385,3 @@ message SealedSegmentsChangeInfo {
common.MsgBase base = 1;
repeated SegmentChangeInfo infos = 2;
}
......@@ -31,6 +31,7 @@ import (
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
......@@ -144,8 +145,8 @@ func (node *Proxy) ReleaseDQLMessageStream(ctx context.Context, request *proxypb
}, nil
}
// TODO(dragondriver): add more detailed ut for ConsistencyLevel, should we support multiple consistency level in Proxy?
// CreateCollection create a collection by the schema.
// TODO(dragondriver): add more detailed ut for ConsistencyLevel, should we support multiple consistency level in Proxy?
func (node *Proxy) CreateCollection(ctx context.Context, request *milvuspb.CreateCollectionRequest) (*commonpb.Status, error) {
if !node.checkHealthy() {
return unhealthyStatus(), nil
......@@ -2399,11 +2400,10 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
},
ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
},
resultBuf: make(chan []*internalpb.SearchResults, 1),
query: request,
chMgr: node.chMgr,
qc: node.queryCoord,
tr: timerecord.NewTimeRecorder("search"),
request: request,
qc: node.queryCoord,
tr: timerecord.NewTimeRecorder("search"),
getQueryNodePolicy: defaultGetQueryNodePolicy,
}
travelTs := request.TravelTimestamp
......@@ -2516,11 +2516,11 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
metrics.ProxySearchCount.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
metrics.SearchLabel, metrics.SuccessLabel).Inc()
metrics.ProxySearchVectors.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
metrics.SearchLabel).Set(float64(qt.result.Results.NumQueries))
metrics.SearchLabel).Set(float64(qt.result.GetResults().GetNumQueries()))
searchDur := tr.ElapseSpan().Milliseconds()
metrics.ProxySearchLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
metrics.SearchLabel).Observe(float64(searchDur))
metrics.ProxySearchLatencyPerNQ.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10)).Observe(float64(searchDur) / float64(qt.result.Results.NumQueries))
metrics.ProxySearchLatencyPerNQ.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10)).Observe(float64(searchDur) / float64(qt.result.GetResults().GetNumQueries()))
return qt.result, nil
}
......@@ -2641,10 +2641,10 @@ func (node *Proxy) Query(ctx context.Context, request *milvuspb.QueryRequest) (*
},
ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
},
resultBuf: make(chan []*internalpb.RetrieveResults),
query: request,
chMgr: node.chMgr,
qc: node.queryCoord,
request: request,
qc: node.queryCoord,
getQueryNodePolicy: defaultGetQueryNodePolicy,
queryShardPolicy: roundRobinPolicy,
}
method := "Query"
......@@ -3058,11 +3058,12 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
},
ResultChannelID: strconv.FormatInt(Params.ProxyCfg.ProxyID, 10),
},
resultBuf: make(chan []*internalpb.RetrieveResults),
query: queryRequest,
chMgr: node.chMgr,
qc: node.queryCoord,
ids: ids.IdArray,
request: queryRequest,
qc: node.queryCoord,
ids: ids.IdArray,
getQueryNodePolicy: defaultGetQueryNodePolicy,
queryShardPolicy: roundRobinPolicy,
}
err := node.sched.dqQueue.Enqueue(qt)
......@@ -3715,6 +3716,7 @@ func (node *Proxy) RegisterLink(ctx context.Context, req *milvuspb.RegisterLinkR
}, nil
}
// GetMetrics gets the metrics of proxy
// TODO(dragondriver): cache the Metrics and set a retention to the cache
func (node *Proxy) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
log.Debug("Proxy.GetMetrics",
......@@ -3817,6 +3819,13 @@ func (node *Proxy) LoadBalance(ctx context.Context, req *milvuspb.LoadBalanceReq
status := &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
}
collectionID, err := globalMetaCache.GetCollectionID(ctx, req.GetCollectionName())
if err != nil {
log.Error("failed to get collection id", zap.String("collection name", req.GetCollectionName()), zap.Error(err))
status.Reason = err.Error()
return status, nil
}
infoResp, err := node.queryCoord.LoadBalance(ctx, &querypb.LoadBalanceRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadBalanceSegments,
......@@ -3828,6 +3837,7 @@ func (node *Proxy) LoadBalance(ctx context.Context, req *milvuspb.LoadBalanceReq
DstNodeIDs: req.DstNodeIDs,
BalanceReason: querypb.TriggerCondition_GrpcRequest,
SealedSegmentIDs: req.SealedSegmentIDs,
CollectionID: collectionID,
})
if err != nil {
log.Error("Failed to LoadBalance from Query Coordinator",
......@@ -3873,6 +3883,7 @@ func (node *Proxy) ManualCompaction(ctx context.Context, req *milvuspb.ManualCom
return resp, err
}
// GetCompactionStateWithPlans returns the compactions states with the given plan ID
func (node *Proxy) GetCompactionStateWithPlans(ctx context.Context, req *milvuspb.GetCompactionPlansRequest) (*milvuspb.GetCompactionPlansResponse, error) {
log.Info("received GetCompactionStateWithPlans request", zap.Int64("compactionID", req.GetCompactionID()))
resp := &milvuspb.GetCompactionPlansResponse{}
......@@ -3979,7 +3990,7 @@ func (node *Proxy) GetReplicas(ctx context.Context, req *milvuspb.GetReplicasReq
return resp, err
}
// Check import task state from datanode
// GetImportState checks import task state from datanode
func (node *Proxy) GetImportState(ctx context.Context, req *milvuspb.GetImportStateRequest) (*milvuspb.GetImportStateResponse, error) {
log.Info("received get import state request", zap.Int64("taskID", req.GetTask()))
resp := &milvuspb.GetImportStateResponse{}
......@@ -4206,3 +4217,19 @@ func (node *Proxy) ListCredUsers(ctx context.Context, req *milvuspb.ListCredUser
Usernames: usernames,
}, nil
}
// SendSearchResult needs to be removed TODO
func (node *Proxy) SendSearchResult(ctx context.Context, req *internalpb.SearchResults) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "Not implemented",
}, nil
}
// SendRetrieveResult needs to be removed TODO
func (node *Proxy) SendRetrieveResult(ctx context.Context, req *internalpb.RetrieveResults) (*commonpb.Status, error) {
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "Not implemented",
}, nil
}
......@@ -31,6 +31,7 @@ import (
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/types"
......@@ -52,6 +53,7 @@ type Cache interface {
GetPartitionInfo(ctx context.Context, collectionName string, partitionName string) (*partitionInfo, error)
// GetCollectionSchema get collection's schema.
GetCollectionSchema(ctx context.Context, collectionName string) (*schemapb.CollectionSchema, error)
GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) ([]*querypb.ShardLeadersList, error)
RemoveCollection(ctx context.Context, collectionName string)
RemovePartition(ctx context.Context, collectionName string, partitionName string)
......@@ -67,6 +69,7 @@ type collectionInfo struct {
collID typeutil.UniqueID
schema *schemapb.CollectionSchema
partInfo map[string]*partitionInfo
shardLeaders []*querypb.ShardLeadersList
createdTimestamp uint64
createdUtcTimestamp uint64
}
......@@ -160,6 +163,7 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, collectionName string
collInfo = m.collInfo[collectionName]
metrics.ProxyUpdateCacheLatency.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10)).Observe(float64(tr.ElapseSpan().Milliseconds()))
}
metrics.ProxyCacheHitCounter.WithLabelValues(strconv.FormatInt(Params.ProxyCfg.ProxyID, 10), "GetCollectionInfo", metrics.CacheHitLabel).Inc()
return &collectionInfo{
collID: collInfo.collID,
......@@ -167,6 +171,7 @@ func (m *MetaCache) GetCollectionInfo(ctx context.Context, collectionName string
partInfo: collInfo.partInfo,
createdTimestamp: collInfo.createdTimestamp,
createdUtcTimestamp: collInfo.createdUtcTimestamp,
shardLeaders: collInfo.shardLeaders,
}, nil
}
......@@ -520,3 +525,41 @@ func (m *MetaCache) GetCredUsernames(ctx context.Context) ([]string, error) {
return usernames, nil
}
// GetShards update cache if withCache == false
func (m *MetaCache) GetShards(ctx context.Context, withCache bool, collectionName string, qc types.QueryCoord) ([]*querypb.ShardLeadersList, error) {
info, err := m.GetCollectionInfo(ctx, collectionName)
if err != nil {
return nil, err
}
if withCache {
if len(info.shardLeaders) > 0 {
return info.shardLeaders, nil
}
log.Info("no shard cache for collection, try to get shard leaders from QueryCoord",
zap.String("collectionName", collectionName))
}
m.mu.Lock()
defer m.mu.Unlock()
req := &querypb.GetShardLeadersRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_GetShardLeaders,
SourceID: Params.ProxyCfg.ProxyID,
},
CollectionID: info.collID,
}
resp, err := qc.GetShardLeaders(ctx, req)
if err != nil {
return nil, err
}
if resp.Status.ErrorCode != commonpb.ErrorCode_Success {
return nil, fmt.Errorf("fail to get shard leaders from QueryCoord: %s", resp.Status.Reason)
}
shards := resp.GetShards()
m.collInfo[collectionName].shardLeaders = shards
return shards, nil
}
......@@ -22,17 +22,17 @@ import (
"fmt"
"testing"
"github.com/milvus-io/milvus/internal/util/crypto"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/crypto"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type MockRootCoordClientInterface struct {
......@@ -310,3 +310,51 @@ func TestMetaCache_GetPartitionError(t *testing.T) {
log.Debug(err.Error())
assert.Equal(t, id, typeutil.UniqueID(0))
}
func TestMetaCache_GetShards(t *testing.T) {
client := &MockRootCoordClientInterface{}
err := InitMetaCache(client)
require.Nil(t, err)
var (
ctx = context.TODO()
collectionName = "collection1"
qc = NewQueryCoordMock()
)
qc.Init()
qc.Start()
defer qc.Stop()
t.Run("No collection in meta cache", func(t *testing.T) {
shards, err := globalMetaCache.GetShards(ctx, true, "non-exists", qc)
assert.Error(t, err)
assert.Empty(t, shards)
})
t.Run("without shardLeaders in collection info invalid shardLeaders", func(t *testing.T) {
qc.validShardLeaders = false
shards, err := globalMetaCache.GetShards(ctx, false, collectionName, qc)
assert.Error(t, err)
assert.Empty(t, shards)
})
t.Run("without shardLeaders in collection info", func(t *testing.T) {
qc.validShardLeaders = true
shards, err := globalMetaCache.GetShards(ctx, true, collectionName, qc)
assert.NoError(t, err)
assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards))
assert.Equal(t, 3, len(shards[0].GetNodeAddrs()))
assert.Equal(t, 3, len(shards[0].GetNodeIds()))
// get from cache
qc.validShardLeaders = false
shards, err = globalMetaCache.GetShards(ctx, true, collectionName, qc)
assert.NoError(t, err)
assert.NotEmpty(t, shards)
assert.Equal(t, 1, len(shards))
assert.Equal(t, 3, len(shards[0].GetNodeAddrs()))
assert.Equal(t, 3, len(shards[0].GetNodeIds()))
})
}
......@@ -93,8 +93,7 @@ type Proxy struct {
factory dependency.Factory
searchResultCh chan *internalpb.SearchResults
retrieveResultCh chan *internalpb.RetrieveResults
searchResultCh chan *internalpb.SearchResults
// Add callback functions at different stages
startCallbacks []func()
......@@ -107,11 +106,10 @@ func NewProxy(ctx context.Context, factory dependency.Factory) (*Proxy, error) {
ctx1, cancel := context.WithCancel(ctx)
n := 1024 // better to be configurable
node := &Proxy{
ctx: ctx1,
cancel: cancel,
factory: factory,
searchResultCh: make(chan *internalpb.SearchResults, n),
retrieveResultCh: make(chan *internalpb.RetrieveResults, n),
ctx: ctx1,
cancel: cancel,
factory: factory,
searchResultCh: make(chan *internalpb.SearchResults, n),
}
node.UpdateStateCode(internalpb.StateCode_Abnormal)
logutil.Logger(ctx).Debug("create a new Proxy instance", zap.Any("state", node.stateCode.Load()))
......@@ -228,9 +226,7 @@ func (node *Proxy) Init() error {
log.Debug("create channels manager done", zap.String("role", typeutil.ProxyRole))
log.Debug("create task scheduler", zap.String("role", typeutil.ProxyRole))
node.sched, err = newTaskScheduler(node.ctx, node.idAllocator, node.tsoAllocator, node.factory,
schedOptWithSearchResultCh(node.searchResultCh),
schedOptWithRetrieveResultCh(node.retrieveResultCh))
node.sched, err = newTaskScheduler(node.ctx, node.idAllocator, node.tsoAllocator, node.factory)
if err != nil {
log.Warn("failed to create task scheduler", zap.Error(err), zap.String("role", typeutil.ProxyRole))
return err
......
......@@ -17,12 +17,7 @@
package proxy
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"math/rand"
"net"
"os"
"strconv"
......@@ -30,81 +25,56 @@ import (
"testing"
"time"
"github.com/golang/protobuf/proto"
ot "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/rootcoord"
"github.com/milvus-io/milvus/internal/util/crypto"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/prometheus/client_golang/prometheus"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/distance"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/tsoutil"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/util/metricsinfo"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/internal/util/distance"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/proxypb"
"github.com/milvus-io/milvus/internal/proto/rootcoordpb"
"github.com/golang/protobuf/proto"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
grpcindexcoordclient "github.com/milvus-io/milvus/internal/distributed/indexcoord/client"
grpcquerycoordclient "github.com/milvus-io/milvus/internal/distributed/querycoord/client"
grpcdatacoordclient2 "github.com/milvus-io/milvus/internal/distributed/datacoord/client"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client"
grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode"
grpcindexcoord "github.com/milvus-io/milvus/internal/distributed/indexcoord"
grpcdatacoordclient "github.com/milvus-io/milvus/internal/distributed/datacoord"
grpcdatacoordclient2 "github.com/milvus-io/milvus/internal/distributed/datacoord/client"
grpcdatanode "github.com/milvus-io/milvus/internal/distributed/datanode"
grpcquerynode "github.com/milvus-io/milvus/internal/distributed/querynode"
grpcindexcoord "github.com/milvus-io/milvus/internal/distributed/indexcoord"
grpcindexcoordclient "github.com/milvus-io/milvus/internal/distributed/indexcoord/client"
grpcindexnode "github.com/milvus-io/milvus/internal/distributed/indexnode"
grpcquerycoord "github.com/milvus-io/milvus/internal/distributed/querycoord"
grpcquerycoordclient "github.com/milvus-io/milvus/internal/distributed/querycoord/client"
grpcquerynode "github.com/milvus-io/milvus/internal/distributed/querynode"
grpcrootcoord "github.com/milvus-io/milvus/internal/distributed/rootcoord"
rcc "github.com/milvus-io/milvus/internal/distributed/rootcoord/client"
"github.com/milvus-io/milvus/internal/datacoord"
"github.com/milvus-io/milvus/internal/datanode"
"github.com/milvus-io/milvus/internal/indexcoord"
"github.com/milvus-io/milvus/internal/indexnode"
"github.com/milvus-io/milvus/internal/querynode"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/querycoord"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/rootcoord"
"github.com/milvus-io/milvus/internal/util/logutil"
"github.com/milvus-io/milvus/internal/querynode"
)
const (
......@@ -620,12 +590,12 @@ func TestProxy(t *testing.T) {
rowNum := 3000
indexName := "_default"
nlist := 10
nprobe := 10
topk := 10
// nprobe := 10
// topk := 10
// add a test parameter
roundDecimal := 6
// roundDecimal := 6
nq := 10
expr := fmt.Sprintf("%s > 0", int64Field)
// expr := fmt.Sprintf("%s > 0", int64Field)
var segmentIDs []int64
// an int64 field (pk) & a float vector field
......@@ -721,76 +691,6 @@ func TestProxy(t *testing.T) {
}
}
constructPlaceholderGroup := func() *milvuspb.PlaceholderGroup {
values := make([][]byte, 0, nq)
for i := 0; i < nq; i++ {
bs := make([]byte, 0, dim*4)
for j := 0; j < dim; j++ {
var buffer bytes.Buffer
f := rand.Float32()
err := binary.Write(&buffer, common.Endian, f)
assert.NoError(t, err)
bs = append(bs, buffer.Bytes()...)
}
values = append(values, bs)
}
return &milvuspb.PlaceholderGroup{
Placeholders: []*milvuspb.PlaceholderValue{
{
Tag: "$0",
Type: milvuspb.PlaceholderType_FloatVector,
Values: values,
},
},
}
}
constructSearchRequest := func() *milvuspb.SearchRequest {
params := make(map[string]string)
params["nprobe"] = strconv.Itoa(nprobe)
b, err := json.Marshal(params)
assert.NoError(t, err)
plg := constructPlaceholderGroup()
plgBs, err := proto.Marshal(plg)
assert.NoError(t, err)
return &milvuspb.SearchRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
PartitionNames: nil,
Dsl: expr,
PlaceholderGroup: plgBs,
DslType: commonpb.DslType_BoolExprV1,
OutputFields: nil,
SearchParams: []*commonpb.KeyValuePair{
{
Key: MetricTypeKey,
Value: distance.L2,
},
{
Key: SearchParamsKey,
Value: string(b),
},
{
Key: AnnsFieldKey,
Value: floatVecField,
},
{
Key: TopKKey,
Value: strconv.Itoa(topk),
},
{
Key: RoundDecimalKey,
Value: strconv.Itoa(roundDecimal),
},
},
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
}
}
wg.Add(1)
t.Run("create collection", func(t *testing.T) {
defer wg.Done()
......@@ -1368,103 +1268,178 @@ func TestProxy(t *testing.T) {
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
if loaded {
wg.Add(1)
t.Run("search", func(t *testing.T) {
defer wg.Done()
req := constructSearchRequest()
resp, err := proxy.Search(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
})
wg.Add(1)
t.Run("search_travel", func(t *testing.T) {
defer wg.Done()
past := time.Now().Add(time.Duration(-1*Params.CommonCfg.RetentionDuration-100) * time.Second)
travelTs := tsoutil.ComposeTSByTime(past, 0)
req := constructSearchRequest()
req.TravelTimestamp = travelTs
//resp, err := proxy.Search(ctx, req)
res, err := proxy.Search(ctx, req)
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, res.Status.ErrorCode)
})
wg.Add(1)
t.Run("search_travel_succ", func(t *testing.T) {
defer wg.Done()
past := time.Now().Add(time.Duration(-1*Params.CommonCfg.RetentionDuration+100) * time.Second)
travelTs := tsoutil.ComposeTSByTime(past, 0)
req := constructSearchRequest()
req.TravelTimestamp = travelTs
//resp, err := proxy.Search(ctx, req)
res, err := proxy.Search(ctx, req)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, res.Status.ErrorCode)
})
wg.Add(1)
t.Run("query", func(t *testing.T) {
defer wg.Done()
//resp, err := proxy.Query(ctx, &milvuspb.QueryRequest{
_, err := proxy.Query(ctx, &milvuspb.QueryRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Expr: expr,
OutputFields: nil,
PartitionNames: nil,
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
})
assert.NoError(t, err)
// FIXME(dragondriver)
// assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
// TODO(dragondriver): compare query result
})
wg.Add(1)
t.Run("query_travel", func(t *testing.T) {
defer wg.Done()
past := time.Now().Add(time.Duration(-1*Params.CommonCfg.RetentionDuration-100) * time.Second)
travelTs := tsoutil.ComposeTSByTime(past, 0)
queryReq := &milvuspb.QueryRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Expr: expr,
OutputFields: nil,
PartitionNames: nil,
TravelTimestamp: travelTs,
GuaranteeTimestamp: 0,
}
res, err := proxy.Query(ctx, queryReq)
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, res.Status.ErrorCode)
})
wg.Add(1)
t.Run("query_travel_succ", func(t *testing.T) {
defer wg.Done()
past := time.Now().Add(time.Duration(-1*Params.CommonCfg.RetentionDuration+100) * time.Second)
travelTs := tsoutil.ComposeTSByTime(past, 0)
queryReq := &milvuspb.QueryRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Expr: expr,
OutputFields: nil,
PartitionNames: nil,
TravelTimestamp: travelTs,
GuaranteeTimestamp: 0,
}
res, err := proxy.Query(ctx, queryReq)
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_EmptyCollection, res.Status.ErrorCode)
})
}
// nprobe := 10
// topk := 10
// roundDecimal := 6
// expr := fmt.Sprintf("%s > 0", int64Field)
// constructPlaceholderGroup := func() *milvuspb.PlaceholderGroup {
// values := make([][]byte, 0, nq)
// for i := 0; i < nq; i++ {
// bs := make([]byte, 0, dim*4)
// for j := 0; j < dim; j++ {
// var buffer bytes.Buffer
// f := rand.Float32()
// err := binary.Write(&buffer, common.Endian, f)
// assert.NoError(t, err)
// bs = append(bs, buffer.Bytes()...)
// }
// values = append(values, bs)
// }
//
// return &milvuspb.PlaceholderGroup{
// Placeholders: []*milvuspb.PlaceholderValue{
// {
// Tag: "$0",
// Type: milvuspb.PlaceholderType_FloatVector,
// Values: values,
// },
// },
// }
// }
//
// constructSearchRequest := func() *milvuspb.SearchRequest {
// params := make(map[string]string)
// params["nprobe"] = strconv.Itoa(nprobe)
// b, err := json.Marshal(params)
// assert.NoError(t, err)
// plg := constructPlaceholderGroup()
// plgBs, err := proto.Marshal(plg)
// assert.NoError(t, err)
//
// return &milvuspb.SearchRequest{
// Base: nil,
// DbName: dbName,
// CollectionName: collectionName,
// PartitionNames: nil,
// Dsl: expr,
// PlaceholderGroup: plgBs,
// DslType: commonpb.DslType_BoolExprV1,
// OutputFields: nil,
// SearchParams: []*commonpb.KeyValuePair{
// {
// Key: MetricTypeKey,
// Value: distance.L2,
// },
// {
// Key: SearchParamsKey,
// Value: string(b),
// },
// {
// Key: AnnsFieldKey,
// Value: floatVecField,
// },
// {
// Key: TopKKey,
// Value: strconv.Itoa(topk),
// },
// {
// Key: RoundDecimalKey,
// Value: strconv.Itoa(roundDecimal),
// },
// },
// TravelTimestamp: 0,
// GuaranteeTimestamp: 0,
// }
// }
// TODO(Goose): reopen after joint-tests
// if loaded {
// wg.Add(1)
// t.Run("search", func(t *testing.T) {
// defer wg.Done()
// req := constructSearchRequest()
//
// resp, err := proxy.Search(ctx, req)
// assert.NoError(t, err)
// assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
// })
//
// wg.Add(1)
// t.Run("search_travel", func(t *testing.T) {
// defer wg.Done()
// past := time.Now().Add(time.Duration(-1*Params.CommonCfg.RetentionDuration-100) * time.Second)
// travelTs := tsoutil.ComposeTSByTime(past, 0)
// req := constructSearchRequest()
// req.TravelTimestamp = travelTs
// //resp, err := proxy.Search(ctx, req)
// res, err := proxy.Search(ctx, req)
// assert.NoError(t, err)
// assert.NotEqual(t, commonpb.ErrorCode_Success, res.Status.ErrorCode)
// })
//
// wg.Add(1)
// t.Run("search_travel_succ", func(t *testing.T) {
// defer wg.Done()
// past := time.Now().Add(time.Duration(-1*Params.CommonCfg.RetentionDuration+100) * time.Second)
// travelTs := tsoutil.ComposeTSByTime(past, 0)
// req := constructSearchRequest()
// req.TravelTimestamp = travelTs
// //resp, err := proxy.Search(ctx, req)
// res, err := proxy.Search(ctx, req)
// assert.NoError(t, err)
// assert.Equal(t, commonpb.ErrorCode_Success, res.Status.ErrorCode)
// })
//
// wg.Add(1)
// t.Run("query", func(t *testing.T) {
// defer wg.Done()
// //resp, err := proxy.Query(ctx, &milvuspb.QueryRequest{
// _, err := proxy.Query(ctx, &milvuspb.QueryRequest{
// Base: nil,
// DbName: dbName,
// CollectionName: collectionName,
// Expr: expr,
// OutputFields: nil,
// PartitionNames: nil,
// TravelTimestamp: 0,
// GuaranteeTimestamp: 0,
// })
// assert.NoError(t, err)
// // FIXME(dragondriver)
// // assert.Equal(t, commonpb.ErrorCode_Success, resp.Status.ErrorCode)
// // TODO(dragondriver): compare query result
// })
//
// wg.Add(1)
// t.Run("query_travel", func(t *testing.T) {
// defer wg.Done()
// past := time.Now().Add(time.Duration(-1*Params.CommonCfg.RetentionDuration-100) * time.Second)
// travelTs := tsoutil.ComposeTSByTime(past, 0)
// queryReq := &milvuspb.QueryRequest{
// Base: nil,
// DbName: dbName,
// CollectionName: collectionName,
// Expr: expr,
// OutputFields: nil,
// PartitionNames: nil,
// TravelTimestamp: travelTs,
// GuaranteeTimestamp: 0,
// }
// res, err := proxy.Query(ctx, queryReq)
// assert.NoError(t, err)
// assert.NotEqual(t, commonpb.ErrorCode_Success, res.Status.ErrorCode)
// })
//
// wg.Add(1)
// t.Run("query_travel_succ", func(t *testing.T) {
// defer wg.Done()
// past := time.Now().Add(time.Duration(-1*Params.CommonCfg.RetentionDuration+100) * time.Second)
// travelTs := tsoutil.ComposeTSByTime(past, 0)
// queryReq := &milvuspb.QueryRequest{
// Base: nil,
// DbName: dbName,
// CollectionName: collectionName,
// Expr: expr,
// OutputFields: nil,
// PartitionNames: nil,
// TravelTimestamp: travelTs,
// GuaranteeTimestamp: 0,
// }
// res, err := proxy.Query(ctx, queryReq)
// assert.NoError(t, err)
// assert.Equal(t, commonpb.ErrorCode_EmptyCollection, res.Status.ErrorCode)
// })
// }
wg.Add(1)
t.Run("calculate distance", func(t *testing.T) {
......@@ -1683,6 +1658,7 @@ func TestProxy(t *testing.T) {
DbName: dbName,
CollectionName: collectionName,
PartitionNames: []string{partitionName},
ReplicaNumber: 1,
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, resp.ErrorCode)
......@@ -1693,6 +1669,7 @@ func TestProxy(t *testing.T) {
DbName: dbName,
CollectionName: collectionName,
PartitionNames: []string{otherPartitionName},
ReplicaNumber: 1,
})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode)
......@@ -1703,6 +1680,7 @@ func TestProxy(t *testing.T) {
DbName: dbName,
CollectionName: otherCollectionName,
PartitionNames: []string{partitionName},
ReplicaNumber: 1,
})
assert.NoError(t, err)
assert.NotEqual(t, commonpb.ErrorCode_Success, resp.ErrorCode)
......
......@@ -36,12 +36,20 @@ type QueryCoordMockOption func(mock *QueryCoordMock)
type queryCoordShowCollectionsFuncType func(ctx context.Context, request *querypb.ShowCollectionsRequest) (*querypb.ShowCollectionsResponse, error)
type queryCoordShowPartitionsFuncType func(ctx context.Context, request *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error)
func SetQueryCoordShowCollectionsFunc(f queryCoordShowCollectionsFuncType) QueryCoordMockOption {
return func(mock *QueryCoordMock) {
mock.showCollectionsFunc = f
}
}
func withValidShardLeaders() QueryCoordMockOption {
return func(mock *QueryCoordMock) {
mock.validShardLeaders = true
}
}
type QueryCoordMock struct {
nodeID typeutil.UniqueID
address string
......@@ -54,9 +62,12 @@ type QueryCoordMock struct {
showCollectionsFunc queryCoordShowCollectionsFuncType
getMetricsFunc getMetricsFuncType
showPartitionsFunc queryCoordShowPartitionsFuncType
statisticsChannel string
timeTickChannel string
validShardLeaders bool
}
func (coord *QueryCoordMock) updateState(state internalpb.StateCode) {
......@@ -223,6 +234,14 @@ func (coord *QueryCoordMock) ReleaseCollection(ctx context.Context, req *querypb
}, nil
}
func (coord *QueryCoordMock) SetShowPartitionsFunc(f queryCoordShowPartitionsFuncType) {
coord.showPartitionsFunc = f
}
func (coord *QueryCoordMock) ResetShowPartitionsFunc() {
coord.showPartitionsFunc = nil
}
func (coord *QueryCoordMock) ShowPartitions(ctx context.Context, req *querypb.ShowPartitionsRequest) (*querypb.ShowPartitionsResponse, error) {
if !coord.healthy() {
return &querypb.ShowPartitionsResponse{
......@@ -233,6 +252,10 @@ func (coord *QueryCoordMock) ShowPartitions(ctx context.Context, req *querypb.Sh
}, nil
}
if coord.showPartitionsFunc != nil {
return coord.showPartitionsFunc(ctx, req)
}
panic("implement me")
}
......@@ -360,6 +383,21 @@ func (coord *QueryCoordMock) GetShardLeaders(ctx context.Context, req *querypb.G
}, nil
}
if coord.validShardLeaders {
return &querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Shards: []*querypb.ShardLeadersList{
{
ChannelName: "channel-1",
NodeIds: []int64{1, 2, 3},
NodeAddrs: []string{"localhost:9000", "localhost:9001", "localhost:9002"},
},
},
}, nil
}
return &querypb.GetShardLeadersResponse{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
......
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"context"
"sync/atomic"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/proto/milvuspb"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
var _ types.QueryNode = &QueryNodeMock{}
type QueryNodeMock struct {
nodeID typeutil.UniqueID
address string
state atomic.Value // internal.StateCode
withSearchResult *internalpb.SearchResults
withQueryResult *internalpb.RetrieveResults
}
func (m *QueryNodeMock) Search(ctx context.Context, req *querypb.SearchRequest) (*internalpb.SearchResults, error) {
return m.withSearchResult, nil
}
func (m *QueryNodeMock) Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
return m.withQueryResult, nil
}
// TODO
func (m *QueryNodeMock) AddQueryChannel(ctx context.Context, req *querypb.AddQueryChannelRequest) (*commonpb.Status, error) {
return nil, nil
}
// TODO
func (m *QueryNodeMock) RemoveQueryChannel(ctx context.Context, req *querypb.RemoveQueryChannelRequest) (*commonpb.Status, error) {
return nil, nil
}
// TODO
func (m *QueryNodeMock) WatchDmChannels(ctx context.Context, req *querypb.WatchDmChannelsRequest) (*commonpb.Status, error) {
return nil, nil
}
// TODO
func (m *QueryNodeMock) WatchDeltaChannels(ctx context.Context, req *querypb.WatchDeltaChannelsRequest) (*commonpb.Status, error) {
return nil, nil
}
// TODO
func (m *QueryNodeMock) LoadSegments(ctx context.Context, req *querypb.LoadSegmentsRequest) (*commonpb.Status, error) {
return nil, nil
}
// TODO
func (m *QueryNodeMock) ReleaseCollection(ctx context.Context, req *querypb.ReleaseCollectionRequest) (*commonpb.Status, error) {
return nil, nil
}
// TODO
func (m *QueryNodeMock) ReleasePartitions(ctx context.Context, req *querypb.ReleasePartitionsRequest) (*commonpb.Status, error) {
return nil, nil
}
// TODO
func (m *QueryNodeMock) ReleaseSegments(ctx context.Context, req *querypb.ReleaseSegmentsRequest) (*commonpb.Status, error) {
return nil, nil
}
// TODO
func (m *QueryNodeMock) GetSegmentInfo(ctx context.Context, req *querypb.GetSegmentInfoRequest) (*querypb.GetSegmentInfoResponse, error) {
return nil, nil
}
// TODO
func (m *QueryNodeMock) GetMetrics(ctx context.Context, req *milvuspb.GetMetricsRequest) (*milvuspb.GetMetricsResponse, error) {
return nil, nil
}
func (m *QueryNodeMock) Init() error { return nil }
func (m *QueryNodeMock) Start() error { return nil }
func (m *QueryNodeMock) Stop() error { return nil }
func (m *QueryNodeMock) Register() error { return nil }
func (m *QueryNodeMock) GetComponentStates(ctx context.Context) (*internalpb.ComponentStates, error) {
return nil, nil
}
func (m *QueryNodeMock) GetStatisticsChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return nil, nil
}
func (m *QueryNodeMock) GetTimeTickChannel(ctx context.Context) (*milvuspb.StringResponse, error) {
return nil, nil
}
package proxy
import (
"context"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
)
func (node *Proxy) SendSearchResult(ctx context.Context, req *internalpb.SearchResults) (*commonpb.Status, error) {
node.searchResultCh <- req
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
}, nil
}
func (node *Proxy) SendRetrieveResult(ctx context.Context, req *internalpb.RetrieveResults) (*commonpb.Status, error) {
node.retrieveResultCh <- req
return &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
}, nil
}
......@@ -2675,9 +2675,10 @@ func (lct *loadCollectionTask) Execute(ctx context.Context) (err error) {
Timestamp: lct.Base.Timestamp,
SourceID: lct.Base.SourceID,
},
DbID: 0,
CollectionID: collID,
Schema: collSchema,
DbID: 0,
CollectionID: collID,
Schema: collSchema,
ReplicaNumber: lct.ReplicaNumber,
}
log.Debug("send LoadCollectionRequest to query coordinator", zap.String("role", typeutil.ProxyRole),
zap.Int64("msgID", request.Base.MsgID), zap.Int64("collectionID", request.CollectionID),
......@@ -2869,10 +2870,11 @@ func (lpt *loadPartitionsTask) Execute(ctx context.Context) error {
Timestamp: lpt.Base.Timestamp,
SourceID: lpt.Base.SourceID,
},
DbID: 0,
CollectionID: collID,
PartitionIDs: partitionIDs,
Schema: collSchema,
DbID: 0,
CollectionID: collID,
PartitionIDs: partitionIDs,
Schema: collSchema,
ReplicaNumber: lpt.ReplicaNumber,
}
lpt.result, err = lpt.queryCoord.LoadPartitions(ctx, request)
return err
......
package proxy
import (
"context"
"errors"
"fmt"
qnClient "github.com/milvus-io/milvus/internal/distributed/querynode/client"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
"go.uber.org/zap"
)
type getQueryNodePolicy func(context.Context, string) (types.QueryNode, error)
type pickShardPolicy func(ctx context.Context, policy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders *querypb.ShardLeadersList) error
// TODO add another policy to enbale the use of cache
// defaultGetQueryNodePolicy creates QueryNode client for every address everytime
func defaultGetQueryNodePolicy(ctx context.Context, address string) (types.QueryNode, error) {
qn, err := qnClient.NewClient(ctx, address)
if err != nil {
return nil, err
}
if err := qn.Init(); err != nil {
return nil, err
}
if err := qn.Start(); err != nil {
return nil, err
}
return qn, nil
}
var (
errBegin = errors.New("begin error")
errInvalidShardLeaders = errors.New("Invalid shard leader")
)
func roundRobinPolicy(ctx context.Context, getQueryNodePolicy getQueryNodePolicy, query func(UniqueID, types.QueryNode) error, leaders *querypb.ShardLeadersList) error {
var (
err = errBegin
current = 0
qn types.QueryNode
)
replicaNum := len(leaders.GetNodeIds())
for err != nil && current < replicaNum {
currentID := leaders.GetNodeIds()[current]
if err != errBegin {
log.Warn("retry with another QueryNode", zap.String("leader", leaders.GetChannelName()), zap.Int64("nodeID", currentID))
}
qn, err = getQueryNodePolicy(ctx, leaders.GetNodeAddrs()[current])
if err != nil {
log.Warn("fail to get valid QueryNode", zap.Int64("nodeID", currentID),
zap.Error(err))
current++
continue
}
defer qn.Stop()
err = query(currentID, qn)
if err != nil {
log.Warn("fail to Query with shard leader",
zap.String("leader", leaders.GetChannelName()),
zap.Int64("nodeID", currentID),
zap.Error(err))
}
current++
}
if current == replicaNum && err != nil {
return fmt.Errorf("no shard leaders available for channel: %s, leaders: %v, err: %s", leaders.GetChannelName(), leaders.GetNodeIds(), err.Error())
}
return nil
}
此差异已折叠。
......@@ -3,16 +3,15 @@ package proxy
import (
"context"
"fmt"
"strconv"
"sync"
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/milvus-io/milvus/internal/common"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
......@@ -25,25 +24,35 @@ import (
)
func TestQueryTask_all(t *testing.T) {
var err error
Params.Init()
Params.ProxyCfg.RetrieveResultChannelNames = []string{funcutil.GenRandomStr()}
rc := NewRootCoordMock()
var (
err error
ctx = context.TODO()
rc = NewRootCoordMock()
qc = NewQueryCoordMock(withValidShardLeaders())
qn = &QueryNodeMock{}
shardsNum = int32(2)
collectionName = t.Name() + funcutil.GenRandomStr()
expr = fmt.Sprintf("%s > 0", testInt64Field)
hitNum = 10
)
mockGetQueryNodePolicy := func(ctx context.Context, address string) (types.QueryNode, error) {
return qn, nil
}
rc.Start()
defer rc.Stop()
ctx := context.Background()
qc.Start()
defer qc.Stop()
err = InitMetaCache(rc)
assert.NoError(t, err)
shardsNum := int32(2)
prefix := "TestQueryTask_all"
dbName := ""
collectionName := prefix + funcutil.GenRandomStr()
fieldName2Types := map[string]schemapb.DataType{
testBoolField: schemapb.DataType_Bool,
testInt32Field: schemapb.DataType_Int32,
......@@ -56,9 +65,6 @@ func TestQueryTask_all(t *testing.T) {
fieldName2Types[testBinaryVecField] = schemapb.DataType_BinaryVector
}
expr := fmt.Sprintf("%s > 0", testInt64Field)
hitNum := 10
schema := constructCollectionSchemaByDataType(collectionName, fieldName2Types, testInt64Field, false)
marshaledSchema, err := proto.Marshal(schema)
assert.NoError(t, err)
......@@ -66,165 +72,66 @@ func TestQueryTask_all(t *testing.T) {
createColT := &createCollectionTask{
Condition: NewTaskCondition(ctx),
CreateCollectionRequest: &milvuspb.CreateCollectionRequest{
Base: nil,
DbName: dbName,
CollectionName: collectionName,
Schema: marshaledSchema,
ShardsNum: shardsNum,
},
ctx: ctx,
rootCoord: rc,
result: nil,
schema: nil,
}
assert.NoError(t, createColT.OnEnqueue())
assert.NoError(t, createColT.PreExecute(ctx))
assert.NoError(t, createColT.Execute(ctx))
assert.NoError(t, createColT.PostExecute(ctx))
dmlChannelsFunc := getDmlChannelsFunc(ctx, rc)
query := newMockGetChannelsService()
factory := newSimpleMockMsgStreamFactory()
chMgr := newChannelsMgrImpl(dmlChannelsFunc, nil, query.GetChannels, nil, factory)
defer chMgr.removeAllDMLStream()
defer chMgr.removeAllDQLStream()
require.NoError(t, createColT.OnEnqueue())
require.NoError(t, createColT.PreExecute(ctx))
require.NoError(t, createColT.Execute(ctx))
require.NoError(t, createColT.PostExecute(ctx))
collectionID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
assert.NoError(t, err)
qc := NewQueryCoordMock()
qc.Start()
defer qc.Stop()
status, err := qc.LoadCollection(ctx, &querypb.LoadCollectionRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadCollection,
MsgID: 0,
Timestamp: 0,
SourceID: Params.ProxyCfg.ProxyID,
MsgType: commonpb.MsgType_LoadCollection,
SourceID: Params.ProxyCfg.ProxyID,
},
DbID: 0,
CollectionID: collectionID,
Schema: nil,
})
assert.NoError(t, err)
assert.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
require.NoError(t, err)
require.Equal(t, commonpb.ErrorCode_Success, status.ErrorCode)
// test begins
task := &queryTask{
Condition: NewTaskCondition(ctx),
RetrieveRequest: &internalpb.RetrieveRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
MsgID: 0,
Timestamp: 0,
SourceID: Params.ProxyCfg.ProxyID,
MsgType: commonpb.MsgType_Retrieve,
SourceID: Params.ProxyCfg.ProxyID,
},
ResultChannelID: strconv.Itoa(int(Params.ProxyCfg.ProxyID)),
DbID: 0,
CollectionID: collectionID,
PartitionIDs: nil,
SerializedExprPlan: nil,
OutputFieldsId: make([]int64, len(fieldName2Types)),
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
CollectionID: collectionID,
OutputFieldsId: make([]int64, len(fieldName2Types)),
},
ctx: ctx,
resultBuf: make(chan []*internalpb.RetrieveResults),
ctx: ctx,
result: &milvuspb.QueryResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
FieldsData: nil,
},
query: &milvuspb.QueryRequest{
request: &milvuspb.QueryRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Retrieve,
MsgID: 0,
Timestamp: 0,
SourceID: Params.ProxyCfg.ProxyID,
MsgType: commonpb.MsgType_Retrieve,
SourceID: Params.ProxyCfg.ProxyID,
},
DbName: dbName,
CollectionName: collectionName,
Expr: expr,
OutputFields: nil,
PartitionNames: nil,
TravelTimestamp: 0,
GuaranteeTimestamp: 0,
CollectionName: collectionName,
Expr: expr,
},
chMgr: chMgr,
qc: qc,
ids: nil,
qc: qc,
getQueryNodePolicy: mockGetQueryNodePolicy,
queryShardPolicy: roundRobinPolicy,
}
for i := 0; i < len(fieldName2Types); i++ {
task.RetrieveRequest.OutputFieldsId[i] = int64(common.StartOfUserFieldID + i)
}
// simple mock for query node
// TODO(dragondriver): should we replace this mock using RocksMq or MemMsgStream?
err = chMgr.createDQLStream(collectionID)
assert.NoError(t, err)
stream, err := chMgr.getDQLStream(collectionID)
assert.NoError(t, err)
var wg sync.WaitGroup
wg.Add(1)
consumeCtx, cancel := context.WithCancel(ctx)
go func() {
defer wg.Done()
for {
select {
case <-consumeCtx.Done():
return
case pack, ok := <-stream.Chan():
assert.True(t, ok)
if pack == nil {
continue
}
for _, msg := range pack.Msgs {
_, ok := msg.(*msgstream.RetrieveMsg)
assert.True(t, ok)
// TODO(dragondriver): construct result according to the request
result1 := &internalpb.RetrieveResults{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_RetrieveResult,
MsgID: 0,
Timestamp: 0,
SourceID: 0,
},
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
Reason: "",
},
ResultChannelID: strconv.Itoa(int(Params.ProxyCfg.ProxyID)),
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{
Data: generateInt64Array(hitNum),
},
},
},
SealedSegmentIDsRetrieved: nil,
ChannelIDsRetrieved: nil,
GlobalSealedSegmentIDs: nil,
}
fieldID := common.StartOfUserFieldID
for fieldName, dataType := range fieldName2Types {
result1.FieldsData = append(result1.FieldsData, generateFieldData(dataType, fieldName, int64(fieldID), hitNum))
fieldID++
}
// send search result
task.resultBuf <- []*internalpb.RetrieveResults{result1}
}
}
}
}()
assert.NoError(t, task.OnEnqueue())
// test query task with timeout
......@@ -236,11 +143,29 @@ func TestQueryTask_all(t *testing.T) {
assert.NoError(t, task.PreExecute(ctx))
// after preExecute
assert.Greater(t, task.TimeoutTimestamp, typeutil.ZeroTimestamp)
task.ctx = ctx
result1 := &internalpb.RetrieveResults{
Base: &commonpb.MsgBase{MsgType: commonpb.MsgType_RetrieveResult},
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success,
},
Ids: &schemapb.IDs{
IdField: &schemapb.IDs_IntId{
IntId: &schemapb.LongArray{Data: generateInt64Array(hitNum)},
},
},
}
fieldID := common.StartOfUserFieldID
for fieldName, dataType := range fieldName2Types {
result1.FieldsData = append(result1.FieldsData, generateFieldData(dataType, fieldName, int64(fieldID), hitNum))
fieldID++
}
qn.withQueryResult = result1
task.ctx = ctx
assert.NoError(t, task.Execute(ctx))
assert.NoError(t, task.PostExecute(ctx))
cancel()
wg.Wait()
assert.NoError(t, task.PostExecute(ctx))
}
......@@ -23,16 +23,10 @@ import (
"fmt"
"sync"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/milvus-io/milvus/internal/util/funcutil"
"go.uber.org/zap"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/opentracing/opentracing-go"
oplog "github.com/opentracing/opentracing-go/log"
......@@ -385,25 +379,10 @@ type taskScheduler struct {
cancel context.CancelFunc
msFactory msgstream.Factory
searchResultCh chan *internalpb.SearchResults
retrieveResultCh chan *internalpb.RetrieveResults
}
type schedOpt func(*taskScheduler)
func schedOptWithSearchResultCh(ch chan *internalpb.SearchResults) schedOpt {
return func(sched *taskScheduler) {
sched.searchResultCh = ch
}
}
func schedOptWithRetrieveResultCh(ch chan *internalpb.RetrieveResults) schedOpt {
return func(sched *taskScheduler) {
sched.retrieveResultCh = ch
}
}
func newTaskScheduler(ctx context.Context,
idAllocatorIns idAllocatorInterface,
tsoAllocatorIns tsoAllocator,
......@@ -551,265 +530,6 @@ func (sched *taskScheduler) queryLoop() {
}
}
type resultBufHeader struct {
msgID UniqueID
usedVChans map[interface{}]struct{} // set of vChan
receivedVChansSet map[interface{}]struct{} // set of vChan
receivedSealedSegmentIDsSet map[interface{}]struct{} // set of UniqueID
receivedGlobalSegmentIDsSet map[interface{}]struct{} // set of UniqueID
haveError bool
}
type searchResultBuf struct {
resultBufHeader
resultBuf []*internalpb.SearchResults
}
type queryResultBuf struct {
resultBufHeader
resultBuf []*internalpb.RetrieveResults
}
func newSearchResultBuf(msgID UniqueID) *searchResultBuf {
return &searchResultBuf{
resultBufHeader: resultBufHeader{
usedVChans: make(map[interface{}]struct{}),
receivedVChansSet: make(map[interface{}]struct{}),
receivedSealedSegmentIDsSet: make(map[interface{}]struct{}),
receivedGlobalSegmentIDsSet: make(map[interface{}]struct{}),
haveError: false,
msgID: msgID,
},
resultBuf: make([]*internalpb.SearchResults, 0),
}
}
func newQueryResultBuf(msgID UniqueID) *queryResultBuf {
return &queryResultBuf{
resultBufHeader: resultBufHeader{
usedVChans: make(map[interface{}]struct{}),
receivedVChansSet: make(map[interface{}]struct{}),
receivedSealedSegmentIDsSet: make(map[interface{}]struct{}),
receivedGlobalSegmentIDsSet: make(map[interface{}]struct{}),
haveError: false,
msgID: msgID,
},
resultBuf: make([]*internalpb.RetrieveResults, 0),
}
}
func (sr *resultBufHeader) readyToReduce() bool {
if sr.haveError {
log.Debug("Proxy searchResultBuf readyToReduce", zap.Any("haveError", true))
return true
}
log.Debug("check if result buf is ready to reduce",
zap.String("role", typeutil.ProxyRole),
zap.Int64("MsgID", sr.msgID),
zap.Any("receivedVChansSet", funcutil.SetToSlice(sr.receivedVChansSet)),
zap.Any("usedVChans", funcutil.SetToSlice(sr.usedVChans)),
zap.Any("receivedSealedSegmentIDsSet", funcutil.SetToSlice(sr.receivedSealedSegmentIDsSet)),
zap.Any("receivedGlobalSegmentIDsSet", funcutil.SetToSlice(sr.receivedGlobalSegmentIDsSet)))
ret1 := funcutil.SetContain(sr.receivedVChansSet, sr.usedVChans)
if !ret1 {
return false
}
return funcutil.SetContain(sr.receivedSealedSegmentIDsSet, sr.receivedGlobalSegmentIDsSet)
}
func (sr *resultBufHeader) addPartialResult(vchans []vChan, searchSegIDs, globalSegIDs []UniqueID) {
for _, vchan := range vchans {
sr.receivedVChansSet[vchan] = struct{}{}
}
for _, sealedSegment := range searchSegIDs {
sr.receivedSealedSegmentIDsSet[sealedSegment] = struct{}{}
}
for _, globalSegment := range globalSegIDs {
sr.receivedGlobalSegmentIDsSet[globalSegment] = struct{}{}
}
}
func (sr *searchResultBuf) addPartialResult(result *internalpb.SearchResults) {
sr.resultBuf = append(sr.resultBuf, result)
if result.Status.ErrorCode != commonpb.ErrorCode_Success {
sr.haveError = true
return
}
sr.resultBufHeader.addPartialResult(result.ChannelIDsSearched, result.SealedSegmentIDsSearched,
result.GlobalSealedSegmentIDs)
}
func (qr *queryResultBuf) addPartialResult(result *internalpb.RetrieveResults) {
qr.resultBuf = append(qr.resultBuf, result)
if result.Status.ErrorCode != commonpb.ErrorCode_Success {
qr.haveError = true
return
}
qr.resultBufHeader.addPartialResult(result.ChannelIDsRetrieved, result.SealedSegmentIDsRetrieved,
result.GlobalSealedSegmentIDs)
}
func (sched *taskScheduler) collectionResultLoopV2() {
defer sched.wg.Done()
searchResultBufs := make(map[UniqueID]*searchResultBuf)
searchResultBufFlags := newIDCache(Params.ProxyCfg.BufFlagExpireTime, Params.ProxyCfg.BufFlagCleanupInterval) // if value is true, we can ignore searchResult
queryResultBufs := make(map[UniqueID]*queryResultBuf)
queryResultBufFlags := newIDCache(Params.ProxyCfg.BufFlagExpireTime, Params.ProxyCfg.BufFlagCleanupInterval) // if value is true, we can ignore queryResult
processSearchResult := func(results *internalpb.SearchResults) error {
reqID := results.Base.MsgID
ignoreThisResult, ok := searchResultBufFlags.Get(reqID)
if !ok {
searchResultBufFlags.Set(reqID, false)
ignoreThisResult = false
}
if ignoreThisResult {
log.Debug("got a search result, but we should ignore", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
return nil
}
log.Debug("got a search result", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
t := sched.getTaskByReqID(reqID)
if t == nil {
log.Debug("got a search result, but not in task scheduler", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
delete(searchResultBufs, reqID)
searchResultBufFlags.Set(reqID, true)
}
st, ok := t.(*searchTask)
if !ok {
log.Debug("got a search result, but the related task is not of search task", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
delete(searchResultBufs, reqID)
searchResultBufFlags.Set(reqID, true)
return nil
}
resultBuf, ok := searchResultBufs[reqID]
if !ok {
log.Debug("first receive search result of this task", zap.String("role", typeutil.ProxyRole), zap.Int64("reqID", reqID))
resultBuf = newSearchResultBuf(reqID)
vchans, err := st.getVChannels()
if err != nil {
delete(searchResultBufs, reqID)
log.Warn("failed to get virtual channels", zap.String("role", typeutil.ProxyRole), zap.Error(err), zap.Int64("reqID", reqID))
return err
}
for _, vchan := range vchans {
resultBuf.usedVChans[vchan] = struct{}{}
}
searchResultBufs[reqID] = resultBuf
}
resultBuf.addPartialResult(results)
colName := t.(*searchTask).query.CollectionName
log.Debug("process search result", zap.String("role", typeutil.ProxyRole), zap.String("collection", colName), zap.Int64("reqID", reqID), zap.Int("answer cnt", len(searchResultBufs[reqID].resultBuf)))
if resultBuf.readyToReduce() {
log.Debug("process search result, ready to reduce", zap.String("role", typeutil.ProxyRole), zap.Int64("reqID", reqID))
searchResultBufFlags.Set(reqID, true)
st.resultBuf <- resultBuf.resultBuf
delete(searchResultBufs, reqID)
}
return nil
}
processRetrieveResult := func(results *internalpb.RetrieveResults) error {
reqID := results.Base.MsgID
ignoreThisResult, ok := queryResultBufFlags.Get(reqID)
if !ok {
queryResultBufFlags.Set(reqID, false)
ignoreThisResult = false
}
if ignoreThisResult {
log.Debug("got a retrieve result, but we should ignore", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
return nil
}
log.Debug("got a retrieve result", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
t := sched.getTaskByReqID(reqID)
if t == nil {
log.Debug("got a retrieve result, but not in task scheduler", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
delete(queryResultBufs, reqID)
queryResultBufFlags.Set(reqID, true)
}
st, ok := t.(*queryTask)
if !ok {
log.Debug("got a retrieve result, but the related task is not of retrieve task", zap.String("role", typeutil.ProxyRole), zap.Int64("ReqID", reqID))
delete(queryResultBufs, reqID)
queryResultBufFlags.Set(reqID, true)
return nil
}
resultBuf, ok := queryResultBufs[reqID]
if !ok {
log.Debug("first receive retrieve result of this task", zap.String("role", typeutil.ProxyRole), zap.Int64("reqID", reqID))
resultBuf = newQueryResultBuf(reqID)
vchans, err := st.getVChannels()
if err != nil {
delete(queryResultBufs, reqID)
log.Warn("failed to get virtual channels", zap.String("role", typeutil.ProxyRole), zap.Error(err), zap.Int64("reqID", reqID))
return err
}
for _, vchan := range vchans {
resultBuf.usedVChans[vchan] = struct{}{}
}
queryResultBufs[reqID] = resultBuf
}
resultBuf.addPartialResult(results)
colName := t.(*queryTask).query.CollectionName
log.Debug("process retrieve result", zap.String("role", typeutil.ProxyRole), zap.String("collection", colName), zap.Int64("reqID", reqID), zap.Int("answer cnt", len(queryResultBufs[reqID].resultBuf)))
if resultBuf.readyToReduce() {
log.Debug("process retrieve result, ready to reduce", zap.String("role", typeutil.ProxyRole), zap.Int64("reqID", reqID))
queryResultBufFlags.Set(reqID, true)
st.resultBuf <- resultBuf.resultBuf
delete(queryResultBufs, reqID)
}
return nil
}
for {
select {
case <-sched.ctx.Done():
log.Info("task scheduler's result loop of Proxy exit", zap.String("reason", "context done"))
return
case sr, ok := <-sched.searchResultCh:
if !ok {
log.Info("task scheduler's result loop of Proxy exit", zap.String("reason", "search result channel closed"))
return
}
if err := processSearchResult(sr); err != nil {
log.Warn("failed to process search result", zap.Error(err))
}
case rr, ok := <-sched.retrieveResultCh:
if !ok {
log.Info("task scheduler's result loop of Proxy exit", zap.String("reason", "retrieve result channel closed"))
return
}
if err := processRetrieveResult(rr); err != nil {
log.Warn("failed to process retrieve result", zap.Error(err))
}
}
}
}
func (sched *taskScheduler) Start() error {
sched.wg.Add(1)
go sched.definitionLoop()
......@@ -820,10 +540,6 @@ func (sched *taskScheduler) Start() error {
sched.wg.Add(1)
go sched.queryLoop()
sched.wg.Add(1)
// go sched.collectResultLoop()
go sched.collectionResultLoopV2()
return nil
}
......
此差异已折叠。
此差异已折叠。
......@@ -30,6 +30,7 @@ import (
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/util/dependency"
"github.com/milvus-io/milvus/internal/util/funcutil"
)
......@@ -51,7 +52,7 @@ type channelUnsubscribeHandler struct {
}
// newChannelUnsubscribeHandler create a new handler service to unsubscribe channels
func newChannelUnsubscribeHandler(ctx context.Context, kv *etcdkv.EtcdKV, factory msgstream.Factory) (*channelUnsubscribeHandler, error) {
func newChannelUnsubscribeHandler(ctx context.Context, kv *etcdkv.EtcdKV, factory dependency.Factory) (*channelUnsubscribeHandler, error) {
childCtx, cancel := context.WithCancel(ctx)
handler := &channelUnsubscribeHandler{
ctx: childCtx,
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册