未验证 提交 3b305048 编写于 作者: S shengjun.li 提交者: GitHub

[skip ci] fix binary sdk (#3332)

Signed-off-by: Nshengjun.li <shengjun.li@zilliz.com>
上级 07ece74b
......@@ -15,5 +15,5 @@ aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/utils util_files)
add_subdirectory(simple)
#add_subdirectory(partition)
#add_subdirectory(binary_vector)
add_subdirectory(binary_vector)
#add_subdirectory(qps)
......@@ -86,10 +86,11 @@ TestProcess(std::shared_ptr<milvus::Connection> connection, const milvus::Mappin
{ // generate vectors
milvus_sdk::TimeRecorder rc("Build entities No." + std::to_string(i));
BuildBinaryVectors(begin_index, begin_index + BATCH_ENTITY_COUNT, entity_array, entity_ids, DIMENSION);
entity_ids.clear();
}
if (search_entity_array.size() < NQ) {
search_entity_array.push_back(std::make_pair(entity_ids[SEARCH_TARGET], entity_array[SEARCH_TARGET]));
search_entity_array.push_back(std::make_pair(entity_ids[0], entity_array[0]));
}
std::vector<int64_t> int64_data(BATCH_ENTITY_COUNT);
......@@ -114,12 +115,29 @@ TestProcess(std::shared_ptr<milvus::Connection> connection, const milvus::Mappin
}
{ // search vectors
// std::string metric_type = "HAMMING";
std::string metric_type = "JACCARD";
// std::string metric_type = "TANIMOTO";
nlohmann::json dsl_json, vector_param_json;
milvus_sdk::Utils::GenDSLJson(dsl_json, vector_param_json, metric_type);
std::vector<milvus::VectorData> temp_entity_array;
for (auto& pair : search_entity_array) {
temp_entity_array.push_back(pair.second);
}
milvus::VectorParam vector_param = {vector_param_json.dump(), temp_entity_array};
std::vector<std::string> partition_tags;
milvus::TopKQueryResult topk_query_result;
milvus_sdk::Utils::DoSearch(connection, mapping.collection_name, partition_tags, TOP_K, NPROBE,
search_entity_array, topk_query_result);
}
auto status = connection->Search(mapping.collection_name, partition_tags, dsl_json.dump(), vector_param, topk_query_result);
std::cout << metric_type << " Search function call result: " << std::endl;
milvus_sdk::Utils::PrintTopKQueryResult(topk_query_result);
std::cout << metric_type << " Search function call status: " << status.message() << std::endl;
}
/*
{ // wait unit build index finish
milvus_sdk::TimeRecorder rc("Create index");
std::cout << "Wait until create all index done" << std::endl;
......@@ -134,7 +152,7 @@ TestProcess(std::shared_ptr<milvus::Connection> connection, const milvus::Mappin
milvus_sdk::Utils::DoSearch(connection, mapping.collection_name, partition_tags, TOP_K, NPROBE,
search_entity_array, topk_query_result);
}
*/
{ // drop collection
stat = connection->DropCollection(mapping.collection_name);
std::cout << "DropCollection function call status: " << stat.message() << std::endl;
......@@ -157,27 +175,27 @@ ClientTest::Test(const std::string& address, const std::string& port) {
{
milvus::FieldPtr field_ptr1 = std::make_shared<milvus::Field>();
milvus::FieldPtr field_ptr2 = std::make_shared<milvus::Field>();
field_ptr1->field_name = "field_1";
field_ptr1->field_type = milvus::DataType::INT64;
JSON index_param_1;
index_param_1["name"] = "index_1";
field_ptr1->index_params = index_param_1.dump();
milvus::FieldPtr field_ptr2 = std::make_shared<milvus::Field>();
field_ptr2->field_type = milvus::DataType::VECTOR_BINARY;
field_ptr2->field_name = "field_vec";
field_ptr2->field_type = milvus::DataType::VECTOR_BINARY;
JSON index_param_2;
index_param_2["name"] = "index_3";
index_param_2["name"] = "index_vec";
field_ptr2->index_params = index_param_2.dump();
JSON extra_params;
extra_params["dimension"] = 128;
extra_params["metric_type"] = "TANIMOTO";
extra_params["dim"] = DIMENSION;
field_ptr2->extra_params = extra_params.dump();
milvus::Mapping mapping = {"collection_1", {field_ptr1, field_ptr2}};
JSON json_params = {{"index_type", "IVF_FLAT"}, {"nlist", 1024}};
milvus::IndexParam index_param = {mapping.collection_name, "field_2", "index_3", json_params.dump()};
JSON json_params = {{"index_type", "BIN_IVF_FLAT"}, {"nlist", 1024}};
milvus::IndexParam index_param = {mapping.collection_name, "field_vec", json_params.dump()};
TestProcess(connection, mapping, index_param);
}
......
......@@ -255,33 +255,6 @@ Utils::CheckSearchResult(const std::vector<std::pair<int64_t, milvus::VectorData
BLOCK_SPLITER
}
void
Utils::DoSearch(std::shared_ptr<milvus::Connection> conn, const std::string& collection_name,
const std::vector<std::string>& partition_tags, int64_t top_k, int64_t nprobe,
std::vector<std::pair<int64_t, milvus::VectorData>> entity_array,
milvus::TopKQueryResult& topk_query_result) {
/*
topk_query_result.clear();
nlohmann::json dsl_json, vector_param_json;
GenDSLJson(dsl_json, vector_param_json);
std::vector<milvus::VectorData> temp_entity_array;
for (auto& pair : entity_array) {
temp_entity_array.push_back(pair.second);
}
milvus::VectorParam vector_param = {vector_param_json.dump(), temp_entity_array};
JSON json_params = {{"nprobe", nprobe}};
milvus_sdk::TimeRecorder rc("Search");
auto status = conn->Search(collection_name, partition_tags, dsl_json.dump(), vector_param, topk_query_result);
PrintTopKQueryResult(topk_query_result);
// PrintSearchResult(entity_array, topk_query_result);
*/
}
void
Utils::ConstructVectors(int64_t from, int64_t to, std::vector<milvus::VectorData>& query_vector,
std::vector<int64_t>& search_ids, int64_t dimension) {
......@@ -397,6 +370,31 @@ Utils::GenDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, c
vector_param_json[placeholder]["field_vec"] = query_vector_json;
}
void
Utils::GenBinaryDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, const std::string metric_type) {
uint64_t row_num = 10000;
std::vector<int64_t> term_value;
term_value.resize(row_num);
for (uint64_t i = 0; i < row_num; ++i) {
term_value[i] = i;
}
nlohmann::json bool_json, vector_json;
std::string placeholder = "placeholder_1";
vector_json["vector"] = placeholder;
bool_json["must"].push_back(vector_json);
dsl_json["bool"] = bool_json;
nlohmann::json query_vector_json, vector_extra_params;
int64_t topk = 10;
query_vector_json["topk"] = topk;
query_vector_json["metric_type"] = metric_type;
vector_extra_params["nprobe"] = 32;
query_vector_json["params"] = vector_extra_params;
vector_param_json[placeholder]["field_vec"] = query_vector_json;
}
void
Utils::PrintTopKQueryResult(milvus::TopKQueryResult& topk_query_result) {
for (size_t i = 0; i < topk_query_result.size(); i++) {
......@@ -433,6 +431,7 @@ Utils::PrintTopKQueryResult(milvus::TopKQueryResult& topk_query_result) {
std::cout << topk_query_result[i].ids[j] << " --------- " << topk_query_result[i].distances[j]
<< std::endl;
}
std::cout << std::endl;
}
}
......
......@@ -84,6 +84,9 @@ class Utils {
static void
GenDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, const std::string metric_type);
static void
GenBinaryDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, const std::string metric_type);
static void
PrintTopKQueryResult(milvus::TopKQueryResult& topk_query_result);
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册