From 997a5b8700353ff4293b40c7eb88b683e069d386 Mon Sep 17 00:00:00 2001 From: groot Date: Thu, 13 Aug 2020 00:08:26 +0800 Subject: [PATCH] implement loadcollection (#3229) * implement loadcollection Signed-off-by: groot * typo Signed-off-by: groot * refine code Signed-off-by: groot Co-authored-by: Wang Xiangyu --- core/src/cache/CpuCacheMgr.cpp | 4 +- core/src/cache/GpuCacheMgr.cpp | 13 +---- core/src/cache/GpuCacheMgr.h | 3 - core/src/db/Constants.h | 7 ++- core/src/db/DBImpl.cpp | 5 +- core/src/db/SnapshotHandlers.cpp | 90 ++++++++++++++---------------- core/src/db/SnapshotHandlers.h | 38 +++++-------- core/src/db/SnapshotUtils.cpp | 7 ++- core/src/db/SnapshotUtils.h | 3 + core/src/db/Types.h | 9 +-- core/src/db/insert/MemSegment.cpp | 1 - core/src/segment/SegmentReader.cpp | 10 +++- core/src/segment/SegmentReader.h | 2 +- core/unittest/db/test_db.cpp | 73 ++++++++++++++++++++---- 14 files changed, 152 insertions(+), 113 deletions(-) diff --git a/core/src/cache/CpuCacheMgr.cpp b/core/src/cache/CpuCacheMgr.cpp index b71e59a9..3c4b23a2 100644 --- a/core/src/cache/CpuCacheMgr.cpp +++ b/core/src/cache/CpuCacheMgr.cpp @@ -24,7 +24,9 @@ namespace cache { CpuCacheMgr::CpuCacheMgr() { cache_ = std::make_shared>(config.cache.cache_size(), 1UL << 32, "[CACHE CPU]"); - cache_->set_freemem_percent(config.cache.cpu_cache_threshold()); + if (config.cache.cpu_cache_threshold() > 0.0) { + cache_->set_freemem_percent(config.cache.cpu_cache_threshold()); + } ConfigMgr::GetInstance().Attach("cache.cache_size", this); } diff --git a/core/src/cache/GpuCacheMgr.cpp b/core/src/cache/GpuCacheMgr.cpp index f1486a57..3d020bc6 100644 --- a/core/src/cache/GpuCacheMgr.cpp +++ b/core/src/cache/GpuCacheMgr.cpp @@ -24,15 +24,13 @@ namespace cache { std::mutex GpuCacheMgr::global_mutex_; std::unordered_map GpuCacheMgr::instance_; -namespace { -constexpr int64_t G_BYTE = 1024 * 1024 * 1024; -} - GpuCacheMgr::GpuCacheMgr(int64_t gpu_id) : gpu_id_(gpu_id) { std::string header = "[CACHE GPU" + std::to_string(gpu_id) + "]"; cache_ = std::make_shared>(config.gpu.cache_size(), 1UL << 32, header); - cache_->set_freemem_percent(config.gpu.cache_threshold()); + if (config.gpu.cache_threshold() > 0.0) { + cache_->set_freemem_percent(config.gpu.cache_threshold()); + } ConfigMgr::GetInstance().Attach("gpu.cache_threshold", this); } @@ -51,11 +49,6 @@ GpuCacheMgr::GetInstance(int64_t gpu_id) { return instance_[gpu_id]; } -bool -GpuCacheMgr::Reserve(const int64_t size) { - return CacheMgr::Reserve(size); -} - void GpuCacheMgr::ConfigUpdate(const std::string& name) { std::lock_guard lock(global_mutex_); diff --git a/core/src/cache/GpuCacheMgr.h b/core/src/cache/GpuCacheMgr.h index f2dd4c89..8c648ea5 100644 --- a/core/src/cache/GpuCacheMgr.h +++ b/core/src/cache/GpuCacheMgr.h @@ -36,9 +36,6 @@ class GpuCacheMgr : public CacheMgr, public ConfigObserver { static GpuCacheMgrPtr GetInstance(int64_t gpu_id); - bool - Reserve(const int64_t size); - public: void ConfigUpdate(const std::string& name) override; diff --git a/core/src/db/Constants.h b/core/src/db/Constants.h index c8afeed1..7422298f 100644 --- a/core/src/db/Constants.h +++ b/core/src/db/Constants.h @@ -23,7 +23,12 @@ constexpr int64_t TB = 1LL << 40; constexpr int64_t MAX_TABLE_FILE_MEM = 128 * MB; -constexpr int FLOAT_TYPE_SIZE = sizeof(float); +constexpr int64_t BUILD_INDEX_THRESHOLD = 4096; // row count threshold when building index +constexpr int64_t MAX_NAME_LENGTH = 255; +constexpr int64_t MAX_DIMENSION = 32768; +constexpr int32_t MAX_SEGMENT_ROW_COUNT = 4 * 1024 * 1024; +constexpr int64_t DEFAULT_SEGMENT_ROW_COUNT = 100000; // default row count per segment when creating collection +constexpr int64_t MAX_INSERT_DATA_SIZE = 256 * MB; } // namespace engine } // namespace milvus diff --git a/core/src/db/DBImpl.cpp b/core/src/db/DBImpl.cpp index 62e9ae35..932416a4 100644 --- a/core/src/db/DBImpl.cpp +++ b/core/src/db/DBImpl.cpp @@ -653,10 +653,11 @@ DBImpl::LoadCollection(const server::ContextPtr& context, const std::string& col snapshot::ScopedSnapshotT ss; STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name)); - auto handler = std::make_shared(context, ss); + auto handler = std::make_shared(nullptr, ss, options_.meta_.path_, field_names, force); handler->Iterate(); + STATUS_CHECK(handler->GetStatus()); - return handler->GetStatus(); + return Status::OK(); } Status diff --git a/core/src/db/SnapshotHandlers.cpp b/core/src/db/SnapshotHandlers.cpp index 4e2b85ec..f399dc56 100644 --- a/core/src/db/SnapshotHandlers.cpp +++ b/core/src/db/SnapshotHandlers.cpp @@ -10,6 +10,7 @@ // or implied. See the License for the specific language governing permissions and limitations under the License. #include "db/SnapshotHandlers.h" +#include "db/SnapshotUtils.h" #include "db/SnapshotVisitor.h" #include "db/Types.h" #include "db/snapshot/ResourceHelper.h" @@ -24,53 +25,6 @@ namespace milvus { namespace engine { -LoadVectorFieldElementHandler::LoadVectorFieldElementHandler(const std::shared_ptr& context, - snapshot::ScopedSnapshotT ss, - const snapshot::FieldPtr& field) - : BaseT(ss), context_(context), field_(field) { -} - -Status -LoadVectorFieldElementHandler::Handle(const snapshot::FieldElementPtr& field_element) { - if (field_->GetFtype() != engine::DataType::VECTOR_FLOAT && field_->GetFtype() != engine::DataType::VECTOR_BINARY) { - return Status(DB_ERROR, "Should be VECTOR field"); - } - if (field_->GetID() != field_element->GetFieldId()) { - return Status::OK(); - } - // SS TODO - return Status::OK(); -} - -LoadVectorFieldHandler::LoadVectorFieldHandler(const std::shared_ptr& context, - snapshot::ScopedSnapshotT ss) - : BaseT(ss), context_(context) { -} - -Status -LoadVectorFieldHandler::Handle(const snapshot::FieldPtr& field) { - if (field->GetFtype() != engine::DataType::VECTOR_FLOAT && field->GetFtype() != engine::DataType::VECTOR_BINARY) { - return Status::OK(); - } - if (context_ && context_->IsConnectionBroken()) { - LOG_ENGINE_DEBUG_ << "Client connection broken, stop load collection"; - return Status(DB_ERROR, "Connection broken"); - } - - // SS TODO - auto element_handler = std::make_shared(context_, ss_, field); - element_handler->Iterate(); - - auto status = element_handler->GetStatus(); - if (!status.ok()) { - return status; - } - - // SS TODO: Do Load - - return status; -} - /////////////////////////////////////////////////////////////////////////////// SegmentsToSearchCollector::SegmentsToSearchCollector(snapshot::ScopedSnapshotT ss, snapshot::IDS_TYPE& segment_ids) : BaseT(ss), segment_ids_(segment_ids) { @@ -181,6 +135,48 @@ GetEntityByIdSegmentHandler::Handle(const snapshot::SegmentPtr& segment) { } /////////////////////////////////////////////////////////////////////////////// +LoadCollectionHandler::LoadCollectionHandler(const server::ContextPtr& context, snapshot::ScopedSnapshotT ss, + const std::string& dir_root, const std::vector& field_names, + bool force) + : BaseT(ss), context_(context), dir_root_(dir_root), field_names_(field_names), force_(force) { +} + +Status +LoadCollectionHandler::Handle(const snapshot::SegmentPtr& segment) { + auto seg_visitor = engine::SegmentVisitor::Build(ss_, segment->GetID()); + segment::SegmentReaderPtr segment_reader = std::make_shared(dir_root_, seg_visitor); + + SegmentPtr segment_ptr; + segment_reader->GetSegment(segment_ptr); + + // if the input field_names is empty, will load all fields of this collection + if (field_names_.empty()) { + field_names_ = ss_->GetFieldNames(); + } + + // SegmentReader will load data into cache + for (auto& field_name : field_names_) { + DataType ftype = DataType::NONE; + segment_ptr->GetFieldType(field_name, ftype); + + knowhere::IndexPtr index_ptr; + if (IsVectorField(ftype)) { + knowhere::VecIndexPtr vec_index_ptr; + segment_reader->LoadVectorIndex(field_name, vec_index_ptr); + index_ptr = vec_index_ptr; + } else { + segment_reader->LoadStructuredIndex(field_name, index_ptr); + } + + // if index doesn't exist, load the raw file + if (index_ptr == nullptr) { + engine::BinaryDataPtr raw; + segment_reader->LoadField(field_name, raw); + } + } + + return Status::OK(); +} } // namespace engine } // namespace milvus diff --git a/core/src/db/SnapshotHandlers.h b/core/src/db/SnapshotHandlers.h index fe8602f8..4dcc355f 100644 --- a/core/src/db/SnapshotHandlers.h +++ b/core/src/db/SnapshotHandlers.h @@ -25,30 +25,6 @@ namespace milvus { namespace engine { -struct LoadVectorFieldElementHandler : public snapshot::FieldElementIterator { - using ResourceT = snapshot::FieldElement; - using BaseT = snapshot::IterateHandler; - LoadVectorFieldElementHandler(const server::ContextPtr& context, snapshot::ScopedSnapshotT ss, - const snapshot::FieldPtr& field); - - Status - Handle(const typename ResourceT::Ptr&) override; - - const server::ContextPtr context_; - const snapshot::FieldPtr field_; -}; - -struct LoadVectorFieldHandler : public snapshot::FieldIterator { - using ResourceT = snapshot::Field; - using BaseT = snapshot::IterateHandler; - LoadVectorFieldHandler(const server::ContextPtr& context, snapshot::ScopedSnapshotT ss); - - Status - Handle(const typename ResourceT::Ptr&) override; - - const server::ContextPtr context_; -}; - struct SegmentsToSearchCollector : public snapshot::SegmentCommitIterator { using ResourceT = snapshot::SegmentCommit; using BaseT = snapshot::IterateHandler; @@ -93,6 +69,20 @@ struct GetEntityByIdSegmentHandler : public snapshot::SegmentIterator { }; /////////////////////////////////////////////////////////////////////////////// +struct LoadCollectionHandler : public snapshot::SegmentIterator { + using ResourceT = snapshot::Segment; + using BaseT = snapshot::IterateHandler; + LoadCollectionHandler(const server::ContextPtr& context, snapshot::ScopedSnapshotT ss, const std::string& dir_root, + const std::vector& field_names, bool force); + + Status + Handle(const typename ResourceT::Ptr&) override; + + const server::ContextPtr context_; + const std::string dir_root_; + std::vector field_names_; + bool force_; +}; } // namespace engine } // namespace milvus diff --git a/core/src/db/SnapshotUtils.cpp b/core/src/db/SnapshotUtils.cpp index b1eb8e95..b87654ec 100644 --- a/core/src/db/SnapshotUtils.cpp +++ b/core/src/db/SnapshotUtils.cpp @@ -154,7 +154,12 @@ IsVectorField(const engine::snapshot::FieldPtr& field) { } engine::DataType ftype = static_cast(field->GetFtype()); - return ftype == engine::DataType::VECTOR_FLOAT || ftype == engine::DataType::VECTOR_BINARY; + return IsVectorField(ftype); +} + +bool +IsVectorField(engine::DataType type) { + return type == engine::DataType::VECTOR_FLOAT || type == engine::DataType::VECTOR_BINARY; } Status diff --git a/core/src/db/SnapshotUtils.h b/core/src/db/SnapshotUtils.h index a63b8d67..a2d9a47b 100644 --- a/core/src/db/SnapshotUtils.h +++ b/core/src/db/SnapshotUtils.h @@ -46,6 +46,9 @@ DeleteSnapshotIndex(const std::string& collection_name, const std::string& field bool IsVectorField(const engine::snapshot::FieldPtr& field); +bool +IsVectorField(engine::DataType type); + Status GetSnapshotInfo(const std::string& collection_name, milvus::json& json_info); diff --git a/core/src/db/Types.h b/core/src/db/Types.h index 467fd113..6581d7d5 100644 --- a/core/src/db/Types.h +++ b/core/src/db/Types.h @@ -23,6 +23,7 @@ #include #include "cache/DataObj.h" +#include "db/Constants.h" #include "knowhere/index/vector_index/VecIndex.h" #include "utils/Json.h" @@ -153,14 +154,6 @@ extern const char* PARAM_SEGMENT_ROW_COUNT; extern const char* DEFAULT_STRUCTURED_INDEX; -constexpr int64_t BUILD_INDEX_THRESHOLD = 4096; // row count threshold when building index -constexpr int64_t MAX_NAME_LENGTH = 255; -constexpr int64_t MAX_DIMENSION = 32768; -constexpr int32_t MAX_SEGMENT_ROW_COUNT = 4 * 1024 * 1024; -constexpr int64_t DEFAULT_SEGMENT_ROW_COUNT = 100000; // default row count per segment when creating collection -constexpr int64_t M_BYTE = 1024 * 1024; -constexpr int64_t MAX_INSERT_DATA_SIZE = 256 * M_BYTE; - enum FieldElementType { FET_NONE = 0, FET_RAW = 1, diff --git a/core/src/db/insert/MemSegment.cpp b/core/src/db/insert/MemSegment.cpp index 6a161fc5..640b7c45 100644 --- a/core/src/db/insert/MemSegment.cpp +++ b/core/src/db/insert/MemSegment.cpp @@ -18,7 +18,6 @@ #include #include "config/ServerConfig.h" -#include "db/Constants.h" #include "db/Types.h" #include "db/Utils.h" #include "db/snapshot/Operations.h" diff --git a/core/src/segment/SegmentReader.cpp b/core/src/segment/SegmentReader.cpp index cbaea086..7fa7f39a 100644 --- a/core/src/segment/SegmentReader.cpp +++ b/core/src/segment/SegmentReader.cpp @@ -102,7 +102,7 @@ SegmentReader::Load() { } Status -SegmentReader::LoadField(const std::string& field_name, engine::BinaryDataPtr& raw) { +SegmentReader::LoadField(const std::string& field_name, engine::BinaryDataPtr& raw, bool to_cache) { try { segment_ptr_->GetFixedFieldData(field_name, raw); if (raw != nullptr) { @@ -124,7 +124,9 @@ SegmentReader::LoadField(const std::string& field_name, engine::BinaryDataPtr& r auto& ss_codec = codec::Codec::instance(); ss_codec.GetBlockFormat()->Read(fs_ptr_, file_path, raw); - cache::CpuCacheMgr::GetInstance().InsertItem(file_path, raw); // put into cache + if (to_cache) { + cache::CpuCacheMgr::GetInstance().InsertItem(file_path, raw); // put into cache + } } else { raw = std::static_pointer_cast(data_obj); } @@ -300,7 +302,7 @@ SegmentReader::LoadVectorIndex(const std::string& field_name, knowhere::VecIndex } int64_t dimension = json[knowhere::meta::DIM]; engine::BinaryDataPtr raw; - STATUS_CHECK(LoadField(field_name, raw)); + STATUS_CHECK(LoadField(field_name, raw, false)); auto dataset = knowhere::GenDataset(segment_commit->GetRowCount(), dimension, raw->data_.data()); @@ -319,6 +321,8 @@ SegmentReader::LoadVectorIndex(const std::string& field_name, knowhere::VecIndex index_ptr->SetUids(uids); index_ptr->SetBlacklist(concurrent_bitset_ptr); segment_ptr_->SetVectorIndex(field_name, index_ptr); + + cache::CpuCacheMgr::GetInstance().InsertItem(temp_index_path, index_ptr); } return Status::OK(); diff --git a/core/src/segment/SegmentReader.h b/core/src/segment/SegmentReader.h index f26982a9..9496d6ea 100644 --- a/core/src/segment/SegmentReader.h +++ b/core/src/segment/SegmentReader.h @@ -37,7 +37,7 @@ class SegmentReader { Load(); Status - LoadField(const std::string& field_name, engine::BinaryDataPtr& raw); + LoadField(const std::string& field_name, engine::BinaryDataPtr& raw, bool to_cache = true); Status LoadFields(); diff --git a/core/unittest/db/test_db.cpp b/core/unittest/db/test_db.cpp index 4a8faa32..0aa40825 100644 --- a/core/unittest/db/test_db.cpp +++ b/core/unittest/db/test_db.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include "db/SnapshotUtils.h" #include "db/SnapshotVisitor.h" @@ -617,9 +618,9 @@ TEST_F(DBTest, MergeTest) { std::string res_path = milvus::engine::snapshot::GetResPath(root_path, segment_file); if (std::experimental::filesystem::is_regular_file(res_path) || std::experimental::filesystem::is_regular_file(res_path + - milvus::codec::IdBloomFilterFormat::FilePostfix()) || + milvus::codec::IdBloomFilterFormat::FilePostfix()) || std::experimental::filesystem::is_regular_file(res_path + - milvus::codec::DeletedDocsFormat::FilePostfix())) { + milvus::codec::DeletedDocsFormat::FilePostfix())) { segment_file_paths.insert(res_path); std::cout << res_path << std::endl; } @@ -660,7 +661,7 @@ TEST_F(DBTest, GetEntityTest) { }; auto fill_field_names = [&](const milvus::engine::snapshot::FieldElementMappings& field_mappings, - std::vector& field_names) -> void { + std::vector& field_names) -> void { if (field_names.empty()) { for (const auto& schema : field_mappings) { field_names.emplace_back(schema.first->GetName()); @@ -704,9 +705,6 @@ TEST_F(DBTest, GetEntityTest) { status = db_->GetCollectionInfo(collection_name, collection, field_mappings); ASSERT_TRUE(status.ok()) << status.ToString(); - - - { std::vector field_names; fill_field_names(field_mappings, field_names); @@ -717,13 +715,12 @@ TEST_F(DBTest, GetEntityTest) { ASSERT_TRUE(status.ok()) << status.ToString(); ASSERT_TRUE(get_data_chunk->count_ == get_row_size(valid_row)); - for (const auto &name : field_names) { + for (const auto& name : field_names) { ASSERT_TRUE(get_data_chunk->fixed_fields_[name]->Size() == dataChunkPtr->fixed_fields_[name]->Size()); ASSERT_TRUE(get_data_chunk->fixed_fields_[name]->data_ == dataChunkPtr->fixed_fields_[name]->data_); } } - { std::vector field_names; fill_field_names(field_mappings, field_names); @@ -750,7 +747,7 @@ TEST_F(DBTest, GetEntityTest) { ASSERT_TRUE(status.ok()) << status.ToString(); ASSERT_TRUE(get_data_chunk->count_ == get_row_size(valid_row)); - for (const auto &name : field_names) { + for (const auto& name : field_names) { ASSERT_TRUE(get_data_chunk->fixed_fields_[name]->Size() == dataChunkPtr->fixed_fields_[name]->Size()); ASSERT_TRUE(get_data_chunk->fixed_fields_[name]->data_ == dataChunkPtr->fixed_fields_[name]->data_); } @@ -1099,7 +1096,6 @@ TEST_F(DBTest, FetchTest) { milvus::engine::IDNumbers segment_entity_ids; status = db_->ListIDInSegment(collection_name, id, segment_entity_ids); - std::cout << status.message() << std::endl; ASSERT_TRUE(status.ok()); if (tag == partition_name) { @@ -1255,7 +1251,7 @@ TEST_F(DBTest, DeleteStaleTest) { status = db_->Flush(collection_name); ASSERT_TRUE(status.ok()) << status.ToString(); - for (size_t i = 0; i < del_id_pair; i ++) { + for (size_t i = 0; i < del_id_pair; i++) { del_ids.push_back(entity_ids[i]); del_ids.push_back(entity_ids2[i]); } @@ -1281,3 +1277,58 @@ TEST_F(DBTest, DeleteStaleTest) { // ASSERT_EQ(entity_data_chunk->count_, 0) << "[" << j << "] Delete id " << del_ids[j] << " failed."; // } } + +TEST_F(DBTest, LoadTest) { + std::string collection_name = "LOAD_TEST"; + auto status = CreateCollection2(db_, collection_name, 0); + ASSERT_TRUE(status.ok()); + + std::string partition_name = "p1"; + status = db_->CreatePartition(collection_name, partition_name); + ASSERT_TRUE(status.ok()); + + // insert 1000 entities into default partition + // insert 1000 entities into partition 'p1' + const uint64_t entity_count = 1000; + milvus::engine::DataChunkPtr data_chunk; + BuildEntities(entity_count, 0, data_chunk); + + status = db_->Insert(collection_name, "", data_chunk); + ASSERT_TRUE(status.ok()); + + data_chunk->fixed_fields_.erase(milvus::engine::FIELD_UID); // clear auto-generated id + + status = db_->Insert(collection_name, partition_name, data_chunk); + ASSERT_TRUE(status.ok()); + + status = db_->Flush(); + ASSERT_TRUE(status.ok()); + + auto& cache_mgr = milvus::cache::CpuCacheMgr::GetInstance(); + cache_mgr.ClearCache(); + + // load "vector", "field_1" + std::vector fields = {VECTOR_FIELD_NAME, "field_1"}; + status = db_->LoadCollection(dummy_context_, collection_name, fields); + ASSERT_TRUE(status.ok()); + + // 2 segments, 2 fields, at least 4 files loaded + ASSERT_GE(cache_mgr.ItemCount(), 4); + + int64_t total_size = entity_count * (COLLECTION_DIM * sizeof(float) + sizeof(int64_t)) * 2; + ASSERT_GE(cache_mgr.CacheUsage(), total_size); + + // load all fields + fields.clear(); + cache_mgr.ClearCache(); + + status = db_->LoadCollection(dummy_context_, collection_name, fields); + ASSERT_TRUE(status.ok()); + + // 2 segments, 4 fields, at least 8 files loaded + ASSERT_GE(cache_mgr.ItemCount(), 8); + + total_size = + entity_count * (COLLECTION_DIM * sizeof(float) + sizeof(int32_t) + sizeof(int64_t) + sizeof(double)) * 2; + ASSERT_GE(cache_mgr.CacheUsage(), total_size); +} -- GitLab