diff --git a/reader/read_node/query_node.go b/reader/read_node/query_node.go index da41ad6052ac574aeab177588ca89b1ba25dc3b1..d766d0cd38b0bdc19a17e5c40b729617e47a5eed 100644 --- a/reader/read_node/query_node.go +++ b/reader/read_node/query_node.go @@ -17,9 +17,6 @@ import ( "encoding/json" "fmt" "github.com/czs007/suvlim/conf" - msgPb "github.com/czs007/suvlim/pkg/master/grpc/message" - "github.com/czs007/suvlim/pkg/master/kv" - "github.com/czs007/suvlim/reader/message_client" "github.com/stretchr/testify/assert" "log" "sort" @@ -27,6 +24,9 @@ import ( "sync/atomic" "time" + msgPb "github.com/czs007/suvlim/pkg/master/grpc/message" + "github.com/czs007/suvlim/pkg/master/kv" + "github.com/czs007/suvlim/reader/message_client" //"github.com/stretchr/testify/assert" ) @@ -69,21 +69,8 @@ type QueryInfo struct { type MsgCounter struct { InsertCounter int64 - InsertTime time.Time - DeleteCounter int64 - DeleteTime time.Time - SearchCounter int64 - SearchTime time.Time -} - -type InsertLog struct { - MsgLength int - DurationInMilliseconds int64 - InsertTime time.Time - NumSince int64 - Speed float64 } type QueryNode struct { @@ -99,7 +86,6 @@ type QueryNode struct { insertData InsertData kvBase *kv.EtcdKVBase msgCounter *MsgCounter - InsertLogs []InsertLog } func NewQueryNode(queryNodeId uint64, timeSync uint64) *QueryNode { @@ -109,7 +95,7 @@ func NewQueryNode(queryNodeId uint64, timeSync uint64) *QueryNode { ReadTimeSyncMin: timeSync, ReadTimeSyncMax: timeSync, WriteTimeSync: timeSync, - ServiceTimeSync: timeSync, + ServiceTimeSync: timeSync, TSOTimeSync: timeSync, } @@ -149,7 +135,7 @@ func CreateQueryNode(queryNodeId uint64, timeSync uint64, mc *message_client.Mes ReadTimeSyncMin: timeSync, ReadTimeSyncMax: timeSync, WriteTimeSync: timeSync, - ServiceTimeSync: timeSync, + ServiceTimeSync: timeSync, TSOTimeSync: timeSync, } @@ -176,7 +162,6 @@ func CreateQueryNode(queryNodeId uint64, timeSync uint64, mc *message_client.Mes queryNodeTimeSync: queryNodeTimeSync, buffer: buffer, msgCounter: &msgCounter, - InsertLogs: make([]InsertLog, 0), } } @@ -261,11 +246,13 @@ func (node *QueryNode) InitQueryNodeCollection() { func (node *QueryNode) RunInsertDelete(wg *sync.WaitGroup) { const Debug = true - const CountInsertMsgBaseline = 1000 * 1000 - var BaselineCounter int64 = 0 - node.msgCounter.InsertTime = time.Now() + const CountMsgNum = 1000 * 1000 if Debug { + var printFlag = true + var startTime = true + var start time.Time + for { var msgLen = node.PrepareBatchMsg() var timeRange = TimeRange{node.messageClient.TimeSyncStart(), node.messageClient.TimeSyncEnd()} @@ -277,9 +264,10 @@ func (node *QueryNode) RunInsertDelete(wg *sync.WaitGroup) { continue } - if node.msgCounter.InsertCounter/CountInsertMsgBaseline == BaselineCounter { - node.WriteQueryLog() - BaselineCounter++ + if startTime { + fmt.Println("============> Start Test <============") + startTime = false + start = time.Now() } node.QueryNodeDataInit() @@ -291,6 +279,13 @@ func (node *QueryNode) RunInsertDelete(wg *sync.WaitGroup) { node.DoInsertAndDelete() //fmt.Println("DoInsertAndDelete Done") node.queryNodeTimeSync.UpdateSearchTimeSync(timeRange) + + // Test insert time + if printFlag && node.msgCounter.InsertCounter >= CountMsgNum { + printFlag = false + timeSince := time.Since(start) + fmt.Println("============> Do", node.msgCounter.InsertCounter, "Insert in", timeSince, "<============") + } } } @@ -339,14 +334,14 @@ func (node *QueryNode) RunSearch(wg *sync.WaitGroup) { node.messageClient.SearchMsg = append(node.messageClient.SearchMsg, msg) fmt.Println("Do Search...") //for { - //if node.messageClient.SearchMsg[0].Timestamp < node.queryNodeTimeSync.ServiceTimeSync { - var status = node.Search(node.messageClient.SearchMsg) - if status.ErrorCode != 0 { - fmt.Println("Search Failed") - node.PublishFailedSearchResult() - } - //break - //} + //if node.messageClient.SearchMsg[0].Timestamp < node.queryNodeTimeSync.ServiceTimeSync { + var status = node.Search(node.messageClient.SearchMsg) + if status.ErrorCode != 0 { + fmt.Println("Search Failed") + node.PublishFailedSearchResult() + } + //break + //} //} default: } @@ -490,9 +485,9 @@ func (node *QueryNode) PreInsertAndDelete() msgPb.Status { func (node *QueryNode) DoInsertAndDelete() msgPb.Status { var wg sync.WaitGroup // Do insert - for segmentID := range node.insertData.insertRecords { + for segmentID, records := range node.insertData.insertRecords { wg.Add(1) - go node.DoInsert(segmentID, &wg) + go node.DoInsert(segmentID, &records, &wg) } // Do delete @@ -510,7 +505,7 @@ func (node *QueryNode) DoInsertAndDelete() msgPb.Status { return msgPb.Status{ErrorCode: msgPb.ErrorCode_SUCCESS} } -func (node *QueryNode) DoInsert(segmentID int64, wg *sync.WaitGroup) msgPb.Status { +func (node *QueryNode) DoInsert(segmentID int64, records *[][]byte, wg *sync.WaitGroup) msgPb.Status { fmt.Println("Doing insert..., len = ", len(node.insertData.insertIDs[segmentID])) var targetSegment, err = node.GetSegmentBySegmentID(segmentID) if err != nil { @@ -520,12 +515,10 @@ func (node *QueryNode) DoInsert(segmentID int64, wg *sync.WaitGroup) msgPb.Statu ids := node.insertData.insertIDs[segmentID] timestamps := node.insertData.insertTimestamps[segmentID] - records := node.insertData.insertRecords[segmentID] offsets := node.insertData.insertOffset[segmentID] - node.QueryLog(len(ids)) - - err = targetSegment.SegmentInsert(offsets, &ids, ×tamps, &records) + node.msgCounter.InsertCounter += int64(len(ids)) + err = targetSegment.SegmentInsert(offsets, &ids, ×tamps, records) if err != nil { fmt.Println(err.Error()) return msgPb.Status{ErrorCode: 1} @@ -592,7 +585,7 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { // Here, we manually make searchTimestamp's logic time minus `conf.Config.Timesync.Interval` milliseconds. // Which means `searchTimestamp.logicTime = searchTimestamp.logicTime - conf.Config.Timesync.Interval`. var logicTimestamp = searchTimestamp << 46 >> 46 - searchTimestamp = (searchTimestamp>>18-uint64(conf.Config.Timesync.Interval+600))<<18 + logicTimestamp + searchTimestamp = (searchTimestamp >> 18 - uint64(conf.Config.Timesync.Interval + 600)) << 18 + logicTimestamp var vector = msg.Records // We now only the first Json is valid. @@ -601,7 +594,7 @@ func (node *QueryNode) Search(searchMessages []*msgPb.SearchMsg) msgPb.Status { // 1. Timestamp check // TODO: return or wait? Or adding graceful time if searchTimestamp > node.queryNodeTimeSync.ServiceTimeSync { - fmt.Println("Invalid query time, timestamp = ", searchTimestamp>>18, ", SearchTimeSync = ", node.queryNodeTimeSync.ServiceTimeSync>>18) + fmt.Println("Invalid query time, timestamp = ", searchTimestamp >> 18, ", SearchTimeSync = ", node.queryNodeTimeSync.ServiceTimeSync >> 18) return msgPb.Status{ErrorCode: 1} } diff --git a/reader/read_node/segment.go b/reader/read_node/segment.go index c369a13247ee4fff5d8eeb8dc25886f491f6ee5a..3da047a4e8263acaa9806452b8ce6cb1d65bd4c3 100644 --- a/reader/read_node/segment.go +++ b/reader/read_node/segment.go @@ -16,7 +16,6 @@ import ( "fmt" "github.com/czs007/suvlim/errors" msgPb "github.com/czs007/suvlim/pkg/master/grpc/message" - "github.com/stretchr/testify/assert" "strconv" "unsafe" ) @@ -144,13 +143,11 @@ func (s *Segment) SegmentInsert(offset int64, entityIDs *[]int64, timestamps *[] var numOfRow = len(*entityIDs) var sizeofPerRow = len((*records)[0]) - assert.Equal(nil, numOfRow, len(*records)) - - var rawData = make([]byte, numOfRow * sizeofPerRow) + var rawData = make([]byte, numOfRow*sizeofPerRow) var copyOffset = 0 for i := 0; i < len(*records); i++ { copy(rawData[copyOffset:], (*records)[i]) - copyOffset += sizeofPerRow + copyOffset += len((*records)[i]) } var cOffset = C.long(offset) diff --git a/reader/read_node/util_functions.go b/reader/read_node/util_functions.go index 303d824d907888c567f88fe66990008b296860dd..c9071d4e04ca35ed25333823a4c26a115f721ac4 100644 --- a/reader/read_node/util_functions.go +++ b/reader/read_node/util_functions.go @@ -1,13 +1,8 @@ package reader import ( - "encoding/json" "errors" - "fmt" - log "github.com/apache/pulsar/pulsar-client-go/logutil" - "os" "strconv" - "time" ) // Function `GetSegmentByEntityId` should return entityIDs, timestamps and segmentIDs @@ -73,54 +68,3 @@ func (c *Collection) GetPartitionByName(partitionName string) (partition *Partit return nil // TODO: remove from c.Partitions } - -func (node *QueryNode) QueryLog(length int) { - node.msgCounter.InsertCounter += int64(length) - timeNow := time.Now() - duration := timeNow.Sub(node.msgCounter.InsertTime) - speed := float64(length) / duration.Seconds() - - insertLog := InsertLog{ - MsgLength: length, - DurationInMilliseconds: duration.Milliseconds(), - InsertTime: timeNow, - NumSince: node.msgCounter.InsertCounter, - Speed: speed, - } - - node.InsertLogs = append(node.InsertLogs, insertLog) - node.msgCounter.InsertTime = timeNow -} - -func (node *QueryNode) WriteQueryLog() { - f, err := os.OpenFile("/tmp/query_node.txt", os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - log.Fatal(err) - } - - // write logs - for _, insertLog := range node.InsertLogs { - insertLogJson, err := json.Marshal(&insertLog) - if err != nil { - log.Fatal(err) - } - - writeString := string(insertLogJson) + "\n" - fmt.Println(writeString) - - _, err2 := f.WriteString(writeString) - if err2 != nil { - log.Fatal(err2) - } - } - - // reset InsertLogs buffer - node.InsertLogs = make([]InsertLog, 0) - - err = f.Close() - if err != nil { - log.Fatal(err) - } - - fmt.Println("write log done") -} diff --git a/sdk/examples/simple/insert.cpp b/sdk/examples/simple/insert.cpp index b539adb3da2607ad698e69138a4f822d5161f07f..9b0fe128eca0fc502c5727f8fa464b00999de6d6 100644 --- a/sdk/examples/simple/insert.cpp +++ b/sdk/examples/simple/insert.cpp @@ -59,11 +59,32 @@ const milvus::FieldValue GetData(int count) { value_map.int32_value["INT32"] = int32_data; value_map.vector_value["VECTOR"] = vector_data; - value_map.row_num = N; + value_map.row_num = count; return value_map; } -bool checkSchema(){ +bool check_field(milvus::FieldPtr left, milvus::FieldPtr right){ + + if (left->field_name != right->field_name){ + std::cout<<"filed_name not match! want "<< left->field_name << " but get "<field_name << std::endl; + return false; + } + + if (left->field_type != right->field_type){ + std::cout<<"filed_type not match! want "<< int(left->field_type) << " but get "<< int(right->field_type) << std::endl; + return false; + } + + + if (left->dim != right->dim){ + std::cout<<"dim not match! want "<< left->dim << " but get "<dim << std::endl; + return false; + } + + return true; +} + +bool check_schema(const milvus::Mapping & map){ // Get Collection info bool ret = false; @@ -80,11 +101,16 @@ bool checkSchema(){ std::vector fields{field_ptr1, field_ptr2}; - milvus::Mapping map; - //client.GetCollectionInfo(collection_name, map); + auto size_ = map.fields.size(); + for ( int i =0; i != size_; ++ i){ + auto ret = check_field(fields[i], map.fields[i]); + if (!ret){ + return false; + } + } for (auto &f : map.fields) { - ///std::cout << f->field_name << ":" << int(f->field_type) << ":" << f->dim << "DIM" << std::endl; + std::cout << f->field_name << ":" << int(f->field_type) << ":" << f->dim << "DIM" << std::endl; } return true; @@ -105,12 +131,22 @@ main(int argc, char* argv[]) { const std::string collection_name = parameters.collection_name_; auto client = milvus::ConnectionImpl(); + milvus::ConnectParam connect_param; connect_param.ip_address = parameters.address_.empty() ? "127.0.0.1":parameters.address_; connect_param.port = parameters.port_.empty() ? "19530":parameters.port_ ; client.Connect(connect_param); + milvus::Mapping map; + client.GetCollectionInfo(collection_name, map); + + auto check_ret = check_schema(map); + if (!check_ret){ + std::cout<<" Schema is not right!"<< std::endl; + return 0; + } + int per_count = N / LOOP; int failed_count = 0; diff --git a/sdk/examples/simple/search.cpp b/sdk/examples/simple/search.cpp index b3173ab4ab8c16fd52bcd0ac75d5a19fc111ebfb..25ff340a5cbb7eddf1cb1e699e2b69f37339aae6 100644 --- a/sdk/examples/simple/search.cpp +++ b/sdk/examples/simple/search.cpp @@ -1,5 +1,3 @@ - - // Copyright (C) 2019-2020 Zilliz. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance @@ -23,27 +21,9 @@ const int TOP_K = 10; const int LOOP = 1000; const int DIM = 128; -int main(int argc , char**argv) { - TestParameters parameters = milvus_sdk::Utils::ParseTestParameters(argc, argv); - if (!parameters.is_valid){ - return 0; - } - - if (parameters.collection_name_.empty()){ - std::cout<< "should specify collection name!" << std::endl; - milvus_sdk::Utils::PrintHelp(argc, argv); - return 0; - } - - auto client = milvus::ConnectionImpl(); - milvus::ConnectParam connect_param; - connect_param.ip_address = parameters.address_.empty() ? "127.0.0.1":parameters.address_; - connect_param.port = parameters.port_.empty() ? "19530":parameters.port_ ; - client.Connect(connect_param); - std::vector ids_array; - std::vector partition_list; - partition_list.emplace_back("default"); +const milvus::VectorParam +get_vector_param() { milvus::VectorParam vectorParam; std::vector vector_records; @@ -54,7 +34,7 @@ int main(int argc , char**argv) { for (int j = 0; j < 1; ++j) { milvus::VectorData vectorData; std::vector float_data; - for (int i = 0; i < 16; ++i) { + for (int i = 0; i < DIM; ++i) { float_data.emplace_back(dis(eng)); } @@ -71,11 +51,103 @@ int main(int argc , char**argv) { vectorParam.json_param = vector_param_json_string; vectorParam.vector_records = vector_records; + return vectorParam; +} + +bool check_field(milvus::FieldPtr left, milvus::FieldPtr right){ + + if (left->field_name != right->field_name){ + std::cout<<"filed_name not match! want "<< left->field_name << " but get "<field_name << std::endl; + return false; + } + + if (left->field_type != right->field_type){ + std::cout<<"filed_type not match! want "<< int(left->field_type) << " but get "<< int(right->field_type) << std::endl; + return false; + } + + + if (left->dim != right->dim){ + std::cout<<"dim not match! want "<< left->dim << " but get "<dim << std::endl; + return false; + } + + return true; +} + + +bool check_schema(const milvus::Mapping & map){ + // Get Collection info + bool ret = false; + + milvus::FieldPtr field_ptr1 = std::make_shared(); + milvus::FieldPtr field_ptr2 = std::make_shared(); + + field_ptr1->field_name = "age"; + field_ptr1->field_type = milvus::DataType::INT32; + field_ptr1->dim = 1; + + field_ptr2->field_name = "field_vec"; + field_ptr2->field_type = milvus::DataType::VECTOR_FLOAT; + field_ptr2->dim = DIM; + + std::vector fields{field_ptr1, field_ptr2}; + + auto size_ = map.fields.size(); + for ( int i =0; i != size_; ++ i){ + auto ret = check_field(fields[i], map.fields[i]); + if (!ret){ + return false; + } + } + + for (auto &f : map.fields) { + std::cout << f->field_name << ":" << int(f->field_type) << ":" << f->dim << "DIM" << std::endl; + } + + return true; +} + +int main(int argc , char**argv) { + TestParameters parameters = milvus_sdk::Utils::ParseTestParameters(argc, argv); + if (!parameters.is_valid){ + return 0; + } + + if (parameters.collection_name_.empty()){ + std::cout<< "should specify collection name!" << std::endl; + milvus_sdk::Utils::PrintHelp(argc, argv); + return 0; + } + + + const std::string collection_name = parameters.collection_name_; + auto client = milvus::ConnectionImpl(); + milvus::ConnectParam connect_param; + connect_param.ip_address = parameters.address_.empty() ? "127.0.0.1":parameters.address_; + connect_param.port = parameters.port_.empty() ? "19530":parameters.port_ ; + client.Connect(connect_param); + + + milvus::Mapping map; + client.GetCollectionInfo(collection_name, map); + auto check_ret = check_schema(map); + if (!check_ret){ + std::cout<<" Schema is not right!"<< std::endl; + return 0; + } + + + std::vector partition_list; + partition_list.emplace_back("default"); + + auto vectorParam = get_vector_param(); + milvus::TopKQueryResult result; milvus_sdk::TimeRecorder test_search("search"); for (int k = 0; k < LOOP; ++k) { test_search.Start(); - auto status = client.Search("collection0", partition_list, "dsl", vectorParam, result); + auto status = client.Search(collection_name, partition_list, "dsl", vectorParam, result); test_search.End(); } test_search.Print(LOOP);