提交 16df78ca 编写于 作者: G groot

add attribute for vector


Former-commit-id: e520d4e7aaceb1fedae62cc67326a550666efc6b
上级 fe5909a5
......@@ -55,15 +55,18 @@ public:
void add_binary_vector_batch(const std::string& group_id, const VecBinaryTensorList& tensor_list);
/**
* search interfaces
* if time_range_list is empty, engine will search without time limit
*
*
* @param group_id
* @param top_k
* @param tensor
* @param filter
*/
* search interfaces
* you can use filter to reduce search result
* filter.attrib_filter can specify which attribute you need, for example:
* set attrib_filter = {"color":""} means you want to get "color" attribute for result vector
* set attrib_filter = {"color":"red"} means you want to get vectors which has attribute "color" equals "red"
* if filter.time_range is empty, engine will search without time limit
*
* @param group_id
* @param top_k
* @param tensor
* @param filter
*/
void search_vector(VecSearchResult& _return, const std::string& group_id, const int64_t top_k, const VecTensor& tensor, const VecSearchFilter& filter);
void search_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecTensorList& tensor_list, const VecSearchFilter& filter);
......@@ -71,7 +74,7 @@ public:
void search_binary_vector(VecSearchResult& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensor& tensor, const VecSearchFilter& filter);
void search_binary_vector_batch(VecSearchResultList& _return, const std::string& group_id, const int64_t top_k, const VecBinaryTensorList& tensor_list, const VecSearchFilter& filter);
};
......
......@@ -19,6 +19,8 @@ namespace server {
static const std::string DQL_TASK_GROUP = "dql";
static const std::string DDL_DML_TASK_GROUP = "ddl_dml";
static const std::string VECTOR_UID = "uid";
namespace {
class DBWrapper {
public:
......@@ -201,6 +203,14 @@ std::string AddVectorTask::GetVecID() const {
}
}
const AttribMap& AddVectorTask::GetVecAttrib() const {
if(tensor_) {
return tensor_->attrib;
} else {
return bin_tensor_->attrib;
}
}
ServerError AddVectorTask::OnExecute() {
try {
engine::meta::GroupSchema group_info;
......@@ -238,8 +248,12 @@ ServerError AddVectorTask::OnExecute() {
} else {
std::string uid = GetVecID();
std::string nid = group_id_ + "_" + std::to_string(vector_ids[0]);
IVecIdMapper::GetInstance()->Put(nid, uid);
SERVER_LOG_TRACE << "nid = " << vector_ids[0] << ", sid = " << uid;
AttribMap attrib = GetVecAttrib();
attrib[VECTOR_UID] = uid;
std::string attrib_str;
AttributeSerializer::Encode(attrib, attrib_str);
IVecIdMapper::GetInstance()->Put(nid, attrib_str);
SERVER_LOG_TRACE << "nid = " << vector_ids[0] << ", uid = " << uid;
}
}
......@@ -339,6 +353,14 @@ std::string AddBatchVectorTask::GetVecID(uint64_t index) const {
}
}
const AttribMap& AddBatchVectorTask::GetVecAttrib(uint64_t index) const {
if(tensor_list_) {
return tensor_list_->tensor_list[index].attrib;
} else {
return bin_tensor_list_->tensor_list[index].attrib;
}
}
ServerError AddBatchVectorTask::OnExecute() {
try {
TimeRecorder rc("AddBatchVectorTask");
......@@ -387,7 +409,11 @@ ServerError AddBatchVectorTask::OnExecute() {
for(size_t i = 0; i < vec_count; i++) {
std::string uid = GetVecID(i);
std::string nid = nid_prefix + std::to_string(vector_ids[i]);
IVecIdMapper::GetInstance()->Put(nid, uid);
AttribMap attrib = GetVecAttrib(i);
attrib[VECTOR_UID] = uid;
std::string attrib_str;
AttributeSerializer::Encode(attrib, attrib_str);
IVecIdMapper::GetInstance()->Put(nid, attrib_str);
}
rc.Record("build id mapping");
}
......@@ -543,16 +569,20 @@ ServerError SearchVectorTask::OnExecute() {
VecSearchResult v_res;
std::string nid_prefix = group_id_ + "_";
for(auto id : res) {
std::string sid;
std::string attrib_str;
std::string nid = nid_prefix + std::to_string(id);
IVecIdMapper::GetInstance()->Get(nid, sid);
IVecIdMapper::GetInstance()->Get(nid, attrib_str);
AttribMap attrib_map;
AttributeSerializer::Decode(attrib_str, attrib_map);
VecSearchResultItem item;
item.uid = sid;
item.__set_attrib(attrib_map);
item.uid = item.attrib[VECTOR_UID];
item.distance = 0.0;////TODO: return distance
v_res.result_list.emplace_back(item);
SERVER_LOG_TRACE << "nid = " << nid << ", string id = " << sid;
SERVER_LOG_TRACE << "nid = " << nid << ", uid = " << item.uid;
}
result_.result_list.push_back(v_res);
......
......@@ -7,6 +7,7 @@
#include "VecServiceScheduler.h"
#include "utils/Error.h"
#include "utils/AttributeSerializer.h"
#include "db/Types.h"
#include "thrift/gen-cpp/VectorService_types.h"
......@@ -85,6 +86,7 @@ protected:
uint64_t GetVecDimension() const;
const double* GetVecData() const;
std::string GetVecID() const;
const AttribMap& GetVecAttrib() const;
ServerError OnExecute() override;
......@@ -115,6 +117,7 @@ protected:
uint64_t GetVecDimension(uint64_t index) const;
const double* GetVecData(uint64_t index) const;
std::string GetVecID(uint64_t index) const;
const AttribMap& GetVecAttrib(uint64_t index) const;
ServerError OnExecute() override;
......
......@@ -5,17 +5,44 @@
******************************************************************************/
#include "AttributeSerializer.h"
#include "StringHelpFunctions.h"
namespace zilliz {
namespace vecwise {
namespace server {
void AttributeSerializer::Encode(const std::map<std::string, std::string>& attrib, std::string& result) {
ServerError AttributeSerializer::Encode(const AttribMap& attrib_map, std::string& attrib_str) {
attrib_str = "";
for(auto iter : attrib_map) {
attrib_str += iter.first;
attrib_str += ":\"";
attrib_str += iter.second;
attrib_str += "\";";
}
return SERVER_SUCCESS;
}
void AttributeSerializer::Decode(const std::string& str, std::map<std::string, std::string>& result) {
ServerError AttributeSerializer::Decode(const std::string& attrib_str, AttribMap& attrib_map) {
attrib_map.clear();
std::vector<std::string> kv_pairs;
StringHelpFunctions::SplitStringByQuote(attrib_str, ";", "\"", kv_pairs);
for(std::string& str : kv_pairs) {
std::string key, val;
size_t index = str.find_first_of(":", 0);
if (index != std::string::npos) {
key = str.substr(0, index);
val = str.substr(index + 1);
} else {
key = str;
}
attrib_map.insert(std::make_pair(key, val));
}
return SERVER_SUCCESS;
}
}
......
......@@ -7,14 +7,18 @@
#include <map>
#include "Error.h"
namespace zilliz {
namespace vecwise {
namespace server {
using AttribMap = std::map<std::string, std::string>;
class AttributeSerializer {
public:
static void Encode(const std::map<std::string, std::string>& attrib, std::string& result);
static void Decode(const std::string& str, std::map<std::string, std::string>& result);
static ServerError Encode(const AttribMap& attrib_map, std::string& attrib_str);
static ServerError Decode(const std::string& attrib_str, AttribMap& attrib_map);
};
......
......@@ -4,7 +4,8 @@
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#include <gtest/gtest.h>
#include <utils/TimeRecorder.h>
#include "utils/TimeRecorder.h"
#include "utils/AttributeSerializer.h"
#include "ClientSession.h"
#include "server/ServerConfig.h"
#include "Log.h"
......@@ -16,6 +17,9 @@ using namespace zilliz::vecwise;
namespace {
static const int32_t VEC_DIMENSION = 256;
static const std::string TEST_ATTRIB_NUM = "number";
static const std::string TEST_ATTRIB_COMMENT = "comment";
std::string CurrentTime() {
time_t tt;
time( &tt );
......@@ -69,27 +73,39 @@ TEST(AddVector, CLIENT_TEST) {
const int64_t count = 100000;
VecTensorList tensor_list;
VecBinaryTensorList bin_tensor_list;
for (int64_t k = 0; k < count; k++) {
VecTensor tensor;
tensor.tensor.reserve(VEC_DIMENSION);
VecBinaryTensor bin_tensor;
bin_tensor.tensor.resize(VEC_DIMENSION * sizeof(double));
double *d_p = (double *) (const_cast<char *>(bin_tensor.tensor.data()));
for (int32_t i = 0; i < VEC_DIMENSION; i++) {
double val = (double) (i + k);
tensor.tensor.push_back(val);
d_p[i] = val;
}
{
server::TimeRecorder rc(std::to_string(count) + " vectors built");
for (int64_t k = 0; k < count; k++) {
VecTensor tensor;
tensor.tensor.reserve(VEC_DIMENSION);
VecBinaryTensor bin_tensor;
bin_tensor.tensor.resize(VEC_DIMENSION * sizeof(double));
double *d_p = (double *) (const_cast<char *>(bin_tensor.tensor.data()));
for (int32_t i = 0; i < VEC_DIMENSION; i++) {
double val = (double) (i + k);
tensor.tensor.push_back(val);
d_p[i] = val;
}
server::AttribMap attrib_map;
attrib_map[TEST_ATTRIB_NUM] = "No." + std::to_string(k);
tensor.uid = "normal_vec_" + std::to_string(k);
tensor_list.tensor_list.emplace_back(tensor);
tensor.uid = "normal_vec_" + std::to_string(k);
attrib_map[TEST_ATTRIB_COMMENT] = tensor.uid;
tensor.__set_attrib(attrib_map);
tensor_list.tensor_list.emplace_back(tensor);
bin_tensor.uid = "binary_vec_" + std::to_string(k);
bin_tensor_list.tensor_list.emplace_back(bin_tensor);
bin_tensor.uid = "binary_vec_" + std::to_string(k);
attrib_map[TEST_ATTRIB_COMMENT] = bin_tensor.uid;
bin_tensor.__set_attrib(attrib_map);
bin_tensor_list.tensor_list.emplace_back(bin_tensor);
if((k+1)%10000 == 0) {
CLIENT_LOG_INFO << k+1 << " vectors built";
if ((k + 1) % 10000 == 0) {
CLIENT_LOG_INFO << k + 1 << " vectors built";
}
}
rc.Elapse("done");
}
// //add vectors one by one
......@@ -164,6 +180,10 @@ TEST(SearchVector, CLIENT_TEST) {
std::cout << "Search result: " << std::endl;
for(VecSearchResultItem& item : res.result_list) {
std::cout << "\t" << item.uid << std::endl;
ASSERT_TRUE(item.attrib.count(TEST_ATTRIB_NUM) != 0);
ASSERT_TRUE(item.attrib.count(TEST_ATTRIB_COMMENT) != 0);
ASSERT_TRUE(item.attrib[TEST_ATTRIB_COMMENT].find(item.uid) != std::string::npos);
}
rc.Elapse("done!");
......@@ -200,6 +220,9 @@ TEST(SearchVector, CLIENT_TEST) {
std::cout << "No " << i << ":" << std::endl;
for(VecSearchResultItem& item : res.result_list[i].result_list) {
std::cout << "\t" << item.uid << std::endl;
ASSERT_TRUE(item.attrib.count(TEST_ATTRIB_NUM) != 0);
ASSERT_TRUE(item.attrib.count(TEST_ATTRIB_COMMENT) != 0);
ASSERT_TRUE(item.attrib[TEST_ATTRIB_COMMENT].find(item.uid) != std::string::npos);
}
}
......
......@@ -19,6 +19,8 @@ set(require_files
../../src/server/ServerConfig.cpp
../../src/utils/CommonUtil.cpp
../../src/utils/TimeRecorder.cpp
../../src/utils/StringHelpFunctions.cpp
../../src/utils/AttributeSerializer.cpp
)
cuda_add_executable(server_test
......
////////////////////////////////////////////////////////////////////////////////
// Copyright 上海赜睿信息科技有限公司(Zilliz) - All Rights Reserved
// Unauthorized copying of this file, via any medium is strictly prohibited.
// Proprietary and confidential.
////////////////////////////////////////////////////////////////////////////////
#include <gtest/gtest.h>
#include "utils/AttributeSerializer.h"
#include "utils/StringHelpFunctions.h"
using namespace zilliz::vecwise;
TEST(AttribSerializeTest, ATTRIBSERIAL_TEST) {
std::map<std::string, std::string> attrib;
attrib["uid"] = "ABCDEF";
attrib["color"] = "red";
attrib["number"] = "9900";
attrib["comment"] = "please note: it is a car, not a ship";
attrib["address"] = " china;shanghai ";
std::string attri_str;
server::AttributeSerializer::Encode(attrib, attri_str);
std::map<std::string, std::string> attrib_out;
server::ServerError err = server::AttributeSerializer::Decode(attri_str, attrib_out);
ASSERT_EQ(err, server::SERVER_SUCCESS);
ASSERT_EQ(attrib_out.size(), attrib.size());
for(auto iter : attrib) {
ASSERT_EQ(attrib_out[iter.first], attrib_out[iter.first]);
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册