未验证 提交 fc74a0f5 编写于 作者: S sunby 提交者: GitHub

Auto create new segments when allocating rows more than (#6665)

max number of rows per segment

If user insert too much rows in a request. Now we will return a failed
response. Maybe auto creating new segments to hold that much rows is a
better way.

issue: #6664
Signed-off-by: Nsunby <bingyi.sun@zilliz.com>
上级 63387c83
...@@ -80,15 +80,6 @@ func (s *Server) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentI ...@@ -80,15 +80,6 @@ func (s *Server) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentI
assigns := make([]*datapb.SegmentIDAssignment, 0, len(req.SegmentIDRequests)) assigns := make([]*datapb.SegmentIDAssignment, 0, len(req.SegmentIDRequests))
var appendFailedAssignment = func(err string) {
assigns = append(assigns, &datapb.SegmentIDAssignment{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err,
},
})
}
for _, r := range req.SegmentIDRequests { for _, r := range req.SegmentIDRequests {
log.Debug("Handle assign segment request", log.Debug("Handle assign segment request",
zap.Int64("collectionID", r.GetCollectionID()), zap.Int64("collectionID", r.GetCollectionID()),
...@@ -98,46 +89,39 @@ func (s *Server) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentI ...@@ -98,46 +89,39 @@ func (s *Server) AssignSegmentID(ctx context.Context, req *datapb.AssignSegmentI
if coll := s.meta.GetCollection(r.CollectionID); coll == nil { if coll := s.meta.GetCollection(r.CollectionID); coll == nil {
if err := s.loadCollectionFromRootCoord(ctx, r.CollectionID); err != nil { if err := s.loadCollectionFromRootCoord(ctx, r.CollectionID); err != nil {
errMsg := fmt.Sprintf("Can not load collection %d", r.CollectionID)
appendFailedAssignment(errMsg)
log.Error("load collection from rootcoord error", log.Error("load collection from rootcoord error",
zap.Int64("collectionID", r.CollectionID), zap.Int64("collectionID", r.CollectionID),
zap.Error(err)) zap.Error(err))
continue continue
} }
} }
//if err := s.validateAllocRequest(r.CollectionID, r.PartitionID, r.ChannelName); err != nil {
//result.Status.Reason = err.Error()
//assigns = append(assigns, result)
//continue
//}
s.cluster.Watch(r.ChannelName, r.CollectionID) s.cluster.Watch(r.ChannelName, r.CollectionID)
segmentID, retCount, expireTs, err := s.segmentManager.AllocSegment(ctx, allocations, err := s.segmentManager.AllocSegment(ctx,
r.CollectionID, r.PartitionID, r.ChannelName, int64(r.Count)) r.CollectionID, r.PartitionID, r.ChannelName, int64(r.Count))
if err != nil { if err != nil {
errMsg := fmt.Sprintf("Allocation of collection %d, partition %d, channel %s, count %d error: %s", log.Warn("failed to alloc segment", zap.Any("request", r), zap.Error(err))
r.CollectionID, r.PartitionID, r.ChannelName, r.Count, err.Error())
appendFailedAssignment(errMsg)
continue continue
} }
log.Debug("Assign segment success", zap.Int64("segmentID", segmentID), log.Debug("Assign segment success", zap.Any("assignments", allocations))
zap.Uint64("expireTs", expireTs))
for _, allocation := range allocations {
result := &datapb.SegmentIDAssignment{ result := &datapb.SegmentIDAssignment{
SegID: segmentID, SegID: allocation.SegmentID,
ChannelName: r.ChannelName, ChannelName: r.ChannelName,
Count: uint32(retCount), Count: uint32(allocation.NumOfRows),
CollectionID: r.CollectionID, CollectionID: r.CollectionID,
PartitionID: r.PartitionID, PartitionID: r.PartitionID,
ExpireTime: expireTs, ExpireTime: allocation.ExpireTime,
Status: &commonpb.Status{ Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_Success, ErrorCode: commonpb.ErrorCode_Success,
Reason: "", Reason: "",
}, },
}
assigns = append(assigns, result)
} }
assigns = append(assigns, result)
} }
return &datapb.AssignSegmentIDResponse{ return &datapb.AssignSegmentIDResponse{
Status: &commonpb.Status{ Status: &commonpb.Status{
......
...@@ -30,15 +30,49 @@ func calBySchemaPolicy(schema *schemapb.CollectionSchema) (int, error) { ...@@ -30,15 +30,49 @@ func calBySchemaPolicy(schema *schemapb.CollectionSchema) (int, error) {
return int(threshold / float64(sizePerRecord)), nil return int(threshold / float64(sizePerRecord)), nil
} }
type allocatePolicy func(segment *SegmentInfo, count int64) bool type AllocatePolicy func(segments []*SegmentInfo, count int64,
maxCountPerSegment int64) ([]*Allocation, []*Allocation)
func AllocatePolicyV1(segments []*SegmentInfo, count int64,
maxCountPerSegment int64) ([]*Allocation, []*Allocation) {
newSegmentAllocations := make([]*Allocation, 0)
existedSegmentAllocations := make([]*Allocation, 0)
// create new segment if count >= max num
for count >= maxCountPerSegment {
allocation := &Allocation{
NumOfRows: maxCountPerSegment,
}
newSegmentAllocations = append(newSegmentAllocations, allocation)
count -= maxCountPerSegment
}
// allocate space for remaining count
if count == 0 {
return newSegmentAllocations, existedSegmentAllocations
}
for _, segment := range segments {
var allocSize int64
for _, allocation := range segment.allocations {
allocSize += allocation.NumOfRows
}
free := segment.GetMaxRowNum() - segment.GetNumOfRows() - allocSize
if free < count {
continue
}
allocation := &Allocation{
SegmentID: segment.GetID(),
NumOfRows: count,
}
existedSegmentAllocations = append(existedSegmentAllocations, allocation)
return newSegmentAllocations, existedSegmentAllocations
}
func allocatePolicyV1(segment *SegmentInfo, count int64) bool { // allocate new segment for remaining count
var allocSize int64 allocation := &Allocation{
for _, allocation := range segment.allocations { NumOfRows: count,
allocSize += allocation.numOfRows
} }
free := segment.GetMaxRowNum() - segment.GetNumOfRows() - allocSize newSegmentAllocations = append(newSegmentAllocations, allocation)
return free >= count return newSegmentAllocations, existedSegmentAllocations
} }
type sealPolicy func(maxCount, writtenCount, allocatedCount int64) bool type sealPolicy func(maxCount, writtenCount, allocatedCount int64) bool
...@@ -54,7 +88,7 @@ func getSegmentCapacityPolicy(sizeFactor float64) segmentSealPolicy { ...@@ -54,7 +88,7 @@ func getSegmentCapacityPolicy(sizeFactor float64) segmentSealPolicy {
return func(segment *SegmentInfo, ts Timestamp) bool { return func(segment *SegmentInfo, ts Timestamp) bool {
var allocSize int64 var allocSize int64
for _, allocation := range segment.allocations { for _, allocation := range segment.allocations {
allocSize += allocation.numOfRows allocSize += allocation.NumOfRows
} }
return float64(segment.currRows) >= sizeFactor*float64(segment.GetMaxRowNum()) return float64(segment.currRows) >= sizeFactor*float64(segment.GetMaxRowNum())
} }
......
...@@ -168,7 +168,7 @@ func SetAllocations(allocations []*Allocation) SegmentInfoOption { ...@@ -168,7 +168,7 @@ func SetAllocations(allocations []*Allocation) SegmentInfoOption {
func AddAllocation(allocation *Allocation) SegmentInfoOption { func AddAllocation(allocation *Allocation) SegmentInfoOption {
return func(segment *SegmentInfo) { return func(segment *SegmentInfo) {
segment.allocations = append(segment.allocations, allocation) segment.allocations = append(segment.allocations, allocation)
segment.LastExpireTime = allocation.expireTime segment.LastExpireTime = allocation.ExpireTime
} }
} }
......
...@@ -29,14 +29,10 @@ import ( ...@@ -29,14 +29,10 @@ import (
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
) )
var errRemainInSufficient = func(requestRows int64) error {
return fmt.Errorf("segment remaining is insufficient for %d", requestRows)
}
// Manager manage segment related operations. // Manager manage segment related operations.
type Manager interface { type Manager interface {
// AllocSegment allocate rows and record the allocation. // AllocSegment allocate rows and record the allocation.
AllocSegment(ctx context.Context, collectionID, partitionID UniqueID, channelName string, requestRows int64) (UniqueID, int64, Timestamp, error) AllocSegment(ctx context.Context, collectionID, partitionID UniqueID, channelName string, requestRows int64) ([]*Allocation, error)
// DropSegment drop the segment from allocator. // DropSegment drop the segment from allocator.
DropSegment(ctx context.Context, segmentID UniqueID) DropSegment(ctx context.Context, segmentID UniqueID)
// SealAllSegments sealed all segmetns of collection with collectionID and return sealed segments // SealAllSegments sealed all segmetns of collection with collectionID and return sealed segments
...@@ -49,8 +45,9 @@ type Manager interface { ...@@ -49,8 +45,9 @@ type Manager interface {
// allcation entry for segment Allocation record // allcation entry for segment Allocation record
type Allocation struct { type Allocation struct {
numOfRows int64 SegmentID UniqueID
expireTime Timestamp NumOfRows int64
ExpireTime Timestamp
} }
// SegmentManager handles segment related logic // SegmentManager handles segment related logic
...@@ -61,7 +58,7 @@ type SegmentManager struct { ...@@ -61,7 +58,7 @@ type SegmentManager struct {
helper allocHelper helper allocHelper
segments []UniqueID segments []UniqueID
estimatePolicy calUpperLimitPolicy estimatePolicy calUpperLimitPolicy
allocPolicy allocatePolicy allocPolicy AllocatePolicy
segmentSealPolicies []segmentSealPolicy segmentSealPolicies []segmentSealPolicy
channelSealPolicies []channelSealPolicy channelSealPolicies []channelSealPolicy
flushPolicy flushPolicy flushPolicy flushPolicy
...@@ -103,7 +100,7 @@ func withCalUpperLimitPolicy(policy calUpperLimitPolicy) allocOption { ...@@ -103,7 +100,7 @@ func withCalUpperLimitPolicy(policy calUpperLimitPolicy) allocOption {
} }
// get allocOption with allocPolicy // get allocOption with allocPolicy
func withAllocPolicy(policy allocatePolicy) allocOption { func withAllocPolicy(policy AllocatePolicy) allocOption {
return allocFunc(func(manager *SegmentManager) { manager.allocPolicy = policy }) return allocFunc(func(manager *SegmentManager) { manager.allocPolicy = policy })
} }
...@@ -132,8 +129,8 @@ func defaultCalUpperLimitPolicy() calUpperLimitPolicy { ...@@ -132,8 +129,8 @@ func defaultCalUpperLimitPolicy() calUpperLimitPolicy {
return calBySchemaPolicy return calBySchemaPolicy
} }
func defaultAlocatePolicy() allocatePolicy { func defaultAlocatePolicy() AllocatePolicy {
return allocatePolicyV1 return AllocatePolicyV1
} }
func defaultSealPolicy() sealPolicy { func defaultSealPolicy() sealPolicy {
...@@ -188,14 +185,14 @@ func (s *SegmentManager) getAllocation(numOfRows int64) *Allocation { ...@@ -188,14 +185,14 @@ func (s *SegmentManager) getAllocation(numOfRows int64) *Allocation {
v := s.allocPool.Get() v := s.allocPool.Get()
if v == nil { if v == nil {
return &Allocation{ return &Allocation{
numOfRows: numOfRows, NumOfRows: numOfRows,
} }
} }
a, ok := v.(*Allocation) a, ok := v.(*Allocation)
if !ok { if !ok {
a = &Allocation{} a = &Allocation{}
} }
a.numOfRows = numOfRows a.NumOfRows = numOfRows
return a return a
} }
...@@ -206,79 +203,61 @@ func (s *SegmentManager) putAllocation(a *Allocation) { ...@@ -206,79 +203,61 @@ func (s *SegmentManager) putAllocation(a *Allocation) {
// AllocSegment allocate segment per request collcation, partication, channel and rows // AllocSegment allocate segment per request collcation, partication, channel and rows
func (s *SegmentManager) AllocSegment(ctx context.Context, collectionID UniqueID, func (s *SegmentManager) AllocSegment(ctx context.Context, collectionID UniqueID,
partitionID UniqueID, channelName string, requestRows int64) (segID UniqueID, retCount int64, expireTime Timestamp, err error) { partitionID UniqueID, channelName string, requestRows int64) ([]*Allocation, error) {
sp, _ := trace.StartSpanFromContext(ctx) sp, _ := trace.StartSpanFromContext(ctx)
defer sp.Finish() defer sp.Finish()
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
var segment *SegmentInfo // filter segments
var allocation *Allocation segments := make([]*SegmentInfo, 0)
for _, segmentID := range s.segments { for _, segmentID := range s.segments {
segment = s.meta.GetSegment(segmentID) segment := s.meta.GetSegment(segmentID)
if segment == nil { if segment == nil {
log.Warn("Failed to get seginfo from meta", zap.Int64("id", segmentID), zap.Error(err)) log.Warn("Failed to get seginfo from meta", zap.Int64("id", segmentID))
continue continue
} }
if segment.State == commonpb.SegmentState_Sealed || segment.CollectionID != collectionID || if segment.State == commonpb.SegmentState_Sealed || segment.CollectionID != collectionID ||
segment.PartitionID != partitionID || segment.InsertChannel != channelName { segment.PartitionID != partitionID || segment.InsertChannel != channelName {
continue continue
} }
allocation, err = s.alloc(segment, requestRows) segments = append(segments, segment)
if err != nil {
return
}
if allocation != nil {
break
}
}
if allocation == nil {
segment, err = s.openNewSegment(ctx, collectionID, partitionID, channelName)
if err != nil {
return
}
segment = s.meta.GetSegment(segment.GetID())
if segment == nil {
log.Warn("Failed to get seg into from meta", zap.Int64("id", segment.GetID()), zap.Error(err))
return
}
allocation, err = s.alloc(segment, requestRows)
if err != nil {
return
}
if allocation == nil {
err = errRemainInSufficient(requestRows)
return
}
} }
segID = segment.GetID() // apply allocate policy
retCount = allocation.numOfRows maxCountPerSegment, err := s.estimateMaxNumOfRows(collectionID)
expireTime = allocation.expireTime if err != nil {
return return nil, err
}
func (s *SegmentManager) alloc(segment *SegmentInfo, numOfRows int64) (*Allocation, error) {
var allocSize int64
for _, allocItem := range segment.allocations {
allocSize += allocItem.numOfRows
}
if !s.allocPolicy(segment, numOfRows) {
return nil, nil
} }
newSegmentAllocations, existedSegmentAllocations := s.allocPolicy(segments,
requestRows, int64(maxCountPerSegment))
alloc := s.getAllocation(numOfRows) // create new segments and add allocations
expireTs, err := s.genExpireTs() expireTs, err := s.genExpireTs()
if err != nil { if err != nil {
return nil, err return nil, err
} }
alloc.expireTime = expireTs for _, allocation := range newSegmentAllocations {
segment, err := s.openNewSegment(ctx, collectionID, partitionID, channelName)
if err != nil {
return nil, err
}
allocation.ExpireTime = expireTs
allocation.SegmentID = segment.GetID()
if err := s.meta.AddAllocation(segment.GetID(), allocation); err != nil {
return nil, err
}
}
for _, allocation := range existedSegmentAllocations {
allocation.ExpireTime = expireTs
if err := s.meta.AddAllocation(allocation.SegmentID, allocation); err != nil {
return nil, err
}
}
//safe here since info is a clone, used to pass expireTs out allocations := append(newSegmentAllocations, existedSegmentAllocations...)
s.meta.AddAllocation(segment.GetID(), alloc) return allocations, nil
return alloc, nil
} }
func (s *SegmentManager) genExpireTs() (Timestamp, error) { func (s *SegmentManager) genExpireTs() (Timestamp, error) {
...@@ -425,7 +404,7 @@ func (s *SegmentManager) ExpireAllocations(channel string, ts Timestamp) error { ...@@ -425,7 +404,7 @@ func (s *SegmentManager) ExpireAllocations(channel string, ts Timestamp) error {
continue continue
} }
for i := 0; i < len(segment.allocations); i++ { for i := 0; i < len(segment.allocations); i++ {
if segment.allocations[i].expireTime <= ts { if segment.allocations[i].ExpireTime <= ts {
a := segment.allocations[i] a := segment.allocations[i]
segment.allocations = append(segment.allocations[:i], segment.allocations[i+1:]...) segment.allocations = append(segment.allocations[:i], segment.allocations[i+1:]...)
s.putAllocation(a) s.putAllocation(a)
......
...@@ -11,11 +11,11 @@ package datacoord ...@@ -11,11 +11,11 @@ package datacoord
import ( import (
"context" "context"
"math"
"testing" "testing"
"github.com/milvus-io/milvus/internal/proto/commonpb" "github.com/milvus-io/milvus/internal/proto/commonpb"
"github.com/milvus-io/milvus/internal/proto/datapb" "github.com/milvus-io/milvus/internal/proto/datapb"
"github.com/milvus-io/milvus/internal/proto/schemapb"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
...@@ -32,27 +32,15 @@ func TestAllocSegment(t *testing.T) { ...@@ -32,27 +32,15 @@ func TestAllocSegment(t *testing.T) {
collID, err := mockAllocator.allocID() collID, err := mockAllocator.allocID()
assert.Nil(t, err) assert.Nil(t, err)
meta.AddCollection(&datapb.CollectionInfo{ID: collID, Schema: schema}) meta.AddCollection(&datapb.CollectionInfo{ID: collID, Schema: schema})
cases := []struct {
collectionID UniqueID t.Run("normal allocation", func(t *testing.T) {
partitionID UniqueID allocations, err := segmentManager.AllocSegment(ctx, collID, 100, "c1", 100)
channelName string assert.Nil(t, err)
requestRows int64 assert.EqualValues(t, 1, len(allocations))
expectResult bool assert.EqualValues(t, 100, allocations[0].NumOfRows)
}{ assert.NotEqualValues(t, 0, allocations[0].SegmentID)
{collID, 100, "c1", 100, true}, assert.NotEqualValues(t, 0, allocations[0].ExpireTime)
{collID, 100, "c1", math.MaxInt64, false}, })
}
for _, c := range cases {
id, count, expireTime, err := segmentManager.AllocSegment(ctx, c.collectionID, c.partitionID, c.channelName, c.requestRows)
if c.expectResult {
assert.Nil(t, err)
assert.EqualValues(t, c.requestRows, count)
assert.NotEqualValues(t, 0, id)
assert.NotEqualValues(t, 0, expireTime)
} else {
assert.NotNil(t, err)
}
}
} }
func TestLoadSegmentsFromMeta(t *testing.T) { func TestLoadSegmentsFromMeta(t *testing.T) {
...@@ -116,13 +104,14 @@ func TestSaveSegmentsToMeta(t *testing.T) { ...@@ -116,13 +104,14 @@ func TestSaveSegmentsToMeta(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
meta.AddCollection(&datapb.CollectionInfo{ID: collID, Schema: schema}) meta.AddCollection(&datapb.CollectionInfo{ID: collID, Schema: schema})
segmentManager := newSegmentManager(meta, mockAllocator) segmentManager := newSegmentManager(meta, mockAllocator)
segID, _, expireTs, err := segmentManager.AllocSegment(context.Background(), collID, 0, "c1", 1000) allocations, err := segmentManager.AllocSegment(context.Background(), collID, 0, "c1", 1000)
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, 1, len(allocations))
_, err = segmentManager.SealAllSegments(context.Background(), collID) _, err = segmentManager.SealAllSegments(context.Background(), collID)
assert.Nil(t, err) assert.Nil(t, err)
segment := meta.GetSegment(segID) segment := meta.GetSegment(allocations[0].SegmentID)
assert.NotNil(t, segment) assert.NotNil(t, segment)
assert.EqualValues(t, segment.LastExpireTime, expireTs) assert.EqualValues(t, segment.LastExpireTime, allocations[0].ExpireTime)
assert.EqualValues(t, commonpb.SegmentState_Sealed, segment.State) assert.EqualValues(t, commonpb.SegmentState_Sealed, segment.State)
} }
...@@ -137,8 +126,10 @@ func TestDropSegment(t *testing.T) { ...@@ -137,8 +126,10 @@ func TestDropSegment(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
meta.AddCollection(&datapb.CollectionInfo{ID: collID, Schema: schema}) meta.AddCollection(&datapb.CollectionInfo{ID: collID, Schema: schema})
segmentManager := newSegmentManager(meta, mockAllocator) segmentManager := newSegmentManager(meta, mockAllocator)
segID, _, _, err := segmentManager.AllocSegment(context.Background(), collID, 0, "c1", 1000) allocations, err := segmentManager.AllocSegment(context.Background(), collID, 0, "c1", 1000)
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, 1, len(allocations))
segID := allocations[0].SegmentID
segment := meta.GetSegment(segID) segment := meta.GetSegment(segID)
assert.NotNil(t, segment) assert.NotNil(t, segment)
...@@ -146,3 +137,25 @@ func TestDropSegment(t *testing.T) { ...@@ -146,3 +137,25 @@ func TestDropSegment(t *testing.T) {
segment = meta.GetSegment(segID) segment = meta.GetSegment(segID)
assert.NotNil(t, segment) assert.NotNil(t, segment)
} }
func TestAllocRowsLargerThanOneSegment(t *testing.T) {
Params.Init()
mockAllocator := newMockAllocator()
meta, err := newMemoryMeta(mockAllocator)
assert.Nil(t, err)
schema := newTestSchema()
collID, err := mockAllocator.allocID()
assert.Nil(t, err)
meta.AddCollection(&datapb.CollectionInfo{ID: collID, Schema: schema})
var mockPolicy = func(schema *schemapb.CollectionSchema) (int, error) {
return 1, nil
}
segmentManager := newSegmentManager(meta, mockAllocator, withCalUpperLimitPolicy(mockPolicy))
allocations, err := segmentManager.AllocSegment(context.TODO(), collID, 0, "c1", 2)
assert.Nil(t, err)
assert.EqualValues(t, 2, len(allocations))
assert.EqualValues(t, 1, allocations[0].NumOfRows)
assert.EqualValues(t, 1, allocations[1].NumOfRows)
}
...@@ -407,7 +407,7 @@ func (s *Server) startActiveCheck(ctx context.Context) { ...@@ -407,7 +407,7 @@ func (s *Server) startActiveCheck(ctx context.Context) {
if ok { if ok {
continue continue
} }
s.Stop() go func() { s.Stop() }()
log.Debug("disconnect with etcd and shutdown data coordinator") log.Debug("disconnect with etcd and shutdown data coordinator")
return return
case <-ctx.Done(): case <-ctx.Done():
...@@ -487,7 +487,6 @@ func (s *Server) Stop() error { ...@@ -487,7 +487,6 @@ func (s *Server) Stop() error {
return nil return nil
} }
log.Debug("DataCoord server shutdown") log.Debug("DataCoord server shutdown")
atomic.StoreInt64(&s.isServing, ServerStateStopped)
s.cluster.Close() s.cluster.Close()
s.stopServerLoop() s.stopServerLoop()
return nil return nil
......
...@@ -11,7 +11,6 @@ package datacoord ...@@ -11,7 +11,6 @@ package datacoord
import ( import (
"context" "context"
"math"
"path" "path"
"strconv" "strconv"
"testing" "testing"
...@@ -26,7 +25,6 @@ import ( ...@@ -26,7 +25,6 @@ import (
"github.com/milvus-io/milvus/internal/types" "github.com/milvus-io/milvus/internal/types"
"github.com/milvus-io/milvus/internal/util/retry" "github.com/milvus-io/milvus/internal/util/retry"
"github.com/milvus-io/milvus/internal/util/sessionutil" "github.com/milvus-io/milvus/internal/util/sessionutil"
"github.com/milvus-io/milvus/internal/util/typeutil"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.etcd.io/etcd/clientv3" "go.etcd.io/etcd/clientv3"
) )
...@@ -57,52 +55,46 @@ func TestAssignSegmentID(t *testing.T) { ...@@ -57,52 +55,46 @@ func TestAssignSegmentID(t *testing.T) {
Schema: schema, Schema: schema,
Partitions: []int64{}, Partitions: []int64{},
}) })
recordSize, err := typeutil.EstimateSizePerRecord(schema)
assert.Nil(t, err)
maxCount := int(Params.SegmentMaxSize * 1024 * 1024 / float64(recordSize))
cases := []struct { t.Run("assign segment normally", func(t *testing.T) {
Description string req := &datapb.SegmentIDRequest{
CollectionID UniqueID Count: 1000,
PartitionID UniqueID ChannelName: channel0,
ChannelName string CollectionID: collID,
Count uint32 PartitionID: partID,
Success bool }
}{
{"assign segment normally", collID, partID, channel0, 1000, true},
{"assign segment with invalid collection", collIDInvalid, partID, channel0, 1000, false},
{"assign with max count", collID, partID, channel0, uint32(maxCount), true},
{"assign with max uint32 count", collID, partID, channel1, math.MaxUint32, false},
}
for _, test := range cases { resp, err := svr.AssignSegmentID(context.TODO(), &datapb.AssignSegmentIDRequest{
t.Run(test.Description, func(t *testing.T) { NodeID: 0,
req := &datapb.SegmentIDRequest{ PeerRole: "",
Count: test.Count, SegmentIDRequests: []*datapb.SegmentIDRequest{req},
ChannelName: test.ChannelName, })
CollectionID: test.CollectionID, assert.Nil(t, err)
PartitionID: test.PartitionID, assert.EqualValues(t, 1, len(resp.SegIDAssignments))
} assign := resp.SegIDAssignments[0]
assert.EqualValues(t, commonpb.ErrorCode_Success, assign.Status.ErrorCode)
assert.EqualValues(t, collID, assign.CollectionID)
assert.EqualValues(t, partID, assign.PartitionID)
assert.EqualValues(t, channel0, assign.ChannelName)
assert.EqualValues(t, 1000, assign.Count)
})
resp, err := svr.AssignSegmentID(context.TODO(), &datapb.AssignSegmentIDRequest{ t.Run("assign segment with invalid collection", func(t *testing.T) {
NodeID: 0, req := &datapb.SegmentIDRequest{
PeerRole: "", Count: 1000,
SegmentIDRequests: []*datapb.SegmentIDRequest{req}, ChannelName: channel0,
}) CollectionID: collIDInvalid,
assert.Nil(t, err) PartitionID: partID,
assert.EqualValues(t, 1, len(resp.SegIDAssignments)) }
assign := resp.SegIDAssignments[0]
if test.Success { resp, err := svr.AssignSegmentID(context.TODO(), &datapb.AssignSegmentIDRequest{
assert.EqualValues(t, commonpb.ErrorCode_Success, assign.Status.ErrorCode) NodeID: 0,
assert.EqualValues(t, test.CollectionID, assign.CollectionID) PeerRole: "",
assert.EqualValues(t, test.PartitionID, assign.PartitionID) SegmentIDRequests: []*datapb.SegmentIDRequest{req},
assert.EqualValues(t, test.ChannelName, assign.ChannelName)
assert.EqualValues(t, test.Count, assign.Count)
} else {
assert.NotEqualValues(t, commonpb.ErrorCode_Success, assign.Status.ErrorCode)
}
}) })
} assert.Nil(t, err)
assert.EqualValues(t, 0, len(resp.SegIDAssignments))
})
} }
func TestFlush(t *testing.T) { func TestFlush(t *testing.T) {
...@@ -110,8 +102,12 @@ func TestFlush(t *testing.T) { ...@@ -110,8 +102,12 @@ func TestFlush(t *testing.T) {
defer closeTestServer(t, svr) defer closeTestServer(t, svr)
schema := newTestSchema() schema := newTestSchema()
svr.meta.AddCollection(&datapb.CollectionInfo{ID: 0, Schema: schema, Partitions: []int64{}}) svr.meta.AddCollection(&datapb.CollectionInfo{ID: 0, Schema: schema, Partitions: []int64{}})
segID, _, expireTs, err := svr.segmentManager.AllocSegment(context.TODO(), 0, 1, "channel-1", 1) allocations, err := svr.segmentManager.AllocSegment(context.TODO(), 0, 1, "channel-1", 1)
assert.Nil(t, err) assert.Nil(t, err)
assert.EqualValues(t, 1, len(allocations))
expireTs := allocations[0].ExpireTime
segID := allocations[0].SegmentID
req := &datapb.FlushRequest{ req := &datapb.FlushRequest{
Base: &commonpb.MsgBase{ Base: &commonpb.MsgBase{
MsgType: commonpb.MsgType_Flush, MsgType: commonpb.MsgType_Flush,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册