diff --git a/cpp/src/db/insert/MemTable.cpp b/cpp/src/db/insert/MemTable.cpp index ca63c02ad9c5f2ccdfdd5a7001dbdd4aab73fab9..ff9c25e3e732154891ea148c4c692dbb51af9faf 100644 --- a/cpp/src/db/insert/MemTable.cpp +++ b/cpp/src/db/insert/MemTable.cpp @@ -27,12 +27,12 @@ Status MemTable::Add(VectorSource::Ptr &source, IDNumbers &vector_ids) { Status status; if (mem_table_file_list_.empty() || current_mem_table_file->IsFull()) { MemTableFile::Ptr new_mem_table_file = std::make_shared(table_id_, meta_, options_); - status = new_mem_table_file->Add(source); + status = new_mem_table_file->Add(source, vector_ids); if (status.ok()) { mem_table_file_list_.emplace_back(new_mem_table_file); } } else { - status = current_mem_table_file->Add(source); + status = current_mem_table_file->Add(source, vector_ids); } if (!status.ok()) { diff --git a/cpp/src/db/insert/MemTableFile.cpp b/cpp/src/db/insert/MemTableFile.cpp index 1d7053ab5a2c9b33e111dd74aaf52c8a6f9da84b..326658df5fe9bfb4988587845c647885c4984240 100644 --- a/cpp/src/db/insert/MemTableFile.cpp +++ b/cpp/src/db/insert/MemTableFile.cpp @@ -41,7 +41,7 @@ Status MemTableFile::CreateTableFile() { return status; } -Status MemTableFile::Add(const VectorSource::Ptr &source) { +Status MemTableFile::Add(const VectorSource::Ptr &source, IDNumbers& vector_ids) { if (table_file_schema_.dimension_ <= 0) { std::string err_msg = "MemTableFile::Add: table_file_schema dimension = " + @@ -55,7 +55,7 @@ Status MemTableFile::Add(const VectorSource::Ptr &source) { if (mem_left >= single_vector_mem_size) { size_t num_vectors_to_add = std::ceil(mem_left / single_vector_mem_size); size_t num_vectors_added; - auto status = source->Add(execution_engine_, table_file_schema_, num_vectors_to_add, num_vectors_added); + auto status = source->Add(execution_engine_, table_file_schema_, num_vectors_to_add, num_vectors_added, vector_ids); if (status.ok()) { current_mem_ += (num_vectors_added * single_vector_mem_size); } diff --git a/cpp/src/db/insert/MemTableFile.h b/cpp/src/db/insert/MemTableFile.h index b582152299fb90e64a75d10c03404ab941d966f7..d754b030713ba8a75a7369f7be5d5e6e87e41b7e 100644 --- a/cpp/src/db/insert/MemTableFile.h +++ b/cpp/src/db/insert/MemTableFile.h @@ -19,7 +19,7 @@ class MemTableFile { MemTableFile(const std::string &table_id, const std::shared_ptr &meta, const Options &options); - Status Add(const VectorSource::Ptr &source); + Status Add(const VectorSource::Ptr &source, IDNumbers& vector_ids); size_t GetCurrentMem(); diff --git a/cpp/src/db/insert/VectorSource.cpp b/cpp/src/db/insert/VectorSource.cpp index 5a24d261afafa1eb9f677115950f931e953355b7..27385b4b230303bdf9a8c7877648410ce71c4f4a 100644 --- a/cpp/src/db/insert/VectorSource.cpp +++ b/cpp/src/db/insert/VectorSource.cpp @@ -21,14 +21,22 @@ VectorSource::VectorSource(const size_t &n, Status VectorSource::Add(const ExecutionEnginePtr &execution_engine, const meta::TableFileSchema &table_file_schema, const size_t &num_vectors_to_add, - size_t &num_vectors_added) { + size_t &num_vectors_added, + IDNumbers &vector_ids) { auto start_time = METRICS_NOW_TIME; num_vectors_added = current_num_vectors_added + num_vectors_to_add <= n_ ? num_vectors_to_add : n_ - current_num_vectors_added; IDNumbers vector_ids_to_add; - id_generator_->GetNextIDNumbers(num_vectors_added, vector_ids_to_add); + if (vector_ids.empty()) { + id_generator_->GetNextIDNumbers(num_vectors_added, vector_ids_to_add); + } else { + vector_ids_to_add.resize(num_vectors_added); + for (int pos = current_num_vectors_added; pos < current_num_vectors_added + num_vectors_added; pos++) { + vector_ids_to_add[pos-current_num_vectors_added] = vector_ids[pos]; + } + } Status status = execution_engine->AddWithIds(num_vectors_added, vectors_ + current_num_vectors_added * table_file_schema.dimension_, vector_ids_to_add.data()); diff --git a/cpp/src/db/insert/VectorSource.h b/cpp/src/db/insert/VectorSource.h index 3f7e4e8f5e15f76be75b8d52de201681383ca0ee..4c350c78bcb789ace5a91f8232d06546c5ad4360 100644 --- a/cpp/src/db/insert/VectorSource.h +++ b/cpp/src/db/insert/VectorSource.h @@ -21,7 +21,8 @@ class VectorSource { Status Add(const ExecutionEnginePtr &execution_engine, const meta::TableFileSchema &table_file_schema, const size_t &num_vectors_to_add, - size_t &num_vectors_added); + size_t &num_vectors_added, + IDNumbers &vector_ids); size_t GetNumVectorsAdded(); diff --git a/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp b/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp index 583a91789768d0f36ab00da1a4175ff6734651ee..5225f2a97e32fa848b22eaf5c59938b055b576fc 100644 --- a/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp +++ b/cpp/src/sdk/examples/grpcsimple/src/ClientTest.cpp @@ -15,6 +15,8 @@ using namespace milvus; +//#define SET_VECTOR_IDS; + namespace { std::string GetTableName(); @@ -211,9 +213,9 @@ ClientTest::Test(const std::string& address, const std::string& port) { std::cout << "All tables: " << std::endl; for(auto& table : tables) { int64_t row_count = 0; -// conn->DropTable(table); - stat = conn->CountTable(table, row_count); - std::cout << "\t" << table << "(" << row_count << " rows)" << std::endl; + conn->DropTable(table); +// stat = conn->CountTable(table, row_count); +// std::cout << "\t" << table << "(" << row_count << " rows)" << std::endl; } } @@ -235,59 +237,21 @@ ClientTest::Test(const std::string& address, const std::string& port) { std::cout << "DescribeTable function call status: " << stat.ToString() << std::endl; PrintTableSchema(tb_schema); } -// -// Connection::Destroy(conn); - -// pid_t pid; -// for (int i = 0; i < 5; ++i) { -// pid = fork(); -// if (pid == 0 || pid == -1) { -// break; -// } -// } -// if (pid == -1) { -// std::cout << "fail to fork!\n"; -// exit(1); -// } else if (pid == 0) { -// std::shared_ptr conn = Connection::Create(); -// -// {//connect server -// ConnectParam param = {address, port}; -// Status stat = conn->Connect(param); -// std::cout << "Connect function call status: " << stat.ToString() << std::endl; -// } -// -// {//server version -// std::string version = conn->ServerVersion(); -// std::cout << "Server version: " << version << std::endl; -// } -// Connection::Destroy(conn); -// exit(0); -// } else { -// std::shared_ptr conn = Connection::Create(); -// -// {//connect server -// ConnectParam param = {address, port}; -// Status stat = conn->Connect(param); -// std::cout << "Connect function call status: " << stat.ToString() << std::endl; -// } -// -// {//server version -// std::string version = conn->ServerVersion(); -// std::cout << "Server version: " << version << std::endl; -// } -// Connection::Destroy(conn); -// std::cout << "in main process\n"; -// exit(0); -// } std::vector> search_record_array; {//insert vectors + std::vector record_ids; for (int i = 0; i < ADD_VECTOR_LOOP; i++) {//add vectors std::vector record_array; int64_t begin_index = i * BATCH_ROW_COUNT; BuildVectors(begin_index, begin_index + BATCH_ROW_COUNT, record_array); - std::vector record_ids; + +#ifdef SET_VECTOR_IDS + record_ids.resize(ADD_VECTOR_LOOP * BATCH_ROW_COUNT); + for (auto j = begin_index; j Insert(vector_ids, insert_param, status); - auto finish = std::chrono::high_resolution_clock::now(); - - for (size_t i = 0; i < vector_ids.vector_id_array_size(); i++) { - id_array.push_back(vector_ids.vector_id_array(i)); + ::milvus::grpc::VectorIds vector_ids; + if (!id_array.empty()) { + for (auto i = 0; i < id_array.size(); i++) { + insert_param.add_row_id_array(id_array[i]); + } + client_ptr_->Insert(vector_ids, insert_param, status); + } else { + client_ptr_->Insert(vector_ids, insert_param, status); + for (size_t i = 0; i < vector_ids.vector_id_array_size(); i++) { + id_array.push_back(vector_ids.vector_id_array(i)); + } } + #endif } catch (std::exception &ex) { diff --git a/cpp/src/server/grpc_impl/GrpcRequestTask.cpp b/cpp/src/server/grpc_impl/GrpcRequestTask.cpp index 73f38528cf2f37c9419e7a6de9b7ac17793ccef1..8934045579a9705cbc9192ee25610e207d28c942 100644 --- a/cpp/src/server/grpc_impl/GrpcRequestTask.cpp +++ b/cpp/src/server/grpc_impl/GrpcRequestTask.cpp @@ -453,7 +453,10 @@ InsertTask::OnExecute() { //step 4: insert vectors auto vec_count = (uint64_t) insert_param_.row_record_array_size(); - std::vector vec_ids(record_ids_.vector_id_array_size(), 0); + std::vector vec_ids(insert_param_.row_id_array_size(), 0); + for (auto i = 0; i < insert_param_.row_id_array_size(); i++) { + vec_ids[i] = insert_param_.row_id_array(i); + } stat = DBWrapper::DB()->InsertVectors(insert_param_.table_name(), vec_count, vec_f.data(), vec_ids); diff --git a/cpp/unittest/db/mem_test.cpp b/cpp/unittest/db/mem_test.cpp index 1976822e761d20fc538d9cc6c0baecbe8e40fa56..0f8d2b65e05e2211562350cd29a318a3b6ea6264 100644 --- a/cpp/unittest/db/mem_test.cpp +++ b/cpp/unittest/db/mem_test.cpp @@ -68,15 +68,15 @@ TEST_F(NewMemManagerTest, VECTOR_SOURCE_TEST) { engine::ExecutionEnginePtr execution_engine_ = engine::EngineFactory::Build(table_file_schema.dimension_, table_file_schema.location_, (engine::EngineType) table_file_schema.engine_type_); - status = source.Add(execution_engine_, table_file_schema, 50, num_vectors_added); + engine::IDNumbers vector_ids; + status = source.Add(execution_engine_, table_file_schema, 50, num_vectors_added, vector_ids); ASSERT_TRUE(status.ok()); - - ASSERT_EQ(num_vectors_added, 50); - - engine::IDNumbers vector_ids = source.GetVectorIds(); + vector_ids = source.GetVectorIds(); ASSERT_EQ(vector_ids.size(), 50); + ASSERT_EQ(num_vectors_added, 50); - status = source.Add(execution_engine_, table_file_schema, 60, num_vectors_added); + vector_ids.clear(); + status = source.Add(execution_engine_, table_file_schema, 60, num_vectors_added, vector_ids); ASSERT_TRUE(status.ok()); ASSERT_EQ(num_vectors_added, 50); @@ -105,12 +105,13 @@ TEST_F(NewMemManagerTest, MEM_TABLE_FILE_TEST) { engine::VectorSource::Ptr source = std::make_shared(n_100, vectors_100.data()); - status = mem_table_file.Add(source); + engine::IDNumbers vector_ids; + status = mem_table_file.Add(source, vector_ids); ASSERT_TRUE(status.ok()); // std::cout << mem_table_file.GetCurrentMem() << " " << mem_table_file.GetMemLeft() << std::endl; - engine::IDNumbers vector_ids = source->GetVectorIds(); + vector_ids = source->GetVectorIds(); ASSERT_EQ(vector_ids.size(), 100); size_t singleVectorMem = sizeof(float) * TABLE_DIM; @@ -121,7 +122,8 @@ TEST_F(NewMemManagerTest, MEM_TABLE_FILE_TEST) { BuildVectors(n_max, vectors_128M); engine::VectorSource::Ptr source_128M = std::make_shared(n_max, vectors_128M.data()); - status = mem_table_file.Add(source_128M); + vector_ids.clear(); + status = mem_table_file.Add(source_128M, vector_ids); vector_ids = source_128M->GetVectorIds(); ASSERT_EQ(vector_ids.size(), n_max - n_100); @@ -149,9 +151,10 @@ TEST_F(NewMemManagerTest, MEM_TABLE_TEST) { engine::MemTable mem_table(TABLE_NAME, impl_, options); - status = mem_table.Add(source_100); + engine::IDNumbers vector_ids; + status = mem_table.Add(source_100, vector_ids); ASSERT_TRUE(status.ok()); - engine::IDNumbers vector_ids = source_100->GetVectorIds(); + vector_ids = source_100->GetVectorIds(); ASSERT_EQ(vector_ids.size(), 100); engine::MemTableFile::Ptr mem_table_file; @@ -163,8 +166,9 @@ TEST_F(NewMemManagerTest, MEM_TABLE_TEST) { std::vector vectors_128M; BuildVectors(n_max, vectors_128M); + vector_ids.clear(); engine::VectorSource::Ptr source_128M = std::make_shared(n_max, vectors_128M.data()); - status = mem_table.Add(source_128M); + status = mem_table.Add(source_128M, vector_ids); ASSERT_TRUE(status.ok()); vector_ids = source_128M->GetVectorIds(); @@ -181,7 +185,8 @@ TEST_F(NewMemManagerTest, MEM_TABLE_TEST) { engine::VectorSource::Ptr source_1G = std::make_shared(n_1G, vectors_1G.data()); - status = mem_table.Add(source_1G); + vector_ids.clear(); + status = mem_table.Add(source_1G, vector_ids); ASSERT_TRUE(status.ok()); vector_ids = source_1G->GetVectorIds(); @@ -370,3 +375,61 @@ TEST_F(NewMemManagerTest, CONCURRENT_INSERT_SEARCH_TEST) { }; +TEST_F(DBTest, VECTOR_IDS_TEST) +{ + engine::meta::TableSchema table_info = BuildTableSchema(); + engine::Status stat = db_->CreateTable(table_info); + + engine::meta::TableSchema table_info_get; + table_info_get.table_id_ = TABLE_NAME; + stat = db_->DescribeTable(table_info_get); + ASSERT_STATS(stat); + ASSERT_EQ(table_info_get.dimension_, TABLE_DIM); + + engine::IDNumbers vector_ids; + + + int64_t nb = 100000; + std::vector xb; + BuildVectors(nb, xb); + + vector_ids.resize(nb); + for (auto i = 0; i < nb; i++) { + vector_ids[i] = i; + } + + stat = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); + ASSERT_EQ(vector_ids[0], 0); + ASSERT_STATS(stat); + + nb = 25000; + xb.clear(); + BuildVectors(nb, xb); + vector_ids.clear(); + vector_ids.resize(nb); + for (auto i = 0; i < nb; i++) { + vector_ids[i] = i + nb; + } + stat = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); + ASSERT_EQ(vector_ids[0], nb); + ASSERT_STATS(stat); + + nb = 262144; //512M + xb.clear(); + BuildVectors(nb, xb); + vector_ids.clear(); + vector_ids.resize(nb); + for (auto i = 0; i < nb; i++) { + vector_ids[i] = i + nb / 2; + } + stat = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); + ASSERT_EQ(vector_ids[0], nb/2); + ASSERT_STATS(stat); + + nb = 65536; //128M + xb.clear(); + BuildVectors(nb, xb); + vector_ids.clear(); + stat = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); + ASSERT_STATS(stat); +} \ No newline at end of file