未验证 提交 f10f6cd5 编写于 作者: B BossZou 提交者: GitHub

Support run dev test with http handler in python SDK (#1116)

* refactoring(create_table done)

* refactoring

* refactor server delivery (insert done)

* refactoring server module (count_table done)

* server refactor done

* cmake pass

* refactor server module done.

* set grpc response status correctly

* format done.

* fix redefine ErrorMap()

* optimize insert reducing ids data copy

* optimize grpc request with reducing data copy

* clang format

* [skip ci] Refactor server module done. update changlog. prepare for PR

* remove explicit and change int32_t to int64_t

* add web server

* [skip ci] add license in web module

* modify header include & comment oatpp environment config

* add port configure & create table in handler

* modify web url

* simple url complation done & add swagger

* make sure web url

* web functionality done. debuging

* add web unittest

* web test pass

* add web server port

* add web server port in template

* update unittest cmake file

* change web server default port to 19121

* rename method in web module & unittest pass

* add search case in unittest for web module

* rename some variables

* fix bug

* unittest pass

* web prepare

* fix cmd bug(check server status)

* update changlog

* add web port validate & default set

* clang-format pass

* add web port test in unittest

* add CORS & redirect root to swagger ui

* add web status

* web table method func cascade test pass

* add config url in web module

* modify thirdparty cmake to avoid building oatpp test

* clang format

* update changlog

* add constants in web module

* reserve Config.cpp

* fix constants reference bug

* replace web server with async module

* modify component to support async

* format

* developing controller & add test clent into unittest

* add web port into demo/server_config

* modify thirdparty cmake to allow build test

* remove  unnecessary comment

* add endpoint info in controller

* finish web test(bug here)

* clang format

* add web test cpp to lint exclusions

* check null field in GetConfig

* add macro RETURN STATUS DTo

* fix cmake conflict

* fix crash when exit server

* remove surplus comments & add http param check

* add uri /docs to direct swagger

* format

* change cmd to system

* add default value & unittest in web module

* add macros to judge if GPU supported

* add macros in unit & add default in index dto & print error message when bind http port fail

* format (fix #788)

* fix cors bug (not completed)

* comment cors

* change web framework to simple api

* comments optimize

* change to simple API

* remove comments in controller.hpp

* remove EP_COMMON_CMAKE_ARGS in oatpp and oatpp-swagger

* add ep cmake args to sqlite

* clang-format

* change a format

* test pass

* change name to

* fix compiler issue(oatpp-swagger depend on oatpp)

* add & in start_server.h

* specify lib location with oatpp and oatpp-swagger

* add comments

* add swagger definition

* [skip ci] change http method options status code

* remove oatpp swagger(fix #970)

* remove comments

* check Start web behavior

* add default to cpu_cache_capacity

* remove swagger component.hpp & /docs url

* remove /docs info

* remove /docs in unittest

* remove space in test rpc

* remove repeate info in CHANGLOG

* change cache_insert_data default value as a constant

* [skip ci] Fix some broken links (#960)

* [skip ci] Fix broken link

* [skip ci] Fix broken link

* [skip ci] Fix broken link

* [skip ci] Fix broken links

* fix issue 373 (#964)

* fix issue 373

* Adjustment format

* Adjustment format

* Adjustment format

* change readme

* #966 update NOTICE.md (#967)

* remove comments

* check Start web behavior

* add default to cpu_cache_capacity

* remove swagger component.hpp & /docs url

* remove /docs info

* remove /docs in unittest

* remove space in test rpc

* remove repeate info in CHANGLOG

* change cache_insert_data default value as a constant

* adjust web port cofig place

* rename web_port variable

* change gpu resources invoke way to cmd()

* set advanced config name add DEFAULT

* change config setting to cmd

* modify ..

* optimize code

* assign TableDto' count default value 0 (fix #995)

* check if table exists when show partitions (fix #1028)

* check table exists when drop partition (fix #1029)

* check if partition name is legal (fix #1022)

* modify status code when partition tag is illegal

* update changlog

* add info to /system url

* add binary index and add bin uri & handler method(not completed)

* optimize http insert and search time(fix #1066) | add binary vectors support(fix #1067)

* fix test partition bug

* fix test bug when check insert records

* add binary vectors test

* add default for offset and page_size

* fix uinttest bug

* [skip ci] remove comments

* optimize web code for PR comments

* add new folder named utils

* check offset and pagesize (fix #1082)

* improve error message if offset or page_size is not legal (fix #1075)

* add log into web module

* update changlog

* check gpu sources setting when assign repeated value (fix #990)

* update changlog

* clang-format pass

* add default handler in http handler

* [skip ci] improve error msg when check gpu resources

* change check offset way

* remove func IsIntStr

* add case

* change int32 to int64 when check number str

* add log in we module(doing)

* update test case

* add log in web controller

* remove surplus dot

* add preload into /system/

* change get_milvus() to get_milvus(args['handler'])

* support load table into memory with http server (fix #1115)

* [skip ci] comment surplus dto in VectorDto
Co-authored-by: Njielinxu <52057195+jielinxu@users.noreply.github.com>
Co-authored-by: NJackLCL <53512883+JackLCL@users.noreply.github.com>
Co-authored-by: NCai Yudong <yudong.cai@zilliz.com>
上级 4dee7dfa
...@@ -21,6 +21,7 @@ Please mark all change in change log and use the issue from GitHub ...@@ -21,6 +21,7 @@ Please mark all change in change log and use the issue from GitHub
- \#1067 - Add binary vectors support in http server - \#1067 - Add binary vectors support in http server
- \#1075 - improve error message when page size or offset is illegal - \#1075 - improve error message when page size or offset is illegal
- \#1082 - check page_size or offset value to avoid float - \#1082 - check page_size or offset value to avoid float
- \#1115 - http server support load table into memory
## Feature ## Feature
- \#216 - Add CLI to get server info - \#216 - Add CLI to get server info
......
...@@ -783,14 +783,14 @@ class WebController : public oatpp::web::server::api::ApiController { ...@@ -783,14 +783,14 @@ class WebController : public oatpp::web::server::api::ApiController {
ADD_CORS(SystemMsg) ADD_CORS(SystemMsg)
ENDPOINT("GET", "/system/{msg}", SystemMsg, PATH(String, msg)) { ENDPOINT("GET", "/system/{msg}", SystemMsg, PATH(String, msg), QUERIES(const QueryParams&, query_params)) {
TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/system/" + msg->std_str() + "\'"); TimeRecorder tr(std::string(WEB_LOG_PREFIX) + "GET \'/system/" + msg->std_str() + "\'");
tr.RecordSection("Received request."); tr.RecordSection("Received request.");
auto cmd_dto = CommandDto::createShared(); auto cmd_dto = CommandDto::createShared();
WebRequestHandler handler = WebRequestHandler(); WebRequestHandler handler = WebRequestHandler();
auto status_dto = handler.Cmd(msg, cmd_dto); auto status_dto = handler.Cmd(msg, query_params, cmd_dto);
std::shared_ptr<OutgoingResponse> response; std::shared_ptr<OutgoingResponse> response;
switch (status_dto->code->getValue()) { switch (status_dto->code->getValue()) {
case StatusCode::SUCCESS: case StatusCode::SUCCESS:
......
...@@ -32,12 +32,6 @@ class CommandDto: public oatpp::data::mapping::type::Object { ...@@ -32,12 +32,6 @@ class CommandDto: public oatpp::data::mapping::type::Object {
DTO_FIELD(String, reply, "reply"); DTO_FIELD(String, reply, "reply");
}; };
class CmdFieldsDto : public OObject {
DTO_INIT(CmdFieldsDto, Object)
DTO_FIELD(Fields<String>::ObjectWrapper, reply);
};
#include OATPP_CODEGEN_END(DTO) #include OATPP_CODEGEN_END(DTO)
} // namespace web } // namespace web
......
...@@ -26,18 +26,6 @@ namespace web { ...@@ -26,18 +26,6 @@ namespace web {
#include OATPP_CODEGEN_BEGIN(DTO) #include OATPP_CODEGEN_BEGIN(DTO)
class RowRecordDto : public oatpp::data::mapping::type::Object {
DTO_INIT(RowRecordDto, Object)
DTO_FIELD(List<Float32>::ObjectWrapper, record);
};
class RecordsDto : public oatpp::data::mapping::type::Object {
DTO_INIT(RecordsDto, Object)
DTO_FIELD(List<RowRecordDto::ObjectWrapper>::ObjectWrapper, records);
};
class SearchRequestDto : public OObject { class SearchRequestDto : public OObject {
DTO_INIT(SearchRequestDto, Object) DTO_INIT(SearchRequestDto, Object)
......
...@@ -746,15 +746,40 @@ WebRequestHandler::Search(const OString& table_name, const SearchRequestDto::Obj ...@@ -746,15 +746,40 @@ WebRequestHandler::Search(const OString& table_name, const SearchRequestDto::Obj
} }
StatusDto::ObjectWrapper StatusDto::ObjectWrapper
WebRequestHandler::Cmd(const OString& cmd, CommandDto::ObjectWrapper& cmd_dto) { WebRequestHandler::Cmd(const OString& cmd, const OQueryParams& query_params, CommandDto::ObjectWrapper& cmd_dto) {
std::string info = cmd->std_str(); std::string info = cmd->std_str();
auto status = Status::OK();
// TODO: (yhz) now only support load table into memory, may remove in the future
if ("task" == info) {
auto action = query_params.get("action");
if (nullptr == action.get()) {
RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param \'action\' is required in url \'/system/task\'");
}
std::string action_str = action->std_str();
auto target = query_params.get("target");
if (nullptr == target.get()) {
RETURN_STATUS_DTO(QUERY_PARAM_LOSS, "Query param \'target\' is required in url \'/system/task\'");
}
std::string target_str = target->std_str();
if ("load" == action_str) {
status = request_handler_.PreloadTable(context_ptr_, target_str);
} else {
std::string error_msg = std::string("Unknown action value \'") + action_str + "\'";
RETURN_STATUS_DTO(ILLEGAL_QUERY_PARAM, error_msg.c_str());
}
ASSIGN_RETURN_STATUS_DTO(status)
}
if ("info" == info) { if ("info" == info) {
info = "get_system_info"; info = "get_system_info";
} }
std::string reply_str; std::string reply_str;
auto status = CommandLine(info, reply_str); status = CommandLine(info, reply_str);
if (status.ok()) { if (status.ok()) {
cmd_dto->reply = reply_str.c_str(); cmd_dto->reply = reply_str.c_str();
......
...@@ -152,7 +152,7 @@ class WebRequestHandler { ...@@ -152,7 +152,7 @@ class WebRequestHandler {
TopkResultsDto::ObjectWrapper& results_dto); TopkResultsDto::ObjectWrapper& results_dto);
StatusDto::ObjectWrapper StatusDto::ObjectWrapper
Cmd(const OString& cmd, CommandDto::ObjectWrapper& cmd_dto); Cmd(const OString& cmd, const OQueryParams& query_params, CommandDto::ObjectWrapper& cmd_dto);
public: public:
WebRequestHandler& WebRequestHandler&
......
...@@ -413,12 +413,13 @@ TEST_F(WebHandlerTest, CMD) { ...@@ -413,12 +413,13 @@ TEST_F(WebHandlerTest, CMD) {
auto cmd_dto = milvus::server::web::CommandDto::createShared(); auto cmd_dto = milvus::server::web::CommandDto::createShared();
cmd = "status"; cmd = "status";
auto status_dto = handler->Cmd(cmd, cmd_dto); OQueryParams query_params;
auto status_dto = handler->Cmd(cmd, query_params, cmd_dto);
ASSERT_EQ(0, status_dto->code->getValue()); ASSERT_EQ(0, status_dto->code->getValue());
ASSERT_EQ("OK", cmd_dto->reply->std_str()); ASSERT_EQ("OK", cmd_dto->reply->std_str());
cmd = "version"; cmd = "version";
status_dto = handler->Cmd(cmd, cmd_dto); status_dto = handler->Cmd(cmd, query_params, cmd_dto);
ASSERT_EQ(0, status_dto->code->getValue()); ASSERT_EQ(0, status_dto->code->getValue());
ASSERT_EQ("0.6.0", cmd_dto->reply->std_str()); ASSERT_EQ("0.6.0", cmd_dto->reply->std_str());
} }
...@@ -547,7 +548,7 @@ class TestClient : public oatpp::web::client::ApiClient { ...@@ -547,7 +548,7 @@ class TestClient : public oatpp::web::client::ApiClient {
API_CALL("PUT", "/tables/{table_name}/vectors", search, API_CALL("PUT", "/tables/{table_name}/vectors", search,
PATH(String, table_name, "table_name"), BODY_DTO(milvus::server::web::SearchRequestDto::ObjectWrapper, body)) PATH(String, table_name, "table_name"), BODY_DTO(milvus::server::web::SearchRequestDto::ObjectWrapper, body))
API_CALL("GET", "/system/{msg}", cmd, PATH(String, cmd_str, "msg")) API_CALL("GET", "/system/{msg}", cmd, PATH(String, cmd_str, "msg"), QUERY(String, action), QUERY(String, target))
#include OATPP_CODEGEN_END(ApiClient) #include OATPP_CODEGEN_END(ApiClient)
}; };
...@@ -1122,11 +1123,27 @@ TEST_F(WebControllerTest, SEARCH_BIN) { ...@@ -1122,11 +1123,27 @@ TEST_F(WebControllerTest, SEARCH_BIN) {
} }
TEST_F(WebControllerTest, CMD) { TEST_F(WebControllerTest, CMD) {
auto response = client_ptr->cmd("status", conncetion_ptr); auto response = client_ptr->cmd("status", "", "", conncetion_ptr);
ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());
response = client_ptr->cmd("version", "", "", conncetion_ptr);
ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());
response = client_ptr->cmd("mode", "", "", conncetion_ptr);
ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());
response = client_ptr->cmd("version", conncetion_ptr); response = client_ptr->cmd("tasktable", "", "", conncetion_ptr);
ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode()); ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());
response = client_ptr->cmd("info", "", "", conncetion_ptr);
ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());
GenTable("test_cmd", 16, 10, "L2");
response = client_ptr->cmd("task", "load", "test_cmd", conncetion_ptr);
ASSERT_EQ(OStatus::CODE_200.code, response->getStatusCode());
// task without existing table
response = client_ptr->cmd("task", "load", "test_cmdXXXXXXXXXXXX", conncetion_ptr);
ASSERT_EQ(OStatus::CODE_400.code, response->getStatusCode());
} }
TEST_F(WebControllerTest, ADVANCED_CONFIG) { TEST_F(WebControllerTest, ADVANCED_CONFIG) {
......
...@@ -426,9 +426,9 @@ class TestAddBase: ...@@ -426,9 +426,9 @@ class TestAddBase:
nq = 5 nq = 5
vectors = gen_vectors(nq, dim) vectors = gen_vectors(nq, dim)
vector_id = get_vector_id vector_id = get_vector_id
ids = [vector_id for i in range(nq)] ids = [vector_id for _ in range(nq)]
with pytest.raises(Exception) as e: with pytest.raises(Exception):
status, ids = connect.add_vectors(table, vectors, ids) connect.add_vectors(table, vectors, ids)
@pytest.mark.timeout(ADD_TIMEOUT) @pytest.mark.timeout(ADD_TIMEOUT)
def test_add_vectors(self, connect, table): def test_add_vectors(self, connect, table):
...@@ -591,7 +591,7 @@ class TestAddBase: ...@@ -591,7 +591,7 @@ class TestAddBase:
'dimension': dim, 'dimension': dim,
'index_file_size': index_file_size, 'index_file_size': index_file_size,
'metric_type': MetricType.L2} 'metric_type': MetricType.L2}
milvus = get_milvus() milvus = get_milvus(args["handler"])
milvus.connect(uri=uri) milvus.connect(uri=uri)
milvus.create_table(param) milvus.create_table(param)
vector = gen_single_vector(dim) vector = gen_single_vector(dim)
...@@ -599,7 +599,7 @@ class TestAddBase: ...@@ -599,7 +599,7 @@ class TestAddBase:
loop_num = 5 loop_num = 5
processes = [] processes = []
def add(): def add():
milvus = get_milvus() milvus = get_milvus(args["handler"])
milvus.connect(uri=uri) milvus.connect(uri=uri)
i = 0 i = 0
while i < loop_num: while i < loop_num:
......
...@@ -36,7 +36,7 @@ class TestConnect: ...@@ -36,7 +36,7 @@ class TestConnect:
expected: raise an error after disconnected expected: raise an error after disconnected
''' '''
if not connect.connected(): if not connect.connected():
milvus = get_milvus() milvus = get_milvus(args["handler"])
uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
milvus.connect(uri=uri_value) milvus.connect(uri=uri_value)
res = milvus.disconnect() res = milvus.disconnect()
...@@ -53,7 +53,7 @@ class TestConnect: ...@@ -53,7 +53,7 @@ class TestConnect:
method: set correct ip and port method: set correct ip and port
expected: connected is True expected: connected is True
''' '''
milvus = get_milvus() milvus = get_milvus(args["handler"])
milvus.connect(host=args["ip"], port=args["port"]) milvus.connect(host=args["ip"], port=args["port"])
assert milvus.connected() assert milvus.connected()
...@@ -63,7 +63,7 @@ class TestConnect: ...@@ -63,7 +63,7 @@ class TestConnect:
method: set correct ip and port method: set correct ip and port
expected: connected is False expected: connected is False
''' '''
milvus = get_milvus() milvus = get_milvus(args["handler"])
milvus.connect(host=args["ip"], port=args["port"]) milvus.connect(host=args["ip"], port=args["port"])
milvus.disconnect() milvus.disconnect()
assert not milvus.connected() assert not milvus.connected()
...@@ -75,7 +75,7 @@ class TestConnect: ...@@ -75,7 +75,7 @@ class TestConnect:
method: set host localhost method: set host localhost
expected: connected is True expected: connected is True
''' '''
milvus = get_milvus() milvus = get_milvus(args["handler"])
milvus.connect(host='localhost', port=args["port"]) milvus.connect(host='localhost', port=args["port"])
assert milvus.connected() assert milvus.connected()
...@@ -86,7 +86,7 @@ class TestConnect: ...@@ -86,7 +86,7 @@ class TestConnect:
method: set host null method: set host null
expected: not use default ip, connected is False expected: not use default ip, connected is False
''' '''
milvus = get_milvus() milvus = get_milvus(args["handler"])
ip = "" ip = ""
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
milvus.connect(host=ip, port=args["port"], timeout=1) milvus.connect(host=ip, port=args["port"], timeout=1)
...@@ -98,7 +98,7 @@ class TestConnect: ...@@ -98,7 +98,7 @@ class TestConnect:
method: uri format and value are both correct method: uri format and value are both correct
expected: connected is True expected: connected is True
''' '''
milvus = get_milvus() milvus = get_milvus(args["handler"])
uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
milvus.connect(uri=uri_value) milvus.connect(uri=uri_value)
assert milvus.connected() assert milvus.connected()
...@@ -109,7 +109,7 @@ class TestConnect: ...@@ -109,7 +109,7 @@ class TestConnect:
method: uri set null method: uri set null
expected: connected is True expected: connected is True
''' '''
milvus = get_milvus() milvus = get_milvus(args["handler"])
uri_value = "" uri_value = ""
if self.local_ip(args): if self.local_ip(args):
...@@ -128,7 +128,7 @@ class TestConnect: ...@@ -128,7 +128,7 @@ class TestConnect:
method: set uri port null method: set uri port null
expected: connected is True expected: connected is True
''' '''
milvus = get_milvus() milvus = get_milvus(args["handler"])
uri_value = "tcp://%s:" % args["ip"] uri_value = "tcp://%s:" % args["ip"]
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
milvus.connect(uri=uri_value, timeout=1) milvus.connect(uri=uri_value, timeout=1)
...@@ -141,7 +141,7 @@ class TestConnect: ...@@ -141,7 +141,7 @@ class TestConnect:
method: set uri ip null method: set uri ip null
expected: connected is True expected: connected is True
''' '''
milvus = get_milvus() milvus = get_milvus(args["handler"])
uri_value = "tcp://:%s" % args["port"] uri_value = "tcp://:%s" % args["port"]
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
...@@ -166,7 +166,7 @@ class TestConnect: ...@@ -166,7 +166,7 @@ class TestConnect:
assert milvus.connected() assert milvus.connected()
for i in range(process_num): for i in range(process_num):
milvus = get_milvus() milvus = get_milvus(args["handler"])
p = Process(target=connect, args=(milvus, )) p = Process(target=connect, args=(milvus, ))
processes.append(p) processes.append(p)
p.start() p.start()
...@@ -179,7 +179,7 @@ class TestConnect: ...@@ -179,7 +179,7 @@ class TestConnect:
method: connect again method: connect again
expected: status.code is 0, and status.message shows have connected already expected: status.code is 0, and status.message shows have connected already
''' '''
milvus = get_milvus() milvus = get_milvus(args["handler"])
uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
milvus.connect(uri=uri_value) milvus.connect(uri=uri_value)
...@@ -192,7 +192,7 @@ class TestConnect: ...@@ -192,7 +192,7 @@ class TestConnect:
method: disconnect, and then connect, assert connect status method: disconnect, and then connect, assert connect status
expected: status.code is 0 expected: status.code is 0
''' '''
milvus = get_milvus() milvus = get_milvus(args["handler"])
uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
milvus.connect(uri=uri_value) milvus.connect(uri=uri_value)
...@@ -207,7 +207,7 @@ class TestConnect: ...@@ -207,7 +207,7 @@ class TestConnect:
expected: status.code is 0 expected: status.code is 0
''' '''
times = 10 times = 10
milvus = get_milvus() milvus = get_milvus(args["handler"])
uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
milvus.connect(uri=uri_value) milvus.connect(uri=uri_value)
for i in range(times): for i in range(times):
...@@ -233,7 +233,7 @@ class TestConnect: ...@@ -233,7 +233,7 @@ class TestConnect:
assert milvus.connected() assert milvus.connected()
for i in range(process_num): for i in range(process_num):
milvus = get_milvus() milvus = get_milvus(args["handler"])
p = Process(target=connect, args=(milvus, )) p = Process(target=connect, args=(milvus, ))
processes.append(p) processes.append(p)
p.start() p.start()
...@@ -246,7 +246,7 @@ class TestConnect: ...@@ -246,7 +246,7 @@ class TestConnect:
method: port set "", check if wrong uri connection is ok method: port set "", check if wrong uri connection is ok
expected: connect raise an exception and connected is false expected: connect raise an exception and connected is false
''' '''
milvus = get_milvus() milvus = get_milvus(args["handler"])
uri_value = "tcp://%s:39540" % args["ip"] uri_value = "tcp://%s:39540" % args["ip"]
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
milvus.connect(host=args["ip"], port="", uri=uri_value) milvus.connect(host=args["ip"], port="", uri=uri_value)
...@@ -257,7 +257,7 @@ class TestConnect: ...@@ -257,7 +257,7 @@ class TestConnect:
method: host set "", check if correct uri connection is ok method: host set "", check if correct uri connection is ok
expected: connected is False expected: connected is False
''' '''
milvus = get_milvus() milvus = get_milvus(args["handler"])
uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
milvus.connect(host="", port=args["port"], uri=uri_value, timeout=1) milvus.connect(host="", port=args["port"], uri=uri_value, timeout=1)
...@@ -270,7 +270,7 @@ class TestConnect: ...@@ -270,7 +270,7 @@ class TestConnect:
method: check if wrong uri connection is ok method: check if wrong uri connection is ok
expected: connect raise an exception and connected is false expected: connect raise an exception and connected is false
''' '''
milvus = get_milvus() milvus = get_milvus(args["handler"])
uri_value = "tcp://%s:%s" % (args["ip"], args["port"]) uri_value = "tcp://%s:%s" % (args["ip"], args["port"])
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
res = milvus.connect(host=args["ip"], port=39540, uri=uri_value, timeout=1) res = milvus.connect(host=args["ip"], port=39540, uri=uri_value, timeout=1)
...@@ -326,7 +326,7 @@ class TestConnectIPInvalid(object): ...@@ -326,7 +326,7 @@ class TestConnectIPInvalid(object):
@pytest.mark.level(2) @pytest.mark.level(2)
@pytest.mark.timeout(CONNECT_TIMEOUT) @pytest.mark.timeout(CONNECT_TIMEOUT)
def test_connect_with_invalid_ip(self, args, get_invalid_ip): def test_connect_with_invalid_ip(self, args, get_invalid_ip):
milvus = get_milvus() milvus = get_milvus(args["handler"])
ip = get_invalid_ip ip = get_invalid_ip
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
milvus.connect(host=ip, port=args["port"], timeout=1) milvus.connect(host=ip, port=args["port"], timeout=1)
...@@ -353,7 +353,7 @@ class TestConnectPortInvalid(object): ...@@ -353,7 +353,7 @@ class TestConnectPortInvalid(object):
method: set port in gen_invalid_ports method: set port in gen_invalid_ports
expected: connected is False expected: connected is False
''' '''
milvus = get_milvus() milvus = get_milvus(args["handler"])
port = get_invalid_port port = get_invalid_port
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
milvus.connect(host=args["ip"], port=port, timeout=1) milvus.connect(host=args["ip"], port=port, timeout=1)
...@@ -373,13 +373,13 @@ class TestConnectURIInvalid(object): ...@@ -373,13 +373,13 @@ class TestConnectURIInvalid(object):
@pytest.mark.level(2) @pytest.mark.level(2)
@pytest.mark.timeout(CONNECT_TIMEOUT) @pytest.mark.timeout(CONNECT_TIMEOUT)
def test_connect_with_invalid_uri(self, get_invalid_uri): def test_connect_with_invalid_uri(self, get_invalid_uri, args):
''' '''
target: test uri connect with invalid uri value target: test uri connect with invalid uri value
method: set port in gen_invalid_uris method: set port in gen_invalid_uris
expected: connected is False expected: connected is False
''' '''
milvus = get_milvus() milvus = get_milvus(args["handler"])
uri_value = get_invalid_uri uri_value = get_invalid_uri
with pytest.raises(Exception) as e: with pytest.raises(Exception) as e:
milvus.connect(uri=uri_value, timeout=1) milvus.connect(uri=uri_value, timeout=1)
......
...@@ -146,7 +146,7 @@ class TestIndexBase: ...@@ -146,7 +146,7 @@ class TestIndexBase:
uri = "tcp://%s:%s" % (args["ip"], args["port"]) uri = "tcp://%s:%s" % (args["ip"], args["port"])
for i in range(process_num): for i in range(process_num):
m = get_milvus() m = get_milvus(args["handler"])
m.connect(uri=uri) m.connect(uri=uri)
p = Process(target=build, args=(m,)) p = Process(target=build, args=(m,))
processes.append(p) processes.append(p)
...@@ -205,7 +205,7 @@ class TestIndexBase: ...@@ -205,7 +205,7 @@ class TestIndexBase:
uri = "tcp://%s:%s" % (args["ip"], args["port"]) uri = "tcp://%s:%s" % (args["ip"], args["port"])
for i in range(process_num): for i in range(process_num):
m = get_milvus() m = get_milvus(args["handler"])
m.connect(uri=uri) m.connect(uri=uri)
ids = i ids = i
p = Process(target=create_index, args=(m,ids)) p = Process(target=create_index, args=(m,ids))
...@@ -669,7 +669,7 @@ class TestIndexIP: ...@@ -669,7 +669,7 @@ class TestIndexIP:
uri = "tcp://%s:%s" % (args["ip"], args["port"]) uri = "tcp://%s:%s" % (args["ip"], args["port"])
for i in range(process_num): for i in range(process_num):
m = get_milvus() m = get_milvus(args["handler"])
m.connect(uri=uri) m.connect(uri=uri)
p = Process(target=build, args=(m,)) p = Process(target=build, args=(m,))
processes.append(p) processes.append(p)
...@@ -726,7 +726,7 @@ class TestIndexIP: ...@@ -726,7 +726,7 @@ class TestIndexIP:
uri = "tcp://%s:%s" % (args["ip"], args["port"]) uri = "tcp://%s:%s" % (args["ip"], args["port"])
for i in range(process_num): for i in range(process_num):
m = get_milvus() m = get_milvus(args["handler"])
m.connect(uri=uri) m.connect(uri=uri)
ids = i ids = i
p = Process(target=create_index, args=(m,ids)) p = Process(target=create_index, args=(m,ids))
......
...@@ -32,7 +32,7 @@ class TestMixBase: ...@@ -32,7 +32,7 @@ class TestMixBase:
query_vecs = [vectors[0], vectors[1]] query_vecs = [vectors[0], vectors[1]]
uri = "tcp://%s:%s" % (args["ip"], args["port"]) uri = "tcp://%s:%s" % (args["ip"], args["port"])
id_0 = 0; id_1 = 0 id_0 = 0; id_1 = 0
milvus_instance = get_milvus() milvus_instance = get_milvus(args["handler"])
milvus_instance.connect(uri=uri) milvus_instance.connect(uri=uri)
milvus_instance.create_table({'table_name': table, milvus_instance.create_table({'table_name': table,
'dimension': dim, 'dimension': dim,
...@@ -60,11 +60,11 @@ class TestMixBase: ...@@ -60,11 +60,11 @@ class TestMixBase:
logging.getLogger().info(status) logging.getLogger().info(status)
assert result[0][0].id == id_0 assert result[0][0].id == id_0
assert result[1][0].id == id_1 assert result[1][0].id == id_1
milvus_instance = get_milvus() milvus_instance = get_milvus(args["handler"])
milvus_instance.connect(uri=uri) milvus_instance.connect(uri=uri)
p_search = Process(target=search, args=(milvus_instance, )) p_search = Process(target=search, args=(milvus_instance, ))
p_search.start() p_search.start()
milvus_instance = get_milvus() milvus_instance = get_milvus(args["handler"])
milvus_instance.connect(uri=uri) milvus_instance.connect(uri=uri)
p_create = Process(target=add_vectors, args=(milvus_instance, )) p_create = Process(target=add_vectors, args=(milvus_instance, ))
p_create.start() p_create.start()
......
...@@ -743,7 +743,7 @@ class TestSearchBase: ...@@ -743,7 +743,7 @@ class TestSearchBase:
'index_type': IndexType.FLAT, 'index_type': IndexType.FLAT,
'store_raw_vector': False} 'store_raw_vector': False}
# create table # create table
milvus = get_milvus() milvus = get_milvus(args["handler"])
milvus.connect(uri=uri) milvus.connect(uri=uri)
milvus.create_table(param) milvus.create_table(param)
vectors, ids = self.init_data(milvus, table, nb=nb) vectors, ids = self.init_data(milvus, table, nb=nb)
...@@ -756,7 +756,7 @@ class TestSearchBase: ...@@ -756,7 +756,7 @@ class TestSearchBase:
assert result[i][0].distance == 0.0 assert result[i][0].distance == 0.0
for i in range(process_num): for i in range(process_num):
milvus = get_milvus() milvus = get_milvus(args["handler"])
milvus.connect(uri=uri) milvus.connect(uri=uri)
p = Process(target=search, args=(milvus, )) p = Process(target=search, args=(milvus, ))
processes.append(p) processes.append(p)
...@@ -784,7 +784,7 @@ class TestSearchBase: ...@@ -784,7 +784,7 @@ class TestSearchBase:
'index_file_size': 10, 'index_file_size': 10,
'metric_type': MetricType.L2} 'metric_type': MetricType.L2}
# create table # create table
milvus = get_milvus() milvus = get_milvus(args["handler"])
milvus.connect(uri=uri) milvus.connect(uri=uri)
milvus.create_table(param) milvus.create_table(param)
status, ids = milvus.add_vectors(table, vectors) status, ids = milvus.add_vectors(table, vectors)
...@@ -826,7 +826,7 @@ class TestSearchBase: ...@@ -826,7 +826,7 @@ class TestSearchBase:
'index_file_size': 10, 'index_file_size': 10,
'metric_type': MetricType.L2} 'metric_type': MetricType.L2}
# create table # create table
milvus = get_milvus() milvus = get_milvus(args["handler"])
milvus.connect(uri=uri) milvus.connect(uri=uri)
milvus.create_table(param) milvus.create_table(param)
status, ids = milvus.add_vectors(table, vectors) status, ids = milvus.add_vectors(table, vectors)
......
...@@ -278,7 +278,7 @@ class TestTable: ...@@ -278,7 +278,7 @@ class TestTable:
process_num = 4 process_num = 4
processes = [] processes = []
for i in range(process_num): for i in range(process_num):
milvus = get_milvus() milvus = get_milvus(args["handler"])
milvus.connect(uri=uri) milvus.connect(uri=uri)
p = Process(target=describetable, args=(milvus,)) p = Process(target=describetable, args=(milvus,))
processes.append(p) processes.append(p)
...@@ -458,7 +458,7 @@ class TestTable: ...@@ -458,7 +458,7 @@ class TestTable:
assert status.OK() assert status.OK()
for i in range(process_num): for i in range(process_num):
milvus = get_milvus() milvus = get_milvus(args["handler"])
milvus.connect(uri=uri) milvus.connect(uri=uri)
p = Process(target=deletetable, args=(milvus,)) p = Process(target=deletetable, args=(milvus,))
processes.append(p) processes.append(p)
...@@ -711,7 +711,7 @@ class TestTable: ...@@ -711,7 +711,7 @@ class TestTable:
processes = [] processes = []
for i in range(process_num): for i in range(process_num):
milvus = get_milvus() milvus = get_milvus(args["handler"])
milvus.connect(uri=uri) milvus.connect(uri=uri)
p = Process(target=showtables, args=(milvus,)) p = Process(target=showtables, args=(milvus,))
processes.append(p) processes.append(p)
......
...@@ -206,7 +206,7 @@ class TestTableCount: ...@@ -206,7 +206,7 @@ class TestTableCount:
process_num = 8 process_num = 8
processes = [] processes = []
for i in range(process_num): for i in range(process_num):
milvus = get_milvus() milvus = get_milvus(args["handler"])
milvus.connect(uri=uri) milvus.connect(uri=uri)
p = Process(target=rows_count, args=(milvus, )) p = Process(target=rows_count, args=(milvus, ))
processes.append(p) processes.append(p)
...@@ -351,7 +351,7 @@ class TestTableCountIP: ...@@ -351,7 +351,7 @@ class TestTableCountIP:
process_num = 8 process_num = 8
processes = [] processes = []
for i in range(process_num): for i in range(process_num):
milvus = get_milvus() milvus = get_milvus(args["handler"])
milvus.connect(uri=uri) milvus.connect(uri=uri)
p = Process(target=rows_count, args=(milvus,)) p = Process(target=rows_count, args=(milvus,))
processes.append(p) processes.append(p)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册