未验证 提交 997a5b87 编写于 作者: G groot 提交者: GitHub

implement loadcollection (#3229)

* implement loadcollection
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* typo
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* refine code
Signed-off-by: Ngroot <yihua.mo@zilliz.com>
Co-authored-by: NWang Xiangyu <xy.wang@zilliz.com>
上级 a2e1f923
......@@ -24,7 +24,9 @@ namespace cache {
CpuCacheMgr::CpuCacheMgr() {
cache_ = std::make_shared<Cache<DataObjPtr>>(config.cache.cache_size(), 1UL << 32, "[CACHE CPU]");
cache_->set_freemem_percent(config.cache.cpu_cache_threshold());
if (config.cache.cpu_cache_threshold() > 0.0) {
cache_->set_freemem_percent(config.cache.cpu_cache_threshold());
}
ConfigMgr::GetInstance().Attach("cache.cache_size", this);
}
......
......@@ -24,15 +24,13 @@ namespace cache {
std::mutex GpuCacheMgr::global_mutex_;
std::unordered_map<int64_t, GpuCacheMgrPtr> GpuCacheMgr::instance_;
namespace {
constexpr int64_t G_BYTE = 1024 * 1024 * 1024;
}
GpuCacheMgr::GpuCacheMgr(int64_t gpu_id) : gpu_id_(gpu_id) {
std::string header = "[CACHE GPU" + std::to_string(gpu_id) + "]";
cache_ = std::make_shared<Cache<DataObjPtr>>(config.gpu.cache_size(), 1UL << 32, header);
cache_->set_freemem_percent(config.gpu.cache_threshold());
if (config.gpu.cache_threshold() > 0.0) {
cache_->set_freemem_percent(config.gpu.cache_threshold());
}
ConfigMgr::GetInstance().Attach("gpu.cache_threshold", this);
}
......@@ -51,11 +49,6 @@ GpuCacheMgr::GetInstance(int64_t gpu_id) {
return instance_[gpu_id];
}
bool
GpuCacheMgr::Reserve(const int64_t size) {
return CacheMgr<DataObjPtr>::Reserve(size);
}
void
GpuCacheMgr::ConfigUpdate(const std::string& name) {
std::lock_guard<std::mutex> lock(global_mutex_);
......
......@@ -36,9 +36,6 @@ class GpuCacheMgr : public CacheMgr<DataObjPtr>, public ConfigObserver {
static GpuCacheMgrPtr
GetInstance(int64_t gpu_id);
bool
Reserve(const int64_t size);
public:
void
ConfigUpdate(const std::string& name) override;
......
......@@ -23,7 +23,12 @@ constexpr int64_t TB = 1LL << 40;
constexpr int64_t MAX_TABLE_FILE_MEM = 128 * MB;
constexpr int FLOAT_TYPE_SIZE = sizeof(float);
constexpr int64_t BUILD_INDEX_THRESHOLD = 4096; // row count threshold when building index
constexpr int64_t MAX_NAME_LENGTH = 255;
constexpr int64_t MAX_DIMENSION = 32768;
constexpr int32_t MAX_SEGMENT_ROW_COUNT = 4 * 1024 * 1024;
constexpr int64_t DEFAULT_SEGMENT_ROW_COUNT = 100000; // default row count per segment when creating collection
constexpr int64_t MAX_INSERT_DATA_SIZE = 256 * MB;
} // namespace engine
} // namespace milvus
......@@ -653,10 +653,11 @@ DBImpl::LoadCollection(const server::ContextPtr& context, const std::string& col
snapshot::ScopedSnapshotT ss;
STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name));
auto handler = std::make_shared<LoadVectorFieldHandler>(context, ss);
auto handler = std::make_shared<LoadCollectionHandler>(nullptr, ss, options_.meta_.path_, field_names, force);
handler->Iterate();
STATUS_CHECK(handler->GetStatus());
return handler->GetStatus();
return Status::OK();
}
Status
......
......@@ -10,6 +10,7 @@
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "db/SnapshotHandlers.h"
#include "db/SnapshotUtils.h"
#include "db/SnapshotVisitor.h"
#include "db/Types.h"
#include "db/snapshot/ResourceHelper.h"
......@@ -24,53 +25,6 @@
namespace milvus {
namespace engine {
LoadVectorFieldElementHandler::LoadVectorFieldElementHandler(const std::shared_ptr<server::Context>& context,
snapshot::ScopedSnapshotT ss,
const snapshot::FieldPtr& field)
: BaseT(ss), context_(context), field_(field) {
}
Status
LoadVectorFieldElementHandler::Handle(const snapshot::FieldElementPtr& field_element) {
if (field_->GetFtype() != engine::DataType::VECTOR_FLOAT && field_->GetFtype() != engine::DataType::VECTOR_BINARY) {
return Status(DB_ERROR, "Should be VECTOR field");
}
if (field_->GetID() != field_element->GetFieldId()) {
return Status::OK();
}
// SS TODO
return Status::OK();
}
LoadVectorFieldHandler::LoadVectorFieldHandler(const std::shared_ptr<server::Context>& context,
snapshot::ScopedSnapshotT ss)
: BaseT(ss), context_(context) {
}
Status
LoadVectorFieldHandler::Handle(const snapshot::FieldPtr& field) {
if (field->GetFtype() != engine::DataType::VECTOR_FLOAT && field->GetFtype() != engine::DataType::VECTOR_BINARY) {
return Status::OK();
}
if (context_ && context_->IsConnectionBroken()) {
LOG_ENGINE_DEBUG_ << "Client connection broken, stop load collection";
return Status(DB_ERROR, "Connection broken");
}
// SS TODO
auto element_handler = std::make_shared<LoadVectorFieldElementHandler>(context_, ss_, field);
element_handler->Iterate();
auto status = element_handler->GetStatus();
if (!status.ok()) {
return status;
}
// SS TODO: Do Load
return status;
}
///////////////////////////////////////////////////////////////////////////////
SegmentsToSearchCollector::SegmentsToSearchCollector(snapshot::ScopedSnapshotT ss, snapshot::IDS_TYPE& segment_ids)
: BaseT(ss), segment_ids_(segment_ids) {
......@@ -181,6 +135,48 @@ GetEntityByIdSegmentHandler::Handle(const snapshot::SegmentPtr& segment) {
}
///////////////////////////////////////////////////////////////////////////////
LoadCollectionHandler::LoadCollectionHandler(const server::ContextPtr& context, snapshot::ScopedSnapshotT ss,
const std::string& dir_root, const std::vector<std::string>& field_names,
bool force)
: BaseT(ss), context_(context), dir_root_(dir_root), field_names_(field_names), force_(force) {
}
Status
LoadCollectionHandler::Handle(const snapshot::SegmentPtr& segment) {
auto seg_visitor = engine::SegmentVisitor::Build(ss_, segment->GetID());
segment::SegmentReaderPtr segment_reader = std::make_shared<segment::SegmentReader>(dir_root_, seg_visitor);
SegmentPtr segment_ptr;
segment_reader->GetSegment(segment_ptr);
// if the input field_names is empty, will load all fields of this collection
if (field_names_.empty()) {
field_names_ = ss_->GetFieldNames();
}
// SegmentReader will load data into cache
for (auto& field_name : field_names_) {
DataType ftype = DataType::NONE;
segment_ptr->GetFieldType(field_name, ftype);
knowhere::IndexPtr index_ptr;
if (IsVectorField(ftype)) {
knowhere::VecIndexPtr vec_index_ptr;
segment_reader->LoadVectorIndex(field_name, vec_index_ptr);
index_ptr = vec_index_ptr;
} else {
segment_reader->LoadStructuredIndex(field_name, index_ptr);
}
// if index doesn't exist, load the raw file
if (index_ptr == nullptr) {
engine::BinaryDataPtr raw;
segment_reader->LoadField(field_name, raw);
}
}
return Status::OK();
}
} // namespace engine
} // namespace milvus
......@@ -25,30 +25,6 @@
namespace milvus {
namespace engine {
struct LoadVectorFieldElementHandler : public snapshot::FieldElementIterator {
using ResourceT = snapshot::FieldElement;
using BaseT = snapshot::IterateHandler<ResourceT>;
LoadVectorFieldElementHandler(const server::ContextPtr& context, snapshot::ScopedSnapshotT ss,
const snapshot::FieldPtr& field);
Status
Handle(const typename ResourceT::Ptr&) override;
const server::ContextPtr context_;
const snapshot::FieldPtr field_;
};
struct LoadVectorFieldHandler : public snapshot::FieldIterator {
using ResourceT = snapshot::Field;
using BaseT = snapshot::IterateHandler<ResourceT>;
LoadVectorFieldHandler(const server::ContextPtr& context, snapshot::ScopedSnapshotT ss);
Status
Handle(const typename ResourceT::Ptr&) override;
const server::ContextPtr context_;
};
struct SegmentsToSearchCollector : public snapshot::SegmentCommitIterator {
using ResourceT = snapshot::SegmentCommit;
using BaseT = snapshot::IterateHandler<ResourceT>;
......@@ -93,6 +69,20 @@ struct GetEntityByIdSegmentHandler : public snapshot::SegmentIterator {
};
///////////////////////////////////////////////////////////////////////////////
struct LoadCollectionHandler : public snapshot::SegmentIterator {
using ResourceT = snapshot::Segment;
using BaseT = snapshot::IterateHandler<ResourceT>;
LoadCollectionHandler(const server::ContextPtr& context, snapshot::ScopedSnapshotT ss, const std::string& dir_root,
const std::vector<std::string>& field_names, bool force);
Status
Handle(const typename ResourceT::Ptr&) override;
const server::ContextPtr context_;
const std::string dir_root_;
std::vector<std::string> field_names_;
bool force_;
};
} // namespace engine
} // namespace milvus
......@@ -154,7 +154,12 @@ IsVectorField(const engine::snapshot::FieldPtr& field) {
}
engine::DataType ftype = static_cast<engine::DataType>(field->GetFtype());
return ftype == engine::DataType::VECTOR_FLOAT || ftype == engine::DataType::VECTOR_BINARY;
return IsVectorField(ftype);
}
bool
IsVectorField(engine::DataType type) {
return type == engine::DataType::VECTOR_FLOAT || type == engine::DataType::VECTOR_BINARY;
}
Status
......
......@@ -46,6 +46,9 @@ DeleteSnapshotIndex(const std::string& collection_name, const std::string& field
bool
IsVectorField(const engine::snapshot::FieldPtr& field);
bool
IsVectorField(engine::DataType type);
Status
GetSnapshotInfo(const std::string& collection_name, milvus::json& json_info);
......
......@@ -23,6 +23,7 @@
#include <vector>
#include "cache/DataObj.h"
#include "db/Constants.h"
#include "knowhere/index/vector_index/VecIndex.h"
#include "utils/Json.h"
......@@ -153,14 +154,6 @@ extern const char* PARAM_SEGMENT_ROW_COUNT;
extern const char* DEFAULT_STRUCTURED_INDEX;
constexpr int64_t BUILD_INDEX_THRESHOLD = 4096; // row count threshold when building index
constexpr int64_t MAX_NAME_LENGTH = 255;
constexpr int64_t MAX_DIMENSION = 32768;
constexpr int32_t MAX_SEGMENT_ROW_COUNT = 4 * 1024 * 1024;
constexpr int64_t DEFAULT_SEGMENT_ROW_COUNT = 100000; // default row count per segment when creating collection
constexpr int64_t M_BYTE = 1024 * 1024;
constexpr int64_t MAX_INSERT_DATA_SIZE = 256 * M_BYTE;
enum FieldElementType {
FET_NONE = 0,
FET_RAW = 1,
......
......@@ -18,7 +18,6 @@
#include <vector>
#include "config/ServerConfig.h"
#include "db/Constants.h"
#include "db/Types.h"
#include "db/Utils.h"
#include "db/snapshot/Operations.h"
......
......@@ -102,7 +102,7 @@ SegmentReader::Load() {
}
Status
SegmentReader::LoadField(const std::string& field_name, engine::BinaryDataPtr& raw) {
SegmentReader::LoadField(const std::string& field_name, engine::BinaryDataPtr& raw, bool to_cache) {
try {
segment_ptr_->GetFixedFieldData(field_name, raw);
if (raw != nullptr) {
......@@ -124,7 +124,9 @@ SegmentReader::LoadField(const std::string& field_name, engine::BinaryDataPtr& r
auto& ss_codec = codec::Codec::instance();
ss_codec.GetBlockFormat()->Read(fs_ptr_, file_path, raw);
cache::CpuCacheMgr::GetInstance().InsertItem(file_path, raw); // put into cache
if (to_cache) {
cache::CpuCacheMgr::GetInstance().InsertItem(file_path, raw); // put into cache
}
} else {
raw = std::static_pointer_cast<engine::BinaryData>(data_obj);
}
......@@ -300,7 +302,7 @@ SegmentReader::LoadVectorIndex(const std::string& field_name, knowhere::VecIndex
}
int64_t dimension = json[knowhere::meta::DIM];
engine::BinaryDataPtr raw;
STATUS_CHECK(LoadField(field_name, raw));
STATUS_CHECK(LoadField(field_name, raw, false));
auto dataset = knowhere::GenDataset(segment_commit->GetRowCount(), dimension, raw->data_.data());
......@@ -319,6 +321,8 @@ SegmentReader::LoadVectorIndex(const std::string& field_name, knowhere::VecIndex
index_ptr->SetUids(uids);
index_ptr->SetBlacklist(concurrent_bitset_ptr);
segment_ptr_->SetVectorIndex(field_name, index_ptr);
cache::CpuCacheMgr::GetInstance().InsertItem(temp_index_path, index_ptr);
}
return Status::OK();
......
......@@ -37,7 +37,7 @@ class SegmentReader {
Load();
Status
LoadField(const std::string& field_name, engine::BinaryDataPtr& raw);
LoadField(const std::string& field_name, engine::BinaryDataPtr& raw, bool to_cache = true);
Status
LoadFields();
......
......@@ -17,6 +17,7 @@
#include <set>
#include <string>
#include <experimental/filesystem>
#include <src/cache/CpuCacheMgr.h>
#include "db/SnapshotUtils.h"
#include "db/SnapshotVisitor.h"
......@@ -617,9 +618,9 @@ TEST_F(DBTest, MergeTest) {
std::string res_path = milvus::engine::snapshot::GetResPath<SegmentFile>(root_path, segment_file);
if (std::experimental::filesystem::is_regular_file(res_path) ||
std::experimental::filesystem::is_regular_file(res_path +
milvus::codec::IdBloomFilterFormat::FilePostfix()) ||
milvus::codec::IdBloomFilterFormat::FilePostfix()) ||
std::experimental::filesystem::is_regular_file(res_path +
milvus::codec::DeletedDocsFormat::FilePostfix())) {
milvus::codec::DeletedDocsFormat::FilePostfix())) {
segment_file_paths.insert(res_path);
std::cout << res_path << std::endl;
}
......@@ -660,7 +661,7 @@ TEST_F(DBTest, GetEntityTest) {
};
auto fill_field_names = [&](const milvus::engine::snapshot::FieldElementMappings& field_mappings,
std::vector<std::string>& field_names) -> void {
std::vector<std::string>& field_names) -> void {
if (field_names.empty()) {
for (const auto& schema : field_mappings) {
field_names.emplace_back(schema.first->GetName());
......@@ -704,9 +705,6 @@ TEST_F(DBTest, GetEntityTest) {
status = db_->GetCollectionInfo(collection_name, collection, field_mappings);
ASSERT_TRUE(status.ok()) << status.ToString();
{
std::vector<std::string> field_names;
fill_field_names(field_mappings, field_names);
......@@ -717,13 +715,12 @@ TEST_F(DBTest, GetEntityTest) {
ASSERT_TRUE(status.ok()) << status.ToString();
ASSERT_TRUE(get_data_chunk->count_ == get_row_size(valid_row));
for (const auto &name : field_names) {
for (const auto& name : field_names) {
ASSERT_TRUE(get_data_chunk->fixed_fields_[name]->Size() == dataChunkPtr->fixed_fields_[name]->Size());
ASSERT_TRUE(get_data_chunk->fixed_fields_[name]->data_ == dataChunkPtr->fixed_fields_[name]->data_);
}
}
{
std::vector<std::string> field_names;
fill_field_names(field_mappings, field_names);
......@@ -750,7 +747,7 @@ TEST_F(DBTest, GetEntityTest) {
ASSERT_TRUE(status.ok()) << status.ToString();
ASSERT_TRUE(get_data_chunk->count_ == get_row_size(valid_row));
for (const auto &name : field_names) {
for (const auto& name : field_names) {
ASSERT_TRUE(get_data_chunk->fixed_fields_[name]->Size() == dataChunkPtr->fixed_fields_[name]->Size());
ASSERT_TRUE(get_data_chunk->fixed_fields_[name]->data_ == dataChunkPtr->fixed_fields_[name]->data_);
}
......@@ -1099,7 +1096,6 @@ TEST_F(DBTest, FetchTest) {
milvus::engine::IDNumbers segment_entity_ids;
status = db_->ListIDInSegment(collection_name, id, segment_entity_ids);
std::cout << status.message() << std::endl;
ASSERT_TRUE(status.ok());
if (tag == partition_name) {
......@@ -1255,7 +1251,7 @@ TEST_F(DBTest, DeleteStaleTest) {
status = db_->Flush(collection_name);
ASSERT_TRUE(status.ok()) << status.ToString();
for (size_t i = 0; i < del_id_pair; i ++) {
for (size_t i = 0; i < del_id_pair; i++) {
del_ids.push_back(entity_ids[i]);
del_ids.push_back(entity_ids2[i]);
}
......@@ -1281,3 +1277,58 @@ TEST_F(DBTest, DeleteStaleTest) {
// ASSERT_EQ(entity_data_chunk->count_, 0) << "[" << j << "] Delete id " << del_ids[j] << " failed.";
// }
}
TEST_F(DBTest, LoadTest) {
std::string collection_name = "LOAD_TEST";
auto status = CreateCollection2(db_, collection_name, 0);
ASSERT_TRUE(status.ok());
std::string partition_name = "p1";
status = db_->CreatePartition(collection_name, partition_name);
ASSERT_TRUE(status.ok());
// insert 1000 entities into default partition
// insert 1000 entities into partition 'p1'
const uint64_t entity_count = 1000;
milvus::engine::DataChunkPtr data_chunk;
BuildEntities(entity_count, 0, data_chunk);
status = db_->Insert(collection_name, "", data_chunk);
ASSERT_TRUE(status.ok());
data_chunk->fixed_fields_.erase(milvus::engine::FIELD_UID); // clear auto-generated id
status = db_->Insert(collection_name, partition_name, data_chunk);
ASSERT_TRUE(status.ok());
status = db_->Flush();
ASSERT_TRUE(status.ok());
auto& cache_mgr = milvus::cache::CpuCacheMgr::GetInstance();
cache_mgr.ClearCache();
// load "vector", "field_1"
std::vector<std::string> fields = {VECTOR_FIELD_NAME, "field_1"};
status = db_->LoadCollection(dummy_context_, collection_name, fields);
ASSERT_TRUE(status.ok());
// 2 segments, 2 fields, at least 4 files loaded
ASSERT_GE(cache_mgr.ItemCount(), 4);
int64_t total_size = entity_count * (COLLECTION_DIM * sizeof(float) + sizeof(int64_t)) * 2;
ASSERT_GE(cache_mgr.CacheUsage(), total_size);
// load all fields
fields.clear();
cache_mgr.ClearCache();
status = db_->LoadCollection(dummy_context_, collection_name, fields);
ASSERT_TRUE(status.ok());
// 2 segments, 4 fields, at least 8 files loaded
ASSERT_GE(cache_mgr.ItemCount(), 8);
total_size =
entity_count * (COLLECTION_DIM * sizeof(float) + sizeof(int32_t) + sizeof(int64_t) + sizeof(double)) * 2;
ASSERT_GE(cache_mgr.CacheUsage(), total_size);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册