提交 b33fc331 编写于 作者: B barrierye

replacing server_pid with variant_tag

上级 e451d03e
...@@ -53,15 +53,17 @@ class PredictorRes { ...@@ -53,15 +53,17 @@ class PredictorRes {
const std::string& name) { const std::string& name) {
return _float_map[name]; return _float_map[name];
} }
void set_server_pid(int pid) { _server_pid = pid; } void set_variant_tag(const std::string& variant_tag) {
int server_pid() { return _server_pid; } _variant_tag = variant_tag;
}
const std::string& variant_tag() { return _variant_tag; }
public: public:
std::map<std::string, std::vector<std::vector<int64_t>>> _int64_map; std::map<std::string, std::vector<std::vector<int64_t>>> _int64_map;
std::map<std::string, std::vector<std::vector<float>>> _float_map; std::map<std::string, std::vector<std::vector<float>>> _float_map;
private: private:
int _server_pid; std::string _variant_tag;
}; };
class PredictorClient { class PredictorClient {
......
...@@ -144,7 +144,9 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed, ...@@ -144,7 +144,9 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed,
Timer timeline; Timer timeline;
int64_t preprocess_start = timeline.TimeStampUS(); int64_t preprocess_start = timeline.TimeStampUS();
_api.thrd_clear(); _api.thrd_clear();
_predictor = _api.fetch_predictor("general_model"); std::string variant_tag;
_predictor = _api.fetch_predictor("general_model", &variant_tag);
predict_res.set_variant_tag(variant_tag);
Request req; Request req;
for (auto &name : fetch_name) { for (auto &name : fetch_name) {
...@@ -237,7 +239,6 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed, ...@@ -237,7 +239,6 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed,
} }
postprocess_end = timeline.TimeStampUS(); postprocess_end = timeline.TimeStampUS();
} }
predict_res.set_server_pid(res.server_pid());
} }
if (FLAGS_profile_client) { if (FLAGS_profile_client) {
...@@ -283,7 +284,9 @@ int PredictorClient::batch_predict( ...@@ -283,7 +284,9 @@ int PredictorClient::batch_predict(
int fetch_name_num = fetch_name.size(); int fetch_name_num = fetch_name.size();
_api.thrd_clear(); _api.thrd_clear();
_predictor = _api.fetch_predictor("general_model"); std::string variant_tag;
_predictor = _api.fetch_predictor("general_model", &variant_tag);
predict_res_batch.set_variant_tag(variant_tag);
VLOG(2) << "fetch general model predictor done."; VLOG(2) << "fetch general model predictor done.";
VLOG(2) << "float feed name size: " << float_feed_name.size(); VLOG(2) << "float feed name size: " << float_feed_name.size();
VLOG(2) << "int feed name size: " << int_feed_name.size(); VLOG(2) << "int feed name size: " << int_feed_name.size();
...@@ -401,7 +404,6 @@ int PredictorClient::batch_predict( ...@@ -401,7 +404,6 @@ int PredictorClient::batch_predict(
} }
} }
postprocess_end = timeline.TimeStampUS(); postprocess_end = timeline.TimeStampUS();
predict_res_batch.set_server_pid(res.server_pid());
} }
if (FLAGS_profile_client) { if (FLAGS_profile_client) {
......
...@@ -40,7 +40,8 @@ PYBIND11_MODULE(serving_client, m) { ...@@ -40,7 +40,8 @@ PYBIND11_MODULE(serving_client, m) {
return self.get_float_by_name(name); return self.get_float_by_name(name);
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("server_pid", [](PredictorRes &self) { return self.server_pid(); }); .def("variant_tag",
[](PredictorRes &self) { return self.variant_tag(); });
py::class_<PredictorClient>(m, "PredictorClient", py::buffer_protocol()) py::class_<PredictorClient>(m, "PredictorClient", py::buffer_protocol())
.def(py::init()) .def(py::init())
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#include "core/general-server/op/general_response_op.h" #include "core/general-server/op/general_response_op.h"
#include <unistd.h>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
...@@ -75,11 +74,6 @@ int GeneralResponseOp::inference() { ...@@ -75,11 +74,6 @@ int GeneralResponseOp::inference() {
// response inst with only fetch_var_names // response inst with only fetch_var_names
Response *res = mutable_data<Response>(); Response *res = mutable_data<Response>();
// to let the client know which server the current response comes from
int server_pid = static_cast<int>(getpid());
VLOG(2) << "getpid: " << server_pid;
res->set_server_pid(server_pid);
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
FetchInst *fetch_inst = res->add_insts(); FetchInst *fetch_inst = res->add_insts();
for (auto &idx : fetch_index) { for (auto &idx : fetch_index) {
......
...@@ -41,7 +41,6 @@ message Request { ...@@ -41,7 +41,6 @@ message Request {
message Response { message Response {
repeated FetchInst insts = 1; repeated FetchInst insts = 1;
repeated int64 profile_time = 2; repeated int64 profile_time = 2;
optional int64 server_pid = 3;
}; };
service GeneralModelService { service GeneralModelService {
......
...@@ -43,9 +43,9 @@ class Endpoint { ...@@ -43,9 +43,9 @@ class Endpoint {
int thrd_finalize(); int thrd_finalize();
Predictor* get_predictor(const void* params); Predictor* get_predictor(const void* params, std::string* variant_tag);
Predictor* get_predictor(); Predictor* get_predictor(std::string* variant_tag);
int ret_predictor(Predictor* predictor); int ret_predictor(Predictor* predictor);
......
...@@ -48,24 +48,26 @@ class PredictorApi { ...@@ -48,24 +48,26 @@ class PredictorApi {
return api; return api;
} }
Predictor* fetch_predictor(std::string ep_name) { Predictor* fetch_predictor(std::string ep_name, std::string* variant_tag) {
std::map<std::string, Endpoint*>::iterator it = _endpoints.find(ep_name); std::map<std::string, Endpoint*>::iterator it = _endpoints.find(ep_name);
if (it == _endpoints.end() || !it->second) { if (it == _endpoints.end() || !it->second) {
LOG(ERROR) << "Failed fetch predictor:" LOG(ERROR) << "Failed fetch predictor:"
<< ", ep_name: " << ep_name; << ", ep_name: " << ep_name;
return NULL; return NULL;
} }
return it->second->get_predictor(); return it->second->get_predictor(variant_tag);
} }
Predictor* fetch_predictor(std::string ep_name, const void* params) { Predictor* fetch_predictor(std::string ep_name,
const void* params,
std::string* variant_tag) {
std::map<std::string, Endpoint*>::iterator it = _endpoints.find(ep_name); std::map<std::string, Endpoint*>::iterator it = _endpoints.find(ep_name);
if (it == _endpoints.end() || !it->second) { if (it == _endpoints.end() || !it->second) {
LOG(ERROR) << "Failed fetch predictor:" LOG(ERROR) << "Failed fetch predictor:"
<< ", ep_name: " << ep_name; << ", ep_name: " << ep_name;
return NULL; return NULL;
} }
return it->second->get_predictor(params); return it->second->get_predictor(params, variant_tag);
} }
int free_predictor(Predictor* predictor) { int free_predictor(Predictor* predictor) {
......
...@@ -41,7 +41,6 @@ message Request { ...@@ -41,7 +41,6 @@ message Request {
message Response { message Response {
repeated FetchInst insts = 1; repeated FetchInst insts = 1;
repeated int64 profile_time = 2; repeated int64 profile_time = 2;
optional int64 server_pid = 3;
}; };
service GeneralModelService { service GeneralModelService {
......
...@@ -79,13 +79,15 @@ int Endpoint::thrd_finalize() { ...@@ -79,13 +79,15 @@ int Endpoint::thrd_finalize() {
return 0; return 0;
} }
Predictor* Endpoint::get_predictor() { Predictor* Endpoint::get_predictor(std::string* variant_tag) {
if (_variant_list.size() == 1) { if (_variant_list.size() == 1) {
if (_variant_list[0] == NULL) { if (_variant_list[0] == NULL) {
LOG(ERROR) << "Not valid variant info"; LOG(ERROR) << "Not valid variant info";
return NULL; return NULL;
} }
return _variant_list[0]->get_predictor(); Variant* var = _variant_list[0];
*variant_tag = var->variant_tag();
return var->get_predictor();
} }
if (_abtest_router == NULL) { if (_abtest_router == NULL) {
...@@ -99,6 +101,7 @@ Predictor* Endpoint::get_predictor() { ...@@ -99,6 +101,7 @@ Predictor* Endpoint::get_predictor() {
return NULL; return NULL;
} }
*variant_tag = var->variant_tag();
return var->get_predictor(); return var->get_predictor();
} }
......
...@@ -164,7 +164,7 @@ class Client(object): ...@@ -164,7 +164,7 @@ class Client(object):
raise SystemExit("The shape of feed tensor {} not match.".format( raise SystemExit("The shape of feed tensor {} not match.".format(
key)) key))
def predict(self, feed={}, fetch=[], need_server_pid=False): def predict(self, feed={}, fetch=[], need_variant_tag=False):
int_slot = [] int_slot = []
float_slot = [] float_slot = []
int_feed_names = [] int_feed_names = []
...@@ -199,10 +199,12 @@ class Client(object): ...@@ -199,10 +199,12 @@ class Client(object):
result_map[name] = self.result_handle_.get_float_by_name(name)[ result_map[name] = self.result_handle_.get_float_by_name(name)[
0] 0]
return [result_map, self.result_handle_.server_pid() return [
] if need_server_pid else result_map result_map,
self.result_handle_.variant_tag(),
] if need_variant_tag else result_map
def batch_predict(self, feed_batch=[], fetch=[], need_server_pid=False): def batch_predict(self, feed_batch=[], fetch=[], need_variant_tag=False):
int_slot_batch = [] int_slot_batch = []
float_slot_batch = [] float_slot_batch = []
int_feed_names = [] int_feed_names = []
...@@ -249,8 +251,10 @@ class Client(object): ...@@ -249,8 +251,10 @@ class Client(object):
index] index]
result_map_batch.append(result_map) result_map_batch.append(result_map)
return [result_map, self.result_handle_.server_pid() return [
] if need_server_pid else result_map result_map,
self.result_handle_.variant_tag(),
] if need_variant_tag else result_map
def release(self): def release(self):
self.client_handle_.destroy_predictor() self.client_handle_.destroy_predictor()
......
...@@ -54,7 +54,9 @@ class WebService(object): ...@@ -54,7 +54,9 @@ class WebService(object):
client_service = Client() client_service = Client()
client_service.load_client_config( client_service.load_client_config(
"{}/serving_server_conf.prototxt".format(self.model_config)) "{}/serving_server_conf.prototxt".format(self.model_config))
client_service.connect(["0.0.0.0:{}".format(self.port + 1)]) client_service.add_variant("var1",
["0.0.0.0:{}".format(self.port + 1)], 100)
client_service.connect()
service_name = "/" + self.name + "/prediction" service_name = "/" + self.name + "/prediction"
@app_instance.route(service_name, methods=['POST']) @app_instance.route(service_name, methods=['POST'])
......
...@@ -91,7 +91,8 @@ class WebService(object): ...@@ -91,7 +91,8 @@ class WebService(object):
client = Client() client = Client()
client.load_client_config("{}/serving_server_conf.prototxt".format( client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config)) self.model_config))
client.connect([endpoint]) client.add_variant("var1", [endpoint], 100)
client.connect()
while True: while True:
request_json = inputqueue.get() request_json = inputqueue.get()
feed, fetch = self.preprocess(request_json, request_json["fetch"]) feed, fetch = self.preprocess(request_json, request_json["fetch"])
...@@ -126,7 +127,8 @@ class WebService(object): ...@@ -126,7 +127,8 @@ class WebService(object):
client = Client() client = Client()
client.load_client_config("{}/serving_server_conf.prototxt".format( client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config)) self.model_config))
client.connect(["0.0.0.0:{}".format(self.port + 1)]) client.add_variant("var1", ["0.0.0.0:{}".format(self.port + 1)], 100)
client.connect()
self.idx = 0 self.idx = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册