提交 4b244f9d 编写于 作者: G groot

redefine thrift api


Former-commit-id: 8d7d1c47202341521cf06a567e8f7cf3cfc55074
上级 ec579df5
......@@ -99,7 +99,6 @@ link_directories(${VECWISE_THIRD_PARTY_BUILD}/lib64)
add_subdirectory(src)
add_subdirectory(test_client)
if (BUILD_UNIT_TEST)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unittest)
......
......@@ -22,7 +22,7 @@ set(license_generator_src
)
set(service_files
thrift/gen-cpp/VecService.cpp
thrift/gen-cpp/MegasearchService.cpp
thrift/gen-cpp/megasearch_constants.cpp
thrift/gen-cpp/megasearch_types.cpp
)
......@@ -39,6 +39,7 @@ set(get_sys_info_src
include_directories(/usr/include)
include_directories(/usr/local/cuda/include)
include_directories(thrift/gen-cpp)
if (GPU_VERSION STREQUAL "ON")
link_directories(/usr/local/cuda/lib64)
......@@ -126,4 +127,6 @@ if (ENABLE_LICENSE STREQUAL "ON")
install(TARGETS get_sys_info DESTINATION bin)
endif ()
install(TARGETS vecwise_server DESTINATION bin)
\ No newline at end of file
install(TARGETS vecwise_server DESTINATION bin)
add_subdirectory(sdk)
\ No newline at end of file
#-------------------------------------------------------------------------------
# Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
# Unauthorized copying of this file, via any medium is strictly prohibited.
# Proprietary and confidential.
#-------------------------------------------------------------------------------
aux_source_directory(src/interface interface_files)
aux_source_directory(src/client client_files)
aux_source_directory(src/util util_files)
include_directories(src)
include_directories(include)
include_directories(/usr/include)
include_directories(${CMAKE_SOURCE_DIR}/src/thrift/gen-cpp)
set(service_files
${CMAKE_SOURCE_DIR}/src/thrift/gen-cpp/MegasearchService.cpp
${CMAKE_SOURCE_DIR}/src/thrift/gen-cpp/megasearch_constants.cpp
${CMAKE_SOURCE_DIR}/src/thrift/gen-cpp/megasearch_types.cpp
)
add_library(megasearch_sdk STATIC
${interface_files}
${client_files}
${util_files}
${service_files}
)
link_directories(../../third_party/build/lib)
target_link_libraries(megasearch_sdk
libthrift.a
pthread
)
add_subdirectory(examples)
#include "MegaSearch.h"
namespace megasearch {
std::shared_ptr<Connection>
Create() {
return nullptr;
}
Status
Destroy(std::shared_ptr<Connection> &connection_ptr) {
return Status::OK();
}
/**
Status
Connection::Connect(const ConnectParam &param) {
return Status::NotSupported("Connect interface is not supported.");
}
Status
Connection::Connect(const std::string &uri) {
return Status::NotSupported("Connect interface is not supported.");
}
Status
Connection::Connected() const {
return Status::NotSupported("Connected interface is not supported.");
}
Status
Connection::Disconnect() {
return Status::NotSupported("Disconnect interface is not supported.");
}
std::string
Connection::ClientVersion() const {
return std::string("Current Version");
}
Status
Connection::CreateTable(const TableSchema &param) {
return Status::NotSupported("Create table interface interface is not supported.");
}
Status
Connection::CreateTablePartition(const CreateTablePartitionParam &param) {
return Status::NotSupported("Create table partition interface is not supported.");
}
Status
Connection::DeleteTablePartition(const DeleteTablePartitionParam &param) {
return Status::NotSupported("Delete table partition interface is not supported.");
}
Status
Connection::DeleteTable(const std::string &table_name) {
return Status::NotSupported("Create table interface is not supported.");
}
Status
Connection::AddVector(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) {
return Status::NotSupported("Add vector array interface is not supported.");
}
Status
Connection::SearchVector(const std::string &table_name,
const std::vector<QueryRecord> &query_record_array,
std::vector<TopKQueryResult> &topk_query_result_array,
int64_t topk) {
return Status::NotSupported("Query vector array interface is not supported.");
}
Status
Connection::DescribeTable(const std::string &table_name, TableSchema &table_schema) {
return Status::NotSupported("Show table interface is not supported.");
}
Status
Connection::ShowTables(std::vector<std::string> &table_array) {
return Status::NotSupported("List table array interface is not supported.");
}
std::string
Connection::ServerVersion() const {
return std::string("Server version.");
}
std::string
Connection::ServerStatus() const {
return std::string("Server status");
}
**/
}
\ No newline at end of file
#-------------------------------------------------------------------------------
# Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
# Unauthorized copying of this file, via any medium is strictly prohibited.
# Proprietary and confidential.
#-------------------------------------------------------------------------------
add_subdirectory(simple)
#-------------------------------------------------------------------------------
# Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
# Unauthorized copying of this file, via any medium is strictly prohibited.
# Proprietary and confidential.
#-------------------------------------------------------------------------------
aux_source_directory(src src_files)
include_directories(src)
include_directories(../../megasearch_sdk/include)
include_directories(/usr/include)
link_directories(${CMAKE_BINARY_DIR}/megasearch_sdk)
add_executable(sdk_simple
./main.cpp
${src_files}
${service_files}
)
target_link_libraries(sdk_simple
megasearch_sdk
pthread
)
......@@ -8,16 +8,8 @@
#include <libgen.h>
#include <cstring>
#include <string>
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <easylogging++.h>
#include "src/FaissTest.h"
#include "src/Log.h"
#include "src/ClientTest.h"
#include "server/ServerConfig.h"
INITIALIZE_EASYLOGGINGPP
void print_help(const std::string &app_name);
......@@ -26,58 +18,47 @@ int
main(int argc, char *argv[]) {
printf("Client start...\n");
// FaissTest::test();
// return 0;
std::string app_name = basename(argv[0]);
static struct option long_options[] = {{"conf_file", optional_argument, 0, 'c'},
{"help", no_argument, 0, 'h'},
{NULL, 0, 0, 0}};
int option_index = 0;
std::string config_filename = "../../conf/server_config.yaml";
std::string address = "127.0.0.1", port = "33001";
app_name = argv[0];
int value;
while ((value = getopt_long(argc, argv, "c:p:dh", long_options, &option_index)) != -1) {
switch (value) {
case 'c': {
char *config_filename_ptr = strdup(optarg);
config_filename = config_filename_ptr;
free(config_filename_ptr);
case 'h': {
char *address_ptr = strdup(optarg);
address = address_ptr;
free(address_ptr);
break;
}
case 'p': {
char *port_ptr = strdup(optarg);
address = port_ptr;
free(port_ptr);
break;
}
case 'h':
print_help(app_name);
return EXIT_SUCCESS;
case '?':
print_help(app_name);
return EXIT_FAILURE;
default:
print_help(app_name);
break;
}
}
zilliz::vecwise::server::ServerConfig& config = zilliz::vecwise::server::ServerConfig::GetInstance();
config.LoadConfigFile(config_filename);
CLIENT_LOG_INFO << "Load config file:" << config_filename;
ClientTest test;
test.Test(address, port);
#if 1
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
#else
zilliz::vecwise::client::ClientTest::LoopTest();
printf("Client stop...\n");
return 0;
#endif
}
void
print_help(const std::string &app_name) {
printf("\n Usage: %s [OPTIONS]\n\n", app_name.c_str());
printf(" Options:\n");
printf(" -h --help Print this help\n");
printf(" -c --conf_file filename Read configuration from the file\n");
printf(" -h Megasearch server address\n");
printf(" -p Megasearch server port\n");
printf("\n");
}
\ No newline at end of file
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "ClientTest.h"
#include "MegaSearch.h"
#include <iostream>
namespace {
static const std::string TABLE_NAME = "human face";
static const int64_t TABLE_DIMENSION = 512;
void PrintTableSchema(const megasearch::TableSchema& tb_schema) {
std::cout << "===========================================" << std::endl;
std::cout << "Table name: " << tb_schema.table_name << std::endl;
std::cout << "Table vectors: " << tb_schema.vector_column_array.size() << std::endl;
std::cout << "Table attributes: " << tb_schema.attribute_column_array.size() << std::endl;
std::cout << "Table partitions: " << tb_schema.partition_column_name_array.size() << std::endl;
std::cout << "===========================================" << std::endl;
}
}
void
ClientTest::Test(const std::string& address, const std::string& port) {
std::shared_ptr<megasearch::Connection> conn = megasearch::Connection::Create();
megasearch::ConnectParam param = { address, port };
conn->Connect(param);
{//create table
megasearch::TableSchema tb_schema;
megasearch::VectorColumn col1;
col1.name = "face";
col1.dimension = TABLE_DIMENSION;
col1.store_raw_vector = true;
tb_schema.vector_column_array.emplace_back(col1);
megasearch::Column col2;
col2.name = "age";
tb_schema.attribute_column_array.emplace_back(col2);
tb_schema.table_name = TABLE_NAME;
PrintTableSchema(tb_schema);
megasearch::Status stat = conn->CreateTable(tb_schema);
std::cout << "Create table result: " << stat.ToString() << std::endl;
}
{//describe table
megasearch::TableSchema tb_schema;
megasearch::Status stat = conn->DescribeTable(TABLE_NAME, tb_schema);
std::cout << "Describe table result: " << stat.ToString() << std::endl;
PrintTableSchema(tb_schema);
}
{//add vectors
}
{//search vectors
}
{//delete table
megasearch::Status stat = conn->DeleteTable(TABLE_NAME);
std::cout << "Delete table result: " << stat.ToString() << std::endl;
}
megasearch::Connection::Destroy(conn);
}
\ No newline at end of file
......@@ -5,16 +5,9 @@
******************************************************************************/
#pragma once
namespace zilliz {
namespace vecwise {
namespace client {
#include <string>
class ClientTest {
public:
static void LoopTest();
void Test(const std::string& address, const std::string& port);
};
}
}
}
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "ClientProxy.h"
#include "util/ConvertUtil.h"
namespace megasearch {
std::shared_ptr<ThriftClient>&
ClientProxy::ClientPtr() const {
if(client_ptr == nullptr) {
client_ptr = std::make_shared<ThriftClient>();
}
return client_ptr;
}
Status
ClientProxy::Connect(const ConnectParam &param) {
Disconnect();
int32_t port = atoi(param.port.c_str());
return ClientPtr()->Connect(param.ip_address, port, "json");
}
Status
ClientProxy::Connect(const std::string &uri) {
Disconnect();
return Status::NotSupported("Connect interface is not supported.");
}
Status
ClientProxy::Connected() const {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
try {
std::string info;
ClientPtr()->interface()->Ping(info, "");
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "connection lost: " + std::string(ex.what()));
}
return Status::OK();
}
Status
ClientProxy::Disconnect() {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
return ClientPtr()->Disconnect();
}
std::string
ClientProxy::ClientVersion() const {
return std::string("Current Version");
}
Status
ClientProxy::CreateTable(const TableSchema &param) {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
try {
thrift::TableSchema schema;
schema.__set_table_name(param.table_name);
std::vector<thrift::VectorColumn> vector_column_array;
for(auto& column : param.vector_column_array) {
thrift::VectorColumn col;
col.__set_dimension(column.dimension);
col.__set_index_type(ConvertUtil::IndexType2Str(column.index_type));
col.__set_store_raw_vector(column.store_raw_vector);
vector_column_array.emplace_back(col);
}
schema.__set_vector_column_array(vector_column_array);
std::vector<thrift::Column> attribute_column_array;
for(auto& column : param.attribute_column_array) {
thrift::Column col;
col.__set_name(col.name);
col.__set_type(col.type);
attribute_column_array.emplace_back(col);
}
schema.__set_attribute_column_array(attribute_column_array);
schema.__set_partition_column_name_array(param.partition_column_name_array);
ClientPtr()->interface()->CreateTable(schema);
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to create table: " + std::string(ex.what()));
}
return Status::OK();
}
Status
ClientProxy::CreateTablePartition(const CreateTablePartitionParam &param) {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
try {
thrift::CreateTablePartitionParam partition_param;
partition_param.__set_table_name(param.table_name);
partition_param.__set_partition_name(param.partition_name);
std::map<std::string, thrift::Range> range_map;
for(auto& pair : param.range_map) {
thrift::Range range;
range.__set_start_value(pair.second.start_value);
range.__set_end_value(pair.second.end_value);
range_map.insert(std::make_pair(pair.first, range));
}
partition_param.__set_range_map(range_map);
ClientPtr()->interface()->CreateTablePartition(partition_param);
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to create table partition: " + std::string(ex.what()));
}
return Status::OK();
}
Status
ClientProxy::DeleteTablePartition(const DeleteTablePartitionParam &param) {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
try {
thrift::DeleteTablePartitionParam partition_param;
partition_param.__set_table_name(param.table_name);
partition_param.__set_partition_name_array(param.partition_name_array);
ClientPtr()->interface()->DeleteTablePartition(partition_param);
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to delete table partition: " + std::string(ex.what()));
}
return Status::OK();
}
Status
ClientProxy::DeleteTable(const std::string &table_name) {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
try {
ClientPtr()->interface()->DeleteTable(table_name);
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to delete table: " + std::string(ex.what()));
}
return Status::OK();
}
Status
ClientProxy::AddVector(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
try {
std::vector<thrift::RowRecord> thrift_records;
for(auto& record : record_array) {
thrift::RowRecord thrift_record;
thrift_record.__set_attribute_map(record.attribute_map);
for(auto& pair : record.vector_map) {
size_t dim = pair.second.size();
std::string& thrift_vector = thrift_record.vector_map[pair.first];
thrift_vector.resize(dim * sizeof(double));
double *dbl = (double *) (const_cast<char *>(thrift_vector.data()));
for (size_t i = 0; i < dim; i++) {
dbl[i] = (double) (pair.second[i]);
}
}
thrift_records.emplace_back(thrift_record);
}
ClientPtr()->interface()->AddVector(id_array, table_name, thrift_records);
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to add vector: " + std::string(ex.what()));
}
return Status::OK();
}
Status
ClientProxy::SearchVector(const std::string &table_name,
const std::vector<QueryRecord> &query_record_array,
std::vector<TopKQueryResult> &topk_query_result_array,
int64_t topk) {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
try {
std::vector<thrift::QueryRecord> thrift_records;
for(auto& record : query_record_array) {
thrift::QueryRecord thrift_record;
thrift_record.__set_selected_column_array(record.selected_column_array);
for(auto& pair : record.vector_map) {
size_t dim = pair.second.size();
std::string& thrift_vector = thrift_record.vector_map[pair.first];
thrift_vector.resize(dim * sizeof(double));
double *dbl = (double *) (const_cast<char *>(thrift_vector.data()));
for (size_t i = 0; i < dim; i++) {
dbl[i] = (double) (pair.second[i]);
}
}
thrift_records.emplace_back(thrift_record);
}
std::vector<thrift::TopKQueryResult> result_array;
ClientPtr()->interface()->SearchVector(result_array, table_name, thrift_records, topk);
for(auto& thrift_topk_result : result_array) {
TopKQueryResult result;
for(auto& thrift_query_result : thrift_topk_result.query_result_arrays) {
QueryResult query_result;
query_result.id = thrift_query_result.id;
query_result.column_map = thrift_query_result.column_map;
query_result.score = thrift_query_result.score;
result.query_result_arrays.emplace_back(query_result);
}
topk_query_result_array.emplace_back(result);
}
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to create table partition: " + std::string(ex.what()));
}
return Status::OK();
}
Status
ClientProxy::DescribeTable(const std::string &table_name, TableSchema &table_schema) {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
try {
thrift::TableSchema thrift_schema;
ClientPtr()->interface()->DescribeTable(thrift_schema, table_name);
table_schema.table_name = thrift_schema.table_name;
table_schema.partition_column_name_array = thrift_schema.partition_column_name_array;
for(auto& thrift_col : thrift_schema.attribute_column_array) {
Column col;
col.name = col.name;
col.type = col.type;
table_schema.attribute_column_array.emplace_back(col);
}
for(auto& thrift_col : thrift_schema.vector_column_array) {
VectorColumn col;
col.store_raw_vector = thrift_col.store_raw_vector;
col.index_type = ConvertUtil::Str2IndexType(thrift_col.index_type);
col.dimension = thrift_col.dimension;
col.name = thrift_col.base.name;
col.type = (ColumnType)thrift_col.base.type;
table_schema.vector_column_array.emplace_back(col);
}
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to describe table: " + std::string(ex.what()));
}
return Status::OK();
}
Status
ClientProxy::ShowTables(std::vector<std::string> &table_array) {
if(client_ptr == nullptr) {
return Status(StatusCode::UnknownError, "not connected");
}
try {
ClientPtr()->interface()->ShowTables(table_array);
} catch ( std::exception& ex) {
return Status(StatusCode::UnknownError, "failed to show tables: " + std::string(ex.what()));
}
return Status::OK();
}
std::string
ClientProxy::ServerVersion() const {
if(client_ptr == nullptr) {
return "";
}
try {
std::string version;
ClientPtr()->interface()->Ping(version, "version");
return version;
} catch ( std::exception& ex) {
return "";
}
}
std::string
ClientProxy::ServerStatus() const {
if(client_ptr == nullptr) {
return "not connected";
}
try {
std::string dummy;
ClientPtr()->interface()->Ping(dummy, "");
return "server alive";
} catch ( std::exception& ex) {
return "connection lost";
}
}
}
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "MegaSearch.h"
#include "ThriftClient.h"
namespace megasearch {
class ClientProxy : public Connection {
public:
// Implementations of the Connection interface
virtual Status Connect(const ConnectParam &param) override;
virtual Status Connect(const std::string &uri) override;
virtual Status Connected() const override;
virtual Status Disconnect() override;
virtual Status CreateTable(const TableSchema &param) override;
virtual Status DeleteTable(const std::string &table_name) override;
virtual Status CreateTablePartition(const CreateTablePartitionParam &param) override;
virtual Status DeleteTablePartition(const DeleteTablePartitionParam &param) override;
virtual Status AddVector(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) override;
virtual Status SearchVector(const std::string &table_name,
const std::vector<QueryRecord> &query_record_array,
std::vector<TopKQueryResult> &topk_query_result_array,
int64_t topk) override;
virtual Status DescribeTable(const std::string &table_name, TableSchema &table_schema) override;
virtual Status ShowTables(std::vector<std::string> &table_array) override;
virtual std::string ClientVersion() const override;
virtual std::string ServerVersion() const override;
virtual std::string ServerStatus() const override;
private:
std::shared_ptr<ThriftClient>& ClientPtr() const;
private:
mutable std::shared_ptr<ThriftClient> client_ptr;
};
}
......@@ -3,11 +3,10 @@
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "ClientSession.h"
#include "Log.h"
#include "ThriftClient.h"
#include "thrift/gen-cpp/megasearch_types.h"
#include "thrift/gen-cpp/megasearch_constants.h"
#include "megasearch_types.h"
#include "megasearch_constants.h"
#include <exception>
......@@ -22,19 +21,31 @@
#include <thrift/transport/TBufferTransports.h>
#include <thrift/concurrency/PosixThreadFactory.h>
namespace zilliz {
namespace vecwise {
namespace client {
using namespace megasearch;
namespace megasearch {
using namespace ::apache::thrift;
using namespace ::apache::thrift::protocol;
using namespace ::apache::thrift::transport;
using namespace ::apache::thrift::concurrency;
ClientSession::ClientSession(const std::string &address, int32_t port, const std::string &protocol)
: client_(nullptr) {
ThriftClient::ThriftClient() {
}
ThriftClient::~ThriftClient() {
}
MegasearchServiceClientPtr
ThriftClient::interface() {
if(client_ == nullptr) {
throw std::exception();
}
return client_;
}
Status
ThriftClient::Connect(const std::string& address, int32_t port, const std::string& protocol) {
try {
stdcxx::shared_ptr<TSocket> socket_ptr(new transport::TSocket(address, port));
stdcxx::shared_ptr<TTransport> transport_ptr(new TBufferedTransport(socket_ptr));
......@@ -48,19 +59,21 @@ ClientSession::ClientSession(const std::string &address, int32_t port, const std
} else if(protocol == "debug") {
protocol_ptr.reset(new TDebugProtocol(transport_ptr));
} else {
CLIENT_LOG_ERROR << "Service protocol: " << protocol << " is not supported currently";
return;
//CLIENT_LOG_ERROR << "Service protocol: " << protocol << " is not supported currently";
return Status(StatusCode::Invalid, "unsupported protocol");
}
transport_ptr->open();
client_ = std::make_shared<VecServiceClient>(protocol_ptr);
client_ = std::make_shared<thrift::MegasearchServiceClient>(protocol_ptr);
} catch ( std::exception& ex) {
CLIENT_LOG_ERROR << "connect encounter exception: " << ex.what();
//CLIENT_LOG_ERROR << "connect encounter exception: " << ex.what();
return Status(StatusCode::UnknownError, "failed to connect megasearch server" + std::string(ex.what()));
}
return Status::OK();
}
ClientSession::~ClientSession() {
Status
ThriftClient::Disconnect() {
try {
if(client_ != nullptr) {
auto protocol = client_->getInputProtocol();
......@@ -72,17 +85,20 @@ ClientSession::~ClientSession() {
}
}
} catch ( std::exception& ex) {
CLIENT_LOG_ERROR << "disconnect encounter exception: " << ex.what();
//CLIENT_LOG_ERROR << "disconnect encounter exception: " << ex.what();
return Status(StatusCode::UnknownError, "failed to disconnect: " + std::string(ex.what()));
}
}
VecServiceClientPtr ClientSession::interface() {
if(client_ == nullptr) {
throw std::exception();
}
return client_;
return Status::OK();
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////
ThriftClientSession::ThriftClientSession(const std::string& address, int32_t port, const std::string& protocol) {
Connect(address, port, protocol);
}
ThriftClientSession::~ThriftClientSession() {
Disconnect();
}
}
\ No newline at end of file
......@@ -5,26 +5,34 @@
******************************************************************************/
#pragma once
#include "thrift/gen-cpp/VecService.h"
#include "MegasearchService.h"
#include "Status.h"
#include <memory>
namespace zilliz {
namespace vecwise {
namespace client {
namespace megasearch {
using VecServiceClientPtr = std::shared_ptr<megasearch::VecServiceClient>;
using MegasearchServiceClientPtr = std::shared_ptr<megasearch::thrift::MegasearchServiceClient>;
class ClientSession {
class ThriftClient {
public:
ClientSession(const std::string& address, int32_t port, const std::string& protocol);
~ClientSession();
ThriftClient();
virtual ~ThriftClient();
VecServiceClientPtr interface();
MegasearchServiceClientPtr interface();
VecServiceClientPtr client_;
Status Connect(const std::string& address, int32_t port, const std::string& protocol);
Status Disconnect();
private:
MegasearchServiceClientPtr client_;
};
class ThriftClientSession : public ThriftClient {
public:
ThriftClientSession(const std::string& address, int32_t port, const std::string& protocol);
~ThriftClientSession();
};
}
}
}
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "ConnectionImpl.h"
namespace megasearch {
std::shared_ptr<Connection>
Connection::Create() {
return std::shared_ptr<Connection>(new ConnectionImpl());
}
Status
Connection::Destroy(std::shared_ptr<megasearch::Connection> connection_ptr) {
if(connection_ptr != nullptr) {
return connection_ptr->Disconnect();
}
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////////////////////////
ConnectionImpl::ConnectionImpl() {
client_proxy = std::make_shared<ClientProxy>();
}
Status
ConnectionImpl::Connect(const ConnectParam &param) {
return client_proxy->Connect(param);
}
Status
ConnectionImpl::Connect(const std::string &uri) {
return client_proxy->Connect(uri);
}
Status
ConnectionImpl::Connected() const {
return client_proxy->Connected();
}
Status
ConnectionImpl::Disconnect() {
return client_proxy->Disconnect();
}
std::string
ConnectionImpl::ClientVersion() const {
return client_proxy->ClientVersion();
}
Status
ConnectionImpl::CreateTable(const TableSchema &param) {
return client_proxy->CreateTable(param);
}
Status
ConnectionImpl::CreateTablePartition(const CreateTablePartitionParam &param) {
return client_proxy->CreateTablePartition(param);
}
Status
ConnectionImpl::DeleteTablePartition(const DeleteTablePartitionParam &param) {
return client_proxy->DeleteTablePartition(param);
}
Status
ConnectionImpl::DeleteTable(const std::string &table_name) {
return client_proxy->DeleteTable(table_name);
}
Status
ConnectionImpl::AddVector(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) {
return client_proxy->AddVector(table_name, record_array, id_array);
}
Status
ConnectionImpl::SearchVector(const std::string &table_name,
const std::vector<QueryRecord> &query_record_array,
std::vector<TopKQueryResult> &topk_query_result_array,
int64_t topk) {
return client_proxy->SearchVector(table_name, query_record_array, topk_query_result_array, topk);
}
Status
ConnectionImpl::DescribeTable(const std::string &table_name, TableSchema &table_schema) {
return client_proxy->DescribeTable(table_name, table_schema);
}
Status
ConnectionImpl::ShowTables(std::vector<std::string> &table_array) {
return client_proxy->ShowTables(table_array);
}
std::string
ConnectionImpl::ServerVersion() const {
return client_proxy->ServerVersion();
}
std::string
ConnectionImpl::ServerStatus() const {
return client_proxy->ServerStatus();
}
}
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "MegaSearch.h"
#include "client/ClientProxy.h"
namespace megasearch {
class ConnectionImpl : public Connection {
public:
ConnectionImpl();
// Implementations of the Connection interface
virtual Status Connect(const ConnectParam &param) override;
virtual Status Connect(const std::string &uri) override;
virtual Status Connected() const override;
virtual Status Disconnect() override;
virtual Status CreateTable(const TableSchema &param) override;
virtual Status DeleteTable(const std::string &table_name) override;
virtual Status CreateTablePartition(const CreateTablePartitionParam &param) override;
virtual Status DeleteTablePartition(const DeleteTablePartitionParam &param) override;
virtual Status AddVector(const std::string &table_name,
const std::vector<RowRecord> &record_array,
std::vector<int64_t> &id_array) override;
virtual Status SearchVector(const std::string &table_name,
const std::vector<QueryRecord> &query_record_array,
std::vector<TopKQueryResult> &topk_query_result_array,
int64_t topk) override;
virtual Status DescribeTable(const std::string &table_name, TableSchema &table_schema) override;
virtual Status ShowTables(std::vector<std::string> &table_array) override;
virtual std::string ClientVersion() const override;
virtual std::string ServerVersion() const override;
virtual std::string ServerStatus() const override;
private:
std::shared_ptr<ClientProxy> client_proxy;
};
}
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "Status.h"
......@@ -22,7 +27,7 @@ void Status::MoveFrom(Status &s) {
}
Status::Status(const Status &s)
: state_((s.state_ == nullptr) ? nullptr : new State(*s.state_)) {}
: state_((s.state_ == nullptr) ? nullptr : new State(*s.state_)) {}
Status &Status::operator=(const Status &s) {
if (state_ != s.state_) {
......@@ -112,4 +117,4 @@ std::string Status::ToString() const {
return result;
}
}
\ No newline at end of file
}
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "ConvertUtil.h"
#include <map>
namespace megasearch {
static const std::string INDEX_RAW = "raw";
static const std::string INDEX_IVFFLAT = "ivfflat";
std::string ConvertUtil::IndexType2Str(megasearch::IndexType index) {
static const std::map<megasearch::IndexType, std::string> s_index2str = {
{megasearch::IndexType::raw, INDEX_RAW},
{megasearch::IndexType::ivfflat, INDEX_IVFFLAT}
};
const auto& iter = s_index2str.find(index);
if(iter == s_index2str.end()) {
return INDEX_RAW;
}
return iter->second;
}
megasearch::IndexType ConvertUtil::Str2IndexType(const std::string& type) {
static const std::map<std::string, megasearch::IndexType> s_str2index = {
{INDEX_RAW, megasearch::IndexType::raw},
{INDEX_IVFFLAT, megasearch::IndexType::ivfflat}
};
const auto& iter = s_str2index.find(type);
if(iter == s_str2index.end()) {
return megasearch::IndexType::raw;
}
return iter->second;
}
}
\ No newline at end of file
......@@ -5,15 +5,14 @@
******************************************************************************/
#pragma once
namespace zilliz {
namespace vecwise {
namespace client {
#include "MegaSearch.h"
class FaissTest {
namespace megasearch {
class ConvertUtil {
public:
static void test();
static std::string IndexType2Str(megasearch::IndexType index);
static megasearch::IndexType Str2IndexType(const std::string& type);
};
}
}
}
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "MegasearchHandler.h"
#include "MegasearchTask.h"
#include "utils/TimeRecorder.h"
namespace zilliz {
namespace vecwise {
namespace server {
using namespace megasearch;
MegasearchServiceHandler::MegasearchServiceHandler() {
}
void
MegasearchServiceHandler::CreateTable(const thrift::TableSchema &param) {
BaseTaskPtr task_ptr = CreateTableTask::Create(param);
MegasearchScheduler::ExecTask(task_ptr);
}
void
MegasearchServiceHandler::DeleteTable(const std::string &table_name) {
BaseTaskPtr task_ptr = DeleteTableTask::Create(table_name);
MegasearchScheduler::ExecTask(task_ptr);
}
void
MegasearchServiceHandler::CreateTablePartition(const thrift::CreateTablePartitionParam &param) {
// Your implementation goes here
printf("CreateTablePartition\n");
}
void
MegasearchServiceHandler::DeleteTablePartition(const thrift::DeleteTablePartitionParam &param) {
// Your implementation goes here
printf("DeleteTablePartition\n");
}
void
MegasearchServiceHandler::AddVector(std::vector<int64_t> &_return,
const std::string &table_name,
const std::vector<thrift::RowRecord> &record_array) {
BaseTaskPtr task_ptr = AddVectorTask::Create(table_name, record_array, _return);
MegasearchScheduler::ExecTask(task_ptr);
}
void
MegasearchServiceHandler::SearchVector(std::vector<thrift::TopKQueryResult> &_return,
const std::string &table_name,
const std::vector<thrift::QueryRecord> &query_record_array,
const int64_t topk) {
BaseTaskPtr task_ptr = SearchVectorTask::Create(table_name, query_record_array, topk, _return);
MegasearchScheduler::ExecTask(task_ptr);
}
void
MegasearchServiceHandler::DescribeTable(thrift::TableSchema &_return, const std::string &table_name) {
BaseTaskPtr task_ptr = DescribeTableTask::Create(table_name, _return);
MegasearchScheduler::ExecTask(task_ptr);
}
void
MegasearchServiceHandler::ShowTables(std::vector<std::string> &_return) {
// Your implementation goes here
printf("ShowTables\n");
}
void
MegasearchServiceHandler::Ping(std::string& _return, const std::string& cmd) {
// Your implementation goes here
printf("Ping\n");
}
}
}
}
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include <cstdint>
#include <string>
#include "MegasearchService.h"
namespace zilliz {
namespace vecwise {
namespace server {
class MegasearchServiceHandler : virtual public megasearch::thrift::MegasearchServiceIf {
public:
MegasearchServiceHandler();
/**
* @brief Create table method
*
* This method is used to create table
*
* @param param, use to provide table information to be created.
*
*
* @param param
*/
void CreateTable(const megasearch::thrift::TableSchema& param);
/**
* @brief Delete table method
*
* This method is used to delete table.
*
* @param table_name, table name is going to be deleted.
*
*
* @param table_name
*/
void DeleteTable(const std::string& table_name);
/**
* @brief Create table partition
*
* This method is used to create table partition.
*
* @param param, use to provide partition information to be created.
*
*
* @param param
*/
void CreateTablePartition(const megasearch::thrift::CreateTablePartitionParam& param);
/**
* @brief Delete table partition
*
* This method is used to delete table partition.
*
* @param param, use to provide partition information to be deleted.
*
*
* @param param
*/
void DeleteTablePartition(const megasearch::thrift::DeleteTablePartitionParam& param);
/**
* @brief Add vector array to table
*
* This method is used to add vector array to table.
*
* @param table_name, table_name is inserted.
* @param record_array, vector array is inserted.
*
* @return vector id array
*
* @param table_name
* @param record_array
*/
void AddVector(std::vector<int64_t> & _return,
const std::string& table_name,
const std::vector<megasearch::thrift::RowRecord> & record_array);
/**
* @brief Query vector
*
* This method is used to query vector in table.
*
* @param table_name, table_name is queried.
* @param query_record_array, all vector are going to be queried.
* @param topk, how many similarity vectors will be searched.
*
* @return query result array.
*
* @param table_name
* @param query_record_array
* @param topk
*/
void SearchVector(std::vector<megasearch::thrift::TopKQueryResult> & _return,
const std::string& table_name,
const std::vector<megasearch::thrift::QueryRecord> & query_record_array,
const int64_t topk);
/**
* @brief Show table information
*
* This method is used to show table information.
*
* @param table_name, which table is show.
*
* @return table schema
*
* @param table_name
*/
void DescribeTable(megasearch::thrift::TableSchema& _return, const std::string& table_name);
/**
* @brief List all tables in database
*
* This method is used to list all tables.
*
*
* @return table names.
*/
void ShowTables(std::vector<std::string> & _return);
/**
* @brief Give the server status
*
* This method is used to give the server status.
*
* @return Server status.
*
* @param cmd
*/
void Ping(std::string& _return, const std::string& cmd);
};
}
}
}
......@@ -3,12 +3,51 @@
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "VecServiceScheduler.h"
#include "MegasearchScheduler.h"
#include "utils/Log.h"
#include "megasearch_types.h"
#include "megasearch_constants.h"
namespace zilliz {
namespace vecwise {
namespace server {
using namespace megasearch;
namespace {
const std::map<ServerError, thrift::ErrorCode::type> &ErrorMap() {
static const std::map<ServerError, thrift::ErrorCode::type> code_map = {
{SERVER_UNEXPECTED_ERROR, thrift::ErrorCode::ILLEGAL_ARGUMENT},
{SERVER_NULL_POINTER, thrift::ErrorCode::ILLEGAL_ARGUMENT},
{SERVER_INVALID_ARGUMENT, thrift::ErrorCode::ILLEGAL_ARGUMENT},
{SERVER_FILE_NOT_FOUND, thrift::ErrorCode::ILLEGAL_ARGUMENT},
{SERVER_NOT_IMPLEMENT, thrift::ErrorCode::ILLEGAL_ARGUMENT},
{SERVER_BLOCKING_QUEUE_EMPTY, thrift::ErrorCode::ILLEGAL_ARGUMENT},
{SERVER_GROUP_NOT_EXIST, thrift::ErrorCode::TABLE_NOT_EXISTS},
{SERVER_INVALID_TIME_RANGE, thrift::ErrorCode::ILLEGAL_RANGE},
{SERVER_INVALID_VECTOR_DIMENSION, thrift::ErrorCode::ILLEGAL_DIMENSION},
};
return code_map;
}
const std::map<ServerError, std::string> &ErrorMessage() {
static const std::map<ServerError, std::string> msg_map = {
{SERVER_UNEXPECTED_ERROR, "unexpected error occurs"},
{SERVER_NULL_POINTER, "null pointer error"},
{SERVER_INVALID_ARGUMENT, "invalid argument"},
{SERVER_FILE_NOT_FOUND, "file not found"},
{SERVER_NOT_IMPLEMENT, "not implemented"},
{SERVER_BLOCKING_QUEUE_EMPTY, "queue empty"},
{SERVER_GROUP_NOT_EXIST, "group not exist"},
{SERVER_INVALID_TIME_RANGE, "invalid time range"},
{SERVER_INVALID_VECTOR_DIMENSION, "invalid vector dimension"},
};
return msg_map;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
BaseTask::BaseTask(const std::string& task_group, bool async)
......@@ -38,16 +77,40 @@ ServerError BaseTask::WaitToFinish() {
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
VecServiceScheduler::VecServiceScheduler()
MegasearchScheduler::MegasearchScheduler()
: stopped_(false) {
Start();
}
VecServiceScheduler::~VecServiceScheduler() {
MegasearchScheduler::~MegasearchScheduler() {
Stop();
}
void VecServiceScheduler::Start() {
void MegasearchScheduler::ExecTask(BaseTaskPtr& task_ptr) {
if(task_ptr == nullptr) {
return;
}
MegasearchScheduler& scheduler = MegasearchScheduler::GetInstance();
scheduler.ExecuteTask(task_ptr);
if(!task_ptr->IsAsync()) {
task_ptr->WaitToFinish();
ServerError err = task_ptr->ErrorCode();
if (err != SERVER_SUCCESS) {
thrift::Exception ex;
ex.__set_code(ErrorMap().at(err));
std::string msg = task_ptr->ErrorMsg();
if(msg.empty()){
msg = ErrorMessage().at(err);
}
ex.__set_reason(msg);
throw ex;
}
}
}
void MegasearchScheduler::Start() {
if(!stopped_) {
return;
}
......@@ -55,7 +118,7 @@ void VecServiceScheduler::Start() {
stopped_ = false;
}
void VecServiceScheduler::Stop() {
void MegasearchScheduler::Stop() {
if(stopped_) {
return;
}
......@@ -80,7 +143,7 @@ void VecServiceScheduler::Stop() {
SERVER_LOG_INFO << "Scheduler stopped";
}
ServerError VecServiceScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
ServerError MegasearchScheduler::ExecuteTask(const BaseTaskPtr& task_ptr) {
if(task_ptr == nullptr) {
return SERVER_NULL_POINTER;
}
......@@ -121,7 +184,7 @@ namespace {
}
}
ServerError VecServiceScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) {
ServerError MegasearchScheduler::PutTaskToQueue(const BaseTaskPtr& task_ptr) {
std::lock_guard<std::mutex> lock(queue_mtx_);
std::string group_name = task_ptr->TaskGroup();
......
......@@ -50,10 +50,10 @@ using TaskQueue = BlockingQueue<BaseTaskPtr>;
using TaskQueuePtr = std::shared_ptr<TaskQueue>;
using ThreadPtr = std::shared_ptr<std::thread>;
class VecServiceScheduler {
class MegasearchScheduler {
public:
static VecServiceScheduler& GetInstance() {
static VecServiceScheduler scheduler;
static MegasearchScheduler& GetInstance() {
static MegasearchScheduler scheduler;
return scheduler;
}
......@@ -62,9 +62,11 @@ public:
ServerError ExecuteTask(const BaseTaskPtr& task_ptr);
static void ExecTask(BaseTaskPtr& task_ptr);
protected:
VecServiceScheduler();
virtual ~VecServiceScheduler();
MegasearchScheduler();
virtual ~MegasearchScheduler();
ServerError PutTaskToQueue(const BaseTaskPtr& task_ptr);
......
......@@ -3,22 +3,17 @@
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "VecServiceWrapper.h"
#include "VecServiceHandler.h"
#include "VecServiceScheduler.h"
#include "MegasearchServer.h"
#include "MegasearchHandler.h"
#include "megasearch_types.h"
#include "megasearch_constants.h"
#include "ServerConfig.h"
#include "utils/Log.h"
#include "thrift/gen-cpp/megasearch_types.h"
#include "thrift/gen-cpp/megasearch_constants.h"
#include <thrift/protocol/TBinaryProtocol.h>
#include <thrift/protocol/TJSONProtocol.h>
#include <thrift/protocol/TDebugProtocol.h>
#include <thrift/protocol/TCompactProtocol.h>
#include <thrift/server/TSimpleServer.h>
//#include <thrift/server/TNonblockingServer.h>
#include <thrift/server/TThreadPoolServer.h>
#include <thrift/transport/TServerSocket.h>
#include <thrift/transport/TBufferTransports.h>
......@@ -30,6 +25,7 @@ namespace zilliz {
namespace vecwise {
namespace server {
using namespace megasearch::thrift;
using namespace ::apache::thrift;
using namespace ::apache::thrift::protocol;
using namespace ::apache::thrift::transport;
......@@ -38,7 +34,8 @@ using namespace ::apache::thrift::concurrency;
static stdcxx::shared_ptr<TServer> s_server;
void VecServiceWrapper::StartService() {
void
MegasearchServer::StartService() {
if(s_server != nullptr){
StopService();
}
......@@ -52,11 +49,12 @@ void VecServiceWrapper::StartService() {
std::string mode = server_config.GetValue(CONFIG_SERVER_MODE, "thread_pool");
try {
stdcxx::shared_ptr<VecServiceHandler> handler(new VecServiceHandler());
stdcxx::shared_ptr<TProcessor> processor(new VecServiceProcessor(handler));
stdcxx::shared_ptr<MegasearchServiceHandler> handler(new MegasearchServiceHandler());
stdcxx::shared_ptr<TProcessor> processor(new MegasearchServiceProcessor(handler));
stdcxx::shared_ptr<TServerTransport> server_transport(new TServerSocket(address, port));
stdcxx::shared_ptr<TTransportFactory> transport_factory(new TBufferedTransportFactory());
std::string protocol = "json";
stdcxx::shared_ptr<TProtocolFactory> protocol_factory;
if (protocol == "binary") {
protocol_factory.reset(new TBinaryProtocolFactory());
......@@ -67,24 +65,14 @@ void VecServiceWrapper::StartService() {
} else if (protocol == "debug") {
protocol_factory.reset(new TDebugProtocolFactory());
} else {
SERVER_LOG_INFO << "Service protocol: " << protocol << " is not supported currently";
//SERVER_LOG_INFO << "Service protocol: " << protocol << " is not supported currently";
return;
}
std::string mode = "thread_pool";
if (mode == "simple") {
s_server.reset(new TSimpleServer(processor, server_transport, transport_factory, protocol_factory));
s_server->serve();
// } else if(mode == "non_blocking") {
// ::apache::thrift::stdcxx::shared_ptr<TNonblockingServerTransport> nb_server_transport(new TServerSocket(address, port));
// ::apache::thrift::stdcxx::shared_ptr<ThreadManager> threadManager(ThreadManager::newSimpleThreadManager());
// ::apache::thrift::stdcxx::shared_ptr<PosixThreadFactory> threadFactory(new PosixThreadFactory());
// threadManager->threadFactory(threadFactory);
// threadManager->start();
//
// s_server.reset(new TNonblockingServer(processor,
// protocol_factory,
// nb_server_transport,
// threadManager));
} else if (mode == "thread_pool") {
stdcxx::shared_ptr<ThreadManager> threadManager(ThreadManager::newSimpleThreadManager());
stdcxx::shared_ptr<PosixThreadFactory> threadFactory(new PosixThreadFactory());
......@@ -98,19 +86,17 @@ void VecServiceWrapper::StartService() {
threadManager));
s_server->serve();
} else {
SERVER_LOG_INFO << "Service mode: " << mode << " is not supported currently";
//SERVER_LOG_INFO << "Service mode: " << mode << " is not supported currently";
return;
}
} catch (apache::thrift::TException& ex) {
SERVER_LOG_ERROR << "Server encounter exception: " << ex.what();
//SERVER_LOG_ERROR << "Server encounter exception: " << ex.what();
}
}
void VecServiceWrapper::StopService() {
void
MegasearchServer::StopService() {
auto stop_server_worker = [&]{
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
scheduler.Stop();
if(s_server != nullptr) {
s_server->stop();
}
......
......@@ -5,8 +5,6 @@
******************************************************************************/
#pragma once
#include "utils/Error.h"
#include <cstdint>
#include <string>
......@@ -14,13 +12,12 @@ namespace zilliz {
namespace vecwise {
namespace server {
class VecServiceWrapper {
class MegasearchServer {
public:
static void StartService();
static void StopService();
};
}
}
}
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "MegasearchTask.h"
#include "ServerConfig.h"
#include "VecIdMapper.h"
#include "utils/CommonUtil.h"
#include "utils/Log.h"
#include "utils/TimeRecorder.h"
#include "utils/ThreadPool.h"
#include "db/DB.h"
#include "db/Env.h"
#include "db/Meta.h"
namespace zilliz {
namespace vecwise {
namespace server {
static const std::string DQL_TASK_GROUP = "dql";
static const std::string DDL_DML_TASK_GROUP = "ddl_dml";
static const std::string VECTOR_UID = "uid";
static const uint64_t USE_MT = 5000;
using DB_META = zilliz::vecwise::engine::meta::Meta;
using DB_DATE = zilliz::vecwise::engine::meta::DateT;
namespace {
class DBWrapper {
public:
DBWrapper() {
zilliz::vecwise::engine::Options opt;
ConfigNode& config = ServerConfig::GetInstance().GetConfig(CONFIG_DB);
opt.meta.backend_uri = config.GetValue(CONFIG_DB_URL);
std::string db_path = config.GetValue(CONFIG_DB_PATH);
opt.memory_sync_interval = (uint16_t)config.GetInt32Value(CONFIG_DB_FLUSH_INTERVAL, 10);
opt.meta.path = db_path + "/db";
CommonUtil::CreateDirectory(opt.meta.path);
zilliz::vecwise::engine::DB::Open(opt, &db_);
if(db_ == nullptr) {
SERVER_LOG_ERROR << "Failed to open db";
throw ServerException(SERVER_NULL_POINTER, "Failed to open db");
}
}
zilliz::vecwise::engine::DB* DB() { return db_; }
private:
zilliz::vecwise::engine::DB* db_ = nullptr;
};
zilliz::vecwise::engine::DB* DB() {
static DBWrapper db_wrapper;
return db_wrapper.DB();
}
ThreadPool& GetThreadPool() {
static ThreadPool pool(6);
return pool;
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
CreateTableTask::CreateTableTask(const thrift::TableSchema& schema)
: BaseTask(DDL_DML_TASK_GROUP),
schema_(schema) {
}
BaseTaskPtr CreateTableTask::Create(const thrift::TableSchema& schema) {
return std::shared_ptr<BaseTask>(new CreateTableTask(schema));
}
ServerError CreateTableTask::OnExecute() {
TimeRecorder rc("CreateTableTask");
try {
if(schema_.vector_column_array.empty()) {
return SERVER_INVALID_ARGUMENT;
}
IVecIdMapper::GetInstance()->AddGroup(schema_.table_name);
engine::meta::GroupSchema group_info;
group_info.dimension = (uint16_t)schema_.vector_column_array[0].dimension;
group_info.group_id = schema_.table_name;
engine::Status stat = DB()->add_group(group_info);
if(!stat.ok()) {//could exist
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return SERVER_SUCCESS;
}
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = ex.what();
SERVER_LOG_ERROR << error_msg_;
return SERVER_UNEXPECTED_ERROR;
}
rc.Record("done");
return SERVER_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DescribeTableTask::DescribeTableTask(const std::string &table_name, thrift::TableSchema &schema)
: BaseTask(DDL_DML_TASK_GROUP),
table_name_(table_name),
schema_(schema) {
schema_.table_name = table_name_;
}
BaseTaskPtr DescribeTableTask::Create(const std::string& table_name, thrift::TableSchema& schema) {
return std::shared_ptr<BaseTask>(new DescribeTableTask(table_name, schema));
}
ServerError DescribeTableTask::OnExecute() {
TimeRecorder rc("DescribeTableTask");
try {
engine::meta::GroupSchema group_info;
group_info.group_id = table_name_;
engine::Status stat = DB()->get_group(group_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
} else {
}
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = ex.what();
SERVER_LOG_ERROR << error_msg_;
return SERVER_UNEXPECTED_ERROR;
}
rc.Record("done");
return SERVER_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
DeleteTableTask::DeleteTableTask(const std::string& table_name)
: BaseTask(DDL_DML_TASK_GROUP),
table_name_(table_name) {
}
BaseTaskPtr DeleteTableTask::Create(const std::string& group_id) {
return std::shared_ptr<BaseTask>(new DeleteTableTask(group_id));
}
ServerError DeleteTableTask::OnExecute() {
error_code_ = SERVER_NOT_IMPLEMENT;
error_msg_ = "delete table not implemented";
SERVER_LOG_ERROR << error_msg_;
//IVecIdMapper::GetInstance()->DeleteGroup(table_name_);
return SERVER_NOT_IMPLEMENT;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
AddVectorTask::AddVectorTask(const std::string& table_name,
const std::vector<thrift::RowRecord>& record_array,
std::vector<int64_t>& record_ids)
: BaseTask(DDL_DML_TASK_GROUP),
table_name_(table_name),
record_array_(record_array),
record_ids_(record_ids) {
record_ids_.clear();
record_ids_.resize(record_array.size());
}
BaseTaskPtr AddVectorTask::Create(const std::string& table_name,
const std::vector<thrift::RowRecord>& record_array,
std::vector<int64_t>& record_ids) {
return std::shared_ptr<BaseTask>(new AddVectorTask(table_name, record_array, record_ids));
}
ServerError AddVectorTask::OnExecute() {
try {
TimeRecorder rc("AddVectorTask");
if(record_array_.empty()) {
return SERVER_SUCCESS;
}
engine::meta::GroupSchema group_info;
group_info.group_id = table_name_;
engine::Status stat = DB()->get_group(group_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
rc.Record("get group info");
uint64_t vec_count = (uint64_t)record_array_.size();
uint64_t group_dim = group_info.dimension;
std::vector<float> vec_f;
vec_f.resize(vec_count*group_dim);//allocate enough memory
for(uint64_t i = 0; i < vec_count; i++) {
const auto& record = record_array_[i];
if(record.vector_map.empty()) {
error_code_ = SERVER_INVALID_ARGUMENT;
error_msg_ = "No vector provided in record";
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
uint64_t vec_dim = record.vector_map.begin()->second.size()/sizeof(double);//how many double value?
if(vec_dim != group_dim) {
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
<< " vs. group dimension:" << group_dim;
error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
error_msg_ = "Engine failed: " + stat.ToString();
return error_code_;
}
const double* d_p = reinterpret_cast<const double*>(record.vector_map.begin()->second.data());
for(uint64_t d = 0; d < vec_dim; d++) {
vec_f[i*vec_dim + d] = (float)(d_p[d]);
}
}
rc.Record("prepare vectors data");
stat = DB()->add_vectors(table_name_, vec_count, vec_f.data(), record_ids_);
rc.Record("add vectors to engine");
if(!stat.ok()) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
if(record_ids_.size() < vec_count) {
SERVER_LOG_ERROR << "Vector ID not returned";
return SERVER_UNEXPECTED_ERROR;
}
rc.Record("done");
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = ex.what();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
return SERVER_SUCCESS;
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
SearchVectorTask::SearchVectorTask(const std::string& table_name,
const int64_t top_k,
const std::vector<thrift::QueryRecord>& record_array,
std::vector<thrift::TopKQueryResult>& result_array)
: BaseTask(DQL_TASK_GROUP),
table_name_(table_name),
top_k_(top_k),
record_array_(record_array),
result_array_(result_array) {
}
BaseTaskPtr SearchVectorTask::Create(const std::string& table_name,
const std::vector<thrift::QueryRecord>& record_array,
const int64_t top_k,
std::vector<thrift::TopKQueryResult>& result_array) {
return std::shared_ptr<BaseTask>(new SearchVectorTask(table_name, top_k, record_array, result_array));
}
ServerError SearchVectorTask::OnExecute() {
try {
TimeRecorder rc("SearchVectorTask");
if(top_k_ <= 0 || record_array_.empty()) {
error_code_ = SERVER_INVALID_ARGUMENT;
error_msg_ = "Invalid topk value, or query record array is empty";
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
engine::meta::GroupSchema group_info;
group_info.group_id = table_name_;
engine::Status stat = DB()->get_group(group_info);
if(!stat.ok()) {
error_code_ = SERVER_GROUP_NOT_EXIST;
error_msg_ = "Engine failed: " + stat.ToString();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
std::vector<float> vec_f;
uint64_t record_count = (uint64_t)record_array_.size();
vec_f.resize(record_count*group_info.dimension);
for(uint64_t i = 0; i < record_array_.size(); i++) {
const auto& record = record_array_[i];
if (record.vector_map.empty()) {
error_code_ = SERVER_INVALID_ARGUMENT;
error_msg_ = "Query record has no vector";
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
uint64_t vec_dim = record.vector_map.begin()->second.size() / sizeof(double);//how many double value?
if (vec_dim != group_info.dimension) {
SERVER_LOG_ERROR << "Invalid vector dimension: " << vec_dim
<< " vs. group dimension:" << group_info.dimension;
error_code_ = SERVER_INVALID_VECTOR_DIMENSION;
error_msg_ = "Engine failed: " + stat.ToString();
return error_code_;
}
const double* d_p = reinterpret_cast<const double*>(record.vector_map.begin()->second.data());
for(uint64_t d = 0; d < vec_dim; d++) {
vec_f[i*vec_dim + d] = (float)(d_p[d]);
}
}
rc.Record("prepare vector data");
std::vector<DB_DATE> dates;
engine::QueryResults results;
stat = DB()->search(table_name_, (size_t)top_k_, record_count, vec_f.data(), dates, results);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
return SERVER_UNEXPECTED_ERROR;
} else {
rc.Record("do searching");
for(engine::QueryResult& result : results){
thrift::TopKQueryResult thrift_topk_result;
for(auto id : result) {
thrift::QueryResult thrift_result;
thrift_result.__set_id(id);
thrift_topk_result.query_result_arrays.emplace_back(thrift_result);
}
result_array_.push_back(thrift_topk_result);
}
rc.Record("construct result");
}
rc.Record("done");
} catch (std::exception& ex) {
error_code_ = SERVER_UNEXPECTED_ERROR;
error_msg_ = ex.what();
SERVER_LOG_ERROR << error_msg_;
return error_code_;
}
return SERVER_SUCCESS;
}
}
}
}
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "MegasearchScheduler.h"
#include "utils/Error.h"
#include "utils/AttributeSerializer.h"
#include "db/Types.h"
#include "megasearch_types.h"
#include <condition_variable>
#include <memory>
namespace zilliz {
namespace vecwise {
namespace server {
using namespace megasearch;
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class CreateTableTask : public BaseTask {
public:
static BaseTaskPtr Create(const thrift::TableSchema& schema);
protected:
CreateTableTask(const thrift::TableSchema& schema);
ServerError OnExecute() override;
private:
const thrift::TableSchema& schema_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class DescribeTableTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& table_name, thrift::TableSchema& schema);
protected:
DescribeTableTask(const std::string& table_name, thrift::TableSchema& schema);
ServerError OnExecute() override;
private:
std::string table_name_;
thrift::TableSchema& schema_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class DeleteTableTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& table_name);
protected:
DeleteTableTask(const std::string& table_name);
ServerError OnExecute() override;
private:
std::string table_name_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class AddVectorTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& table_name,
const std::vector<thrift::RowRecord>& record_array,
std::vector<int64_t>& record_ids_);
protected:
AddVectorTask(const std::string& table_name,
const std::vector<thrift::RowRecord>& record_array,
std::vector<int64_t>& record_ids_);
ServerError OnExecute() override;
private:
std::string table_name_;
const std::vector<thrift::RowRecord>& record_array_;
std::vector<int64_t>& record_ids_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class SearchVectorTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& table_name,
const std::vector<thrift::QueryRecord>& record_array,
const int64_t top_k,
std::vector<thrift::TopKQueryResult>& result_array);
protected:
SearchVectorTask(const std::string& table_name,
const int64_t top_k,
const std::vector<thrift::QueryRecord>& record_array,
std::vector<thrift::TopKQueryResult>& result_array);
ServerError OnExecute() override;
private:
std::string table_name_;
int64_t top_k_;
const std::vector<thrift::QueryRecord>& record_array_;
std::vector<thrift::TopKQueryResult>& result_array_;
};
}
}
}
\ No newline at end of file
......@@ -5,7 +5,7 @@
////////////////////////////////////////////////////////////////////////////////
#include "Server.h"
#include "ServerConfig.h"
#include "VecServiceWrapper.h"
#include "MegasearchServer.h"
#include "utils/Log.h"
#include "utils/SignalUtil.h"
#include "utils/TimeRecorder.h"
......@@ -225,12 +225,12 @@ Server::LoadConfig() {
void
Server::StartService() {
VecServiceWrapper::StartService();
MegasearchServer::StartService();
}
void
Server::StopService() {
VecServiceWrapper::StopService();
MegasearchServer::StopService();
}
}
......
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "VecServiceHandler.h"
#include "VecServiceTask.h"
#include "ServerConfig.h"
#include "utils/Log.h"
#include "utils/CommonUtil.h"
#include "utils/TimeRecorder.h"
#include "db/DB.h"
#include "db/Env.h"
namespace zilliz {
namespace vecwise {
namespace server {
using namespace megasearch;
namespace {
class TimeRecordWrapper {
public:
TimeRecordWrapper(const std::string& func_name)
: recorder_(func_name), func_name_(func_name) {
//SERVER_LOG_TRACE << func_name << " called";
}
~TimeRecordWrapper() {
recorder_.Elapse("cost");
//SERVER_LOG_TRACE << func_name_ << " finished";
}
private:
TimeRecorder recorder_;
std::string func_name_;
};
void TimeRecord(const std::string& func_name) {
}
const std::map<ServerError, VecErrCode::type>& ErrorMap() {
static const std::map<ServerError, VecErrCode::type> code_map = {
{SERVER_UNEXPECTED_ERROR, VecErrCode::ILLEGAL_ARGUMENT},
{SERVER_NULL_POINTER, VecErrCode::ILLEGAL_ARGUMENT},
{SERVER_INVALID_ARGUMENT, VecErrCode::ILLEGAL_ARGUMENT},
{SERVER_FILE_NOT_FOUND, VecErrCode::ILLEGAL_ARGUMENT},
{SERVER_NOT_IMPLEMENT, VecErrCode::ILLEGAL_ARGUMENT},
{SERVER_BLOCKING_QUEUE_EMPTY, VecErrCode::ILLEGAL_ARGUMENT},
{SERVER_GROUP_NOT_EXIST, VecErrCode::GROUP_NOT_EXISTS},
{SERVER_INVALID_TIME_RANGE, VecErrCode::ILLEGAL_TIME_RANGE},
{SERVER_INVALID_VECTOR_DIMENSION, VecErrCode::ILLEGAL_VECTOR_DIMENSION},
};
return code_map;
}
const std::map<ServerError, std::string>& ErrorMessage() {
static const std::map<ServerError, std::string> msg_map = {
{SERVER_UNEXPECTED_ERROR, "unexpected error occurs"},
{SERVER_NULL_POINTER, "null pointer error"},
{SERVER_INVALID_ARGUMENT, "invalid argument"},
{SERVER_FILE_NOT_FOUND, "file not found"},
{SERVER_NOT_IMPLEMENT, "not implemented"},
{SERVER_BLOCKING_QUEUE_EMPTY, "queue empty"},
{SERVER_GROUP_NOT_EXIST, "group not exist"},
{SERVER_INVALID_TIME_RANGE, "invalid time range"},
{SERVER_INVALID_VECTOR_DIMENSION, "invalid vector dimension"},
};
return msg_map;
}
void ExecTask(BaseTaskPtr& task_ptr) {
if(task_ptr == nullptr) {
return;
}
VecServiceScheduler& scheduler = VecServiceScheduler::GetInstance();
scheduler.ExecuteTask(task_ptr);
if(!task_ptr->IsAsync()) {
task_ptr->WaitToFinish();
ServerError err = task_ptr->ErrorCode();
if (err != SERVER_SUCCESS) {
VecException ex;
ex.__set_code(ErrorMap().at(err));
std::string msg = task_ptr->ErrorMsg();
if(msg.empty()){
msg = ErrorMessage().at(err);
}
ex.__set_reason(msg);
throw ex;
}
}
}
}
void
VecServiceHandler::add_group(const VecGroup &group) {
std::string info = "add_group() " + group.id + " dimension = " + std::to_string(group.dimension)
+ " index_type = " + std::to_string(group.index_type);
TimeRecordWrapper rc(info);
BaseTaskPtr task_ptr = AddGroupTask::Create(group.dimension, group.id);
ExecTask(task_ptr);
}
void
VecServiceHandler::get_group(VecGroup &_return, const std::string &group_id) {
TimeRecordWrapper rc("get_group() " + group_id);
_return.id = group_id;
BaseTaskPtr task_ptr = GetGroupTask::Create(group_id, _return.dimension);
ExecTask(task_ptr);
}
void
VecServiceHandler::del_group(const std::string &group_id) {
TimeRecordWrapper rc("del_group() " + group_id);
BaseTaskPtr task_ptr = DeleteGroupTask::Create(group_id);
ExecTask(task_ptr);
}
void
VecServiceHandler::add_vector(std::string& _return, const std::string &group_id, const VecTensor &tensor) {
TimeRecordWrapper rc("add_vector() to " + group_id);
BaseTaskPtr task_ptr = AddVectorTask::Create(group_id, &tensor, _return);
ExecTask(task_ptr);
}
void
VecServiceHandler::add_vector_batch(std::vector<std::string> & _return,
const std::string &group_id,
const VecTensorList &tensor_list) {
TimeRecordWrapper rc("add_vector_batch() to " + group_id);
BaseTaskPtr task_ptr = AddBatchVectorTask::Create(group_id, &tensor_list, _return);
ExecTask(task_ptr);
}
void
VecServiceHandler::add_binary_vector(std::string& _return,
const std::string& group_id,
const VecBinaryTensor& tensor) {
TimeRecordWrapper rc("add_binary_vector() to " + group_id);
BaseTaskPtr task_ptr = AddVectorTask::Create(group_id, &tensor, _return);
ExecTask(task_ptr);
}
void
VecServiceHandler::add_binary_vector_batch(std::vector<std::string> & _return,
const std::string& group_id,
const VecBinaryTensorList& tensor_list) {
TimeRecordWrapper rc("add_binary_vector_batch() to " + group_id);
BaseTaskPtr task_ptr = AddBatchVectorTask::Create(group_id, &tensor_list, _return);
ExecTask(task_ptr);
}
void
VecServiceHandler::search_vector(VecSearchResult &_return,
const std::string &group_id,
const int64_t top_k,
const VecTensor &tensor,
const VecSearchFilter& filter) {
TimeRecordWrapper rc("search_vector() in " + group_id);
VecTensorList tensor_list;
tensor_list.tensor_list.push_back(tensor);
VecSearchResultList result;
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, filter, result);
ExecTask(task_ptr);
if(!result.result_list.empty()) {
_return = result.result_list[0];
} else {
SERVER_LOG_ERROR << "No search result returned";
}
}
void
VecServiceHandler::search_vector_batch(VecSearchResultList &_return,
const std::string &group_id,
const int64_t top_k,
const VecTensorList &tensor_list,
const VecSearchFilter& filter) {
TimeRecordWrapper rc("search_vector_batch() in " + group_id);
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, filter, _return);
ExecTask(task_ptr);
}
void
VecServiceHandler::search_binary_vector(VecSearchResult& _return,
const std::string& group_id,
const int64_t top_k,
const VecBinaryTensor& tensor,
const VecSearchFilter& filter) {
TimeRecordWrapper rc("search_binary_vector() in " + group_id);
VecBinaryTensorList tensor_list;
tensor_list.tensor_list.push_back(tensor);
VecSearchResultList result;
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, filter, result);
ExecTask(task_ptr);
if(!result.result_list.empty()) {
_return = result.result_list[0];
} else {
SERVER_LOG_ERROR << "No search result returned";
}
}
void
VecServiceHandler::search_binary_vector_batch(VecSearchResultList& _return,
const std::string& group_id,
const int64_t top_k,
const VecBinaryTensorList& tensor_list,
const VecSearchFilter& filter) {
TimeRecordWrapper rc("search_binary_vector_batch() in " + group_id);
BaseTaskPtr task_ptr = SearchVectorTask::Create(group_id, top_k, &tensor_list, filter, _return);
ExecTask(task_ptr);
}
}
}
}
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "utils/Error.h"
#include "thrift/gen-cpp/VecService.h"
#include <cstdint>
#include <string>
namespace zilliz {
namespace vecwise {
namespace engine {
class DB;
}
}
}
namespace zilliz {
namespace vecwise {
namespace server {
using namespace megasearch;
class VecServiceHandler : virtual public VecServiceIf {
public:
VecServiceHandler() {
// Your initialization goes here
}
/**
* group interfaces
*
* @param group
*/
void add_group(const VecGroup& group);
void get_group(VecGroup& _return, const std::string& group_id);
void del_group(const std::string& group_id);
/**
* insert vector interfaces
*
*
* @param group_id
* @param tensor
*/
void add_vector(std::string& _return, const std::string& group_id, const VecTensor& tensor);
void add_vector_batch(std::vector<std::string> & _return, const std::string& group_id, const VecTensorList& tensor_list);
void add_binary_vector(std::string& _return, const std::string& group_id, const VecBinaryTensor& tensor);
void add_binary_vector_batch(std::vector<std::string> & _return, const std::string& group_id, const VecBinaryTensorList& tensor_list);
/**
* search interfaces
* you can use filter to reduce search result
* filter.attrib_filter can specify which attribute you need, for example:
* set attrib_filter = {"color":""} means you want to get "color" attribute for result vector
* set attrib_filter = {"color":"red"} means you want to get vectors which has attribute "color" equals "red"
* if filter.time_range is empty, engine will search without time limit
*
* @param group_id
* @param top_k
* @param tensor
* @param filter
*/
void search_vector(VecSearchResult& _return, const std::string& group_id, const int64_t top_k, const VecTensor& tensor, const VecSearchFilter& filter);
void search_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecTensorList& tensor_list, const VecSearchFilter& filter);
void search_binary_vector(VecSearchResult& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensor& tensor, const VecSearchFilter& filter);
void search_binary_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensorList& tensor_list, const VecSearchFilter& filter);
};
}
}
}
此差异已折叠。
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#pragma once
#include "VecServiceScheduler.h"
#include "utils/Error.h"
#include "utils/AttributeSerializer.h"
#include "db/Types.h"
#include "thrift/gen-cpp/megasearch_types.h"
#include <condition_variable>
#include <memory>
namespace zilliz {
namespace vecwise {
namespace server {
using namespace megasearch;
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class AddGroupTask : public BaseTask {
public:
static BaseTaskPtr Create(int32_t dimension,
const std::string& group_id);
protected:
AddGroupTask(int32_t dimension,
const std::string& group_id);
ServerError OnExecute() override;
private:
int32_t dimension_;
std::string group_id_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class GetGroupTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& group_id, int32_t& dimension);
protected:
GetGroupTask(const std::string& group_id, int32_t& dimension);
ServerError OnExecute() override;
private:
std::string group_id_;
int32_t& dimension_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class DeleteGroupTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& group_id);
protected:
DeleteGroupTask(const std::string& group_id);
ServerError OnExecute() override;
private:
std::string group_id_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class AddVectorTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& group_id,
const VecTensor* tensor,
std::string& id);
static BaseTaskPtr Create(const std::string& group_id,
const VecBinaryTensor* tensor,
std::string& id);
protected:
AddVectorTask(const std::string& group_id,
const VecTensor* tensor,
std::string& id);
AddVectorTask(const std::string& group_id,
const VecBinaryTensor* tensor,
std::string& id);
uint64_t GetVecDimension() const;
const double* GetVecData() const;
std::string GetVecID() const;
const AttribMap& GetVecAttrib() const;
ServerError OnExecute() override;
private:
std::string group_id_;
const VecTensor* tensor_;
const VecBinaryTensor* bin_tensor_;
std::string& tensor_id_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class AddBatchVectorTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& group_id,
const VecTensorList* tensor_list,
std::vector<std::string>& ids);
static BaseTaskPtr Create(const std::string& group_id,
const VecBinaryTensorList* tensor_list,
std::vector<std::string>& ids);
protected:
AddBatchVectorTask(const std::string& group_id,
const VecTensorList* tensor_list,
std::vector<std::string>& ids);
AddBatchVectorTask(const std::string& group_id,
const VecBinaryTensorList* tensor_list,
std::vector<std::string>& ids);
uint64_t GetVecListCount() const;
uint64_t GetVecDimension(uint64_t index) const;
const double* GetVecData(uint64_t index) const;
std::string GetVecID(uint64_t index) const;
const AttribMap& GetVecAttrib(uint64_t index) const;
void ProcessIdMapping(engine::IDNumbers& vector_ids,
uint64_t from, uint64_t to,
std::vector<std::string>& tensor_ids);
ServerError OnExecute() override;
private:
std::string group_id_;
const VecTensorList* tensor_list_;
const VecBinaryTensorList* bin_tensor_list_;
std::vector<std::string>& tensor_ids_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class SearchVectorTask : public BaseTask {
public:
static BaseTaskPtr Create(const std::string& group_id,
const int64_t top_k,
const VecTensorList* tensor_list,
const VecSearchFilter& filter,
VecSearchResultList& result);
static BaseTaskPtr Create(const std::string& group_id,
const int64_t top_k,
const VecBinaryTensorList* bin_tensor_list,
const VecSearchFilter& filter,
VecSearchResultList& result);
protected:
SearchVectorTask(const std::string& group_id,
const int64_t top_k,
const VecTensorList* tensor_list,
const VecSearchFilter& filter,
VecSearchResultList& result);
SearchVectorTask(const std::string& group_id,
const int64_t top_k,
const VecBinaryTensorList* bin_tensor_list,
const VecSearchFilter& filter,
VecSearchResultList& result);
ServerError GetTargetData(std::vector<float>& data) const;
uint64_t GetTargetDimension() const;
uint64_t GetTargetCount() const;
ServerError OnExecute() override;
private:
std::string group_id_;
int64_t top_k_;
const VecTensorList* tensor_list_;
const VecBinaryTensorList* bin_tensor_list_;
const VecSearchFilter& filter_;
VecSearchResultList& result_;
};
}
}
}
\ No newline at end of file
此差异已折叠。
此差异已折叠。
/**
* Autogenerated by Thrift Compiler (0.11.0)
* Autogenerated by Thrift Compiler (0.12.0)
*
* DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
* @generated
*/
#include "megasearch_constants.h"
namespace megasearch {
namespace megasearch { namespace thrift {
const megasearchConstants g_megasearch_constants;
megasearchConstants::megasearchConstants() {
}
} // namespace
}} // namespace
/**
* Autogenerated by Thrift Compiler (0.11.0)
* Autogenerated by Thrift Compiler (0.12.0)
*
* DO NOT EDIT UNLESS YOU ARE SURE THAT YOU KNOW WHAT YOU ARE DOING
* @generated
......@@ -9,7 +9,7 @@
#include "megasearch_types.h"
namespace megasearch {
namespace megasearch { namespace thrift {
class megasearchConstants {
public:
......@@ -19,6 +19,6 @@ class megasearchConstants {
extern const megasearchConstants g_megasearch_constants;
} // namespace
}} // namespace
#endif
import time
import struct
from megasearch import VecService
#Note: pip install thrift
from thrift.transport import TSocket
from thrift.transport import TTransport
from thrift.protocol import TBinaryProtocol, TCompactProtocol, TJSONProtocol
def test_megasearch():
try:
#connect
transport = TSocket.TSocket('localhost', 33001)
transport = TTransport.TBufferedTransport(transport)
protocol = TJSONProtocol.TJSONProtocol(transport)
client = VecService.Client(protocol)
transport.open()
print("connected");
#add group
group = VecService.VecGroup("test_" + time.strftime('%H%M%S'), 256)
client.add_group(group)
print("group added");
# build binary vectors
bin_vec_list = VecService.VecBinaryTensorList([])
for i in range(10000):
a=[]
for k in range(group.dimension):
a.append(i + k)
bin_vec = VecService.VecBinaryTensor("binary_" + str(i), bytes())
bin_vec.tensor = struct.pack(str(group.dimension)+"d", *a)
bin_vec_list.tensor_list.append(bin_vec)
# add vectors
client.add_binary_vector_batch(group.id, bin_vec_list)
wait_storage = 5
print("wait {} seconds for persisting data".format(wait_storage))
time.sleep(wait_storage)
# search vector
a = []
for k in range(group.dimension):
a.append(300 + k)
bin_vec = VecService.VecBinaryTensor("binary_search", bytes())
bin_vec.tensor = struct.pack(str(group.dimension) + "d", *a)
filter = VecService.VecSearchFilter()
res = VecService.VecSearchResult()
print("begin search ...");
res = client.search_binary_vector(group.id, 5, bin_vec, filter)
print('result count: ' + str(len(res.result_list)))
for item in res.result_list:
print(item.uid)
transport.close()
print("disconnected");
except VecService.VecException as ex:
print(ex.reason)
test_megasearch()
\ No newline at end of file
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册