未验证 提交 9fcc3834 编写于 作者: Y yukun 提交者: GitHub

Fix TestSearchDSL multi fields bug (#3411)

Signed-off-by: Nfishpenguin <kun.yu@zilliz.com>
上级 2e5ff884
......@@ -415,7 +415,6 @@ ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& gener
return status;
} else {
if (general_query->leaf->term_query != nullptr) {
// process attrs_data
bitset = std::make_shared<faiss::ConcurrentBitset>(entity_count_);
STATUS_CHECK(ProcessTermQuery(bitset, general_query->leaf->term_query, attr_type));
}
......@@ -512,6 +511,10 @@ ExecutionEngineImpl::ProcessTermQuery(faiss::ConcurrentBitsetPtr& bitset, const
STATUS_CHECK(IndexedTermQuery(bitset, field_name, attr_type.at(field_name), term_it.value()));
}
}
term_it++;
if (term_it != term_query_json.end()) {
return Status(SERVER_INVALID_DSL_PARAMETER, "Term query does not support multiple fields");
}
} catch (std::exception& ex) {
return Status{SERVER_INVALID_DSL_PARAMETER, ex.what()};
}
......@@ -528,7 +531,7 @@ ProcessIndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr&
bool flag = false;
for (auto& range_value_it : range_values_json.items()) {
const std::string& comp_op = range_value_it.key();
T value = range_value_it.value().get<T>();
T value = range_value_it.value();
if (not flag) {
bitset = (*bitset) | T_index->Range(value, knowhere::s_map_operator_type.at(comp_op));
flag = true;
......@@ -582,15 +585,22 @@ ExecutionEngineImpl::ProcessRangeQuery(const std::unordered_map<std::string, Dat
faiss::ConcurrentBitsetPtr& bitset, const query::RangeQueryPtr& range_query) {
SegmentPtr segment_ptr;
segment_reader_->GetSegment(segment_ptr);
auto range_query_json = range_query->json_obj;
JSON_NULL_CHECK(range_query_json);
auto range_it = range_query_json.begin();
if (range_it != range_query_json.end()) {
const std::string& field_name = range_it.key();
knowhere::IndexPtr index_ptr = nullptr;
segment_ptr->GetStructuredIndex(field_name, index_ptr);
IndexedRangeQuery(bitset, attr_type.at(field_name), index_ptr, range_it.value());
try {
auto range_query_json = range_query->json_obj;
JSON_NULL_CHECK(range_query_json);
auto range_it = range_query_json.begin();
if (range_it != range_query_json.end()) {
const std::string& field_name = range_it.key();
knowhere::IndexPtr index_ptr = nullptr;
segment_ptr->GetStructuredIndex(field_name, index_ptr);
STATUS_CHECK(IndexedRangeQuery(bitset, attr_type.at(field_name), index_ptr, range_it.value()));
}
range_it++;
if (range_it != range_query_json.end()) {
return Status(SERVER_INVALID_DSL_PARAMETER, "Range query does not support multiple fields");
}
} catch (std::exception& ex) {
return Status{SERVER_INVALID_DSL_PARAMETER, ex.what()};
}
return Status::OK();
}
......
......@@ -46,6 +46,7 @@ INITIALIZE_EASYLOGGINGPP
static const char* COLLECTION_NAME = "test_milvus_web_collection";
static int64_t DIM = 128;
static int64_t NB = 100;
using OStatus = oatpp::web::protocol::http::Status;
using OString = milvus::server::web::OString;
......@@ -247,20 +248,20 @@ class TestClient : public oatpp::web::client::ApiClient {
API_CALL("OPTIONS", "/collections/{collection_name}", optionsCollection,
PATH(String, collection_name, "collection_name"))
API_CALL("GET", "/collections/{collection_name}", getCollection, PATH(String, collection_name, "collection_name"),
QUERY(String, info))
API_CALL("GET", "/collections/{collection_name}", getCollection, PATH(String, collection_name, "collection_name"))
API_CALL("DELETE", "/collections/{collection_name}", dropCollection,
PATH(String, collection_name, "collection_name"))
API_CALL("OPTIONS", "/collections/{collection_name}/fields/{field_name}/indexes", optionsIndexes,
PATH(String, collection_name, "collection_name"))
PATH(String, collection_name, "collection_name"), PATH(String, field_name, "field_name"))
API_CALL("POST", "/collections/{collection_name}/fields/{field_name}/indexes", createIndex,
PATH(String, collection_name, "collection_name"), BODY_STRING(OString, body))
PATH(String, collection_name, "collection_name"), PATH(String, field_name, "field_name"),
BODY_STRING(OString, body))
API_CALL("DELETE", "/collections/{collection_name}/fields/{field_name}/indexes", dropIndex,
PATH(String, collection_name, "collection_name"))
PATH(String, collection_name, "collection_name"), PATH(String, field_name, "field_name"))
API_CALL("OPTIONS", "/collections/{collection_name}/partitions", optionsPartitions,
PATH(String, collection_name, "collection_name"))
......@@ -418,7 +419,7 @@ CreateCollection(const TestClientP& client_ptr, const TestConnP& connection_ptr,
mapping_json = nlohmann::json::parse(mapping_str);
mapping_json["collection_name"] = collection_name;
mapping_json["auto_id"] = auto_id;
auto response = client_ptr->getCollection(collection_name.c_str(), "", connection_ptr);
auto response = client_ptr->getCollection(collection_name.c_str(), connection_ptr);
if (OStatus::CODE_200.code == response->getStatusCode()) {
return;
}
......@@ -516,7 +517,7 @@ TEST_F(WebControllerTest, GET_COLLECTION_INFO) {
OQueryParams params;
auto response = client_ptr->getCollection(collection_name.c_str(), "", connection_ptr);
auto response = client_ptr->getCollection(collection_name.c_str(), connection_ptr);
ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());
auto json_response = nlohmann::json::parse(response->readBodyToString()->c_str());
ASSERT_EQ(collection_name, json_response["collection_name"]);
......@@ -524,7 +525,7 @@ TEST_F(WebControllerTest, GET_COLLECTION_INFO) {
// invalid collection name
collection_name = "57474dgdfhdfhdh dgd";
response = client_ptr->getCollection(collection_name.c_str(), "", connection_ptr);
response = client_ptr->getCollection(collection_name.c_str(), connection_ptr);
ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode());
auto status_sto = response->readBodyToDto<milvus::server::web::StatusDtoT>(object_mapper.get());
ASSERT_EQ(milvus::server::web::StatusCode::ILLEGAL_COLLECTION_NAME, status_sto->code);
......@@ -802,14 +803,46 @@ TEST_F(WebControllerTest, SEARCH) {
}
TEST_F(WebControllerTest, INDEX) {
// auto collection_name = "test_insert_collection_test" + RandomName();
// nlohmann::json mapping_json;
// CreateCollection(client_ptr, connection_ptr, collection_name, mapping_json);
//
// // test index with imcomplete param
// nlohmann::json index_json;
// auto response = client_ptr->createIndex(collection_name.c_str(), index_json.dump().c_str(), connection_ptr);
// ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode());
auto collection_name = "test_index_collection_test" + RandomName();
nlohmann::json mapping_json;
CreateCollection(client_ptr, connection_ptr, collection_name, mapping_json);
nlohmann::json insert_json;
GenEntities(NB, DIM, insert_json);
auto response = client_ptr->insert(collection_name.c_str(), insert_json.dump().c_str(), connection_ptr);
ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode());
auto result_dto = response->readBodyToDto<milvus::server::web::EntityIdsDtoT>(object_mapper.get());
ASSERT_EQ(NB, result_dto->ids->size());
// test index with imcomplete param
std::string field_name = "field_vec";
std::string index_str = R"({
"metric_type": "L2",
"index_type": "IVF_FLAT",
"params": {
"nlist": 1024
}
})";
nlohmann::json index_json;
response =
client_ptr->createIndex(collection_name.c_str(), field_name.c_str(), index_json.dump().c_str(), connection_ptr);
ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode());
index_json = nlohmann::json::parse(index_str);
response =
client_ptr->createIndex(collection_name.c_str(), field_name.c_str(), index_json.dump().c_str(), connection_ptr);
ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode());
response = client_ptr->getCollection(collection_name.c_str(), connection_ptr);
nlohmann::json collection_json = nlohmann::json::parse(response->readBodyToString()->std_str());
std::cout << collection_json.dump() << std::endl;
for (const auto& field_json : collection_json["fields"]) {
if (field_json["field_name"] == "field_vec") {
nlohmann::json index_params = field_json["index_params"];
ASSERT_EQ(index_params["index_type"].get<std::string>(), milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT);
break;
}
}
//
// // index_json["index_type"] = milvus::server::web::IndexMap.at(milvus::engine::FAISS_IDMAP);
//
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册