提交 9b2d6263 编写于 作者: Z zhiru

resolve merge conflicts


Former-commit-id: 31d8d0d07a1fb4d8299f694db44787ceacf8c2d3
......@@ -20,3 +20,4 @@ Please mark all change in change log and use the ticket from JIRA.
- MS-4 - Refactor the vecwise_engine code structure
- MS-6 - Implement SDK interface part 1
- MS-20 - Clean Code Part 1
- MS-6 - Implement SDK interface part 2
......@@ -101,7 +101,6 @@ link_directories(${CMAKE_CURRRENT_BINARY_DIR})
#execute_process(COMMAND bash build.sh
# WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/third_party)
add_subdirectory(src)
if (BUILD_UNIT_TEST)
......
server_config:
address: 0.0.0.0
port: 33001
transfer_protocol: json #optional: binary, compact, json, debug
transfer_protocol: binary #optional: binary, compact, json
server_mode: thread_pool #optional: simple, thread_pool
gpu_index: 0 #which gpu to be used
......
server_config:
address: 0.0.0.0
port: 33001
transfer_protocol: json #optional: binary, compact, json, debug
transfer_protocol: binary #optional: binary, compact, json
server_mode: thread_pool #optional: simple, thread_pool
gpu_index: 0 #which gpu to be used
......
......@@ -13,13 +13,49 @@
using namespace megasearch;
namespace {
#define BLOCK_SPLITER std::cout << "===========================================" << std::endl;
void PrintTableSchema(const megasearch::TableSchema& tb_schema) {
std::cout << "===========================================" << std::endl;
BLOCK_SPLITER
std::cout << "Table name: " << tb_schema.table_name << std::endl;
std::cout << "Table vectors: " << tb_schema.vector_column_array.size() << std::endl;
std::cout << "Table attributes: " << tb_schema.attribute_column_array.size() << std::endl;
std::cout << "Table partitions: " << tb_schema.partition_column_name_array.size() << std::endl;
std::cout << "===========================================" << std::endl;
BLOCK_SPLITER
}
void PrintRecordIdArray(const std::vector<int64_t>& record_ids) {
BLOCK_SPLITER
std::cout << "Returned id array count: " << record_ids.size() << std::endl;
#if 0
for(auto id : record_ids) {
std::cout << std::to_string(id) << std::endl;
}
#endif
BLOCK_SPLITER
}
void PrintSearchResult(const std::vector<TopKQueryResult>& topk_query_result_array) {
BLOCK_SPLITER
std::cout << "Returned result count: " << topk_query_result_array.size() << std::endl;
int32_t index = 0;
for(auto& result : topk_query_result_array) {
index++;
std::cout << "No." << std::to_string(index) << " vector top "
<< std::to_string(result.query_result_arrays.size())
<< " search result:" << std::endl;
for(auto& item : result.query_result_arrays) {
std::cout << "\t" << std::to_string(item.id) << "\tscore:" << std::to_string(item.score);
for(auto& attribute : item.column_map) {
std::cout << "\t" << attribute.first << ":" << attribute.second;
}
std::cout << std::endl;
}
}
BLOCK_SPLITER
}
std::string CurrentTime() {
......@@ -42,8 +78,29 @@ namespace {
static const std::string TABLE_NAME = GetTableName();
static const std::string VECTOR_COLUMN_NAME = "face_vector";
static const std::string AGE_COLUMN_NAME = "age";
static const std::string CITY_COLUMN_NAME = "city";
static const int64_t TABLE_DIMENSION = 512;
TableSchema BuildTableSchema() {
TableSchema tb_schema;
VectorColumn col1;
col1.name = VECTOR_COLUMN_NAME;
col1.dimension = TABLE_DIMENSION;
col1.store_raw_vector = true;
tb_schema.vector_column_array.emplace_back(col1);
Column col2 = {ColumnType::int8, AGE_COLUMN_NAME};
tb_schema.attribute_column_array.emplace_back(col2);
Column col3 = {ColumnType::int16, CITY_COLUMN_NAME};
tb_schema.attribute_column_array.emplace_back(col3);
tb_schema.table_name = TABLE_NAME;
return tb_schema;
}
void BuildVectors(int64_t from, int64_t to,
std::vector<RowRecord>* vector_record_array,
std::vector<QueryRecord>* query_record_array) {
......@@ -58,6 +115,19 @@ namespace {
query_record_array->clear();
}
static const std::map<int64_t , std::string> CITY_MAP = {
{0, "Beijing"},
{1, "Shanhai"},
{2, "Hangzhou"},
{3, "Guangzhou"},
{4, "Shenzheng"},
{5, "Wuhan"},
{6, "Chengdu"},
{7, "Chongqin"},
{8, "Tianjing"},
{9, "Hongkong"},
};
for (int64_t k = from; k < to; k++) {
std::vector<float> f_p;
......@@ -69,12 +139,16 @@ namespace {
if(vector_record_array) {
RowRecord record;
record.vector_map.insert(std::make_pair(VECTOR_COLUMN_NAME, f_p));
record.attribute_map[AGE_COLUMN_NAME] = std::to_string(k%100);
record.attribute_map[CITY_COLUMN_NAME] = CITY_MAP.at(k%CITY_MAP.size());
vector_record_array->emplace_back(record);
}
if(query_record_array) {
QueryRecord record;
record.vector_map.insert(std::make_pair(VECTOR_COLUMN_NAME, f_p));
record.selected_column_array.push_back(AGE_COLUMN_NAME);
record.selected_column_array.push_back(CITY_COLUMN_NAME);
query_record_array->emplace_back(record);
}
}
......@@ -87,29 +161,35 @@ ClientTest::Test(const std::string& address, const std::string& port) {
ConnectParam param = { address, port };
conn->Connect(param);
{//create table
TableSchema tb_schema;
VectorColumn col1;
col1.name = VECTOR_COLUMN_NAME;
col1.dimension = TABLE_DIMENSION;
col1.store_raw_vector = true;
tb_schema.vector_column_array.emplace_back(col1);
Column col2;
col2.name = "age";
tb_schema.attribute_column_array.emplace_back(col2);
{//get server version
std::string version = conn->ServerVersion();
std::cout << "MegaSearch server version: " << version << std::endl;
}
tb_schema.table_name = TABLE_NAME;
{
std::cout << "ShowTables" << std::endl;
std::vector<std::string> tables;
Status stat = conn->ShowTables(tables);
std::cout << "Function call status: " << stat.ToString() << std::endl;
std::cout << "All tables: " << std::endl;
for(auto& table : tables) {
std::cout << "\t" << table << std::endl;
}
}
{//create table
TableSchema tb_schema = BuildTableSchema();
PrintTableSchema(tb_schema);
std::cout << "CreateTable" << std::endl;
Status stat = conn->CreateTable(tb_schema);
std::cout << "Create table result: " << stat.ToString() << std::endl;
std::cout << "Function call status: " << stat.ToString() << std::endl;
}
{//describe table
TableSchema tb_schema;
std::cout << "DescribeTable" << std::endl;
Status stat = conn->DescribeTable(TABLE_NAME, tb_schema);
std::cout << "Describe table result: " << stat.ToString() << std::endl;
std::cout << "Function call status: " << stat.ToString() << std::endl;
PrintTableSchema(tb_schema);
}
......@@ -117,22 +197,23 @@ ClientTest::Test(const std::string& address, const std::string& port) {
std::vector<RowRecord> record_array;
BuildVectors(0, 10000, &record_array, nullptr);
std::vector<int64_t> record_ids;
std::cout << "Begin add vectors" << std::endl;
std::cout << "AddVector" << std::endl;
Status stat = conn->AddVector(TABLE_NAME, record_array, record_ids);
std::cout << "Add vector result: " << stat.ToString() << std::endl;
std::cout << "Returned vector ids: " << record_ids.size() << std::endl;
std::cout << "Function call status: " << stat.ToString() << std::endl;
PrintRecordIdArray(record_ids);
}
{//search vectors
std::cout << "Waiting data persist. Sleep 10 seconds ..." << std::endl;
sleep(10);
std::vector<QueryRecord> record_array;
BuildVectors(500, 510, nullptr, &record_array);
std::vector<TopKQueryResult> topk_query_result_array;
std::cout << "Begin search vectors" << std::endl;
std::cout << "SearchVector" << std::endl;
Status stat = conn->SearchVector(TABLE_NAME, record_array, topk_query_result_array, 10);
std::cout << "Search vector result: " << stat.ToString() << std::endl;
std::cout << "Returned result count: " << topk_query_result_array.size() << std::endl;
std::cout << "Function call status: " << stat.ToString() << std::endl;
PrintSearchResult(topk_query_result_array);
}
// {//delete table
......@@ -140,5 +221,13 @@ ClientTest::Test(const std::string& address, const std::string& port) {
// std::cout << "Delete table result: " << stat.ToString() << std::endl;
// }
{//server status
std::string status = conn->ServerStatus();
std::cout << "Server status before disconnect: " << status << std::endl;
}
Connection::Destroy(conn);
{//server status
std::string status = conn->ServerStatus();
std::cout << "Server status after disconnect: " << status << std::endl;
}
}
\ No newline at end of file
......@@ -21,7 +21,7 @@ ClientProxy::Connect(const ConnectParam &param) {
Disconnect();
int32_t port = atoi(param.port.c_str());
return ClientPtr()->Connect(param.ip_address, port, "json");
return ClientPtr()->Connect(param.ip_address, port, THRIFT_PROTOCOL_BINARY);
}
Status
......@@ -58,7 +58,7 @@ ClientProxy::Disconnect() {
std::string
ClientProxy::ClientVersion() const {
return std::string("Current Version");
return std::string("v1.0");
}
Status
......
......@@ -50,14 +50,12 @@ ThriftClient::Connect(const std::string& address, int32_t port, const std::strin
stdcxx::shared_ptr<TSocket> socket_ptr(new transport::TSocket(address, port));
stdcxx::shared_ptr<TTransport> transport_ptr(new TBufferedTransport(socket_ptr));
stdcxx::shared_ptr<TProtocol> protocol_ptr;
if(protocol == "binary") {
if(protocol == THRIFT_PROTOCOL_BINARY) {
protocol_ptr.reset(new TBinaryProtocol(transport_ptr));
} else if(protocol == "json") {
} else if(protocol == THRIFT_PROTOCOL_JSON) {
protocol_ptr.reset(new TJSONProtocol(transport_ptr));
} else if(protocol == "compact") {
} else if(protocol == THRIFT_PROTOCOL_COMPACT) {
protocol_ptr.reset(new TCompactProtocol(transport_ptr));
} else if(protocol == "debug") {
protocol_ptr.reset(new TDebugProtocol(transport_ptr));
} else {
//CLIENT_LOG_ERROR << "Service protocol: " << protocol << " is not supported currently";
return Status(StatusCode::Invalid, "unsupported protocol");
......
......@@ -14,6 +14,10 @@ namespace megasearch {
using MegasearchServiceClientPtr = std::shared_ptr<megasearch::thrift::MegasearchServiceClient>;
static const std::string THRIFT_PROTOCOL_JSON = "json";
static const std::string THRIFT_PROTOCOL_BINARY = "binary";
static const std::string THRIFT_PROTOCOL_COMPACT = "compact";
class ThriftClient {
public:
ThriftClient();
......
......@@ -32,14 +32,14 @@ MegasearchServiceHandler::DeleteTable(const std::string &table_name) {
void
MegasearchServiceHandler::CreateTablePartition(const thrift::CreateTablePartitionParam &param) {
// Your implementation goes here
printf("CreateTablePartition\n");
BaseTaskPtr task_ptr = CreateTablePartitionTask::Create(param);
MegasearchScheduler::ExecTask(task_ptr);
}
void
MegasearchServiceHandler::DeleteTablePartition(const thrift::DeleteTablePartitionParam &param) {
// Your implementation goes here
printf("DeleteTablePartition\n");
BaseTaskPtr task_ptr = DeleteTablePartitionTask::Create(param);
MegasearchScheduler::ExecTask(task_ptr);
}
void
......@@ -67,14 +67,14 @@ MegasearchServiceHandler::DescribeTable(thrift::TableSchema &_return, const std:
void
MegasearchServiceHandler::ShowTables(std::vector<std::string> &_return) {
// Your implementation goes here
printf("ShowTables\n");
BaseTaskPtr task_ptr = ShowTablesTask::Create(_return);
MegasearchScheduler::ExecTask(task_ptr);
}
void
MegasearchServiceHandler::Ping(std::string& _return, const std::string& cmd) {
// Your implementation goes here
printf("Ping\n");
BaseTaskPtr task_ptr = PingTask::Create(cmd, _return);
MegasearchScheduler::ExecTask(task_ptr);
}
}
......
......@@ -54,7 +54,6 @@ MegasearchServer::StartService() {
stdcxx::shared_ptr<TServerTransport> server_transport(new TServerSocket(address, port));
stdcxx::shared_ptr<TTransportFactory> transport_factory(new TBufferedTransportFactory());
std::string protocol = "json";
stdcxx::shared_ptr<TProtocolFactory> protocol_factory;
if (protocol == "binary") {
protocol_factory.reset(new TBinaryProtocolFactory());
......@@ -62,8 +61,6 @@ MegasearchServer::StartService() {
protocol_factory.reset(new TJSONProtocolFactory());
} else if (protocol == "compact") {
protocol_factory.reset(new TCompactProtocolFactory());
} else if (protocol == "debug") {
protocol_factory.reset(new TDebugProtocolFactory());
} else {
//SERVER_LOG_INFO << "Service protocol: " << protocol << " is not supported currently";
return;
......
......@@ -21,6 +21,7 @@ namespace server {
static const std::string DQL_TASK_GROUP = "dql";
static const std::string DDL_DML_TASK_GROUP = "ddl_dml";
static const std::string PING_TASK_GROUP = "ping";
static const std::string VECTOR_UID = "uid";
static const uint64_t USE_MT = 5000;
......@@ -48,6 +49,10 @@ namespace {
}
}
~DBWrapper() {
delete db_;
}
zilliz::vecwise::engine::DB* DB() { return db_; }
private:
......@@ -78,17 +83,17 @@ BaseTaskPtr CreateTableTask::Create(const thrift::TableSchema& schema) {
ServerError CreateTableTask::OnExecute() {
TimeRecorder rc("CreateTableTask");
try {
if(schema_.vector_column_array.empty()) {
return SERVER_INVALID_ARGUMENT;
}
IVecIdMapper::GetInstance()->AddGroup(schema_.table_name);
engine::meta::TableSchema table_schema;
table_schema.dimension = (uint16_t)schema_.vector_column_array[0].dimension;
table_schema.table_id = schema_.table_name;
engine::Status stat = DB()->CreateTable(table_schema);
engine::meta::TableSchema table_info;
table_info.dimension = (uint16_t)schema_.vector_column_array[0].dimension;
table_info.table_id = schema_.table_name;
engine::Status stat = DB()->CreateTable(table_info);
if(!stat.ok()) {//could exist
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
......@@ -109,7 +114,7 @@ ServerError CreateTableTask::OnExecute() {
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DescribeTableTask::DescribeTableTask(const std::string &table_name, thrift::TableSchema &schema)
: BaseTask(DDL_DML_TASK_GROUP),
: BaseTask(PING_TASK_GROUP),
table_name_(table_name),
schema_(schema) {
schema_.table_name = table_name_;
......@@ -123,9 +128,9 @@ ServerError DescribeTableTask::OnExecute() {
TimeRecorder rc("DescribeTableTask");
try {
engine::meta::TableSchema table_schema;
table_schema.table_id = table_name_;
engine::Status stat = DB()->DescribeTable(table_schema);
engine::meta::TableSchema table_info;
table_info.table_id = table_name_;
engine::Status stat = DB()->DescribeTable(table_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
......@@ -154,8 +159,8 @@ DeleteTableTask::DeleteTableTask(const std::string& table_name)
}
BaseTaskPtr DeleteTableTask::Create(const std::string& table_id) {
return std::shared_ptr<BaseTask>(new DeleteTableTask(table_id));
BaseTaskPtr DeleteTableTask::Create(const std::string& group_id) {
return std::shared_ptr<BaseTask>(new DeleteTableTask(group_id));
}
ServerError DeleteTableTask::OnExecute() {
......@@ -168,6 +173,60 @@ ServerError DeleteTableTask::OnExecute() {
return SERVER_NOT_IMPLEMENT;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
CreateTablePartitionTask::CreateTablePartitionTask(const thrift::CreateTablePartitionParam &param)
: BaseTask(DDL_DML_TASK_GROUP),
param_(param) {
}
BaseTaskPtr CreateTablePartitionTask::Create(const thrift::CreateTablePartitionParam &param) {
return std::shared_ptr<BaseTask>(new CreateTablePartitionTask(param));
}
ServerError CreateTablePartitionTask::OnExecute() {
error_code_ = SERVER_NOT_IMPLEMENT;
error_msg_ = "create table partition not implemented";
SERVER_LOG_ERROR << error_msg_;
return SERVER_NOT_IMPLEMENT;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DeleteTablePartitionTask::DeleteTablePartitionTask(const thrift::DeleteTablePartitionParam &param)
: BaseTask(DDL_DML_TASK_GROUP),
param_(param) {
}
BaseTaskPtr DeleteTablePartitionTask::Create(const thrift::DeleteTablePartitionParam &param) {
return std::shared_ptr<BaseTask>(new DeleteTablePartitionTask(param));
}
ServerError DeleteTablePartitionTask::OnExecute() {
error_code_ = SERVER_NOT_IMPLEMENT;
error_msg_ = "delete table partition not implemented";
SERVER_LOG_ERROR << error_msg_;
return SERVER_NOT_IMPLEMENT;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
ShowTablesTask::ShowTablesTask(std::vector<std::string>& tables)
: BaseTask(PING_TASK_GROUP),
tables_(tables) {
}
BaseTaskPtr ShowTablesTask::Create(std::vector<std::string>& tables) {
return std::shared_ptr<BaseTask>(new ShowTablesTask(tables));
}
ServerError ShowTablesTask::OnExecute() {
IVecIdMapper::GetInstance()->AllGroups(tables_);
return SERVER_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddVectorTask::AddVectorTask(const std::string& table_name,
......@@ -195,9 +254,9 @@ ServerError AddVectorTask::OnExecute() {
return SERVER_SUCCESS;
}
engine::meta::TableSchema table_schema;
table_schema.table_id = table_name_;
engine::Status stat = DB()->DescribeTable(table_schema);
engine::meta::TableSchema table_info;
table_info.table_id = table_name_;
engine::Status stat = DB()->DescribeTable(table_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
......@@ -208,7 +267,7 @@ ServerError AddVectorTask::OnExecute() {
rc.Record("get group info");
uint64_t vec_count = (uint64_t)record_array_.size();
uint64_t group_dim = table_schema.dimension;
uint64_t group_dim = table_info.dimension;
std::vector<float> vec_f;
vec_f.resize(vec_count*group_dim);//allocate enough memory
for(uint64_t i = 0; i < vec_count; i++) {
......@@ -228,6 +287,7 @@ ServerError AddVectorTask::OnExecute() {
return error_code_;
}
//convert double array to float array(thrift has no float type)
const double* d_p = reinterpret_cast<const double*>(record.vector_map.begin()->second.data());
for(uint64_t d = 0; d < vec_dim; d++) {
vec_f[i*vec_dim + d] = (float)(d_p[d]);
......@@ -245,12 +305,27 @@ ServerError AddVectorTask::OnExecute() {
return error_code_;
}
if(record_ids_.size() < vec_count) {
if(record_ids_.size() != vec_count) {
SERVER_LOG_ERROR << "Vector ID not returned";
return SERVER_UNEXPECTED_ERROR;
}
rc.Record("done");
//persist attributes
for(uint64_t i = 0; i < vec_count; i++) {
const auto &record = record_array_[i];
//any attributes?
if(record.attribute_map.empty()) {
continue;
}
std::string nid = std::to_string(record_ids_[i]);
std::string attrib_str;
AttributeSerializer::Encode(record.attribute_map, attrib_str);
IVecIdMapper::GetInstance()->Put(nid, attrib_str, table_name_);
}
rc.Record("persist vector attributes");
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
......@@ -293,9 +368,9 @@ ServerError SearchVectorTask::OnExecute() {
return error_code_;
}
engine::meta::TableSchema table_schema;
table_schema.table_id = table_name_;
engine::Status stat = DB()->DescribeTable(table_schema);
engine::meta::TableSchema table_info;
table_info.table_id = table_name_;
engine::Status stat = DB()->DescribeTable(table_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
......@@ -305,7 +380,7 @@ ServerError SearchVectorTask::OnExecute() {
std::vector<float> vec_f;
uint64_t record_count = (uint64_t)record_array_.size();
vec_f.resize(record_count*table_schema.dimension);
vec_f.resize(record_count*table_info.dimension);
for(uint64_t i = 0; i < record_array_.size(); i++) {
const auto& record = record_array_[i];
......@@ -317,14 +392,15 @@ ServerError SearchVectorTask::OnExecute() {
}
uint64_t vec_dim = record.vector_map.begin()->second.size() / sizeof(double);//how many double value?
if (vec_dim != table_schema.dimension) {
if (vec_dim != table_info.dimension) {
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
<< " vs. group dimension:" << table_schema.dimension;
<< " vs. group dimension:" << table_info.dimension;
error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
error_msg_ = "Engine failed: " + stat.ToString();
return error_code_;
}
//convert double array to float array(thrift has no float type)
const double* d_p = reinterpret_cast<const double*>(record.vector_map.begin()->second.data());
for(uint64_t d = 0; d < vec_dim; d++) {
vec_f[i*vec_dim + d] = (float)(d_p[d]);
......@@ -336,25 +412,50 @@ ServerError SearchVectorTask::OnExecute() {
std::vector<DB_DATE> dates;
engine::QueryResults results;
stat = DB()->Query(table_name_, (size_t)top_k_, record_count, vec_f.data(), dates, results);
rc.Record("search vectors from engine");
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
return SERVER_UNEXPECTED_ERROR;
} else {
rc.Record("do searching");
for(engine::QueryResult& result : results){
thrift::TopKQueryResult thrift_topk_result;
for(auto id : result) {
thrift::QueryResult thrift_result;
thrift_result.__set_id(id);
}
if(results.size() != record_count) {
SERVER_LOG_ERROR << "Search result not returned";
return SERVER_UNEXPECTED_ERROR;
}
//construct result array
for(uint64_t i = 0; i < record_count; i++) {
auto& result = results[i];
const auto& record = record_array_[i];
thrift::TopKQueryResult thrift_topk_result;
for(auto id : result) {
thrift::QueryResult thrift_result;
thrift_result.__set_id(id);
//need get attributes?
if(record.selected_column_array.empty()) {
thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
continue;
}
result_array_.emplace_back(thrift_topk_result);
std::string nid = std::to_string(id);
std::string attrib_str;
IVecIdMapper::GetInstance()->Get(nid, attrib_str, table_name_);
AttribMap attrib_map;
AttributeSerializer::Decode(attrib_str, attrib_map);
for(auto& attribute : record.selected_column_array) {
thrift_result.column_map[attribute] = attrib_map[attribute];
}
thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
}
rc.Record("construct result");
}
rc.Record("done");
result_array_.emplace_back(thrift_topk_result);
}
rc.Record("construct result");
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
......@@ -366,6 +467,26 @@ ServerError SearchVectorTask::OnExecute() {
return SERVER_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
PingTask::PingTask(const std::string& cmd, std::string& result)
: BaseTask(PING_TASK_GROUP),
cmd_(cmd),
result_(result) {
}
BaseTaskPtr PingTask::Create(const std::string& cmd, std::string& result) {
return std::shared_ptr<BaseTask>(new PingTask(cmd, result));
}
ServerError PingTask::OnExecute() {
if(cmd_ == "version") {
result_ = "v1.2.0";//currently hardcode
}
return SERVER_SUCCESS;
}
}
}
}
......@@ -65,6 +65,50 @@ private:
std::string table_name_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class CreateTablePartitionTask : public BaseTask {
public:
static BaseTaskPtr Create(const thrift::CreateTablePartitionParam &param);
protected:
CreateTablePartitionTask(const thrift::CreateTablePartitionParam &param);
ServerError OnExecute() override;
private:
const thrift::CreateTablePartitionParam &param_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class DeleteTablePartitionTask : public BaseTask {
public:
static BaseTaskPtr Create(const thrift::DeleteTablePartitionParam &param);
protected:
DeleteTablePartitionTask(const thrift::DeleteTablePartitionParam &param);
ServerError OnExecute() override;
private:
const thrift::DeleteTablePartitionParam &param_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class ShowTablesTask : public BaseTask {
public:
static BaseTaskPtr Create(std::vector<std::string>& tables);
protected:
ShowTablesTask(std::vector<std::string>& tables);
ServerError OnExecute() override;
private:
std::vector<std::string>& tables_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class AddVectorTask : public BaseTask {
public:
......@@ -108,6 +152,21 @@ private:
std::vector<thrift::TopKQueryResult>& result_array_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class PingTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& cmd, std::string& result);
protected:
PingTask(const std::string& cmd, std::string& result);
ServerError OnExecute() override;
private:
std::string cmd_;
std::string& result_;
};
}
}
}
\ No newline at end of file
......@@ -108,6 +108,19 @@ bool RocksIdMapper::IsGroupExist(const std::string& group) const {
return IsGroupExistInternal(group);
}
ServerError RocksIdMapper::AllGroups(std::vector<std::string>& groups) const {
groups.clear();
std::lock_guard<std::mutex> lck(db_mutex_);
for(auto& pair : column_handles_) {
if(pair.first == ROCKSDB_DEFAULT_GROUP) {
continue;
}
groups.push_back(pair.first);
}
return SERVER_SUCCESS;
}
ServerError RocksIdMapper::Put(const std::string& nid, const std::string& sid, const std::string& group) {
std::lock_guard<std::mutex> lck(db_mutex_);
......
......@@ -26,6 +26,7 @@ class RocksIdMapper : public IVecIdMapper{
ServerError AddGroup(const std::string& group) override;
bool IsGroupExist(const std::string& group) const override;
ServerError AllGroups(std::vector<std::string>& groups) const override;
ServerError Put(const std::string& nid, const std::string& sid, const std::string& group = "") override;
ServerError Put(const std::vector<std::string>& nid, const std::vector<std::string>& sid, const std::string& group = "") override;
......
......@@ -52,6 +52,15 @@ SimpleIdMapper::IsGroupExist(const std::string& group) const {
return id_groups_.count(group) > 0;
}
ServerError SimpleIdMapper::AllGroups(std::vector<std::string>& groups) const {
groups.clear();
for(auto& pair : id_groups_) {
groups.push_back(pair.first);
}
return SERVER_SUCCESS;
}
//not thread-safe
ServerError SimpleIdMapper::Put(const std::string& nid, const std::string& sid, const std::string& group) {
......
......@@ -27,6 +27,7 @@ public:
virtual ServerError AddGroup(const std::string& group) = 0;
virtual bool IsGroupExist(const std::string& group) const = 0;
virtual ServerError AllGroups(std::vector<std::string>& groups) const = 0;
virtual ServerError Put(const std::string& nid, const std::string& sid, const std::string& group = "") = 0;
virtual ServerError Put(const std::vector<std::string>& nid, const std::vector<std::string>& sid, const std::string& group = "") = 0;
......@@ -46,6 +47,7 @@ public:
ServerError AddGroup(const std::string& group) override;
bool IsGroupExist(const std::string& group) const override;
ServerError AllGroups(std::vector<std::string>& groups) const override;
ServerError Put(const std::string& nid, const std::string& sid, const std::string& group = "") override;
ServerError Put(const std::vector<std::string>& nid, const std::vector<std::string>& sid, const std::string& group = "") override;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册