未验证 提交 cd6a910d 编写于 作者: Y yukun 提交者: GitHub

Fix search when there are multiple vector fields (#4420)

Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>
上级 cf7ce289
......@@ -21,6 +21,7 @@ Please mark all changes in change log and use the issue from GitHub
- \#4272 Program exit abnormally
- \#4302 Setting DSL fields is invalid in restful api, fields are not returned
- \#4329 C++ sdk sdk_binary needs to update
- \#4418 Fix search when there are multiple vector fields
## Feature
- \#4163 Update C++ sdk search interface
......
......@@ -73,18 +73,15 @@ SearchReq::OnExecute() {
// step 4: Get field info
std::unordered_map<std::string, engine::DataType> field_types;
auto vector_query = query_ptr_->vectors.begin()->second;
for (auto& schema : fields_schema) {
auto field = schema.first;
field_types.insert(std::make_pair(field->GetName(), field->GetFtype()));
if (field->GetFtype() == engine::DataType::VECTOR_FLOAT ||
field->GetFtype() == engine::DataType::VECTOR_BINARY) {
if (vector_query->field_name == field->GetName() &&
(field->GetFtype() == engine::DataType::VECTOR_FLOAT ||
field->GetFtype() == engine::DataType::VECTOR_BINARY)) {
// check dim
int64_t dimension = field->GetParams()[engine::PARAM_DIMENSION];
auto vector_query = query_ptr_->vectors.begin()->second;
if (vector_query->field_name != field->GetName()) {
return Status(SERVER_INVALID_ARGUMENT,
"DSL vector query field name: " + vector_query->field_name + " is wrong");
}
if (!vector_query->query_vector.binary_data.empty()) {
if (vector_query->query_vector.binary_data.size() !=
......
......@@ -76,6 +76,26 @@ ClientTest::ListCollections(std::vector<std::string>& collection_array) {
}
}
void
ClientTest::CreateMultiVecCollection() {
milvus::FieldPtr field1 = std::make_shared<milvus::Field>("release_year", milvus::DataType::INT32, "");
milvus::FieldPtr field2 = std::make_shared<milvus::Field>("duration", milvus::DataType::INT32, "");
nlohmann::json vector_param = {{"dim", COLLECTION_DIMENSION}};
milvus::FieldPtr field3 =
std::make_shared<milvus::Field>("embedding", milvus::DataType::VECTOR_FLOAT, vector_param.dump());
nlohmann::json vector_param_1 = {{"dim", COLLECTION_DIMENSION + COLLECTION_DIMENSION}};
milvus::FieldPtr field4 =
std::make_shared<milvus::Field>("vec", milvus::DataType::VECTOR_FLOAT, vector_param_1.dump());
nlohmann::json json_param;
json_param = {{"auto_id", false}, {"segment_row_limit", 4096}};
milvus::Mapping mapping = {COLLECTION_NAME, {field1, field2, field3, field4}, json_param.dump()};
milvus::Status status = conn_->CreateCollection(mapping);
std::cout << "CreateCollection function call status: " << status.message() << std::endl;
}
void
ClientTest::CreateCollection() {
milvus::FieldPtr field1 = std::make_shared<milvus::Field>("release_year", milvus::DataType::INT32, "");
......@@ -138,6 +158,31 @@ ClientTest::InsertEntities() {
std::cout << "InsertEntities function call status: " << status.message() << std::endl;
}
void
ClientTest::InsertMultiEntities() {
std::vector<int32_t> duration{208, 226, 252};
std::vector<int32_t> release_year{2001, 2002, 2003};
std::vector<milvus::VectorData> embedding;
milvus_sdk::Utils::BuildVectors(COLLECTION_DIMENSION, 3, embedding);
std::vector<milvus::VectorData> vec;
milvus_sdk::Utils::BuildVectors(COLLECTION_DIMENSION * 2, 3, vec);
milvus::FieldValue field_value;
std::unordered_map<std::string, std::vector<int32_t>> int32_value = {{"duration", duration},
{"release_year", release_year}};
std::unordered_map<std::string, std::vector<milvus::VectorData>> vector_value = {{"embedding", embedding},
{"vec", vec}};
field_value.int32_value = int32_value;
field_value.vector_value = vector_value;
std::vector<int64_t> id_array = {1, 2, 3};
auto status = conn_->Insert(COLLECTION_NAME, PARTITION_TAG, field_value, id_array);
std::cout << "InsertEntities function call status: " << status.message() << std::endl;
}
void
ClientTest::CountEntities(int64_t& entity_count) {
auto status = conn_->CountEntities(COLLECTION_NAME, entity_count);
......@@ -215,8 +260,7 @@ ClientTest::SearchEntities() {
std::vector<std::string> partition_tags;
milvus::TopKQueryResult topk_query_result;
auto status = conn_->Search(COLLECTION_NAME, partition_tags, dsl_json, json_params.dump(),
topk_query_result);
auto status = conn_->Search(COLLECTION_NAME, partition_tags, dsl_json, json_params.dump(), topk_query_result);
std::cout << " Search function call result: " << std::endl;
milvus_sdk::Utils::PrintTopKQueryResult(topk_query_result);
......@@ -288,6 +332,7 @@ ClientTest::Test() {
}
CreateCollection();
// CreateMultiVecCollection();
CreatePartition();
std::cout << "--------get collection info--------" << std::endl;
......@@ -297,6 +342,7 @@ ClientTest::Test() {
std::cout << "\n----------insert----------" << std::endl;
InsertEntities();
// InsertMultiEntities();
int64_t before_flush_counts = 0;
int64_t after_flush_counts = 0;
......
......@@ -33,6 +33,12 @@ class ClientTest {
void
CreateCollection();
void
CreateMultiVecCollection();
void
InsertMultiEntities();
void
CreatePartition();
......
......@@ -496,6 +496,13 @@ Utils::PrintTopKQueryResult(milvus::TopKQueryResult& topk_query_result) {
}
std::cout << std::endl;
}
if (data.first == "vec") {
std::cout << "- " << data.first << ": ";
for (const auto& v : data.second.float_data) {
std::cout << v << " ";
}
std::cout << std::endl;
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册