提交 32b5d7c4 编写于 作者: G groot

client call api


Former-commit-id: 1abe8528051e0fb01dc2922f3ac71c739ae6df57
上级 a7cefc6a
......@@ -3,6 +3,8 @@ server_config:
port: 33001
transfer_protocol: json #optional: binary, compact, json, debug
server_mode: thread_pool #optional: simple, thread_pool
db_backend_url: http://127.0.0.1
db_name: vecwise
log_config:
global:
......
......@@ -8,6 +8,7 @@ aux_source_directory(cache cache_files)
aux_source_directory(config config_files)
aux_source_directory(server server_files)
aux_source_directory(utils utils_files)
aux_source_directory(db db_files)
aux_source_directory(wrapper wrapper_files)
set(service_files
......@@ -18,6 +19,7 @@ set(service_files
set(vecwise_engine_src
${CMAKE_CURRENT_SOURCE_DIR}/main.cpp
${cache_files}
${db_files}
${wrapper_files})
include_directories(/usr/local/cuda/include)
......@@ -42,5 +44,6 @@ set(dependency_libs
thrift
pthread
faiss
vecwise_engine
)
target_link_libraries(vecwise_engine_server ${dependency_libs} ${cuda_library})
\ No newline at end of file
......@@ -19,6 +19,8 @@ static const std::string CONFIG_SERVER_ADDRESS = "address";
static const std::string CONFIG_SERVER_PORT = "port";
static const std::string CONFIG_SERVER_PROTOCOL = "transfer_protocol";
static const std::string CONFIG_SERVER_MODE = "server_mode";
static const std::string CONFIG_SERVER_DB_URL = "db_backend_url";
static const std::string CONFIG_SERVER_DB_NAME = "db_name";
static const std::string CONFIG_LOG = "log_config";
......
//
// Created by yhmo on 19-4-16.
//
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "VecServiceHandler.h"
#include "ServerConfig.h"
#include "utils/Log.h"
#include "utils/CommonUtil.h"
#include "db/DB.h"
#include "db/Env.h"
namespace zilliz {
namespace vecwise {
namespace server {
namespace {
static engine::DB* DB() {
static engine::DB* s_db = nullptr;
if(s_db == nullptr) {
engine::Options opt;
ConfigNode& config = ServerConfig::GetInstance().GetConfig(CONFIG_SERVER);
opt.meta.backend_uri = config.GetValue(CONFIG_SERVER_DB_URL);
opt.meta.dbname = config.GetValue(CONFIG_SERVER_DB_NAME);
std::string db_path = "/tmp/test";
CommonUtil::CreateDirectory(db_path);
s_db = engine::DB::Open(opt, db_path);
}
return s_db;
}
}
VecServiceHandler::VecServiceHandler() {
}
void
......@@ -19,7 +44,16 @@ VecServiceHandler::add_group(const VecGroup &group) {
<< ", group.index_type = " << group.index_type;
try {
engine::GroupOptions gopt;
gopt.dimension = (size_t)group.dimension;
gopt.has_id = !group.id.empty();
engine::meta::GroupSchema group_info;
engine::Status stat = DB()->add_group(gopt, group.id, group_info);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
}
SERVER_LOG_INFO << "add_group() finished";
} catch (std::exception& ex) {
SERVER_LOG_ERROR << ex.what();
}
......@@ -31,7 +65,16 @@ VecServiceHandler::get_group(VecGroup &_return, const std::string &group_id) {
SERVER_LOG_TRACE << "group_id = " << group_id;
try {
engine::meta::GroupSchema group_info;
engine::Status stat = DB()->get_group(group_id, group_info);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
} else {
_return.id = group_info.group_id;
_return.dimension = (int32_t)group_info.dimension;
}
SERVER_LOG_INFO << "get_group() finished";
} catch (std::exception& ex) {
SERVER_LOG_ERROR << ex.what();
}
......@@ -44,19 +87,29 @@ VecServiceHandler::del_group(const std::string &group_id) {
try {
SERVER_LOG_INFO << "del_group() not implemented";
} catch (std::exception& ex) {
SERVER_LOG_ERROR << ex.what();
}
}
int64_t
VecServiceHandler::add_vector(const std::string &group_id, const VecTensor &tensor) {
void
VecServiceHandler::add_vector(VecTensorIdList& _return, const std::string &group_id, const VecTensor &tensor) {
SERVER_LOG_INFO << "add_vector() called";
SERVER_LOG_TRACE << "group_id = " << group_id << ", vector size = " << tensor.tensor.size();
try {
engine::IDNumbers vector_ids;
std::vector<float> vec_f(tensor.tensor.begin(), tensor.tensor.end());
engine::Status stat = DB()->add_vectors(group_id, 1, vec_f.data(), vector_ids);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
} else {
_return.id_list.swap(vector_ids);
}
SERVER_LOG_INFO << "add_vector() finished";
} catch (std::exception& ex) {
SERVER_LOG_ERROR << ex.what();
}
......@@ -71,7 +124,20 @@ VecServiceHandler::add_vector_batch(VecTensorIdList &_return,
<< tensor_list.tensor_list.size();
try {
std::vector<float> vec_f;
for(const VecTensor& tensor : tensor_list.tensor_list) {
vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end());
}
engine::IDNumbers vector_ids;
engine::Status stat = DB()->add_vectors(group_id, tensor_list.tensor_list.size(), vec_f.data(), vector_ids);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
} else {
_return.id_list.swap(vector_ids);
}
SERVER_LOG_INFO << "add_vector_batch() finished";
} catch (std::exception& ex) {
SERVER_LOG_ERROR << ex.what();
}
......@@ -90,7 +156,18 @@ VecServiceHandler::search_vector(VecSearchResult &_return,
<< ", time range list size = " << time_range_list.range_list.size();
try {
engine::QueryResults results;
std::vector<float> vec_f(tensor.tensor.begin(), tensor.tensor.end());
engine::Status stat = DB()->search(group_id, (size_t)top_k, 1, vec_f.data(), results);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
} else {
if(results.size() > 0) {
_return.id_list.swap(results[0]);
}
}
SERVER_LOG_INFO << "search_vector() finished";
} catch (std::exception& ex) {
SERVER_LOG_ERROR << ex.what();
}
......@@ -108,7 +185,24 @@ VecServiceHandler::search_vector_batch(VecSearchResultList &_return,
<< ", time range list size = " << time_range_list.range_list.size();
try {
std::vector<float> vec_f;
for(const VecTensor& tensor : tensor_list.tensor_list) {
vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end());
}
engine::QueryResults results;
engine::Status stat = DB()->search(group_id, (size_t)top_k, tensor_list.tensor_list.size(), vec_f.data(), results);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
} else {
for(engine::QueryResult& res : results){
VecSearchResult v_res;
v_res.id_list.swap(res);
_return.result_list.push_back(v_res);
}
}
SERVER_LOG_INFO << "search_vector_batch() finished";
} catch (std::exception& ex) {
SERVER_LOG_ERROR << ex.what();
}
......
......@@ -4,7 +4,11 @@
* Proprietary and confidential.
******************************************************************************/
#pragma once
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "utils/Error.h"
#include "thrift/gen-cpp/VecService.h"
......@@ -37,7 +41,7 @@ public:
* @param group_id
* @param tensor
*/
int64_t add_vector(const std::string& group_id, const VecTensor& tensor);
void add_vector(VecTensorIdList& _return, const std::string& group_id, const VecTensor& tensor);
void add_vector_batch(VecTensorIdList& _return, const std::string& group_id, const VecTensorList& tensor_list);
......
......@@ -78,7 +78,7 @@ service VecService {
* vector interfaces
*
*/
i64 add_vector(2: string group_id, 3: VecTensor tensor) throws(1: VecException e);
VecTensorIdList add_vector(2: string group_id, 3: VecTensor tensor) throws(1: VecException e);
VecTensorIdList add_vector_batch(2: string group_id, 3: VecTensorList tensor_list) throws(1: VecException e);
/**
......
......@@ -711,8 +711,8 @@ uint32_t VecService_add_vector_result::read(::apache::thrift::protocol::TProtoco
switch (fid)
{
case 0:
if (ftype == ::apache::thrift::protocol::T_I64) {
xfer += iprot->readI64(this->success);
if (ftype == ::apache::thrift::protocol::T_STRUCT) {
xfer += this->success.read(iprot);
this->__isset.success = true;
} else {
xfer += iprot->skip(ftype);
......@@ -745,8 +745,8 @@ uint32_t VecService_add_vector_result::write(::apache::thrift::protocol::TProtoc
xfer += oprot->writeStructBegin("VecService_add_vector_result");
if (this->__isset.success) {
xfer += oprot->writeFieldBegin("success", ::apache::thrift::protocol::T_I64, 0);
xfer += oprot->writeI64(this->success);
xfer += oprot->writeFieldBegin("success", ::apache::thrift::protocol::T_STRUCT, 0);
xfer += this->success.write(oprot);
xfer += oprot->writeFieldEnd();
} else if (this->__isset.e) {
xfer += oprot->writeFieldBegin("e", ::apache::thrift::protocol::T_STRUCT, 1);
......@@ -785,8 +785,8 @@ uint32_t VecService_add_vector_presult::read(::apache::thrift::protocol::TProtoc
switch (fid)
{
case 0:
if (ftype == ::apache::thrift::protocol::T_I64) {
xfer += iprot->readI64((*(this->success)));
if (ftype == ::apache::thrift::protocol::T_STRUCT) {
xfer += (*(this->success)).read(iprot);
this->__isset.success = true;
} else {
xfer += iprot->skip(ftype);
......@@ -1718,10 +1718,10 @@ void VecServiceClient::recv_del_group()
return;
}
int64_t VecServiceClient::add_vector(const std::string& group_id, const VecTensor& tensor)
void VecServiceClient::add_vector(VecTensorIdList& _return, const std::string& group_id, const VecTensor& tensor)
{
send_add_vector(group_id, tensor);
return recv_add_vector();
recv_add_vector(_return);
}
void VecServiceClient::send_add_vector(const std::string& group_id, const VecTensor& tensor)
......@@ -1739,7 +1739,7 @@ void VecServiceClient::send_add_vector(const std::string& group_id, const VecTen
oprot_->getTransport()->flush();
}
int64_t VecServiceClient::recv_add_vector()
void VecServiceClient::recv_add_vector(VecTensorIdList& _return)
{
int32_t rseqid = 0;
......@@ -1764,7 +1764,6 @@ int64_t VecServiceClient::recv_add_vector()
iprot_->readMessageEnd();
iprot_->getTransport()->readEnd();
}
int64_t _return;
VecService_add_vector_presult result;
result.success = &_return;
result.read(iprot_);
......@@ -1772,7 +1771,8 @@ int64_t VecServiceClient::recv_add_vector()
iprot_->getTransport()->readEnd();
if (result.__isset.success) {
return _return;
// _return pointer has now been filled
return;
}
if (result.__isset.e) {
throw result.e;
......@@ -2181,7 +2181,7 @@ void VecServiceProcessor::process_add_vector(int32_t seqid, ::apache::thrift::pr
VecService_add_vector_result result;
try {
result.success = iface_->add_vector(args.group_id, args.tensor);
iface_->add_vector(result.success, args.group_id, args.tensor);
result.__isset.success = true;
} catch (VecException &e) {
result.e = e;
......@@ -2645,10 +2645,10 @@ void VecServiceConcurrentClient::recv_del_group(const int32_t seqid)
} // end while(true)
}
int64_t VecServiceConcurrentClient::add_vector(const std::string& group_id, const VecTensor& tensor)
void VecServiceConcurrentClient::add_vector(VecTensorIdList& _return, const std::string& group_id, const VecTensor& tensor)
{
int32_t seqid = send_add_vector(group_id, tensor);
return recv_add_vector(seqid);
recv_add_vector(_return, seqid);
}
int32_t VecServiceConcurrentClient::send_add_vector(const std::string& group_id, const VecTensor& tensor)
......@@ -2670,7 +2670,7 @@ int32_t VecServiceConcurrentClient::send_add_vector(const std::string& group_id,
return cseqid;
}
int64_t VecServiceConcurrentClient::recv_add_vector(const int32_t seqid)
void VecServiceConcurrentClient::recv_add_vector(VecTensorIdList& _return, const int32_t seqid)
{
int32_t rseqid = 0;
......@@ -2708,7 +2708,6 @@ int64_t VecServiceConcurrentClient::recv_add_vector(const int32_t seqid)
using ::apache::thrift::protocol::TProtocolException;
throw TProtocolException(TProtocolException::INVALID_DATA);
}
int64_t _return;
VecService_add_vector_presult result;
result.success = &_return;
result.read(iprot_);
......@@ -2716,8 +2715,9 @@ int64_t VecServiceConcurrentClient::recv_add_vector(const int32_t seqid)
iprot_->getTransport()->readEnd();
if (result.__isset.success) {
// _return pointer has now been filled
sentry.commit();
return _return;
return;
}
if (result.__isset.e) {
sentry.commit();
......
......@@ -38,7 +38,7 @@ class VecServiceIf {
* @param group_id
* @param tensor
*/
virtual int64_t add_vector(const std::string& group_id, const VecTensor& tensor) = 0;
virtual void add_vector(VecTensorIdList& _return, const std::string& group_id, const VecTensor& tensor) = 0;
virtual void add_vector_batch(VecTensorIdList& _return, const std::string& group_id, const VecTensorList& tensor_list) = 0;
/**
......@@ -90,9 +90,8 @@ class VecServiceNull : virtual public VecServiceIf {
void del_group(const std::string& /* group_id */) {
return;
}
int64_t add_vector(const std::string& /* group_id */, const VecTensor& /* tensor */) {
int64_t _return = 0;
return _return;
void add_vector(VecTensorIdList& /* _return */, const std::string& /* group_id */, const VecTensor& /* tensor */) {
return;
}
void add_vector_batch(VecTensorIdList& /* _return */, const std::string& /* group_id */, const VecTensorList& /* tensor_list */) {
return;
......@@ -492,16 +491,16 @@ class VecService_add_vector_result {
VecService_add_vector_result(const VecService_add_vector_result&);
VecService_add_vector_result& operator=(const VecService_add_vector_result&);
VecService_add_vector_result() : success(0) {
VecService_add_vector_result() {
}
virtual ~VecService_add_vector_result() throw();
int64_t success;
VecTensorIdList success;
VecException e;
_VecService_add_vector_result__isset __isset;
void __set_success(const int64_t val);
void __set_success(const VecTensorIdList& val);
void __set_e(const VecException& val);
......@@ -535,7 +534,7 @@ class VecService_add_vector_presult {
virtual ~VecService_add_vector_presult() throw();
int64_t* success;
VecTensorIdList* success;
VecException e;
_VecService_add_vector_presult__isset __isset;
......@@ -963,9 +962,9 @@ class VecServiceClient : virtual public VecServiceIf {
void del_group(const std::string& group_id);
void send_del_group(const std::string& group_id);
void recv_del_group();
int64_t add_vector(const std::string& group_id, const VecTensor& tensor);
void add_vector(VecTensorIdList& _return, const std::string& group_id, const VecTensor& tensor);
void send_add_vector(const std::string& group_id, const VecTensor& tensor);
int64_t recv_add_vector();
void recv_add_vector(VecTensorIdList& _return);
void add_vector_batch(VecTensorIdList& _return, const std::string& group_id, const VecTensorList& tensor_list);
void send_add_vector_batch(const std::string& group_id, const VecTensorList& tensor_list);
void recv_add_vector_batch(VecTensorIdList& _return);
......@@ -1063,13 +1062,14 @@ class VecServiceMultiface : virtual public VecServiceIf {
ifaces_[i]->del_group(group_id);
}
int64_t add_vector(const std::string& group_id, const VecTensor& tensor) {
void add_vector(VecTensorIdList& _return, const std::string& group_id, const VecTensor& tensor) {
size_t sz = ifaces_.size();
size_t i = 0;
for (; i < (sz - 1); ++i) {
ifaces_[i]->add_vector(group_id, tensor);
ifaces_[i]->add_vector(_return, group_id, tensor);
}
return ifaces_[i]->add_vector(group_id, tensor);
ifaces_[i]->add_vector(_return, group_id, tensor);
return;
}
void add_vector_batch(VecTensorIdList& _return, const std::string& group_id, const VecTensorList& tensor_list) {
......@@ -1141,9 +1141,9 @@ class VecServiceConcurrentClient : virtual public VecServiceIf {
void del_group(const std::string& group_id);
int32_t send_del_group(const std::string& group_id);
void recv_del_group(const int32_t seqid);
int64_t add_vector(const std::string& group_id, const VecTensor& tensor);
void add_vector(VecTensorIdList& _return, const std::string& group_id, const VecTensor& tensor);
int32_t send_add_vector(const std::string& group_id, const VecTensor& tensor);
int64_t recv_add_vector(const int32_t seqid);
void recv_add_vector(VecTensorIdList& _return, const int32_t seqid);
void add_vector_batch(VecTensorIdList& _return, const std::string& group_id, const VecTensorList& tensor_list);
int32_t send_add_vector_batch(const std::string& group_id, const VecTensorList& tensor_list);
void recv_add_vector_batch(VecTensorIdList& _return, const int32_t seqid);
......
......@@ -45,7 +45,7 @@ class VecServiceHandler : virtual public VecServiceIf {
* @param group_id
* @param tensor
*/
int64_t add_vector(const std::string& group_id, const VecTensor& tensor) {
void add_vector(VecTensorIdList& _return, const std::string& group_id, const VecTensor& tensor) {
// Your implementation goes here
printf("add_vector\n");
}
......
......@@ -68,13 +68,19 @@ void ClientApp::Run(const std::string &config_file) {
transport_ptr->open();
VecServiceClient client(protocol_ptr);
try {
const int32_t dim = 256;
VecGroup group;
group.id = "test_group";
group.dimension = 256;
group.dimension = dim;
group.index_type = 0;
client.add_group(group);
VecTensor tensor;
for(int32_t i = 0; i < dim; i++) {
tensor.tensor.push_back((double)i);
}
VecTensorIdList result;
client.add_vector(result, group.id, tensor);
} catch (apache::thrift::TException& ex) {
printf("%s", ex.what());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册