未验证 提交 18cad3a1 编写于 作者: S smellthemoon 提交者: GitHub

Optimization of delete and insert (#20990)

Signed-off-by: Nlixinguo <xinguo.li@zilliz.com>
Signed-off-by: Nlixinguo <xinguo.li@zilliz.com>
Co-authored-by: Nlixinguo <xinguo.li@zilliz.com>
上级 eb7ef01b
......@@ -1983,7 +1983,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
ctx: ctx,
Condition: NewTaskCondition(ctx),
// req: request,
BaseInsertTask: BaseInsertTask{
insertMsg: &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
HashValues: request.HashKeys,
},
......@@ -2007,8 +2007,8 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
chTicker: node.chTicker,
}
if len(it.PartitionName) <= 0 {
it.PartitionName = Params.CommonCfg.DefaultPartitionName.GetValue()
if len(it.insertMsg.PartitionName) <= 0 {
it.insertMsg.PartitionName = Params.CommonCfg.DefaultPartitionName.GetValue()
}
constructFailedResponse := func(err error) *milvuspb.MutationResult {
......@@ -2045,7 +2045,6 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
log.Debug("Detail of insert request in Proxy",
zap.String("role", typeutil.ProxyRole),
zap.Int64("msgID", it.Base.MsgID),
zap.Uint64("BeginTS", it.BeginTs()),
zap.Uint64("EndTS", it.EndTs()),
zap.String("db", request.DbName),
......@@ -2082,6 +2081,7 @@ func (node *Proxy) Insert(ctx context.Context, request *milvuspb.InsertRequest)
metrics.ProxyInsertVectors.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Add(float64(successCnt))
metrics.ProxyMutationLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.InsertLabel).Observe(float64(tr.ElapseSpan().Milliseconds()))
metrics.ProxyCollectionMutationLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.InsertLabel, request.CollectionName).Observe(float64(tr.ElapseSpan().Milliseconds()))
log.Debug("lxg debug", zap.Any("insertResult", it.result))
return it.result, nil
}
......@@ -2112,7 +2112,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
ctx: ctx,
Condition: NewTaskCondition(ctx),
deleteExpr: request.Expr,
BaseDeleteTask: BaseDeleteTask{
deleteMsg: &BaseDeleteTask{
BaseMsg: msgstream.BaseMsg{
HashValues: request.HashKeys,
},
......@@ -2154,7 +2154,7 @@ func (node *Proxy) Delete(ctx context.Context, request *milvuspb.DeleteRequest)
log.Debug("Detail of delete request in Proxy",
zap.String("role", typeutil.ProxyRole),
zap.Uint64("timestamp", dt.Base.Timestamp),
zap.Uint64("timestamp", dt.deleteMsg.Base.Timestamp),
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
zap.String("partition", request.PartitionName),
......
// Licensed to the LF AI & Data foundation under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package proxy
import (
"context"
"github.com/milvus-io/milvus-proto/go-api/commonpb"
"github.com/milvus-io/milvus-proto/go-api/milvuspb"
"github.com/milvus-io/milvus-proto/go-api/schemapb"
"github.com/milvus-io/milvus/internal/allocator"
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/commonpbutil"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
)
func assignSegmentID(ctx context.Context, insertMsg *msgstream.InsertMsg, result *milvuspb.MutationResult, channelNames []string, idAllocator *allocator.IDAllocator, segIDAssigner *segIDAssigner) (*msgstream.MsgPack, error) {
threshold := Params.PulsarCfg.MaxMessageSize.GetAsInt()
log.Debug("assign segmentid", zap.Int("threshold", threshold))
msgPack := &msgstream.MsgPack{
BeginTs: insertMsg.BeginTs(),
EndTs: insertMsg.EndTs(),
}
// generate hash value for every primary key
if len(insertMsg.HashValues) != 0 {
log.Warn("the hashvalues passed through client is not supported now, and will be overwritten")
}
insertMsg.HashValues = typeutil.HashPK2Channels(result.IDs, channelNames)
// groupedHashKeys represents the dmChannel index
channel2RowOffsets := make(map[string][]int) // channelName to count
channelMaxTSMap := make(map[string]Timestamp) // channelName to max Timestamp
// assert len(it.hashValues) < maxInt
for offset, channelID := range insertMsg.HashValues {
channelName := channelNames[channelID]
if _, ok := channel2RowOffsets[channelName]; !ok {
channel2RowOffsets[channelName] = []int{}
}
channel2RowOffsets[channelName] = append(channel2RowOffsets[channelName], offset)
if _, ok := channelMaxTSMap[channelName]; !ok {
channelMaxTSMap[channelName] = typeutil.ZeroTimestamp
}
ts := insertMsg.Timestamps[offset]
if channelMaxTSMap[channelName] < ts {
channelMaxTSMap[channelName] = ts
}
}
// pre-alloc msg id by batch
var idBegin, idEnd int64
var err error
// fetch next id, if not id available, fetch next batch
// lazy fetch, get first batch after first getMsgID called
getMsgID := func() (int64, error) {
if idBegin == idEnd {
err = retry.Do(ctx, func() error {
idBegin, idEnd, err = idAllocator.Alloc(16)
return err
})
if err != nil {
log.Error("failed to allocate msg id", zap.Int64("base.MsgID", insertMsg.Base.MsgID), zap.Error(err))
return 0, err
}
}
result := idBegin
idBegin++
return result, nil
}
// create empty insert message
createInsertMsg := func(segmentID UniqueID, channelName string, msgID int64) *msgstream.InsertMsg {
insertReq := internalpb.InsertRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
commonpbutil.WithMsgID(msgID),
commonpbutil.WithTimeStamp(insertMsg.BeginTimestamp), // entity's timestamp was set to equal it.BeginTimestamp in preExecute()
commonpbutil.WithSourceID(insertMsg.Base.SourceID),
),
CollectionID: insertMsg.CollectionID,
PartitionID: insertMsg.PartitionID,
CollectionName: insertMsg.CollectionName,
PartitionName: insertMsg.PartitionName,
SegmentID: segmentID,
ShardName: channelName,
Version: internalpb.InsertDataVersion_ColumnBased,
}
insertReq.FieldsData = make([]*schemapb.FieldData, len(insertMsg.GetFieldsData()))
msg := &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
Ctx: ctx,
},
InsertRequest: insertReq,
}
return msg
}
// repack the row data corresponding to the offset to insertMsg
getInsertMsgsBySegmentID := func(segmentID UniqueID, rowOffsets []int, channelName string, maxMessageSize int) ([]msgstream.TsMsg, error) {
repackedMsgs := make([]msgstream.TsMsg, 0)
requestSize := 0
msgID, err := getMsgID()
if err != nil {
return nil, err
}
msg := createInsertMsg(segmentID, channelName, msgID)
for _, offset := range rowOffsets {
curRowMessageSize, err := typeutil.EstimateEntitySize(insertMsg.GetFieldsData(), offset)
if err != nil {
return nil, err
}
// if insertMsg's size is greater than the threshold, split into multiple insertMsgs
if requestSize+curRowMessageSize >= maxMessageSize {
repackedMsgs = append(repackedMsgs, msg)
msgID, err = getMsgID()
if err != nil {
return nil, err
}
msg = createInsertMsg(segmentID, channelName, msgID)
requestSize = 0
}
typeutil.AppendFieldData(msg.FieldsData, insertMsg.GetFieldsData(), int64(offset))
msg.HashValues = append(msg.HashValues, insertMsg.HashValues[offset])
msg.Timestamps = append(msg.Timestamps, insertMsg.Timestamps[offset])
msg.RowIDs = append(msg.RowIDs, insertMsg.RowIDs[offset])
msg.NumRows++
requestSize += curRowMessageSize
}
repackedMsgs = append(repackedMsgs, msg)
return repackedMsgs, nil
}
// get allocated segmentID info for every dmChannel and repack insertMsgs for every segmentID
for channelName, rowOffsets := range channel2RowOffsets {
assignedSegmentInfos, err := segIDAssigner.GetSegmentID(insertMsg.CollectionID, insertMsg.PartitionID, channelName, uint32(len(rowOffsets)), channelMaxTSMap[channelName])
if err != nil {
log.Error("allocate segmentID for insert data failed", zap.Int64("collectionID", insertMsg.CollectionID), zap.String("channel name", channelName),
zap.Int("allocate count", len(rowOffsets)),
zap.Error(err))
return nil, err
}
startPos := 0
for segmentID, count := range assignedSegmentInfos {
subRowOffsets := rowOffsets[startPos : startPos+int(count)]
insertMsgs, err := getInsertMsgsBySegmentID(segmentID, subRowOffsets, channelName, threshold)
if err != nil {
log.Error("repack insert data to insert msgs failed", zap.Int64("collectionID", insertMsg.CollectionID),
zap.Error(err))
return nil, err
}
msgPack.Msgs = append(msgPack.Msgs, insertMsgs...)
startPos += int(count)
}
}
return msgPack, nil
}
......@@ -27,7 +27,7 @@ type BaseDeleteTask = msgstream.DeleteMsg
type deleteTask struct {
Condition
BaseDeleteTask
deleteMsg *BaseDeleteTask
ctx context.Context
deleteExpr string
//req *milvuspb.DeleteRequest
......@@ -46,15 +46,15 @@ func (dt *deleteTask) TraceCtx() context.Context {
}
func (dt *deleteTask) ID() UniqueID {
return dt.Base.MsgID
return dt.deleteMsg.Base.MsgID
}
func (dt *deleteTask) SetID(uid UniqueID) {
dt.Base.MsgID = uid
dt.deleteMsg.Base.MsgID = uid
}
func (dt *deleteTask) Type() commonpb.MsgType {
return dt.Base.MsgType
return dt.deleteMsg.Base.MsgType
}
func (dt *deleteTask) Name() string {
......@@ -62,19 +62,19 @@ func (dt *deleteTask) Name() string {
}
func (dt *deleteTask) BeginTs() Timestamp {
return dt.Base.Timestamp
return dt.deleteMsg.Base.Timestamp
}
func (dt *deleteTask) EndTs() Timestamp {
return dt.Base.Timestamp
return dt.deleteMsg.Base.Timestamp
}
func (dt *deleteTask) SetTs(ts Timestamp) {
dt.Base.Timestamp = ts
dt.deleteMsg.Base.Timestamp = ts
}
func (dt *deleteTask) OnEnqueue() error {
dt.DeleteRequest.Base = commonpbutil.NewMsgBase()
dt.deleteMsg.Base = commonpbutil.NewMsgBase()
return nil
}
......@@ -99,7 +99,7 @@ func (dt *deleteTask) getPChanStats() (map[pChan]pChanStatistics, error) {
}
func (dt *deleteTask) getChannels() ([]pChan, error) {
collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.CollectionName)
collID, err := globalMetaCache.GetCollectionID(dt.ctx, dt.deleteMsg.CollectionName)
if err != nil {
return nil, err
}
......@@ -154,8 +154,8 @@ func getPrimaryKeysFromExpr(schema *schemapb.CollectionSchema, expr string) (res
}
func (dt *deleteTask) PreExecute(ctx context.Context) error {
dt.Base.MsgType = commonpb.MsgType_Delete
dt.Base.SourceID = paramtable.GetNodeID()
dt.deleteMsg.Base.MsgType = commonpb.MsgType_Delete
dt.deleteMsg.Base.SourceID = paramtable.GetNodeID()
dt.result = &milvuspb.MutationResult{
Status: &commonpb.Status{
......@@ -167,7 +167,7 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error {
Timestamp: dt.BeginTs(),
}
collName := dt.CollectionName
collName := dt.deleteMsg.CollectionName
if err := validateCollectionName(collName); err != nil {
log.Info("Invalid collection name", zap.String("collectionName", collName), zap.Error(err))
return err
......@@ -177,12 +177,12 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error {
log.Info("Failed to get collection id", zap.String("collectionName", collName), zap.Error(err))
return err
}
dt.DeleteRequest.CollectionID = collID
dt.deleteMsg.CollectionID = collID
dt.collectionID = collID
// If partitionName is not empty, partitionID will be set.
if len(dt.PartitionName) > 0 {
partName := dt.PartitionName
if len(dt.deleteMsg.PartitionName) > 0 {
partName := dt.deleteMsg.PartitionName
if err := validatePartitionTag(partName, true); err != nil {
log.Info("Invalid partition name", zap.String("partitionName", partName), zap.Error(err))
return err
......@@ -192,9 +192,9 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error {
log.Info("Failed to get partition id", zap.String("collectionName", collName), zap.String("partitionName", partName), zap.Error(err))
return err
}
dt.DeleteRequest.PartitionID = partID
dt.deleteMsg.PartitionID = partID
} else {
dt.DeleteRequest.PartitionID = common.InvalidPartitionID
dt.deleteMsg.PartitionID = common.InvalidPartitionID
}
schema, err := globalMetaCache.GetCollectionSchema(ctx, collName)
......@@ -211,17 +211,17 @@ func (dt *deleteTask) PreExecute(ctx context.Context) error {
return err
}
dt.DeleteRequest.NumRows = numRow
dt.DeleteRequest.PrimaryKeys = primaryKeys
log.Debug("get primary keys from expr", zap.Int64("len of primary keys", dt.DeleteRequest.NumRows))
dt.deleteMsg.NumRows = numRow
dt.deleteMsg.PrimaryKeys = primaryKeys
log.Debug("get primary keys from expr", zap.Int64("len of primary keys", dt.deleteMsg.NumRows))
// set result
dt.result.IDs = primaryKeys
dt.result.DeleteCnt = dt.DeleteRequest.NumRows
dt.result.DeleteCnt = dt.deleteMsg.NumRows
dt.Timestamps = make([]uint64, numRow)
for index := range dt.Timestamps {
dt.Timestamps[index] = dt.BeginTs()
dt.deleteMsg.Timestamps = make([]uint64, numRow)
for index := range dt.deleteMsg.Timestamps {
dt.deleteMsg.Timestamps[index] = dt.BeginTs()
}
return nil
......@@ -233,7 +233,7 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute delete %d", dt.ID()))
collID := dt.DeleteRequest.CollectionID
collID := dt.deleteMsg.CollectionID
stream, err := dt.chMgr.getOrCreateDmlStream(collID)
if err != nil {
return err
......@@ -247,10 +247,10 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
dt.result.Status.Reason = err.Error()
return err
}
dt.HashValues = typeutil.HashPK2Channels(dt.result.IDs, channelNames)
dt.deleteMsg.HashValues = typeutil.HashPK2Channels(dt.result.IDs, channelNames)
log.Debug("send delete request to virtual channels",
zap.String("collection", dt.GetCollectionName()),
zap.String("collection", dt.deleteMsg.GetCollectionName()),
zap.Int64("collection_id", collID),
zap.Strings("virtual_channels", channelNames),
zap.Int64("task_id", dt.ID()))
......@@ -258,19 +258,19 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
tr.Record("get vchannels")
// repack delete msg by dmChannel
result := make(map[uint32]msgstream.TsMsg)
collectionName := dt.CollectionName
collectionID := dt.CollectionID
partitionID := dt.PartitionID
partitionName := dt.PartitionName
proxyID := dt.Base.SourceID
for index, key := range dt.HashValues {
ts := dt.Timestamps[index]
collectionName := dt.deleteMsg.CollectionName
collectionID := dt.deleteMsg.CollectionID
partitionID := dt.deleteMsg.PartitionID
partitionName := dt.deleteMsg.PartitionName
proxyID := dt.deleteMsg.Base.SourceID
for index, key := range dt.deleteMsg.HashValues {
ts := dt.deleteMsg.Timestamps[index]
_, ok := result[key]
if !ok {
sliceRequest := internalpb.DeleteRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Delete),
commonpbutil.WithMsgID(dt.Base.MsgID),
commonpbutil.WithMsgID(dt.deleteMsg.Base.MsgID),
commonpbutil.WithTimeStamp(ts),
commonpbutil.WithSourceID(proxyID),
),
......@@ -289,9 +289,9 @@ func (dt *deleteTask) Execute(ctx context.Context) (err error) {
result[key] = deleteMsg
}
curMsg := result[key].(*msgstream.DeleteMsg)
curMsg.HashValues = append(curMsg.HashValues, dt.HashValues[index])
curMsg.Timestamps = append(curMsg.Timestamps, dt.Timestamps[index])
typeutil.AppendIDs(curMsg.PrimaryKeys, dt.PrimaryKeys, index)
curMsg.HashValues = append(curMsg.HashValues, dt.deleteMsg.HashValues[index])
curMsg.Timestamps = append(curMsg.Timestamps, dt.deleteMsg.Timestamps[index])
typeutil.AppendIDs(curMsg.PrimaryKeys, dt.deleteMsg.PrimaryKeys, index)
curMsg.NumRows++
}
......
......@@ -12,21 +12,17 @@ import (
"github.com/milvus-io/milvus/internal/log"
"github.com/milvus-io/milvus/internal/metrics"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/internalpb"
"github.com/milvus-io/milvus/internal/util/commonpbutil"
"github.com/milvus-io/milvus/internal/util/paramtable"
"github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/timerecord"
"github.com/milvus-io/milvus/internal/util/trace"
"github.com/milvus-io/milvus/internal/util/typeutil"
"go.uber.org/zap"
)
type insertTask struct {
BaseInsertTask
// req *milvuspb.InsertRequest
Condition
ctx context.Context
insertMsg *BaseInsertTask
ctx context.Context
result *milvuspb.MutationResult
idAllocator *allocator.IDAllocator
......@@ -44,11 +40,11 @@ func (it *insertTask) TraceCtx() context.Context {
}
func (it *insertTask) ID() UniqueID {
return it.Base.MsgID
return it.insertMsg.Base.MsgID
}
func (it *insertTask) SetID(uid UniqueID) {
it.Base.MsgID = uid
it.insertMsg.Base.MsgID = uid
}
func (it *insertTask) Name() string {
......@@ -56,20 +52,20 @@ func (it *insertTask) Name() string {
}
func (it *insertTask) Type() commonpb.MsgType {
return it.Base.MsgType
return it.insertMsg.Base.MsgType
}
func (it *insertTask) BeginTs() Timestamp {
return it.BeginTimestamp
return it.insertMsg.BeginTimestamp
}
func (it *insertTask) SetTs(ts Timestamp) {
it.BeginTimestamp = ts
it.EndTimestamp = ts
it.insertMsg.BeginTimestamp = ts
it.insertMsg.EndTimestamp = ts
}
func (it *insertTask) EndTs() Timestamp {
return it.EndTimestamp
return it.insertMsg.EndTimestamp
}
func (it *insertTask) getPChanStats() (map[pChan]pChanStatistics, error) {
......@@ -93,7 +89,7 @@ func (it *insertTask) getPChanStats() (map[pChan]pChanStatistics, error) {
}
func (it *insertTask) getChannels() ([]pChan, error) {
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.CollectionName)
collID, err := globalMetaCache.GetCollectionID(it.ctx, it.insertMsg.CollectionName)
if err != nil {
return nil, err
}
......@@ -104,71 +100,6 @@ func (it *insertTask) OnEnqueue() error {
return nil
}
func (it *insertTask) checkLengthOfFieldsData() error {
neededFieldsNum := 0
for _, field := range it.schema.Fields {
if !field.AutoID {
neededFieldsNum++
}
}
if len(it.FieldsData) < neededFieldsNum {
return errFieldsLessThanNeeded(len(it.FieldsData), neededFieldsNum)
}
return nil
}
func (it *insertTask) checkPrimaryFieldData() error {
rowNums := uint32(it.NRows())
// TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields
if it.NRows() <= 0 {
return errNumRowsLessThanOrEqualToZero(rowNums)
}
if err := it.checkLengthOfFieldsData(); err != nil {
return err
}
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(it.schema)
if err != nil {
log.Error("get primary field schema failed", zap.String("collectionName", it.CollectionName), zap.Any("schema", it.schema), zap.Error(err))
return err
}
// get primaryFieldData whether autoID is true or not
var primaryFieldData *schemapb.FieldData
if !primaryFieldSchema.AutoID {
primaryFieldData, err = typeutil.GetPrimaryFieldData(it.GetFieldsData(), primaryFieldSchema)
if err != nil {
log.Error("get primary field data failed", zap.String("collectionName", it.CollectionName), zap.Error(err))
return err
}
} else {
// check primary key data not exist
if typeutil.IsPrimaryFieldDataExist(it.GetFieldsData(), primaryFieldSchema) {
return fmt.Errorf("can not assign primary field data when auto id enabled %v", primaryFieldSchema.Name)
}
// if autoID == true, currently only support autoID for int64 PrimaryField
primaryFieldData, err = autoGenPrimaryFieldData(primaryFieldSchema, it.RowIDs)
if err != nil {
log.Error("generate primary field data failed when autoID == true", zap.String("collectionName", it.CollectionName), zap.Error(err))
return err
}
// if autoID == true, set the primary field data
it.FieldsData = append(it.FieldsData, primaryFieldData)
}
// parse primaryFieldData to result.IDs, and as returned primary keys
it.result.IDs, err = parsePrimaryFieldData2IDs(primaryFieldData)
if err != nil {
log.Error("parse primary field data to IDs failed", zap.String("collectionName", it.CollectionName), zap.Error(err))
return err
}
return nil
}
func (it *insertTask) PreExecute(ctx context.Context) error {
sp, ctx := trace.StartSpanFromContextWithOperationName(it.ctx, "Proxy-Insert-PreExecute")
defer sp.Finish()
......@@ -183,13 +114,13 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
Timestamp: it.EndTs(),
}
collectionName := it.CollectionName
collectionName := it.insertMsg.CollectionName
if err := validateCollectionName(collectionName); err != nil {
log.Error("valid collection name failed", zap.String("collectionName", collectionName), zap.Error(err))
return err
}
partitionTag := it.PartitionName
partitionTag := it.insertMsg.PartitionName
if err := validatePartitionTag(partitionTag, true); err != nil {
log.Error("valid partition name failed", zap.String("partition name", partitionTag), zap.Error(err))
return err
......@@ -202,7 +133,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
}
it.schema = collSchema
rowNums := uint32(it.NRows())
rowNums := uint32(it.insertMsg.NRows())
// set insertTask.rowIDs
var rowIDBegin UniqueID
var rowIDEnd UniqueID
......@@ -210,16 +141,16 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
rowIDBegin, rowIDEnd, _ = it.idAllocator.Alloc(rowNums)
metrics.ProxyApplyPrimaryKeyLatency.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Observe(float64(tr.ElapseSpan().Milliseconds()))
it.RowIDs = make([]UniqueID, rowNums)
it.insertMsg.RowIDs = make([]UniqueID, rowNums)
for i := rowIDBegin; i < rowIDEnd; i++ {
offset := i - rowIDBegin
it.RowIDs[offset] = i
it.insertMsg.RowIDs[offset] = i
}
// set insertTask.timeStamps
rowNum := it.NRows()
it.Timestamps = make([]uint64, rowNum)
for index := range it.Timestamps {
it.Timestamps[index] = it.BeginTimestamp
rowNum := it.insertMsg.NRows()
it.insertMsg.Timestamps = make([]uint64, rowNum)
for index := range it.insertMsg.Timestamps {
it.insertMsg.Timestamps[index] = it.insertMsg.BeginTimestamp
}
// set result.SuccIndex
......@@ -231,7 +162,8 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
// check primaryFieldData whether autoID is true or not
// set rowIDs as primary data if autoID == true
err = it.checkPrimaryFieldData()
// TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields
it.result.IDs, err = checkPrimaryFieldData(it.schema, it.insertMsg)
log := log.Ctx(ctx).With(zap.String("collectionName", collectionName))
if err != nil {
log.Error("check primary field data and hash primary key failed",
......@@ -240,7 +172,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
}
// set field ID to insert field data
err = fillFieldIDBySchema(it.GetFieldsData(), collSchema)
err = fillFieldIDBySchema(it.insertMsg.GetFieldsData(), collSchema)
if err != nil {
log.Error("set fieldID to fieldData failed",
zap.Error(err))
......@@ -248,7 +180,7 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
}
// check that all field's number rows are equal
if err = it.CheckAligned(); err != nil {
if err = it.insertMsg.CheckAligned(); err != nil {
log.Error("field data is not aligned",
zap.Error(err))
return err
......@@ -259,160 +191,6 @@ func (it *insertTask) PreExecute(ctx context.Context) error {
return nil
}
func (it *insertTask) assignSegmentID(channelNames []string) (*msgstream.MsgPack, error) {
threshold := Params.PulsarCfg.MaxMessageSize.GetAsInt()
log.Debug("assign segmentid", zap.Int("threshold", threshold))
result := &msgstream.MsgPack{
BeginTs: it.BeginTs(),
EndTs: it.EndTs(),
}
// generate hash value for every primary key
if len(it.HashValues) != 0 {
log.Warn("the hashvalues passed through client is not supported now, and will be overwritten")
}
it.HashValues = typeutil.HashPK2Channels(it.result.IDs, channelNames)
// groupedHashKeys represents the dmChannel index
channel2RowOffsets := make(map[string][]int) // channelName to count
channelMaxTSMap := make(map[string]Timestamp) // channelName to max Timestamp
// assert len(it.hashValues) < maxInt
for offset, channelID := range it.HashValues {
channelName := channelNames[channelID]
if _, ok := channel2RowOffsets[channelName]; !ok {
channel2RowOffsets[channelName] = []int{}
}
channel2RowOffsets[channelName] = append(channel2RowOffsets[channelName], offset)
if _, ok := channelMaxTSMap[channelName]; !ok {
channelMaxTSMap[channelName] = typeutil.ZeroTimestamp
}
ts := it.Timestamps[offset]
if channelMaxTSMap[channelName] < ts {
channelMaxTSMap[channelName] = ts
}
}
// pre-alloc msg id by batch
var idBegin, idEnd int64
var err error
// fetch next id, if not id available, fetch next batch
// lazy fetch, get first batch after first getMsgID called
getMsgID := func() (int64, error) {
if idBegin == idEnd {
err = retry.Do(it.ctx, func() error {
idBegin, idEnd, err = it.idAllocator.Alloc(16)
return err
})
if err != nil {
log.Error("failed to allocate msg id", zap.Int64("base.MsgID", it.Base.MsgID), zap.Error(err))
return 0, err
}
}
result := idBegin
idBegin++
return result, nil
}
// create empty insert message
createInsertMsg := func(segmentID UniqueID, channelName string, msgID int64) *msgstream.InsertMsg {
insertReq := internalpb.InsertRequest{
Base: commonpbutil.NewMsgBase(
commonpbutil.WithMsgType(commonpb.MsgType_Insert),
commonpbutil.WithMsgID(msgID),
commonpbutil.WithTimeStamp(it.BeginTimestamp), // entity's timestamp was set to equal it.BeginTimestamp in preExecute()
commonpbutil.WithSourceID(it.Base.SourceID),
),
CollectionID: it.CollectionID,
PartitionID: it.PartitionID,
CollectionName: it.CollectionName,
PartitionName: it.PartitionName,
SegmentID: segmentID,
ShardName: channelName,
Version: internalpb.InsertDataVersion_ColumnBased,
}
insertReq.FieldsData = make([]*schemapb.FieldData, len(it.GetFieldsData()))
insertMsg := &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
Ctx: it.TraceCtx(),
},
InsertRequest: insertReq,
}
return insertMsg
}
// repack the row data corresponding to the offset to insertMsg
getInsertMsgsBySegmentID := func(segmentID UniqueID, rowOffsets []int, channelName string, maxMessageSize int) ([]msgstream.TsMsg, error) {
repackedMsgs := make([]msgstream.TsMsg, 0)
requestSize := 0
msgID, err := getMsgID()
if err != nil {
return nil, err
}
insertMsg := createInsertMsg(segmentID, channelName, msgID)
for _, offset := range rowOffsets {
curRowMessageSize, err := typeutil.EstimateEntitySize(it.InsertRequest.GetFieldsData(), offset)
if err != nil {
return nil, err
}
// if insertMsg's size is greater than the threshold, split into multiple insertMsgs
if requestSize+curRowMessageSize >= maxMessageSize {
repackedMsgs = append(repackedMsgs, insertMsg)
msgID, err = getMsgID()
if err != nil {
return nil, err
}
insertMsg = createInsertMsg(segmentID, channelName, msgID)
requestSize = 0
}
typeutil.AppendFieldData(insertMsg.FieldsData, it.GetFieldsData(), int64(offset))
insertMsg.HashValues = append(insertMsg.HashValues, it.HashValues[offset])
insertMsg.Timestamps = append(insertMsg.Timestamps, it.Timestamps[offset])
insertMsg.RowIDs = append(insertMsg.RowIDs, it.RowIDs[offset])
insertMsg.NumRows++
requestSize += curRowMessageSize
}
repackedMsgs = append(repackedMsgs, insertMsg)
return repackedMsgs, nil
}
// get allocated segmentID info for every dmChannel and repack insertMsgs for every segmentID
for channelName, rowOffsets := range channel2RowOffsets {
assignedSegmentInfos, err := it.segIDAssigner.GetSegmentID(it.CollectionID, it.PartitionID, channelName, uint32(len(rowOffsets)), channelMaxTSMap[channelName])
if err != nil {
log.Error("allocate segmentID for insert data failed",
zap.Int64("collectionID", it.CollectionID),
zap.String("channel name", channelName),
zap.Int("allocate count", len(rowOffsets)),
zap.Error(err))
return nil, err
}
startPos := 0
for segmentID, count := range assignedSegmentInfos {
subRowOffsets := rowOffsets[startPos : startPos+int(count)]
insertMsgs, err := getInsertMsgsBySegmentID(segmentID, subRowOffsets, channelName, threshold)
if err != nil {
log.Error("repack insert data to insert msgs failed",
zap.Int64("collectionID", it.CollectionID),
zap.Error(err))
return nil, err
}
result.Msgs = append(result.Msgs, insertMsgs...)
startPos += int(count)
}
}
return result, nil
}
func (it *insertTask) Execute(ctx context.Context) error {
sp, ctx := trace.StartSpanFromContextWithOperationName(it.ctx, "Proxy-Insert-Execute")
defer sp.Finish()
......@@ -420,15 +198,15 @@ func (it *insertTask) Execute(ctx context.Context) error {
tr := timerecord.NewTimeRecorder(fmt.Sprintf("proxy execute insert %d", it.ID()))
defer tr.Elapse("insert execute done")
collectionName := it.CollectionName
collectionName := it.insertMsg.CollectionName
collID, err := globalMetaCache.GetCollectionID(ctx, collectionName)
if err != nil {
return err
}
it.CollectionID = collID
it.insertMsg.CollectionID = collID
var partitionID UniqueID
if len(it.PartitionName) > 0 {
partitionID, err = globalMetaCache.GetPartitionID(ctx, collectionName, it.PartitionName)
if len(it.insertMsg.PartitionName) > 0 {
partitionID, err = globalMetaCache.GetPartitionID(ctx, collectionName, it.insertMsg.PartitionName)
if err != nil {
return err
}
......@@ -438,7 +216,7 @@ func (it *insertTask) Execute(ctx context.Context) error {
return err
}
}
it.PartitionID = partitionID
it.insertMsg.PartitionID = partitionID
tr.Record("get collection id & partition id from cache")
stream, err := it.chMgr.getOrCreateDmlStream(collID)
......@@ -458,15 +236,16 @@ func (it *insertTask) Execute(ctx context.Context) error {
}
log.Ctx(ctx).Debug("send insert request to virtual channels",
zap.String("collection", it.GetCollectionName()),
zap.String("partition", it.GetPartitionName()),
zap.String("collection", it.insertMsg.GetCollectionName()),
zap.String("partition", it.insertMsg.GetPartitionName()),
zap.Int64("collection_id", collID),
zap.Int64("partition_id", partitionID),
zap.Strings("virtual_channels", channelNames),
zap.Int64("task_id", it.ID()))
// assign segmentID for insert data and repack data by segmentID
msgPack, err := it.assignSegmentID(channelNames)
var msgPack *msgstream.MsgPack
msgPack, err = assignSegmentID(it.TraceCtx(), it.insertMsg, it.result, channelNames, it.idAllocator, it.segIDAssigner)
if err != nil {
log.Error("assign segmentID and repack insert data failed",
zap.Int64("collectionID", collID),
......@@ -477,7 +256,7 @@ func (it *insertTask) Execute(ctx context.Context) error {
}
log.Debug("assign segmentID for insert data success",
zap.Int64("collectionID", collID),
zap.String("collectionName", it.CollectionName))
zap.String("collectionName", it.insertMsg.CollectionName))
tr.Record("assign segment id")
err = stream.Produce(msgPack)
if err != nil {
......
......@@ -9,140 +9,12 @@ import (
"github.com/stretchr/testify/assert"
)
func TestInsertTask_checkLengthOfFieldsData(t *testing.T) {
var err error
// schema is empty, though won't happen in system
case1 := insertTask{
schema: &schemapb.CollectionSchema{
Name: "TestInsertTask_checkLengthOfFieldsData",
Description: "TestInsertTask_checkLengthOfFieldsData",
AutoID: false,
Fields: []*schemapb.FieldSchema{},
},
BaseInsertTask: BaseInsertTask{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
DbName: "TestInsertTask_checkLengthOfFieldsData",
CollectionName: "TestInsertTask_checkLengthOfFieldsData",
PartitionName: "TestInsertTask_checkLengthOfFieldsData",
},
},
}
err = case1.checkLengthOfFieldsData()
assert.Equal(t, nil, err)
// schema has two fields, neither of them are autoID
case2 := insertTask{
schema: &schemapb.CollectionSchema{
Name: "TestInsertTask_checkLengthOfFieldsData",
Description: "TestInsertTask_checkLengthOfFieldsData",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
AutoID: false,
DataType: schemapb.DataType_Int64,
},
{
AutoID: false,
DataType: schemapb.DataType_Int64,
},
},
},
}
// passed fields is empty
// case2.BaseInsertTask = BaseInsertTask{
// InsertRequest: internalpb.InsertRequest{
// Base: &commonpb.MsgBase{
// MsgType: commonpb.MsgType_Insert,
// MsgID: 0,
// SourceID: paramtable.GetNodeID(),
// },
// },
// }
err = case2.checkLengthOfFieldsData()
assert.NotEqual(t, nil, err)
// the num of passed fields is less than needed
case2.FieldsData = []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
},
}
err = case2.checkLengthOfFieldsData()
assert.NotEqual(t, nil, err)
// satisfied
case2.FieldsData = []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
},
{
Type: schemapb.DataType_Int64,
},
}
err = case2.checkLengthOfFieldsData()
assert.Equal(t, nil, err)
// schema has two field, one of them are autoID
case3 := insertTask{
schema: &schemapb.CollectionSchema{
Name: "TestInsertTask_checkLengthOfFieldsData",
Description: "TestInsertTask_checkLengthOfFieldsData",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
AutoID: true,
DataType: schemapb.DataType_Int64,
},
{
AutoID: false,
DataType: schemapb.DataType_Int64,
},
},
},
}
// passed fields is empty
// case3.req = &milvuspb.InsertRequest{}
err = case3.checkLengthOfFieldsData()
assert.NotEqual(t, nil, err)
// satisfied
case3.FieldsData = []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
},
}
err = case3.checkLengthOfFieldsData()
assert.Equal(t, nil, err)
// schema has one field which is autoID
case4 := insertTask{
schema: &schemapb.CollectionSchema{
Name: "TestInsertTask_checkLengthOfFieldsData",
Description: "TestInsertTask_checkLengthOfFieldsData",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
AutoID: true,
DataType: schemapb.DataType_Int64,
},
},
},
}
// passed fields is empty
// satisfied
// case4.req = &milvuspb.InsertRequest{}
err = case4.checkLengthOfFieldsData()
assert.Equal(t, nil, err)
}
func TestInsertTask_CheckAligned(t *testing.T) {
var err error
// passed NumRows is less than 0
case1 := insertTask{
BaseInsertTask: BaseInsertTask{
insertMsg: &BaseInsertTask{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
......@@ -151,7 +23,7 @@ func TestInsertTask_CheckAligned(t *testing.T) {
},
},
}
err = case1.CheckAligned()
err = case1.insertMsg.CheckAligned()
assert.NoError(t, err)
// checkLengthOfFieldsData was already checked by TestInsertTask_checkLengthOfFieldsData
......@@ -170,7 +42,7 @@ func TestInsertTask_CheckAligned(t *testing.T) {
numRows := 20
dim := 128
case2 := insertTask{
BaseInsertTask: BaseInsertTask{
insertMsg: &BaseInsertTask{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
......@@ -200,8 +72,8 @@ func TestInsertTask_CheckAligned(t *testing.T) {
}
// satisfied
case2.NumRows = uint64(numRows)
case2.FieldsData = []*schemapb.FieldData{
case2.insertMsg.NumRows = uint64(numRows)
case2.insertMsg.FieldsData = []*schemapb.FieldData{
newScalarFieldData(boolFieldSchema, "Bool", numRows),
newScalarFieldData(int8FieldSchema, "Int8", numRows),
newScalarFieldData(int16FieldSchema, "Int16", numRows),
......@@ -213,136 +85,136 @@ func TestInsertTask_CheckAligned(t *testing.T) {
newBinaryVectorFieldData("BinaryVector", numRows, dim),
newScalarFieldData(varCharFieldSchema, "VarChar", numRows),
}
err = case2.CheckAligned()
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
// less bool data
case2.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows/2)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows/2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more bool data
case2.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows*2)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows*2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[0] = newScalarFieldData(boolFieldSchema, "Bool", numRows)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
// less int8 data
case2.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows/2)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows/2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more int8 data
case2.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows*2)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows*2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[1] = newScalarFieldData(int8FieldSchema, "Int8", numRows)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
// less int16 data
case2.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows/2)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows/2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more int16 data
case2.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows*2)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows*2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[2] = newScalarFieldData(int16FieldSchema, "Int16", numRows)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
// less int32 data
case2.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows/2)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows/2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more int32 data
case2.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows*2)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows*2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[3] = newScalarFieldData(int32FieldSchema, "Int32", numRows)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
// less int64 data
case2.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows/2)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows/2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more int64 data
case2.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows*2)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows*2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[4] = newScalarFieldData(int64FieldSchema, "Int64", numRows)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
// less float data
case2.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows/2)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows/2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more float data
case2.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows*2)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows*2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[5] = newScalarFieldData(floatFieldSchema, "Float", numRows)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, nil, err)
// less double data
case2.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows/2)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows/2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more double data
case2.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows*2)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows*2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[6] = newScalarFieldData(doubleFieldSchema, "Double", numRows)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, nil, err)
// less float vectors
case2.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows/2, dim)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows/2, dim)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more float vectors
case2.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows*2, dim)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows*2, dim)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows, dim)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[7] = newFloatVectorFieldData("FloatVector", numRows, dim)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
// less binary vectors
case2.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows/2, dim)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows/2, dim)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more binary vectors
case2.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows*2, dim)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows*2, dim)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows, dim)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[7] = newBinaryVectorFieldData("BinaryVector", numRows, dim)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
// less double data
case2.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows/2)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows/2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// more double data
case2.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows*2)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows*2)
err = case2.insertMsg.CheckAligned()
assert.Error(t, err)
// revert
case2.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows)
err = case2.CheckAligned()
case2.insertMsg.FieldsData[8] = newScalarFieldData(varCharFieldSchema, "VarChar", numRows)
err = case2.insertMsg.CheckAligned()
assert.NoError(t, err)
}
......@@ -1390,7 +1390,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
t.Run("insert", func(t *testing.T) {
hash := generateHashKeys(nb)
task := &insertTask{
BaseInsertTask: BaseInsertTask{
insertMsg: &BaseInsertTask{
BaseMsg: msgstream.BaseMsg{
HashValues: hash,
},
......@@ -1434,7 +1434,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
}
for fieldName, dataType := range fieldName2Types {
task.FieldsData = append(task.FieldsData, generateFieldData(dataType, fieldName, nb))
task.insertMsg.FieldsData = append(task.insertMsg.FieldsData, generateFieldData(dataType, fieldName, nb))
}
assert.NoError(t, task.OnEnqueue())
......@@ -1446,7 +1446,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
t.Run("delete", func(t *testing.T) {
task := &deleteTask{
Condition: NewTaskCondition(ctx),
BaseDeleteTask: msgstream.DeleteMsg{
deleteMsg: &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{},
DeleteRequest: internalpb.DeleteRequest{
Base: &commonpb.MsgBase{
......@@ -1486,7 +1486,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
task.SetID(id)
assert.Equal(t, id, task.ID())
task.Base.MsgType = commonpb.MsgType_Delete
task.deleteMsg.Base.MsgType = commonpb.MsgType_Delete
assert.Equal(t, commonpb.MsgType_Delete, task.Type())
ts := Timestamp(time.Now().UnixNano())
......@@ -1500,7 +1500,7 @@ func TestTask_Int64PrimaryKey(t *testing.T) {
task2 := &deleteTask{
Condition: NewTaskCondition(ctx),
BaseDeleteTask: msgstream.DeleteMsg{
deleteMsg: &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{},
DeleteRequest: internalpb.DeleteRequest{
Base: &commonpb.MsgBase{
......@@ -1643,7 +1643,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
t.Run("insert", func(t *testing.T) {
hash := generateHashKeys(nb)
task := &insertTask{
BaseInsertTask: BaseInsertTask{
insertMsg: &BaseInsertTask{
BaseMsg: msgstream.BaseMsg{
HashValues: hash,
},
......@@ -1688,7 +1688,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
fieldID := common.StartOfUserFieldID
for fieldName, dataType := range fieldName2Types {
task.FieldsData = append(task.FieldsData, generateFieldData(dataType, fieldName, nb))
task.insertMsg.FieldsData = append(task.insertMsg.FieldsData, generateFieldData(dataType, fieldName, nb))
fieldID++
}
......@@ -1701,7 +1701,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
t.Run("delete", func(t *testing.T) {
task := &deleteTask{
Condition: NewTaskCondition(ctx),
BaseDeleteTask: msgstream.DeleteMsg{
deleteMsg: &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{},
DeleteRequest: internalpb.DeleteRequest{
Base: &commonpb.MsgBase{
......@@ -1741,7 +1741,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
task.SetID(id)
assert.Equal(t, id, task.ID())
task.Base.MsgType = commonpb.MsgType_Delete
task.deleteMsg.Base.MsgType = commonpb.MsgType_Delete
assert.Equal(t, commonpb.MsgType_Delete, task.Type())
ts := Timestamp(time.Now().UnixNano())
......@@ -1755,7 +1755,7 @@ func TestTask_VarCharPrimaryKey(t *testing.T) {
task2 := &deleteTask{
Condition: NewTaskCondition(ctx),
BaseDeleteTask: msgstream.DeleteMsg{
deleteMsg: &msgstream.DeleteMsg{
BaseMsg: msgstream.BaseMsg{},
DeleteRequest: internalpb.DeleteRequest{
Base: &commonpb.MsgBase{
......
......@@ -24,6 +24,7 @@ import (
"strings"
"time"
"github.com/milvus-io/milvus/internal/mq/msgstream"
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/types"
......@@ -872,3 +873,69 @@ func isPartitionLoaded(ctx context.Context, qc types.QueryCoord, collID int64, p
}
return false, nil
}
func checkLengthOfFieldsData(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) error {
neededFieldsNum := 0
for _, field := range schema.Fields {
if !field.AutoID {
neededFieldsNum++
}
}
if len(insertMsg.FieldsData) < neededFieldsNum {
return errFieldsLessThanNeeded(len(insertMsg.FieldsData), neededFieldsNum)
}
return nil
}
func checkPrimaryFieldData(schema *schemapb.CollectionSchema, insertMsg *msgstream.InsertMsg) (*schemapb.IDs, error) {
rowNums := uint32(insertMsg.NRows())
// TODO(dragondriver): in fact, NumRows is not trustable, we should check all input fields
if insertMsg.NRows() <= 0 {
return nil, errNumRowsLessThanOrEqualToZero(rowNums)
}
if err := checkLengthOfFieldsData(schema, insertMsg); err != nil {
return nil, err
}
primaryFieldSchema, err := typeutil.GetPrimaryFieldSchema(schema)
if err != nil {
log.Error("get primary field schema failed", zap.String("collectionName", insertMsg.CollectionName), zap.Any("schema", schema), zap.Error(err))
return nil, err
}
// get primaryFieldData whether autoID is true or not
var primaryFieldData *schemapb.FieldData
if !primaryFieldSchema.AutoID {
primaryFieldData, err = typeutil.GetPrimaryFieldData(insertMsg.GetFieldsData(), primaryFieldSchema)
if err != nil {
log.Error("get primary field data failed", zap.String("collectionName", insertMsg.CollectionName), zap.Error(err))
return nil, err
}
} else {
// check primary key data not exist
if typeutil.IsPrimaryFieldDataExist(insertMsg.GetFieldsData(), primaryFieldSchema) {
return nil, fmt.Errorf("can not assign primary field data when auto id enabled %v", primaryFieldSchema.Name)
}
// if autoID == true, currently only support autoID for int64 PrimaryField
primaryFieldData, err = autoGenPrimaryFieldData(primaryFieldSchema, insertMsg.GetRowIDs())
if err != nil {
log.Error("generate primary field data failed when autoID == true", zap.String("collectionName", insertMsg.CollectionName), zap.Error(err))
return nil, err
}
// if autoID == true, set the primary field data
// insertMsg.fieldsData need append primaryFieldData
insertMsg.FieldsData = append(insertMsg.FieldsData, primaryFieldData)
}
// parse primaryFieldData to result.IDs, and as returned primary keys
ids, err := parsePrimaryFieldData2IDs(primaryFieldData)
if err != nil {
log.Error("parse primary field data to IDs failed", zap.String("collectionName", insertMsg.CollectionName), zap.Error(err))
return nil, err
}
return ids, nil
}
......@@ -924,3 +924,305 @@ func Test_isPartitionIsLoaded(t *testing.T) {
assert.False(t, loaded)
})
}
func Test_InsertTaskCheckLengthOfFieldsData(t *testing.T) {
var err error
// schema is empty, though won't happen in system
case1 := insertTask{
schema: &schemapb.CollectionSchema{
Name: "TestInsertTask_checkLengthOfFieldsData",
Description: "TestInsertTask_checkLengthOfFieldsData",
AutoID: false,
Fields: []*schemapb.FieldSchema{},
},
insertMsg: &BaseInsertTask{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
DbName: "TestInsertTask_checkLengthOfFieldsData",
CollectionName: "TestInsertTask_checkLengthOfFieldsData",
PartitionName: "TestInsertTask_checkLengthOfFieldsData",
},
},
}
err = checkLengthOfFieldsData(case1.schema, case1.insertMsg)
assert.Equal(t, nil, err)
// schema has two fields, neither of them are autoID
case2 := insertTask{
schema: &schemapb.CollectionSchema{
Name: "TestInsertTask_checkLengthOfFieldsData",
Description: "TestInsertTask_checkLengthOfFieldsData",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
AutoID: false,
DataType: schemapb.DataType_Int64,
},
{
AutoID: false,
DataType: schemapb.DataType_Int64,
},
},
},
insertMsg: &BaseInsertTask{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
},
},
}
// passed fields is empty
// case2.BaseInsertTask = BaseInsertTask{
// InsertRequest: internalpb.insertRequest{
// Base: &commonpb.MsgBase{
// MsgType: commonpb.MsgType_Insert,
// MsgID: 0,
// SourceID: paramtable.GetNodeID(),
// },
// },
// }
err = checkLengthOfFieldsData(case2.schema, case2.insertMsg)
assert.NotEqual(t, nil, err)
// the num of passed fields is less than needed
case2.insertMsg.FieldsData = []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
},
}
err = checkLengthOfFieldsData(case2.schema, case2.insertMsg)
assert.NotEqual(t, nil, err)
// satisfied
case2.insertMsg.FieldsData = []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
},
{
Type: schemapb.DataType_Int64,
},
}
err = checkLengthOfFieldsData(case2.schema, case2.insertMsg)
assert.Equal(t, nil, err)
// schema has two field, one of them are autoID
case3 := insertTask{
schema: &schemapb.CollectionSchema{
Name: "TestInsertTask_checkLengthOfFieldsData",
Description: "TestInsertTask_checkLengthOfFieldsData",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
AutoID: true,
DataType: schemapb.DataType_Int64,
},
{
AutoID: false,
DataType: schemapb.DataType_Int64,
},
},
},
insertMsg: &BaseInsertTask{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
},
},
}
// passed fields is empty
// case3.req = &milvuspb.InsertRequest{}
err = checkLengthOfFieldsData(case3.schema, case3.insertMsg)
assert.NotEqual(t, nil, err)
// satisfied
case3.insertMsg.FieldsData = []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
},
}
err = checkLengthOfFieldsData(case3.schema, case3.insertMsg)
assert.Equal(t, nil, err)
// schema has one field which is autoID
case4 := insertTask{
schema: &schemapb.CollectionSchema{
Name: "TestInsertTask_checkLengthOfFieldsData",
Description: "TestInsertTask_checkLengthOfFieldsData",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
AutoID: true,
DataType: schemapb.DataType_Int64,
},
},
},
insertMsg: &BaseInsertTask{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
},
},
}
// passed fields is empty
// satisfied
// case4.req = &milvuspb.InsertRequest{}
err = checkLengthOfFieldsData(case4.schema, case4.insertMsg)
assert.Equal(t, nil, err)
}
func Test_InsertTaskCheckPrimaryFieldData(t *testing.T) {
// schema is empty, though won't happen in system
// num_rows(0) should be greater than 0
case1 := insertTask{
schema: &schemapb.CollectionSchema{
Name: "TestInsertTask_checkPrimaryFieldData",
Description: "TestInsertTask_checkPrimaryFieldData",
AutoID: false,
Fields: []*schemapb.FieldSchema{},
},
insertMsg: &BaseInsertTask{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
DbName: "TestInsertTask_checkPrimaryFieldData",
CollectionName: "TestInsertTask_checkPrimaryFieldData",
PartitionName: "TestInsertTask_checkPrimaryFieldData",
},
},
}
_, err := checkPrimaryFieldData(case1.schema, case1.insertMsg)
assert.NotEqual(t, nil, err)
// the num of passed fields is less than needed
case2 := insertTask{
schema: &schemapb.CollectionSchema{
Name: "TestInsertTask_checkPrimaryFieldData",
Description: "TestInsertTask_checkPrimaryFieldData",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
AutoID: false,
DataType: schemapb.DataType_Int64,
},
{
AutoID: false,
DataType: schemapb.DataType_Int64,
},
},
},
insertMsg: &BaseInsertTask{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
RowData: []*commonpb.Blob{
{},
{},
},
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
},
},
Version: internalpb.InsertDataVersion_RowBased,
},
},
}
_, err = checkPrimaryFieldData(case2.schema, case2.insertMsg)
assert.NotEqual(t, nil, err)
// autoID == false, no primary field schema
// primary field is not found
case3 := insertTask{
schema: &schemapb.CollectionSchema{
Name: "TestInsertTask_checkPrimaryFieldData",
Description: "TestInsertTask_checkPrimaryFieldData",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
Name: "int64Field",
DataType: schemapb.DataType_Int64,
},
{
Name: "floatField",
DataType: schemapb.DataType_Float,
},
},
},
insertMsg: &BaseInsertTask{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
RowData: []*commonpb.Blob{
{},
{},
},
FieldsData: []*schemapb.FieldData{
{},
{},
},
},
},
}
_, err = checkPrimaryFieldData(case3.schema, case3.insertMsg)
assert.NotEqual(t, nil, err)
// autoID == true, has primary field schema, but primary field data exist
// can not assign primary field data when auto id enabled int64Field
case4 := insertTask{
schema: &schemapb.CollectionSchema{
Name: "TestInsertTask_checkPrimaryFieldData",
Description: "TestInsertTask_checkPrimaryFieldData",
AutoID: false,
Fields: []*schemapb.FieldSchema{
{
Name: "int64Field",
FieldID: 1,
DataType: schemapb.DataType_Int64,
},
{
Name: "floatField",
FieldID: 2,
DataType: schemapb.DataType_Float,
},
},
},
insertMsg: &BaseInsertTask{
InsertRequest: internalpb.InsertRequest{
Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Insert,
},
RowData: []*commonpb.Blob{
{},
{},
},
FieldsData: []*schemapb.FieldData{
{
Type: schemapb.DataType_Int64,
FieldName: "int64Field",
},
},
},
},
}
case4.schema.Fields[0].IsPrimaryKey = true
case4.schema.Fields[0].AutoID = true
case4.insertMsg.FieldsData[0] = newScalarFieldData(case4.schema.Fields[0], case4.schema.Fields[0].Name, 10)
_, err = checkPrimaryFieldData(case4.schema, case4.insertMsg)
assert.NotEqual(t, nil, err)
// autoID == true, has primary field schema, but DataType don't match
// the data type of the data and the schema do not match
case4.schema.Fields[0].IsPrimaryKey = false
case4.schema.Fields[1].IsPrimaryKey = true
case4.schema.Fields[1].AutoID = true
_, err = checkPrimaryFieldData(case4.schema, case4.insertMsg)
assert.NotEqual(t, nil, err)
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册