diff --git a/CHANGELOG.md b/CHANGELOG.md index e74ccd528d4721bbbcdfddd41334ccab80d41348..d7dd6768a45f064fb492d0d1116f4418ff828bee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ Please mark all change in change log and use the issue from GitHub - \#1067 - Add binary vectors support in http server - \#1075 - improve error message when page size or offset is illegal - \#1082 - check page_size or offset value to avoid float +- \#1115 - http server support load table into memory ## Feature - \#216 - Add CLI to get server info diff --git a/core/src/server/web_impl/controller/WebController.hpp b/core/src/server/web_impl/controller/WebController.hpp index 9221b75975b00a4f459bac6dd2866d2fb4eda105..bd8a7d4fa3e2422313d2a0a2793b2fd3568b80bf 100644 --- a/core/src/server/web_impl/controller/WebController.hpp +++ b/core/src/server/web_impl/controller/WebController.hpp @@ -783,14 +783,14 @@ class WebController : public oatpp::web::server::api::ApiController { ADD_CORS(SystemMsg) - ENDPOINT("GET", "/system/{msg}", SystemMsg, PATH(String, msg)) { + ENDPOINT("GET", "/system/{msg}", SystemMsg, PATH(String, msg), QUERIES(const QueryParams&, query_params)) { TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/system/" + msg->std_str() + "\'"); tr.RecordSection("Received request."); auto cmd_dto = CommandDto::createShared(); WebRequestHandler handler = WebRequestHandler(); - auto status_dto = handler.Cmd(msg, cmd_dto); + auto status_dto = handler.Cmd(msg, query_params, cmd_dto); std::shared_ptr response; switch (status_dto->code->getValue()) { case StatusCode::SUCCESS: diff --git a/core/src/server/web_impl/dto/CmdDto.hpp b/core/src/server/web_impl/dto/CmdDto.hpp index db4bd94b5d6552e4d538b788744cab23987ec23d..4a5e7f9182510a120bf54356dbcb4e6e8090f463 100644 --- a/core/src/server/web_impl/dto/CmdDto.hpp +++ b/core/src/server/web_impl/dto/CmdDto.hpp @@ -32,12 +32,6 @@ class CommandDto: public oatpp::data::mapping::type::Object { DTO_FIELD(String, reply, "reply"); }; -class CmdFieldsDto : public OObject { - DTO_INIT(CmdFieldsDto, Object) - - DTO_FIELD(Fields::ObjectWrapper, reply); -}; - #include OATPP_CODEGEN_END(DTO) } // namespace web diff --git a/core/src/server/web_impl/dto/VectorDto.hpp b/core/src/server/web_impl/dto/VectorDto.hpp index 96fb469cafe77b335d1a4e7b400cdb930aa035e3..8b3b5962033783d23e542b8067713e9e391342b8 100644 --- a/core/src/server/web_impl/dto/VectorDto.hpp +++ b/core/src/server/web_impl/dto/VectorDto.hpp @@ -26,18 +26,6 @@ namespace web { #include OATPP_CODEGEN_BEGIN(DTO) -class RowRecordDto : public oatpp::data::mapping::type::Object { - DTO_INIT(RowRecordDto, Object) - - DTO_FIELD(List::ObjectWrapper, record); -}; - -class RecordsDto : public oatpp::data::mapping::type::Object { - DTO_INIT(RecordsDto, Object) - - DTO_FIELD(List::ObjectWrapper, records); -}; - class SearchRequestDto : public OObject { DTO_INIT(SearchRequestDto, Object) diff --git a/core/src/server/web_impl/handler/WebRequestHandler.cpp b/core/src/server/web_impl/handler/WebRequestHandler.cpp index 3394f8388c4887a50ab8f3128ca621f6c125abb3..017f9757947b6a55b7c145159577860fedef08cd 100644 --- a/core/src/server/web_impl/handler/WebRequestHandler.cpp +++ b/core/src/server/web_impl/handler/WebRequestHandler.cpp @@ -746,15 +746,40 @@ WebRequestHandler::Search(const OString& table_name, const SearchRequestDto::Obj } StatusDto::ObjectWrapper -WebRequestHandler::Cmd(const OString& cmd, CommandDto::ObjectWrapper& cmd_dto) { +WebRequestHandler::Cmd(const OString& cmd, const OQueryParams& query_params, CommandDto::ObjectWrapper& cmd_dto) { std::string info = cmd->std_str(); + auto status = Status::OK(); + + // TODO: (yhz) now only support load table into memory, may remove in the future + if ("task" == info) { + auto action = query_params.get("action"); + if (nullptr == action.get()) { + RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param \'action\' is required in url \'/system/task\'"); + } + std::string action_str = action->std_str(); + + auto target = query_params.get("target"); + if (nullptr == target.get()) { + RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param \'target\' is required in url \'/system/task\'"); + } + std::string target_str = target->std_str(); + + if ("load" == action_str) { + status = request_handler_.PreloadTable(context_ptr_, target_str); + } else { + std::string error_msg = std::string("Unknown action value \'") + action_str + "\'"; + RETURN_STATUS_DTO(ILLEGAL_QUERY_PARAM, error_msg.c_str()); + } + + ASSIGN_RETURN_STATUS_DTO(status) + } if ("info" == info) { info = "get_system_info"; } std::string reply_str; - auto status = CommandLine(info, reply_str); + status = CommandLine(info, reply_str); if (status.ok()) { cmd_dto->reply = reply_str.c_str(); diff --git a/core/src/server/web_impl/handler/WebRequestHandler.h b/core/src/server/web_impl/handler/WebRequestHandler.h index 90d45276e2d5d9f75ac6333871cd258336cf6576..23bfd720d1092bd6c0b0865817abc63cf89f1072 100644 --- a/core/src/server/web_impl/handler/WebRequestHandler.h +++ b/core/src/server/web_impl/handler/WebRequestHandler.h @@ -152,7 +152,7 @@ class WebRequestHandler { TopkResultsDto::ObjectWrapper& results_dto); StatusDto::ObjectWrapper - Cmd(const OString& cmd, CommandDto::ObjectWrapper& cmd_dto); + Cmd(const OString& cmd, const OQueryParams& query_params, CommandDto::ObjectWrapper& cmd_dto); public: WebRequestHandler& diff --git a/core/unittest/server/test_web.cpp b/core/unittest/server/test_web.cpp index 6cf98be44182963d5782d3055b392bfbea44f045..d083c13edd81891caf7d8dd38b510d3bf4827014 100644 --- a/core/unittest/server/test_web.cpp +++ b/core/unittest/server/test_web.cpp @@ -413,12 +413,13 @@ TEST_F(WebHandlerTest, CMD) { auto cmd_dto = milvus::server::web::CommandDto::createShared(); cmd = "status"; - auto status_dto = handler->Cmd(cmd, cmd_dto); + OQueryParams query_params; + auto status_dto = handler->Cmd(cmd, query_params, cmd_dto); ASSERT_EQ(0, status_dto->code->getValue()); ASSERT_EQ("OK", cmd_dto->reply->std_str()); cmd = "version"; - status_dto = handler->Cmd(cmd, cmd_dto); + status_dto = handler->Cmd(cmd, query_params, cmd_dto); ASSERT_EQ(0, status_dto->code->getValue()); ASSERT_EQ("0.6.0", cmd_dto->reply->std_str()); } @@ -547,7 +548,7 @@ class TestClient : public oatpp::web::client::ApiClient { API_CALL("PUT", "/tables/{table_name}/vectors", search, PATH(String, table_name, "table_name"), BODY_DTO(milvus::server::web::SearchRequestDto::ObjectWrapper, body)) - API_CALL("GET", "/system/{msg}", cmd, PATH(String, cmd_str, "msg")) + API_CALL("GET", "/system/{msg}", cmd, PATH(String, cmd_str, "msg"), QUERY(String, action), QUERY(String, target)) #include OATPP_CODEGEN_END(ApiClient) }; @@ -1122,11 +1123,27 @@ TEST_F(WebControllerTest, SEARCH_BIN) { } TEST_F(WebControllerTest, CMD) { - auto response = client_ptr->cmd("status", conncetion_ptr); + auto response = client_ptr->cmd("status", "", "", conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + response = client_ptr->cmd("version", "", "", conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + response = client_ptr->cmd("mode", "", "", conncetion_ptr); ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); - response = client_ptr->cmd("version", conncetion_ptr); + response = client_ptr->cmd("tasktable", "", "", conncetion_ptr); ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + response = client_ptr->cmd("info", "", "", conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + + GenTable("test_cmd", 16, 10, "L2"); + response = client_ptr->cmd("task", "load", "test_cmd", conncetion_ptr); + ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); + // task without existing table + response = client_ptr->cmd("task", "load", "test_cmdXXXXXXXXXXXX", conncetion_ptr); + ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode()); } TEST_F(WebControllerTest, ADVANCED_CONFIG) { diff --git a/tests/milvus_python_test/test_add_vectors.py b/tests/milvus_python_test/test_add_vectors.py index e43a6b2cf6f2e923784b1dbb4d089fefe2da14c3..ad6d099d5fbbed73214070181a5a838801d501ab 100644 --- a/tests/milvus_python_test/test_add_vectors.py +++ b/tests/milvus_python_test/test_add_vectors.py @@ -426,9 +426,9 @@ class TestAddBase: nq = 5 vectors = gen_vectors(nq, dim) vector_id = get_vector_id - ids = [vector_id for i in range(nq)] - with pytest.raises(Exception) as e: - status, ids = connect.add_vectors(table, vectors, ids) + ids = [vector_id for _ in range(nq)] + with pytest.raises(Exception): + connect.add_vectors(table, vectors, ids) @pytest.mark.timeout(ADD_TIMEOUT) def test_add_vectors(self, connect, table): @@ -591,7 +591,7 @@ class TestAddBase: 'dimension': dim, 'index_file_size': index_file_size, 'metric_type': MetricType.L2} - milvus = get_milvus() + milvus = get_milvus(args["handler"]) milvus.connect(uri=uri) milvus.create_table(param) vector = gen_single_vector(dim) @@ -599,7 +599,7 @@ class TestAddBase: loop_num = 5 processes = [] def add(): - milvus = get_milvus() + milvus = get_milvus(args["handler"]) milvus.connect(uri=uri) i = 0 while i < loop_num: diff --git a/tests/milvus_python_test/test_connect.py b/tests/milvus_python_test/test_connect.py index f3a99c7d437c11ca862564fc88b7c45e4bfcb482..ec26aa7247c8746a6d2ae992899a1b1bd7cff1c1 100644 --- a/tests/milvus_python_test/test_connect.py +++ b/tests/milvus_python_test/test_connect.py @@ -36,7 +36,7 @@ class TestConnect: expected: raise an error after disconnected ''' if not connect.connected(): - milvus = get_milvus() + milvus = get_milvus(args["handler"]) uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) milvus.connect(uri=uri_value) res = milvus.disconnect() @@ -53,7 +53,7 @@ class TestConnect: method: set correct ip and port expected: connected is True ''' - milvus = get_milvus() + milvus = get_milvus(args["handler"]) milvus.connect(host=args["ip"], port=args["port"]) assert milvus.connected() @@ -63,7 +63,7 @@ class TestConnect: method: set correct ip and port expected: connected is False ''' - milvus = get_milvus() + milvus = get_milvus(args["handler"]) milvus.connect(host=args["ip"], port=args["port"]) milvus.disconnect() assert not milvus.connected() @@ -75,7 +75,7 @@ class TestConnect: method: set host localhost expected: connected is True ''' - milvus = get_milvus() + milvus = get_milvus(args["handler"]) milvus.connect(host='localhost', port=args["port"]) assert milvus.connected() @@ -86,7 +86,7 @@ class TestConnect: method: set host null expected: not use default ip, connected is False ''' - milvus = get_milvus() + milvus = get_milvus(args["handler"]) ip = "" with pytest.raises(Exception) as e: milvus.connect(host=ip, port=args["port"], timeout=1) @@ -98,7 +98,7 @@ class TestConnect: method: uri format and value are both correct expected: connected is True ''' - milvus = get_milvus() + milvus = get_milvus(args["handler"]) uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) milvus.connect(uri=uri_value) assert milvus.connected() @@ -109,7 +109,7 @@ class TestConnect: method: uri set null expected: connected is True ''' - milvus = get_milvus() + milvus = get_milvus(args["handler"]) uri_value = "" if self.local_ip(args): @@ -128,7 +128,7 @@ class TestConnect: method: set uri port null expected: connected is True ''' - milvus = get_milvus() + milvus = get_milvus(args["handler"]) uri_value = "tcp://%s:" % args["ip"] with pytest.raises(Exception) as e: milvus.connect(uri=uri_value, timeout=1) @@ -141,7 +141,7 @@ class TestConnect: method: set uri ip null expected: connected is True ''' - milvus = get_milvus() + milvus = get_milvus(args["handler"]) uri_value = "tcp://:%s" % args["port"] with pytest.raises(Exception) as e: @@ -166,7 +166,7 @@ class TestConnect: assert milvus.connected() for i in range(process_num): - milvus = get_milvus() + milvus = get_milvus(args["handler"]) p = Process(target=connect, args=(milvus, )) processes.append(p) p.start() @@ -179,7 +179,7 @@ class TestConnect: method: connect again expected: status.code is 0, and status.message shows have connected already ''' - milvus = get_milvus() + milvus = get_milvus(args["handler"]) uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) milvus.connect(uri=uri_value) @@ -192,7 +192,7 @@ class TestConnect: method: disconnect, and then connect, assert connect status expected: status.code is 0 ''' - milvus = get_milvus() + milvus = get_milvus(args["handler"]) uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) milvus.connect(uri=uri_value) @@ -207,7 +207,7 @@ class TestConnect: expected: status.code is 0 ''' times = 10 - milvus = get_milvus() + milvus = get_milvus(args["handler"]) uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) milvus.connect(uri=uri_value) for i in range(times): @@ -233,7 +233,7 @@ class TestConnect: assert milvus.connected() for i in range(process_num): - milvus = get_milvus() + milvus = get_milvus(args["handler"]) p = Process(target=connect, args=(milvus, )) processes.append(p) p.start() @@ -246,7 +246,7 @@ class TestConnect: method: port set "", check if wrong uri connection is ok expected: connect raise an exception and connected is false ''' - milvus = get_milvus() + milvus = get_milvus(args["handler"]) uri_value = "tcp://%s:39540" % args["ip"] with pytest.raises(Exception) as e: milvus.connect(host=args["ip"], port="", uri=uri_value) @@ -257,7 +257,7 @@ class TestConnect: method: host set "", check if correct uri connection is ok expected: connected is False ''' - milvus = get_milvus() + milvus = get_milvus(args["handler"]) uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) with pytest.raises(Exception) as e: milvus.connect(host="", port=args["port"], uri=uri_value, timeout=1) @@ -270,7 +270,7 @@ class TestConnect: method: check if wrong uri connection is ok expected: connect raise an exception and connected is false ''' - milvus = get_milvus() + milvus = get_milvus(args["handler"]) uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) with pytest.raises(Exception) as e: res = milvus.connect(host=args["ip"], port=39540, uri=uri_value, timeout=1) @@ -326,7 +326,7 @@ class TestConnectIPInvalid(object): @pytest.mark.level(2) @pytest.mark.timeout(CONNECT_TIMEOUT) def test_connect_with_invalid_ip(self, args, get_invalid_ip): - milvus = get_milvus() + milvus = get_milvus(args["handler"]) ip = get_invalid_ip with pytest.raises(Exception) as e: milvus.connect(host=ip, port=args["port"], timeout=1) @@ -353,7 +353,7 @@ class TestConnectPortInvalid(object): method: set port in gen_invalid_ports expected: connected is False ''' - milvus = get_milvus() + milvus = get_milvus(args["handler"]) port = get_invalid_port with pytest.raises(Exception) as e: milvus.connect(host=args["ip"], port=port, timeout=1) @@ -373,13 +373,13 @@ class TestConnectURIInvalid(object): @pytest.mark.level(2) @pytest.mark.timeout(CONNECT_TIMEOUT) - def test_connect_with_invalid_uri(self, get_invalid_uri): + def test_connect_with_invalid_uri(self, get_invalid_uri, args): ''' target: test uri connect with invalid uri value method: set port in gen_invalid_uris expected: connected is False ''' - milvus = get_milvus() + milvus = get_milvus(args["handler"]) uri_value = get_invalid_uri with pytest.raises(Exception) as e: milvus.connect(uri=uri_value, timeout=1) diff --git a/tests/milvus_python_test/test_index.py b/tests/milvus_python_test/test_index.py index cf03f8568649b4bd9fd3716188ec05152c5d2c9c..594cdb82ef3073393b58b245bb402be3bdc806a6 100644 --- a/tests/milvus_python_test/test_index.py +++ b/tests/milvus_python_test/test_index.py @@ -146,7 +146,7 @@ class TestIndexBase: uri = "tcp://%s:%s" % (args["ip"], args["port"]) for i in range(process_num): - m = get_milvus() + m = get_milvus(args["handler"]) m.connect(uri=uri) p = Process(target=build, args=(m,)) processes.append(p) @@ -205,7 +205,7 @@ class TestIndexBase: uri = "tcp://%s:%s" % (args["ip"], args["port"]) for i in range(process_num): - m = get_milvus() + m = get_milvus(args["handler"]) m.connect(uri=uri) ids = i p = Process(target=create_index, args=(m,ids)) @@ -669,7 +669,7 @@ class TestIndexIP: uri = "tcp://%s:%s" % (args["ip"], args["port"]) for i in range(process_num): - m = get_milvus() + m = get_milvus(args["handler"]) m.connect(uri=uri) p = Process(target=build, args=(m,)) processes.append(p) @@ -726,7 +726,7 @@ class TestIndexIP: uri = "tcp://%s:%s" % (args["ip"], args["port"]) for i in range(process_num): - m = get_milvus() + m = get_milvus(args["handler"]) m.connect(uri=uri) ids = i p = Process(target=create_index, args=(m,ids)) diff --git a/tests/milvus_python_test/test_mix.py b/tests/milvus_python_test/test_mix.py index f7baa8dd22cf96506e892b358d2a3c962db5465f..a5ac8c5821f5b482ea30d26055de86662828ff29 100644 --- a/tests/milvus_python_test/test_mix.py +++ b/tests/milvus_python_test/test_mix.py @@ -32,7 +32,7 @@ class TestMixBase: query_vecs = [vectors[0], vectors[1]] uri = "tcp://%s:%s" % (args["ip"], args["port"]) id_0 = 0; id_1 = 0 - milvus_instance = get_milvus() + milvus_instance = get_milvus(args["handler"]) milvus_instance.connect(uri=uri) milvus_instance.create_table({'table_name': table, 'dimension': dim, @@ -60,11 +60,11 @@ class TestMixBase: logging.getLogger().info(status) assert result[0][0].id == id_0 assert result[1][0].id == id_1 - milvus_instance = get_milvus() + milvus_instance = get_milvus(args["handler"]) milvus_instance.connect(uri=uri) p_search = Process(target=search, args=(milvus_instance, )) p_search.start() - milvus_instance = get_milvus() + milvus_instance = get_milvus(args["handler"]) milvus_instance.connect(uri=uri) p_create = Process(target=add_vectors, args=(milvus_instance, )) p_create.start() diff --git a/tests/milvus_python_test/test_search_vectors.py b/tests/milvus_python_test/test_search_vectors.py index 237e4521c6470d817edfa147facc288d88c968cb..6deede3db6a6b568ba77488316b13f8b9e37c00c 100644 --- a/tests/milvus_python_test/test_search_vectors.py +++ b/tests/milvus_python_test/test_search_vectors.py @@ -743,7 +743,7 @@ class TestSearchBase: 'index_type': IndexType.FLAT, 'store_raw_vector': False} # create table - milvus = get_milvus() + milvus = get_milvus(args["handler"]) milvus.connect(uri=uri) milvus.create_table(param) vectors, ids = self.init_data(milvus, table, nb=nb) @@ -756,7 +756,7 @@ class TestSearchBase: assert result[i][0].distance == 0.0 for i in range(process_num): - milvus = get_milvus() + milvus = get_milvus(args["handler"]) milvus.connect(uri=uri) p = Process(target=search, args=(milvus, )) processes.append(p) @@ -784,7 +784,7 @@ class TestSearchBase: 'index_file_size': 10, 'metric_type': MetricType.L2} # create table - milvus = get_milvus() + milvus = get_milvus(args["handler"]) milvus.connect(uri=uri) milvus.create_table(param) status, ids = milvus.add_vectors(table, vectors) @@ -826,7 +826,7 @@ class TestSearchBase: 'index_file_size': 10, 'metric_type': MetricType.L2} # create table - milvus = get_milvus() + milvus = get_milvus(args["handler"]) milvus.connect(uri=uri) milvus.create_table(param) status, ids = milvus.add_vectors(table, vectors) diff --git a/tests/milvus_python_test/test_table.py b/tests/milvus_python_test/test_table.py index 84c333f520fb093a96177b9e2f0fa3bebaf62492..588a180c3a35fd2199f5e4572bfd666a4a198902 100644 --- a/tests/milvus_python_test/test_table.py +++ b/tests/milvus_python_test/test_table.py @@ -278,7 +278,7 @@ class TestTable: process_num = 4 processes = [] for i in range(process_num): - milvus = get_milvus() + milvus = get_milvus(args["handler"]) milvus.connect(uri=uri) p = Process(target=describetable, args=(milvus,)) processes.append(p) @@ -458,7 +458,7 @@ class TestTable: assert status.OK() for i in range(process_num): - milvus = get_milvus() + milvus = get_milvus(args["handler"]) milvus.connect(uri=uri) p = Process(target=deletetable, args=(milvus,)) processes.append(p) @@ -711,7 +711,7 @@ class TestTable: processes = [] for i in range(process_num): - milvus = get_milvus() + milvus = get_milvus(args["handler"]) milvus.connect(uri=uri) p = Process(target=showtables, args=(milvus,)) processes.append(p) diff --git a/tests/milvus_python_test/test_table_count.py b/tests/milvus_python_test/test_table_count.py index feb2156edc655dc0de57f4104e59f8ceeabd3a50..10bc4e5a3c4bb8119614c7488db1ba475b036195 100644 --- a/tests/milvus_python_test/test_table_count.py +++ b/tests/milvus_python_test/test_table_count.py @@ -206,7 +206,7 @@ class TestTableCount: process_num = 8 processes = [] for i in range(process_num): - milvus = get_milvus() + milvus = get_milvus(args["handler"]) milvus.connect(uri=uri) p = Process(target=rows_count, args=(milvus, )) processes.append(p) @@ -351,7 +351,7 @@ class TestTableCountIP: process_num = 8 processes = [] for i in range(process_num): - milvus = get_milvus() + milvus = get_milvus(args["handler"]) milvus.connect(uri=uri) p = Process(target=rows_count, args=(milvus,)) processes.append(p)