提交 32b5d7c4 编写于 作者: G groot

client call api


Former-commit-id: 1abe8528051e0fb01dc2922f3ac71c739ae6df57
上级 a7cefc6a
...@@ -3,6 +3,8 @@ server_config: ...@@ -3,6 +3,8 @@ server_config:
port: 33001 port: 33001
transfer_protocol: json #optional: binary, compact, json, debug transfer_protocol: json #optional: binary, compact, json, debug
server_mode: thread_pool #optional: simple, thread_pool server_mode: thread_pool #optional: simple, thread_pool
db_backend_url: http://127.0.0.1
db_name: vecwise
log_config: log_config:
global: global:
......
...@@ -8,6 +8,7 @@ aux_source_directory(cache cache_files) ...@@ -8,6 +8,7 @@ aux_source_directory(cache cache_files)
aux_source_directory(config config_files) aux_source_directory(config config_files)
aux_source_directory(server server_files) aux_source_directory(server server_files)
aux_source_directory(utils utils_files) aux_source_directory(utils utils_files)
aux_source_directory(db db_files)
aux_source_directory(wrapper wrapper_files) aux_source_directory(wrapper wrapper_files)
set(service_files set(service_files
...@@ -18,6 +19,7 @@ set(service_files ...@@ -18,6 +19,7 @@ set(service_files
set(vecwise_engine_src set(vecwise_engine_src
${CMAKE_CURRENT_SOURCE_DIR}/main.cpp ${CMAKE_CURRENT_SOURCE_DIR}/main.cpp
${cache_files} ${cache_files}
${db_files}
${wrapper_files}) ${wrapper_files})
include_directories(/usr/local/cuda/include) include_directories(/usr/local/cuda/include)
...@@ -42,5 +44,6 @@ set(dependency_libs ...@@ -42,5 +44,6 @@ set(dependency_libs
thrift thrift
pthread pthread
faiss faiss
vecwise_engine
) )
target_link_libraries(vecwise_engine_server ${dependency_libs} ${cuda_library}) target_link_libraries(vecwise_engine_server ${dependency_libs} ${cuda_library})
\ No newline at end of file
...@@ -19,6 +19,8 @@ static const std::string CONFIG_SERVER_ADDRESS = "address"; ...@@ -19,6 +19,8 @@ static const std::string CONFIG_SERVER_ADDRESS = "address";
static const std::string CONFIG_SERVER_PORT = "port"; static const std::string CONFIG_SERVER_PORT = "port";
static const std::string CONFIG_SERVER_PROTOCOL = "transfer_protocol"; static const std::string CONFIG_SERVER_PROTOCOL = "transfer_protocol";
static const std::string CONFIG_SERVER_MODE = "server_mode"; static const std::string CONFIG_SERVER_MODE = "server_mode";
static const std::string CONFIG_SERVER_DB_URL = "db_backend_url";
static const std::string CONFIG_SERVER_DB_NAME = "db_name";
static const std::string CONFIG_LOG = "log_config"; static const std::string CONFIG_LOG = "log_config";
......
// /*******************************************************************************
// Created by yhmo on 19-4-16. * Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// * Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "VecServiceHandler.h" #include "VecServiceHandler.h"
#include "ServerConfig.h"
#include "utils/Log.h" #include "utils/Log.h"
#include "utils/CommonUtil.h"
#include "db/DB.h"
#include "db/Env.h"
namespace zilliz { namespace zilliz {
namespace vecwise { namespace vecwise {
namespace server { namespace server {
namespace {
static engine::DB* DB() {
static engine::DB* s_db = nullptr;
if(s_db == nullptr) {
engine::Options opt;
ConfigNode& config = ServerConfig::GetInstance().GetConfig(CONFIG_SERVER);
opt.meta.backend_uri = config.GetValue(CONFIG_SERVER_DB_URL);
opt.meta.dbname = config.GetValue(CONFIG_SERVER_DB_NAME);
std::string db_path = "/tmp/test";
CommonUtil::CreateDirectory(db_path);
s_db = engine::DB::Open(opt, db_path);
}
return s_db;
}
}
VecServiceHandler::VecServiceHandler() { VecServiceHandler::VecServiceHandler() {
} }
void void
...@@ -19,7 +44,16 @@ VecServiceHandler::add_group(const VecGroup &group) { ...@@ -19,7 +44,16 @@ VecServiceHandler::add_group(const VecGroup &group) {
<< ", group.index_type = " << group.index_type; << ", group.index_type = " << group.index_type;
try { try {
engine::GroupOptions gopt;
gopt.dimension = (size_t)group.dimension;
gopt.has_id = !group.id.empty();
engine::meta::GroupSchema group_info;
engine::Status stat = DB()->add_group(gopt, group.id, group_info);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
}
SERVER_LOG_INFO << "add_group() finished";
} catch (std::exception& ex) { } catch (std::exception& ex) {
SERVER_LOG_ERROR << ex.what(); SERVER_LOG_ERROR << ex.what();
} }
...@@ -31,7 +65,16 @@ VecServiceHandler::get_group(VecGroup &_return, const std::string &group_id) { ...@@ -31,7 +65,16 @@ VecServiceHandler::get_group(VecGroup &_return, const std::string &group_id) {
SERVER_LOG_TRACE << "group_id = " << group_id; SERVER_LOG_TRACE << "group_id = " << group_id;
try { try {
engine::meta::GroupSchema group_info;
engine::Status stat = DB()->get_group(group_id, group_info);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
} else {
_return.id = group_info.group_id;
_return.dimension = (int32_t)group_info.dimension;
}
SERVER_LOG_INFO << "get_group() finished";
} catch (std::exception& ex) { } catch (std::exception& ex) {
SERVER_LOG_ERROR << ex.what(); SERVER_LOG_ERROR << ex.what();
} }
...@@ -44,19 +87,29 @@ VecServiceHandler::del_group(const std::string &group_id) { ...@@ -44,19 +87,29 @@ VecServiceHandler::del_group(const std::string &group_id) {
try { try {
SERVER_LOG_INFO << "del_group() not implemented";
} catch (std::exception& ex) { } catch (std::exception& ex) {
SERVER_LOG_ERROR << ex.what(); SERVER_LOG_ERROR << ex.what();
} }
} }
int64_t void
VecServiceHandler::add_vector(const std::string &group_id, const VecTensor &tensor) { VecServiceHandler::add_vector(VecTensorIdList& _return, const std::string &group_id, const VecTensor &tensor) {
SERVER_LOG_INFO << "add_vector() called"; SERVER_LOG_INFO << "add_vector() called";
SERVER_LOG_TRACE << "group_id = " << group_id << ", vector size = " << tensor.tensor.size(); SERVER_LOG_TRACE << "group_id = " << group_id << ", vector size = " << tensor.tensor.size();
try { try {
engine::IDNumbers vector_ids;
std::vector<float> vec_f(tensor.tensor.begin(), tensor.tensor.end());
engine::Status stat = DB()->add_vectors(group_id, 1, vec_f.data(), vector_ids);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
} else {
_return.id_list.swap(vector_ids);
}
SERVER_LOG_INFO << "add_vector() finished";
} catch (std::exception& ex) { } catch (std::exception& ex) {
SERVER_LOG_ERROR << ex.what(); SERVER_LOG_ERROR << ex.what();
} }
...@@ -71,7 +124,20 @@ VecServiceHandler::add_vector_batch(VecTensorIdList &_return, ...@@ -71,7 +124,20 @@ VecServiceHandler::add_vector_batch(VecTensorIdList &_return,
<< tensor_list.tensor_list.size(); << tensor_list.tensor_list.size();
try { try {
std::vector<float> vec_f;
for(const VecTensor& tensor : tensor_list.tensor_list) {
vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end());
}
engine::IDNumbers vector_ids;
engine::Status stat = DB()->add_vectors(group_id, tensor_list.tensor_list.size(), vec_f.data(), vector_ids);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
} else {
_return.id_list.swap(vector_ids);
}
SERVER_LOG_INFO << "add_vector_batch() finished";
} catch (std::exception& ex) { } catch (std::exception& ex) {
SERVER_LOG_ERROR << ex.what(); SERVER_LOG_ERROR << ex.what();
} }
...@@ -90,7 +156,18 @@ VecServiceHandler::search_vector(VecSearchResult &_return, ...@@ -90,7 +156,18 @@ VecServiceHandler::search_vector(VecSearchResult &_return,
<< ", time range list size = " << time_range_list.range_list.size(); << ", time range list size = " << time_range_list.range_list.size();
try { try {
engine::QueryResults results;
std::vector<float> vec_f(tensor.tensor.begin(), tensor.tensor.end());
engine::Status stat = DB()->search(group_id, (size_t)top_k, 1, vec_f.data(), results);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
} else {
if(results.size() > 0) {
_return.id_list.swap(results[0]);
}
}
SERVER_LOG_INFO << "search_vector() finished";
} catch (std::exception& ex) { } catch (std::exception& ex) {
SERVER_LOG_ERROR << ex.what(); SERVER_LOG_ERROR << ex.what();
} }
...@@ -108,7 +185,24 @@ VecServiceHandler::search_vector_batch(VecSearchResultList &_return, ...@@ -108,7 +185,24 @@ VecServiceHandler::search_vector_batch(VecSearchResultList &_return,
<< ", time range list size = " << time_range_list.range_list.size(); << ", time range list size = " << time_range_list.range_list.size();
try { try {
std::vector<float> vec_f;
for(const VecTensor& tensor : tensor_list.tensor_list) {
vec_f.insert(vec_f.begin(), tensor.tensor.begin(), tensor.tensor.end());
}
engine::QueryResults results;
engine::Status stat = DB()->search(group_id, (size_t)top_k, tensor_list.tensor_list.size(), vec_f.data(), results);
if(!stat.ok()) {
SERVER_LOG_ERROR << "Engine failed: " << stat.ToString();
} else {
for(engine::QueryResult& res : results){
VecSearchResult v_res;
v_res.id_list.swap(res);
_return.result_list.push_back(v_res);
}
}
SERVER_LOG_INFO << "search_vector_batch() finished";
} catch (std::exception& ex) { } catch (std::exception& ex) {
SERVER_LOG_ERROR << ex.what(); SERVER_LOG_ERROR << ex.what();
} }
......
...@@ -4,7 +4,11 @@ ...@@ -4,7 +4,11 @@
* Proprietary and confidential. * Proprietary and confidential.
******************************************************************************/ ******************************************************************************/
#pragma once #pragma once
/*******************************************************************************
* Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
* Unauthorized copying of this file, via any medium is strictly prohibited.
* Proprietary and confidential.
******************************************************************************/
#include "utils/Error.h" #include "utils/Error.h"
#include "thrift/gen-cpp/VecService.h" #include "thrift/gen-cpp/VecService.h"
...@@ -37,7 +41,7 @@ public: ...@@ -37,7 +41,7 @@ public:
* @param group_id * @param group_id
* @param tensor * @param tensor
*/ */
int64_t add_vector(const std::string& group_id, const VecTensor& tensor); void add_vector(VecTensorIdList& _return, const std::string& group_id, const VecTensor& tensor);
void add_vector_batch(VecTensorIdList& _return, const std::string& group_id, const VecTensorList& tensor_list); void add_vector_batch(VecTensorIdList& _return, const std::string& group_id, const VecTensorList& tensor_list);
......
...@@ -78,7 +78,7 @@ service VecService { ...@@ -78,7 +78,7 @@ service VecService {
* vector interfaces * vector interfaces
* *
*/ */
i64 add_vector(2: string group_id, 3: VecTensor tensor) throws(1: VecException e); VecTensorIdList add_vector(2: string group_id, 3: VecTensor tensor) throws(1: VecException e);
VecTensorIdList add_vector_batch(2: string group_id, 3: VecTensorList tensor_list) throws(1: VecException e); VecTensorIdList add_vector_batch(2: string group_id, 3: VecTensorList tensor_list) throws(1: VecException e);
/** /**
......
...@@ -711,8 +711,8 @@ uint32_t VecService_add_vector_result::read(::apache::thrift::protocol::TProtoco ...@@ -711,8 +711,8 @@ uint32_t VecService_add_vector_result::read(::apache::thrift::protocol::TProtoco
switch (fid) switch (fid)
{ {
case 0: case 0:
if (ftype == ::apache::thrift::protocol::T_I64) { if (ftype == ::apache::thrift::protocol::T_STRUCT) {
xfer += iprot->readI64(this->success); xfer += this->success.read(iprot);
this->__isset.success = true; this->__isset.success = true;
} else { } else {
xfer += iprot->skip(ftype); xfer += iprot->skip(ftype);
...@@ -745,8 +745,8 @@ uint32_t VecService_add_vector_result::write(::apache::thrift::protocol::TProtoc ...@@ -745,8 +745,8 @@ uint32_t VecService_add_vector_result::write(::apache::thrift::protocol::TProtoc
xfer += oprot->writeStructBegin("VecService_add_vector_result"); xfer += oprot->writeStructBegin("VecService_add_vector_result");
if (this->__isset.success) { if (this->__isset.success) {
xfer += oprot->writeFieldBegin("success", ::apache::thrift::protocol::T_I64, 0); xfer += oprot->writeFieldBegin("success", ::apache::thrift::protocol::T_STRUCT, 0);
xfer += oprot->writeI64(this->success); xfer += this->success.write(oprot);
xfer += oprot->writeFieldEnd(); xfer += oprot->writeFieldEnd();
} else if (this->__isset.e) { } else if (this->__isset.e) {
xfer += oprot->writeFieldBegin("e", ::apache::thrift::protocol::T_STRUCT, 1); xfer += oprot->writeFieldBegin("e", ::apache::thrift::protocol::T_STRUCT, 1);
...@@ -785,8 +785,8 @@ uint32_t VecService_add_vector_presult::read(::apache::thrift::protocol::TProtoc ...@@ -785,8 +785,8 @@ uint32_t VecService_add_vector_presult::read(::apache::thrift::protocol::TProtoc
switch (fid) switch (fid)
{ {
case 0: case 0:
if (ftype == ::apache::thrift::protocol::T_I64) { if (ftype == ::apache::thrift::protocol::T_STRUCT) {
xfer += iprot->readI64((*(this->success))); xfer += (*(this->success)).read(iprot);
this->__isset.success = true; this->__isset.success = true;
} else { } else {
xfer += iprot->skip(ftype); xfer += iprot->skip(ftype);
...@@ -1718,10 +1718,10 @@ void VecServiceClient::recv_del_group() ...@@ -1718,10 +1718,10 @@ void VecServiceClient::recv_del_group()
return; return;
} }
int64_t VecServiceClient::add_vector(const std::string& group_id, const VecTensor& tensor) void VecServiceClient::add_vector(VecTensorIdList& _return, const std::string& group_id, const VecTensor& tensor)
{ {
send_add_vector(group_id, tensor); send_add_vector(group_id, tensor);
return recv_add_vector(); recv_add_vector(_return);
} }
void VecServiceClient::send_add_vector(const std::string& group_id, const VecTensor& tensor) void VecServiceClient::send_add_vector(const std::string& group_id, const VecTensor& tensor)
...@@ -1739,7 +1739,7 @@ void VecServiceClient::send_add_vector(const std::string& group_id, const VecTen ...@@ -1739,7 +1739,7 @@ void VecServiceClient::send_add_vector(const std::string& group_id, const VecTen
oprot_->getTransport()->flush(); oprot_->getTransport()->flush();
} }
int64_t VecServiceClient::recv_add_vector() void VecServiceClient::recv_add_vector(VecTensorIdList& _return)
{ {
int32_t rseqid = 0; int32_t rseqid = 0;
...@@ -1764,7 +1764,6 @@ int64_t VecServiceClient::recv_add_vector() ...@@ -1764,7 +1764,6 @@ int64_t VecServiceClient::recv_add_vector()
iprot_->readMessageEnd(); iprot_->readMessageEnd();
iprot_->getTransport()->readEnd(); iprot_->getTransport()->readEnd();
} }
int64_t _return;
VecService_add_vector_presult result; VecService_add_vector_presult result;
result.success = &_return; result.success = &_return;
result.read(iprot_); result.read(iprot_);
...@@ -1772,7 +1771,8 @@ int64_t VecServiceClient::recv_add_vector() ...@@ -1772,7 +1771,8 @@ int64_t VecServiceClient::recv_add_vector()
iprot_->getTransport()->readEnd(); iprot_->getTransport()->readEnd();
if (result.__isset.success) { if (result.__isset.success) {
return _return; // _return pointer has now been filled
return;
} }
if (result.__isset.e) { if (result.__isset.e) {
throw result.e; throw result.e;
...@@ -2181,7 +2181,7 @@ void VecServiceProcessor::process_add_vector(int32_t seqid, ::apache::thrift::pr ...@@ -2181,7 +2181,7 @@ void VecServiceProcessor::process_add_vector(int32_t seqid, ::apache::thrift::pr
VecService_add_vector_result result; VecService_add_vector_result result;
try { try {
result.success = iface_->add_vector(args.group_id, args.tensor); iface_->add_vector(result.success, args.group_id, args.tensor);
result.__isset.success = true; result.__isset.success = true;
} catch (VecException &e) { } catch (VecException &e) {
result.e = e; result.e = e;
...@@ -2645,10 +2645,10 @@ void VecServiceConcurrentClient::recv_del_group(const int32_t seqid) ...@@ -2645,10 +2645,10 @@ void VecServiceConcurrentClient::recv_del_group(const int32_t seqid)
} // end while(true) } // end while(true)
} }
int64_t VecServiceConcurrentClient::add_vector(const std::string& group_id, const VecTensor& tensor) void VecServiceConcurrentClient::add_vector(VecTensorIdList& _return, const std::string& group_id, const VecTensor& tensor)
{ {
int32_t seqid = send_add_vector(group_id, tensor); int32_t seqid = send_add_vector(group_id, tensor);
return recv_add_vector(seqid); recv_add_vector(_return, seqid);
} }
int32_t VecServiceConcurrentClient::send_add_vector(const std::string& group_id, const VecTensor& tensor) int32_t VecServiceConcurrentClient::send_add_vector(const std::string& group_id, const VecTensor& tensor)
...@@ -2670,7 +2670,7 @@ int32_t VecServiceConcurrentClient::send_add_vector(const std::string& group_id, ...@@ -2670,7 +2670,7 @@ int32_t VecServiceConcurrentClient::send_add_vector(const std::string& group_id,
return cseqid; return cseqid;
} }
int64_t VecServiceConcurrentClient::recv_add_vector(const int32_t seqid) void VecServiceConcurrentClient::recv_add_vector(VecTensorIdList& _return, const int32_t seqid)
{ {
int32_t rseqid = 0; int32_t rseqid = 0;
...@@ -2708,7 +2708,6 @@ int64_t VecServiceConcurrentClient::recv_add_vector(const int32_t seqid) ...@@ -2708,7 +2708,6 @@ int64_t VecServiceConcurrentClient::recv_add_vector(const int32_t seqid)
using ::apache::thrift::protocol::TProtocolException; using ::apache::thrift::protocol::TProtocolException;
throw TProtocolException(TProtocolException::INVALID_DATA); throw TProtocolException(TProtocolException::INVALID_DATA);
} }
int64_t _return;
VecService_add_vector_presult result; VecService_add_vector_presult result;
result.success = &_return; result.success = &_return;
result.read(iprot_); result.read(iprot_);
...@@ -2716,8 +2715,9 @@ int64_t VecServiceConcurrentClient::recv_add_vector(const int32_t seqid) ...@@ -2716,8 +2715,9 @@ int64_t VecServiceConcurrentClient::recv_add_vector(const int32_t seqid)
iprot_->getTransport()->readEnd(); iprot_->getTransport()->readEnd();
if (result.__isset.success) { if (result.__isset.success) {
// _return pointer has now been filled
sentry.commit(); sentry.commit();
return _return; return;
} }
if (result.__isset.e) { if (result.__isset.e) {
sentry.commit(); sentry.commit();
......
...@@ -38,7 +38,7 @@ class VecServiceIf { ...@@ -38,7 +38,7 @@ class VecServiceIf {
* @param group_id * @param group_id
* @param tensor * @param tensor
*/ */
virtual int64_t add_vector(const std::string& group_id, const VecTensor& tensor) = 0; virtual void add_vector(VecTensorIdList& _return, const std::string& group_id, const VecTensor& tensor) = 0;
virtual void add_vector_batch(VecTensorIdList& _return, const std::string& group_id, const VecTensorList& tensor_list) = 0; virtual void add_vector_batch(VecTensorIdList& _return, const std::string& group_id, const VecTensorList& tensor_list) = 0;
/** /**
...@@ -90,9 +90,8 @@ class VecServiceNull : virtual public VecServiceIf { ...@@ -90,9 +90,8 @@ class VecServiceNull : virtual public VecServiceIf {
void del_group(const std::string& /* group_id */) { void del_group(const std::string& /* group_id */) {
return; return;
} }
int64_t add_vector(const std::string& /* group_id */, const VecTensor& /* tensor */) { void add_vector(VecTensorIdList& /* _return */, const std::string& /* group_id */, const VecTensor& /* tensor */) {
int64_t _return = 0; return;
return _return;
} }
void add_vector_batch(VecTensorIdList& /* _return */, const std::string& /* group_id */, const VecTensorList& /* tensor_list */) { void add_vector_batch(VecTensorIdList& /* _return */, const std::string& /* group_id */, const VecTensorList& /* tensor_list */) {
return; return;
...@@ -492,16 +491,16 @@ class VecService_add_vector_result { ...@@ -492,16 +491,16 @@ class VecService_add_vector_result {
VecService_add_vector_result(const VecService_add_vector_result&); VecService_add_vector_result(const VecService_add_vector_result&);
VecService_add_vector_result& operator=(const VecService_add_vector_result&); VecService_add_vector_result& operator=(const VecService_add_vector_result&);
VecService_add_vector_result() : success(0) { VecService_add_vector_result() {
} }
virtual ~VecService_add_vector_result() throw(); virtual ~VecService_add_vector_result() throw();
int64_t success; VecTensorIdList success;
VecException e; VecException e;
_VecService_add_vector_result__isset __isset; _VecService_add_vector_result__isset __isset;
void __set_success(const int64_t val); void __set_success(const VecTensorIdList& val);
void __set_e(const VecException& val); void __set_e(const VecException& val);
...@@ -535,7 +534,7 @@ class VecService_add_vector_presult { ...@@ -535,7 +534,7 @@ class VecService_add_vector_presult {
virtual ~VecService_add_vector_presult() throw(); virtual ~VecService_add_vector_presult() throw();
int64_t* success; VecTensorIdList* success;
VecException e; VecException e;
_VecService_add_vector_presult__isset __isset; _VecService_add_vector_presult__isset __isset;
...@@ -963,9 +962,9 @@ class VecServiceClient : virtual public VecServiceIf { ...@@ -963,9 +962,9 @@ class VecServiceClient : virtual public VecServiceIf {
void del_group(const std::string& group_id); void del_group(const std::string& group_id);
void send_del_group(const std::string& group_id); void send_del_group(const std::string& group_id);
void recv_del_group(); void recv_del_group();
int64_t add_vector(const std::string& group_id, const VecTensor& tensor); void add_vector(VecTensorIdList& _return, const std::string& group_id, const VecTensor& tensor);
void send_add_vector(const std::string& group_id, const VecTensor& tensor); void send_add_vector(const std::string& group_id, const VecTensor& tensor);
int64_t recv_add_vector(); void recv_add_vector(VecTensorIdList& _return);
void add_vector_batch(VecTensorIdList& _return, const std::string& group_id, const VecTensorList& tensor_list); void add_vector_batch(VecTensorIdList& _return, const std::string& group_id, const VecTensorList& tensor_list);
void send_add_vector_batch(const std::string& group_id, const VecTensorList& tensor_list); void send_add_vector_batch(const std::string& group_id, const VecTensorList& tensor_list);
void recv_add_vector_batch(VecTensorIdList& _return); void recv_add_vector_batch(VecTensorIdList& _return);
...@@ -1063,13 +1062,14 @@ class VecServiceMultiface : virtual public VecServiceIf { ...@@ -1063,13 +1062,14 @@ class VecServiceMultiface : virtual public VecServiceIf {
ifaces_[i]->del_group(group_id); ifaces_[i]->del_group(group_id);
} }
int64_t add_vector(const std::string& group_id, const VecTensor& tensor) { void add_vector(VecTensorIdList& _return, const std::string& group_id, const VecTensor& tensor) {
size_t sz = ifaces_.size(); size_t sz = ifaces_.size();
size_t i = 0; size_t i = 0;
for (; i < (sz - 1); ++i) { for (; i < (sz - 1); ++i) {
ifaces_[i]->add_vector(group_id, tensor); ifaces_[i]->add_vector(_return, group_id, tensor);
} }
return ifaces_[i]->add_vector(group_id, tensor); ifaces_[i]->add_vector(_return, group_id, tensor);
return;
} }
void add_vector_batch(VecTensorIdList& _return, const std::string& group_id, const VecTensorList& tensor_list) { void add_vector_batch(VecTensorIdList& _return, const std::string& group_id, const VecTensorList& tensor_list) {
...@@ -1141,9 +1141,9 @@ class VecServiceConcurrentClient : virtual public VecServiceIf { ...@@ -1141,9 +1141,9 @@ class VecServiceConcurrentClient : virtual public VecServiceIf {
void del_group(const std::string& group_id); void del_group(const std::string& group_id);
int32_t send_del_group(const std::string& group_id); int32_t send_del_group(const std::string& group_id);
void recv_del_group(const int32_t seqid); void recv_del_group(const int32_t seqid);
int64_t add_vector(const std::string& group_id, const VecTensor& tensor); void add_vector(VecTensorIdList& _return, const std::string& group_id, const VecTensor& tensor);
int32_t send_add_vector(const std::string& group_id, const VecTensor& tensor); int32_t send_add_vector(const std::string& group_id, const VecTensor& tensor);
int64_t recv_add_vector(const int32_t seqid); void recv_add_vector(VecTensorIdList& _return, const int32_t seqid);
void add_vector_batch(VecTensorIdList& _return, const std::string& group_id, const VecTensorList& tensor_list); void add_vector_batch(VecTensorIdList& _return, const std::string& group_id, const VecTensorList& tensor_list);
int32_t send_add_vector_batch(const std::string& group_id, const VecTensorList& tensor_list); int32_t send_add_vector_batch(const std::string& group_id, const VecTensorList& tensor_list);
void recv_add_vector_batch(VecTensorIdList& _return, const int32_t seqid); void recv_add_vector_batch(VecTensorIdList& _return, const int32_t seqid);
......
...@@ -45,7 +45,7 @@ class VecServiceHandler : virtual public VecServiceIf { ...@@ -45,7 +45,7 @@ class VecServiceHandler : virtual public VecServiceIf {
* @param group_id * @param group_id
* @param tensor * @param tensor
*/ */
int64_t add_vector(const std::string& group_id, const VecTensor& tensor) { void add_vector(VecTensorIdList& _return, const std::string& group_id, const VecTensor& tensor) {
// Your implementation goes here // Your implementation goes here
printf("add_vector\n"); printf("add_vector\n");
} }
......
...@@ -68,13 +68,19 @@ void ClientApp::Run(const std::string &config_file) { ...@@ -68,13 +68,19 @@ void ClientApp::Run(const std::string &config_file) {
transport_ptr->open(); transport_ptr->open();
VecServiceClient client(protocol_ptr); VecServiceClient client(protocol_ptr);
try { try {
const int32_t dim = 256;
VecGroup group; VecGroup group;
group.id = "test_group"; group.id = "test_group";
group.dimension = 256; group.dimension = dim;
group.index_type = 0; group.index_type = 0;
client.add_group(group); client.add_group(group);
VecTensor tensor;
for(int32_t i = 0; i < dim; i++) {
tensor.tensor.push_back((double)i);
}
VecTensorIdList result;
client.add_vector(result, group.id, tensor);
} catch (apache::thrift::TException& ex) { } catch (apache::thrift::TException& ex) {
printf("%s", ex.what()); printf("%s", ex.what());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册