From b33fc331bd21a21b95c847904f4048857f3a4ed4 Mon Sep 17 00:00:00 2001 From: barrierye Date: Mon, 23 Mar 2020 12:52:35 +0800 Subject: [PATCH] replacing server_pid with variant_tag --- core/general-client/include/general_model.h | 8 +++++--- core/general-client/src/general_model.cpp | 10 ++++++---- core/general-client/src/pybind_general_model.cpp | 3 ++- core/general-server/op/general_response_op.cpp | 6 ------ .../proto/general_model_service.proto | 1 - core/sdk-cpp/include/endpoint.h | 4 ++-- core/sdk-cpp/include/predictor_sdk.h | 10 ++++++---- core/sdk-cpp/proto/general_model_service.proto | 1 - core/sdk-cpp/src/endpoint.cpp | 7 +++++-- python/paddle_serving_client/__init__.py | 16 ++++++++++------ python/paddle_serving_server/web_service.py | 4 +++- python/paddle_serving_server_gpu/web_service.py | 6 ++++-- 12 files changed, 43 insertions(+), 33 deletions(-) diff --git a/core/general-client/include/general_model.h b/core/general-client/include/general_model.h index 28ddca70..d2c63efd 100644 --- a/core/general-client/include/general_model.h +++ b/core/general-client/include/general_model.h @@ -53,15 +53,17 @@ class PredictorRes { const std::string& name) { return _float_map[name]; } - void set_server_pid(int pid) { _server_pid = pid; } - int server_pid() { return _server_pid; } + void set_variant_tag(const std::string& variant_tag) { + _variant_tag = variant_tag; + } + const std::string& variant_tag() { return _variant_tag; } public: std::map>> _int64_map; std::map>> _float_map; private: - int _server_pid; + std::string _variant_tag; }; class PredictorClient { diff --git a/core/general-client/src/general_model.cpp b/core/general-client/src/general_model.cpp index e8267e1b..04bc4169 100644 --- a/core/general-client/src/general_model.cpp +++ b/core/general-client/src/general_model.cpp @@ -144,7 +144,9 @@ int PredictorClient::predict(const std::vector> &float_feed, Timer timeline; int64_t preprocess_start = timeline.TimeStampUS(); _api.thrd_clear(); - _predictor = _api.fetch_predictor("general_model"); + std::string variant_tag; + _predictor = _api.fetch_predictor("general_model", &variant_tag); + predict_res.set_variant_tag(variant_tag); Request req; for (auto &name : fetch_name) { @@ -237,7 +239,6 @@ int PredictorClient::predict(const std::vector> &float_feed, } postprocess_end = timeline.TimeStampUS(); } - predict_res.set_server_pid(res.server_pid()); } if (FLAGS_profile_client) { @@ -283,7 +284,9 @@ int PredictorClient::batch_predict( int fetch_name_num = fetch_name.size(); _api.thrd_clear(); - _predictor = _api.fetch_predictor("general_model"); + std::string variant_tag; + _predictor = _api.fetch_predictor("general_model", &variant_tag); + predict_res_batch.set_variant_tag(variant_tag); VLOG(2) << "fetch general model predictor done."; VLOG(2) << "float feed name size: " << float_feed_name.size(); VLOG(2) << "int feed name size: " << int_feed_name.size(); @@ -401,7 +404,6 @@ int PredictorClient::batch_predict( } } postprocess_end = timeline.TimeStampUS(); - predict_res_batch.set_server_pid(res.server_pid()); } if (FLAGS_profile_client) { diff --git a/core/general-client/src/pybind_general_model.cpp b/core/general-client/src/pybind_general_model.cpp index 3431bfc8..47bc6bd3 100644 --- a/core/general-client/src/pybind_general_model.cpp +++ b/core/general-client/src/pybind_general_model.cpp @@ -40,7 +40,8 @@ PYBIND11_MODULE(serving_client, m) { return self.get_float_by_name(name); }, py::return_value_policy::reference) - .def("server_pid", [](PredictorRes &self) { return self.server_pid(); }); + .def("variant_tag", + [](PredictorRes &self) { return self.variant_tag(); }); py::class_(m, "PredictorClient", py::buffer_protocol()) .def(py::init()) diff --git a/core/general-server/op/general_response_op.cpp b/core/general-server/op/general_response_op.cpp index 271278c9..c5248227 100644 --- a/core/general-server/op/general_response_op.cpp +++ b/core/general-server/op/general_response_op.cpp @@ -13,7 +13,6 @@ // limitations under the License. #include "core/general-server/op/general_response_op.h" -#include #include #include #include @@ -75,11 +74,6 @@ int GeneralResponseOp::inference() { // response inst with only fetch_var_names Response *res = mutable_data(); - // to let the client know which server the current response comes from - int server_pid = static_cast(getpid()); - VLOG(2) << "getpid: " << server_pid; - res->set_server_pid(server_pid); - for (int i = 0; i < batch_size; ++i) { FetchInst *fetch_inst = res->add_insts(); for (auto &idx : fetch_index) { diff --git a/core/general-server/proto/general_model_service.proto b/core/general-server/proto/general_model_service.proto index 8ede69c0..09e2854d 100644 --- a/core/general-server/proto/general_model_service.proto +++ b/core/general-server/proto/general_model_service.proto @@ -41,7 +41,6 @@ message Request { message Response { repeated FetchInst insts = 1; repeated int64 profile_time = 2; - optional int64 server_pid = 3; }; service GeneralModelService { diff --git a/core/sdk-cpp/include/endpoint.h b/core/sdk-cpp/include/endpoint.h index 52926ecf..37fb582b 100644 --- a/core/sdk-cpp/include/endpoint.h +++ b/core/sdk-cpp/include/endpoint.h @@ -43,9 +43,9 @@ class Endpoint { int thrd_finalize(); - Predictor* get_predictor(const void* params); + Predictor* get_predictor(const void* params, std::string* variant_tag); - Predictor* get_predictor(); + Predictor* get_predictor(std::string* variant_tag); int ret_predictor(Predictor* predictor); diff --git a/core/sdk-cpp/include/predictor_sdk.h b/core/sdk-cpp/include/predictor_sdk.h index 65d80672..0cf5a84e 100644 --- a/core/sdk-cpp/include/predictor_sdk.h +++ b/core/sdk-cpp/include/predictor_sdk.h @@ -48,24 +48,26 @@ class PredictorApi { return api; } - Predictor* fetch_predictor(std::string ep_name) { + Predictor* fetch_predictor(std::string ep_name, std::string* variant_tag) { std::map::iterator it = _endpoints.find(ep_name); if (it == _endpoints.end() || !it->second) { LOG(ERROR) << "Failed fetch predictor:" << ", ep_name: " << ep_name; return NULL; } - return it->second->get_predictor(); + return it->second->get_predictor(variant_tag); } - Predictor* fetch_predictor(std::string ep_name, const void* params) { + Predictor* fetch_predictor(std::string ep_name, + const void* params, + std::string* variant_tag) { std::map::iterator it = _endpoints.find(ep_name); if (it == _endpoints.end() || !it->second) { LOG(ERROR) << "Failed fetch predictor:" << ", ep_name: " << ep_name; return NULL; } - return it->second->get_predictor(params); + return it->second->get_predictor(params, variant_tag); } int free_predictor(Predictor* predictor) { diff --git a/core/sdk-cpp/proto/general_model_service.proto b/core/sdk-cpp/proto/general_model_service.proto index fd54ca28..827bb880 100644 --- a/core/sdk-cpp/proto/general_model_service.proto +++ b/core/sdk-cpp/proto/general_model_service.proto @@ -41,7 +41,6 @@ message Request { message Response { repeated FetchInst insts = 1; repeated int64 profile_time = 2; - optional int64 server_pid = 3; }; service GeneralModelService { diff --git a/core/sdk-cpp/src/endpoint.cpp b/core/sdk-cpp/src/endpoint.cpp index 517fe6dd..2d4ed65d 100644 --- a/core/sdk-cpp/src/endpoint.cpp +++ b/core/sdk-cpp/src/endpoint.cpp @@ -79,13 +79,15 @@ int Endpoint::thrd_finalize() { return 0; } -Predictor* Endpoint::get_predictor() { +Predictor* Endpoint::get_predictor(std::string* variant_tag) { if (_variant_list.size() == 1) { if (_variant_list[0] == NULL) { LOG(ERROR) << "Not valid variant info"; return NULL; } - return _variant_list[0]->get_predictor(); + Variant* var = _variant_list[0]; + *variant_tag = var->variant_tag(); + return var->get_predictor(); } if (_abtest_router == NULL) { @@ -99,6 +101,7 @@ Predictor* Endpoint::get_predictor() { return NULL; } + *variant_tag = var->variant_tag(); return var->get_predictor(); } diff --git a/python/paddle_serving_client/__init__.py b/python/paddle_serving_client/__init__.py index 66f00a36..8f33817a 100644 --- a/python/paddle_serving_client/__init__.py +++ b/python/paddle_serving_client/__init__.py @@ -164,7 +164,7 @@ class Client(object): raise SystemExit("The shape of feed tensor {} not match.".format( key)) - def predict(self, feed={}, fetch=[], need_server_pid=False): + def predict(self, feed={}, fetch=[], need_variant_tag=False): int_slot = [] float_slot = [] int_feed_names = [] @@ -199,10 +199,12 @@ class Client(object): result_map[name] = self.result_handle_.get_float_by_name(name)[ 0] - return [result_map, self.result_handle_.server_pid() - ] if need_server_pid else result_map + return [ + result_map, + self.result_handle_.variant_tag(), + ] if need_variant_tag else result_map - def batch_predict(self, feed_batch=[], fetch=[], need_server_pid=False): + def batch_predict(self, feed_batch=[], fetch=[], need_variant_tag=False): int_slot_batch = [] float_slot_batch = [] int_feed_names = [] @@ -249,8 +251,10 @@ class Client(object): index] result_map_batch.append(result_map) - return [result_map, self.result_handle_.server_pid() - ] if need_server_pid else result_map + return [ + result_map, + self.result_handle_.variant_tag(), + ] if need_variant_tag else result_map def release(self): self.client_handle_.destroy_predictor() diff --git a/python/paddle_serving_server/web_service.py b/python/paddle_serving_server/web_service.py index 71614129..1d6fa5f6 100755 --- a/python/paddle_serving_server/web_service.py +++ b/python/paddle_serving_server/web_service.py @@ -54,7 +54,9 @@ class WebService(object): client_service = Client() client_service.load_client_config( "{}/serving_server_conf.prototxt".format(self.model_config)) - client_service.connect(["0.0.0.0:{}".format(self.port + 1)]) + client_service.add_variant("var1", + ["0.0.0.0:{}".format(self.port + 1)], 100) + client_service.connect() service_name = "/" + self.name + "/prediction" @app_instance.route(service_name, methods=['POST']) diff --git a/python/paddle_serving_server_gpu/web_service.py b/python/paddle_serving_server_gpu/web_service.py index 4d88994c..41c1a2b9 100755 --- a/python/paddle_serving_server_gpu/web_service.py +++ b/python/paddle_serving_server_gpu/web_service.py @@ -91,7 +91,8 @@ class WebService(object): client = Client() client.load_client_config("{}/serving_server_conf.prototxt".format( self.model_config)) - client.connect([endpoint]) + client.add_variant("var1", [endpoint], 100) + client.connect() while True: request_json = inputqueue.get() feed, fetch = self.preprocess(request_json, request_json["fetch"]) @@ -126,7 +127,8 @@ class WebService(object): client = Client() client.load_client_config("{}/serving_server_conf.prototxt".format( self.model_config)) - client.connect(["0.0.0.0:{}".format(self.port + 1)]) + client.add_variant("var1", ["0.0.0.0:{}".format(self.port + 1)], 100) + client.connect() self.idx = 0 -- GitLab