diff --git a/cpp/src/db/DBImpl.cpp b/cpp/src/db/DBImpl.cpp index 46101dcf93cae4c9bf53258bbce38ce4e7a5d762..62d6b9c3ccf712a427889787925af4dbd6501692 100644 --- a/cpp/src/db/DBImpl.cpp +++ b/cpp/src/db/DBImpl.cpp @@ -87,8 +87,7 @@ DBImpl::DBImpl(const Options& options) compact_thread_pool_(1, 1), index_thread_pool_(1, 1) { meta_ptr_ = DBMetaImplFactory::Build(options.meta, options.mode); - mem_mgr_ = std::make_shared(meta_ptr_, options_); - // mem_mgr_ = (MemManagerPtr)(new MemManager(meta_ptr_, options_)); + mem_mgr_ = MemManagerFactory::Build(meta_ptr_, options_); if (options.mode != Options::MODE::READ_ONLY) { StartTimerTasks(); } diff --git a/cpp/src/db/DBImpl.h b/cpp/src/db/DBImpl.h index cc632847c128399d772ac9024d481bd1e41a02f7..4853a0fb4da2c95f107656cfb46ad4ecb63fbfca 100644 --- a/cpp/src/db/DBImpl.h +++ b/cpp/src/db/DBImpl.h @@ -9,6 +9,7 @@ #include "MemManager.h" #include "Types.h" #include "utils/ThreadPool.h" +#include "MemManagerAbstract.h" #include #include @@ -33,7 +34,6 @@ class Meta; class DBImpl : public DB { public: using MetaPtr = meta::Meta::Ptr; - using MemManagerPtr = typename MemManager::Ptr; explicit DBImpl(const Options &options); @@ -131,7 +131,7 @@ class DBImpl : public DB { std::thread bg_timer_thread_; MetaPtr meta_ptr_; - MemManagerPtr mem_mgr_; + MemManagerAbstractPtr mem_mgr_; server::ThreadPool compact_thread_pool_; std::list> compact_thread_results_; diff --git a/cpp/src/db/Factories.cpp b/cpp/src/db/Factories.cpp index 231c3cce4fb26e24bf22236d027fbb9edf2502ac..abcc0821ab74aa227b884970acbdbce37587d1de 100644 --- a/cpp/src/db/Factories.cpp +++ b/cpp/src/db/Factories.cpp @@ -6,6 +6,8 @@ #include #include "Factories.h" #include "DBImpl.h" +#include "MemManager.h" +#include "NewMemManager.h" #include #include @@ -98,6 +100,15 @@ DB* DBFactory::Build(const Options& options) { return new DBImpl(options); } +MemManagerAbstractPtr MemManagerFactory::Build(const std::shared_ptr& meta, + const Options& options) { + bool useNew = true; + if (useNew) { + return std::make_shared(meta, options); + } + return std::make_shared(meta, options); +} + } // namespace engine } // namespace milvus } // namespace zilliz diff --git a/cpp/src/db/Factories.h b/cpp/src/db/Factories.h index 889922b17a59cb9bb96df73971a0952abe6079e7..567bc0a8bcd01e394e345feca7e6eeb22e85cb58 100644 --- a/cpp/src/db/Factories.h +++ b/cpp/src/db/Factories.h @@ -10,6 +10,7 @@ #include "MySQLMetaImpl.h" #include "Options.h" #include "ExecutionEngine.h" +#include "MemManagerAbstract.h" #include #include @@ -36,6 +37,10 @@ struct DBFactory { static DB* Build(const Options&); }; +struct MemManagerFactory { + static MemManagerAbstractPtr Build(const std::shared_ptr& meta, const Options& options); +}; + } // namespace engine } // namespace milvus } // namespace zilliz diff --git a/cpp/src/db/MemManager.h b/cpp/src/db/MemManager.h index 0ce88d504dcc3b3b69a620bf9da21cc2aa35d0e9..95303889dbf2f2c069f6c10f95b84a8c4f2bdd93 100644 --- a/cpp/src/db/MemManager.h +++ b/cpp/src/db/MemManager.h @@ -9,13 +9,13 @@ #include "IDGenerator.h" #include "Status.h" #include "Meta.h" +#include "MemManagerAbstract.h" #include #include #include #include #include -#include namespace zilliz { namespace milvus { @@ -62,7 +62,7 @@ private: -class MemManager { +class MemManager : public MemManagerAbstract { public: using MetaPtr = meta::Meta::Ptr; using MemVectorsPtr = typename MemVectors::Ptr; @@ -71,16 +71,16 @@ public: MemManager(const std::shared_ptr& meta, const Options& options) : meta_(meta), options_(options) {} - MemVectorsPtr GetMemByTable(const std::string& table_id); - Status InsertVectors(const std::string& table_id, - size_t n, const float* vectors, IDNumbers& vector_ids); + size_t n, const float* vectors, IDNumbers& vector_ids) override; - Status Serialize(std::set& table_ids); + Status Serialize(std::set& table_ids) override; - Status EraseMemVector(const std::string& table_id); + Status EraseMemVector(const std::string& table_id) override; private: + MemVectorsPtr GetMemByTable(const std::string& table_id); + Status InsertVectorsNoLock(const std::string& table_id, size_t n, const float* vectors, IDNumbers& vector_ids); Status ToImmutable(); diff --git a/cpp/src/db/MemManagerAbstract.h b/cpp/src/db/MemManagerAbstract.h new file mode 100644 index 0000000000000000000000000000000000000000..74222df1e844155e70bc51732c3a330146ceb592 --- /dev/null +++ b/cpp/src/db/MemManagerAbstract.h @@ -0,0 +1,25 @@ +#pragma once + +#include + +namespace zilliz { +namespace milvus { +namespace engine { + +class MemManagerAbstract { +public: + + virtual Status InsertVectors(const std::string& table_id, + size_t n, const float* vectors, IDNumbers& vector_ids) = 0; + + virtual Status Serialize(std::set& table_ids) = 0; + + virtual Status EraseMemVector(const std::string& table_id) = 0; + +}; // MemManagerAbstract + +using MemManagerAbstractPtr = std::shared_ptr; + +} // namespace engine +} // namespace milvus +} // namespace zilliz \ No newline at end of file diff --git a/cpp/src/db/MemTable.cpp b/cpp/src/db/MemTable.cpp index 86554695c8ee61a38c74eb1d81ddb7bb65b91c2d..b282ad375a638661875c1fae0b7079178f36c12b 100644 --- a/cpp/src/db/MemTable.cpp +++ b/cpp/src/db/MemTable.cpp @@ -44,7 +44,7 @@ void MemTable::GetCurrentMemTableFile(MemTableFile::Ptr& mem_table_file) { mem_table_file = mem_table_file_list_.back(); } -size_t MemTable::GetStackSize() { +size_t MemTable::GetTableFileCount() { return mem_table_file_list_.size(); } @@ -60,6 +60,14 @@ Status MemTable::Serialize() { return Status::OK(); } +bool MemTable::Empty() { + return mem_table_file_list_.empty(); +} + +std::string MemTable::GetTableId() { + return table_id_; +} + } // namespace engine } // namespace milvus } // namespace zilliz \ No newline at end of file diff --git a/cpp/src/db/MemTable.h b/cpp/src/db/MemTable.h index d5c7cc9e85a642f541915b36b9e0c8cd0545fcfc..e09d6ddac17faa5aebd2006d0c4a29e320769026 100644 --- a/cpp/src/db/MemTable.h +++ b/cpp/src/db/MemTable.h @@ -24,10 +24,14 @@ public: void GetCurrentMemTableFile(MemTableFile::Ptr& mem_table_file); - size_t GetStackSize(); + size_t GetTableFileCount(); Status Serialize(); + bool Empty(); + + std::string GetTableId(); + private: const std::string table_id_; diff --git a/cpp/src/db/NewMemManager.cpp b/cpp/src/db/NewMemManager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..19aba68eb79675bf06eb2f0bb348aa89cf4ab2ae --- /dev/null +++ b/cpp/src/db/NewMemManager.cpp @@ -0,0 +1,92 @@ +#include "NewMemManager.h" +#include "VectorSource.h" + +namespace zilliz { +namespace milvus { +namespace engine { + +NewMemManager::MemTablePtr NewMemManager::GetMemByTable(const std::string& table_id) { + auto memIt = mem_id_map_.find(table_id); + if (memIt != mem_id_map_.end()) { + return memIt->second; + } + + mem_id_map_[table_id] = std::make_shared(table_id, meta_, options_); + return mem_id_map_[table_id]; +} + +Status NewMemManager::InsertVectors(const std::string& table_id_, + size_t n_, + const float* vectors_, + IDNumbers& vector_ids_) { + + + std::unique_lock lock(mutex_); + + return InsertVectorsNoLock(table_id_, n_, vectors_, vector_ids_); +} + +Status NewMemManager::InsertVectorsNoLock(const std::string& table_id, + size_t n, + const float* vectors, + IDNumbers& vector_ids) { + MemTablePtr mem = GetMemByTable(table_id); + VectorSource::Ptr source = std::make_shared(n, vectors); + + auto status = mem->Add(source); + if (status.ok()) { + vector_ids = source->GetVectorIds(); + } + return status; +} + +Status NewMemManager::ToImmutable() { + std::unique_lock lock(mutex_); + MemIdMap temp_map; + for (auto& kv: mem_id_map_) { + if(kv.second->Empty()) { + temp_map.insert(kv); + continue;//empty table, no need to serialize + } + immu_mem_list_.push_back(kv.second); + } + + mem_id_map_.swap(temp_map); + return Status::OK(); +} + +Status NewMemManager::Serialize(std::set& table_ids) { + ToImmutable(); + std::unique_lock lock(serialization_mtx_); + table_ids.clear(); + for (auto& mem : immu_mem_list_) { + mem->Serialize(); + table_ids.insert(mem->GetTableId()); + } + immu_mem_list_.clear(); + return Status::OK(); +} + +Status NewMemManager::EraseMemVector(const std::string& table_id) { + {//erase MemVector from rapid-insert cache + std::unique_lock lock(mutex_); + mem_id_map_.erase(table_id); + } + + {//erase MemVector from serialize cache + std::unique_lock lock(serialization_mtx_); + MemList temp_list; + for (auto& mem : immu_mem_list_) { + if(mem->GetTableId() != table_id) { + temp_list.push_back(mem); + } + } + immu_mem_list_.swap(temp_list); + } + + return Status::OK(); +} + +} // namespace engine +} // namespace milvus +} // namespace zilliz \ No newline at end of file diff --git a/cpp/src/db/NewMemManager.h b/cpp/src/db/NewMemManager.h new file mode 100644 index 0000000000000000000000000000000000000000..a5f5a9ca13b8fa86758a76bc3263bf1c6812b3d8 --- /dev/null +++ b/cpp/src/db/NewMemManager.h @@ -0,0 +1,54 @@ +#pragma once + +#include "Meta.h" +#include "MemTable.h" +#include "Status.h" +#include "MemManagerAbstract.h" + +#include +#include +#include +#include +#include + +namespace zilliz { +namespace milvus { +namespace engine { + +class NewMemManager : public MemManagerAbstract { +public: + using MetaPtr = meta::Meta::Ptr; + using Ptr = std::shared_ptr; + using MemTablePtr = typename MemTable::Ptr; + + NewMemManager(const std::shared_ptr& meta, const Options& options) + : meta_(meta), options_(options) {} + + Status InsertVectors(const std::string& table_id, + size_t n, const float* vectors, IDNumbers& vector_ids) override; + + Status Serialize(std::set& table_ids) override; + + Status EraseMemVector(const std::string& table_id) override; + +private: + MemTablePtr GetMemByTable(const std::string& table_id); + + Status InsertVectorsNoLock(const std::string& table_id, + size_t n, const float* vectors, IDNumbers& vector_ids); + Status ToImmutable(); + + using MemIdMap = std::map; + using MemList = std::vector; + MemIdMap mem_id_map_; + MemList immu_mem_list_; + MetaPtr meta_; + Options options_; + std::mutex mutex_; + std::mutex serialization_mtx_; +}; // NewMemManager + + +} // namespace engine +} // namespace milvus +} // namespace zilliz \ No newline at end of file diff --git a/cpp/src/db/VectorSource.cpp b/cpp/src/db/VectorSource.cpp index b113b9ad5e1c538d3b1c425c83707a6bf1ba6e79..d032be51f68f7b5950f01b25e1f400a6173d06d3 100644 --- a/cpp/src/db/VectorSource.cpp +++ b/cpp/src/db/VectorSource.cpp @@ -24,13 +24,18 @@ Status VectorSource::Add(const ExecutionEnginePtr& execution_engine, auto start_time = METRICS_NOW_TIME; - num_vectors_added = current_num_vectors_added + num_vectors_to_add <= n_ ? num_vectors_to_add : n_ - current_num_vectors_added; + num_vectors_added = current_num_vectors_added + num_vectors_to_add <= n_ ? + num_vectors_to_add : n_ - current_num_vectors_added; IDNumbers vector_ids_to_add; id_generator_->GetNextIDNumbers(num_vectors_added, vector_ids_to_add); - Status status = execution_engine->AddWithIds(num_vectors_added, vectors_ + current_num_vectors_added, vector_ids_to_add.data()); + Status status = execution_engine->AddWithIds(num_vectors_added, + vectors_ + current_num_vectors_added * table_file_schema.dimension_, + vector_ids_to_add.data()); if (status.ok()) { current_num_vectors_added += num_vectors_added; - vector_ids_.insert(vector_ids_.end(), vector_ids_to_add.begin(), vector_ids_to_add.end()); + vector_ids_.insert(vector_ids_.end(), + std::make_move_iterator(vector_ids_to_add.begin()), + std::make_move_iterator(vector_ids_to_add.end())); } else { ENGINE_LOG_ERROR << "VectorSource::Add failed: " + status.ToString(); @@ -38,7 +43,9 @@ Status VectorSource::Add(const ExecutionEnginePtr& execution_engine, auto end_time = METRICS_NOW_TIME; auto total_time = METRICS_MICROSECONDS(start_time, end_time); - server::Metrics::GetInstance().AddVectorsPerSecondGaugeSet(static_cast(n_), static_cast(table_file_schema.dimension_), total_time); + server::Metrics::GetInstance().AddVectorsPerSecondGaugeSet(static_cast(n_), + static_cast(table_file_schema.dimension_), + total_time); return status; } diff --git a/cpp/unittest/db/mem_test.cpp b/cpp/unittest/db/mem_test.cpp index f68d1eb8e38b59735958515b1f46957ef8d55993..915610adccfc6a92474827ccad8c2b5ac78f97a4 100644 --- a/cpp/unittest/db/mem_test.cpp +++ b/cpp/unittest/db/mem_test.cpp @@ -7,6 +7,11 @@ #include "db/Factories.h" #include "db/Constants.h" #include "db/EngineFactory.h" +#include "metrics/Metrics.h" + +#include +#include +#include using namespace zilliz::milvus; @@ -29,6 +34,9 @@ namespace { vectors.clear(); vectors.resize(n*TABLE_DIM); float* data = vectors.data(); +// std::random_device rd; +// std::mt19937 gen(rd()); +// std::uniform_real_distribution<> dis(0.0, 1.0); for(int i = 0; i < n; i++) { for(int j = 0; j < TABLE_DIM; j++) data[TABLE_DIM * i + j] = drand48(); data[TABLE_DIM * i] += i / 2000.; @@ -169,7 +177,7 @@ TEST(MEM_TEST, MEM_TABLE_TEST) { memTable.GetCurrentMemTableFile(memTableFile); ASSERT_EQ(memTableFile->GetCurrentMem(), n_100 * singleVectorMem); - ASSERT_EQ(memTable.GetStackSize(), 2); + ASSERT_EQ(memTable.GetTableFileCount(), 2); int64_t n_1G = 1024000; std::vector vectors_1G; @@ -183,8 +191,8 @@ TEST(MEM_TEST, MEM_TABLE_TEST) { vector_ids = source_1G->GetVectorIds(); ASSERT_EQ(vector_ids.size(), n_1G); - int expectedStackSize = 2 + std::ceil((n_1G - n_100) * singleVectorMem / engine::MAX_TABLE_FILE_MEM); - ASSERT_EQ(memTable.GetStackSize(), expectedStackSize); + int expectedTableFileCount = 2 + std::ceil((n_1G - n_100) * singleVectorMem / engine::MAX_TABLE_FILE_MEM); + ASSERT_EQ(memTable.GetTableFileCount(), expectedTableFileCount); status = memTable.Serialize(); ASSERT_TRUE(status.ok()); @@ -193,4 +201,127 @@ TEST(MEM_TEST, MEM_TABLE_TEST) { ASSERT_TRUE(status.ok()); } +TEST(MEM_TEST, MEM_MANAGER_TEST) { + + auto options = engine::OptionsFactory::Build(); + options.meta.path = "/tmp/milvus_test"; + options.meta.backend_uri = "sqlite://:@:/"; + auto db_ = engine::DBFactory::Build(options); + + engine::meta::TableSchema table_info = BuildTableSchema(); + engine::Status stat = db_->CreateTable(table_info); + + engine::meta::TableSchema table_info_get; + table_info_get.table_id_ = TABLE_NAME; + stat = db_->DescribeTable(table_info_get); + ASSERT_STATS(stat); + ASSERT_EQ(table_info_get.dimension_, TABLE_DIM); + + std::map> search_vectors; +// std::map> vectors_ids_map; + { + engine::IDNumbers vector_ids; + int64_t nb = 1024000; + std::vector xb; + BuildVectors(nb, xb); + engine::Status status = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); + ASSERT_TRUE(status.ok()); + +// std::ofstream myfile("mem_test.txt"); +// for (int64_t i = 0; i < nb; ++i) { +// int64_t vector_id = vector_ids[i]; +// std::vector vectors; +// for (int64_t j = 0; j < TABLE_DIM; j++) { +// vectors.emplace_back(xb[i*TABLE_DIM + j]); +//// std::cout << xb[i*TABLE_DIM + j] << std::endl; +// } +// vectors_ids_map[vector_id] = vectors; +// } + + std::this_thread::sleep_for(std::chrono::seconds(3)); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution dis(0, nb - 1); + + int64_t numQuery = 1000; + for (int64_t i = 0; i < numQuery; ++i) { + int64_t index = dis(gen); + std::vector search; + for (int64_t j = 0; j < TABLE_DIM; j++) { + search.push_back(xb[index * TABLE_DIM + j]); + } + search_vectors.insert(std::make_pair(vector_ids[index], search)); +// std::cout << "index: " << index << " vector_ids[index]: " << vector_ids[index] << std::endl; + } + +// for (int64_t i = 0; i < nb; i += 100000) { +// std::vector search; +// for (int64_t j = 0; j < TABLE_DIM; j++) { +// search.push_back(xb[i * TABLE_DIM + j]); +// } +// search_vectors.insert(std::make_pair(vector_ids[i], search)); +// } + + } + + int k = 10; + for(auto& pair : search_vectors) { + auto& search = pair.second; + engine::QueryResults results; + stat = db_->Query(TABLE_NAME, k, 1, search.data(), results); + for(int t = 0; t < k; t++) { +// std::cout << "ID=" << results[0][t].first << " DISTANCE=" << results[0][t].second << std::endl; + +// std::cout << vectors_ids_map[results[0][t].first].size() << std::endl; +// for (auto& data : vectors_ids_map[results[0][t].first]) { +// std::cout << data << " "; +// } +// std::cout << std::endl; + } + // std::cout << "results[0][0].first: " << results[0][0].first << " pair.first: " << pair.first << " results[0][0].second: " << results[0][0].second << std::endl; + ASSERT_EQ(results[0][0].first, pair.first); + ASSERT_LT(results[0][0].second, 0.00001); + } + + stat = db_->DropAll(); + ASSERT_TRUE(stat.ok()); + +} + +TEST(MEM_TEST, INSERT_TEST) { + + auto options = engine::OptionsFactory::Build(); + options.meta.path = "/tmp/milvus_test"; + options.meta.backend_uri = "sqlite://:@:/"; + auto db_ = engine::DBFactory::Build(options); + + engine::meta::TableSchema table_info = BuildTableSchema(); + engine::Status stat = db_->CreateTable(table_info); + + engine::meta::TableSchema table_info_get; + table_info_get.table_id_ = TABLE_NAME; + stat = db_->DescribeTable(table_info_get); + ASSERT_STATS(stat); + ASSERT_EQ(table_info_get.dimension_, TABLE_DIM); + + auto start_time = METRICS_NOW_TIME; + + int insert_loop = 1000; + for (int i = 0; i < insert_loop; ++i) { + int64_t nb = 204800; + std::vector xb; + BuildVectors(nb, xb); + engine::IDNumbers vector_ids; + engine::Status status = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids); + ASSERT_TRUE(status.ok()); + } + auto end_time = METRICS_NOW_TIME; + auto total_time = METRICS_MICROSECONDS(start_time, end_time); + std::cout << "total_time(ms) : " << total_time << std::endl; + + stat = db_->DropAll(); + ASSERT_TRUE(stat.ok()); + +}