未验证 提交 bdb8396e 编写于 作者: G groot 提交者: GitHub

Fix CalcDistance wrong result when fetting vectors from collection (#6976)

* Fix CalcDistance wrong result when fetting vectors from collection
Signed-off-by: Nyhmo <yihua.mo@zilliz.com>

* Fix CalcDistance wrong result when fetting vectors from collection
Signed-off-by: Nyhmo <yihua.mo@zilliz.com>

* preset capacity
Signed-off-by: Nyhmo <yihua.mo@zilliz.com>

* typo
Signed-off-by: Nyhmo <yihua.mo@zilliz.com>

* error check
Signed-off-by: Nyhmo <yihua.mo@zilliz.com>

* code lint
Signed-off-by: Nyhmo <yihua.mo@zilliz.com>
上级 3c3975b5
......@@ -1360,6 +1360,10 @@ func (node *Proxy) Retrieve(ctx context.Context, request *milvuspb.RetrieveReque
zap.Any("partitions", request.PartitionNames),
zap.Any("len(Ids)", len(request.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data)))
defer func() {
idsCount := 0
if rt.result != nil {
idsCount = len(rt.result.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data)
}
log.Debug("Retrieve Done",
zap.Error(err),
zap.String("role", Params.RoleName),
......@@ -1368,7 +1372,7 @@ func (node *Proxy) Retrieve(ctx context.Context, request *milvuspb.RetrieveReque
zap.String("db", request.DbName),
zap.String("collection", request.CollectionName),
zap.Any("partitions", request.PartitionNames),
zap.Any("len(Ids)", len(rt.result.Ids.IdField.(*schemapb.IDs_IntId).IntId.Data)))
zap.Any("len(Ids)", idsCount))
}()
err = rt.WaitToFinish()
......@@ -1593,6 +1597,80 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
return node.Retrieve(ctx, retrieveRequest)
}
// the vectors retrieved are random order, we need re-arrange the vectors by the order of input ids
arrangeFunc := func(ids *milvuspb.VectorIDs, retrievedFields []*schemapb.FieldData) (*schemapb.VectorField, error) {
var retrievedIds *schemapb.ScalarField
var retrievedVectors *schemapb.VectorField
for _, fieldData := range retrievedFields {
if fieldData.FieldName == ids.FieldName {
retrievedVectors = fieldData.GetVectors()
}
if fieldData.Type == schemapb.DataType_Int64 {
retrievedIds = fieldData.GetScalars()
}
}
if retrievedIds == nil || retrievedVectors == nil {
return nil, errors.New("Failed to fetch vectors")
}
dict := make(map[int64]int)
for index, id := range retrievedIds.GetLongData().Data {
dict[id] = index
}
inputIds := ids.IdArray.GetIntId().Data
if retrievedVectors.GetFloatVector() != nil {
floatArr := retrievedVectors.GetFloatVector().Data
element := retrievedVectors.GetDim()
result := make([]float32, 0, int64(len(inputIds))*element)
for _, id := range inputIds {
index, ok := dict[id]
if !ok {
log.Error("id not found in CalcDistance", zap.Int64("id", id))
return nil, errors.New("Failed to fetch vectors by id: " + fmt.Sprintln(id))
}
result = append(result, floatArr[int64(index)*element:int64(index+1)*element]...)
}
return &schemapb.VectorField{
Dim: element,
Data: &schemapb.VectorField_FloatVector{
FloatVector: &schemapb.FloatArray{
Data: result,
},
},
}, nil
}
if retrievedVectors.GetBinaryVector() != nil {
binaryArr := retrievedVectors.GetBinaryVector()
element := retrievedVectors.GetDim()
if element%8 != 0 {
element = element + 8 - element%8
}
result := make([]byte, 0, int64(len(inputIds))*element)
for _, id := range inputIds {
index, ok := dict[id]
if !ok {
log.Error("id not found in CalcDistance", zap.Int64("id", id))
return nil, errors.New("Failed to fetch vectors by id: " + fmt.Sprintln(id))
}
result = append(result, binaryArr[int64(index)*element:int64(index+1)*element]...)
}
return &schemapb.VectorField{
Dim: element * 8,
Data: &schemapb.VectorField_BinaryVector{
BinaryVector: result,
},
}, nil
}
return nil, errors.New("Failed to fetch vectors")
}
vectorsLeft := request.GetOpLeft().GetDataArray()
opLeft := request.GetOpLeft().GetIdArray()
if opLeft != nil {
......@@ -1606,11 +1684,14 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
}, nil
}
for _, fieldData := range result.FieldsData {
if fieldData.FieldName == opLeft.FieldName {
vectorsLeft = fieldData.GetVectors()
break
}
vectorsLeft, err = arrangeFunc(opLeft, result.FieldsData)
if err != nil {
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, nil
}
}
......@@ -1636,11 +1717,14 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
}, nil
}
for _, fieldData := range result.FieldsData {
if fieldData.FieldName == opRight.FieldName {
vectorsRight = fieldData.GetVectors()
break
}
vectorsRight, err = arrangeFunc(opRight, result.FieldsData)
if err != nil {
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: err.Error(),
},
}, nil
}
}
......@@ -1653,7 +1737,16 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
}, nil
}
if vectorsLeft.Dim == vectorsRight.Dim && vectorsLeft.GetFloatVector() != nil && vectorsRight.GetFloatVector() != nil {
if vectorsLeft.Dim != vectorsRight.Dim {
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
Reason: "Vectors dimension is not equal",
},
}, nil
}
if vectorsLeft.GetFloatVector() != nil && vectorsRight.GetFloatVector() != nil {
distances, err := distance.CalcFloatDistance(vectorsLeft.Dim, vectorsLeft.GetFloatVector().Data, vectorsRight.GetFloatVector().Data, metric)
if err != nil {
return &milvuspb.CalcDistanceResults{
......@@ -1674,7 +1767,7 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
}, nil
}
if vectorsLeft.Dim == vectorsRight.Dim && vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetBinaryVector() != nil {
if vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetBinaryVector() != nil {
hamming, err := distance.CalcHammingDistance(vectorsLeft.Dim, vectorsLeft.GetBinaryVector(), vectorsRight.GetBinaryVector())
if err != nil {
return &milvuspb.CalcDistanceResults{
......@@ -1719,6 +1812,10 @@ func (node *Proxy) CalcDistance(ctx context.Context, request *milvuspb.CalcDista
}
err = errors.New("Unexpected error")
if (vectorsLeft.GetBinaryVector() != nil && vectorsRight.GetFloatVector() != nil) || (vectorsLeft.GetFloatVector() != nil && vectorsRight.GetBinaryVector() != nil) {
err = errors.New("Cannot calculate distance between binary vectors and float vectors")
}
return &milvuspb.CalcDistanceResults{
Status: &commonpb.Status{
ErrorCode: commonpb.ErrorCode_UnexpectedError,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册