diff --git a/core/src/db/engine/ExecutionEngineImpl.cpp b/core/src/db/engine/ExecutionEngineImpl.cpp index a9cde3af04a1a0836405b92cae9f828dcd1345d4..5fe51a1591a46be95e1b054834f09a2e0c034acb 100644 --- a/core/src/db/engine/ExecutionEngineImpl.cpp +++ b/core/src/db/engine/ExecutionEngineImpl.cpp @@ -415,7 +415,6 @@ ExecutionEngineImpl::ExecBinaryQuery(const milvus::query::GeneralQueryPtr& gener return status; } else { if (general_query->leaf->term_query != nullptr) { - // process attrs_data bitset = std::make_shared(entity_count_); STATUS_CHECK(ProcessTermQuery(bitset, general_query->leaf->term_query, attr_type)); } @@ -512,6 +511,10 @@ ExecutionEngineImpl::ProcessTermQuery(faiss::ConcurrentBitsetPtr& bitset, const STATUS_CHECK(IndexedTermQuery(bitset, field_name, attr_type.at(field_name), term_it.value())); } } + term_it++; + if (term_it != term_query_json.end()) { + return Status(SERVER_INVALID_DSL_PARAMETER, "Term query does not support multiple fields"); + } } catch (std::exception& ex) { return Status{SERVER_INVALID_DSL_PARAMETER, ex.what()}; } @@ -528,7 +531,7 @@ ProcessIndexedRangeQuery(faiss::ConcurrentBitsetPtr& bitset, knowhere::IndexPtr& bool flag = false; for (auto& range_value_it : range_values_json.items()) { const std::string& comp_op = range_value_it.key(); - T value = range_value_it.value().get(); + T value = range_value_it.value(); if (not flag) { bitset = (*bitset) | T_index->Range(value, knowhere::s_map_operator_type.at(comp_op)); flag = true; @@ -582,15 +585,22 @@ ExecutionEngineImpl::ProcessRangeQuery(const std::unordered_mapGetSegment(segment_ptr); - - auto range_query_json = range_query->json_obj; - JSON_NULL_CHECK(range_query_json); - auto range_it = range_query_json.begin(); - if (range_it != range_query_json.end()) { - const std::string& field_name = range_it.key(); - knowhere::IndexPtr index_ptr = nullptr; - segment_ptr->GetStructuredIndex(field_name, index_ptr); - IndexedRangeQuery(bitset, attr_type.at(field_name), index_ptr, range_it.value()); + try { + auto range_query_json = range_query->json_obj; + JSON_NULL_CHECK(range_query_json); + auto range_it = range_query_json.begin(); + if (range_it != range_query_json.end()) { + const std::string& field_name = range_it.key(); + knowhere::IndexPtr index_ptr = nullptr; + segment_ptr->GetStructuredIndex(field_name, index_ptr); + STATUS_CHECK(IndexedRangeQuery(bitset, attr_type.at(field_name), index_ptr, range_it.value())); + } + range_it++; + if (range_it != range_query_json.end()) { + return Status(SERVER_INVALID_DSL_PARAMETER, "Range query does not support multiple fields"); + } + } catch (std::exception& ex) { + return Status{SERVER_INVALID_DSL_PARAMETER, ex.what()}; } return Status::OK(); } diff --git a/core/unittest/server/test_web.cpp b/core/unittest/server/test_web.cpp index e0c146c8f69f8562c83235cee9136c275155bafb..d2455731699d91f32e39b66405a6901916212080 100644 --- a/core/unittest/server/test_web.cpp +++ b/core/unittest/server/test_web.cpp @@ -46,6 +46,7 @@ INITIALIZE_EASYLOGGINGPP static const char* COLLECTION_NAME = "test_milvus_web_collection"; static int64_t DIM = 128; +static int64_t NB = 100; using OStatus = oatpp::web::protocol::http::Status; using OString = milvus::server::web::OString; @@ -247,20 +248,20 @@ class TestClient : public oatpp::web::client::ApiClient { API_CALL("OPTIONS", "/collections/{collection_name}", optionsCollection, PATH(String, collection_name, "collection_name")) - API_CALL("GET", "/collections/{collection_name}", getCollection, PATH(String, collection_name, "collection_name"), - QUERY(String, info)) + API_CALL("GET", "/collections/{collection_name}", getCollection, PATH(String, collection_name, "collection_name")) API_CALL("DELETE", "/collections/{collection_name}", dropCollection, PATH(String, collection_name, "collection_name")) API_CALL("OPTIONS", "/collections/{collection_name}/fields/{field_name}/indexes", optionsIndexes, - PATH(String, collection_name, "collection_name")) + PATH(String, collection_name, "collection_name"), PATH(String, field_name, "field_name")) API_CALL("POST", "/collections/{collection_name}/fields/{field_name}/indexes", createIndex, - PATH(String, collection_name, "collection_name"), BODY_STRING(OString, body)) + PATH(String, collection_name, "collection_name"), PATH(String, field_name, "field_name"), + BODY_STRING(OString, body)) API_CALL("DELETE", "/collections/{collection_name}/fields/{field_name}/indexes", dropIndex, - PATH(String, collection_name, "collection_name")) + PATH(String, collection_name, "collection_name"), PATH(String, field_name, "field_name")) API_CALL("OPTIONS", "/collections/{collection_name}/partitions", optionsPartitions, PATH(String, collection_name, "collection_name")) @@ -418,7 +419,7 @@ CreateCollection(const TestClientP& client_ptr, const TestConnP& connection_ptr, mapping_json = nlohmann::json::parse(mapping_str); mapping_json["collection_name"] = collection_name; mapping_json["auto_id"] = auto_id; - auto response = client_ptr->getCollection(collection_name.c_str(), "", connection_ptr); + auto response = client_ptr->getCollection(collection_name.c_str(), connection_ptr); if (OStatus::CODE_200.code == response->getStatusCode()) { return; } @@ -516,7 +517,7 @@ TEST_F(WebControllerTest, GET_COLLECTION_INFO) { OQueryParams params; - auto response = client_ptr->getCollection(collection_name.c_str(), "", connection_ptr); + auto response = client_ptr->getCollection(collection_name.c_str(), connection_ptr); ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); auto json_response = nlohmann::json::parse(response->readBodyToString()->c_str()); ASSERT_EQ(collection_name, json_response["collection_name"]); @@ -524,7 +525,7 @@ TEST_F(WebControllerTest, GET_COLLECTION_INFO) { // invalid collection name collection_name = "57474dgdfhdfhdh dgd"; - response = client_ptr->getCollection(collection_name.c_str(), "", connection_ptr); + response = client_ptr->getCollection(collection_name.c_str(), connection_ptr); ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); auto status_sto = response->readBodyToDto(object_mapper.get()); ASSERT_EQ(milvus::server::web::StatusCode::ILLEGAL_COLLECTION_NAME, status_sto->code); @@ -802,14 +803,46 @@ TEST_F(WebControllerTest, SEARCH) { } TEST_F(WebControllerTest, INDEX) { - // auto collection_name = "test_insert_collection_test" + RandomName(); - // nlohmann::json mapping_json; - // CreateCollection(client_ptr, connection_ptr, collection_name, mapping_json); - // - // // test index with imcomplete param - // nlohmann::json index_json; - // auto response = client_ptr->createIndex(collection_name.c_str(), index_json.dump().c_str(), connection_ptr); - // ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); + auto collection_name = "test_index_collection_test" + RandomName(); + nlohmann::json mapping_json; + CreateCollection(client_ptr, connection_ptr, collection_name, mapping_json); + + nlohmann::json insert_json; + GenEntities(NB, DIM, insert_json); + auto response = client_ptr->insert(collection_name.c_str(), insert_json.dump().c_str(), connection_ptr); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + auto result_dto = response->readBodyToDto(object_mapper.get()); + ASSERT_EQ(NB, result_dto->ids->size()); + + // test index with imcomplete param + std::string field_name = "field_vec"; + std::string index_str = R"({ + "metric_type": "L2", + "index_type": "IVF_FLAT", + "params": { + "nlist": 1024 + } + })"; + nlohmann::json index_json; + response = + client_ptr->createIndex(collection_name.c_str(), field_name.c_str(), index_json.dump().c_str(), connection_ptr); + ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); + + index_json = nlohmann::json::parse(index_str); + response = + client_ptr->createIndex(collection_name.c_str(), field_name.c_str(), index_json.dump().c_str(), connection_ptr); + ASSERT_EQ(OStatus::CODE_201.code, response->getStatusCode()); + + response = client_ptr->getCollection(collection_name.c_str(), connection_ptr); + nlohmann::json collection_json = nlohmann::json::parse(response->readBodyToString()->std_str()); + std::cout << collection_json.dump() << std::endl; + for (const auto& field_json : collection_json["fields"]) { + if (field_json["field_name"] == "field_vec") { + nlohmann::json index_params = field_json["index_params"]; + ASSERT_EQ(index_params["index_type"].get(), milvus::knowhere::IndexEnum::INDEX_FAISS_IVFFLAT); + break; + } + } // // // index_json["index_type"] = milvus::server::web::IndexMap.at(milvus::engine::FAISS_IDMAP); //