未验证 提交 fbf5972f 编写于 作者: Y yukun 提交者: GitHub

C++ sdk sdk_binary needs to update (#4330)

Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>
上级 2909677d
......@@ -20,6 +20,7 @@ Please mark all changes in change log and use the issue from GitHub
- \#4246 Fix 'Illegal instruction' bug when running index tests at GitHub action
- \#4272 Program exit abnormally
- \#4302 Setting DSL fields is invalid in restful api, fields are not returned
- \#4329 C++ sdk sdk_binary needs to update
## Feature
- \#4163 Update C++ sdk search interface
......
......@@ -185,7 +185,7 @@ if (DEFINED ENV{MILVUS_GRPC_URL})
set(GRPC_SOURCE_URL "$ENV{MILVUS_GRPC_URL}")
else ()
set(GRPC_SOURCE_URL
"https://github.com/youny626/grpc-milvus/archive/master.zip")
"https://github.com/milvus-io/grpc-milvus/archive/master.zip")
endif ()
if (DEFINED ENV{MILVUS_ZLIB_URL})
......
......@@ -57,7 +57,7 @@ BuildBinaryVectors(int64_t from, int64_t to, std::vector<milvus::VectorData>& en
}
void
TestProcess(std::shared_ptr<milvus::Connection> connection, const milvus::Mapping& mapping,
TestProcess(std::shared_ptr<milvus::Connection> connection, milvus::Mapping& mapping,
const milvus::IndexParam& index_param) {
milvus::Status stat;
......@@ -65,7 +65,8 @@ TestProcess(std::shared_ptr<milvus::Connection> connection, const milvus::Mappin
JSON extra_params;
extra_params["segment_row_limit"] = 1000000;
extra_params["auto_id"] = false;
stat = connection->CreateCollection(mapping, extra_params.dump());
mapping.extra_params = extra_params.dump();
stat = connection->CreateCollection(mapping);
std::cout << "CreateCollection function call status: " << stat.message() << std::endl;
milvus_sdk::Utils::PrintCollectionParam(mapping);
}
......@@ -79,7 +80,7 @@ TestProcess(std::shared_ptr<milvus::Connection> connection, const milvus::Mappin
std::vector<std::pair<int64_t, milvus::VectorData>> search_entity_array;
{ // insert vectors
for (int i = 0; i < ADD_ENTITY_LOOP; i++) {
milvus::FieldValue field_value;
milvus::FieldValue field_value = milvus::FieldValue();
std::vector<milvus::VectorData> entity_array;
std::vector<int64_t> entity_ids;
......@@ -90,7 +91,7 @@ TestProcess(std::shared_ptr<milvus::Connection> connection, const milvus::Mappin
}
if (search_entity_array.size() < NQ) {
search_entity_array.push_back(std::make_pair(entity_ids[0], entity_array[0]));
search_entity_array.emplace_back(entity_ids[0], entity_array[0]);
}
std::vector<int64_t> int64_data(BATCH_ENTITY_COUNT);
......@@ -115,44 +116,42 @@ TestProcess(std::shared_ptr<milvus::Connection> connection, const milvus::Mappin
}
{ // search vectors
// std::string metric_type = "HAMMING";
// std::string metric_type = "HAMMING";
std::string metric_type = "JACCARD";
// std::string metric_type = "TANIMOTO";
// std::string metric_type = "TANIMOTO";
nlohmann::json dsl_json, vector_param_json;
milvus_sdk::Utils::GenPureVecDSLJson(dsl_json, vector_param_json, metric_type);
std::vector<milvus::VectorData> temp_entity_array;
std::vector<std::vector<uint8_t>> temp_entity_array;
for (auto& pair : search_entity_array) {
temp_entity_array.push_back(pair.second);
temp_entity_array.push_back(pair.second.binary_data);
}
milvus::VectorParam vector_param = {vector_param_json.dump(), temp_entity_array};
milvus_sdk::Utils::GenBinDSLJson(dsl_json, TOP_K, metric_type, temp_entity_array);
std::vector<std::string> partition_tags;
milvus::TopKQueryResult topk_query_result;
auto status = connection->Search(mapping.collection_name, partition_tags, dsl_json.dump(), vector_param, "", topk_query_result);
auto status = connection->Search(mapping.collection_name, partition_tags, dsl_json, "", topk_query_result);
std::cout << metric_type << " Search function call result: " << std::endl;
milvus_sdk::Utils::PrintTopKQueryResult(topk_query_result);
std::cout << metric_type << " Search function call status: " << status.message() << std::endl;
}
/*
{ // wait unit build index finish
milvus_sdk::TimeRecorder rc("Create index");
std::cout << "Wait until create all index done" << std::endl;
milvus_sdk::Utils::PrintIndexParam(index_param);
stat = connection->CreateIndex(index_param);
std::cout << "CreateIndex function call status: " << stat.message() << std::endl;
}
/*
{ // wait unit build index finish
milvus_sdk::TimeRecorder rc("Create index");
std::cout << "Wait until create all index done" << std::endl;
milvus_sdk::Utils::PrintIndexParam(index_param);
stat = connection->CreateIndex(index_param);
std::cout << "CreateIndex function call status: " << stat.message() << std::endl;
}
{ // search vectors
std::vector<std::string> partition_tags;
milvus::TopKQueryResult topk_query_result;
milvus_sdk::Utils::DoSearch(connection, mapping.collection_name, partition_tags, TOP_K, NPROBE,
search_entity_array, topk_query_result);
}
*/
{ // search vectors
std::vector<std::string> partition_tags;
milvus::TopKQueryResult topk_query_result;
milvus_sdk::Utils::DoSearch(connection, mapping.collection_name, partition_tags, TOP_K, NPROBE,
search_entity_array, topk_query_result);
}
*/
{ // drop collection
stat = connection->DropCollection(mapping.collection_name);
std::cout << "DropCollection function call status: " << stat.message() << std::endl;
......@@ -177,20 +176,20 @@ ClientTest::Test(const std::string& address, const std::string& port) {
milvus::FieldPtr field_ptr1 = std::make_shared<milvus::Field>();
milvus::FieldPtr field_ptr2 = std::make_shared<milvus::Field>();
field_ptr1->field_name = "field_1";
field_ptr1->field_type = milvus::DataType::INT64;
field_ptr1->name = "field_1";
field_ptr1->type = milvus::DataType::INT64;
JSON index_param_1;
index_param_1["name"] = "index_1";
field_ptr1->index_params = index_param_1.dump();
field_ptr2->field_name = "field_vec";
field_ptr2->field_type = milvus::DataType::VECTOR_BINARY;
field_ptr2->name = "field_vec";
field_ptr2->type = milvus::DataType::VECTOR_BINARY;
JSON index_param_2;
index_param_2["name"] = "index_vec";
field_ptr2->index_params = index_param_2.dump();
JSON extra_params;
extra_params["dim"] = DIMENSION;
field_ptr2->extra_params = extra_params.dump();
field_ptr2->params = extra_params.dump();
milvus::Mapping mapping = {"collection_1", {field_ptr1, field_ptr2}};
......
......@@ -444,21 +444,32 @@ Utils::GenSpecificDSLJson(nlohmann::json& dsl_json, int64_t topk, const std::str
}
void
Utils::GenPureVecDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, const std::string metric_type) {
nlohmann::json bool_json, vector_json;
std::string placeholder = "placeholder_1";
vector_json["vector"] = placeholder;
bool_json["must"].push_back(vector_json);
dsl_json["bool"] = bool_json;
Utils::GenBinDSLJson(nlohmann::json& dsl_json, int64_t topk, const std::string metric_type,
std::vector<std::vector<uint8_t>>& vectors) {
auto dsl = R"({
"bool": {
"must": [
{
"vector": {
"field_vec": {
"topk": "topk",
"query": "placeholder",
"metric_type": "metric_type"
}
}
}
]
}
})";
dsl_json = nlohmann::json::parse(dsl);
nlohmann::json query_vector_json, vector_extra_params;
int64_t topk = 10;
nlohmann::json vector_extra_params;
nlohmann::json& query_vector_json = dsl_json["bool"]["must"][0]["vector"]["field_vec"];
query_vector_json["topk"] = topk;
query_vector_json["metric_type"] = metric_type;
query_vector_json["query"] = vectors;
vector_extra_params["nprobe"] = 32;
query_vector_json["params"] = vector_extra_params;
vector_param_json[placeholder]["field_vec"] = query_vector_json;
}
void
......@@ -469,6 +480,9 @@ Utils::PrintTopKQueryResult(milvus::TopKQueryResult& topk_query_result) {
std::cout << "- id: " << topk_query_result[i].ids[j] << std::endl;
std::cout << "- distance: " << topk_query_result[i].distances[j] << std::endl;
if (entities.empty()) {
continue;
}
for (const auto& data : entities[j].scalar_data) {
if (data.first == "duration" || data.first == "release_year") {
std::cout << "- " << data.first << ": " << std::any_cast<int32_t>(data.second) << std::endl;
......
......@@ -90,10 +90,11 @@ class Utils {
static void
GenSpecificDSLJson(nlohmann::json& dsl_json, int64_t topk, const std::string& metric_type,
std::vector<milvus::VectorData>& vectors);
std::vector<milvus::VectorData>& vectors);
static void
GenPureVecDSLJson(nlohmann::json& dsl_json, nlohmann::json& vector_param_json, const std::string metric_type);
GenBinDSLJson(nlohmann::json& dsl_json, int64_t topk, const std::string metric_type,
std::vector<std::vector<uint8_t>>& vectors);
static void
PrintTopKQueryResult(milvus::TopKQueryResult& topk_query_result);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册