提交 3c0932cc 编写于 作者: P peng.xu

Merge branch 'branch-0.4.0' into 'branch-0.4.0'

add vector_ids in Insert

See merge request megasearch/milvus!393

Former-commit-id: 3553b40681daea6d3cf42f29f7364aa00db65c19
......@@ -42,10 +42,12 @@ Status MemManagerImpl::InsertVectorsNoLock(const std::string &table_id,
MemTablePtr mem = GetMemByTable(table_id);
VectorSource::Ptr source = std::make_shared<VectorSource>(n, vectors);
auto status = mem->Add(source);
auto status = mem->Add(source, vector_ids);
if (status.ok()) {
if (vector_ids.empty()) {
vector_ids = source->GetVectorIds();
}
}
return status;
}
......
......@@ -15,7 +15,7 @@ MemTable::MemTable(const std::string &table_id,
}
Status MemTable::Add(VectorSource::Ptr &source) {
Status MemTable::Add(VectorSource::Ptr &source, IDNumbers &vector_ids) {
while (!source->AllAdded()) {
......@@ -27,12 +27,12 @@ Status MemTable::Add(VectorSource::Ptr &source) {
Status status;
if (mem_table_file_list_.empty() || current_mem_table_file->IsFull()) {
MemTableFile::Ptr new_mem_table_file = std::make_shared<MemTableFile>(table_id_, meta_, options_);
status = new_mem_table_file->Add(source);
status = new_mem_table_file->Add(source, vector_ids);
if (status.ok()) {
mem_table_file_list_.emplace_back(new_mem_table_file);
}
} else {
status = current_mem_table_file->Add(source);
status = current_mem_table_file->Add(source, vector_ids);
}
if (!status.ok()) {
......
......@@ -21,7 +21,7 @@ class MemTable {
MemTable(const std::string &table_id, const std::shared_ptr<meta::Meta> &meta, const Options &options);
Status Add(VectorSource::Ptr &source);
Status Add(VectorSource::Ptr &source, IDNumbers &vector_ids);
void GetCurrentMemTableFile(MemTableFile::Ptr &mem_table_file);
......
......@@ -41,7 +41,7 @@ Status MemTableFile::CreateTableFile() {
return status;
}
Status MemTableFile::Add(const VectorSource::Ptr &source) {
Status MemTableFile::Add(const VectorSource::Ptr &source, IDNumbers& vector_ids) {
if (table_file_schema_.dimension_ <= 0) {
std::string err_msg = "MemTableFile::Add: table_file_schema dimension = " +
......@@ -55,7 +55,7 @@ Status MemTableFile::Add(const VectorSource::Ptr &source) {
if (mem_left >= single_vector_mem_size) {
size_t num_vectors_to_add = std::ceil(mem_left / single_vector_mem_size);
size_t num_vectors_added;
auto status = source->Add(execution_engine_, table_file_schema_, num_vectors_to_add, num_vectors_added);
auto status = source->Add(execution_engine_, table_file_schema_, num_vectors_to_add, num_vectors_added, vector_ids);
if (status.ok()) {
current_mem_ += (num_vectors_added * single_vector_mem_size);
}
......
......@@ -19,7 +19,7 @@ class MemTableFile {
MemTableFile(const std::string &table_id, const std::shared_ptr<meta::Meta> &meta, const Options &options);
Status Add(const VectorSource::Ptr &source);
Status Add(const VectorSource::Ptr &source, IDNumbers& vector_ids);
size_t GetCurrentMem();
......
......@@ -14,21 +14,29 @@ VectorSource::VectorSource(const size_t &n,
const float *vectors) :
n_(n),
vectors_(vectors),
id_generator_(new SimpleIDGenerator()) {
id_generator_(std::make_shared<SimpleIDGenerator>()) {
current_num_vectors_added = 0;
}
Status VectorSource::Add(const ExecutionEnginePtr &execution_engine,
const meta::TableFileSchema &table_file_schema,
const size_t &num_vectors_to_add,
size_t &num_vectors_added) {
size_t &num_vectors_added,
IDNumbers &vector_ids) {
auto start_time = METRICS_NOW_TIME;
num_vectors_added = current_num_vectors_added + num_vectors_to_add <= n_ ?
num_vectors_to_add : n_ - current_num_vectors_added;
IDNumbers vector_ids_to_add;
if (vector_ids.empty()) {
id_generator_->GetNextIDNumbers(num_vectors_added, vector_ids_to_add);
} else {
vector_ids_to_add.resize(num_vectors_added);
for (int pos = current_num_vectors_added; pos < current_num_vectors_added + num_vectors_added; pos++) {
vector_ids_to_add[pos-current_num_vectors_added] = vector_ids[pos];
}
}
Status status = execution_engine->AddWithIds(num_vectors_added,
vectors_ + current_num_vectors_added * table_file_schema.dimension_,
vector_ids_to_add.data());
......
......@@ -21,7 +21,8 @@ class VectorSource {
Status Add(const ExecutionEnginePtr &execution_engine,
const meta::TableFileSchema &table_file_schema,
const size_t &num_vectors_to_add,
size_t &num_vectors_added);
size_t &num_vectors_added,
IDNumbers &vector_ids);
size_t GetNumVectorsAdded();
......@@ -37,7 +38,7 @@ class VectorSource {
size_t current_num_vectors_added;
IDGenerator *id_generator_;
std::shared_ptr<IDGenerator> id_generator_;
}; //VectorSource
......
......@@ -15,6 +15,8 @@
using namespace milvus;
//#define SET_VECTOR_IDS;
namespace {
std::string GetTableName();
......@@ -211,9 +213,9 @@ ClientTest::Test(const std::string& address, const std::string& port) {
std::cout << "All tables: " << std::endl;
for(auto& table : tables) {
int64_t row_count = 0;
// conn->DropTable(table);
stat = conn->CountTable(table, row_count);
std::cout << "\t" << table << "(" << row_count << " rows)" << std::endl;
conn->DropTable(table);
// stat = conn->CountTable(table, row_count);
// std::cout << "\t" << table << "(" << row_count << " rows)" << std::endl;
}
}
......@@ -235,59 +237,21 @@ ClientTest::Test(const std::string& address, const std::string& port) {
std::cout << "DescribeTable function call status: " << stat.ToString() << std::endl;
PrintTableSchema(tb_schema);
}
//
// Connection::Destroy(conn);
// pid_t pid;
// for (int i = 0; i < 5; ++i) {
// pid = fork();
// if (pid == 0 || pid == -1) {
// break;
// }
// }
// if (pid == -1) {
// std::cout << "fail to fork!\n";
// exit(1);
// } else if (pid == 0) {
// std::shared_ptr<Connection> conn = Connection::Create();
//
// {//connect server
// ConnectParam param = {address, port};
// Status stat = conn->Connect(param);
// std::cout << "Connect function call status: " << stat.ToString() << std::endl;
// }
//
// {//server version
// std::string version = conn->ServerVersion();
// std::cout << "Server version: " << version << std::endl;
// }
// Connection::Destroy(conn);
// exit(0);
// } else {
// std::shared_ptr<Connection> conn = Connection::Create();
//
// {//connect server
// ConnectParam param = {address, port};
// Status stat = conn->Connect(param);
// std::cout << "Connect function call status: " << stat.ToString() << std::endl;
// }
//
// {//server version
// std::string version = conn->ServerVersion();
// std::cout << "Server version: " << version << std::endl;
// }
// Connection::Destroy(conn);
// std::cout << "in main process\n";
// exit(0);
// }
std::vector<std::pair<int64_t, RowRecord>> search_record_array;
{//insert vectors
std::vector<int64_t> record_ids;
for (int i = 0; i < ADD_VECTOR_LOOP; i++) {//add vectors
std::vector<RowRecord> record_array;
int64_t begin_index = i * BATCH_ROW_COUNT;
BuildVectors(begin_index, begin_index + BATCH_ROW_COUNT, record_array);
std::vector<int64_t> record_ids;
#ifdef SET_VECTOR_IDS
record_ids.resize(ADD_VECTOR_LOOP * BATCH_ROW_COUNT);
for (auto j = begin_index; j <begin_index + BATCH_ROW_COUNT; j++) {
record_ids[i * BATCH_ROW_COUNT + j] = i * BATCH_ROW_COUNT + j;
}
#endif
auto start = std::chrono::high_resolution_clock::now();
......
......@@ -187,15 +187,20 @@ ClientProxy::Insert(const std::string &table_name,
}
}
::milvus::grpc::VectorIds vector_ids;
//Single thread
::milvus::grpc::VectorIds vector_ids;
if (!id_array.empty()) {
for (auto i = 0; i < id_array.size(); i++) {
insert_param.add_row_id_array(id_array[i]);
}
client_ptr_->Insert(vector_ids, insert_param, status);
} else {
client_ptr_->Insert(vector_ids, insert_param, status);
auto finish = std::chrono::high_resolution_clock::now();
for (size_t i = 0; i < vector_ids.vector_id_array_size(); i++) {
id_array.push_back(vector_ids.vector_id_array(i));
}
}
#endif
} catch (std::exception &ex) {
......
......@@ -381,9 +381,9 @@ InsertTask::InsertTask(const ::milvus::grpc::InsertParam &insert_param,
}
BaseTaskPtr
InsertTask::Create(const ::milvus::grpc::InsertParam &insert_infos,
InsertTask::Create(const ::milvus::grpc::InsertParam &insert_param,
::milvus::grpc::VectorIds &record_ids) {
return std::shared_ptr<GrpcBaseTask>(new InsertTask(insert_infos, record_ids));
return std::shared_ptr<GrpcBaseTask>(new InsertTask(insert_param, record_ids));
}
ServerError
......@@ -400,6 +400,13 @@ InsertTask::OnExecute() {
return SetError(SERVER_INVALID_ROWRECORD_ARRAY, "Row record array is empty");
}
if (!record_ids_.vector_id_array().empty()) {
if (record_ids_.vector_id_array().size() != insert_param_.row_record_array_size()) {
return SetError(SERVER_ILLEGAL_VECTOR_ID,
"Size of vector ids is not equal to row record array size");
}
}
//step 2: check table existence
engine::meta::TableSchema table_info;
table_info.table_id_ = insert_param_.table_name();
......@@ -446,7 +453,10 @@ InsertTask::OnExecute() {
//step 4: insert vectors
auto vec_count = (uint64_t) insert_param_.row_record_array_size();
std::vector<int64_t> vec_ids(record_ids_.vector_id_array_size(), 0);
std::vector<int64_t> vec_ids(insert_param_.row_id_array_size(), 0);
for (auto i = 0; i < insert_param_.row_id_array_size(); i++) {
vec_ids[i] = insert_param_.row_id_array(i);
}
stat = DBWrapper::DB()->InsertVectors(insert_param_.table_name(), vec_count, vec_f.data(),
vec_ids);
......
......@@ -68,15 +68,15 @@ TEST_F(NewMemManagerTest, VECTOR_SOURCE_TEST) {
engine::ExecutionEnginePtr execution_engine_ = engine::EngineFactory::Build(table_file_schema.dimension_,
table_file_schema.location_,
(engine::EngineType) table_file_schema.engine_type_);
status = source.Add(execution_engine_, table_file_schema, 50, num_vectors_added);
engine::IDNumbers vector_ids;
status = source.Add(execution_engine_, table_file_schema, 50, num_vectors_added, vector_ids);
ASSERT_TRUE(status.ok());
ASSERT_EQ(num_vectors_added, 50);
engine::IDNumbers vector_ids = source.GetVectorIds();
vector_ids = source.GetVectorIds();
ASSERT_EQ(vector_ids.size(), 50);
ASSERT_EQ(num_vectors_added, 50);
status = source.Add(execution_engine_, table_file_schema, 60, num_vectors_added);
vector_ids.clear();
status = source.Add(execution_engine_, table_file_schema, 60, num_vectors_added, vector_ids);
ASSERT_TRUE(status.ok());
ASSERT_EQ(num_vectors_added, 50);
......@@ -105,12 +105,13 @@ TEST_F(NewMemManagerTest, MEM_TABLE_FILE_TEST) {
engine::VectorSource::Ptr source = std::make_shared<engine::VectorSource>(n_100, vectors_100.data());
status = mem_table_file.Add(source);
engine::IDNumbers vector_ids;
status = mem_table_file.Add(source, vector_ids);
ASSERT_TRUE(status.ok());
// std::cout << mem_table_file.GetCurrentMem() << " " << mem_table_file.GetMemLeft() << std::endl;
engine::IDNumbers vector_ids = source->GetVectorIds();
vector_ids = source->GetVectorIds();
ASSERT_EQ(vector_ids.size(), 100);
size_t singleVectorMem = sizeof(float) * TABLE_DIM;
......@@ -121,7 +122,8 @@ TEST_F(NewMemManagerTest, MEM_TABLE_FILE_TEST) {
BuildVectors(n_max, vectors_128M);
engine::VectorSource::Ptr source_128M = std::make_shared<engine::VectorSource>(n_max, vectors_128M.data());
status = mem_table_file.Add(source_128M);
vector_ids.clear();
status = mem_table_file.Add(source_128M, vector_ids);
vector_ids = source_128M->GetVectorIds();
ASSERT_EQ(vector_ids.size(), n_max - n_100);
......@@ -149,9 +151,10 @@ TEST_F(NewMemManagerTest, MEM_TABLE_TEST) {
engine::MemTable mem_table(TABLE_NAME, impl_, options);
status = mem_table.Add(source_100);
engine::IDNumbers vector_ids;
status = mem_table.Add(source_100, vector_ids);
ASSERT_TRUE(status.ok());
engine::IDNumbers vector_ids = source_100->GetVectorIds();
vector_ids = source_100->GetVectorIds();
ASSERT_EQ(vector_ids.size(), 100);
engine::MemTableFile::Ptr mem_table_file;
......@@ -163,8 +166,9 @@ TEST_F(NewMemManagerTest, MEM_TABLE_TEST) {
std::vector<float> vectors_128M;
BuildVectors(n_max, vectors_128M);
vector_ids.clear();
engine::VectorSource::Ptr source_128M = std::make_shared<engine::VectorSource>(n_max, vectors_128M.data());
status = mem_table.Add(source_128M);
status = mem_table.Add(source_128M, vector_ids);
ASSERT_TRUE(status.ok());
vector_ids = source_128M->GetVectorIds();
......@@ -181,7 +185,8 @@ TEST_F(NewMemManagerTest, MEM_TABLE_TEST) {
engine::VectorSource::Ptr source_1G = std::make_shared<engine::VectorSource>(n_1G, vectors_1G.data());
status = mem_table.Add(source_1G);
vector_ids.clear();
status = mem_table.Add(source_1G, vector_ids);
ASSERT_TRUE(status.ok());
vector_ids = source_1G->GetVectorIds();
......@@ -370,3 +375,74 @@ TEST_F(NewMemManagerTest, CONCURRENT_INSERT_SEARCH_TEST) {
};
TEST_F(DBTest, VECTOR_IDS_TEST)
{
engine::meta::TableSchema table_info = BuildTableSchema();
engine::Status stat = db_->CreateTable(table_info);
engine::meta::TableSchema table_info_get;
table_info_get.table_id_ = TABLE_NAME;
stat = db_->DescribeTable(table_info_get);
ASSERT_STATS(stat);
ASSERT_EQ(table_info_get.dimension_, TABLE_DIM);
engine::IDNumbers vector_ids;
int64_t nb = 100000;
std::vector<float> xb;
BuildVectors(nb, xb);
vector_ids.resize(nb);
for (auto i = 0; i < nb; i++) {
vector_ids[i] = i;
}
stat = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids);
ASSERT_EQ(vector_ids[0], 0);
ASSERT_STATS(stat);
nb = 25000;
xb.clear();
BuildVectors(nb, xb);
vector_ids.clear();
vector_ids.resize(nb);
for (auto i = 0; i < nb; i++) {
vector_ids[i] = i + nb;
}
stat = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids);
ASSERT_EQ(vector_ids[0], nb);
ASSERT_STATS(stat);
nb = 262144; //512M
xb.clear();
BuildVectors(nb, xb);
vector_ids.clear();
vector_ids.resize(nb);
for (auto i = 0; i < nb; i++) {
vector_ids[i] = i + nb / 2;
}
stat = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids);
ASSERT_EQ(vector_ids[0], nb/2);
ASSERT_STATS(stat);
nb = 65536; //128M
xb.clear();
BuildVectors(nb, xb);
vector_ids.clear();
stat = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids);
ASSERT_STATS(stat);
nb = 100;
xb.clear();
BuildVectors(nb, xb);
vector_ids.clear();
vector_ids.resize(nb);
for (auto i = 0; i < nb; i++) {
vector_ids[i] = i + nb;
}
stat = db_->InsertVectors(TABLE_NAME, nb, xb.data(), vector_ids);
for (auto i = 0; i < nb; i++) {
ASSERT_EQ(vector_ids[i], i + nb);
}
}
\ No newline at end of file
......@@ -54,4 +54,7 @@ TEST(PrometheusTest, PROMETHEUS_TEST){
instance.ConnectionGaugeDecrement();
instance.KeepingAliveCounterIncrement();
instance.OctetsSet();
instance.CPUCoreUsagePercentSet();
instance.GPUTemperature();
instance.CPUTemperature();
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册