未验证 提交 84bddd0a 编写于 作者: D dragondriver 提交者: GitHub

Add unittest cases for proxy (#7364)

Signed-off-by: Ndragondriver <jiquan.long@zilliz.com>
上级 e6de86a4
......@@ -138,8 +138,11 @@ func (s *Server) init() error {
proxy.Params.Init()
log.Debug("init params done ...")
// NetworkPort & IP don't matter here, NetworkAddress matters
proxy.Params.NetworkPort = Params.Port
proxy.Params.IP = Params.IP
proxy.Params.NetworkAddress = Params.Address
// for purpose of ID Allocator
proxy.Params.RootCoordAddress = Params.RootCoordAddress
......
......@@ -24,9 +24,6 @@ import (
"go.uber.org/zap"
)
type vChan = string
type pChan = string
type channelsMgr interface {
getChannels(collectionID UniqueID) ([]pChan, error)
getVChannels(collectionID UniqueID) ([]vChan, error)
......@@ -179,7 +176,12 @@ func (mgr *singleTypeChannelsMgr) getAllVIDs(collectionID UniqueID) ([]int, erro
mgr.collMtx.RLock()
defer mgr.collMtx.RUnlock()
return mgr.collectionID2VIDs[collectionID], nil
ids, exist := mgr.collectionID2VIDs[collectionID]
if !exist {
return nil, fmt.Errorf("collection %d not found", collectionID)
}
return ids, nil
}
func (mgr *singleTypeChannelsMgr) getVChansByVID(vid int) ([]vChan, error) {
......@@ -339,10 +341,15 @@ func (mgr *singleTypeChannelsMgr) getVChannels(collectionID UniqueID) ([]vChan,
func (mgr *singleTypeChannelsMgr) createMsgStream(collectionID UniqueID) error {
channels, err := mgr.getChannelsFunc(collectionID)
log.Debug("singleTypeChannelsMgr", zap.Any("createMsgStream.getChannels", channels))
if err != nil {
log.Warn("failed to create message stream",
zap.Int64("collection_id", collectionID),
zap.Error(err))
return err
}
log.Debug("singleTypeChannelsMgr",
zap.Int64("collection_id", collectionID),
zap.Any("createMsgStream.getChannels", channels))
mgr.updateChannels(channels)
......@@ -480,13 +487,13 @@ func (mgr *channelsMgrImpl) removeAllDMLStream() error {
return mgr.dmlChannelsMgr.removeAllStream()
}
func newChannelsMgr(
func newChannelsMgrImpl(
getDmlChannelsFunc getChannelsFuncType,
dmlRepackFunc repackFuncType,
getDqlChannelsFunc getChannelsFuncType,
dqlRepackFunc repackFuncType,
msgStreamFactory msgstream.Factory,
) channelsMgr {
) *channelsMgrImpl {
return &channelsMgrImpl{
dmlChannelsMgr: newSingleTypeChannelsMgr(getDmlChannelsFunc, msgStreamFactory, dmlRepackFunc, dmlStreamType),
dqlChannelsMgr: newSingleTypeChannelsMgr(getDqlChannelsFunc, msgStreamFactory, dqlRepackFunc, dqlStreamType),
......
......@@ -11,7 +11,6 @@
package proxy
/*
import (
"testing"
......@@ -38,7 +37,7 @@ func TestChannelsMgrImpl_getChannels(t *testing.T) {
master := newMockGetChannelsService()
query := newMockGetChannelsService()
factory := msgstream.NewSimpleMsgStreamFactory()
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
defer mgr.removeAllDMLStream()
collID := UniqueID(getUniqueIntGeneratorIns().get())
......@@ -56,7 +55,7 @@ func TestChannelsMgrImpl_getVChannels(t *testing.T) {
master := newMockGetChannelsService()
query := newMockGetChannelsService()
factory := msgstream.NewSimpleMsgStreamFactory()
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
defer mgr.removeAllDMLStream()
collID := UniqueID(getUniqueIntGeneratorIns().get())
......@@ -74,7 +73,7 @@ func TestChannelsMgrImpl_createDMLMsgStream(t *testing.T) {
master := newMockGetChannelsService()
query := newMockGetChannelsService()
factory := msgstream.NewSimpleMsgStreamFactory()
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
defer mgr.removeAllDMLStream()
collID := UniqueID(getUniqueIntGeneratorIns().get())
......@@ -96,7 +95,7 @@ func TestChannelsMgrImpl_getDMLMsgStream(t *testing.T) {
master := newMockGetChannelsService()
query := newMockGetChannelsService()
factory := msgstream.NewSimpleMsgStreamFactory()
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
defer mgr.removeAllDMLStream()
collID := UniqueID(getUniqueIntGeneratorIns().get())
......@@ -114,7 +113,7 @@ func TestChannelsMgrImpl_removeDMLMsgStream(t *testing.T) {
master := newMockGetChannelsService()
query := newMockGetChannelsService()
factory := msgstream.NewSimpleMsgStreamFactory()
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
defer mgr.removeAllDMLStream()
collID := UniqueID(getUniqueIntGeneratorIns().get())
......@@ -141,7 +140,7 @@ func TestChannelsMgrImpl_removeAllDMLMsgStream(t *testing.T) {
master := newMockGetChannelsService()
query := newMockGetChannelsService()
factory := msgstream.NewSimpleMsgStreamFactory()
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
defer mgr.removeAllDMLStream()
num := 10
......@@ -156,21 +155,12 @@ func TestChannelsMgrImpl_createDQLMsgStream(t *testing.T) {
master := newMockGetChannelsService()
query := newMockGetChannelsService()
factory := msgstream.NewSimpleMsgStreamFactory()
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
defer mgr.removeAllDMLStream()
collID := UniqueID(getUniqueIntGeneratorIns().get())
_, err := mgr.getChannels(collID)
assert.NotEqual(t, nil, err)
_, err = mgr.getVChannels(collID)
assert.NotEqual(t, nil, err)
err = mgr.createDQLStream(collID)
assert.Equal(t, nil, err)
_, err = mgr.getChannels(collID)
assert.Equal(t, nil, err)
_, err = mgr.getVChannels(collID)
err := mgr.createDQLStream(collID)
assert.Equal(t, nil, err)
}
......@@ -178,7 +168,7 @@ func TestChannelsMgrImpl_getDQLMsgStream(t *testing.T) {
master := newMockGetChannelsService()
query := newMockGetChannelsService()
factory := msgstream.NewSimpleMsgStreamFactory()
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
defer mgr.removeAllDMLStream()
collID := UniqueID(getUniqueIntGeneratorIns().get())
......@@ -196,7 +186,7 @@ func TestChannelsMgrImpl_removeDQLMsgStream(t *testing.T) {
master := newMockGetChannelsService()
query := newMockGetChannelsService()
factory := msgstream.NewSimpleMsgStreamFactory()
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
defer mgr.removeAllDMLStream()
collID := UniqueID(getUniqueIntGeneratorIns().get())
......@@ -223,7 +213,7 @@ func TestChannelsMgrImpl_removeAllDQLMsgStream(t *testing.T) {
master := newMockGetChannelsService()
query := newMockGetChannelsService()
factory := msgstream.NewSimpleMsgStreamFactory()
mgr := newChannelsMgr(master.GetChannels, nil, query.GetChannels, nil, factory)
mgr := newChannelsMgrImpl(master.GetChannels, nil, query.GetChannels, nil, factory)
defer mgr.removeAllDMLStream()
num := 10
......@@ -233,4 +223,3 @@ func TestChannelsMgrImpl_removeAllDQLMsgStream(t *testing.T) {
assert.Equal(t, nil, err)
}
}
*/
......@@ -21,11 +21,6 @@ import (
"go.uber.org/zap"
)
type pChanStatistics struct {
minTs Timestamp
maxTs Timestamp
}
// ticker can update ts only when the minTs greater than the ts of ticker, we can use maxTs to update current later
type getPChanStatisticsFuncType func() (map[pChan]*pChanStatistics, error)
......
......@@ -11,6 +11,181 @@
package proxy
import (
"encoding/json"
"testing"
"github.com/milvus-io/milvus/internal/log"
"go.uber.org/zap"
"github.com/stretchr/testify/assert"
)
func Test_parseDummyRequestType(t *testing.T) {
var err error
// not in json format
notInJSONFormatStr := "not in json format string"
_, err = parseDummyRequestType(notInJSONFormatStr)
assert.NotNil(t, err)
// only contain other field, in json format
otherField := "other_field"
otherFieldValue := "not important"
m1 := make(map[string]string)
m1[otherField] = otherFieldValue
bs1, err := json.Marshal(m1)
assert.Nil(t, err)
log.Info("Test_parseDummyRequestType",
zap.String("json", string(bs1)))
ret1, err := parseDummyRequestType(string(bs1))
assert.Nil(t, err)
assert.Equal(t, 0, len(ret1.RequestType))
// normal case
key := "request_type"
value := "value"
m2 := make(map[string]string)
m2[key] = value
bs2, err := json.Marshal(m2)
assert.Nil(t, err)
log.Info("Test_parseDummyRequestType",
zap.String("json", string(bs2)))
ret2, err := parseDummyRequestType(string(bs2))
assert.Nil(t, err)
assert.Equal(t, value, ret2.RequestType)
// contain other field and request_type
m3 := make(map[string]string)
m3[key] = value
m3[otherField] = otherFieldValue
bs3, err := json.Marshal(m3)
assert.Nil(t, err)
log.Info("Test_parseDummyRequestType",
zap.String("json", string(bs3)))
ret3, err := parseDummyRequestType(string(bs3))
assert.Nil(t, err)
assert.Equal(t, value, ret3.RequestType)
}
func Test_parseDummyQueryRequest(t *testing.T) {
var err error
// not in json format
notInJSONFormatStr := "not in json format string"
_, err = parseDummyQueryRequest(notInJSONFormatStr)
assert.NotNil(t, err)
// only contain other field, in json format
otherField := "other_field"
otherFieldValue := "not important"
m1 := make(map[string]interface{})
m1[otherField] = otherFieldValue
bs1, err := json.Marshal(m1)
log.Info("Test_parseDummyQueryRequest",
zap.String("json", string(bs1)))
assert.Nil(t, err)
ret1, err := parseDummyQueryRequest(string(bs1))
assert.Nil(t, err)
assert.Equal(t, 0, len(ret1.RequestType))
assert.Equal(t, 0, len(ret1.DbName))
assert.Equal(t, 0, len(ret1.CollectionName))
assert.Equal(t, 0, len(ret1.PartitionNames))
assert.Equal(t, 0, len(ret1.Expr))
assert.Equal(t, 0, len(ret1.OutputFields))
requestTypeKey := "request_type"
requestTypeValue := "request_type"
dbNameKey := "dbname"
dbNameValue := "dbname"
collectionNameKey := "collection_name"
collectionNameValue := "collection_name"
partitionNamesKey := "partition_names"
partitionNamesValue := []string{"partition_names"}
exprKey := "expr"
exprValue := "expr"
outputFieldsKey := "output_fields"
outputFieldsValue := []string{"output_fields"}
// all fields
m2 := make(map[string]interface{})
m2[requestTypeKey] = requestTypeValue
m2[dbNameKey] = dbNameValue
m2[collectionNameKey] = collectionNameValue
m2[partitionNamesKey] = partitionNamesValue
m2[exprKey] = exprValue
m2[outputFieldsKey] = outputFieldsValue
bs2, err := json.Marshal(m2)
log.Info("Test_parseDummyQueryRequest",
zap.String("json", string(bs2)))
assert.Nil(t, err)
ret2, err := parseDummyQueryRequest(string(bs2))
assert.Nil(t, err)
assert.Equal(t, requestTypeValue, ret2.RequestType)
assert.Equal(t, dbNameValue, ret2.DbName)
assert.Equal(t, collectionNameValue, ret2.CollectionName)
assert.Equal(t, partitionNamesValue, ret2.PartitionNames)
assert.Equal(t, exprValue, ret2.Expr)
assert.Equal(t, outputFieldsValue, ret2.OutputFields)
// all fields and other field
m3 := make(map[string]interface{})
m3[otherField] = otherFieldValue
m3[requestTypeKey] = requestTypeValue
m3[dbNameKey] = dbNameValue
m3[collectionNameKey] = collectionNameValue
m3[partitionNamesKey] = partitionNamesValue
m3[exprKey] = exprValue
m3[outputFieldsKey] = outputFieldsValue
bs3, err := json.Marshal(m3)
log.Info("Test_parseDummyQueryRequest",
zap.String("json", string(bs3)))
assert.Nil(t, err)
ret3, err := parseDummyQueryRequest(string(bs3))
assert.Nil(t, err)
assert.Equal(t, requestTypeValue, ret3.RequestType)
assert.Equal(t, dbNameValue, ret3.DbName)
assert.Equal(t, collectionNameValue, ret3.CollectionName)
assert.Equal(t, partitionNamesValue, ret3.PartitionNames)
assert.Equal(t, exprValue, ret3.Expr)
assert.Equal(t, outputFieldsValue, ret3.OutputFields)
// partial fields
m4 := make(map[string]interface{})
m4[requestTypeKey] = requestTypeValue
m4[dbNameKey] = dbNameValue
bs4, err := json.Marshal(m4)
log.Info("Test_parseDummyQueryRequest",
zap.String("json", string(bs4)))
assert.Nil(t, err)
ret4, err := parseDummyQueryRequest(string(bs4))
assert.Nil(t, err)
assert.Equal(t, requestTypeValue, ret4.RequestType)
assert.Equal(t, dbNameValue, ret4.DbName)
assert.Equal(t, collectionNameValue, ret2.CollectionName)
assert.Equal(t, partitionNamesValue, ret2.PartitionNames)
assert.Equal(t, exprValue, ret2.Expr)
assert.Equal(t, outputFieldsValue, ret2.OutputFields)
// partial fields and other field
m5 := make(map[string]interface{})
m5[otherField] = otherFieldValue
m5[requestTypeKey] = requestTypeValue
m5[dbNameKey] = dbNameValue
bs5, err := json.Marshal(m5)
log.Info("Test_parseDummyQueryRequest",
zap.String("json", string(bs5)))
assert.Nil(t, err)
ret5, err := parseDummyQueryRequest(string(bs5))
assert.Nil(t, err)
assert.Equal(t, requestTypeValue, ret5.RequestType)
assert.Equal(t, dbNameValue, ret5.DbName)
assert.Equal(t, collectionNameValue, ret2.CollectionName)
assert.Equal(t, partitionNamesValue, ret2.PartitionNames)
assert.Equal(t, exprValue, ret2.Expr)
assert.Equal(t, outputFieldsValue, ret2.OutputFields)
}
// func TestParseDummyQueryRequest(t *testing.T) {
// invalidStr := `{"request_type":"query"`
// _, err := parseDummyQueryRequest(invalidStr)
......
......@@ -14,28 +14,134 @@ package proxy
import (
"testing"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/milvus-io/milvus/internal/log"
"go.uber.org/zap"
)
func TestMsgProxyIsUnhealthy(t *testing.T) {
func Test_errInvalidNumRows(t *testing.T) {
invalidNumRowsList := []uint32{
0,
16384,
}
for _, invalidNumRows := range invalidNumRowsList {
log.Info("Test_errInvalidNumRows",
zap.Error(errInvalidNumRows(invalidNumRows)))
}
}
func Test_errNumRowsLessThanOrEqualToZero(t *testing.T) {
invalidNumRowsList := []uint32{
0,
16384,
}
for _, invalidNumRows := range invalidNumRowsList {
log.Info("Test_errNumRowsLessThanOrEqualToZero",
zap.Error(errNumRowsLessThanOrEqualToZero(invalidNumRows)))
}
}
func Test_errEmptyFieldData(t *testing.T) {
log.Info("Test_errEmptyFieldData",
zap.Error(errEmptyFieldData))
}
func Test_errFieldsLessThanNeeded(t *testing.T) {
cases := []struct {
fieldsNum int
neededNum int
}{
{0, 1},
{1, 2},
}
for _, test := range cases {
log.Info("Test_errFieldsLessThanNeeded",
zap.Error(errFieldsLessThanNeeded(test.fieldsNum, test.neededNum)))
}
}
func Test_errUnsupportedDataType(t *testing.T) {
unsupportedDTypes := []schemapb.DataType{
schemapb.DataType_None,
}
for _, dType := range unsupportedDTypes {
log.Info("Test_errUnsupportedDataType",
zap.Error(errUnsupportedDataType(dType)))
}
}
func Test_errUnsupportedDType(t *testing.T) {
unsupportedDTypes := []string{
"bytes",
"None",
}
for _, dType := range unsupportedDTypes {
log.Info("Test_errUnsupportedDType",
zap.Error(errUnsupportedDType(dType)))
}
}
func Test_errInvalidDim(t *testing.T) {
invalidDimList := []int{
0,
-1,
}
for _, invalidDim := range invalidDimList {
log.Info("Test_errInvalidDim",
zap.Error(errInvalidDim(invalidDim)))
}
}
func Test_errDimLessThanOrEqualToZero(t *testing.T) {
invalidDimList := []int{
0,
-1,
}
for _, invalidDim := range invalidDimList {
log.Info("Test_errDimLessThanOrEqualToZero",
zap.Error(errDimLessThanOrEqualToZero(invalidDim)))
}
}
func Test_errDimShouldDivide8(t *testing.T) {
invalidDimList := []int{
0,
1,
7,
}
for _, invalidDim := range invalidDimList {
log.Info("Test_errDimShouldDivide8",
zap.Error(errDimShouldDivide8(invalidDim)))
}
}
func Test_msgProxyIsUnhealthy(t *testing.T) {
ids := []UniqueID{
1,
}
for _, id := range ids {
log.Info("TestMsgProxyIsUnhealthy",
log.Info("Test_msgProxyIsUnhealthy",
zap.String("msg", msgProxyIsUnhealthy(id)))
}
}
func TestErrProxyIsUnhealthy(t *testing.T) {
func Test_errProxyIsUnhealthy(t *testing.T) {
ids := []UniqueID{
1,
}
for _, id := range ids {
log.Info("TestErrProxyIsUnhealthy",
log.Info("Test_errProxyIsUnhealthy",
zap.Error(errProxyIsUnhealthy(id)))
}
}
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package proxy
import (
"context"
"errors"
"fmt"
"sort"
"sync"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/msgstream"
"go.uber.org/zap"
)
type insertChannelsMap struct {
collectionID2InsertChannels map[UniqueID]int // the value of map is the location of insertChannels & insertMsgStreams
insertChannels [][]string // it's a little confusing to use []string as the key of map
insertMsgStreams []msgstream.MsgStream // maybe there's a better way to implement Set, just agilely now
droppedBitMap []int // 0 -> normal, 1 -> dropped
usageHistogram []int // message stream can be closed only when the use count is zero
// TODO: use fine grained lock
mtx sync.RWMutex
nodeInstance *Proxy
msFactory msgstream.Factory
}
func (m *insertChannelsMap) CreateInsertMsgStream(collID UniqueID, channels []string) error {
m.mtx.Lock()
defer m.mtx.Unlock()
_, ok := m.collectionID2InsertChannels[collID]
if ok {
return errors.New("impossible and forbidden to create message stream twice")
}
sort.Slice(channels, func(i, j int) bool {
return channels[i] <= channels[j]
})
for loc, existedChannels := range m.insertChannels {
if m.droppedBitMap[loc] == 0 && funcutil.SortedSliceEqual(existedChannels, channels) {
m.collectionID2InsertChannels[collID] = loc
m.usageHistogram[loc]++
return nil
}
}
m.insertChannels = append(m.insertChannels, channels)
m.collectionID2InsertChannels[collID] = len(m.insertChannels) - 1
stream, _ := m.msFactory.NewMsgStream(context.Background())
stream.AsProducer(channels)
log.Debug("proxy", zap.Strings("proxy AsProducer: ", channels))
stream.SetRepackFunc(insertRepackFunc)
stream.Start()
m.insertMsgStreams = append(m.insertMsgStreams, stream)
m.droppedBitMap = append(m.droppedBitMap, 0)
m.usageHistogram = append(m.usageHistogram, 1)
return nil
}
func (m *insertChannelsMap) CloseInsertMsgStream(collID UniqueID) error {
m.mtx.Lock()
defer m.mtx.Unlock()
loc, ok := m.collectionID2InsertChannels[collID]
if !ok {
return fmt.Errorf("cannot find collection with id %d", collID)
}
if m.droppedBitMap[loc] != 0 {
return errors.New("insert message stream already closed")
}
if m.usageHistogram[loc] <= 0 {
return errors.New("insert message stream already closed")
}
m.usageHistogram[loc]--
if m.usageHistogram[loc] <= 0 {
m.insertMsgStreams[loc].Close()
m.droppedBitMap[loc] = 1
log.Warn("close insert message stream ...")
}
delete(m.collectionID2InsertChannels, collID)
return nil
}
func (m *insertChannelsMap) GetInsertChannels(collID UniqueID) ([]string, error) {
m.mtx.RLock()
defer m.mtx.RUnlock()
loc, ok := m.collectionID2InsertChannels[collID]
if !ok {
return nil, fmt.Errorf("cannot find collection with id: %d", collID)
}
if m.droppedBitMap[loc] != 0 {
return nil, errors.New("insert message stream already closed")
}
ret := append([]string(nil), m.insertChannels[loc]...)
return ret, nil
}
func (m *insertChannelsMap) GetInsertMsgStream(collID UniqueID) (msgstream.MsgStream, error) {
m.mtx.RLock()
defer m.mtx.RUnlock()
loc, ok := m.collectionID2InsertChannels[collID]
if !ok {
return nil, fmt.Errorf("cannot find collection with id: %d", collID)
}
if m.droppedBitMap[loc] != 0 {
return nil, errors.New("insert message stream already closed")
}
return m.insertMsgStreams[loc], nil
}
func (m *insertChannelsMap) CloseAllMsgStream() {
m.mtx.Lock()
defer m.mtx.Unlock()
for loc, stream := range m.insertMsgStreams {
if m.droppedBitMap[loc] == 0 && m.usageHistogram[loc] >= 1 {
stream.Close()
}
}
m.collectionID2InsertChannels = make(map[UniqueID]int)
m.insertChannels = make([][]string, 0)
m.insertMsgStreams = make([]msgstream.MsgStream, 0)
m.droppedBitMap = make([]int, 0)
m.usageHistogram = make([]int, 0)
}
func newInsertChannelsMap(node *Proxy) *insertChannelsMap {
return &insertChannelsMap{
collectionID2InsertChannels: make(map[UniqueID]int),
insertChannels: make([][]string, 0),
insertMsgStreams: make([]msgstream.MsgStream, 0),
droppedBitMap: make([]int, 0),
usageHistogram: make([]int, 0),
nodeInstance: node,
msFactory: node.msFactory,
}
}
var globalInsertChannelsMap *insertChannelsMap
var initGlobalInsertChannelsMapOnce sync.Once
// change to singleton mode later? Such as GetInsertChannelsMapInstance like GetConfAdapterMgrInstance.
func initGlobalInsertChannelsMap(node *Proxy) {
initGlobalInsertChannelsMapOnce.Do(func() {
globalInsertChannelsMap = newInsertChannelsMap(node)
})
}
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package proxy
import (
"testing"
"github.com/milvus-io/milvus/internal/util/funcutil"
"github.com/stretchr/testify/assert"
"github.com/milvus-io/milvus/internal/msgstream"
)
func TestInsertChannelsMap_CreateInsertMsgStream(t *testing.T) {
msFactory := msgstream.NewSimpleMsgStreamFactory()
node := &Proxy{
segAssigner: nil,
msFactory: msFactory,
}
m := newInsertChannelsMap(node)
var err error
err = m.CreateInsertMsgStream(1, []string{"1"})
assert.Equal(t, nil, err)
// duplicated
err = m.CreateInsertMsgStream(1, []string{"1"})
assert.NotEqual(t, nil, err)
// duplicated
err = m.CreateInsertMsgStream(1, []string{"1", "2"})
assert.NotEqual(t, nil, err)
// use same channels
err = m.CreateInsertMsgStream(2, []string{"1"})
assert.Equal(t, nil, err)
err = m.CreateInsertMsgStream(3, []string{"3"})
assert.Equal(t, nil, err)
}
func TestInsertChannelsMap_CloseInsertMsgStream(t *testing.T) {
msFactory := msgstream.NewSimpleMsgStreamFactory()
node := &Proxy{
segAssigner: nil,
msFactory: msFactory,
}
m := newInsertChannelsMap(node)
var err error
_ = m.CreateInsertMsgStream(1, []string{"1"})
_ = m.CreateInsertMsgStream(2, []string{"1"})
_ = m.CreateInsertMsgStream(3, []string{"3"})
// don't exist
err = m.CloseInsertMsgStream(0)
assert.NotEqual(t, nil, err)
err = m.CloseInsertMsgStream(1)
assert.Equal(t, nil, err)
// close twice
err = m.CloseInsertMsgStream(1)
assert.NotEqual(t, nil, err)
err = m.CloseInsertMsgStream(2)
assert.Equal(t, nil, err)
// close twice
err = m.CloseInsertMsgStream(2)
assert.NotEqual(t, nil, err)
err = m.CloseInsertMsgStream(3)
assert.Equal(t, nil, err)
// close twice
err = m.CloseInsertMsgStream(3)
assert.NotEqual(t, nil, err)
}
func TestInsertChannelsMap_GetInsertChannels(t *testing.T) {
msFactory := msgstream.NewSimpleMsgStreamFactory()
node := &Proxy{
segAssigner: nil,
msFactory: msFactory,
}
m := newInsertChannelsMap(node)
var err error
var channels []string
_ = m.CreateInsertMsgStream(1, []string{"1"})
_ = m.CreateInsertMsgStream(2, []string{"1"})
_ = m.CreateInsertMsgStream(3, []string{"3"})
// don't exist
channels, err = m.GetInsertChannels(0)
assert.NotEqual(t, nil, err)
assert.Equal(t, 0, len(channels))
channels, err = m.GetInsertChannels(1)
assert.Equal(t, nil, err)
assert.Equal(t, true, funcutil.SortedSliceEqual(channels, []string{"1"}))
channels, err = m.GetInsertChannels(2)
assert.Equal(t, nil, err)
assert.Equal(t, true, funcutil.SortedSliceEqual(channels, []string{"1"}))
channels, err = m.GetInsertChannels(3)
assert.Equal(t, nil, err)
assert.Equal(t, true, funcutil.SortedSliceEqual(channels, []string{"3"}))
_ = m.CloseInsertMsgStream(1)
channels, err = m.GetInsertChannels(1)
assert.NotEqual(t, nil, err)
assert.Equal(t, 0, len(channels))
_ = m.CloseInsertMsgStream(2)
channels, err = m.GetInsertChannels(2)
assert.NotEqual(t, nil, err)
assert.Equal(t, 0, len(channels))
_ = m.CloseInsertMsgStream(3)
channels, err = m.GetInsertChannels(3)
assert.NotEqual(t, nil, err)
assert.Equal(t, 0, len(channels))
}
func TestInsertChannelsMap_GetInsertMsgStream(t *testing.T) {
msFactory := msgstream.NewSimpleMsgStreamFactory()
node := &Proxy{
segAssigner: nil,
msFactory: msFactory,
}
m := newInsertChannelsMap(node)
var err error
var stream msgstream.MsgStream
_ = m.CreateInsertMsgStream(1, []string{"1"})
_ = m.CreateInsertMsgStream(2, []string{"1"})
_ = m.CreateInsertMsgStream(3, []string{"3"})
// don't exist
stream, err = m.GetInsertMsgStream(0)
assert.NotEqual(t, nil, err)
assert.Equal(t, nil, stream)
stream, err = m.GetInsertMsgStream(1)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, stream)
stream, err = m.GetInsertMsgStream(2)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, stream)
stream, err = m.GetInsertMsgStream(3)
assert.Equal(t, nil, err)
assert.NotEqual(t, nil, stream)
_ = m.CloseInsertMsgStream(1)
stream, err = m.GetInsertMsgStream(1)
assert.NotEqual(t, nil, err)
assert.Equal(t, nil, stream)
_ = m.CloseInsertMsgStream(2)
stream, err = m.GetInsertMsgStream(2)
assert.NotEqual(t, nil, err)
assert.Equal(t, nil, stream)
_ = m.CloseInsertMsgStream(3)
stream, err = m.GetInsertMsgStream(3)
assert.NotEqual(t, nil, err)
assert.Equal(t, nil, stream)
}
func TestInsertChannelsMap_CloseAllMsgStream(t *testing.T) {
msFactory := msgstream.NewSimpleMsgStreamFactory()
node := &Proxy{
segAssigner: nil,
msFactory: msFactory,
}
m := newInsertChannelsMap(node)
var err error
var stream msgstream.MsgStream
var channels []string
_ = m.CreateInsertMsgStream(1, []string{"1"})
_ = m.CreateInsertMsgStream(2, []string{"1"})
_ = m.CreateInsertMsgStream(3, []string{"3"})
m.CloseAllMsgStream()
err = m.CloseInsertMsgStream(1)
assert.NotEqual(t, nil, err)
err = m.CloseInsertMsgStream(2)
assert.NotEqual(t, nil, err)
err = m.CloseInsertMsgStream(3)
assert.NotEqual(t, nil, err)
channels, err = m.GetInsertChannels(1)
assert.NotEqual(t, nil, err)
assert.Equal(t, 0, len(channels))
channels, err = m.GetInsertChannels(2)
assert.NotEqual(t, nil, err)
assert.Equal(t, 0, len(channels))
channels, err = m.GetInsertChannels(3)
assert.NotEqual(t, nil, err)
assert.Equal(t, 0, len(channels))
stream, err = m.GetInsertMsgStream(1)
assert.NotEqual(t, nil, err)
assert.Equal(t, nil, stream)
stream, err = m.GetInsertMsgStream(2)
assert.NotEqual(t, nil, err)
assert.Equal(t, nil, stream)
stream, err = m.GetInsertMsgStream(3)
assert.NotEqual(t, nil, err)
assert.Equal(t, nil, stream)
}
......@@ -31,16 +31,21 @@ const (
type ParamTable struct {
paramtable.BaseTable
NetworkPort int
IP string
// NetworkPort & IP are not used
NetworkPort int
IP string
NetworkAddress string
Alias string
// TODO(dragondriver): maybe using the Proxy + ProxyID as the alias is more reasonable
Alias string
EtcdEndpoints []string
MetaRootPath string
RootCoordAddress string
PulsarAddress string
RocksmqPath string
RocksmqPath string // not used in Proxy
ProxyID UniqueID
TimeTickInterval time.Duration
......
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package proxy
import "testing"
func TestParamTable(t *testing.T) {
Params.Init()
t.Run("EtcdEndPoints", func(t *testing.T) {
t.Logf("EtcdEndPoints: %v", Params.EtcdEndpoints)
})
t.Run("MetaRootPath", func(t *testing.T) {
t.Logf("MetaRootPath: %s", Params.MetaRootPath)
})
t.Run("PulsarAddress", func(t *testing.T) {
t.Logf("PulsarAddress: %s", Params.PulsarAddress)
})
t.Run("RocksmqPath", func(t *testing.T) {
t.Logf("RocksmqPath: %s", Params.RocksmqPath)
})
t.Run("TimeTickInterval", func(t *testing.T) {
t.Logf("TimeTickInterval: %v", Params.TimeTickInterval)
})
t.Run("ProxySubName", func(t *testing.T) {
t.Logf("ProxySubName: %s", Params.ProxySubName)
})
t.Run("ProxyTimeTickChannelNames", func(t *testing.T) {
t.Logf("ProxyTimeTickChannelNames: %v", Params.ProxyTimeTickChannelNames)
})
t.Run("MsgStreamTimeTickBufSize", func(t *testing.T) {
t.Logf("MsgStreamTimeTickBufSize: %d", Params.MsgStreamTimeTickBufSize)
})
t.Run("MaxNameLength", func(t *testing.T) {
t.Logf("MaxNameLength: %d", Params.MaxNameLength)
})
t.Run("MaxFieldNum", func(t *testing.T) {
t.Logf("MaxFieldNum: %d", Params.MaxFieldNum)
})
t.Run("MaxDimension", func(t *testing.T) {
t.Logf("MaxDimension: %d", Params.MaxDimension)
})
t.Run("DefaultPartitionName", func(t *testing.T) {
t.Logf("DefaultPartitionName: %s", Params.DefaultPartitionName)
})
t.Run("DefaultIndexName", func(t *testing.T) {
t.Logf("DefaultIndexName: %s", Params.DefaultIndexName)
})
t.Run("PulsarMaxMessageSize", func(t *testing.T) {
t.Logf("PulsarMaxMessageSize: %d", Params.PulsarMaxMessageSize)
})
t.Run("RoleName", func(t *testing.T) {
t.Logf("RoleName: %s", Params.RoleName)
})
}
......@@ -250,7 +250,7 @@ func (node *Proxy) Init() error {
return m, nil
}
chMgr := newChannelsMgr(getDmlChannelsFunc, defaultInsertRepackFunc, getDqlChannelsFunc, nil, node.msFactory)
chMgr := newChannelsMgrImpl(getDmlChannelsFunc, defaultInsertRepackFunc, getDqlChannelsFunc, nil, node.msFactory)
node.chMgr = chMgr
node.sched, err = NewTaskScheduler(node.ctx, node.idAllocator, node.tsoAllocator, node.msFactory)
......
......@@ -12,11 +12,23 @@
package proxy
import (
"fmt"
"github.com/milvus-io/milvus/internal/msgstream"
)
func insertRepackFunc(tsMsgs []msgstream.TsMsg,
hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
func insertRepackFunc(
tsMsgs []msgstream.TsMsg,
hashKeys [][]int32,
) (map[int32]*msgstream.MsgPack, error) {
if len(hashKeys) < len(tsMsgs) {
return nil, fmt.Errorf(
"the length of hash keys (%d) is less than the length of messages (%d)",
len(hashKeys),
len(tsMsgs),
)
}
result := make(map[int32]*msgstream.MsgPack)
for i, request := range tsMsgs {
......@@ -28,17 +40,32 @@ func insertRepackFunc(tsMsgs []msgstream.TsMsg,
result[key] = &msgstream.MsgPack{}
}
result[key].Msgs = append(result[key].Msgs, request)
} else {
return nil, fmt.Errorf("no hash key for %dth message", i)
}
}
return result, nil
}
func defaultInsertRepackFunc(tsMsgs []msgstream.TsMsg, hashKeys [][]int32) (map[int32]*msgstream.MsgPack, error) {
func defaultInsertRepackFunc(
tsMsgs []msgstream.TsMsg,
hashKeys [][]int32,
) (map[int32]*msgstream.MsgPack, error) {
if len(hashKeys) < len(tsMsgs) {
return nil, fmt.Errorf(
"the length of hash keys (%d) is less than the length of messages (%d)",
len(hashKeys),
len(tsMsgs),
)
}
// after assigning segment id to msg, tsMsgs was already re-bucketed
pack := make(map[int32]*msgstream.MsgPack)
for idx, msg := range tsMsgs {
if len(hashKeys[idx]) <= 0 {
continue
return nil, fmt.Errorf("no hash key for %dth message", idx)
}
key := hashKeys[idx][0]
_, ok := pack[key]
......
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package proxy
import (
"math/rand"
"testing"
"github.com/milvus-io/milvus/internal/msgstream"
"github.com/stretchr/testify/assert"
)
func Test_insertRepackFunc(t *testing.T) {
var err error
// tsMsgs is empty
ret1, err := insertRepackFunc(nil, [][]int32{{1, 2}})
assert.Nil(t, err)
assert.Equal(t, 0, len(ret1))
// hashKeys is empty
tsMsgs2 := []msgstream.TsMsg{
&msgstream.InsertMsg{}, // not important
&msgstream.InsertMsg{}, // not important
}
ret2, err := insertRepackFunc(tsMsgs2, nil)
assert.NotNil(t, err)
assert.Nil(t, ret2)
// len(hashKeys) < len(tsMsgs), 1 < 2
ret2, err = insertRepackFunc(tsMsgs2, [][]int32{{1, 2}})
assert.NotNil(t, err)
assert.Nil(t, ret2)
// both tsMsgs & hashKeys are empty
ret3, err := insertRepackFunc(nil, nil)
assert.Nil(t, err)
assert.Equal(t, 0, len(ret3))
num := rand.Int()%100 + 1
tsMsgs4 := make([]msgstream.TsMsg, 0)
for i := 0; i < num; i++ {
tsMsgs4 = append(tsMsgs4, &msgstream.InsertMsg{
// not important
})
}
// len(hashKeys) = len(tsMsgs), but no hash key
hashKeys1 := make([][]int32, num)
ret4, err := insertRepackFunc(tsMsgs4, hashKeys1)
assert.NotNil(t, err)
assert.Nil(t, ret4)
// all messages are shuffled to same bucket
hashKeys2 := make([][]int32, num)
key := int32(0)
for i := 0; i < num; i++ {
hashKeys2[i] = []int32{key}
}
ret5, err := insertRepackFunc(tsMsgs4, hashKeys2)
assert.Nil(t, err)
assert.Equal(t, 1, len(ret5))
assert.Equal(t, num, len(ret5[key].Msgs))
// evenly shuffle
hashKeys3 := make([][]int32, num)
for i := 0; i < num; i++ {
hashKeys3[i] = []int32{int32(i)}
}
ret6, err := insertRepackFunc(tsMsgs4, hashKeys3)
assert.Nil(t, err)
assert.Equal(t, num, len(ret6))
for key := range ret6 {
assert.Equal(t, 1, len(ret6[key].Msgs))
}
// randomly shuffle
histogram := make(map[int32]int) // key -> key num
hashKeys4 := make([][]int32, num)
for i := 0; i < num; i++ {
k := int32(rand.Uint32())
hashKeys4[i] = []int32{k}
_, exist := histogram[k]
if exist {
histogram[k]++
} else {
histogram[k] = 1
}
}
ret7, err := insertRepackFunc(tsMsgs4, hashKeys4)
assert.Nil(t, err)
assert.Equal(t, len(histogram), len(ret7))
for key := range ret7 {
assert.Equal(t, histogram[key], len(ret7[key].Msgs))
}
}
func Test_defaultInsertRepackFunc(t *testing.T) {
var err error
// tsMsgs is empty
ret1, err := defaultInsertRepackFunc(nil, [][]int32{{1, 2}})
assert.Nil(t, err)
assert.Equal(t, 0, len(ret1))
// hashKeys is empty
tsMsgs2 := []msgstream.TsMsg{
&msgstream.InsertMsg{}, // not important
&msgstream.InsertMsg{}, // not important
}
ret2, err := defaultInsertRepackFunc(tsMsgs2, nil)
assert.NotNil(t, err)
assert.Nil(t, ret2)
// len(hashKeys) < len(tsMsgs), 1 < 2
ret2, err = defaultInsertRepackFunc(tsMsgs2, [][]int32{{1, 2}})
assert.NotNil(t, err)
assert.Nil(t, ret2)
// both tsMsgs & hashKeys are empty
ret3, err := defaultInsertRepackFunc(nil, nil)
assert.Nil(t, err)
assert.Equal(t, 0, len(ret3))
num := rand.Int()%100 + 1
tsMsgs4 := make([]msgstream.TsMsg, 0)
for i := 0; i < num; i++ {
tsMsgs4 = append(tsMsgs4, &msgstream.InsertMsg{
// not important
})
}
// len(hashKeys) = len(tsMsgs), but no hash key
hashKeys1 := make([][]int32, num)
ret4, err := defaultInsertRepackFunc(tsMsgs4, hashKeys1)
assert.NotNil(t, err)
assert.Nil(t, ret4)
// all messages are shuffled to same bucket
hashKeys2 := make([][]int32, num)
key := int32(0)
for i := 0; i < num; i++ {
hashKeys2[i] = []int32{key}
}
ret5, err := defaultInsertRepackFunc(tsMsgs4, hashKeys2)
assert.Nil(t, err)
assert.Equal(t, 1, len(ret5))
assert.Equal(t, num, len(ret5[key].Msgs))
// evenly shuffle
hashKeys3 := make([][]int32, num)
for i := 0; i < num; i++ {
hashKeys3[i] = []int32{int32(i)}
}
ret6, err := defaultInsertRepackFunc(tsMsgs4, hashKeys3)
assert.Nil(t, err)
assert.Equal(t, num, len(ret6))
for key := range ret6 {
assert.Equal(t, 1, len(ret6[key].Msgs))
}
// randomly shuffle
histogram := make(map[int32]int) // key -> key num
hashKeys4 := make([][]int32, num)
for i := 0; i < num; i++ {
k := int32(rand.Uint32())
hashKeys4[i] = []int32{k}
_, exist := histogram[k]
if exist {
histogram[k]++
} else {
histogram[k] = 1
}
}
ret7, err := defaultInsertRepackFunc(tsMsgs4, hashKeys4)
assert.Nil(t, err)
assert.Equal(t, len(histogram), len(ret7))
for key := range ret7 {
assert.Equal(t, histogram[key], len(ret7[key].Msgs))
}
}
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
package proxy
type vChan = string
type pChan = string
type pChanStatistics struct {
minTs Timestamp
maxTs Timestamp
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册