未验证 提交 7f6a8fbe 编写于 作者: G groot 提交者: GitHub

rewrite insert memmanager for wal (#3391)

* prepare change memmanager for wal
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix wal test case
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* rewrite insert memmanager
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix unittest failed
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* rewrite insert machinery
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* insert fields validation
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* code format
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* avoid build hang
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* wal path
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* typo
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix get entity by id bug
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* typo
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix wal bug
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix a bug
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix wal bug
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix wal test bug
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix wal path bug
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix test failure
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* typo
Signed-off-by: Ngroot <yihua.mo@zilliz.com>
上级 1590b105
......@@ -21,7 +21,7 @@ constexpr int64_t MB = 1LL << 20;
constexpr int64_t GB = 1LL << 30;
constexpr int64_t TB = 1LL << 40;
constexpr int64_t MAX_TABLE_FILE_MEM = 128 * MB;
constexpr int64_t MAX_MEM_SEGMENT_SIZE = 128 * MB;
constexpr int64_t MAX_NAME_LENGTH = 255;
constexpr int64_t MAX_DIMENSION = 32768;
......@@ -30,5 +30,7 @@ constexpr int64_t DEFAULT_SEGMENT_ROW_COUNT = 100000; // default row count per
constexpr int64_t MAX_INSERT_DATA_SIZE = 256 * MB;
constexpr int64_t MAX_WAL_FILE_SIZE = 256 * MB;
constexpr int64_t BUILD_INEDX_RETRY_TIMES = 3;
} // namespace engine
} // namespace milvus
......@@ -42,6 +42,7 @@
#include <fiu/fiu-local.h>
#include <src/scheduler/job/BuildIndexJob.h>
#include <limits>
#include <unordered_set>
#include <utility>
namespace milvus {
......@@ -168,8 +169,8 @@ DBImpl::CreateCollection(const snapshot::CreateCollectionContext& context) {
auto params = ctx.collection->GetParams();
if (params.find(PARAM_UID_AUTOGEN) == params.end()) {
params[PARAM_UID_AUTOGEN] = true;
ctx.collection->SetParams(params);
}
ctx.collection->SetParams(params);
// check uid existence
snapshot::FieldPtr uid_field;
......@@ -367,7 +368,7 @@ DBImpl::CreateIndex(const std::shared_ptr<server::Context>& context, const std::
// step 5: start background build index thread
std::vector<std::string> collection_names = {collection_name};
WaitBuildIndexFinish();
StartBuildIndexTask(collection_names);
StartBuildIndexTask(collection_names, true);
// step 6: iterate segments need to be build index, wait until all segments are built
while (true) {
......@@ -375,7 +376,14 @@ DBImpl::CreateIndex(const std::shared_ptr<server::Context>& context, const std::
snapshot::IDS_TYPE segment_ids;
ss_visitor.SegmentsToIndex(field_name, segment_ids);
if (segment_ids.empty()) {
break;
break; // all segments build index finished
}
snapshot::ScopedSnapshotT ss;
STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name));
IgnoreIndexFailedSegments(ss->GetCollectionId(), segment_ids);
if (segment_ids.empty()) {
break; // some segments failed to build index, and ignored
}
index_req_swn_.Wait_For(std::chrono::seconds(1));
......@@ -398,8 +406,10 @@ DBImpl::DropIndex(const std::string& collection_name, const std::string& field_n
STATUS_CHECK(DeleteSnapshotIndex(collection_name, field_name));
std::set<std::string> merge_collection_names = {collection_name};
StartMergeTask(merge_collection_names, true);
snapshot::ScopedSnapshotT ss;
STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name));
std::set<int64_t> collection_ids = {ss->GetCollectionId()};
StartMergeTask(collection_ids, true);
return Status::OK();
}
......@@ -427,8 +437,8 @@ DBImpl::Insert(const std::string& collection_name, const std::string& partition_
snapshot::ScopedSnapshotT ss;
STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name));
auto partition_ptr = ss->GetPartition(partition_name);
if (partition_ptr == nullptr) {
auto partition = ss->GetPartition(partition_name);
if (partition == nullptr) {
return Status(DB_NOT_FOUND, "Fail to get partition " + partition_name);
}
......@@ -437,6 +447,37 @@ DBImpl::Insert(const std::string& collection_name, const std::string& partition_
return Status(DB_ERROR, "Field '_id' not found");
}
// check field names
auto field_names = ss->GetFieldNames();
std::unordered_set<std::string> collection_field_names;
for (auto& name : field_names) {
collection_field_names.insert(name);
}
collection_field_names.erase(engine::FIELD_UID);
std::unordered_set<std::string> chunk_field_names;
for (auto& pair : data_chunk->fixed_fields_) {
chunk_field_names.insert(pair.first);
}
for (auto& pair : data_chunk->variable_fields_) {
chunk_field_names.insert(pair.first);
}
chunk_field_names.erase(engine::FIELD_UID);
if (collection_field_names.size() != chunk_field_names.size()) {
std::string msg = "Collection has " + std::to_string(collection_field_names.size()) +
" fields while the insert data has " + std::to_string(chunk_field_names.size()) + " fields";
return Status(DB_ERROR, msg);
} else {
for (auto& name : chunk_field_names) {
if (collection_field_names.find(name) == collection_field_names.end()) {
std::string msg = "The field " + name + " is not defined in collection mapping";
return Status(DB_ERROR, msg);
}
}
}
// check id field existence
auto& params = ss->GetCollection()->GetParams();
bool auto_increment = true;
if (params.find(PARAM_UID_AUTOGEN) != params.end()) {
......@@ -446,39 +487,44 @@ DBImpl::Insert(const std::string& collection_name, const std::string& partition_
FIXEDX_FIELD_MAP& fields = data_chunk->fixed_fields_;
auto pair = fields.find(engine::FIELD_UID);
if (auto_increment) {
// id is auto increment, but client provides id, return error
// id is auto generated, but client provides id, return error
if (pair != fields.end() && pair->second != nullptr) {
return Status(DB_ERROR, "Field '_id' is auto increment, no need to provide id");
}
} else {
// id is not auto increment, but client doesn't provide id, return error
// id is not auto generated, but client doesn't provide id, return error
if (pair == fields.end() || pair->second == nullptr) {
return Status(DB_ERROR, "Field '_id' is user defined");
}
}
// consume the data chunk
DataChunkPtr consume_chunk = std::make_shared<DataChunk>();
consume_chunk->count_ = data_chunk->count_;
consume_chunk->fixed_fields_.swap(data_chunk->fixed_fields_);
consume_chunk->variable_fields_.swap(data_chunk->variable_fields_);
// generate id
if (auto_increment) {
SafeIDGenerator& id_generator = SafeIDGenerator::GetInstance();
IDNumbers ids;
STATUS_CHECK(id_generator.GetNextIDNumbers(data_chunk->count_, ids));
STATUS_CHECK(id_generator.GetNextIDNumbers(consume_chunk->count_, ids));
BinaryDataPtr id_data = std::make_shared<BinaryData>();
id_data->data_.resize(ids.size() * sizeof(int64_t));
memcpy(id_data->data_.data(), ids.data(), ids.size() * sizeof(int64_t));
data_chunk->fixed_fields_[engine::FIELD_UID] = id_data;
}
// insert entities: collection_name is field id
snapshot::PartitionPtr part = ss->GetPartition(partition_name);
if (part == nullptr) {
LOG_ENGINE_ERROR_ << LogOut("[%s][%ld] ", "insert", 0) << "Get partition fail: " << partition_name;
return Status(DB_ERROR, "Invalid partiiton name");
consume_chunk->fixed_fields_[engine::FIELD_UID] = id_data;
data_chunk->fixed_fields_[engine::FIELD_UID] = id_data; // return generated id to customer;
} else {
BinaryDataPtr id_data = std::make_shared<BinaryData>();
id_data->data_ = consume_chunk->fixed_fields_[engine::FIELD_UID]->data_;
data_chunk->fixed_fields_[engine::FIELD_UID] = id_data; // return the id created by client
}
// do insert
int64_t collection_id = ss->GetCollectionId();
int64_t partition_id = part->GetID();
int64_t partition_id = partition->GetID();
auto status = mem_mgr_->InsertEntities(collection_id, partition_id, data_chunk, op_id);
auto status = mem_mgr_->InsertEntities(collection_id, partition_id, consume_chunk, op_id);
if (!status.ok()) {
return status;
}
......@@ -793,7 +839,7 @@ DBImpl::Compact(const std::shared_ptr<server::Context>& context, const std::stri
void
DBImpl::InternalFlush(const std::string& collection_name, bool merge) {
Status status;
std::set<std::string> flushed_collections;
std::set<int64_t> flushed_collection_ids;
if (!collection_name.empty()) {
// flush one collection
snapshot::ScopedSnapshotT ss;
......@@ -810,34 +856,21 @@ DBImpl::InternalFlush(const std::string& collection_name, bool merge) {
if (!status.ok()) {
return;
}
flushed_collection_ids.insert(collection_id);
}
flushed_collections.insert(collection_name);
} else {
// flush all collections
std::set<int64_t> collection_ids;
{
const std::lock_guard<std::mutex> lock(flush_merge_compact_mutex_);
status = mem_mgr_->Flush(collection_ids);
status = mem_mgr_->Flush(flushed_collection_ids);
if (!status.ok()) {
return;
}
}
for (auto id : collection_ids) {
snapshot::ScopedSnapshotT ss;
status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, id);
if (!status.ok()) {
LOG_WAL_ERROR_ << LogOut("[%s][%ld] ", "flush", 0) << "Get snapshot fail: " << status.message();
return;
}
flushed_collections.insert(ss->GetName());
}
}
if (merge) {
StartMergeTask(flushed_collections);
StartMergeTask(flushed_collection_ids);
}
}
......@@ -907,7 +940,7 @@ DBImpl::TimingMetricThread() {
}
void
DBImpl::StartBuildIndexTask(const std::vector<std::string>& collection_names) {
DBImpl::StartBuildIndexTask(const std::vector<std::string>& collection_names, bool reset_retry_times) {
// build index has been finished?
{
std::lock_guard<std::mutex> lck(index_result_mutex_);
......@@ -923,6 +956,11 @@ DBImpl::StartBuildIndexTask(const std::vector<std::string>& collection_names) {
{
std::lock_guard<std::mutex> lck(index_result_mutex_);
if (index_thread_results_.empty()) {
if (reset_retry_times) {
std::lock_guard<std::mutex> lock(index_retry_mutex_);
index_retry_map_.clear(); // reset index retry times
}
index_thread_results_.push_back(
index_thread_pool_.enqueue(&DBImpl::BackgroundBuildIndexTask, this, collection_names));
}
......@@ -949,6 +987,14 @@ DBImpl::BackgroundBuildIndexTask(std::vector<std::string> collection_names) {
continue;
}
// check index retry times
snapshot::ID_TYPE collection_id = latest_ss->GetCollectionId();
IgnoreIndexFailedSegments(collection_id, segment_ids);
if (segment_ids.empty()) {
continue;
}
// start build index job
LOG_ENGINE_DEBUG_ << "Create BuildIndexJob for " << segment_ids.size() << " segments of " << collection_name;
cache::CpuCacheMgr::GetInstance().PrintInfo(); // print cache info before build index
scheduler::BuildIndexJobPtr job = std::make_shared<scheduler::BuildIndexJob>(latest_ss, options_, segment_ids);
......@@ -956,9 +1002,12 @@ DBImpl::BackgroundBuildIndexTask(std::vector<std::string> collection_names) {
job->WaitFinish();
cache::CpuCacheMgr::GetInstance().PrintInfo(); // print cache info after build index
// record failed segments, avoid build index hang
snapshot::IDS_TYPE& failed_ids = job->FailedSegments();
MarkIndexFailedSegments(collection_id, failed_ids);
if (!job->status().ok()) {
LOG_ENGINE_ERROR_ << job->status().message();
break;
}
}
}
......@@ -981,7 +1030,7 @@ DBImpl::TimingIndexThread() {
std::vector<std::string> collection_names;
snapshot::Snapshots::GetInstance().GetCollectionNames(collection_names);
WaitMergeFileFinish();
StartBuildIndexTask(collection_names);
StartBuildIndexTask(collection_names, false);
}
}
......@@ -996,8 +1045,7 @@ DBImpl::WaitBuildIndexFinish() {
}
void
DBImpl::StartMergeTask(const std::set<std::string>& collection_names, bool force_merge_all) {
// LOG_ENGINE_DEBUG_ << "Begin StartMergeTask";
DBImpl::StartMergeTask(const std::set<int64_t>& collection_ids, bool force_merge_all) {
// merge task has been finished?
{
std::lock_guard<std::mutex> lck(merge_result_mutex_);
......@@ -1015,28 +1063,26 @@ DBImpl::StartMergeTask(const std::set<std::string>& collection_names, bool force
if (merge_thread_results_.empty()) {
// start merge file thread
merge_thread_results_.push_back(
merge_thread_pool_.enqueue(&DBImpl::BackgroundMerge, this, collection_names, force_merge_all));
merge_thread_pool_.enqueue(&DBImpl::BackgroundMerge, this, collection_ids, force_merge_all));
}
}
// LOG_ENGINE_DEBUG_ << "End StartMergeTask";
}
void
DBImpl::BackgroundMerge(std::set<std::string> collection_names, bool force_merge_all) {
DBImpl::BackgroundMerge(std::set<int64_t> collection_ids, bool force_merge_all) {
SetThreadName("merge");
for (auto& collection_name : collection_names) {
for (auto& collection_id : collection_ids) {
const std::lock_guard<std::mutex> lock(flush_merge_compact_mutex_);
auto status = merge_mgr_ptr_->MergeFiles(collection_name);
auto status = merge_mgr_ptr_->MergeFiles(collection_id);
if (!status.ok()) {
LOG_ENGINE_ERROR_ << "Failed to get merge files for collection: " << collection_name
LOG_ENGINE_ERROR_ << "Failed to get merge files for collection id: " << collection_id
<< " reason:" << status.message();
}
if (!initialized_.load(std::memory_order_acquire)) {
LOG_ENGINE_DEBUG_ << "Server will shutdown, skip merge action for collection: " << collection_name;
LOG_ENGINE_DEBUG_ << "Server will shutdown, skip merge action for collection id: " << collection_id;
break;
}
}
......@@ -1077,5 +1123,27 @@ DBImpl::ConfigUpdate(const std::string& name) {
}
}
void
DBImpl::MarkIndexFailedSegments(snapshot::ID_TYPE collection_id, const snapshot::IDS_TYPE& failed_ids) {
std::lock_guard<std::mutex> lock(index_retry_mutex_);
SegmentIndexRetryMap& retry_map = index_retry_map_[collection_id];
for (auto& id : failed_ids) {
retry_map[id]++;
}
}
void
DBImpl::IgnoreIndexFailedSegments(snapshot::ID_TYPE collection_id, snapshot::IDS_TYPE& segment_ids) {
std::lock_guard<std::mutex> lock(index_retry_mutex_);
SegmentIndexRetryMap& retry_map = index_retry_map_[collection_id];
snapshot::IDS_TYPE segment_ids_to_build;
for (auto id : segment_ids) {
if (retry_map[id] < BUILD_INEDX_RETRY_TIMES) {
segment_ids_to_build.push_back(id);
}
}
segment_ids.swap(segment_ids_to_build);
}
} // namespace engine
} // namespace milvus
......@@ -85,6 +85,7 @@ class DBImpl : public DB, public ConfigObserver {
Status
DescribeIndex(const std::string& collection_name, const std::string& field_name, CollectionIndex& index) override;
// Note: the data_chunk will be consumed with this method, and only return id field to client
Status
Insert(const std::string& collection_name, const std::string& partition_name, DataChunkPtr& data_chunk,
idx_t op_id) override;
......@@ -103,7 +104,7 @@ class DBImpl : public DB, public ConfigObserver {
Status
ListIDInSegment(const std::string& collection_name, int64_t segment_id, IDNumbers& entity_ids) override;
// if the input field_names is empty, will load all fields of this collection
// Note: if the input field_names is empty, will load all fields of this collection
Status
LoadCollection(const server::ContextPtr& context, const std::string& collection_name,
const std::vector<std::string>& field_names, bool force) override;
......@@ -114,6 +115,8 @@ class DBImpl : public DB, public ConfigObserver {
Status
Flush() override;
// Note: the threshold is percent of deleted entities that trigger compact action,
// default is 0.0, means compact will create a new segment even only one entity is deleted
Status
Compact(const server::ContextPtr& context, const std::string& collection_name, double threshold) override;
......@@ -134,7 +137,7 @@ class DBImpl : public DB, public ConfigObserver {
TimingMetricThread();
void
StartBuildIndexTask(const std::vector<std::string>& collection_names);
StartBuildIndexTask(const std::vector<std::string>& collection_names, bool reset_retry_times);
void
BackgroundBuildIndexTask(std::vector<std::string> collection_names);
......@@ -146,10 +149,10 @@ class DBImpl : public DB, public ConfigObserver {
WaitBuildIndexFinish();
void
StartMergeTask(const std::set<std::string>& collection_names, bool force_merge_all = false);
StartMergeTask(const std::set<int64_t>& collection_ids, bool force_merge_all = false);
void
BackgroundMerge(std::set<std::string> collection_names, bool force_merge_all);
BackgroundMerge(std::set<int64_t> collection_ids, bool force_merge_all);
void
WaitMergeFileFinish();
......@@ -160,6 +163,12 @@ class DBImpl : public DB, public ConfigObserver {
void
ResumeIfLast();
void
MarkIndexFailedSegments(snapshot::ID_TYPE collection_id, const snapshot::IDS_TYPE& failed_ids);
void
IgnoreIndexFailedSegments(snapshot::ID_TYPE collection_id, snapshot::IDS_TYPE& segment_ids);
private:
DBOptions options_;
std::atomic<bool> initialized_;
......@@ -186,6 +195,11 @@ class DBImpl : public DB, public ConfigObserver {
std::mutex index_result_mutex_;
std::list<std::future<void>> index_thread_results_;
using SegmentIndexRetryMap = std::unordered_map<snapshot::ID_TYPE, int64_t>;
using CollectionIndexRetryMap = std::unordered_map<snapshot::ID_TYPE, SegmentIndexRetryMap>;
CollectionIndexRetryMap index_retry_map_;
std::mutex index_retry_mutex_;
std::mutex build_index_mutex_;
std::mutex flush_merge_compact_mutex_;
......
......@@ -81,6 +81,7 @@ GetEntityByIdSegmentHandler::GetEntityByIdSegmentHandler(const std::shared_ptr<m
const std::vector<std::string>& field_names,
std::vector<bool>& valid_row)
: BaseT(ss), context_(context), dir_root_(dir_root), ids_(ids), field_names_(field_names), valid_row_(valid_row) {
ids_left_ = ids_;
}
Status
......@@ -102,19 +103,20 @@ GetEntityByIdSegmentHandler::Handle(const snapshot::SegmentPtr& segment) {
segment::DeletedDocsPtr deleted_docs_ptr;
STATUS_CHECK(segment_reader.LoadDeletedDocs(deleted_docs_ptr));
std::vector<idx_t> ids_in_this_segment;
std::vector<int64_t> offsets;
int i = 0;
for (auto id : ids_) {
for (IDNumbers::iterator it = ids_left_.begin(); it != ids_left_.end();) {
idx_t id = *it;
// fast check using bloom filter
if (!id_bloom_filter_ptr->Check(id)) {
i++;
++it;
continue;
}
// check if id really exists in uids
auto found = std::find(uids.begin(), uids.end(), id);
if (found == uids.end()) {
i++;
++it;
continue;
}
......@@ -124,16 +126,69 @@ GetEntityByIdSegmentHandler::Handle(const snapshot::SegmentPtr& segment) {
auto& deleted_docs = deleted_docs_ptr->GetDeletedDocs();
auto deleted = std::find(deleted_docs.begin(), deleted_docs.end(), offset);
if (deleted != deleted_docs.end()) {
i++;
++it;
continue;
}
}
valid_row_[i] = true;
ids_in_this_segment.push_back(id);
offsets.push_back(offset);
i++;
ids_left_.erase(it);
}
if (offsets.empty()) {
return Status::OK();
}
engine::DataChunkPtr data_chunk;
STATUS_CHECK(segment_reader.LoadFieldsEntities(field_names_, offsets, data_chunk));
// record id in which chunk, and its position within the chunk
for (int64_t i = 0; i < ids_in_this_segment.size(); ++i) {
auto pair = std::make_pair(data_chunk, i);
result_map_.insert(std::make_pair(ids_in_this_segment[i], pair));
}
return Status::OK();
}
Status
GetEntityByIdSegmentHandler::PostIterate() {
// construct result
// Note: makesure the result sequence is according to input ids
// for example:
// No.1, No.3, No.5 id are in segment_1
// No.2, No.4, No.6 id are in segment_2
// After iteration, we got two DataChunk,
// the chunk_1 for No.1, No.3, No.5 entities, the chunk_2 for No.2, No.4, No.6 entities
// now we combine chunk_1 and chunk_2 into one DataChunk, and the entities sequence is 1,2,3,4,5,6
Segment temp_segment;
auto& fields = ss_->GetResources<snapshot::Field>();
for (auto& kv : fields) {
const snapshot::FieldPtr& field = kv.second.Get();
STATUS_CHECK(temp_segment.AddField(field));
}
temp_segment.Reserve(field_names_, result_map_.size());
valid_row_.clear();
valid_row_.reserve(ids_.size());
for (auto id : ids_) {
auto iter = result_map_.find(id);
if (iter == result_map_.end()) {
valid_row_.push_back(false);
} else {
valid_row_.push_back(true);
auto pair = iter->second;
temp_segment.AppendChunk(pair.first, pair.second, pair.second);
}
}
STATUS_CHECK(segment_reader.LoadFieldsEntities(field_names_, offsets, data_chunk_));
data_chunk_ = std::make_shared<engine::DataChunk>();
data_chunk_->count_ = temp_segment.GetRowCount();
data_chunk_->fixed_fields_.swap(temp_segment.GetFixedFields());
data_chunk_->variable_fields_.swap(temp_segment.GetVariableFields());
return Status::OK();
}
......
......@@ -20,6 +20,8 @@
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace milvus {
......@@ -61,12 +63,20 @@ struct GetEntityByIdSegmentHandler : public snapshot::SegmentIterator {
Status
Handle(const typename ResourceT::Ptr&) override;
Status
PostIterate() override;
const server::ContextPtr context_;
const std::string dir_root_;
const engine::IDNumbers ids_;
const std::vector<std::string> field_names_;
engine::DataChunkPtr data_chunk_;
std::vector<bool>& valid_row_;
private:
engine::IDNumbers ids_left_;
using IDChunkMap = std::unordered_map<idx_t, std::pair<engine::DataChunkPtr, int64_t>>;
IDChunkMap result_map_; // record id in which chunk, and its position within the chunk
};
///////////////////////////////////////////////////////////////////////////////
......
......@@ -168,7 +168,7 @@ using QueryResultPtr = std::shared_ptr<QueryResult>;
struct DBMetaOptions {
std::string path_;
std::string backend_uri_;
}; // DBMetaOptions
};
///////////////////////////////////////////////////////////////////////////////////////////////////
struct DBOptions {
......@@ -178,7 +178,6 @@ struct DBOptions {
int mode_ = MODE::SINGLE;
size_t insert_buffer_size_ = 4 * GB;
bool insert_cache_immediately_ = false;
int64_t auto_flush_interval_ = 1;
......@@ -186,13 +185,12 @@ struct DBOptions {
// wal relative configurations
bool wal_enable_ = false;
int64_t buffer_size_ = 256;
std::string mxlog_path_ = "/tmp/milvus/wal/";
std::string wal_path_;
// transcript configurations
bool transcript_enable_ = false;
std::string replay_script_path_; // for replay
}; // Options
};
} // namespace engine
} // namespace milvus
......@@ -804,20 +804,26 @@ ExecutionEngineImpl::BuildKnowhereIndex(const std::string& field_name, const Col
std::vector<idx_t> uids;
faiss::ConcurrentBitsetPtr blacklist;
knowhere::DatasetPtr dataset;
if (from_index) {
auto dataset =
dataset =
knowhere::GenDatasetWithIds(row_count, dimension, from_index->GetRawVectors(), from_index->GetRawIds());
new_index->BuildAll(dataset, conf);
uids = from_index->GetUids();
blacklist = from_index->GetBlacklist();
} else if (bin_from_index) {
auto dataset = knowhere::GenDatasetWithIds(row_count, dimension, bin_from_index->GetRawVectors(),
bin_from_index->GetRawIds());
new_index->BuildAll(dataset, conf);
dataset = knowhere::GenDatasetWithIds(row_count, dimension, bin_from_index->GetRawVectors(),
bin_from_index->GetRawIds());
uids = bin_from_index->GetUids();
blacklist = bin_from_index->GetBlacklist();
}
try {
new_index->BuildAll(dataset, conf);
} catch (std::exception& ex) {
std::string msg = "Knowhere failed to build index: " + std::string(ex.what());
return Status(DB_ERROR, msg);
}
#ifdef MILVUS_GPU_VERSION
/* for GPU index, need copy back to CPU */
if (new_index->index_mode() == knowhere::IndexMode::MODE_GPU) {
......
......@@ -19,6 +19,7 @@
#include <ctime>
#include <memory>
#include <string>
#include <utility>
#include <fiu/fiu-local.h>
......@@ -27,6 +28,7 @@
#include "db/snapshot/CompoundOperations.h"
#include "db/snapshot/IterateHandler.h"
#include "db/snapshot/Snapshots.h"
#include "db/wal/WalManager.h"
#include "utils/CommonUtil.h"
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
......@@ -39,56 +41,58 @@ MemCollection::MemCollection(int64_t collection_id, const DBOptions& options)
}
Status
MemCollection::Add(int64_t partition_id, const milvus::engine::VectorSourcePtr& source) {
while (!source->AllAdded()) {
std::lock_guard<std::mutex> lock(mutex_);
MemSegmentPtr current_mem_segment;
auto pair = mem_segments_.find(partition_id);
if (pair != mem_segments_.end()) {
MemSegmentList& segments = pair->second;
if (!segments.empty()) {
current_mem_segment = segments.back();
}
MemCollection::Add(int64_t partition_id, const DataChunkPtr& chunk, idx_t op_id) {
std::lock_guard<std::mutex> lock(mem_mutex_);
MemSegmentPtr current_mem_segment;
auto pair = mem_segments_.find(partition_id);
if (pair != mem_segments_.end()) {
MemSegmentList& segments = pair->second;
if (!segments.empty()) {
current_mem_segment = segments.back();
}
}
Status status;
if (current_mem_segment == nullptr || current_mem_segment->IsFull()) {
MemSegmentPtr new_mem_segment = std::make_shared<MemSegment>(collection_id_, partition_id, options_);
STATUS_CHECK(new_mem_segment->CreateSegment());
status = new_mem_segment->Add(source);
if (status.ok()) {
mem_segments_[partition_id].emplace_back(new_mem_segment);
} else {
return status;
}
int64_t chunk_size = utils::GetSizeOfChunk(chunk);
Status status;
if (current_mem_segment == nullptr || current_mem_segment->GetCurrentMem() + chunk_size > MAX_MEM_SEGMENT_SIZE) {
MemSegmentPtr new_mem_segment = std::make_shared<MemSegment>(collection_id_, partition_id, options_);
status = new_mem_segment->Add(chunk, op_id);
if (status.ok()) {
mem_segments_[partition_id].emplace_back(new_mem_segment);
} else {
status = current_mem_segment->Add(source);
return status;
}
} else {
status = current_mem_segment->Add(chunk, op_id);
}
if (!status.ok()) {
std::string err_msg = "Insert failed: " + status.ToString();
LOG_ENGINE_ERROR_ << LogOut("[%s][%ld] ", "insert", 0) << err_msg;
return Status(DB_ERROR, err_msg);
}
if (!status.ok()) {
std::string err_msg = "Insert failed: " + status.ToString();
LOG_ENGINE_ERROR_ << LogOut("[%s][%ld] ", "insert", 0) << err_msg;
return Status(DB_ERROR, err_msg);
}
return Status::OK();
}
Status
MemCollection::Delete(const std::vector<idx_t>& ids) {
// Locate which collection file the doc id lands in
{
std::lock_guard<std::mutex> lock(mutex_);
for (auto& partition_segments : mem_segments_) {
MemSegmentList& segments = partition_segments.second;
for (auto& segment : segments) {
segment->Delete(ids);
}
}
MemCollection::Delete(const std::vector<idx_t>& ids, idx_t op_id) {
if (ids.empty()) {
return Status::OK();
}
// Add the id to delete list so it can be applied to other segments on disk during the next flush
// Add the id so it can be applied to segment files during the next flush
for (auto& id : ids) {
doc_ids_to_delete_.insert(id);
ids_to_delete_.insert(id);
}
// Add the id to mem segments so it can be applied during the next flush
std::lock_guard<std::mutex> lock(mem_mutex_);
for (auto& partition_segments : mem_segments_) {
for (auto& segment : partition_segments.second) {
segment->Delete(ids, op_id);
}
}
return Status::OK();
......@@ -96,7 +100,7 @@ MemCollection::Delete(const std::vector<idx_t>& ids) {
Status
MemCollection::EraseMem(int64_t partition_id) {
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(mem_mutex_);
auto pair = mem_segments_.find(partition_id);
if (pair != mem_segments_.end()) {
mem_segments_.erase(pair);
......@@ -109,26 +113,16 @@ Status
MemCollection::Serialize() {
TimeRecorder recorder("MemCollection::Serialize collection " + std::to_string(collection_id_));
if (!doc_ids_to_delete_.empty()) {
while (true) {
auto status = ApplyDeletes();
if (status.ok()) {
break;
} else if (status.code() == SS_STALE_ERROR) {
std::string err = "ApplyDeletes is stale, try again";
LOG_ENGINE_WARNING_ << err;
continue;
} else {
std::string err = "ApplyDeletes failed: " + status.ToString();
LOG_ENGINE_ERROR_ << err;
return status;
}
}
// apply deleted ids to exist setment files
auto status = ApplyDeleteToFile();
if (!status.ok()) {
LOG_ENGINE_DEBUG_ << "Failed to apply deleted ids to segment files" << status.message();
// Note: don't return here, continue serialize mem segments
}
doc_ids_to_delete_.clear();
std::lock_guard<std::mutex> lock(mutex_);
// serialize mem to new segment files
// delete ids will be applied in MemSegment::Serialize() method
std::lock_guard<std::mutex> lock(mem_mutex_);
for (auto& partition_segments : mem_segments_) {
MemSegmentList& segments = partition_segments.second;
for (auto& segment : segments) {
......@@ -136,7 +130,6 @@ MemCollection::Serialize() {
if (!status.ok()) {
return status;
}
LOG_ENGINE_DEBUG_ << "Flushed segment " << segment->GetSegmentId() << " of collection " << collection_id_;
}
}
......@@ -147,32 +140,18 @@ MemCollection::Serialize() {
return Status::OK();
}
int64_t
MemCollection::GetCollectionId() const {
return collection_id_;
}
size_t
MemCollection::GetCurrentMem() {
std::lock_guard<std::mutex> lock(mutex_);
size_t total_mem = 0;
for (auto& partition_segments : mem_segments_) {
MemSegmentList& segments = partition_segments.second;
for (auto& segment : segments) {
total_mem += segment->GetCurrentMem();
}
}
return total_mem;
}
Status
MemCollection::ApplyDeletes() {
MemCollection::ApplyDeleteToFile() {
// iterate each segment to delete entities
snapshot::ScopedSnapshotT ss;
STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_id_));
snapshot::OperationContext context;
auto segments_op = std::make_shared<snapshot::CompoundSegmentsOperation>(context, ss);
std::unordered_set<idx_t> ids_to_delete;
ids_to_delete.swap(ids_to_delete_);
int64_t segment_iterated = 0;
auto segment_executor = [&](const snapshot::SegmentPtr& segment, snapshot::SegmentIterator* iterator) -> Status {
segment_iterated++;
......@@ -181,27 +160,22 @@ MemCollection::ApplyDeletes() {
std::make_shared<segment::SegmentReader>(options_.meta_.path_, seg_visitor);
// Step 1: Check delete_id in mem
std::vector<idx_t> delete_ids;
{
segment::IdBloomFilterPtr pre_bloom_filter;
STATUS_CHECK(segment_reader->LoadBloomFilter(pre_bloom_filter));
for (auto& id : doc_ids_to_delete_) {
if (pre_bloom_filter->Check(id)) {
delete_ids.push_back(id);
}
std::set<idx_t> ids_to_check;
segment::IdBloomFilterPtr pre_bloom_filter;
STATUS_CHECK(segment_reader->LoadBloomFilter(pre_bloom_filter));
for (auto& id : ids_to_delete) {
if (pre_bloom_filter->Check(id)) {
ids_to_check.insert(id);
}
}
if (delete_ids.empty()) {
return Status::OK();
}
if (ids_to_check.empty()) {
return Status::OK();
}
std::vector<engine::idx_t> uids;
STATUS_CHECK(segment_reader->LoadUids(uids));
std::sort(delete_ids.begin(), delete_ids.end());
std::set<idx_t> ids_to_check(delete_ids.begin(), delete_ids.end());
// Step 2: Mark previous deleted docs file and bloom filter file stale
auto& field_visitors_map = seg_visitor->GetFieldVisitors();
auto uid_field_visitor = seg_visitor->GetFieldVisitor(engine::FIELD_UID);
......@@ -307,5 +281,23 @@ MemCollection::ApplyDeletes() {
return segments_op->Push();
}
int64_t
MemCollection::GetCollectionId() const {
return collection_id_;
}
size_t
MemCollection::GetCurrentMem() {
std::lock_guard<std::mutex> lock(mem_mutex_);
size_t total_mem = 0;
for (auto& partition_segments : mem_segments_) {
MemSegmentList& segments = partition_segments.second;
for (auto& segment : segments) {
total_mem += segment->GetCurrentMem();
}
}
return total_mem;
}
} // namespace engine
} // namespace milvus
......@@ -12,16 +12,17 @@
#pragma once
#include <atomic>
#include <map>
#include <memory>
#include <mutex>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "config/ConfigMgr.h"
#include "db/insert/MemSegment.h"
#include "db/insert/VectorSource.h"
#include "utils/Status.h"
namespace milvus {
......@@ -37,10 +38,10 @@ class MemCollection {
~MemCollection() = default;
Status
Add(int64_t partition_id, const VectorSourcePtr& source);
Add(int64_t partition_id, const DataChunkPtr& chunk, idx_t op_id);
Status
Delete(const std::vector<idx_t>& ids);
Delete(const std::vector<idx_t>& ids, idx_t op_id);
Status
EraseMem(int64_t partition_id);
......@@ -56,18 +57,16 @@ class MemCollection {
private:
Status
ApplyDeletes();
ApplyDeleteToFile();
private:
int64_t collection_id_;
MemSegmentMap mem_segments_;
DBOptions options_;
std::mutex mutex_;
MemSegmentMap mem_segments_;
std::mutex mem_mutex_;
std::set<idx_t> doc_ids_to_delete_;
std::unordered_set<idx_t> ids_to_delete_;
};
using MemCollectionPtr = std::shared_ptr<MemCollection>;
......
......@@ -14,7 +14,6 @@
#include <fiu/fiu-local.h>
#include <thread>
#include "VectorSource.h"
#include "db/Constants.h"
#include "db/snapshot/Snapshots.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
......@@ -42,9 +41,8 @@ MemManagerImpl::InsertEntities(int64_t collection_id, int64_t partition_id, cons
return status;
}
VectorSourcePtr source = std::make_shared<VectorSource>(chunk, op_id);
std::unique_lock<std::mutex> lock(mutex_);
return InsertEntitiesNoLock(collection_id, partition_id, source);
return InsertEntitiesNoLock(collection_id, partition_id, chunk, op_id);
}
Status
......@@ -140,11 +138,11 @@ MemManagerImpl::ValidateChunk(int64_t collection_id, const DataChunkPtr& chunk)
}
Status
MemManagerImpl::InsertEntitiesNoLock(int64_t collection_id, int64_t partition_id,
const milvus::engine::VectorSourcePtr& source) {
MemManagerImpl::InsertEntitiesNoLock(int64_t collection_id, int64_t partition_id, const DataChunkPtr& chunk,
idx_t op_id) {
MemCollectionPtr mem = GetMemByCollection(collection_id);
auto status = mem->Add(partition_id, source);
auto status = mem->Add(partition_id, chunk, op_id);
return status;
}
......@@ -153,7 +151,7 @@ MemManagerImpl::DeleteEntities(int64_t collection_id, const std::vector<idx_t>&
std::unique_lock<std::mutex> lock(mutex_);
MemCollectionPtr mem = GetMemByCollection(collection_id);
auto status = mem->Delete(entity_ids);
auto status = mem->Delete(entity_ids, op_id);
if (!status.ok()) {
return status;
}
......@@ -186,13 +184,15 @@ MemManagerImpl::InternalFlush(std::set<int64_t>& collection_ids) {
std::unique_lock<std::mutex> lock(serialization_mtx_);
for (auto& mem : temp_immutable_list) {
LOG_ENGINE_DEBUG_ << "Flushing collection: " << mem->GetCollectionId();
int64_t collection_id = mem->GetCollectionId();
LOG_ENGINE_DEBUG_ << "Flushing collection: " << collection_id;
auto status = mem->Serialize();
if (!status.ok()) {
LOG_ENGINE_ERROR_ << "Flush collection " << mem->GetCollectionId() << " failed";
LOG_ENGINE_ERROR_ << "Flush collection " << collection_id << " failed";
return status;
}
LOG_ENGINE_DEBUG_ << "Flushed collection: " << mem->GetCollectionId();
LOG_ENGINE_DEBUG_ << "Flushed collection: " << collection_id;
collection_ids.insert(collection_id);
}
return Status::OK();
......
......@@ -73,7 +73,7 @@ class MemManagerImpl : public MemManager {
ValidateChunk(int64_t collection_id, const DataChunkPtr& chunk);
Status
InsertEntitiesNoLock(int64_t collection_id, int64_t partition_id, const VectorSourcePtr& source);
InsertEntitiesNoLock(int64_t collection_id, int64_t partition_id, const DataChunkPtr& chunk, idx_t op_id);
Status
ToImmutable();
......
......@@ -15,6 +15,7 @@
#include <cmath>
#include <iterator>
#include <string>
#include <utility>
#include <vector>
#include "config/ServerConfig.h"
......@@ -22,8 +23,10 @@
#include "db/Utils.h"
#include "db/snapshot/Operations.h"
#include "db/snapshot/Snapshots.h"
#include "db/wal/WalManager.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#include "metrics/Metrics.h"
#include "utils/CommonUtil.h"
#include "utils/Log.h"
namespace milvus {
......@@ -31,25 +34,110 @@ namespace engine {
MemSegment::MemSegment(int64_t collection_id, int64_t partition_id, const DBOptions& options)
: collection_id_(collection_id), partition_id_(partition_id), options_(options) {
current_mem_ = 0;
// CreateSegment();
}
Status
MemSegment::CreateSegment() {
MemSegment::Add(const DataChunkPtr& chunk, idx_t op_id) {
if (chunk == nullptr) {
return Status::OK();
}
MemAction action;
action.op_id_ = op_id;
action.insert_data_ = chunk;
actions_.emplace_back(action);
current_mem_ += utils::GetSizeOfChunk(chunk);
return Status::OK();
}
Status
MemSegment::Delete(const std::vector<idx_t>& ids, idx_t op_id) {
if (ids.empty()) {
return Status::OK();
}
MemAction action;
action.op_id_ = op_id;
for (auto& id : ids) {
action.delete_ids_.insert(id);
}
actions_.emplace_back(action);
return Status::OK();
}
Status
MemSegment::Serialize() {
int64_t size = GetCurrentMem();
server::CollectSerializeMetrics metrics(size);
// delete in mem
STATUS_CHECK(ApplyDeleteToMem());
// create new segment and serialize
snapshot::ScopedSnapshotT ss;
auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_id_);
if (!status.ok()) {
std::string err_msg = "MemSegment::CreateSegment failed: " + status.ToString();
std::string err_msg = "Failed to get latest snapshot: " + status.ToString();
LOG_ENGINE_ERROR_ << err_msg;
return status;
}
std::shared_ptr<snapshot::NewSegmentOperation> new_seg_operation;
segment::SegmentWriterPtr segment_writer;
status = CreateNewSegment(ss, new_seg_operation, segment_writer);
if (!status.ok()) {
LOG_ENGINE_ERROR_ << "Failed to create new segment";
return status;
}
status = PutChunksToWriter(segment_writer);
if (!status.ok()) {
LOG_ENGINE_ERROR_ << "Failed to copy data to segment writer";
return status;
}
// delete action could delete all entities of the segment
// no need to serialize empty segment
if (segment_writer->RowCount() == 0) {
return Status::OK();
}
int64_t seg_id = 0;
segment_writer->GetSegmentID(seg_id);
status = segment_writer->Serialize();
if (!status.ok()) {
LOG_ENGINE_ERROR_ << "Failed to serialize segment: " << seg_id;
return status;
}
STATUS_CHECK(new_seg_operation->CommitRowCount(segment_writer->RowCount()));
STATUS_CHECK(new_seg_operation->Push());
LOG_ENGINE_DEBUG_ << "New segment " << seg_id << " of collection " << collection_id_ << " serialized";
// notify wal the max operation id is done
idx_t max_op_id = 0;
for (auto& action : actions_) {
if (action.op_id_ > max_op_id) {
max_op_id = action.op_id_;
}
}
WalManager::GetInstance().OperationDone(ss->GetName(), max_op_id);
return Status::OK();
}
Status
MemSegment::CreateNewSegment(snapshot::ScopedSnapshotT& ss, std::shared_ptr<snapshot::NewSegmentOperation>& operation,
segment::SegmentWriterPtr& writer) {
// create segment
snapshot::SegmentPtr segment;
snapshot::OperationContext context;
context.prev_partition = ss->GetResource<snapshot::Partition>(partition_id_);
operation_ = std::make_shared<snapshot::NewSegmentOperation>(context, ss);
status = operation_->CommitNewSegment(segment_);
operation = std::make_shared<snapshot::NewSegmentOperation>(context, ss);
auto status = operation->CommitNewSegment(segment);
if (!status.ok()) {
std::string err_msg = "MemSegment::CreateSegment failed: " + status.ToString();
LOG_ENGINE_ERROR_ << err_msg;
......@@ -62,12 +150,12 @@ MemSegment::CreateSegment() {
snapshot::SegmentFileContext sf_context;
sf_context.collection_id = collection_id_;
sf_context.partition_id = partition_id_;
sf_context.segment_id = segment_->GetID();
sf_context.segment_id = segment->GetID();
sf_context.field_name = name;
sf_context.field_element_name = engine::ELEMENT_RAW_DATA;
snapshot::SegmentFilePtr seg_file;
status = operation_->CommitNewSegmentFile(sf_context, seg_file);
status = operation->CommitNewSegmentFile(sf_context, seg_file);
if (!status.ok()) {
std::string err_msg = "MemSegment::CreateSegment failed: " + status.ToString();
LOG_ENGINE_ERROR_ << err_msg;
......@@ -80,12 +168,12 @@ MemSegment::CreateSegment() {
snapshot::SegmentFileContext sf_context;
sf_context.collection_id = collection_id_;
sf_context.partition_id = partition_id_;
sf_context.segment_id = segment_->GetID();
sf_context.segment_id = segment->GetID();
sf_context.field_name = engine::FIELD_UID;
sf_context.field_element_name = engine::ELEMENT_DELETED_DOCS;
snapshot::SegmentFilePtr delete_doc_file, bloom_filter_file;
status = operation_->CommitNewSegmentFile(sf_context, delete_doc_file);
status = operation->CommitNewSegmentFile(sf_context, delete_doc_file);
if (!status.ok()) {
std::string err_msg = "MemSegment::CreateSegment failed: " + status.ToString();
LOG_ENGINE_ERROR_ << err_msg;
......@@ -93,7 +181,7 @@ MemSegment::CreateSegment() {
}
sf_context.field_element_name = engine::ELEMENT_BLOOM_FILTER;
status = operation_->CommitNewSegmentFile(sf_context, bloom_filter_file);
status = operation->CommitNewSegmentFile(sf_context, bloom_filter_file);
if (!status.ok()) {
std::string err_msg = "MemSegment::CreateSegment failed: " + status.ToString();
LOG_ENGINE_ERROR_ << err_msg;
......@@ -101,72 +189,61 @@ MemSegment::CreateSegment() {
}
}
auto ctx = operation_->GetContext();
auto ctx = operation->GetContext();
auto visitor = SegmentVisitor::Build(ss, ctx.new_segment, ctx.new_segment_files);
// create segment writer
segment_writer_ptr_ = std::make_shared<segment::SegmentWriter>(options_.meta_.path_, visitor);
writer = std::make_shared<segment::SegmentWriter>(options_.meta_.path_, visitor);
return Status::OK();
}
Status
MemSegment::GetSingleEntitySize(int64_t& single_size) {
snapshot::ScopedSnapshotT ss;
auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_id_);
if (!status.ok()) {
std::string err_msg = "MemSegment::SingleEntitySize failed: " + status.ToString();
LOG_ENGINE_ERROR_ << err_msg;
return status;
}
MemSegment::ApplyDeleteToMem() {
auto outer_iter = actions_.begin();
for (; outer_iter != actions_.end(); ++outer_iter) {
MemAction& action = (*outer_iter);
if (action.delete_ids_.empty()) {
continue;
}
single_size = 0;
std::vector<std::string> field_names = ss->GetFieldNames();
for (auto& name : field_names) {
snapshot::FieldPtr field = ss->GetField(name);
auto ftype = static_cast<DataType>(field->GetFtype());
switch (ftype) {
case DataType::BOOL:
single_size += sizeof(bool);
break;
case DataType::DOUBLE:
single_size += sizeof(double);
break;
case DataType::FLOAT:
single_size += sizeof(float);
break;
case DataType::INT8:
single_size += sizeof(uint8_t);
break;
case DataType::INT16:
single_size += sizeof(uint16_t);
break;
case DataType::INT32:
single_size += sizeof(uint32_t);
break;
case DataType::INT64:
single_size += sizeof(uint64_t);
break;
case DataType::VECTOR_FLOAT:
case DataType::VECTOR_BINARY: {
json params = field->GetParams();
if (params.find(knowhere::meta::DIM) == params.end()) {
std::string msg = "Vector field params must contain: dimension";
LOG_SERVER_ERROR_ << msg;
return Status(DB_ERROR, msg);
}
auto inner_iter = actions_.begin();
for (; inner_iter != outer_iter; ++inner_iter) {
MemAction& insert_action = (*inner_iter);
if (insert_action.insert_data_ == nullptr) {
continue;
}
int64_t dimension = params[knowhere::meta::DIM];
if (ftype == DataType::VECTOR_BINARY) {
single_size += (dimension / 8);
} else {
single_size += (dimension * sizeof(float));
}
DataChunkPtr& chunk = insert_action.insert_data_;
// load chunk uids
auto iter = chunk->fixed_fields_.find(FIELD_UID);
if (iter == chunk->fixed_fields_.end()) {
continue; // no uid field?
}
BinaryDataPtr& uid_data = iter->second;
if (uid_data == nullptr) {
continue; // no uid data?
}
if (uid_data->data_.size() / sizeof(idx_t) != chunk->count_) {
continue; // invalid uid data?
}
idx_t* uid = (idx_t*)(uid_data->data_.data());
break;
// calculte delete offsets
std::vector<offset_t> offsets;
for (int64_t i = 0; i < chunk->count_; ++i) {
if (action.delete_ids_.find(uid[i]) != action.delete_ids_.end()) {
offsets.push_back(i);
}
}
default:
break;
// delete entities from chunks
Segment temp_set;
STATUS_CHECK(temp_set.SetFields(collection_id_));
STATUS_CHECK(temp_set.AddChunk(chunk));
temp_set.DeleteEntity(offsets);
chunk->count_ = temp_set.GetRowCount();
}
}
......@@ -174,100 +251,23 @@ MemSegment::GetSingleEntitySize(int64_t& single_size) {
}
Status
MemSegment::Add(const VectorSourcePtr& source) {
int64_t single_entity_mem_size = 0;
auto status = GetSingleEntitySize(single_entity_mem_size);
if (!status.ok()) {
return status;
MemSegment::PutChunksToWriter(const segment::SegmentWriterPtr& writer) {
if (writer == nullptr) {
return Status(DB_ERROR, "Segment writer is null pointer");
}
size_t mem_left = GetMemLeft();
if (mem_left >= single_entity_mem_size && single_entity_mem_size != 0) {
int64_t num_entities_to_add = std::ceil(mem_left / single_entity_mem_size);
int64_t num_entities_added;
auto status = source->Add(segment_writer_ptr_, num_entities_to_add, num_entities_added);
if (status.ok()) {
current_mem_ += (num_entities_added * single_entity_mem_size);
}
return status;
}
return Status::OK();
}
Status
MemSegment::Delete(const std::vector<idx_t>& ids) {
engine::SegmentPtr segment_ptr;
segment_writer_ptr_->GetSegment(segment_ptr);
// Check wither the doc_id is present, if yes, delete it's corresponding buffer
std::vector<idx_t> uids;
segment_writer_ptr_->LoadUids(uids);
std::vector<offset_t> offsets;
for (auto id : ids) {
auto found = std::find(uids.begin(), uids.end(), id);
if (found == uids.end()) {
for (auto& action : actions_) {
DataChunkPtr chunk = action.insert_data_;
if (chunk == nullptr || chunk->count_ == 0) {
continue;
}
auto offset = std::distance(uids.begin(), found);
offsets.push_back(offset);
}
segment_ptr->DeleteEntity(offsets);
return Status::OK();
}
int64_t
MemSegment::GetCurrentMem() {
return current_mem_;
}
int64_t
MemSegment::GetMemLeft() {
return (MAX_TABLE_FILE_MEM - current_mem_);
}
bool
MemSegment::IsFull() {
int64_t single_entity_mem_size = 0;
auto status = GetSingleEntitySize(single_entity_mem_size);
if (!status.ok()) {
return true;
}
return (GetMemLeft() < single_entity_mem_size);
}
Status
MemSegment::Serialize() {
int64_t size = GetCurrentMem();
server::CollectSerializeMetrics metrics(size);
// delete action could delete all entities of the segment
// no need to serialize empty segment
if (segment_writer_ptr_->RowCount() == 0) {
return Status::OK();
// copy data to writer
writer->AddChunk(chunk);
}
auto status = segment_writer_ptr_->Serialize();
if (!status.ok()) {
LOG_ENGINE_ERROR_ << "Failed to serialize segment: " << segment_->GetID();
return status;
}
STATUS_CHECK(operation_->CommitRowCount(segment_writer_ptr_->RowCount()));
STATUS_CHECK(operation_->Push());
LOG_ENGINE_DEBUG_ << "New segment " << segment_->GetID() << " serialized";
return Status::OK();
}
int64_t
MemSegment::GetSegmentId() const {
return segment_->GetID();
}
} // namespace engine
} // namespace milvus
......@@ -11,12 +11,14 @@
#pragma once
#include <map>
#include <memory>
#include <set>
#include <string>
#include <unordered_set>
#include <vector>
#include "config/ConfigMgr.h"
#include "db/insert/VectorSource.h"
#include "db/snapshot/CompoundOperations.h"
#include "db/snapshot/Resources.h"
#include "segment/SegmentWriter.h"
......@@ -25,6 +27,13 @@
namespace milvus {
namespace engine {
class MemAction {
public:
idx_t op_id_ = 0;
std::unordered_set<idx_t> delete_ids_;
DataChunkPtr insert_data_;
};
class MemSegment {
public:
MemSegment(int64_t collection_id, int64_t partition_id, const DBOptions& options);
......@@ -33,43 +42,39 @@ class MemSegment {
public:
Status
CreateSegment();
Add(const DataChunkPtr& chunk, idx_t op_id);
Status
Add(const VectorSourcePtr& source);
Status
Delete(const std::vector<idx_t>& ids);
int64_t
GetCurrentMem();
Delete(const std::vector<idx_t>& ids, idx_t op_id);
int64_t
GetMemLeft();
bool
IsFull();
GetCurrentMem() const {
return current_mem_;
}
Status
Serialize();
int64_t
GetSegmentId() const;
private:
Status
GetSingleEntitySize(int64_t& single_size);
CreateNewSegment(snapshot::ScopedSnapshotT& ss, std::shared_ptr<snapshot::NewSegmentOperation>& operation,
segment::SegmentWriterPtr& writer);
Status
ApplyDeleteToMem();
Status
PutChunksToWriter(const segment::SegmentWriterPtr& writer);
private:
int64_t collection_id_;
int64_t partition_id_;
std::shared_ptr<snapshot::NewSegmentOperation> operation_;
snapshot::SegmentPtr segment_;
DBOptions options_;
int64_t current_mem_;
int64_t current_mem_ = 0;
segment::SegmentWriterPtr segment_writer_ptr_;
using ActionArray = std::vector<MemAction>;
ActionArray actions_; // the actions array mekesure insert/delete actions executed one by one
};
using MemSegmentPtr = std::shared_ptr<MemSegment>;
......
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#include "db/insert/VectorSource.h"
#include <utility>
#include <vector>
#include "metrics/Metrics.h"
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
namespace milvus {
namespace engine {
VectorSource::VectorSource(const DataChunkPtr& chunk, idx_t op_id) : chunk_(chunk), op_id_(op_id) {
}
Status
VectorSource::Add(const segment::SegmentWriterPtr& segment_writer_ptr, const int64_t& num_entities_to_add,
int64_t& num_entities_added) {
// TODO: n = vectors_.vector_count_;???
int64_t n = chunk_->count_;
num_entities_added = current_num_added_ + num_entities_to_add <= n ? num_entities_to_add : n - current_num_added_;
auto status = segment_writer_ptr->AddChunk(chunk_, current_num_added_, num_entities_added);
if (!status.ok()) {
return status;
}
current_num_added_ += num_entities_added;
return status;
}
bool
VectorSource::AllAdded() {
return (current_num_added_ >= chunk_->count_);
}
} // namespace engine
} // namespace milvus
// Copyright (C) 2019-2020 Zilliz. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software distributed under the License
// is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
// or implied. See the License for the specific language governing permissions and limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "db/IDGenerator.h"
#include "db/insert/MemManager.h"
#include "segment/Segment.h"
#include "segment/SegmentWriter.h"
#include "utils/Status.h"
namespace milvus {
namespace engine {
class VectorSource {
public:
explicit VectorSource(const DataChunkPtr& chunk, idx_t op_id);
Status
Add(const segment::SegmentWriterPtr& segment_writer_ptr, const int64_t& num_attrs_to_add, int64_t& num_attrs_added);
bool
AllAdded();
idx_t
OperationID() const {
return op_id_;
}
private:
DataChunkPtr chunk_;
idx_t op_id_ = 0;
int64_t current_num_added_ = 0;
};
using VectorSourcePtr = std::shared_ptr<VectorSource>;
} // namespace engine
} // namespace milvus
......@@ -46,7 +46,7 @@ enum class MergeStrategyType {
class MergeManager {
public:
virtual Status
MergeFiles(const std::string& collection_id, MergeStrategyType type = MergeStrategyType::SIMPLE) = 0;
MergeFiles(int64_t collection_id, MergeStrategyType type = MergeStrategyType::SIMPLE) = 0;
}; // MergeManager
using MergeManagerPtr = std::shared_ptr<MergeManager>;
......
......@@ -44,7 +44,7 @@ MergeManagerImpl::CreateStrategy(MergeStrategyType type, MergeStrategyPtr& strat
}
Status
MergeManagerImpl::MergeFiles(const std::string& collection_name, MergeStrategyType type) {
MergeManagerImpl::MergeFiles(int64_t collection_id, MergeStrategyType type) {
MergeStrategyPtr strategy;
auto status = CreateStrategy(type, strategy);
if (!status.ok()) {
......@@ -53,7 +53,7 @@ MergeManagerImpl::MergeFiles(const std::string& collection_name, MergeStrategyTy
while (true) {
snapshot::ScopedSnapshotT latest_ss;
STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(latest_ss, collection_name));
STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(latest_ss, collection_id));
// collect all segments
Partition2SegmentsMap part2seg;
......@@ -66,7 +66,7 @@ MergeManagerImpl::MergeFiles(const std::string& collection_name, MergeStrategyTy
SegmentGroups segment_groups;
auto status = strategy->RegroupSegments(latest_ss, part2seg, segment_groups);
if (!status.ok()) {
LOG_ENGINE_ERROR_ << "Failed to regroup segments for: " << collection_name
LOG_ENGINE_ERROR_ << "Failed to regroup segments for collection: " << latest_ss->GetName()
<< ", continue to merge all files into one";
return status;
}
......
......@@ -32,7 +32,7 @@ class MergeManagerImpl : public MergeManager {
explicit MergeManagerImpl(const DBOptions& options);
Status
MergeFiles(const std::string& collection_name, MergeStrategyType type) override;
MergeFiles(int64_t collection_id, MergeStrategyType type) override;
private:
Status
......
......@@ -50,7 +50,7 @@ class WalFile {
template <typename T>
inline int64_t
Write(T* value) {
if (file_ == nullptr) {
if (file_ == nullptr || value == nullptr) {
return 0;
}
......@@ -61,7 +61,7 @@ class WalFile {
inline int64_t
Write(const void* data, int64_t length) {
if (file_ == nullptr) {
if (file_ == nullptr || data == nullptr || length <= 0) {
return 0;
}
......@@ -83,7 +83,7 @@ class WalFile {
inline int64_t
Read(void* data, int64_t length) {
if (file_ == nullptr) {
if (file_ == nullptr || length <= 0) {
return 0;
}
......
......@@ -11,9 +11,6 @@
#include "db/wal/WalManager.h"
#include "db/Utils.h"
#include "db/snapshot/ResourceHelper.h"
#include "db/snapshot/ResourceTypes.h"
#include "db/snapshot/Snapshots.h"
#include "db/wal/WalOperationCodec.h"
#include "utils/CommonUtil.h"
......@@ -26,7 +23,6 @@
namespace milvus {
namespace engine {
const char* WAL_DATA_FOLDER = "wal";
const char* WAL_MAX_OP_FILE_NAME = "max_op";
const char* WAL_DEL_FILE_NAME = "del";
......@@ -44,8 +40,7 @@ WalManager::Start(const DBOptions& options) {
enable_ = options.wal_enable_;
insert_buffer_size_ = options.insert_buffer_size_;
std::experimental::filesystem::path wal_path(options.meta_.path_);
wal_path.append((WAL_DATA_FOLDER));
std::experimental::filesystem::path wal_path(options.wal_path_);
wal_path_ = wal_path.c_str();
CommonUtil::CreateDirectory(wal_path_);
......@@ -235,7 +230,7 @@ WalManager::Init() {
file_path.append(WAL_MAX_OP_FILE_NAME);
if (std::experimental::filesystem::is_regular_file(file_path)) {
WalFile file;
file.OpenFile(path.c_str(), WalFile::READ);
file.OpenFile(file_path.c_str(), WalFile::READ);
idx_t max_op = 0;
file.Read(&max_op);
......@@ -369,29 +364,14 @@ WalManager::RecordDeleteOperation(const DeleteEntityOperationPtr& operation, con
std::string
WalManager::ConstructFilePath(const std::string& collection_name, const std::string& file_name) {
// use snapshot to construct wal path
// typically, the wal file path is like: /xxx/xxx/wal/C_1/xxxxxxxxxx
// if the snapshot not work, use collection name to construct path
snapshot::ScopedSnapshotT ss;
auto status = snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_name);
if (status.ok() && ss->GetCollection() != nullptr) {
std::string col_path = snapshot::GetResPath<snapshot::Collection>(wal_path_, ss->GetCollection());
std::experimental::filesystem::path full_path(col_path);
std::experimental::filesystem::create_directory(full_path);
full_path.append(file_name);
std::string path(full_path.c_str());
return path;
} else {
std::experimental::filesystem::path full_path(wal_path_);
full_path.append(collection_name);
std::experimental::filesystem::create_directory(full_path);
full_path.append(file_name);
std::string path(full_path.c_str());
return path;
}
// typically, the wal file path is like: /xxx/milvus/wal/[collection_name]/xxxxxxxxxx
std::experimental::filesystem::path full_path(wal_path_);
full_path.append(collection_name);
std::experimental::filesystem::create_directory(full_path);
full_path.append(file_name);
std::string path(full_path.c_str());
return path;
}
void
......
......@@ -29,7 +29,6 @@
namespace milvus {
namespace engine {
extern const char* WAL_DATA_FOLDER;
extern const char* WAL_MAX_OP_FILE_NAME;
extern const char* WAL_DEL_FILE_NAME;
......
......@@ -33,6 +33,7 @@ WalOperationCodec::WriteInsertOperation(const WalFilePtr& file, const std::strin
calculate_total_bytes += sizeof(int64_t); // calculated total bytes
calculate_total_bytes += sizeof(int32_t); // partition name length
calculate_total_bytes += partition_name.size(); // partition name
calculate_total_bytes += sizeof(int64_t); // chunk entity count
calculate_total_bytes += sizeof(int32_t); // fixed field count
for (auto& pair : chunk->fixed_fields_) {
calculate_total_bytes += sizeof(int32_t); // field name length
......@@ -61,6 +62,9 @@ WalOperationCodec::WriteInsertOperation(const WalFilePtr& file, const std::strin
total_bytes += file->Write(partition_name.data(), part_name_length);
}
// write chunk entity count
total_bytes += file->Write<int64_t>(&(chunk->count_));
// write fixed data
int32_t field_count = chunk->fixed_fields_.size();
total_bytes += file->Write<int32_t>(&field_count);
......@@ -197,6 +201,13 @@ WalOperationCodec::IterateOperation(const WalFilePtr& file, WalOperationPtr& ope
}
}
// read chunk entity countint64_t total_bytes = 0;
DataChunkPtr chunk = std::make_shared<DataChunk>();
read_bytes = file->Read<int64_t>(&(chunk->count_));
if (read_bytes <= 0) {
return Status(DB_ERROR, "End of file");
}
// read fixed data
int32_t field_count = 0;
read_bytes = file->Read<int32_t>(&field_count);
......@@ -204,7 +215,6 @@ WalOperationCodec::IterateOperation(const WalFilePtr& file, WalOperationPtr& ope
return Status(DB_ERROR, "End of file");
}
DataChunkPtr chunk = std::make_shared<DataChunk>();
for (int32_t i = 0; i < field_count; i++) {
int32_t field_name_length = 0;
read_bytes = file->Read<int32_t>(&field_name_length);
......
......@@ -45,9 +45,9 @@ class BuildIndexJob : public Job {
return options_;
}
const engine::snapshot::IDS_TYPE&
segment_ids() {
return segment_ids_;
engine::snapshot::IDS_TYPE&
FailedSegments() {
return failed_segment_ids_;
}
protected:
......@@ -58,6 +58,7 @@ class BuildIndexJob : public Job {
engine::snapshot::ScopedSnapshotT snapshot_;
engine::DBOptions options_;
engine::snapshot::IDS_TYPE segment_ids_;
engine::snapshot::IDS_TYPE failed_segment_ids_;
};
using BuildIndexJobPtr = std::shared_ptr<BuildIndexJob>;
......
......@@ -80,6 +80,10 @@ BuildIndexTask::OnLoad(milvus::scheduler::LoadType type, uint8_t device_id) {
}
LOG_ENGINE_ERROR_ << s.message();
auto build_job = static_cast<scheduler::BuildIndexJob*>(job_);
build_job->FailedSegments().push_back(segment_id_);
return s;
}
......@@ -100,9 +104,14 @@ BuildIndexTask::OnExecute() {
} catch (std::exception& e) {
status = Status(DB_ERROR, e.what());
}
if (!status.ok()) {
LOG_ENGINE_ERROR_ << "Failed to build index: " << status.ToString();
execution_engine_ = nullptr;
auto build_job = static_cast<scheduler::BuildIndexJob*>(job_);
build_job->FailedSegments().push_back(segment_id_);
return status;
}
......
......@@ -16,6 +16,9 @@
// under the License.
#include "segment/Segment.h"
#include "db/SnapshotUtils.h"
#include "db/snapshot/Snapshots.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#include "utils/Log.h"
#include <algorithm>
......@@ -27,6 +30,51 @@ namespace engine {
const char* COLLECTIONS_FOLDER = "/collections";
Status
Segment::SetFields(int64_t collection_id) {
snapshot::ScopedSnapshotT ss;
STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(ss, collection_id));
auto& fields = ss->GetResources<snapshot::Field>();
for (auto& kv : fields) {
const snapshot::FieldPtr& field = kv.second.Get();
STATUS_CHECK(AddField(field));
}
return Status::OK();
}
Status
Segment::AddField(const snapshot::FieldPtr& field) {
if (field == nullptr) {
return Status(DB_ERROR, "Field is null pointer");
}
std::string name = field->GetName();
auto ftype = static_cast<DataType>(field->GetFtype());
if (IsVectorField(field)) {
json params = field->GetParams();
if (params.find(knowhere::meta::DIM) == params.end()) {
std::string msg = "Vector field params must contain: dimension";
LOG_SERVER_ERROR_ << msg;
return Status(DB_ERROR, msg);
}
int64_t field_width = 0;
int64_t dimension = params[knowhere::meta::DIM];
if (ftype == DataType::VECTOR_BINARY) {
field_width += (dimension / 8);
} else {
field_width += (dimension * sizeof(float));
}
AddField(name, ftype, field_width);
} else {
AddField(name, ftype);
}
return Status::OK();
}
Status
Segment::AddField(const std::string& field_name, DataType field_type, int64_t field_width) {
if (field_types_.find(field_name) != field_types_.end()) {
......@@ -110,9 +158,62 @@ Segment::AddChunk(const DataChunkPtr& chunk_ptr, int64_t from, int64_t to) {
}
// consume
AppendChunk(chunk_ptr, from, to);
return Status::OK();
}
Status
Segment::Reserve(const std::vector<std::string>& field_names, int64_t count) {
if (count <= 0) {
return Status(DB_ERROR, "Invalid input fot segment resize");
}
if (field_names.empty()) {
for (auto& width_iter : fixed_fields_width_) {
int64_t resize_bytes = count * width_iter.second;
auto& data = fixed_fields_[width_iter.first];
if (data == nullptr) {
data = std::make_shared<BinaryData>();
}
data->data_.resize(resize_bytes);
}
} else {
for (const auto& name : field_names) {
auto iter_width = fixed_fields_width_.find(name);
if (iter_width == fixed_fields_width_.end()) {
return Status(DB_ERROR, "Invalid input fot segment resize");
}
int64_t resize_bytes = count * iter_width->second;
auto& data = fixed_fields_[name];
if (data == nullptr) {
data = std::make_shared<BinaryData>();
}
data->data_.resize(resize_bytes);
}
}
return Status::OK();
}
Status
Segment::AppendChunk(const DataChunkPtr& chunk_ptr, int64_t from, int64_t to) {
if (chunk_ptr == nullptr || from < 0 || to < 0 || from > to) {
return Status(DB_ERROR, "Invalid input fot segment append");
}
int64_t add_count = to - from;
if (add_count == 0) {
add_count = 1; // n ~ n also means append the No.n
}
for (auto& width_iter : fixed_fields_width_) {
auto input = chunk_ptr->fixed_fields_.find(width_iter.first);
if (input == chunk_ptr->fixed_fields_.end()) {
continue;
}
auto& data = fixed_fields_[width_iter.first];
if (data == nullptr) {
fixed_fields_[width_iter.first] = input->second;
......@@ -123,7 +224,9 @@ Segment::AddChunk(const DataChunkPtr& chunk_ptr, int64_t from, int64_t to) {
int64_t add_bytes = add_count * width_iter.second;
int64_t previous_bytes = row_count_ * width_iter.second;
int64_t target_bytes = previous_bytes + add_bytes;
data->data_.resize(target_bytes);
if (data->data_.size() < target_bytes) {
data->data_.resize(target_bytes);
}
if (input == chunk_ptr->fixed_fields_.end()) {
// this field is not provided, complicate by 0
memset(data->data_.data() + origin_bytes, 0, target_bytes - origin_bytes);
......
......@@ -23,6 +23,7 @@
#include <vector>
#include "db/Types.h"
#include "db/snapshot/Resources.h"
#include "segment/DeletedDocs.h"
#include "segment/IdBloomFilter.h"
......@@ -33,6 +34,12 @@ extern const char* COLLECTIONS_FOLDER;
class Segment {
public:
Status
SetFields(int64_t collection_id);
Status
AddField(const snapshot::FieldPtr& field);
Status
AddField(const std::string& field_name, DataType field_type, int64_t field_width = 0);
......@@ -42,6 +49,15 @@ class Segment {
Status
AddChunk(const DataChunkPtr& chunk_ptr, int64_t from, int64_t to);
// reserve chunk data capacity to specify count
// this method should only be used on an empty segment
Status
Reserve(const std::vector<std::string>& field_names, int64_t count);
// copy part of chunk data into this segment and append to tail
Status
AppendChunk(const DataChunkPtr& chunk_ptr, int64_t from, int64_t to);
Status
DeleteEntity(std::vector<offset_t>& offsets);
......
......@@ -26,7 +26,6 @@
#include "db/SnapshotUtils.h"
#include "db/Utils.h"
#include "db/snapshot/ResourceHelper.h"
#include "knowhere/index/vector_index/helpers/IndexParameter.h"
#include "storage/disk/DiskIOReader.h"
#include "storage/disk/DiskIOWriter.h"
#include "storage/disk/DiskOperation.h"
......@@ -61,27 +60,7 @@ SegmentWriter::Initialize() {
const engine::SegmentVisitor::IdMapT& field_map = segment_visitor_->GetFieldVisitors();
for (auto& iter : field_map) {
const engine::snapshot::FieldPtr& field = iter.second->GetField();
std::string name = field->GetName();
auto ftype = static_cast<engine::DataType>(field->GetFtype());
if (engine::IsVectorField(field)) {
json params = field->GetParams();
if (params.find(knowhere::meta::DIM) == params.end()) {
std::string msg = "Vector field params must contain: dimension";
LOG_SERVER_ERROR_ << msg;
return Status(DB_ERROR, msg);
}
int64_t field_width = 0;
int64_t dimension = params[knowhere::meta::DIM];
if (ftype == engine::DataType::VECTOR_BINARY) {
field_width += (dimension / 8);
} else {
field_width += (dimension * sizeof(float));
}
segment_ptr_->AddField(name, ftype, field_width);
} else {
segment_ptr_->AddField(name, ftype);
}
STATUS_CHECK(segment_ptr_->AddField(field));
}
return Status::OK();
......
......@@ -42,7 +42,6 @@ DBWrapper::StartService() {
opt.auto_flush_interval_ = config.storage.auto_flush_interval();
opt.metric_enable_ = config.metric.enable();
opt.insert_cache_immediately_ = config.cache.cache_insert_data();
opt.insert_buffer_size_ = config.cache.insert_buffer_size();
if (not config.cluster.enable()) {
......@@ -57,15 +56,8 @@ DBWrapper::StartService() {
}
opt.wal_enable_ = config.wal.enable();
// disable wal for ci devtest
opt.wal_enable_ = false;
if (opt.wal_enable_) {
int64_t wal_buffer_size = config.wal.buffer_size();
wal_buffer_size /= (1024 * 1024);
opt.buffer_size_ = wal_buffer_size;
opt.mxlog_path_ = config.wal.path();
opt.wal_path_ = config.wal.path();
}
// engine config
......
......@@ -97,9 +97,10 @@ ValidateCollectionName(const std::string& collection_name) {
}
std::string invalid_msg = "Invalid collection name: " + collection_name + ". ";
// Collection name size shouldn't exceed 255.
// Collection name size shouldn't exceed engine::MAX_NAME_LENGTH.
if (collection_name.size() > engine::MAX_NAME_LENGTH) {
std::string msg = invalid_msg + "The length of a collection name must be less than 255 characters.";
std::string msg = invalid_msg + "The length of a collection name must be less than " +
std::to_string(engine::MAX_NAME_LENGTH) + " characters.";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_COLLECTION_NAME, msg);
}
......@@ -135,9 +136,10 @@ ValidateFieldName(const std::string& field_name) {
}
std::string invalid_msg = "Invalid field name: " + field_name + ". ";
// Field name size shouldn't exceed 255.
// Field name size shouldn't exceed engine::MAX_NAME_LENGTH.
if (field_name.size() > engine::MAX_NAME_LENGTH) {
std::string msg = invalid_msg + "The length of a field name must be less than 255 characters.";
std::string msg = invalid_msg + "The length of a field name must be less than " +
std::to_string(engine::MAX_NAME_LENGTH) + " characters.";
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_FIELD_NAME, msg);
}
......@@ -438,8 +440,9 @@ ValidatePartitionTags(const std::vector<std::string>& partition_tags) {
}
// max length of partition tag
if (valid_tag.length() > 255) {
std::string msg = "Invalid partition tag: " + valid_tag + ". " + "Partition tag exceed max length(255).";
if (valid_tag.length() > engine::MAX_NAME_LENGTH) {
std::string msg = "Invalid partition tag: " + valid_tag + ". " +
"Partition tag exceed max length: " + std::to_string(engine::MAX_NAME_LENGTH);
LOG_SERVER_ERROR_ << msg;
return Status(SERVER_INVALID_PARTITION_TAG, msg);
}
......@@ -450,24 +453,8 @@ ValidatePartitionTags(const std::vector<std::string>& partition_tags) {
Status
ValidateInsertDataSize(const engine::DataChunkPtr& data) {
int64_t total_size = 0;
for (auto& pair : data->fixed_fields_) {
if (pair.second == nullptr) {
continue;
}
total_size += pair.second->Size();
}
for (auto& pair : data->variable_fields_) {
if (pair.second == nullptr) {
continue;
}
total_size += pair.second->Size();
}
if (total_size > engine::MAX_INSERT_DATA_SIZE) {
int64_t chunk_size = engine::utils::GetSizeOfChunk(data);
if (chunk_size > engine::MAX_INSERT_DATA_SIZE) {
std::string msg = "The amount of data inserted each time cannot exceed " +
std::to_string(engine::MAX_INSERT_DATA_SIZE / engine::MB) + " MB";
return Status(SERVER_INVALID_ROWRECORD_ARRAY, msg);
......
......@@ -48,7 +48,7 @@ CreateCollection(const std::shared_ptr<DB>& db, const std::string& collection_na
return db->CreateCollection(context);
}
static constexpr int64_t COLLECTION_DIM = 128;
static constexpr int64_t COLLECTION_DIM = 10;
milvus::Status
CreateCollection2(std::shared_ptr<DB> db, const std::string& collection_name, const LSN_TYPE& lsn) {
......@@ -163,6 +163,22 @@ BuildEntities(uint64_t n, uint64_t batch_index, milvus::engine::DataChunkPtr& da
}
}
void
CopyChunkData(const milvus::engine::DataChunkPtr& src_chunk, milvus::engine::DataChunkPtr& target_chunk) {
target_chunk = std::make_shared<milvus::engine::DataChunk>();
target_chunk->count_ = src_chunk->count_;
for (auto& pair : src_chunk->fixed_fields_) {
milvus::engine::BinaryDataPtr raw = std::make_shared<milvus::engine::BinaryData>();
raw->data_ = pair.second->data_;
target_chunk->fixed_fields_.insert(std::make_pair(pair.first, raw));
}
for (auto& pair : src_chunk->variable_fields_) {
milvus::engine::VaribleDataPtr raw = std::make_shared<milvus::engine::VaribleData>();
raw->data_ = pair.second->data_;
target_chunk->variable_fields_.insert(std::make_pair(pair.first, raw));
}
}
void
BuildQueryPtr(const std::string& collection_name, int64_t n, int64_t topk, std::vector<std::string>& field_names,
std::vector<std::string>& partitions, milvus::query::QueryPtr& query_ptr) {
......@@ -509,7 +525,7 @@ TEST_F(DBTest, InsertTest) {
milvus::engine::BinaryDataPtr raw = std::make_shared<milvus::engine::BinaryData>();
raw->data_.resize(100 * sizeof(int64_t));
int64_t* p = (int64_t*)raw->data_.data();
for (auto i = 0; i < data_chunk->count_; ++i) {
for (int64_t i = 0; i < data_chunk->count_; ++i) {
p[i] = i;
}
data_chunk->fixed_fields_[milvus::engine::FIELD_UID] = raw;
......@@ -518,7 +534,7 @@ TEST_F(DBTest, InsertTest) {
milvus::engine::BinaryDataPtr raw = std::make_shared<milvus::engine::BinaryData>();
raw->data_.resize(100 * sizeof(int32_t));
int32_t* p = (int32_t*)raw->data_.data();
for (auto i = 0; i < data_chunk->count_; ++i) {
for (int64_t i = 0; i < data_chunk->count_; ++i) {
p[i] = i + 5000;
}
data_chunk->fixed_fields_[field_name] = raw;
......@@ -567,16 +583,14 @@ TEST_F(DBTest, MergeTest) {
const uint64_t entity_count = 100;
milvus::engine::DataChunkPtr data_chunk;
BuildEntities(entity_count, 0, data_chunk);
// insert entities into collection multiple times
int64_t repeat = 2;
for (int32_t i = 0; i < repeat; i++) {
BuildEntities(entity_count, 0, data_chunk);
status = db_->Insert(collection_name, "", data_chunk);
ASSERT_TRUE(status.ok());
data_chunk->fixed_fields_.erase(milvus::engine::FIELD_UID); // clear auto-generated id
status = db_->Flush();
ASSERT_TRUE(status.ok());
}
......@@ -646,17 +660,25 @@ TEST_F(DBTest, GetEntityTest) {
auto insert_entities = [&](const std::string& collection, const std::string& partition,
uint64_t count, uint64_t batch_index, milvus::engine::IDNumbers& ids,
milvus::engine::DataChunkPtr& data_chunk) -> Status {
BuildEntities(count, batch_index, data_chunk);
STATUS_CHECK(db_->Insert(collection, partition, data_chunk));
milvus::engine::DataChunkPtr consume_chunk;
BuildEntities(count, batch_index, consume_chunk);
CopyChunkData(consume_chunk, data_chunk);
// Note: consume_chunk is consumed by insert()
STATUS_CHECK(db_->Insert(collection, partition, consume_chunk));
STATUS_CHECK(db_->Flush(collection));
auto iter = data_chunk->fixed_fields_.find(milvus::engine::FIELD_UID);
if (iter == data_chunk->fixed_fields_.end()) {
auto iter = consume_chunk->fixed_fields_.find(milvus::engine::FIELD_UID);
if (iter == consume_chunk->fixed_fields_.end()) {
return Status(1, "Cannot find uid field");
}
auto& ids_buffer = iter->second;
ids.resize(data_chunk->count_);
ids.resize(consume_chunk->count_);
memcpy(ids.data(), ids_buffer->data_.data(), ids_buffer->Size());
milvus::engine::BinaryDataPtr raw = std::make_shared<milvus::engine::BinaryData>();
raw->data_ = ids_buffer->data_;
data_chunk->fixed_fields_[milvus::engine::FIELD_UID] = raw;
return Status::OK();
};
......@@ -760,7 +782,7 @@ TEST_F(DBTest, CompactTest) {
ASSERT_TRUE(status.ok());
// insert 1000 entities into default partition
const uint64_t entity_count = 1000;
const uint64_t entity_count = 100;
milvus::engine::DataChunkPtr data_chunk;
BuildEntities(entity_count, 0, data_chunk);
......@@ -785,8 +807,8 @@ TEST_F(DBTest, CompactTest) {
};
// delete entities from 100 to 300
int64_t delete_count_1 = 200;
delete_entity(100, 100 + delete_count_1);
int64_t delete_count_1 = 20;
delete_entity(10, 10 + delete_count_1);
status = db_->Flush();
ASSERT_TRUE(status.ok());
......@@ -799,6 +821,7 @@ TEST_F(DBTest, CompactTest) {
ASSERT_TRUE(status.ok());
ASSERT_EQ(valid_row.size(), batch_entity_ids.size());
auto& chunk = fetch_chunk->fixed_fields_["field_0"];
ASSERT_NE(chunk, nullptr);
int32_t* p = (int32_t*)(chunk->data_.data());
int64_t index = 0;
for (uint64_t i = 0; i < valid_row.size(); ++i) {
......@@ -812,34 +835,34 @@ TEST_F(DBTest, CompactTest) {
// validate the left data is correct after deletion
validate_entity_data();
// delete entities from 700 to 800
int64_t delete_count_2 = 100;
delete_entity(700, 700 + delete_count_2);
status = db_->Flush();
ASSERT_TRUE(status.ok());
auto validate_compact = [&](double threshold) -> void {
int64_t row_count = 0;
status = db_->CountEntities(collection_name, row_count);
ASSERT_TRUE(status.ok());
ASSERT_EQ(row_count, entity_count - delete_count_1 - delete_count_2);
status = db_->Compact(dummy_context_, collection_name, threshold);
ASSERT_TRUE(status.ok());
validate_entity_data();
status = db_->CountEntities(collection_name, row_count);
ASSERT_TRUE(status.ok());
ASSERT_EQ(row_count, entity_count - delete_count_1 - delete_count_2);
validate_entity_data();
};
// compact the collection, when threshold = 0.001, the compact do nothing
validate_compact(0.001); // compact skip
validate_compact(0.5); // do compact
// // delete entities from 700 to 800
// int64_t delete_count_2 = 100;
// delete_entity(700, 700 + delete_count_2);
//
// status = db_->Flush();
// ASSERT_TRUE(status.ok());
//
// auto validate_compact = [&](double threshold) -> void {
// int64_t row_count = 0;
// status = db_->CountEntities(collection_name, row_count);
// ASSERT_TRUE(status.ok());
// ASSERT_EQ(row_count, entity_count - delete_count_1 - delete_count_2);
//
// status = db_->Compact(dummy_context_, collection_name, threshold);
// ASSERT_TRUE(status.ok());
//
// validate_entity_data();
//
// status = db_->CountEntities(collection_name, row_count);
// ASSERT_TRUE(status.ok());
// ASSERT_EQ(row_count, entity_count - delete_count_1 - delete_count_2);
//
// validate_entity_data();
// };
//
// // compact the collection, when threshold = 0.001, the compact do nothing
// validate_compact(0.001); // compact skip
// validate_compact(0.5); // do compact
}
TEST_F(DBTest, IndexTest) {
......@@ -937,8 +960,7 @@ TEST_F(DBTest, StatsTest) {
status = db_->Insert(collection_name, "", data_chunk);
ASSERT_TRUE(status.ok());
data_chunk->fixed_fields_.erase(milvus::engine::FIELD_UID); // clear auto-generated id
BuildEntities(entity_count, 0, data_chunk);
status = db_->Insert(collection_name, partition_name, data_chunk);
ASSERT_TRUE(status.ok());
......@@ -1013,7 +1035,133 @@ TEST_F(DBTest, StatsTest) {
}
}
TEST_F(DBTest, FetchTest) {
TEST_F(DBTest, FetchTest1) {
std::string collection_name = "STATS_TEST";
auto status = CreateCollection2(db_, collection_name, 0);
ASSERT_TRUE(status.ok());
std::string partition_name1 = "p1";
status = db_->CreatePartition(collection_name, partition_name1);
ASSERT_TRUE(status.ok());
std::string partition_name2 = "p2";
status = db_->CreatePartition(collection_name, partition_name2);
ASSERT_TRUE(status.ok());
milvus::engine::IDNumbers ids_1, ids_2;
std::vector<float> fetch_vectors;
{
// insert 100 entities into partition 'p1'
const uint64_t entity_count = 100;
milvus::engine::DataChunkPtr data_chunk;
BuildEntities(entity_count, 0, data_chunk);
float* p = (float*)(data_chunk->fixed_fields_[VECTOR_FIELD_NAME]->data_.data());
for (int64_t i = 0; i < COLLECTION_DIM; ++i) {
fetch_vectors.push_back(p[i]);
}
status = db_->Insert(collection_name, partition_name1, data_chunk);
ASSERT_TRUE(status.ok());
milvus::engine::utils::GetIDFromChunk(data_chunk, ids_1);
ASSERT_EQ(ids_1.size(), entity_count);
}
{
// insert 101 entities into partition 'p2'
const uint64_t entity_count = 101;
milvus::engine::DataChunkPtr data_chunk;
BuildEntities(entity_count, 0, data_chunk);
float* p = (float*)(data_chunk->fixed_fields_[VECTOR_FIELD_NAME]->data_.data());
for (int64_t i = 0; i < COLLECTION_DIM; ++i) {
fetch_vectors.push_back(p[i]);
}
status = db_->Insert(collection_name, partition_name2, data_chunk);
ASSERT_TRUE(status.ok());
milvus::engine::utils::GetIDFromChunk(data_chunk, ids_2);
ASSERT_EQ(ids_2.size(), entity_count);
}
status = db_->Flush();
ASSERT_TRUE(status.ok());
// fetch no.1 entity from partition 'p1'
// fetch no.2 entity from partition 'p2'
std::vector<std::string> field_names = {milvus::engine::FIELD_UID, VECTOR_FIELD_NAME};
std::vector<bool> valid_row;
milvus::engine::DataChunkPtr fetch_chunk;
milvus::engine::IDNumbers fetch_ids = {ids_1[0], ids_2[0]};
status = db_->GetEntityByID(collection_name, fetch_ids, field_names, valid_row, fetch_chunk);
ASSERT_TRUE(status.ok());
ASSERT_EQ(fetch_chunk->count_, fetch_ids.size());
ASSERT_EQ(fetch_chunk->fixed_fields_[VECTOR_FIELD_NAME]->data_.size(),
fetch_ids.size() * COLLECTION_DIM * sizeof(float));
// compare result
std::vector<float> result_vectors;
float* p = (float*)(fetch_chunk->fixed_fields_[VECTOR_FIELD_NAME]->data_.data());
for (int64_t i = 0; i < COLLECTION_DIM * fetch_ids.size(); i++) {
result_vectors.push_back(p[i]);
}
ASSERT_EQ(fetch_vectors, result_vectors);
// std::string collection_name = "STATS_TEST";
// auto status = CreateCollection2(db_, collection_name, 0);
// ASSERT_TRUE(status.ok());
//
// std::string partition_name1 = "p1";
// status = db_->CreatePartition(collection_name, partition_name1);
// ASSERT_TRUE(status.ok());
//
// milvus::engine::IDNumbers ids_1;
// std::vector<float> fetch_vectors;
// {
// // insert 100 entities into partition 'p1'
// const uint64_t entity_count = 100;
// milvus::engine::DataChunkPtr data_chunk;
// BuildEntities(entity_count, 0, data_chunk);
//
// float* p = (float*)(data_chunk->fixed_fields_[VECTOR_FIELD_NAME]->data_.data());
// for (int64_t i = 0; i < COLLECTION_DIM; ++i) {
// fetch_vectors.push_back(p[i]);
// }
//
// status = db_->Insert(collection_name, partition_name1, data_chunk);
// ASSERT_TRUE(status.ok());
//
// milvus::engine::utils::GetIDFromChunk(data_chunk, ids_1);
// ASSERT_EQ(ids_1.size(), entity_count);
// }
//
// status = db_->Flush();
// ASSERT_TRUE(status.ok());
//
// // fetch no.1 entity from partition 'p1'
// // fetch no.2 entity from partition 'p2'
// std::vector<std::string> field_names = {milvus::engine::FIELD_UID, VECTOR_FIELD_NAME};
// std::vector<bool> valid_row;
// milvus::engine::DataChunkPtr fetch_chunk;
// milvus::engine::IDNumbers fetch_ids = {ids_1[0]};
// status = db_->GetEntityByID(collection_name, fetch_ids, field_names, valid_row, fetch_chunk);
// ASSERT_TRUE(status.ok());
// ASSERT_EQ(fetch_chunk->count_, fetch_ids.size());
// ASSERT_EQ(fetch_chunk->fixed_fields_[VECTOR_FIELD_NAME]->data_.size(),
// fetch_ids.size() * COLLECTION_DIM * sizeof(float));
//
// // compare result
// std::vector<float> result_vectors;
// float* p = (float*)(fetch_chunk->fixed_fields_[VECTOR_FIELD_NAME]->data_.data());
// for (int64_t i = 0; i < COLLECTION_DIM; i++) {
// result_vectors.push_back(p[i]);
// }
// ASSERT_EQ(fetch_vectors, result_vectors);
}
TEST_F(DBTest, FetchTest2) {
std::string collection_name = "STATS_TEST";
auto status = CreateCollection2(db_, collection_name, 0);
ASSERT_TRUE(status.ok());
......@@ -1031,8 +1179,7 @@ TEST_F(DBTest, FetchTest) {
status = db_->Insert(collection_name, "", data_chunk);
ASSERT_TRUE(status.ok());
data_chunk->fixed_fields_.erase(milvus::engine::FIELD_UID); // clear auto-generated id
BuildEntities(entity_count, 0, data_chunk);
status = db_->Insert(collection_name, partition_name, data_chunk);
ASSERT_TRUE(status.ok());
......@@ -1297,8 +1444,7 @@ TEST_F(DBTest, LoadTest) {
status = db_->Insert(collection_name, "", data_chunk);
ASSERT_TRUE(status.ok());
data_chunk->fixed_fields_.erase(milvus::engine::FIELD_UID); // clear auto-generated id
BuildEntities(entity_count, 0, data_chunk);
status = db_->Insert(collection_name, partition_name, data_chunk);
ASSERT_TRUE(status.ok());
......
......@@ -35,7 +35,9 @@ using WalOperationPtr = milvus::engine::WalOperationPtr;
using WalOperationType = milvus::engine::WalOperationType;
using WalOperationCodec = milvus::engine::WalOperationCodec;
using InsertEntityOperation = milvus::engine::InsertEntityOperation;
using InsertEntityOperationPtr = milvus::engine::InsertEntityOperationPtr;
using DeleteEntityOperation = milvus::engine::DeleteEntityOperation;
using DeleteEntityOperationPtr = milvus::engine::DeleteEntityOperationPtr;
using WalProxy = milvus::engine::WalProxy;
void CreateChunk(DataChunkPtr& chunk, int64_t row_count, int64_t& chunk_size) {
......@@ -145,6 +147,9 @@ TEST_F(WalTest, WalFileTest) {
ASSERT_TRUE(file.ExceedMaxSize(max_size));
bytes = file.Write(path.data(), 0);
ASSERT_EQ(bytes, 0);
bytes = file.Write(path.data(), len);
ASSERT_EQ(bytes, len);
total_bytes += bytes;
......@@ -174,6 +179,9 @@ TEST_F(WalTest, WalFileTest) {
ASSERT_EQ(bytes, sizeof(int8_t));
std::string str;
bytes = file.ReadStr(str, 0);
ASSERT_EQ(bytes, 0);
bytes = file.ReadStr(str, len);
ASSERT_EQ(bytes, len);
ASSERT_EQ(str, path);
......@@ -191,65 +199,76 @@ TEST_F(WalTest, WalFileTest) {
}
TEST_F(WalTest, WalFileCodecTest) {
std::string path = "/tmp/milvus_wal/test_file";
std::string collection_name = "c1";
std::string partition_name = "p1";
std::string file_path = "/tmp/milvus_wal/test_file";
auto file = std::make_shared<WalFile>();
IDNumbers op_ids;
std::vector<WalOperationType> op_types;
// insert operation
{
auto status = file->OpenFile(path, WalFile::APPEND_WRITE);
ASSERT_TRUE(status.ok());
DataChunkPtr chunk;
int64_t chunk_size = 0;
CreateChunk(chunk, 1000, chunk_size);
// record 100 operations
std::vector<WalOperationPtr> operations;
for (int64_t i = 1; i <= 100; ++i) {
if (i % 5 == 0) {
// delete operation
auto status = file->OpenFile(file_path, WalFile::APPEND_WRITE);
ASSERT_TRUE(status.ok());
std::string partition_name = "p1";
idx_t op_id = 100;
op_ids.push_back(op_id);
op_types.push_back(WalOperationType::INSERT_ENTITY);
WalOperationCodec::WriteInsertOperation(file, partition_name, chunk, op_id);
auto pre_size = file->Size();
ASSERT_GE(file->Size(), chunk_size);
DeleteEntityOperationPtr operation = std::make_shared<DeleteEntityOperation>();
operation->collection_name_ = collection_name;
IDNumbers ids = {i + 1, i + 2, i + 3};
operation->entity_ids_ = ids;
idx_t op_id = i + 10000;
operation->SetID(op_id);
operations.emplace_back(operation);
file->CloseFile();
status = WalOperationCodec::WriteDeleteOperation(file, ids, op_id);
ASSERT_TRUE(status.ok());
WalFile file_read;
file_read.OpenFile(path, WalFile::READ);
idx_t last_id = 0;
file_read.ReadLastOpId(last_id);
ASSERT_EQ(last_id, op_id);
}
auto post_size = file->Size();
ASSERT_GE(post_size - pre_size, ids.size() * sizeof(idx_t));
// delete operation
{
auto status = file->OpenFile(path, WalFile::APPEND_WRITE);
ASSERT_TRUE(status.ok());
file->CloseFile();
auto pre_size = file->Size();
WalFile file_read;
file_read.OpenFile(file_path, WalFile::READ);
idx_t last_id = 0;
file_read.ReadLastOpId(last_id);
ASSERT_EQ(last_id, op_id);
} else {
// insert operation
auto status = file->OpenFile(file_path, WalFile::APPEND_WRITE);
ASSERT_TRUE(status.ok());
IDNumbers ids = {1, 2, 3};
idx_t op_id = 200;
op_ids.push_back(op_id);
op_types.push_back(WalOperationType::DELETE_ENTITY);
WalOperationCodec::WriteDeleteOperation(file, ids, op_id);
InsertEntityOperationPtr operation = std::make_shared<InsertEntityOperation>();
operation->collection_name_ = collection_name;
operation->partition_name = partition_name;
auto post_size = file->Size();
ASSERT_GE(post_size - pre_size, ids.size() * sizeof(idx_t));
DataChunkPtr chunk;
int64_t chunk_size = 0;
CreateChunk(chunk, 100, chunk_size);
operation->data_chunk_ = chunk;
file->CloseFile();
idx_t op_id = i + 10000;
operation->SetID(op_id);
operations.emplace_back(operation);
WalFile file_read;
file_read.OpenFile(path, WalFile::READ);
idx_t last_id = 0;
file_read.ReadLastOpId(last_id);
ASSERT_EQ(last_id, op_id);
status = WalOperationCodec::WriteInsertOperation(file, partition_name, chunk, op_id);
ASSERT_TRUE(status.ok());
ASSERT_GE(file->Size(), chunk_size);
file->CloseFile();
WalFile file_read;
file_read.OpenFile(file_path, WalFile::READ);
idx_t last_id = 0;
file_read.ReadLastOpId(last_id);
ASSERT_EQ(last_id, op_id);
}
}
// iterate operations
{
auto status = file->OpenFile(path, WalFile::READ);
auto status = file->OpenFile(file_path, WalFile::READ);
ASSERT_TRUE(status.ok());
Status iter_status;
......@@ -261,11 +280,48 @@ TEST_F(WalTest, WalFileCodecTest) {
continue;
}
ASSERT_EQ(operation->ID(), op_ids[op_index]);
ASSERT_EQ(operation->Type(), op_types[op_index]);
if (op_index >= operations.size()) {
ASSERT_TRUE(false);
}
// validate operation data is correct
WalOperationPtr compare_operation = operations[op_index];
ASSERT_EQ(operation->ID(), compare_operation->ID());
ASSERT_EQ(operation->Type(), compare_operation->Type());
if (operation->Type() == WalOperationType::INSERT_ENTITY) {
InsertEntityOperationPtr op_1 = std::static_pointer_cast<InsertEntityOperation>(operation);
InsertEntityOperationPtr op_2 = std::static_pointer_cast<InsertEntityOperation>(compare_operation);
ASSERT_EQ(op_1->partition_name, op_2->partition_name);
DataChunkPtr chunk_1 = op_1->data_chunk_;
DataChunkPtr chunk_2 = op_2->data_chunk_;
ASSERT_NE(chunk_1, nullptr);
ASSERT_NE(chunk_2, nullptr);
ASSERT_EQ(chunk_1->count_, chunk_2->count_);
for (auto& pair : chunk_1->fixed_fields_) {
auto iter = chunk_2->fixed_fields_.find(pair.first);
ASSERT_NE(iter, chunk_2->fixed_fields_.end());
ASSERT_NE(pair.second, nullptr);
ASSERT_NE(iter->second, nullptr);
ASSERT_EQ(pair.second->data_, iter->second->data_);
}
for (auto& pair : chunk_1->variable_fields_) {
auto iter = chunk_2->variable_fields_.find(pair.first);
ASSERT_NE(iter, chunk_2->variable_fields_.end());
ASSERT_NE(pair.second, nullptr);
ASSERT_NE(iter->second, nullptr);
ASSERT_EQ(pair.second->data_, iter->second->data_);
}
} else if(operation->Type() == WalOperationType::DELETE_ENTITY) {
DeleteEntityOperationPtr op_1 = std::static_pointer_cast<DeleteEntityOperation>(operation);
DeleteEntityOperationPtr op_2 = std::static_pointer_cast<DeleteEntityOperation>(compare_operation);
ASSERT_EQ(op_1->entity_ids_, op_2->entity_ids_);
}
++op_index;
}
ASSERT_EQ(op_index, op_ids.size());
ASSERT_EQ(op_index, operations.size());
}
}
......@@ -291,8 +347,7 @@ TEST_F(WalTest, WalProxyTest) {
// find out the wal files
DBOptions opt = GetOptions();
std::experimental::filesystem::path collection_path = opt.meta_.path_;
collection_path.append(milvus::engine::WAL_DATA_FOLDER);
std::experimental::filesystem::path collection_path = opt.wal_path_;
collection_path.append(collection_name);
using DirectoryIterator = std::experimental::filesystem::recursive_directory_iterator;
......@@ -354,7 +409,7 @@ TEST_F(WalTest, WalManagerTest) {
// construct mock db
DBOptions options;
options.meta_.path_ = "/tmp/milvus_wal";
options.wal_path_ = "/tmp/milvus_wal";
options.wal_enable_ = true;
DummyDBPtr db_1 = std::make_shared<DummyDB>(options);
......
......@@ -158,6 +158,7 @@ DBTest::GetOptions() {
options.meta_.path_ = "/tmp/milvus_ss";
options.meta_.backend_uri_ = "mock://:@:/";
options.wal_enable_ = false;
options.auto_flush_interval_ = 1;
return options;
}
......@@ -312,16 +313,17 @@ EventTest::TearDown() {
DBOptions
WalTest::GetOptions() {
DBOptions options;
options.meta_.path_ = "/tmp/milvus_wal";
options.meta_.backend_uri_ = "mock://:@:/";
options.wal_path_ = "/tmp/milvus_wal";
options.wal_enable_ = true;
return options;
}
void
WalTest::SetUp() {
auto options = GetOptions();
std::experimental::filesystem::create_directory(options.wal_path_);
milvus::engine::DBPtr db = std::make_shared<milvus::engine::DBProxy>(nullptr, GetOptions());
db_ = std::make_shared<milvus::engine::WalProxy>(db, GetOptions());
db_ = std::make_shared<milvus::engine::WalProxy>(db, options);
db_->Start();
}
......@@ -329,7 +331,7 @@ void
WalTest::TearDown() {
db_->Stop();
db_ = nullptr;
std::experimental::filesystem::remove_all(GetOptions().meta_.path_);
std::experimental::filesystem::remove_all(GetOptions().wal_path_);
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册