提交 f9967cd7 编写于 作者: G groot

modify thrift api


Former-commit-id: 744dec10d157280d9776a7507aaa428ca76aad66
上级 f7a2ac7b
......@@ -85,6 +85,12 @@ struct VecDateTime {
6: required i32 second;
}
/**
* time_begin; time range begin
* begine_closed; true means '[', false means '('
* time_end; set to true to return tensor double array
* end_closed; time range end
*/
struct VecTimeRange {
1: required VecDateTime time_begin;
2: required bool begine_closed;
......@@ -92,9 +98,15 @@ struct VecTimeRange {
4: required bool end_closed;
}
/**
* attrib_filter; search condition, for example: "color=red"
* time_ranges; search condition, for example: "date between 1999-02-12 and 2008-10-14"
* return_attribs; specify required attribute names
*/
struct VecSearchFilter {
1: optional map<string, string> attrib_filter;
2: optional list<VecTimeRange> time_ranges;
3: optional list<string> return_attribs;
}
service VecService {
......
......@@ -937,14 +937,14 @@ uint32_t VecService_add_vector_batch_result::read(::apache::thrift::protocol::TP
if (ftype == ::apache::thrift::protocol::T_LIST) {
{
this->success.clear();
uint32_t _size93;
::apache::thrift::protocol::TType _etype96;
xfer += iprot->readListBegin(_etype96, _size93);
this->success.resize(_size93);
uint32_t _i97;
for (_i97 = 0; _i97 < _size93; ++_i97)
uint32_t _size99;
::apache::thrift::protocol::TType _etype102;
xfer += iprot->readListBegin(_etype102, _size99);
this->success.resize(_size99);
uint32_t _i103;
for (_i103 = 0; _i103 < _size99; ++_i103)
{
xfer += iprot->readString(this->success[_i97]);
xfer += iprot->readString(this->success[_i103]);
}
xfer += iprot->readListEnd();
}
......@@ -983,10 +983,10 @@ uint32_t VecService_add_vector_batch_result::write(::apache::thrift::protocol::T
xfer += oprot->writeFieldBegin("success", ::apache::thrift::protocol::T_LIST, 0);
{
xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRING, static_cast<uint32_t>(this->success.size()));
std::vector<std::string> ::const_iterator _iter98;
for (_iter98 = this->success.begin(); _iter98 != this->success.end(); ++_iter98)
std::vector<std::string> ::const_iterator _iter104;
for (_iter104 = this->success.begin(); _iter104 != this->success.end(); ++_iter104)
{
xfer += oprot->writeString((*_iter98));
xfer += oprot->writeString((*_iter104));
}
xfer += oprot->writeListEnd();
}
......@@ -1031,14 +1031,14 @@ uint32_t VecService_add_vector_batch_presult::read(::apache::thrift::protocol::T
if (ftype == ::apache::thrift::protocol::T_LIST) {
{
(*(this->success)).clear();
uint32_t _size99;
::apache::thrift::protocol::TType _etype102;
xfer += iprot->readListBegin(_etype102, _size99);
(*(this->success)).resize(_size99);
uint32_t _i103;
for (_i103 = 0; _i103 < _size99; ++_i103)
uint32_t _size105;
::apache::thrift::protocol::TType _etype108;
xfer += iprot->readListBegin(_etype108, _size105);
(*(this->success)).resize(_size105);
uint32_t _i109;
for (_i109 = 0; _i109 < _size105; ++_i109)
{
xfer += iprot->readString((*(this->success))[_i103]);
xfer += iprot->readString((*(this->success))[_i109]);
}
xfer += iprot->readListEnd();
}
......@@ -1415,14 +1415,14 @@ uint32_t VecService_add_binary_vector_batch_result::read(::apache::thrift::proto
if (ftype == ::apache::thrift::protocol::T_LIST) {
{
this->success.clear();
uint32_t _size104;
::apache::thrift::protocol::TType _etype107;
xfer += iprot->readListBegin(_etype107, _size104);
this->success.resize(_size104);
uint32_t _i108;
for (_i108 = 0; _i108 < _size104; ++_i108)
uint32_t _size110;
::apache::thrift::protocol::TType _etype113;
xfer += iprot->readListBegin(_etype113, _size110);
this->success.resize(_size110);
uint32_t _i114;
for (_i114 = 0; _i114 < _size110; ++_i114)
{
xfer += iprot->readString(this->success[_i108]);
xfer += iprot->readString(this->success[_i114]);
}
xfer += iprot->readListEnd();
}
......@@ -1461,10 +1461,10 @@ uint32_t VecService_add_binary_vector_batch_result::write(::apache::thrift::prot
xfer += oprot->writeFieldBegin("success", ::apache::thrift::protocol::T_LIST, 0);
{
xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRING, static_cast<uint32_t>(this->success.size()));
std::vector<std::string> ::const_iterator _iter109;
for (_iter109 = this->success.begin(); _iter109 != this->success.end(); ++_iter109)
std::vector<std::string> ::const_iterator _iter115;
for (_iter115 = this->success.begin(); _iter115 != this->success.end(); ++_iter115)
{
xfer += oprot->writeString((*_iter109));
xfer += oprot->writeString((*_iter115));
}
xfer += oprot->writeListEnd();
}
......@@ -1509,14 +1509,14 @@ uint32_t VecService_add_binary_vector_batch_presult::read(::apache::thrift::prot
if (ftype == ::apache::thrift::protocol::T_LIST) {
{
(*(this->success)).clear();
uint32_t _size110;
::apache::thrift::protocol::TType _etype113;
xfer += iprot->readListBegin(_etype113, _size110);
(*(this->success)).resize(_size110);
uint32_t _i114;
for (_i114 = 0; _i114 < _size110; ++_i114)
uint32_t _size116;
::apache::thrift::protocol::TType _etype119;
xfer += iprot->readListBegin(_etype119, _size116);
(*(this->success)).resize(_size116);
uint32_t _i120;
for (_i120 = 0; _i120 < _size116; ++_i120)
{
xfer += iprot->readString((*(this->success))[_i114]);
xfer += iprot->readString((*(this->success))[_i120]);
}
xfer += iprot->readListEnd();
}
......
......@@ -1647,6 +1647,11 @@ void VecSearchFilter::__set_time_ranges(const std::vector<VecTimeRange> & val) {
this->time_ranges = val;
__isset.time_ranges = true;
}
void VecSearchFilter::__set_return_attribs(const std::vector<std::string> & val) {
this->return_attribs = val;
__isset.return_attribs = true;
}
std::ostream& operator<<(std::ostream& out, const VecSearchFilter& obj)
{
obj.printTo(out);
......@@ -1718,6 +1723,26 @@ uint32_t VecSearchFilter::read(::apache::thrift::protocol::TProtocol* iprot) {
xfer += iprot->skip(ftype);
}
break;
case 3:
if (ftype == ::apache::thrift::protocol::T_LIST) {
{
this->return_attribs.clear();
uint32_t _size89;
::apache::thrift::protocol::TType _etype92;
xfer += iprot->readListBegin(_etype92, _size89);
this->return_attribs.resize(_size89);
uint32_t _i93;
for (_i93 = 0; _i93 < _size89; ++_i93)
{
xfer += iprot->readString(this->return_attribs[_i93]);
}
xfer += iprot->readListEnd();
}
this->__isset.return_attribs = true;
} else {
xfer += iprot->skip(ftype);
}
break;
default:
xfer += iprot->skip(ftype);
break;
......@@ -1739,11 +1764,11 @@ uint32_t VecSearchFilter::write(::apache::thrift::protocol::TProtocol* oprot) co
xfer += oprot->writeFieldBegin("attrib_filter", ::apache::thrift::protocol::T_MAP, 1);
{
xfer += oprot->writeMapBegin(::apache::thrift::protocol::T_STRING, ::apache::thrift::protocol::T_STRING, static_cast<uint32_t>(this->attrib_filter.size()));
std::map<std::string, std::string> ::const_iterator _iter89;
for (_iter89 = this->attrib_filter.begin(); _iter89 != this->attrib_filter.end(); ++_iter89)
std::map<std::string, std::string> ::const_iterator _iter94;
for (_iter94 = this->attrib_filter.begin(); _iter94 != this->attrib_filter.end(); ++_iter94)
{
xfer += oprot->writeString(_iter89->first);
xfer += oprot->writeString(_iter89->second);
xfer += oprot->writeString(_iter94->first);
xfer += oprot->writeString(_iter94->second);
}
xfer += oprot->writeMapEnd();
}
......@@ -1753,10 +1778,23 @@ uint32_t VecSearchFilter::write(::apache::thrift::protocol::TProtocol* oprot) co
xfer += oprot->writeFieldBegin("time_ranges", ::apache::thrift::protocol::T_LIST, 2);
{
xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRUCT, static_cast<uint32_t>(this->time_ranges.size()));
std::vector<VecTimeRange> ::const_iterator _iter90;
for (_iter90 = this->time_ranges.begin(); _iter90 != this->time_ranges.end(); ++_iter90)
std::vector<VecTimeRange> ::const_iterator _iter95;
for (_iter95 = this->time_ranges.begin(); _iter95 != this->time_ranges.end(); ++_iter95)
{
xfer += (*_iter95).write(oprot);
}
xfer += oprot->writeListEnd();
}
xfer += oprot->writeFieldEnd();
}
if (this->__isset.return_attribs) {
xfer += oprot->writeFieldBegin("return_attribs", ::apache::thrift::protocol::T_LIST, 3);
{
xfer += oprot->writeListBegin(::apache::thrift::protocol::T_STRING, static_cast<uint32_t>(this->return_attribs.size()));
std::vector<std::string> ::const_iterator _iter96;
for (_iter96 = this->return_attribs.begin(); _iter96 != this->return_attribs.end(); ++_iter96)
{
xfer += (*_iter90).write(oprot);
xfer += oprot->writeString((*_iter96));
}
xfer += oprot->writeListEnd();
}
......@@ -1771,18 +1809,21 @@ void swap(VecSearchFilter &a, VecSearchFilter &b) {
using ::std::swap;
swap(a.attrib_filter, b.attrib_filter);
swap(a.time_ranges, b.time_ranges);
swap(a.return_attribs, b.return_attribs);
swap(a.__isset, b.__isset);
}
VecSearchFilter::VecSearchFilter(const VecSearchFilter& other91) {
attrib_filter = other91.attrib_filter;
time_ranges = other91.time_ranges;
__isset = other91.__isset;
VecSearchFilter::VecSearchFilter(const VecSearchFilter& other97) {
attrib_filter = other97.attrib_filter;
time_ranges = other97.time_ranges;
return_attribs = other97.return_attribs;
__isset = other97.__isset;
}
VecSearchFilter& VecSearchFilter::operator=(const VecSearchFilter& other92) {
attrib_filter = other92.attrib_filter;
time_ranges = other92.time_ranges;
__isset = other92.__isset;
VecSearchFilter& VecSearchFilter::operator=(const VecSearchFilter& other98) {
attrib_filter = other98.attrib_filter;
time_ranges = other98.time_ranges;
return_attribs = other98.return_attribs;
__isset = other98.__isset;
return *this;
}
void VecSearchFilter::printTo(std::ostream& out) const {
......@@ -1790,6 +1831,7 @@ void VecSearchFilter::printTo(std::ostream& out) const {
out << "VecSearchFilter(";
out << "attrib_filter="; (__isset.attrib_filter ? (out << to_string(attrib_filter)) : (out << "<null>"));
out << ", " << "time_ranges="; (__isset.time_ranges ? (out << to_string(time_ranges)) : (out << "<null>"));
out << ", " << "return_attribs="; (__isset.return_attribs ? (out << to_string(return_attribs)) : (out << "<null>"));
out << ")";
}
......
......@@ -597,9 +597,10 @@ void swap(VecTimeRange &a, VecTimeRange &b);
std::ostream& operator<<(std::ostream& out, const VecTimeRange& obj);
typedef struct _VecSearchFilter__isset {
_VecSearchFilter__isset() : attrib_filter(false), time_ranges(false) {}
_VecSearchFilter__isset() : attrib_filter(false), time_ranges(false), return_attribs(false) {}
bool attrib_filter :1;
bool time_ranges :1;
bool return_attribs :1;
} _VecSearchFilter__isset;
class VecSearchFilter : public virtual ::apache::thrift::TBase {
......@@ -613,6 +614,7 @@ class VecSearchFilter : public virtual ::apache::thrift::TBase {
virtual ~VecSearchFilter() throw();
std::map<std::string, std::string> attrib_filter;
std::vector<VecTimeRange> time_ranges;
std::vector<std::string> return_attribs;
_VecSearchFilter__isset __isset;
......@@ -620,6 +622,8 @@ class VecSearchFilter : public virtual ::apache::thrift::TBase {
void __set_time_ranges(const std::vector<VecTimeRange> & val);
void __set_return_attribs(const std::vector<std::string> & val);
bool operator == (const VecSearchFilter & rhs) const
{
if (__isset.attrib_filter != rhs.__isset.attrib_filter)
......@@ -630,6 +634,10 @@ class VecSearchFilter : public virtual ::apache::thrift::TBase {
return false;
else if (__isset.time_ranges && !(time_ranges == rhs.time_ranges))
return false;
if (__isset.return_attribs != rhs.__isset.return_attribs)
return false;
else if (__isset.return_attribs && !(return_attribs == rhs.return_attribs))
return false;
return true;
}
bool operator != (const VecSearchFilter &rhs) const {
......
......@@ -119,57 +119,6 @@ namespace {
}
}
//void ClientTest::LoopTest() {
// server::TimeRecorder rc("LoopTest");
//
// std::string address, protocol;
// int32_t port = 0;
// GetServerAddress(address, port, protocol);
// client::ClientSession session(address, port, protocol);
//
// rc.Record("connection");
//
// //add group
// VecGroup group;
// group.id = "loop_group";
// group.dimension = VEC_DIMENSION;
// group.index_type = 0;
// session.interface()->add_group(group);
// rc.Record("add group");
//
// const int64_t batch = 10000;
// for(int64_t i = 0; i < 1000; i++) {
// {
// VecBinaryTensorList bin_tensor_list;
// BuildVectors(i * batch, (i + 1) * batch, nullptr, &bin_tensor_list);
// rc.Record("build batch no." + std::to_string(i));
//
// std::vector<std::string> ids;
// session.interface()->add_binary_vector_batch(ids, group.id, bin_tensor_list);
// rc.Record("add batch no." + std::to_string(i));
// }
//
// sleep(1);
// rc.Record("sleep 1 second");
//
// VecTensor tensor;
// for (int32_t k = 0; k < VEC_DIMENSION; k++) {
// tensor.tensor.push_back((double) (k + i*666));
// }
//
// //do search
// VecSearchResult res;
// VecSearchFilter filter;
// session.interface()->search_vector(res, group.id, 10, tensor, filter);
// rc.Record("search finish");
//
// std::cout << "Search result: " << std::endl;
// for(VecSearchResultItem& item : res.result_list) {
// std::cout << "\t" << item.uid << std::endl;
// }
// }
//}
TEST(AddVector, CLIENT_TEST) {
try {
std::string address, protocol;
......@@ -301,23 +250,22 @@ TEST(SearchVector, CLIENT_TEST) {
ASSERT_TRUE(!res.result_list[0].uid.empty());
}
// //empty search
// date.day > 0 ? date.day -= 1 : date.day += 1;
// range.time_begin = date;
// range.time_end = date;
// time_ranges.clear();
// time_ranges.emplace_back(range);
// filter.__set_time_ranges(time_ranges);
// session.interface()->search_vector(res, GetGroupID(), top_k, tensor, filter);
//
// ASSERT_EQ(res.result_list.size(), 0);
//empty search
date.day > 0 ? date.day -= 1 : date.day += 1;
range.time_begin = date;
range.time_end = date;
time_ranges.clear();
time_ranges.emplace_back(range);
filter.__set_time_ranges(time_ranges);
session.interface()->search_vector(res, GetGroupID(), TOP_K, tensor, filter);
ASSERT_EQ(res.result_list.size(), 0);
}
//search binary vector
{
const int32_t anchor_index = BATCH_COUNT + 200;
const int32_t search_count = 10;
const int64_t top_k = 5;
server::TimeRecorder rc("Search binary batch top_k");
VecBinaryTensorList tensor_list;
for(int32_t k = anchor_index; k < anchor_index + search_count; k++) {
......@@ -333,7 +281,7 @@ TEST(SearchVector, CLIENT_TEST) {
VecSearchResultList res;
VecSearchFilter filter;
session.interface()->search_binary_vector_batch(res, GetGroupID(), top_k, tensor_list, filter);
session.interface()->search_binary_vector_batch(res, GetGroupID(), TOP_K, tensor_list, filter);
std::cout << "Search binary batch result: " << std::endl;
for(size_t i = 0 ; i < res.result_list.size(); i++) {
......@@ -350,7 +298,7 @@ TEST(SearchVector, CLIENT_TEST) {
ASSERT_EQ(res.result_list.size(), search_count);
for(size_t i = 0 ; i < res.result_list.size(); i++) {
ASSERT_EQ(res.result_list[i].result_list.size(), (uint64_t) top_k);
ASSERT_EQ(res.result_list[i].result_list.size(), (uint64_t) TOP_K);
ASSERT_TRUE(!res.result_list[i].result_list.empty());
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册