diff --git a/cpp/src/server/VecServiceHandler.h b/cpp/src/server/VecServiceHandler.h index a43aade6f9627ec06a91b935eb4cc38d6d77d5a2..b1bd4e1ad4b84967c11f1eebe5995b2ff091dfd4 100644 --- a/cpp/src/server/VecServiceHandler.h +++ b/cpp/src/server/VecServiceHandler.h @@ -55,15 +55,18 @@ public: void add_binary_vector_batch(const std::string& group_id, const VecBinaryTensorList& tensor_list); /** - * search interfaces - * if time_range_list is empty, engine will search without time limit - * - * - * @param group_id - * @param top_k - * @param tensor - * @param filter - */ + * search interfaces + * you can use filter to reduce search result + * filter.attrib_filter can specify which attribute you need, for example: + * set attrib_filter = {"color":""} means you want to get "color" attribute for result vector + * set attrib_filter = {"color":"red"} means you want to get vectors which has attribute "color" equals "red" + * if filter.time_range is empty, engine will search without time limit + * + * @param group_id + * @param top_k + * @param tensor + * @param filter + */ void search_vector(VecSearchResult& _return, const std::string& group_id, const int64_t top_k, const VecTensor& tensor, const VecSearchFilter& filter); void search_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecTensorList& tensor_list, const VecSearchFilter& filter); @@ -71,7 +74,7 @@ public: void search_binary_vector(VecSearchResult& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensor& tensor, const VecSearchFilter& filter); void search_binary_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensorList& tensor_list, const VecSearchFilter& filter); - + }; diff --git a/cpp/src/server/VecServiceTask.cpp b/cpp/src/server/VecServiceTask.cpp index f0c268acb7fe32f2dde9103beda86ef2d7666438..5af75723d01a48d1de7ac55e5d1626a13a33d06e 100644 --- a/cpp/src/server/VecServiceTask.cpp +++ b/cpp/src/server/VecServiceTask.cpp @@ -19,6 +19,8 @@ namespace server { static const std::string DQL_TASK_GROUP = "dql"; static const std::string DDL_DML_TASK_GROUP = "ddl_dml"; +static const std::string VECTOR_UID = "uid"; + namespace { class DBWrapper { public: @@ -201,6 +203,14 @@ std::string AddVectorTask::GetVecID() const { } } +const AttribMap& AddVectorTask::GetVecAttrib() const { + if(tensor_) { + return tensor_->attrib; + } else { + return bin_tensor_->attrib; + } +} + ServerError AddVectorTask::OnExecute() { try { engine::meta::GroupSchema group_info; @@ -238,8 +248,12 @@ ServerError AddVectorTask::OnExecute() { } else { std::string uid = GetVecID(); std::string nid = group_id_ + "_" + std::to_string(vector_ids[0]); - IVecIdMapper::GetInstance()->Put(nid, uid); - SERVER_LOG_TRACE << "nid = " << vector_ids[0] << ", sid = " << uid; + AttribMap attrib = GetVecAttrib(); + attrib[VECTOR_UID] = uid; + std::string attrib_str; + AttributeSerializer::Encode(attrib, attrib_str); + IVecIdMapper::GetInstance()->Put(nid, attrib_str); + SERVER_LOG_TRACE << "nid = " << vector_ids[0] << ", uid = " << uid; } } @@ -339,6 +353,14 @@ std::string AddBatchVectorTask::GetVecID(uint64_t index) const { } } +const AttribMap& AddBatchVectorTask::GetVecAttrib(uint64_t index) const { + if(tensor_list_) { + return tensor_list_->tensor_list[index].attrib; + } else { + return bin_tensor_list_->tensor_list[index].attrib; + } +} + ServerError AddBatchVectorTask::OnExecute() { try { TimeRecorder rc("AddBatchVectorTask"); @@ -387,7 +409,11 @@ ServerError AddBatchVectorTask::OnExecute() { for(size_t i = 0; i < vec_count; i++) { std::string uid = GetVecID(i); std::string nid = nid_prefix + std::to_string(vector_ids[i]); - IVecIdMapper::GetInstance()->Put(nid, uid); + AttribMap attrib = GetVecAttrib(i); + attrib[VECTOR_UID] = uid; + std::string attrib_str; + AttributeSerializer::Encode(attrib, attrib_str); + IVecIdMapper::GetInstance()->Put(nid, attrib_str); } rc.Record("build id mapping"); } @@ -543,16 +569,20 @@ ServerError SearchVectorTask::OnExecute() { VecSearchResult v_res; std::string nid_prefix = group_id_ + "_"; for(auto id : res) { - std::string sid; + std::string attrib_str; std::string nid = nid_prefix + std::to_string(id); - IVecIdMapper::GetInstance()->Get(nid, sid); + IVecIdMapper::GetInstance()->Get(nid, attrib_str); + + AttribMap attrib_map; + AttributeSerializer::Decode(attrib_str, attrib_map); + VecSearchResultItem item; - item.uid = sid; + item.__set_attrib(attrib_map); + item.uid = item.attrib[VECTOR_UID]; item.distance = 0.0;////TODO: return distance v_res.result_list.emplace_back(item); - SERVER_LOG_TRACE << "nid = " << nid << ", string id = " << sid; - + SERVER_LOG_TRACE << "nid = " << nid << ", uid = " << item.uid; } result_.result_list.push_back(v_res); diff --git a/cpp/src/server/VecServiceTask.h b/cpp/src/server/VecServiceTask.h index ba5d8dbc376b0ed2eb549ce523bfc3f0ac4c4cc7..14afa3d026bd608550ff5158d5b84f9c868e2ecd 100644 --- a/cpp/src/server/VecServiceTask.h +++ b/cpp/src/server/VecServiceTask.h @@ -7,6 +7,7 @@ #include "VecServiceScheduler.h" #include "utils/Error.h" +#include "utils/AttributeSerializer.h" #include "db/Types.h" #include "thrift/gen-cpp/VectorService_types.h" @@ -85,6 +86,7 @@ protected: uint64_t GetVecDimension() const; const double* GetVecData() const; std::string GetVecID() const; + const AttribMap& GetVecAttrib() const; ServerError OnExecute() override; @@ -115,6 +117,7 @@ protected: uint64_t GetVecDimension(uint64_t index) const; const double* GetVecData(uint64_t index) const; std::string GetVecID(uint64_t index) const; + const AttribMap& GetVecAttrib(uint64_t index) const; ServerError OnExecute() override; diff --git a/cpp/src/utils/AttributeSerializer.cpp b/cpp/src/utils/AttributeSerializer.cpp index d232caa9a01cd99f18a1f2aaa5135647d370c1ab..2c6fe3220b9ce4ae0cd658ed9d3d939e1a7b9ef3 100644 --- a/cpp/src/utils/AttributeSerializer.cpp +++ b/cpp/src/utils/AttributeSerializer.cpp @@ -5,17 +5,44 @@ ******************************************************************************/ #include "AttributeSerializer.h" +#include "StringHelpFunctions.h" namespace zilliz { namespace vecwise { namespace server { -void AttributeSerializer::Encode(const std::map& attrib, std::string& result) { +ServerError AttributeSerializer::Encode(const AttribMap& attrib_map, std::string& attrib_str) { + attrib_str = ""; + for(auto iter : attrib_map) { + attrib_str += iter.first; + attrib_str += ":\""; + attrib_str += iter.second; + attrib_str += "\";"; + } + + return SERVER_SUCCESS; } -void AttributeSerializer::Decode(const std::string& str, std::map& result) { +ServerError AttributeSerializer::Decode(const std::string& attrib_str, AttribMap& attrib_map) { + attrib_map.clear(); + + std::vector kv_pairs; + StringHelpFunctions::SplitStringByQuote(attrib_str, ";", "\"", kv_pairs); + for(std::string& str : kv_pairs) { + std::string key, val; + size_t index = str.find_first_of(":", 0); + if (index != std::string::npos) { + key = str.substr(0, index); + val = str.substr(index + 1); + } else { + key = str; + } + + attrib_map.insert(std::make_pair(key, val)); + } + return SERVER_SUCCESS; } } diff --git a/cpp/src/utils/AttributeSerializer.h b/cpp/src/utils/AttributeSerializer.h index 500b171145ac8d8d33619e447b2add2297ded403..0157d69e88d1464e400d27d857b9c0e20d37926b 100644 --- a/cpp/src/utils/AttributeSerializer.h +++ b/cpp/src/utils/AttributeSerializer.h @@ -7,14 +7,18 @@ #include +#include "Error.h" + namespace zilliz { namespace vecwise { namespace server { +using AttribMap = std::map; + class AttributeSerializer { public: - static void Encode(const std::map& attrib, std::string& result); - static void Decode(const std::string& str, std::map& result); + static ServerError Encode(const AttribMap& attrib_map, std::string& attrib_str); + static ServerError Decode(const std::string& attrib_str, AttribMap& attrib_map); }; diff --git a/cpp/test_client/src/ClientTest.cpp b/cpp/test_client/src/ClientTest.cpp index 2e38ce6c153a565bc8661ef9ef96e968e315b21d..7ef62a068012c70130176388dff4acdfd95c3f7e 100644 --- a/cpp/test_client/src/ClientTest.cpp +++ b/cpp/test_client/src/ClientTest.cpp @@ -4,7 +4,8 @@ // Proprietary and confidential. //////////////////////////////////////////////////////////////////////////////// #include -#include +#include "utils/TimeRecorder.h" +#include "utils/AttributeSerializer.h" #include "ClientSession.h" #include "server/ServerConfig.h" #include "Log.h" @@ -16,6 +17,9 @@ using namespace zilliz::vecwise; namespace { static const int32_t VEC_DIMENSION = 256; + static const std::string TEST_ATTRIB_NUM = "number"; + static const std::string TEST_ATTRIB_COMMENT = "comment"; + std::string CurrentTime() { time_t tt; time( &tt ); @@ -69,27 +73,39 @@ TEST(AddVector, CLIENT_TEST) { const int64_t count = 100000; VecTensorList tensor_list; VecBinaryTensorList bin_tensor_list; - for (int64_t k = 0; k < count; k++) { - VecTensor tensor; - tensor.tensor.reserve(VEC_DIMENSION); - VecBinaryTensor bin_tensor; - bin_tensor.tensor.resize(VEC_DIMENSION * sizeof(double)); - double *d_p = (double *) (const_cast(bin_tensor.tensor.data())); - for (int32_t i = 0; i < VEC_DIMENSION; i++) { - double val = (double) (i + k); - tensor.tensor.push_back(val); - d_p[i] = val; - } + { + server::TimeRecorder rc(std::to_string(count) + " vectors built"); + for (int64_t k = 0; k < count; k++) { + VecTensor tensor; + tensor.tensor.reserve(VEC_DIMENSION); + VecBinaryTensor bin_tensor; + bin_tensor.tensor.resize(VEC_DIMENSION * sizeof(double)); + double *d_p = (double *) (const_cast(bin_tensor.tensor.data())); + for (int32_t i = 0; i < VEC_DIMENSION; i++) { + double val = (double) (i + k); + tensor.tensor.push_back(val); + d_p[i] = val; + } + + server::AttribMap attrib_map; + attrib_map[TEST_ATTRIB_NUM] = "No." + std::to_string(k); - tensor.uid = "normal_vec_" + std::to_string(k); - tensor_list.tensor_list.emplace_back(tensor); + tensor.uid = "normal_vec_" + std::to_string(k); + attrib_map[TEST_ATTRIB_COMMENT] = tensor.uid; + tensor.__set_attrib(attrib_map); + tensor_list.tensor_list.emplace_back(tensor); - bin_tensor.uid = "binary_vec_" + std::to_string(k); - bin_tensor_list.tensor_list.emplace_back(bin_tensor); + bin_tensor.uid = "binary_vec_" + std::to_string(k); + attrib_map[TEST_ATTRIB_COMMENT] = bin_tensor.uid; + bin_tensor.__set_attrib(attrib_map); + bin_tensor_list.tensor_list.emplace_back(bin_tensor); - if((k+1)%10000 == 0) { - CLIENT_LOG_INFO << k+1 << " vectors built"; + if ((k + 1) % 10000 == 0) { + CLIENT_LOG_INFO << k + 1 << " vectors built"; + } } + + rc.Elapse("done"); } // //add vectors one by one @@ -164,6 +180,10 @@ TEST(SearchVector, CLIENT_TEST) { std::cout << "Search result: " << std::endl; for(VecSearchResultItem& item : res.result_list) { std::cout << "\t" << item.uid << std::endl; + + ASSERT_TRUE(item.attrib.count(TEST_ATTRIB_NUM) != 0); + ASSERT_TRUE(item.attrib.count(TEST_ATTRIB_COMMENT) != 0); + ASSERT_TRUE(item.attrib[TEST_ATTRIB_COMMENT].find(item.uid) != std::string::npos); } rc.Elapse("done!"); @@ -200,6 +220,9 @@ TEST(SearchVector, CLIENT_TEST) { std::cout << "No " << i << ":" << std::endl; for(VecSearchResultItem& item : res.result_list[i].result_list) { std::cout << "\t" << item.uid << std::endl; + ASSERT_TRUE(item.attrib.count(TEST_ATTRIB_NUM) != 0); + ASSERT_TRUE(item.attrib.count(TEST_ATTRIB_COMMENT) != 0); + ASSERT_TRUE(item.attrib[TEST_ATTRIB_COMMENT].find(item.uid) != std::string::npos); } } diff --git a/cpp/unittest/server/CMakeLists.txt b/cpp/unittest/server/CMakeLists.txt index 402eaf8ec85000eb8170f48ba507cca1b21e5f3a..8e8e116481419b287fb47b1b688a53fb4d3fb99f 100644 --- a/cpp/unittest/server/CMakeLists.txt +++ b/cpp/unittest/server/CMakeLists.txt @@ -19,6 +19,8 @@ set(require_files ../../src/server/ServerConfig.cpp ../../src/utils/CommonUtil.cpp ../../src/utils/TimeRecorder.cpp + ../../src/utils/StringHelpFunctions.cpp + ../../src/utils/AttributeSerializer.cpp ) cuda_add_executable(server_test diff --git a/cpp/unittest/server/util_test.cpp b/cpp/unittest/server/util_test.cpp new file mode 100644 index 0000000000000000000000000000000000000000..bfb45c8e47d4c217d9efd0701440e7c04cccbe1b --- /dev/null +++ b/cpp/unittest/server/util_test.cpp @@ -0,0 +1,34 @@ +//////////////////////////////////////////////////////////////////////////////// +// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved +// Unauthorized copying of this file, via any medium is strictly prohibited. +// Proprietary and confidential. +//////////////////////////////////////////////////////////////////////////////// +#include + +#include "utils/AttributeSerializer.h" +#include "utils/StringHelpFunctions.h" + +using namespace zilliz::vecwise; + +TEST(AttribSerializeTest, ATTRIBSERIAL_TEST) { + std::map attrib; + attrib["uid"] = "ABCDEF"; + attrib["color"] = "red"; + attrib["number"] = "9900"; + attrib["comment"] = "please note: it is a car, not a ship"; + attrib["address"] = " china;shanghai "; + + std::string attri_str; + server::AttributeSerializer::Encode(attrib, attri_str); + + std::map attrib_out; + server::ServerError err = server::AttributeSerializer::Decode(attri_str, attrib_out); + ASSERT_EQ(err, server::SERVER_SUCCESS); + + ASSERT_EQ(attrib_out.size(), attrib.size()); + for(auto iter : attrib) { + ASSERT_EQ(attrib_out[iter.first], attrib_out[iter.first]); + } + +} +