提交 8715cd1f 编写于 作者: X xige-16 提交者: yefu.chen

Fix key error when loading index

Signed-off-by: Nxige-16 <xi.ge@zilliz.com>
上级 6bcca0a9
......@@ -86,10 +86,14 @@ AppendIndex(CLoadIndexInfo c_load_index_info, CBinarySet c_binary_set) {
auto& index_params = load_index_info->index_params;
bool find_index_type = index_params.count("index_type") > 0 ? true : false;
bool find_index_mode = index_params.count("index_mode") > 0 ? true : false;
Assert(find_index_mode == true);
Assert(find_index_type == true);
auto mode = index_params["index_mode"] == "CPU" ? milvus::knowhere::IndexMode::MODE_CPU
: milvus::knowhere::IndexMode::MODE_GPU;
milvus::knowhere::IndexMode mode;
if (find_index_mode) {
mode = index_params["index_mode"] == "CPU" ? milvus::knowhere::IndexMode::MODE_CPU
: milvus::knowhere::IndexMode::MODE_GPU;
} else {
mode = milvus::knowhere::IndexMode::MODE_CPU;
}
load_index_info->index =
milvus::knowhere::VecIndexFactory::GetInstance().CreateVecIndex(index_params["index_type"], mode);
load_index_info->index->Load(*binary_set);
......
......@@ -375,7 +375,7 @@ generate_data(int N) {
}
std::string
generate_collection_shema(std::string metric_type, std::string dim) {
generate_collection_shema(std::string metric_type, std::string dim, bool is_binary) {
schema::CollectionSchema collection_schema;
collection_schema.set_name("collection_test");
collection_schema.set_autoid(true);
......@@ -383,7 +383,11 @@ generate_collection_shema(std::string metric_type, std::string dim) {
auto vec_field_schema = collection_schema.add_fields();
vec_field_schema->set_name("fakevec");
vec_field_schema->set_fieldid(100);
vec_field_schema->set_data_type(schema::DataType::VECTOR_FLOAT);
if (is_binary) {
vec_field_schema->set_data_type(schema::DataType::VECTOR_BINARY);
} else {
vec_field_schema->set_data_type(schema::DataType::VECTOR_FLOAT);
}
auto metric_type_param = vec_field_schema->add_index_params();
metric_type_param->set_key("metric_type");
metric_type_param->set_value(metric_type);
......@@ -838,7 +842,7 @@ TEST(CApiTest, UpdateSegmentIndex_Without_Predicate) {
constexpr auto DIM = 16;
constexpr auto K = 5;
std::string schema_string = generate_collection_shema("L2", "16");
std::string schema_string = generate_collection_shema("L2", "16", false);
auto collection = NewCollection(schema_string.c_str());
auto schema = ((segcore::Collection*)collection)->get_schema();
auto segment = NewSegment(collection, 0);
......@@ -958,12 +962,12 @@ TEST(CApiTest, UpdateSegmentIndex_Without_Predicate) {
DeleteSegment(segment);
}
TEST(CApiTest, UpdateSegmentIndex_With_Predicate_Range) {
TEST(CApiTest, UpdateSegmentIndex_With_float_Predicate_Range) {
// insert data to segment
constexpr auto DIM = 16;
constexpr auto K = 5;
std::string schema_string = generate_collection_shema("L2", "16");
std::string schema_string = generate_collection_shema("L2", "16", false);
auto collection = NewCollection(schema_string.c_str());
auto schema = ((segcore::Collection*)collection)->get_schema();
auto segment = NewSegment(collection, 0);
......@@ -1096,12 +1100,12 @@ TEST(CApiTest, UpdateSegmentIndex_With_Predicate_Range) {
DeleteSegment(segment);
}
TEST(CApiTest, UpdateSegmentIndex_With_Predicate_Term) {
TEST(CApiTest, UpdateSegmentIndex_With_float_Predicate_Term) {
// insert data to segment
constexpr auto DIM = 16;
constexpr auto K = 5;
std::string schema_string = generate_collection_shema("L2", "16");
std::string schema_string = generate_collection_shema("L2", "16", false);
auto collection = NewCollection(schema_string.c_str());
auto schema = ((segcore::Collection*)collection)->get_schema();
auto segment = NewSegment(collection, 0);
......@@ -1224,6 +1228,290 @@ TEST(CApiTest, UpdateSegmentIndex_With_Predicate_Term) {
search_result_on_raw_index->result_distances_[offset]);
}
DeleteLoadIndexInfo(c_load_index_info);
DeletePlan(plan);
DeletePlaceholderGroup(placeholderGroup);
DeleteQueryResult(c_search_result_on_smallIndex);
DeleteQueryResult(c_search_result_on_bigIndex);
DeleteCollection(collection);
DeleteSegment(segment);
}
TEST(CApiTest, UpdateSegmentIndex_With_binary_Predicate_Range) {
// insert data to segment
constexpr auto DIM = 16;
constexpr auto K = 5;
std::string schema_string = generate_collection_shema("JACCARD", "16", true);
auto collection = NewCollection(schema_string.c_str());
auto schema = ((segcore::Collection*)collection)->get_schema();
auto segment = NewSegment(collection, 0);
auto N = 1000 * 1000;
auto dataset = DataGen(schema, N);
auto vec_col = dataset.get_col<uint8_t>(0);
auto query_ptr = vec_col.data() + 420000 * DIM / 8;
PreInsert(segment, N);
auto ins_res = Insert(segment, 0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_.raw_data,
dataset.raw_.sizeof_per_row, dataset.raw_.count);
assert(ins_res.error_code == Success);
const char* dsl_string = R"({
"bool": {
"must": [
{
"range": {
"counter": {
"GE": 420000,
"LT": 420010
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "JACCARD",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5
}
}
}
]
}
})";
// create place_holder_group
int num_queries = 5;
auto raw_group = CreateBinaryPlaceholderGroupFromBlob(num_queries, DIM, query_ptr);
auto blob = raw_group.SerializeAsString();
// search on segment's small index
void* plan = nullptr;
auto status = CreatePlan(collection, dsl_string, &plan);
assert(status.error_code == Success);
void* placeholderGroup = nullptr;
status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup);
assert(status.error_code == Success);
std::vector<CPlaceholderGroup> placeholderGroups;
placeholderGroups.push_back(placeholderGroup);
Timestamp time = 10000000;
CQueryResult c_search_result_on_smallIndex;
auto res_before_load_index =
Search(segment, plan, placeholderGroups.data(), &time, 1, &c_search_result_on_smallIndex);
assert(res_before_load_index.error_code == Success);
// load index to segment
auto conf = milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, DIM},
{milvus::knowhere::meta::TOPK, K},
{milvus::knowhere::IndexParams::nprobe, 10},
{milvus::knowhere::IndexParams::nlist, 100},
{milvus::knowhere::IndexParams::m, 4},
{milvus::knowhere::IndexParams::nbits, 8},
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::JACCARD},
};
auto indexing = generate_index(vec_col.data(), conf, DIM, K, N, IndexEnum::INDEX_FAISS_BIN_IVFFLAT);
// gen query dataset
auto query_dataset = milvus::knowhere::GenDataset(num_queries, DIM, query_ptr);
auto result_on_index = indexing->Query(query_dataset, conf, nullptr);
auto ids = result_on_index->Get<int64_t*>(milvus::knowhere::meta::IDS);
auto dis = result_on_index->Get<float*>(milvus::knowhere::meta::DISTANCE);
std::vector<int64_t> vec_ids(ids, ids + K * num_queries);
std::vector<float> vec_dis;
for (int j = 0; j < K * num_queries; ++j) {
vec_dis.push_back(dis[j] * -1);
}
auto search_result_on_raw_index = (QueryResult*)c_search_result_on_smallIndex;
search_result_on_raw_index->internal_seg_offsets_ = vec_ids;
search_result_on_raw_index->result_distances_ = vec_dis;
auto binary_set = indexing->Serialize(conf);
void* c_load_index_info = nullptr;
status = NewLoadIndexInfo(&c_load_index_info);
assert(status.error_code == Success);
std::string index_type_key = "index_type";
std::string index_type_value = "BIN_IVF_FLAT";
std::string index_mode_key = "index_mode";
std::string index_mode_value = "cpu";
std::string metric_type_key = "metric_type";
std::string metric_type_value = "JACCARD";
AppendIndexParam(c_load_index_info, index_type_key.c_str(), index_type_value.c_str());
AppendIndexParam(c_load_index_info, index_mode_key.c_str(), index_mode_value.c_str());
AppendIndexParam(c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str());
AppendFieldInfo(c_load_index_info, "fakevec", 0);
AppendIndex(c_load_index_info, (CBinarySet)&binary_set);
status = UpdateSegmentIndex(segment, c_load_index_info);
assert(status.error_code == Success);
CQueryResult c_search_result_on_bigIndex;
auto res_after_load_index = Search(segment, plan, placeholderGroups.data(), &time, 1, &c_search_result_on_bigIndex);
assert(res_after_load_index.error_code == Success);
auto search_result_on_bigIndex = (*(QueryResult*)c_search_result_on_bigIndex);
for (int i = 0; i < num_queries; ++i) {
auto offset = i * K;
ASSERT_EQ(search_result_on_bigIndex.internal_seg_offsets_[offset], 420000 + i);
ASSERT_EQ(search_result_on_bigIndex.result_distances_[offset],
search_result_on_raw_index->result_distances_[offset]);
}
DeleteLoadIndexInfo(c_load_index_info);
DeletePlan(plan);
DeletePlaceholderGroup(placeholderGroup);
DeleteQueryResult(c_search_result_on_smallIndex);
DeleteQueryResult(c_search_result_on_bigIndex);
DeleteCollection(collection);
DeleteSegment(segment);
}
TEST(CApiTest, UpdateSegmentIndex_With_binary_Predicate_Term) {
// insert data to segment
constexpr auto DIM = 16;
constexpr auto K = 5;
std::string schema_string = generate_collection_shema("JACCARD", "16", true);
auto collection = NewCollection(schema_string.c_str());
auto schema = ((segcore::Collection*)collection)->get_schema();
auto segment = NewSegment(collection, 0);
auto N = 1000 * 1000;
auto dataset = DataGen(schema, N);
auto vec_col = dataset.get_col<uint8_t>(0);
auto query_ptr = vec_col.data() + 420000 * DIM / 8;
PreInsert(segment, N);
auto ins_res = Insert(segment, 0, N, dataset.row_ids_.data(), dataset.timestamps_.data(), dataset.raw_.raw_data,
dataset.raw_.sizeof_per_row, dataset.raw_.count);
assert(ins_res.error_code == Success);
const char* dsl_string = R"({
"bool": {
"must": [
{
"term": {
"counter": {
"values": [420000, 420001, 420002, 420003, 420004]
}
}
},
{
"vector": {
"fakevec": {
"metric_type": "JACCARD",
"params": {
"nprobe": 10
},
"query": "$0",
"topk": 5
}
}
}
]
}
})";
// create place_holder_group
int num_queries = 5;
auto raw_group = CreateBinaryPlaceholderGroupFromBlob(num_queries, DIM, query_ptr);
auto blob = raw_group.SerializeAsString();
// search on segment's small index
void* plan = nullptr;
auto status = CreatePlan(collection, dsl_string, &plan);
assert(status.error_code == Success);
void* placeholderGroup = nullptr;
status = ParsePlaceholderGroup(plan, blob.data(), blob.length(), &placeholderGroup);
assert(status.error_code == Success);
std::vector<CPlaceholderGroup> placeholderGroups;
placeholderGroups.push_back(placeholderGroup);
Timestamp time = 10000000;
CQueryResult c_search_result_on_smallIndex;
auto res_before_load_index =
Search(segment, plan, placeholderGroups.data(), &time, 1, &c_search_result_on_smallIndex);
assert(res_before_load_index.error_code == Success);
// load index to segment
auto conf = milvus::knowhere::Config{
{milvus::knowhere::meta::DIM, DIM},
{milvus::knowhere::meta::TOPK, K},
{milvus::knowhere::IndexParams::nprobe, 10},
{milvus::knowhere::IndexParams::nlist, 100},
{milvus::knowhere::IndexParams::m, 4},
{milvus::knowhere::IndexParams::nbits, 8},
{milvus::knowhere::Metric::TYPE, milvus::knowhere::Metric::JACCARD},
};
auto indexing = generate_index(vec_col.data(), conf, DIM, K, N, IndexEnum::INDEX_FAISS_BIN_IVFFLAT);
// gen query dataset
auto query_dataset = milvus::knowhere::GenDataset(num_queries, DIM, query_ptr);
auto result_on_index = indexing->Query(query_dataset, conf, nullptr);
auto ids = result_on_index->Get<int64_t*>(milvus::knowhere::meta::IDS);
auto dis = result_on_index->Get<float*>(milvus::knowhere::meta::DISTANCE);
std::vector<int64_t> vec_ids(ids, ids + K * num_queries);
std::vector<float> vec_dis;
for (int j = 0; j < K * num_queries; ++j) {
vec_dis.push_back(dis[j] * -1);
}
auto search_result_on_raw_index = (QueryResult*)c_search_result_on_smallIndex;
search_result_on_raw_index->internal_seg_offsets_ = vec_ids;
search_result_on_raw_index->result_distances_ = vec_dis;
auto binary_set = indexing->Serialize(conf);
void* c_load_index_info = nullptr;
status = NewLoadIndexInfo(&c_load_index_info);
assert(status.error_code == Success);
std::string index_type_key = "index_type";
std::string index_type_value = "BIN_IVF_FLAT";
std::string index_mode_key = "index_mode";
std::string index_mode_value = "cpu";
std::string metric_type_key = "metric_type";
std::string metric_type_value = "JACCARD";
AppendIndexParam(c_load_index_info, index_type_key.c_str(), index_type_value.c_str());
AppendIndexParam(c_load_index_info, index_mode_key.c_str(), index_mode_value.c_str());
AppendIndexParam(c_load_index_info, metric_type_key.c_str(), metric_type_value.c_str());
AppendFieldInfo(c_load_index_info, "fakevec", 0);
AppendIndex(c_load_index_info, (CBinarySet)&binary_set);
status = UpdateSegmentIndex(segment, c_load_index_info);
assert(status.error_code == Success);
CQueryResult c_search_result_on_bigIndex;
auto res_after_load_index = Search(segment, plan, placeholderGroups.data(), &time, 1, &c_search_result_on_bigIndex);
assert(res_after_load_index.error_code == Success);
std::vector<CQueryResult> results;
results.push_back(c_search_result_on_bigIndex);
bool is_selected[1] = {false};
status = ReduceQueryResults(results.data(), 1, is_selected);
assert(status.error_code == Success);
FillTargetEntry(segment, plan, c_search_result_on_bigIndex);
auto search_result_on_bigIndex = (*(QueryResult*)c_search_result_on_bigIndex);
for (int i = 0; i < num_queries; ++i) {
auto offset = i * K;
ASSERT_EQ(search_result_on_bigIndex.internal_seg_offsets_[offset], 420000 + i);
ASSERT_EQ(search_result_on_bigIndex.result_distances_[offset],
search_result_on_raw_index->result_distances_[offset]);
}
DeleteLoadIndexInfo(c_load_index_info);
DeletePlan(plan);
DeletePlaceholderGroup(placeholderGroup);
......
......@@ -37,7 +37,7 @@ const (
nlist = 100
m = 4
nbits = 8
nb = 8 * 10000
nb = 10000
nprobe = 8
sliceSize = 4
efConstruction = 200
......@@ -192,15 +192,19 @@ func generateParams(indexType, metricType string) (map[string]string, map[string
func generateFloatVectors() []float32 {
vectors := make([]float32, 0)
for i := 0; i < nb; i++ {
vectors = append(vectors, rand.Float32())
for j := 0; j < dim; j++ {
vectors = append(vectors, rand.Float32())
}
}
return vectors
}
func generateBinaryVectors() []byte {
vectors := make([]byte, 0)
for i := 0; i < nb/8; i++ {
vectors = append(vectors, byte(rand.Intn(8)))
for i := 0; i < nb; i++ {
for j := 0; j < dim/8; j++ {
vectors = append(vectors, byte(rand.Intn(8)))
}
}
return vectors
}
......
......@@ -138,7 +138,7 @@ func TestCollectionReplica_addPartitionsByCollectionMeta(t *testing.T) {
collectionID := UniqueID(0)
initTestMeta(t, node, collectionName, collectionID, 0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
collectionMeta.PartitionTags = []string{"p0", "p1", "p2"}
err := node.replica.addPartitionsByCollectionMeta(collectionMeta)
......@@ -162,7 +162,7 @@ func TestCollectionReplica_removePartitionsByCollectionMeta(t *testing.T) {
collectionID := UniqueID(0)
initTestMeta(t, node, collectionName, collectionID, 0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
collectionMeta.PartitionTags = []string{"p0"}
err := node.replica.addPartitionsByCollectionMeta(collectionMeta)
......@@ -187,7 +187,7 @@ func TestCollectionReplica_getPartitionByTag(t *testing.T) {
collectionID := UniqueID(0)
initTestMeta(t, node, collectionName, collectionID, 0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
for _, tag := range collectionMeta.PartitionTags {
err := node.replica.addPartition(collectionID, tag)
......@@ -206,7 +206,7 @@ func TestCollectionReplica_hasPartition(t *testing.T) {
collectionID := UniqueID(0)
initTestMeta(t, node, collectionName, collectionID, 0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
err := node.replica.addPartition(collectionID, collectionMeta.PartitionTags[0])
assert.NoError(t, err)
hasPartition := node.replica.hasPartition(collectionID, "default")
......
......@@ -23,7 +23,7 @@ func TestCollection_Partitions(t *testing.T) {
func TestCollection_newCollection(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
schemaBlob := proto.MarshalTextString(collectionMeta.Schema)
assert.NotEqual(t, "", schemaBlob)
......@@ -36,7 +36,7 @@ func TestCollection_newCollection(t *testing.T) {
func TestCollection_deleteCollection(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
schemaBlob := proto.MarshalTextString(collectionMeta.Schema)
assert.NotEqual(t, "", schemaBlob)
......
......@@ -10,6 +10,8 @@ package querynode
import "C"
import (
"errors"
"fmt"
"path/filepath"
"strconv"
"unsafe"
)
......@@ -77,7 +79,9 @@ func (li *LoadIndexInfo) appendIndex(bytesIndex [][]byte, indexKeys []string) er
for i, byteIndex := range bytesIndex {
indexPtr := unsafe.Pointer(&byteIndex[0])
indexLen := C.long(len(byteIndex))
indexKey := C.CString(indexKeys[i])
binarySetKey := filepath.Base(indexKeys[i])
fmt.Println("index key = ", binarySetKey)
indexKey := C.CString(binarySetKey)
status = C.AppendBinaryIndex(cBinarySet, indexPtr, indexLen, indexKey)
errorCode = status.error_code
if errorCode != 0 {
......
......@@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"log"
"path/filepath"
"sort"
"strconv"
"strings"
......@@ -224,9 +223,8 @@ func (lis *loadIndexService) loadIndex(indexPath []string) ([][]byte, error) {
index := make([][]byte, 0)
for _, path := range indexPath {
// get binarySetKey from indexPath
binarySetKey := filepath.Base(path)
indexPiece, err := (*lis.client).Load(binarySetKey)
fmt.Println("load path = ", indexPath)
indexPiece, err := (*lis.client).Load(path)
if err != nil {
return nil, err
}
......
package querynode
import (
"encoding/binary"
"fmt"
"log"
"math"
"math/rand"
"sort"
"strconv"
"testing"
"github.com/golang/protobuf/proto"
"github.com/stretchr/testify/assert"
"github.com/zilliztech/milvus-distributed/internal/indexbuilder"
minioKV "github.com/zilliztech/milvus-distributed/internal/kv/minio"
"github.com/zilliztech/milvus-distributed/internal/msgstream"
"github.com/zilliztech/milvus-distributed/internal/proto/commonpb"
"github.com/zilliztech/milvus-distributed/internal/proto/internalpb"
"github.com/zilliztech/milvus-distributed/internal/proto/servicepb"
"github.com/zilliztech/milvus-distributed/internal/querynode/client"
)
......@@ -21,20 +29,195 @@ func TestLoadIndexService(t *testing.T) {
initTestMeta(t, node, "collection0", collectionID, segmentID)
// loadIndexService and statsService
node.loadIndexService = newLoadIndexService(node.queryNodeLoopCtx, node.replica)
go node.loadIndexService.start()
node.statsService = newStatsService(node.queryNodeLoopCtx, node.replica, node.loadIndexService.fieldStatsChan)
go node.statsService.start()
oldSearchChannelNames := Params.SearchChannelNames
var newSearchChannelNames []string
for _, channel := range oldSearchChannelNames {
newSearchChannelNames = append(newSearchChannelNames, channel+"new")
}
Params.SearchChannelNames = newSearchChannelNames
oldSearchResultChannelNames := Params.SearchChannelNames
var newSearchResultChannelNames []string
for _, channel := range oldSearchResultChannelNames {
newSearchResultChannelNames = append(newSearchResultChannelNames, channel+"new")
}
Params.SearchResultChannelNames = newSearchResultChannelNames
go node.Start()
//generate insert data
const msgLength = 1000
const receiveBufSize = 1024
const DIM = 16
var insertRowBlob []*commonpb.Blob
var timestamps []uint64
var rowIDs []int64
var hashValues []uint32
for n := 0; n < msgLength; n++ {
rowData := make([]byte, 0)
for i := 0; i < DIM; i++ {
vec := make([]byte, 4)
binary.LittleEndian.PutUint32(vec, math.Float32bits(float32(n*i)))
rowData = append(rowData, vec...)
}
age := make([]byte, 4)
binary.LittleEndian.PutUint32(age, 1)
rowData = append(rowData, age...)
blob := &commonpb.Blob{
Value: rowData,
}
insertRowBlob = append(insertRowBlob, blob)
timestamps = append(timestamps, uint64(n))
rowIDs = append(rowIDs, int64(n))
hashValues = append(hashValues, uint32(n))
}
var insertMsg msgstream.TsMsg = &msgstream.InsertMsg{
BaseMsg: msgstream.BaseMsg{
HashValues: hashValues,
},
InsertRequest: internalpb.InsertRequest{
MsgType: internalpb.MsgType_kInsert,
ReqID: 0,
CollectionName: "collection0",
PartitionTag: "default",
SegmentID: segmentID,
ChannelID: int64(0),
ProxyID: int64(0),
Timestamps: timestamps,
RowIDs: rowIDs,
RowData: insertRowBlob,
},
}
insertMsgPack := msgstream.MsgPack{
BeginTs: 0,
EndTs: math.MaxUint64,
Msgs: []msgstream.TsMsg{insertMsg},
}
// generate timeTick
timeTickMsg := &msgstream.TimeTickMsg{
BaseMsg: msgstream.BaseMsg{
BeginTimestamp: 0,
EndTimestamp: 0,
HashValues: []uint32{0},
},
TimeTickMsg: internalpb.TimeTickMsg{
MsgType: internalpb.MsgType_kTimeTick,
PeerID: UniqueID(0),
Timestamp: math.MaxUint64,
},
}
timeTickMsgPack := &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{timeTickMsg},
}
// pulsar produce
insertChannels := Params.InsertChannelNames
ddChannels := Params.DDChannelNames
insertStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
insertStream.SetPulsarClient(Params.PulsarAddress)
insertStream.CreatePulsarProducers(insertChannels)
ddStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
ddStream.SetPulsarClient(Params.PulsarAddress)
ddStream.CreatePulsarProducers(ddChannels)
var insertMsgStream msgstream.MsgStream = insertStream
insertMsgStream.Start()
var ddMsgStream msgstream.MsgStream = ddStream
ddMsgStream.Start()
err := insertMsgStream.Produce(&insertMsgPack)
assert.NoError(t, err)
err = insertMsgStream.Broadcast(timeTickMsgPack)
assert.NoError(t, err)
err = ddMsgStream.Broadcast(timeTickMsgPack)
assert.NoError(t, err)
// generator searchRowData
var searchRowData []float32
for i := 0; i < DIM; i++ {
searchRowData = append(searchRowData, float32(42*i))
}
//generate search data and send search msg
dslString := "{\"bool\": { \n\"vector\": {\n \"vec\": {\n \"metric_type\": \"L2\", \n \"params\": {\n \"nprobe\": 10 \n},\n \"query\": \"$0\",\"topk\": 10 \n } \n } \n } \n }"
var searchRowByteData []byte
for i := range searchRowData {
vec := make([]byte, 4)
binary.LittleEndian.PutUint32(vec, math.Float32bits(searchRowData[i]))
searchRowByteData = append(searchRowByteData, vec...)
}
placeholderValue := servicepb.PlaceholderValue{
Tag: "$0",
Type: servicepb.PlaceholderType_VECTOR_FLOAT,
Values: [][]byte{searchRowByteData},
}
placeholderGroup := servicepb.PlaceholderGroup{
Placeholders: []*servicepb.PlaceholderValue{&placeholderValue},
}
placeGroupByte, err := proto.Marshal(&placeholderGroup)
if err != nil {
log.Print("marshal placeholderGroup failed")
}
query := servicepb.Query{
CollectionName: "collection0",
PartitionTags: []string{"default"},
Dsl: dslString,
PlaceholderGroup: placeGroupByte,
}
queryByte, err := proto.Marshal(&query)
if err != nil {
log.Print("marshal query failed")
}
blob := commonpb.Blob{
Value: queryByte,
}
fn := func(n int64) *msgstream.MsgPack {
searchMsg := &msgstream.SearchMsg{
BaseMsg: msgstream.BaseMsg{
HashValues: []uint32{0},
},
SearchRequest: internalpb.SearchRequest{
MsgType: internalpb.MsgType_kSearch,
ReqID: n,
ProxyID: int64(1),
Timestamp: uint64(msgLength),
ResultChannelID: int64(0),
Query: &blob,
},
}
return &msgstream.MsgPack{
Msgs: []msgstream.TsMsg{searchMsg},
}
}
searchStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
searchStream.SetPulsarClient(Params.PulsarAddress)
searchStream.CreatePulsarProducers(newSearchChannelNames)
searchStream.Start()
err = searchStream.Produce(fn(1))
assert.NoError(t, err)
//get search result
searchResultStream := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, receiveBufSize)
searchResultStream.SetPulsarClient(Params.PulsarAddress)
unmarshalDispatcher := msgstream.NewUnmarshalDispatcher()
searchResultStream.CreatePulsarConsumers(newSearchResultChannelNames, "loadIndexTestSubSearchResult", unmarshalDispatcher, receiveBufSize)
searchResultStream.Start()
searchResult := searchResultStream.Consume()
assert.NotNil(t, searchResult)
unMarshaledHit := servicepb.Hits{}
err = proto.Unmarshal(searchResult.Msgs[0].(*msgstream.SearchResultMsg).Hits[0], &unMarshaledHit)
assert.Nil(t, err)
// gen load index message pack
const msgLength = 10000
indexParams := make(map[string]string)
indexParams["index_type"] = "IVF_PQ"
indexParams["index_mode"] = "cpu"
indexParams["dim"] = "16"
indexParams["k"] = "10"
indexParams["nlist"] = "100"
indexParams["nprobe"] = "4"
indexParams["nprobe"] = "10"
indexParams["m"] = "4"
indexParams["nbits"] = "8"
indexParams["metric_type"] = "L2"
......@@ -51,20 +234,16 @@ func TestLoadIndexService(t *testing.T) {
// generator index
typeParams := make(map[string]string)
typeParams["dim"] = "16"
index, err := indexbuilder.NewCIndex(typeParams, indexParams)
assert.Nil(t, err)
const DIM = 16
var vec = [DIM]float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
var indexRowData []float32
for i := 0; i < msgLength; i++ {
for i, ele := range vec {
indexRowData = append(indexRowData, ele+float32(i*4))
for n := 0; n < msgLength; n++ {
for i := 0; i < DIM; i++ {
indexRowData = append(indexRowData, float32(n*i))
}
}
index, err := indexbuilder.NewCIndex(typeParams, indexParams)
assert.Nil(t, err)
err = index.BuildFloatVecIndexWithoutIds(indexRowData)
assert.Equal(t, err, nil)
binarySet, err := index.Serialize()
assert.Equal(t, err, nil)
option := &minioKV.Option{
Address: Params.MinioEndPoint,
......@@ -77,22 +256,29 @@ func TestLoadIndexService(t *testing.T) {
minioKV, err := minioKV.NewMinIOKV(node.queryNodeLoopCtx, option)
assert.Equal(t, err, nil)
//save index to minio
binarySet, err := index.Serialize()
assert.Equal(t, err, nil)
indexPaths := make([]string, 0)
for _, index := range binarySet {
indexPaths = append(indexPaths, index.Key)
minioKV.Save(index.Key, string(index.Value))
path := strconv.Itoa(int(segmentID)) + "/" + index.Key
indexPaths = append(indexPaths, path)
minioKV.Save(path, string(index.Value))
}
//test index search result
indexResult, err := index.QueryOnFloatVecIndexWithParam(searchRowData, indexParams)
assert.Equal(t, err, nil)
// create loadIndexClient
fieldID := UniqueID(100)
loadIndexChannelNames := Params.LoadIndexChannelNames
pulsarURL := Params.PulsarAddress
client := client.NewLoadIndexClient(node.queryNodeLoopCtx, pulsarURL, loadIndexChannelNames)
client := client.NewLoadIndexClient(node.queryNodeLoopCtx, Params.PulsarAddress, loadIndexChannelNames)
client.LoadIndex(indexPaths, segmentID, fieldID, "vec", indexParams)
// init message stream consumer and do checks
statsMs := msgstream.NewPulsarMsgStream(node.queryNodeLoopCtx, Params.StatsReceiveBufSize)
statsMs.SetPulsarClient(pulsarURL)
statsMs.SetPulsarClient(Params.PulsarAddress)
statsMs.CreatePulsarConsumers([]string{Params.StatsChannelName}, Params.MsgChannelSubName, msgstream.NewUnmarshalDispatcher(), Params.StatsReceiveBufSize)
statsMs.Start()
......@@ -127,6 +313,23 @@ func TestLoadIndexService(t *testing.T) {
}
}
err = searchStream.Produce(fn(2))
assert.NoError(t, err)
searchResult = searchResultStream.Consume()
assert.NotNil(t, searchResult)
err = proto.Unmarshal(searchResult.Msgs[0].(*msgstream.SearchResultMsg).Hits[0], &unMarshaledHit)
assert.Nil(t, err)
idsIndex := indexResult.IDs()
idsSegment := unMarshaledHit.IDs
assert.Equal(t, len(idsIndex), len(idsSegment))
for i := 0; i < len(idsIndex); i++ {
assert.Equal(t, idsIndex[i], idsSegment[i])
}
Params.SearchChannelNames = oldSearchChannelNames
Params.SearchResultChannelNames = oldSearchResultChannelNames
fmt.Println("loadIndex floatVector test Done!")
defer assert.Equal(t, findFiledStats, true)
<-node.queryNodeLoopCtx.Done()
node.Close()
......
......@@ -98,7 +98,7 @@ func TestMetaService_isSegmentChannelRangeInQueryNodeChannelRange(t *testing.T)
func TestMetaService_printCollectionStruct(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
printCollectionStruct(collectionMeta)
}
......
......@@ -14,7 +14,7 @@ import (
func TestPlan_Plan(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
schemaBlob := proto.MarshalTextString(collectionMeta.Schema)
assert.NotEqual(t, "", schemaBlob)
......@@ -36,7 +36,7 @@ func TestPlan_Plan(t *testing.T) {
func TestPlan_PlaceholderGroup(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
schemaBlob := proto.MarshalTextString(collectionMeta.Schema)
assert.NotEqual(t, "", schemaBlob)
......
......@@ -16,7 +16,7 @@ import (
"github.com/zilliztech/milvus-distributed/internal/proto/schemapb"
)
const ctxTimeInMillisecond = 2000
const ctxTimeInMillisecond = 5000
const closeWithDeadline = true
func setup() {
......@@ -24,24 +24,46 @@ func setup() {
Params.MetaRootPath = "/etcd/test/root/querynode"
}
func genTestCollectionMeta(collectionName string, collectionID UniqueID) *etcdpb.CollectionMeta {
fieldVec := schemapb.FieldSchema{
FieldID: UniqueID(100),
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
func genTestCollectionMeta(collectionName string, collectionID UniqueID, isBinary bool) *etcdpb.CollectionMeta {
var fieldVec schemapb.FieldSchema
if isBinary {
fieldVec = schemapb.FieldSchema{
FieldID: UniqueID(100),
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_BINARY,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
},
IndexParams: []*commonpb.KeyValuePair{
{
Key: "metric_type",
Value: "L2",
IndexParams: []*commonpb.KeyValuePair{
{
Key: "metric_type",
Value: "JACCARD",
},
},
},
}
} else {
fieldVec = schemapb.FieldSchema{
FieldID: UniqueID(100),
Name: "vec",
IsPrimaryKey: false,
DataType: schemapb.DataType_VECTOR_FLOAT,
TypeParams: []*commonpb.KeyValuePair{
{
Key: "dim",
Value: "16",
},
},
IndexParams: []*commonpb.KeyValuePair{
{
Key: "metric_type",
Value: "L2",
},
},
}
}
fieldInt := schemapb.FieldSchema{
......@@ -71,7 +93,7 @@ func genTestCollectionMeta(collectionName string, collectionID UniqueID) *etcdpb
}
func initTestMeta(t *testing.T, node *QueryNode, collectionName string, collectionID UniqueID, segmentID UniqueID) {
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
schemaBlob := proto.MarshalTextString(collectionMeta.Schema)
assert.NotEqual(t, "", schemaBlob)
......
......@@ -16,7 +16,7 @@ func TestReduce_AllFunc(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
segmentID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
schemaBlob := proto.MarshalTextString(collectionMeta.Schema)
assert.NotEqual(t, "", schemaBlob)
......
......@@ -17,7 +17,7 @@ import (
func TestSegment_newSegment(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
schemaBlob := proto.MarshalTextString(collectionMeta.Schema)
assert.NotEqual(t, "", schemaBlob)
......@@ -35,7 +35,7 @@ func TestSegment_newSegment(t *testing.T) {
func TestSegment_deleteSegment(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
schemaBlob := proto.MarshalTextString(collectionMeta.Schema)
assert.NotEqual(t, "", schemaBlob)
......@@ -55,7 +55,7 @@ func TestSegment_deleteSegment(t *testing.T) {
func TestSegment_getRowCount(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
schemaBlob := proto.MarshalTextString(collectionMeta.Schema)
assert.NotEqual(t, "", schemaBlob)
......@@ -106,7 +106,7 @@ func TestSegment_getRowCount(t *testing.T) {
func TestSegment_getDeletedCount(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
schemaBlob := proto.MarshalTextString(collectionMeta.Schema)
assert.NotEqual(t, "", schemaBlob)
......@@ -163,7 +163,7 @@ func TestSegment_getDeletedCount(t *testing.T) {
func TestSegment_getMemSize(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
schemaBlob := proto.MarshalTextString(collectionMeta.Schema)
assert.NotEqual(t, "", schemaBlob)
......@@ -215,7 +215,7 @@ func TestSegment_getMemSize(t *testing.T) {
func TestSegment_segmentInsert(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
schemaBlob := proto.MarshalTextString(collectionMeta.Schema)
assert.NotEqual(t, "", schemaBlob)
......@@ -261,7 +261,7 @@ func TestSegment_segmentInsert(t *testing.T) {
func TestSegment_segmentDelete(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
schemaBlob := proto.MarshalTextString(collectionMeta.Schema)
assert.NotEqual(t, "", schemaBlob)
......@@ -314,7 +314,7 @@ func TestSegment_segmentDelete(t *testing.T) {
func TestSegment_segmentSearch(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
schemaBlob := proto.MarshalTextString(collectionMeta.Schema)
assert.NotEqual(t, "", schemaBlob)
......@@ -399,7 +399,7 @@ func TestSegment_segmentSearch(t *testing.T) {
func TestSegment_segmentPreInsert(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
schemaBlob := proto.MarshalTextString(collectionMeta.Schema)
assert.NotEqual(t, "", schemaBlob)
......@@ -441,7 +441,7 @@ func TestSegment_segmentPreInsert(t *testing.T) {
func TestSegment_segmentPreDelete(t *testing.T) {
collectionName := "collection0"
collectionID := UniqueID(0)
collectionMeta := genTestCollectionMeta(collectionName, collectionID)
collectionMeta := genTestCollectionMeta(collectionName, collectionID, false)
schemaBlob := proto.MarshalTextString(collectionMeta.Schema)
assert.NotEqual(t, "", schemaBlob)
......
......@@ -38,6 +38,7 @@ func (colReplica *collectionReplicaImpl) addCollection(collectionID UniqueID, sc
var newCollection = newCollection(collectionID, schemaBlob)
colReplica.collections = append(colReplica.collections, newCollection)
fmt.Println("yyy, create collection: ", newCollection.Name())
return nil
}
......@@ -51,6 +52,8 @@ func (colReplica *collectionReplicaImpl) removeCollection(collectionID UniqueID)
for _, col := range colReplica.collections {
if col.ID() != collectionID {
tmpCollections = append(tmpCollections, col)
} else {
fmt.Println("yyy, drop collection name: ", col.Name())
}
}
colReplica.collections = tmpCollections
......
......@@ -115,7 +115,7 @@ func (ibNode *insertBufferNode) Operate(in []*Msg) []*Msg {
//collSchema, err := ibNode.getCollectionSchemaByName(collectionName)
if err != nil {
// GOOSE TODO add error handler
log.Println("Get meta wrong:", err)
log.Println("bbb, Get meta wrong:", err)
continue
}
......@@ -457,7 +457,7 @@ func (ibNode *insertBufferNode) Operate(in []*Msg) []*Msg {
collSchema, err := ibNode.getCollectionSchemaByID(collectionID)
if err != nil {
// GOOSE TODO add error handler
log.Println("Get meta wrong: ", err)
log.Println("aaa, Get meta wrong: ", err)
}
collMeta := &etcdpb.CollectionMeta{
Schema: collSchema,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册