未验证 提交 6fec7ed6 编写于 作者: B bigsheeper 提交者: GitHub

Add search benchmark in QueryNode (#15854)

Signed-off-by: Nbigsheeper <yihao.dai@zilliz.com>
上级 1e4d2a37
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package querynode
import (
"context"
"os"
"runtime/pprof"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/zap/zapcore"
"github.com/milvus-io/milvus/internal/log"
msgstream2 "github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/util/typeutil"
)
const (
maxNQ = 100
nb = 10000
)
func benchmarkQueryCollectionSearch(nq int, b *testing.B) {
log.SetLevel(zapcore.ErrorLevel)
defer log.SetLevel(zapcore.DebugLevel)
tx, cancel := context.WithCancel(context.Background())
queryCollection, err := genSimpleQueryCollection(tx, cancel)
assert.NoError(b, err)
// search only one segment
err = queryCollection.streaming.replica.removeSegment(defaultSegmentID)
assert.NoError(b, err)
err = queryCollection.historical.replica.removeSegment(defaultSegmentID)
assert.NoError(b, err)
assert.Equal(b, 0, queryCollection.historical.replica.getSegmentNum())
assert.Equal(b, 0, queryCollection.streaming.replica.getSegmentNum())
segment, err := genSealedSegmentWithMsgLength(nb)
assert.NoError(b, err)
err = queryCollection.historical.replica.setSegment(segment)
assert.NoError(b, err)
sessionManager := NewSessionManager(withSessionCreator(mockProxyCreator()))
sessionManager.AddSession(&NodeInfo{
NodeID: 0,
Address: "",
})
queryCollection.sessionManager = sessionManager
// segment check
assert.Equal(b, 1, queryCollection.historical.replica.getSegmentNum())
assert.Equal(b, 0, queryCollection.streaming.replica.getSegmentNum())
seg, err := queryCollection.historical.replica.getSegmentByID(defaultSegmentID)
assert.NoError(b, err)
assert.Equal(b, int64(nb), seg.getRowCount())
sizePerRecord, err := typeutil.EstimateSizePerRecord(genSimpleSegCoreSchema())
assert.NoError(b, err)
expectSize := sizePerRecord * nb
assert.Equal(b, seg.getMemSize(), int64(expectSize))
// warming up
msgTmp, err := genSearchMsg(10)
assert.NoError(b, err)
for j := 0; j < 10000; j++ {
err = queryCollection.search(msgTmp)
assert.NoError(b, err)
}
msgs := make([]*msgstream2.SearchMsg, maxNQ/nq)
for i := 0; i < maxNQ/nq; i++ {
msg, err := genSearchMsg(nq)
assert.NoError(b, err)
msgs[i] = msg
}
f, err := os.Create("nq_" + strconv.Itoa(nq) + ".perf")
if err != nil {
panic(err)
}
if err = pprof.StartCPUProfile(f); err != nil {
panic(err)
}
defer pprof.StopCPUProfile()
// start benchmark
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < maxNQ/nq; j++ {
err = queryCollection.search(msgs[j])
assert.NoError(b, err)
}
}
}
func benchmarkQueryCollectionSearchIndex(nq int, indexType string, b *testing.B) {
log.SetLevel(zapcore.ErrorLevel)
defer log.SetLevel(zapcore.DebugLevel)
tx, cancel := context.WithCancel(context.Background())
queryCollection, err := genSimpleQueryCollection(tx, cancel)
assert.NoError(b, err)
err = queryCollection.historical.replica.removeSegment(defaultSegmentID)
assert.NoError(b, err)
err = queryCollection.streaming.replica.removeSegment(defaultSegmentID)
assert.NoError(b, err)
assert.Equal(b, 0, queryCollection.historical.replica.getSegmentNum())
assert.Equal(b, 0, queryCollection.streaming.replica.getSegmentNum())
node, err := genSimpleQueryNode(tx)
assert.NoError(b, err)
node.loader.historicalReplica = queryCollection.historical.replica
err = loadIndexForSegment(tx, node, defaultSegmentID, nb, indexType, L2)
assert.NoError(b, err)
sessionManager := NewSessionManager(withSessionCreator(mockProxyCreator()))
sessionManager.AddSession(&NodeInfo{
NodeID: 0,
Address: "",
})
queryCollection.sessionManager = sessionManager
// segment check
assert.Equal(b, 1, queryCollection.historical.replica.getSegmentNum())
assert.Equal(b, 0, queryCollection.streaming.replica.getSegmentNum())
seg, err := queryCollection.historical.replica.getSegmentByID(defaultSegmentID)
assert.NoError(b, err)
assert.Equal(b, int64(nb), seg.getRowCount())
sizePerRecord, err := typeutil.EstimateSizePerRecord(genSimpleSegCoreSchema())
assert.NoError(b, err)
expectSize := sizePerRecord * nb
assert.Equal(b, seg.getMemSize(), int64(expectSize))
// warming up
msgTmp, err := genSearchMsg(10)
assert.NoError(b, err)
for j := 0; j < 10000; j++ {
err = queryCollection.search(msgTmp)
assert.NoError(b, err)
}
msgs := make([]*msgstream2.SearchMsg, maxNQ/nq)
for i := 0; i < maxNQ/nq; i++ {
msg, err := genSearchMsg(nq)
assert.NoError(b, err)
msgs[i] = msg
}
f, err := os.Create(indexType + "_nq_" + strconv.Itoa(nq) + ".perf")
if err != nil {
panic(err)
}
if err = pprof.StartCPUProfile(f); err != nil {
panic(err)
}
defer pprof.StopCPUProfile()
// start benchmark
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < maxNQ/nq; j++ {
err = queryCollection.search(msgs[j])
assert.NoError(b, err)
}
}
}
func BenchmarkSearch_NQ1(b *testing.B) { benchmarkQueryCollectionSearch(1, b) }
func BenchmarkSearch_NQ10(b *testing.B) { benchmarkQueryCollectionSearch(10, b) }
func BenchmarkSearch_NQ100(b *testing.B) { benchmarkQueryCollectionSearch(100, b) }
func BenchmarkSearch_NQ1000(b *testing.B) { benchmarkQueryCollectionSearch(1000, b) }
func BenchmarkSearch_NQ10000(b *testing.B) { benchmarkQueryCollectionSearch(10000, b) }
func BenchmarkSearch_IVFFLAT_NQ1(b *testing.B) {
benchmarkQueryCollectionSearchIndex(1, IndexFaissIVFFlat, b)
}
func BenchmarkSearch_IVFFLAT_NQ10(b *testing.B) {
benchmarkQueryCollectionSearchIndex(10, IndexFaissIVFFlat, b)
}
func BenchmarkSearch_IVFFLAT_NQ100(b *testing.B) {
benchmarkQueryCollectionSearchIndex(100, IndexFaissIVFFlat, b)
}
func BenchmarkSearch_IVFFLAT_NQ1000(b *testing.B) {
benchmarkQueryCollectionSearchIndex(1000, IndexFaissIVFFlat, b)
}
func BenchmarkSearch_IVFFLAT_NQ10000(b *testing.B) {
benchmarkQueryCollectionSearchIndex(10000, IndexFaissIVFFlat, b)
}
......@@ -19,6 +19,7 @@ package querynode
import (
"context"
"errors"
"fmt"
"math"
"math/rand"
"strconv"
......@@ -43,6 +44,7 @@ import (
"github.com/milvus-io/milvus/internal/storage"
"github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/internal/util/etcd"
"github.com/milvus-io/milvus/internal/util/funcutil"
)
// ---------- unittest util functions ----------
......@@ -61,7 +63,8 @@ const (
defaultRoundDecimal = int64(6)
defaultDim = 128
defaultNProb = 10
defaultMetricType = "JACCARD"
defaultMetricType = L2
defaultNQ = 10
defaultDMLChannel = "query-node-unittest-DML-0"
defaultDeltaChannel = "query-node-unittest-delta-channel-0"
......@@ -89,6 +92,44 @@ const (
indexName = "query-node-index-0"
)
const (
// index type
IndexFaissIDMap = "FLAT"
IndexFaissIVFFlat = "IVF_FLAT"
IndexFaissIVFPQ = "IVF_PQ"
IndexFaissIVFSQ8 = "IVF_SQ8"
IndexFaissIVFSQ8H = "IVF_SQ8_HYBRID"
IndexFaissBinIDMap = "BIN_FLAT"
IndexFaissBinIVFFlat = "BIN_IVF_FLAT"
IndexNsg = "NSG"
IndexHNSW = "HNSW"
IndexRHNSWFlat = "RHNSW_FLAT"
IndexRHNSWPQ = "RHNSW_PQ"
IndexRHNSWSQ = "RHNSW_SQ"
IndexANNOY = "ANNOY"
IndexNGTPANNG = "NGT_PANNG"
IndexNGTONNG = "NGT_ONNG"
// metric type
L2 = "L2"
IP = "IP"
hamming = "HAMMING"
Jaccard = "JACCARD"
tanimoto = "TANIMOTO"
nlist = 100
m = 4
nbits = 8
nprobe = 8
sliceSize = 4
efConstruction = 200
ef = 200
edgeSize = 10
epsilon = 0.1
maxSearchEdges = 50
)
// ---------- unittest util functions ----------
// functions of init meta and generate meta
type vecFieldParam struct {
......@@ -222,6 +263,69 @@ func genIndexBinarySet() ([][]byte, error) {
return bytesSet, nil
}
func loadIndexForSegment(ctx context.Context, node *QueryNode, segmentID UniqueID, msgLength int, indexType string, metricType string) error {
schema := genSimpleInsertDataSchema()
// generate insert binlog
fieldBinlog, err := saveBinLog(ctx, defaultCollectionID, defaultPartitionID, defaultSegmentID, msgLength, schema)
if err != nil {
return err
}
// generate index file for segment
indexPaths, err := generateAndSaveIndex(segmentID, msgLength, indexType, metricType)
if err != nil {
return err
}
_, indexParams := genIndexParams(indexType, metricType)
indexInfo := &querypb.VecFieldIndexInfo{
FieldID: simpleVecField.id,
EnableIndex: true,
IndexName: indexName,
IndexID: indexID,
BuildID: buildID,
IndexParams: funcutil.Map2KeyValuePair(indexParams),
IndexFilePaths: indexPaths,
}
loader := node.loader
req := &querypb.LoadSegmentsRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_LoadSegments,
MsgID: rand.Int63(),
},
DstNodeID: 0,
Schema: schema,
Infos: []*querypb.SegmentLoadInfo{
{
SegmentID: segmentID,
PartitionID: defaultPartitionID,
CollectionID: defaultCollectionID,
BinlogPaths: fieldBinlog,
IndexInfos: []*querypb.VecFieldIndexInfo{indexInfo},
},
},
}
err = loader.loadSegment(req, segmentTypeSealed)
if err != nil {
return err
}
segment, err := node.loader.historicalReplica.getSegmentByID(segmentID)
if err != nil {
return err
}
vecFieldInfo, err := segment.getVectorFieldInfo(simpleVecField.id)
if err != nil {
return err
}
if vecFieldInfo == nil {
return fmt.Errorf("nil vecFieldInfo, load index failed")
}
return nil
}
func generateIndex(segmentID UniqueID) ([]string, error) {
indexParams := genSimpleIndexParams()
......@@ -303,6 +407,177 @@ func generateIndex(segmentID UniqueID) ([]string, error) {
return indexPaths, nil
}
func generateAndSaveIndex(segmentID UniqueID, msgLength int, indexType, metricType string) ([]string, error) {
typeParams, indexParams := genIndexParams(indexType, metricType)
var indexParamsKV []*commonpb.KeyValuePair
for key, value := range indexParams {
indexParamsKV = append(indexParamsKV, &commonpb.KeyValuePair{
Key: key,
Value: value,
})
}
var indexRowData []float32
for n := 0; n < msgLength; n++ {
for i := 0; i < defaultDim; i++ {
indexRowData = append(indexRowData, rand.Float32())
}
}
index, err := indexnode.NewCIndex(typeParams, indexParams)
if err != nil {
return nil, err
}
err = index.BuildFloatVecIndexWithoutIds(indexRowData)
if err != nil {
return nil, err
}
option := &minioKV.Option{
Address: Params.MinioCfg.Address,
AccessKeyID: Params.MinioCfg.AccessKeyID,
SecretAccessKeyID: Params.MinioCfg.SecretAccessKey,
UseSSL: Params.MinioCfg.UseSSL,
BucketName: Params.MinioCfg.BucketName,
CreateBucket: true,
}
kv, err := minioKV.NewMinIOKV(context.Background(), option)
if err != nil {
return nil, err
}
// save index to minio
binarySet, err := index.Serialize()
if err != nil {
return nil, err
}
// serialize index params
indexCodec := storage.NewIndexFileBinlogCodec()
serializedIndexBlobs, err := indexCodec.Serialize(
buildID,
0,
defaultCollectionID,
defaultPartitionID,
defaultSegmentID,
simpleVecField.id,
indexParams,
indexName,
indexID,
binarySet,
)
if err != nil {
return nil, err
}
indexPaths := make([]string, 0)
for _, index := range serializedIndexBlobs {
p := strconv.Itoa(int(segmentID)) + "/" + index.Key
indexPaths = append(indexPaths, p)
err := kv.Save(p, string(index.Value))
if err != nil {
return nil, err
}
}
return indexPaths, nil
}
func genIndexParams(indexType, metricType string) (map[string]string, map[string]string) {
typeParams := make(map[string]string)
indexParams := make(map[string]string)
indexParams["index_type"] = indexType
indexParams["metric_type"] = metricType
if indexType == IndexFaissIDMap { // float vector
indexParams["dim"] = strconv.Itoa(defaultDim)
indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize)
} else if indexType == IndexFaissIVFFlat {
indexParams["dim"] = strconv.Itoa(defaultDim)
indexParams["nlist"] = strconv.Itoa(nlist)
} else if indexType == IndexFaissIVFPQ {
indexParams["dim"] = strconv.Itoa(defaultDim)
indexParams["nlist"] = strconv.Itoa(nlist)
indexParams["m"] = strconv.Itoa(m)
indexParams["nbits"] = strconv.Itoa(nbits)
indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize)
} else if indexType == IndexFaissIVFSQ8 {
indexParams["dim"] = strconv.Itoa(defaultDim)
indexParams["nlist"] = strconv.Itoa(nlist)
indexParams["nbits"] = strconv.Itoa(nbits)
indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize)
} else if indexType == IndexFaissIVFSQ8H {
// TODO: enable gpu
} else if indexType == IndexNsg {
indexParams["dim"] = strconv.Itoa(defaultDim)
indexParams["nlist"] = strconv.Itoa(163)
indexParams["nprobe"] = strconv.Itoa(nprobe)
indexParams["knng"] = strconv.Itoa(20)
indexParams["search_length"] = strconv.Itoa(40)
indexParams["out_degree"] = strconv.Itoa(30)
indexParams["candidate_pool_size"] = strconv.Itoa(100)
} else if indexType == IndexHNSW {
indexParams["dim"] = strconv.Itoa(defaultDim)
indexParams["M"] = strconv.Itoa(16)
indexParams["efConstruction"] = strconv.Itoa(efConstruction)
//indexParams["ef"] = strconv.Itoa(ef)
} else if indexType == IndexRHNSWFlat {
indexParams["dim"] = strconv.Itoa(defaultDim)
indexParams["m"] = strconv.Itoa(16)
indexParams["efConstruction"] = strconv.Itoa(efConstruction)
indexParams["ef"] = strconv.Itoa(ef)
indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize)
} else if indexType == IndexRHNSWPQ {
indexParams["dim"] = strconv.Itoa(defaultDim)
indexParams["m"] = strconv.Itoa(16)
indexParams["efConstruction"] = strconv.Itoa(efConstruction)
indexParams["ef"] = strconv.Itoa(ef)
indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize)
indexParams["PQM"] = strconv.Itoa(8)
} else if indexType == IndexRHNSWSQ {
indexParams["dim"] = strconv.Itoa(defaultDim)
indexParams["m"] = strconv.Itoa(16)
indexParams["efConstruction"] = strconv.Itoa(efConstruction)
indexParams["ef"] = strconv.Itoa(ef)
indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize)
} else if indexType == IndexANNOY {
indexParams["dim"] = strconv.Itoa(defaultDim)
indexParams["n_trees"] = strconv.Itoa(4)
indexParams["search_k"] = strconv.Itoa(100)
indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize)
} else if indexType == IndexNGTPANNG {
indexParams["dim"] = strconv.Itoa(defaultDim)
indexParams["edge_size"] = strconv.Itoa(edgeSize)
indexParams["epsilon"] = fmt.Sprint(epsilon)
indexParams["max_search_edges"] = strconv.Itoa(maxSearchEdges)
indexParams["forcedly_pruned_edge_size"] = strconv.Itoa(60)
indexParams["selectively_pruned_edge_size"] = strconv.Itoa(30)
indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize)
} else if indexType == IndexNGTONNG {
indexParams["dim"] = strconv.Itoa(defaultDim)
indexParams["edge_size"] = strconv.Itoa(edgeSize)
indexParams["epsilon"] = fmt.Sprint(epsilon)
indexParams["max_search_edges"] = strconv.Itoa(maxSearchEdges)
indexParams["outgoing_edge_size"] = strconv.Itoa(5)
indexParams["incoming_edge_size"] = strconv.Itoa(40)
indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize)
} else if indexType == IndexFaissBinIVFFlat { // binary vector
indexParams["dim"] = strconv.Itoa(defaultDim)
indexParams["nlist"] = strconv.Itoa(nlist)
indexParams["m"] = strconv.Itoa(m)
indexParams["nbits"] = strconv.Itoa(nbits)
indexParams["SLICE_SIZE"] = strconv.Itoa(sliceSize)
} else if indexType == IndexFaissBinIDMap {
indexParams["dim"] = strconv.Itoa(defaultDim)
} else {
panic("")
}
return typeParams, indexParams
}
func genSimpleSegCoreSchema() *schemapb.CollectionSchema {
fieldVec := genFloatVectorField(simpleVecField)
fieldInt := genConstantField(simpleConstField)
......@@ -866,6 +1141,18 @@ func genSimpleSealedSegment() (*Segment, error) {
defaultMsgLength)
}
func genSealedSegmentWithMsgLength(msgLength int) (*Segment, error) {
schema := genSimpleSegCoreSchema()
schema2 := genSimpleInsertDataSchema()
return genSealedSegment(schema,
schema2,
defaultCollectionID,
defaultPartitionID,
defaultSegmentID,
defaultDMLChannel,
msgLength)
}
func genSimpleReplica() (ReplicaInterface, error) {
kv, err := genEtcdKV()
if err != nil {
......@@ -990,13 +1277,13 @@ func genSimpleDSL() (string, error) {
return genDSL(schema, defaultNProb, defaultTopK, defaultRoundDecimal)
}
func genSimplePlaceHolderGroup() ([]byte, error) {
func genPlaceHolderGroup(nq int) ([]byte, error) {
placeholderValue := &milvuspb.PlaceholderValue{
Tag: "$0",
Type: milvuspb.PlaceholderType_FloatVector,
Values: make([][]byte, 0),
}
for i := 0; i < int(defaultTopK); i++ {
for i := 0; i < nq; i++ {
var vec = make([]float32, defaultDim)
for j := 0; j < defaultDim; j++ {
vec[j] = rand.Float32()
......@@ -1021,6 +1308,10 @@ func genSimplePlaceHolderGroup() ([]byte, error) {
return placeGroupByte, nil
}
func genSimplePlaceHolderGroup() ([]byte, error) {
return genPlaceHolderGroup(defaultNQ)
}
func genSimpleSearchPlanAndRequests() (*SearchPlan, []*searchRequest, error) {
schema := genSimpleSegCoreSchema()
collection := newCollection(defaultCollectionID, schema)
......@@ -1111,8 +1402,8 @@ func genSimpleRetrievePlan() (*RetrievePlan, error) {
return plan, err
}
func genSimpleSearchRequest() (*internalpb.SearchRequest, error) {
placeHolder, err := genSimplePlaceHolderGroup()
func genSearchRequest(nq int) (*internalpb.SearchRequest, error) {
placeHolder, err := genPlaceHolderGroup(nq)
if err != nil {
return nil, err
}
......@@ -1130,6 +1421,10 @@ func genSimpleSearchRequest() (*internalpb.SearchRequest, error) {
}, nil
}
func genSimpleSearchRequest() (*internalpb.SearchRequest, error) {
return genSearchRequest(defaultNQ)
}
func genSimpleRetrieveRequest() (*internalpb.RetrieveRequest, error) {
expr, err := genSimpleRetrievePlanExpr()
if err != nil {
......@@ -1149,6 +1444,19 @@ func genSimpleRetrieveRequest() (*internalpb.RetrieveRequest, error) {
}, nil
}
func genSearchMsg(nq int) (*msgstream.SearchMsg, error) {
req, err := genSearchRequest(nq)
if err != nil {
return nil, err
}
msg := &msgstream.SearchMsg{
BaseMsg: genMsgStreamBaseMsg(),
SearchRequest: *req,
}
msg.SetTimeRecorder()
return msg, nil
}
func genSimpleSearchMsg() (*msgstream.SearchMsg, error) {
req, err := genSimpleSearchRequest()
if err != nil {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册