未验证 提交 a87596fb 编写于 作者: G groot 提交者: GitHub

split insert data accordding to segment row count (#3529)

* split insert data accordding to segment row count
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix bug
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix python test failure
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* typo
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* fix python test
Signed-off-by: Ngroot <yihua.mo@zilliz.com>

* avoid tiny segment row count
Signed-off-by: Ngroot <yihua.mo@zilliz.com>
上级 b4bddc08
......@@ -481,14 +481,14 @@ DBImpl::Insert(const std::string& collection_name, const std::string& partition_
// check id field existence
auto& params = ss->GetCollection()->GetParams();
bool auto_increment = true;
bool auto_genid = true;
if (params.find(PARAM_UID_AUTOGEN) != params.end()) {
auto_increment = params[PARAM_UID_AUTOGEN];
auto_genid = params[PARAM_UID_AUTOGEN];
}
FIXEDX_FIELD_MAP& fields = data_chunk->fixed_fields_;
auto pair = fields.find(engine::FIELD_UID);
if (auto_increment) {
if (auto_genid) {
// id is auto generated, but client provides id, return error
if (pair != fields.end() && pair->second != nullptr) {
return Status(DB_ERROR, "Field '_id' is auto increment, no need to provide id");
......@@ -507,7 +507,7 @@ DBImpl::Insert(const std::string& collection_name, const std::string& partition_
consume_chunk->variable_fields_.swap(data_chunk->variable_fields_);
// generate id
if (auto_increment) {
if (auto_genid) {
SafeIDGenerator& id_generator = SafeIDGenerator::GetInstance();
IDNumbers ids;
STATUS_CHECK(id_generator.GetNextIDNumbers(consume_chunk->count_, ids));
......@@ -523,19 +523,30 @@ DBImpl::Insert(const std::string& collection_name, const std::string& partition_
}
// do insert
int64_t segment_row_count = DEFAULT_SEGMENT_ROW_COUNT;
if (params.find(PARAM_SEGMENT_ROW_COUNT) != params.end()) {
segment_row_count = params[PARAM_SEGMENT_ROW_COUNT];
}
int64_t collection_id = ss->GetCollectionId();
int64_t partition_id = partition->GetID();
auto status = mem_mgr_->InsertEntities(collection_id, partition_id, consume_chunk, op_id);
if (!status.ok()) {
return status;
}
if (mem_mgr_->GetCurrentMem() > options_.insert_buffer_size_) {
LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld] ", "insert", 0) << "Insert buffer size exceeds limit. Force flush";
InternalFlush();
std::vector<DataChunkPtr> chunks;
STATUS_CHECK(utils::SplitChunk(consume_chunk, segment_row_count, chunks));
for (auto& chunk : chunks) {
auto status = mem_mgr_->InsertEntities(collection_id, partition_id, chunk, op_id);
if (!status.ok()) {
return status;
}
if (mem_mgr_->GetCurrentMem() > options_.insert_buffer_size_) {
LOG_ENGINE_DEBUG_ << LogOut("[%s][%ld] ", "insert", 0) << "Insert buffer size exceeds limit. Force flush";
InternalFlush();
}
}
// metrics
Status status = Status::OK();
milvus::server::CollectInsertMetrics metrics(data_chunk->count_, status);
return Status::OK();
......
......@@ -215,9 +215,8 @@ GetSnapshotInfo(const std::string& collection_name, milvus::json& json_info) {
continue;
}
milvus::json json_file;
auto element = pair.second->GetElement();
if (pair.second->GetFile()) {
milvus::json json_file;
json_file[JSON_DATA_SIZE] = pair.second->GetFile()->GetSize();
json_file[JSON_PATH] =
engine::snapshot::GetResPath<engine::snapshot::SegmentFile>("", pair.second->GetFile());
......@@ -225,14 +224,15 @@ GetSnapshotInfo(const std::string& collection_name, milvus::json& json_info) {
// if the element is index, print index name/type
// else print element name
auto element = pair.second->GetElement();
if (element->GetFEtype() == engine::FieldElementType::FET_INDEX) {
json_file[JSON_NAME] = element->GetName();
json_file[JSON_INDEX_TYPE] = element->GetTypeName();
} else {
json_file[JSON_NAME] = element->GetName();
}
json_files.push_back(json_file);
}
json_files.push_back(json_file);
}
}
......@@ -263,5 +263,37 @@ GetSnapshotInfo(const std::string& collection_name, milvus::json& json_info) {
return Status::OK();
}
Status
GetSegmentRowCount(const std::string& collection_name, int64_t& segment_row_count) {
segment_row_count = DEFAULT_SEGMENT_ROW_COUNT;
snapshot::ScopedSnapshotT latest_ss;
STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(latest_ss, collection_name));
// get row count per segment
auto collection = latest_ss->GetCollection();
const json params = collection->GetParams();
if (params.find(PARAM_SEGMENT_ROW_COUNT) != params.end()) {
segment_row_count = params[PARAM_SEGMENT_ROW_COUNT];
}
return Status::OK();
}
Status
GetSegmentRowCount(int64_t collection_id, int64_t& segment_row_count) {
segment_row_count = DEFAULT_SEGMENT_ROW_COUNT;
snapshot::ScopedSnapshotT latest_ss;
STATUS_CHECK(snapshot::Snapshots::GetInstance().GetSnapshot(latest_ss, collection_id));
// get row count per segment
auto collection = latest_ss->GetCollection();
const json params = collection->GetParams();
if (params.find(PARAM_SEGMENT_ROW_COUNT) != params.end()) {
segment_row_count = params[PARAM_SEGMENT_ROW_COUNT];
}
return Status::OK();
}
} // namespace engine
} // namespace milvus
......@@ -52,5 +52,11 @@ IsVectorField(engine::DataType type);
Status
GetSnapshotInfo(const std::string& collection_name, milvus::json& json_info);
Status
GetSegmentRowCount(const std::string& collection_name, int64_t& segment_row_count);
Status
GetSegmentRowCount(int64_t collection_id, int64_t& segment_row_count);
} // namespace engine
} // namespace milvus
......@@ -19,6 +19,7 @@
#include <memory>
#include <mutex>
#include <regex>
#include <utility>
#include <vector>
#include "cache/CpuCacheMgr.h"
......@@ -172,6 +173,84 @@ GetSizeOfChunk(const engine::DataChunkPtr& chunk) {
return total_size;
}
Status
SplitChunk(const DataChunkPtr& chunk, int64_t segment_row_count, std::vector<DataChunkPtr>& chunks) {
if (chunk == nullptr || segment_row_count <= 0) {
return Status::OK();
}
// no need to split chunk if chunk row count less than segment_row_count
// if user specify a tiny segment_row_count(such as 1) , also no need to split,
// use build_index_threshold(default is 4096) to avoid tiny segment_row_count
if (chunk->count_ <= segment_row_count || chunk->count_ <= config.engine.build_index_threshold.value) {
chunks.push_back(chunk);
return Status::OK();
}
int64_t chunk_count = chunk->count_;
// split chunk accordding to segment row count
// firstly validate each field size
FIELD_WIDTH_MAP fields_width;
for (auto& pair : chunk->fixed_fields_) {
if (pair.second == nullptr) {
continue;
}
if (pair.second->data_.size() % chunk_count != 0) {
return Status(DB_ERROR, "Invalid chunk fixed field size");
}
fields_width.insert(std::make_pair(pair.first, pair.second->data_.size() / chunk_count));
}
for (auto& pair : chunk->variable_fields_) {
if (pair.second == nullptr) {
continue;
}
if (pair.second->offset_.size() != chunk_count) {
return Status(DB_ERROR, "Invalid chunk variable field size");
}
}
// secondly, copy new chunk
int64_t copied_count = 0;
while (copied_count < chunk_count) {
int64_t count_to_copy = segment_row_count;
if (chunk_count - copied_count < segment_row_count) {
count_to_copy = chunk_count - copied_count;
}
DataChunkPtr new_chunk = std::make_shared<DataChunk>();
for (auto& pair : chunk->fixed_fields_) {
if (pair.second == nullptr) {
continue;
}
int64_t field_width = fields_width[pair.first];
BinaryDataPtr data = std::make_shared<BinaryData>();
int64_t data_length = field_width * count_to_copy;
int64_t offset = field_width * copied_count;
data->data_.resize(data_length);
memcpy(data->data_.data(), pair.second->data_.data() + offset, data_length);
new_chunk->fixed_fields_.insert(std::make_pair(pair.first, data));
}
// TODO: copy variable data
for (auto& pair : chunk->variable_fields_) {
if (pair.second == nullptr) {
continue;
}
}
new_chunk->count_ = count_to_copy;
copied_count += count_to_copy;
chunks.emplace_back(new_chunk);
}
return Status::OK();
}
bool
RequireRawFile(const std::string& index_type) {
return index_type == knowhere::IndexEnum::INDEX_FAISS_IVFFLAT || index_type == knowhere::IndexEnum::INDEX_NSG ||
......
......@@ -13,6 +13,7 @@
#include <ctime>
#include <string>
#include <vector>
#include "db/Types.h"
#include "utils/Status.h"
......@@ -58,6 +59,9 @@ GetIDFromChunk(const engine::DataChunkPtr& chunk, engine::IDNumbers& ids);
int64_t
GetSizeOfChunk(const engine::DataChunkPtr& chunk);
Status
SplitChunk(const DataChunkPtr& chunk, int64_t segment_row_count, std::vector<DataChunkPtr>& chunks);
bool
RequireRawFile(const std::string& index_type);
......
......@@ -24,6 +24,7 @@
#include <fiu/fiu-local.h>
#include "config/ServerConfig.h"
#include "db/SnapshotUtils.h"
#include "db/Utils.h"
#include "db/snapshot/CompoundOperations.h"
#include "db/snapshot/IterateHandler.h"
......@@ -38,6 +39,7 @@ namespace engine {
MemCollection::MemCollection(int64_t collection_id, const DBOptions& options)
: collection_id_(collection_id), options_(options) {
GetSegmentRowCount(collection_id_, segment_row_count_);
}
Status
......@@ -55,7 +57,9 @@ MemCollection::Add(int64_t partition_id, const DataChunkPtr& chunk, idx_t op_id)
int64_t chunk_size = utils::GetSizeOfChunk(chunk);
Status status;
if (current_mem_segment == nullptr || current_mem_segment->GetCurrentMem() + chunk_size > MAX_MEM_SEGMENT_SIZE) {
if (current_mem_segment == nullptr || chunk->count_ >= segment_row_count_ ||
current_mem_segment->GetCurrentRowCount() >= segment_row_count_ ||
current_mem_segment->GetCurrentMem() + chunk_size > MAX_MEM_SEGMENT_SIZE) {
MemSegmentPtr new_mem_segment = std::make_shared<MemSegment>(collection_id_, partition_id, options_);
status = new_mem_segment->Add(chunk, op_id);
if (status.ok()) {
......
......@@ -60,13 +60,15 @@ class MemCollection {
ApplyDeleteToFile();
private:
int64_t collection_id_;
int64_t collection_id_ = 0;
DBOptions options_;
MemSegmentMap mem_segments_;
std::mutex mem_mutex_;
std::unordered_set<idx_t> ids_to_delete_;
int64_t segment_row_count_ = 0;
};
using MemCollectionPtr = std::shared_ptr<MemCollection>;
......
......@@ -48,6 +48,7 @@ MemSegment::Add(const DataChunkPtr& chunk, idx_t op_id) {
actions_.emplace_back(action);
current_mem_ += utils::GetSizeOfChunk(chunk);
total_row_count_ += chunk->count_;
return Status::OK();
}
......
......@@ -52,6 +52,11 @@ class MemSegment {
return current_mem_;
}
int64_t
GetCurrentRowCount() const {
return total_row_count_;
}
Status
Serialize();
......@@ -75,6 +80,8 @@ class MemSegment {
using ActionArray = std::vector<MemAction>;
ActionArray actions_; // the actions array mekesure insert/delete actions executed one by one
int64_t total_row_count_ = 0;
};
using MemSegmentPtr = std::shared_ptr<MemSegment>;
......
......@@ -257,69 +257,43 @@ WalManager::Init() {
Status
WalManager::RecordInsertOperation(const InsertEntityOperationPtr& operation, const DBPtr& db) {
std::vector<DataChunkPtr> chunks;
SplitChunk(operation->data_chunk_, chunks);
IDNumbers op_ids;
auto status = id_gen_.GetNextIDNumbers(chunks.size(), op_ids);
if (!status.ok()) {
return status;
}
for (size_t i = 0; i < chunks.size(); ++i) {
idx_t op_id = op_ids[i];
DataChunkPtr& chunk = chunks[i];
int64_t chunk_size = utils::GetSizeOfChunk(chunk);
idx_t op_id = id_gen_.GetNextIDNumber();
try {
// open wal file
std::string path = ConstructFilePath(operation->collection_name_, std::to_string(op_id));
if (!path.empty()) {
std::lock_guard<std::mutex> lock(file_map_mutex_);
WalFilePtr file = file_map_[operation->collection_name_];
if (file == nullptr) {
file = std::make_shared<WalFile>();
file_map_[operation->collection_name_] = file;
file->OpenFile(path, WalFile::APPEND_WRITE);
} else if (!file->IsOpened() || file->ExceedMaxSize(chunk_size)) {
file->OpenFile(path, WalFile::APPEND_WRITE);
}
DataChunkPtr& chunk = operation->data_chunk_;
int64_t chunk_size = utils::GetSizeOfChunk(chunk);
// write to wal file
status = WalOperationCodec::WriteInsertOperation(file, operation->partition_name, chunk, op_id);
if (!status.ok()) {
return status;
}
try {
// open wal file
std::string path = ConstructFilePath(operation->collection_name_, std::to_string(op_id));
if (!path.empty()) {
std::lock_guard<std::mutex> lock(file_map_mutex_);
WalFilePtr file = file_map_[operation->collection_name_];
if (file == nullptr) {
file = std::make_shared<WalFile>();
file_map_[operation->collection_name_] = file;
file->OpenFile(path, WalFile::APPEND_WRITE);
} else if (!file->IsOpened() || file->ExceedMaxSize(chunk_size)) {
file->OpenFile(path, WalFile::APPEND_WRITE);
}
} catch (std::exception& ex) {
std::string msg = "Failed to record insert operation, reason: " + std::string(ex.what());
return Status(DB_ERROR, msg);
}
// insert action to db
if (db) {
status = db->Insert(operation->collection_name_, operation->partition_name, operation->data_chunk_, op_id);
// write to wal file
auto status = WalOperationCodec::WriteInsertOperation(file, operation->partition_name, chunk, op_id);
if (!status.ok()) {
return status;
}
}
} catch (std::exception& ex) {
std::string msg = "Failed to record insert operation, reason: " + std::string(ex.what());
return Status(DB_ERROR, msg);
}
return Status::OK();
}
Status
WalManager::SplitChunk(const DataChunkPtr& chunk, std::vector<DataChunkPtr>& chunks) {
// int64_t chunk_size = utils::GetSizeOfChunk(chunk);
// if (chunk_size > insert_buffer_size_) {
// int64_t batch = chunk_size / insert_buffer_size_;
// int64_t batch_count = chunk->count_ / batch;
// for (int64_t i = 0; i <= batch; ++i) {
// }
// } else {
// chunks.push_back(chunk);
// }
chunks.push_back(chunk);
// insert action to db
if (db) {
auto status = db->Insert(operation->collection_name_, operation->partition_name, operation->data_chunk_, op_id);
if (!status.ok()) {
return status;
}
}
return Status::OK();
}
......
......@@ -67,9 +67,6 @@ class WalManager {
Status
RecordDeleteOperation(const DeleteEntityOperationPtr& operation, const DBPtr& db);
Status
SplitChunk(const DataChunkPtr& chunk, std::vector<DataChunkPtr>& chunks);
std::string
ConstructFilePath(const std::string& collection_name, const std::string& file_name);
......
......@@ -11,6 +11,8 @@
#include "db/wal/WalProxy.h"
#include "config/ServerConfig.h"
#include "db/SnapshotUtils.h"
#include "db/Utils.h"
#include "db/wal/WalManager.h"
#include "db/wal/WalOperation.h"
#include "utils/Exception.h"
......@@ -61,13 +63,45 @@ WalProxy::DropCollection(const std::string& collection_name) {
Status
WalProxy::Insert(const std::string& collection_name, const std::string& partition_name, DataChunkPtr& data_chunk,
idx_t op_id) {
// write operation into disk
InsertEntityOperationPtr op = std::make_shared<InsertEntityOperation>();
op->collection_name_ = collection_name;
op->partition_name = partition_name;
op->data_chunk_ = data_chunk;
// get segment row count of this collection
int64_t row_count_per_segment = DEFAULT_SEGMENT_ROW_COUNT;
GetSegmentRowCount(collection_name, row_count_per_segment);
return WalManager::GetInstance().RecordOperation(op, db_);
// split chunk accordding to segment row count
std::vector<DataChunkPtr> chunks;
STATUS_CHECK(utils::SplitChunk(data_chunk, row_count_per_segment, chunks));
if (chunks.size() > 0 && data_chunk != chunks[0]) {
// data has been copied to new chunk, do this to free memory
data_chunk->fixed_fields_.clear();
data_chunk->variable_fields_.clear();
data_chunk->count_ = 0;
}
// write operation into wal file, and insert to memory
for (auto& chunk : chunks) {
InsertEntityOperationPtr op = std::make_shared<InsertEntityOperation>();
op->collection_name_ = collection_name;
op->partition_name = partition_name;
op->data_chunk_ = chunk;
STATUS_CHECK(WalManager::GetInstance().RecordOperation(op, db_));
}
// return id field
if (chunks.size() > 0 && data_chunk != chunks[0]) {
int64_t row_count = 0;
BinaryDataPtr id_data = std::make_shared<BinaryData>();
for (auto& chunk : chunks) {
auto iter = chunk->fixed_fields_.find(engine::FIELD_UID);
if (iter != chunk->fixed_fields_.end()) {
id_data->data_.insert(id_data->data_.end(), iter->second->data_.begin(), iter->second->data_.end());
row_count += chunk->count_;
}
}
data_chunk->count_ = row_count;
data_chunk->fixed_fields_[engine::FIELD_UID] = id_data;
}
return Status::OK();
}
Status
......
......@@ -173,7 +173,7 @@ Server::Start() {
if (is_read_only) {
STATUS_CHECK(Directory::Access("", "", config.logs.path()));
} else {
STATUS_CHECK(Directory::Access(config.storage.path(), config.wal.path(), config.logs.path()));
STATUS_CHECK(Directory::Access(config.storage.path(), wal_path, config.logs.path()));
if (config.system.lock.enable()) {
STATUS_CHECK(Directory::Lock(config.storage.path(), wal_path));
......
......@@ -90,7 +90,11 @@ InsertReq::OnExecute() {
}
// step 5: return entity id to client
chunk_data_[engine::FIELD_UID] = data_chunk->fixed_fields_[engine::FIELD_UID]->data_;
auto iter = data_chunk->fixed_fields_.find(engine::FIELD_UID);
if (iter == data_chunk->fixed_fields_.end() || iter->second == nullptr) {
return Status(SERVER_UNEXPECTED_ERROR, "Insert action return empty id array");
}
chunk_data_[engine::FIELD_UID] = iter->second->data_;
rc.ElapseFromBegin("done");
} catch (std::exception& ex) {
......
......@@ -53,10 +53,12 @@ CreateCollection(const std::shared_ptr<DB>& db, const std::string& collection_na
static constexpr int64_t COLLECTION_DIM = 10;
milvus::Status
CreateCollection2(std::shared_ptr<DB> db, const std::string& collection_name, const LSN_TYPE& lsn) {
CreateCollection2(std::shared_ptr<DB> db, const std::string& collection_name, bool auto_genid = true) {
CreateCollectionContext context;
context.lsn = lsn;
auto collection_schema = std::make_shared<Collection>(collection_name);
milvus::json collection_params;
collection_params[milvus::engine::PARAM_UID_AUTOGEN] = auto_genid;
auto collection_schema = std::make_shared<Collection>(collection_name, collection_params);
context.collection = collection_schema;
milvus::json params;
......@@ -107,7 +109,7 @@ CreateCollection3(std::shared_ptr<DB> db, const std::string& collection_name, co
}
void
BuildEntities(uint64_t n, uint64_t batch_index, milvus::engine::DataChunkPtr& data_chunk) {
BuildEntities(uint64_t n, uint64_t batch_index, milvus::engine::DataChunkPtr& data_chunk, bool gen_id = false) {
data_chunk = std::make_shared<milvus::engine::DataChunk>();
data_chunk->count_ = n;
......@@ -123,10 +125,19 @@ BuildEntities(uint64_t n, uint64_t batch_index, milvus::engine::DataChunkPtr& da
vectors.id_array_.push_back(n * batch_index + i);
}
milvus::engine::BinaryDataPtr raw = std::make_shared<milvus::engine::BinaryData>();
raw->data_.resize(vectors.float_data_.size() * sizeof(float));
memcpy(raw->data_.data(), vectors.float_data_.data(), vectors.float_data_.size() * sizeof(float));
data_chunk->fixed_fields_[VECTOR_FIELD_NAME] = raw;
if (gen_id) {
milvus::engine::BinaryDataPtr raw = std::make_shared<milvus::engine::BinaryData>();
raw->data_.resize(vectors.id_array_.size() * sizeof(int64_t));
memcpy(raw->data_.data(), vectors.id_array_.data(), vectors.id_array_.size() * sizeof(int64_t));
data_chunk->fixed_fields_[milvus::engine::FIELD_UID] = raw;
}
{
milvus::engine::BinaryDataPtr raw = std::make_shared<milvus::engine::BinaryData>();
raw->data_.resize(vectors.float_data_.size() * sizeof(float));
memcpy(raw->data_.data(), vectors.float_data_.data(), vectors.float_data_.size() * sizeof(float));
data_chunk->fixed_fields_[VECTOR_FIELD_NAME] = raw;
}
std::vector<int32_t> value_0;
std::vector<int64_t> value_1;
......@@ -642,7 +653,7 @@ TEST(MergeTest, MergeStrategyTest) {
TEST_F(DBTest, MergeTest) {
std::string collection_name = "MERGE_TEST";
auto status = CreateCollection2(db_, collection_name, 0);
auto status = CreateCollection2(db_, collection_name);
ASSERT_TRUE(status.ok());
const uint64_t entity_count = 100;
......@@ -778,7 +789,7 @@ TEST_F(DBTest, GetEntityTest) {
};
std::string collection_name = "GET_ENTITY_TEST";
auto status = CreateCollection2(db_, collection_name, 0);
auto status = CreateCollection2(db_, collection_name);
ASSERT_TRUE(status.ok()) << status.ToString();
milvus::engine::IDNumbers entity_ids;
......@@ -842,7 +853,7 @@ TEST_F(DBTest, GetEntityTest) {
TEST_F(DBTest, CompactTest) {
std::string collection_name = "COMPACT_TEST";
auto status = CreateCollection2(db_, collection_name, 0);
auto status = CreateCollection2(db_, collection_name);
ASSERT_TRUE(status.ok());
// insert 1000 entities into default partition
......@@ -931,7 +942,7 @@ TEST_F(DBTest, CompactTest) {
TEST_F(DBTest, IndexTest) {
std::string collection_name = "INDEX_TEST";
auto status = CreateCollection2(db_, collection_name, 0);
auto status = CreateCollection2(db_, collection_name);
ASSERT_TRUE(status.ok());
// insert 10000 entities into default partition
......@@ -1008,7 +1019,7 @@ TEST_F(DBTest, IndexTest) {
TEST_F(DBTest, StatsTest) {
std::string collection_name = "STATS_TEST";
auto status = CreateCollection2(db_, collection_name, 0);
auto status = CreateCollection2(db_, collection_name);
ASSERT_TRUE(status.ok());
std::string partition_name = "p1";
......@@ -1101,7 +1112,7 @@ TEST_F(DBTest, StatsTest) {
TEST_F(DBTest, FetchTest1) {
std::string collection_name = "STATS_TEST";
auto status = CreateCollection2(db_, collection_name, 0);
auto status = CreateCollection2(db_, collection_name);
ASSERT_TRUE(status.ok());
std::string partition_name1 = "p1";
......@@ -1174,7 +1185,7 @@ TEST_F(DBTest, FetchTest1) {
ASSERT_EQ(fetch_vectors, result_vectors);
// std::string collection_name = "STATS_TEST";
// auto status = CreateCollection2(db_, collection_name, 0);
// auto status = CreateCollection2(db_, collection_name);
// ASSERT_TRUE(status.ok());
//
// std::string partition_name1 = "p1";
......@@ -1227,7 +1238,7 @@ TEST_F(DBTest, FetchTest1) {
TEST_F(DBTest, FetchTest2) {
std::string collection_name = "STATS_TEST";
auto status = CreateCollection2(db_, collection_name, 0);
auto status = CreateCollection2(db_, collection_name);
ASSERT_TRUE(status.ok());
std::string partition_name = "p1";
......@@ -1320,12 +1331,12 @@ TEST_F(DBTest, FetchTest2) {
TEST_F(DBTest, DeleteEntitiesTest) {
std::string collection_name = "test_collection_delete_";
CreateCollection2(db_, collection_name, 0);
CreateCollection2(db_, collection_name, false);
// insert 100 entities into default partition without flush
milvus::engine::IDNumbers entity_ids;
milvus::engine::DataChunkPtr data_chunk;
BuildEntities(100, 0, data_chunk);
BuildEntities(100, 0, data_chunk, true);
auto status = db_->Insert(collection_name, "", data_chunk);
milvus::engine::utils::GetIDFromChunk(data_chunk, entity_ids);
......@@ -1339,7 +1350,7 @@ TEST_F(DBTest, DeleteEntitiesTest) {
auto insert_entities = [&](const std::string& collection, const std::string& partition,
uint64_t count, uint64_t batch_index, milvus::engine::IDNumbers& ids) -> Status {
milvus::engine::DataChunkPtr data_chunk;
BuildEntities(count, batch_index, data_chunk);
BuildEntities(count, batch_index, data_chunk, true);
STATUS_CHECK(db_->Insert(collection, partition, data_chunk));
STATUS_CHECK(db_->Flush(collection));
......@@ -1402,6 +1413,7 @@ TEST_F(DBTest, DeleteEntitiesTest) {
std::vector<bool> valid_row;
milvus::engine::DataChunkPtr entity_data_chunk;
for (auto& id : whole_delete_ids) {
std::cout << "get entity: " << id << std::endl;
status = db_->GetEntityByID(collection_name, {id}, {}, valid_row, entity_data_chunk);
ASSERT_TRUE(status.ok()) << status.ToString();
ASSERT_EQ(entity_data_chunk->count_, 0);
......@@ -1448,7 +1460,7 @@ TEST_F(DBTest, DeleteStaleTest) {
};
const std::string collection_name = "test_delete_stale_";
auto status = CreateCollection2(db_, collection_name, 0);
auto status = CreateCollection2(db_, collection_name);
ASSERT_TRUE(status.ok()) << status.ToString();
milvus::engine::IDNumbers del_ids;
milvus::engine::IDNumbers entity_ids;
......@@ -1492,7 +1504,7 @@ TEST_F(DBTest, DeleteStaleTest) {
TEST_F(DBTest, LoadTest) {
std::string collection_name = "LOAD_TEST";
auto status = CreateCollection2(db_, collection_name, 0);
auto status = CreateCollection2(db_, collection_name);
ASSERT_TRUE(status.ok());
std::string partition_name = "p1";
......
......@@ -40,13 +40,38 @@ using DeleteEntityOperation = milvus::engine::DeleteEntityOperation;
using DeleteEntityOperationPtr = milvus::engine::DeleteEntityOperationPtr;
using WalProxy = milvus::engine::WalProxy;
void CreateChunk(DataChunkPtr& chunk, int64_t row_count, int64_t& chunk_size) {
const char* COLLECTION_NAME = "wal_tbl";
const char* VECTOR_FIELD_NAME = "vector";
const char* INT_FIELD_NAME = "int";
milvus::Status
CreateCollection() {
CreateCollectionContext context;
auto collection_schema = std::make_shared<Collection>(COLLECTION_NAME);
context.collection = collection_schema;
auto vector_field = std::make_shared<Field>(VECTOR_FIELD_NAME, 0, milvus::engine::DataType::VECTOR_FLOAT);
auto int_field = std::make_shared<Field>(INT_FIELD_NAME, 0, milvus::engine::DataType::INT32);
context.fields_schema[vector_field] = {};
context.fields_schema[int_field] = {};
// default id is auto-generated
auto params = context.collection->GetParams();
params[milvus::engine::PARAM_UID_AUTOGEN] = true;
params[milvus::engine::PARAM_SEGMENT_ROW_COUNT] = 1000;
context.collection->SetParams(params);
auto op = std::make_shared<milvus::engine::snapshot::CreateCollectionOperation>(context);
return op->Push();
}
void
CreateChunk(DataChunkPtr& chunk, int64_t row_count, int64_t& chunk_size) {
chunk = std::make_shared<DataChunk>();
chunk->count_ = row_count;
chunk_size = 0;
{
// int32 type field
std::string field_name = "f1";
std::string field_name = INT_FIELD_NAME;
auto bin = std::make_shared<BinaryData>();
bin->data_.resize(chunk->count_ * sizeof(int32_t));
int32_t* p = (int32_t*)(bin->data_.data());
......@@ -59,7 +84,7 @@ void CreateChunk(DataChunkPtr& chunk, int64_t row_count, int64_t& chunk_size) {
{
// vector type field
int64_t dimension = 128;
std::string field_name = "f2";
std::string field_name = VECTOR_FIELD_NAME;
auto bin = std::make_shared<BinaryData>();
bin->data_.resize(chunk->count_ * sizeof(float) * dimension);
float* p = (float*)(bin->data_.data());
......@@ -76,7 +101,7 @@ void CreateChunk(DataChunkPtr& chunk, int64_t row_count, int64_t& chunk_size) {
class DummyDB : public DBProxy {
public:
DummyDB(const DBOptions& options)
: DBProxy(nullptr, options) {
: DBProxy(nullptr, options) {
}
Status
......@@ -99,6 +124,7 @@ class DummyDB : public DBProxy {
}
int64_t InsertCount() const { return insert_count_; }
int64_t DeleteCount() const { return delete_count_; }
private:
......@@ -199,7 +225,6 @@ TEST_F(WalTest, WalFileTest) {
}
TEST_F(WalTest, WalFileCodecTest) {
std::string collection_name = "c1";
std::string partition_name = "p1";
std::string file_path = "/tmp/milvus_wal/test_file";
auto file = std::make_shared<WalFile>();
......@@ -215,7 +240,7 @@ TEST_F(WalTest, WalFileCodecTest) {
auto pre_size = file->Size();
DeleteEntityOperationPtr operation = std::make_shared<DeleteEntityOperation>();
operation->collection_name_ = collection_name;
operation->collection_name_ = COLLECTION_NAME;
IDNumbers ids = {i + 1, i + 2, i + 3};
operation->entity_ids_ = ids;
idx_t op_id = i + 10000;
......@@ -241,7 +266,7 @@ TEST_F(WalTest, WalFileCodecTest) {
ASSERT_TRUE(status.ok());
InsertEntityOperationPtr operation = std::make_shared<InsertEntityOperation>();
operation->collection_name_ = collection_name;
operation->collection_name_ = COLLECTION_NAME;
operation->partition_name = partition_name;
DataChunkPtr chunk;
......@@ -273,7 +298,7 @@ TEST_F(WalTest, WalFileCodecTest) {
Status iter_status;
int32_t op_index = 0;
while(iter_status.ok()) {
while (iter_status.ok()) {
WalOperationPtr operation;
iter_status = WalOperationCodec::IterateOperation(file, operation, 0);
if (operation == nullptr) {
......@@ -313,7 +338,7 @@ TEST_F(WalTest, WalFileCodecTest) {
ASSERT_NE(iter->second, nullptr);
ASSERT_EQ(pair.second->data_, iter->second->data_);
}
} else if(operation->Type() == WalOperationType::DELETE_ENTITY) {
} else if (operation->Type() == WalOperationType::DELETE_ENTITY) {
DeleteEntityOperationPtr op_1 = std::static_pointer_cast<DeleteEntityOperation>(operation);
DeleteEntityOperationPtr op_2 = std::static_pointer_cast<DeleteEntityOperation>(compare_operation);
ASSERT_EQ(op_1->entity_ids_, op_2->entity_ids_);
......@@ -326,21 +351,23 @@ TEST_F(WalTest, WalFileCodecTest) {
}
TEST_F(WalTest, WalProxyTest) {
std::string collection_name = "col_1";
auto status = CreateCollection();
ASSERT_TRUE(status.ok());
std::string partition_name = "part_1";
// write over more than 400MB data, 2 wal files
for (int64_t i = 1; i <= 1000; i++) {
if (i % 10 == 0) {
IDNumbers ids = {1, 2, 3};
auto status = db_->DeleteEntityByID(collection_name, ids, 0);
status = db_->DeleteEntityByID(COLLECTION_NAME, ids, 0);
ASSERT_TRUE(status.ok());
} else {
DataChunkPtr chunk;
int64_t chunk_size = 0;
CreateChunk(chunk, 1000, chunk_size);
CreateChunk(chunk, (i % 20) * 100, chunk_size);
auto status = db_->Insert(collection_name, partition_name, chunk, 0);
status = db_->Insert(COLLECTION_NAME, partition_name, chunk, 0);
ASSERT_TRUE(status.ok());
}
}
......@@ -348,7 +375,7 @@ TEST_F(WalTest, WalProxyTest) {
// find out the wal files
DBOptions opt = GetOptions();
std::experimental::filesystem::path collection_path = opt.wal_path_;
collection_path.append(collection_name);
collection_path.append(COLLECTION_NAME);
using DirectoryIterator = std::experimental::filesystem::recursive_directory_iterator;
std::set<idx_t> op_ids;
......@@ -364,7 +391,7 @@ TEST_F(WalTest, WalProxyTest) {
// read all operation ids
auto file = std::make_shared<WalFile>();
auto status = file->OpenFile(file_path, WalFile::READ);
status = file->OpenFile(file_path, WalFile::READ);
ASSERT_TRUE(status.ok());
Status iter_status;
......@@ -380,7 +407,7 @@ TEST_F(WalTest, WalProxyTest) {
// notify operation done, the wal files will be removed after all operations done
for (auto id : op_ids) {
auto status = WalManager::GetInstance().OperationDone(collection_name, id);
status = WalManager::GetInstance().OperationDone(COLLECTION_NAME, id);
ASSERT_TRUE(status.ok());
}
......@@ -405,8 +432,6 @@ TEST_F(WalTest, WalProxyTest) {
}
TEST_F(WalTest, WalManagerTest) {
std::string collection_name = "collection";
// construct mock db
DBOptions options;
options.wal_path_ = "/tmp/milvus_wal";
......@@ -422,16 +447,16 @@ TEST_F(WalTest, WalManagerTest) {
int64_t delete_count = 0;
for (int64_t i = 1; i <= 1000; i++) {
if (i % 100 == 0) {
auto status = WalManager::GetInstance().DropCollection(collection_name);
auto status = WalManager::GetInstance().DropCollection(COLLECTION_NAME);
ASSERT_TRUE(status.ok());
} else if (i % 10 == 0) {
IDNumbers ids = {1, 2, 3};
auto op = std::make_shared<DeleteEntityOperation>();
op->collection_name_ = collection_name;
op->collection_name_ = COLLECTION_NAME;
op->entity_ids_ = ids;
auto status = WalManager::GetInstance().RecordOperation(op, db_1);
auto status = WalManager::GetInstance().RecordOperation(op, db_1);
ASSERT_TRUE(status.ok());
delete_count++;
......@@ -441,11 +466,11 @@ TEST_F(WalTest, WalManagerTest) {
CreateChunk(chunk, 1000, chunk_size);
auto op = std::make_shared<InsertEntityOperation>();
op->collection_name_ = collection_name;
op->collection_name_ = COLLECTION_NAME;
op->partition_name = "";
op->data_chunk_ = chunk;
auto status = WalManager::GetInstance().RecordOperation(op, db_1);
auto status = WalManager::GetInstance().RecordOperation(op, db_1);
ASSERT_TRUE(status.ok());
insert_count++;
......@@ -464,10 +489,10 @@ TEST_F(WalTest, WalManagerTest) {
IDNumbers ids = {1, 2, 3};
auto op = std::make_shared<DeleteEntityOperation>();
op->collection_name_ = collection_name;
op->collection_name_ = COLLECTION_NAME;
op->entity_ids_ = ids;
auto status = WalManager::GetInstance().RecordOperation(op, nullptr);
auto status = WalManager::GetInstance().RecordOperation(op, nullptr);
ASSERT_TRUE(status.ok());
delete_count++;
......@@ -477,11 +502,11 @@ TEST_F(WalTest, WalManagerTest) {
CreateChunk(chunk, 1000, chunk_size);
auto op = std::make_shared<InsertEntityOperation>();
op->collection_name_ = collection_name;
op->collection_name_ = COLLECTION_NAME;
op->partition_name = "";
op->data_chunk_ = chunk;
auto status = WalManager::GetInstance().RecordOperation(op, nullptr);
auto status = WalManager::GetInstance().RecordOperation(op, nullptr);
ASSERT_TRUE(status.ok());
insert_count++;
......
......@@ -313,6 +313,7 @@ EventTest::TearDown() {
DBOptions
WalTest::GetOptions() {
DBOptions options;
options.meta_.backend_uri_ = "mock://:@:/";
options.wal_path_ = "/tmp/milvus_wal";
options.wal_enable_ = true;
return options;
......@@ -320,18 +321,22 @@ WalTest::GetOptions() {
void
WalTest::SetUp() {
BaseTest::SetUp();
auto options = GetOptions();
std::experimental::filesystem::create_directory(options.wal_path_);
milvus::engine::DBPtr db = std::make_shared<milvus::engine::DBProxy>(nullptr, GetOptions());
db_ = std::make_shared<milvus::engine::WalProxy>(db, options);
db_->Start();
BaseTest::SnapshotStart(true, options);
}
void
WalTest::TearDown() {
BaseTest::SnapshotStop();
db_->Stop();
db_ = nullptr;
std::experimental::filesystem::remove_all(GetOptions().wal_path_);
BaseTest::TearDown();
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -411,7 +411,7 @@ class EventTest : public BaseTest {
};
///////////////////////////////////////////////////////////////////////////////
class WalTest : public ::testing::Test {
class WalTest : public BaseTest {
protected:
std::shared_ptr<DB> db_;
......
......@@ -171,11 +171,11 @@ class TestStatsBase:
connect.flush([collection])
stats = connect.get_collection_stats(collection)
logging.getLogger().info(stats)
compact_before = stats["partitions"][0]["segments"][0]["data_size"]
compact_before = stats["partitions"][0]["row_count"]
connect.compact(collection)
stats = connect.get_collection_stats(collection)
logging.getLogger().info(stats)
compact_after = stats["partitions"][0]["segments"][0]["data_size"]
compact_after = stats["partitions"][0]["row_count"]
# pdb.set_trace()
assert compact_before == compact_after
......@@ -301,7 +301,7 @@ class TestStatsBase:
connect.flush(collection_list)
for i in range(collection_num):
stats = connect.get_collection_stats(collection_list[i])
assert stats["partitions"][0]["segments"][0]["row_count"] == nb
assert stats["partitions"][0]["row_count"] == nb
connect.drop_collection(collection_list[i])
@pytest.mark.level(2)
......
......@@ -157,14 +157,14 @@ class TestCompactBase:
# get collection info before compact
info = connect.get_collection_stats(collection)
logging.getLogger().info(info["partitions"])
size_before = info["partitions"][0]["segments"][0]["data_size"]
size_before = info["partitions"][0]["data_size"]
logging.getLogger().info(size_before)
status = connect.compact(collection)
assert status.OK()
# get collection info after compact
info = connect.get_collection_stats(collection)
logging.getLogger().info(info["partitions"])
size_after = info["partitions"][0]["segments"][0]["data_size"]
size_after = info["partitions"][0]["data_size"]
logging.getLogger().info(size_after)
assert(size_before >= size_after)
......@@ -301,18 +301,18 @@ class TestCompactBase:
connect.flush([collection])
# get collection info before compact
info = connect.get_collection_stats(collection)
size_before = info["partitions"][0]["segments"][0]["data_size"]
size_before = info["partitions"][0]["data_size"]
status = connect.compact(collection)
assert status.OK()
# get collection info after compact
info = connect.get_collection_stats(collection)
size_after = info["partitions"][0]["segments"][0]["data_size"]
size_after = info["partitions"][0]["data_size"]
assert(size_before >= size_after)
status = connect.compact(collection)
assert status.OK()
# get collection info after compact twice
info = connect.get_collection_stats(collection)
size_after_twice = info["partitions"][0]["segments"][0]["data_size"]
size_after_twice = info["partitions"][0]["data_size"]
assert(size_after == size_after_twice)
@pytest.mark.timeout(COMPACT_TIMEOUT)
......@@ -482,14 +482,14 @@ class TestCompactBinary:
# get collection info before compact
info = connect.get_collection_stats(binary_collection)
logging.getLogger().info(info["partitions"])
size_before = info["partitions"][0]["segments"][0]["data_size"]
size_before = info["partitions"][0]["data_size"]
logging.getLogger().info(size_before)
status = connect.compact(binary_collection)
assert status.OK()
# get collection info after compact
info = connect.get_collection_stats(binary_collection)
logging.getLogger().info(info["partitions"])
size_after = info["partitions"][0]["segments"][0]["data_size"]
size_after = info["partitions"][0]["data_size"]
logging.getLogger().info(size_after)
assert(size_before >= size_after)
......@@ -559,18 +559,18 @@ class TestCompactBinary:
connect.flush([binary_collection])
# get collection info before compact
info = connect.get_collection_stats(binary_collection)
size_before = info["partitions"][0]["segments"][0]["data_size"]
size_before = info["partitions"][0]["data_size"]
status = connect.compact(binary_collection)
assert status.OK()
# get collection info after compact
info = connect.get_collection_stats(binary_collection)
size_after = info["partitions"][0]["segments"][0]["data_size"]
size_after = info["partitions"][0]["data_size"]
assert(size_before >= size_after)
status = connect.compact(binary_collection)
assert status.OK()
# get collection info after compact twice
info = connect.get_collection_stats(binary_collection)
size_after_twice = info["partitions"][0]["segments"][0]["data_size"]
size_after_twice = info["partitions"][0]["data_size"]
assert(size_after == size_after_twice)
@pytest.mark.timeout(COMPACT_TIMEOUT)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册