提交 b33fc331 编写于 作者: B barrierye

replacing server_pid with variant_tag

上级 e451d03e
......@@ -53,15 +53,17 @@ class PredictorRes {
const std::string& name) {
return _float_map[name];
}
void set_server_pid(int pid) { _server_pid = pid; }
int server_pid() { return _server_pid; }
void set_variant_tag(const std::string& variant_tag) {
_variant_tag = variant_tag;
}
const std::string& variant_tag() { return _variant_tag; }
public:
std::map<std::string, std::vector<std::vector<int64_t>>> _int64_map;
std::map<std::string, std::vector<std::vector<float>>> _float_map;
private:
int _server_pid;
std::string _variant_tag;
};
class PredictorClient {
......
......@@ -144,7 +144,9 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed,
Timer timeline;
int64_t preprocess_start = timeline.TimeStampUS();
_api.thrd_clear();
_predictor = _api.fetch_predictor("general_model");
std::string variant_tag;
_predictor = _api.fetch_predictor("general_model", &variant_tag);
predict_res.set_variant_tag(variant_tag);
Request req;
for (auto &name : fetch_name) {
......@@ -237,7 +239,6 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed,
}
postprocess_end = timeline.TimeStampUS();
}
predict_res.set_server_pid(res.server_pid());
}
if (FLAGS_profile_client) {
......@@ -283,7 +284,9 @@ int PredictorClient::batch_predict(
int fetch_name_num = fetch_name.size();
_api.thrd_clear();
_predictor = _api.fetch_predictor("general_model");
std::string variant_tag;
_predictor = _api.fetch_predictor("general_model", &variant_tag);
predict_res_batch.set_variant_tag(variant_tag);
VLOG(2) << "fetch general model predictor done.";
VLOG(2) << "float feed name size: " << float_feed_name.size();
VLOG(2) << "int feed name size: " << int_feed_name.size();
......@@ -401,7 +404,6 @@ int PredictorClient::batch_predict(
}
}
postprocess_end = timeline.TimeStampUS();
predict_res_batch.set_server_pid(res.server_pid());
}
if (FLAGS_profile_client) {
......
......@@ -40,7 +40,8 @@ PYBIND11_MODULE(serving_client, m) {
return self.get_float_by_name(name);
},
py::return_value_policy::reference)
.def("server_pid", [](PredictorRes &self) { return self.server_pid(); });
.def("variant_tag",
[](PredictorRes &self) { return self.variant_tag(); });
py::class_<PredictorClient>(m, "PredictorClient", py::buffer_protocol())
.def(py::init())
......
......@@ -13,7 +13,6 @@
// limitations under the License.
#include "core/general-server/op/general_response_op.h"
#include <unistd.h>
#include <algorithm>
#include <iostream>
#include <memory>
......@@ -75,11 +74,6 @@ int GeneralResponseOp::inference() {
// response inst with only fetch_var_names
Response *res = mutable_data<Response>();
// to let the client know which server the current response comes from
int server_pid = static_cast<int>(getpid());
VLOG(2) << "getpid: " << server_pid;
res->set_server_pid(server_pid);
for (int i = 0; i < batch_size; ++i) {
FetchInst *fetch_inst = res->add_insts();
for (auto &idx : fetch_index) {
......
......@@ -41,7 +41,6 @@ message Request {
message Response {
repeated FetchInst insts = 1;
repeated int64 profile_time = 2;
optional int64 server_pid = 3;
};
service GeneralModelService {
......
......@@ -43,9 +43,9 @@ class Endpoint {
int thrd_finalize();
Predictor* get_predictor(const void* params);
Predictor* get_predictor(const void* params, std::string* variant_tag);
Predictor* get_predictor();
Predictor* get_predictor(std::string* variant_tag);
int ret_predictor(Predictor* predictor);
......
......@@ -48,24 +48,26 @@ class PredictorApi {
return api;
}
Predictor* fetch_predictor(std::string ep_name) {
Predictor* fetch_predictor(std::string ep_name, std::string* variant_tag) {
std::map<std::string, Endpoint*>::iterator it = _endpoints.find(ep_name);
if (it == _endpoints.end() || !it->second) {
LOG(ERROR) << "Failed fetch predictor:"
<< ", ep_name: " << ep_name;
return NULL;
}
return it->second->get_predictor();
return it->second->get_predictor(variant_tag);
}
Predictor* fetch_predictor(std::string ep_name, const void* params) {
Predictor* fetch_predictor(std::string ep_name,
const void* params,
std::string* variant_tag) {
std::map<std::string, Endpoint*>::iterator it = _endpoints.find(ep_name);
if (it == _endpoints.end() || !it->second) {
LOG(ERROR) << "Failed fetch predictor:"
<< ", ep_name: " << ep_name;
return NULL;
}
return it->second->get_predictor(params);
return it->second->get_predictor(params, variant_tag);
}
int free_predictor(Predictor* predictor) {
......
......@@ -41,7 +41,6 @@ message Request {
message Response {
repeated FetchInst insts = 1;
repeated int64 profile_time = 2;
optional int64 server_pid = 3;
};
service GeneralModelService {
......
......@@ -79,13 +79,15 @@ int Endpoint::thrd_finalize() {
return 0;
}
Predictor* Endpoint::get_predictor() {
Predictor* Endpoint::get_predictor(std::string* variant_tag) {
if (_variant_list.size() == 1) {
if (_variant_list[0] == NULL) {
LOG(ERROR) << "Not valid variant info";
return NULL;
}
return _variant_list[0]->get_predictor();
Variant* var = _variant_list[0];
*variant_tag = var->variant_tag();
return var->get_predictor();
}
if (_abtest_router == NULL) {
......@@ -99,6 +101,7 @@ Predictor* Endpoint::get_predictor() {
return NULL;
}
*variant_tag = var->variant_tag();
return var->get_predictor();
}
......
......@@ -164,7 +164,7 @@ class Client(object):
raise SystemExit("The shape of feed tensor {} not match.".format(
key))
def predict(self, feed={}, fetch=[], need_server_pid=False):
def predict(self, feed={}, fetch=[], need_variant_tag=False):
int_slot = []
float_slot = []
int_feed_names = []
......@@ -199,10 +199,12 @@ class Client(object):
result_map[name] = self.result_handle_.get_float_by_name(name)[
0]
return [result_map, self.result_handle_.server_pid()
] if need_server_pid else result_map
return [
result_map,
self.result_handle_.variant_tag(),
] if need_variant_tag else result_map
def batch_predict(self, feed_batch=[], fetch=[], need_server_pid=False):
def batch_predict(self, feed_batch=[], fetch=[], need_variant_tag=False):
int_slot_batch = []
float_slot_batch = []
int_feed_names = []
......@@ -249,8 +251,10 @@ class Client(object):
index]
result_map_batch.append(result_map)
return [result_map, self.result_handle_.server_pid()
] if need_server_pid else result_map
return [
result_map,
self.result_handle_.variant_tag(),
] if need_variant_tag else result_map
def release(self):
self.client_handle_.destroy_predictor()
......
......@@ -54,7 +54,9 @@ class WebService(object):
client_service = Client()
client_service.load_client_config(
"{}/serving_server_conf.prototxt".format(self.model_config))
client_service.connect(["0.0.0.0:{}".format(self.port + 1)])
client_service.add_variant("var1",
["0.0.0.0:{}".format(self.port + 1)], 100)
client_service.connect()
service_name = "/" + self.name + "/prediction"
@app_instance.route(service_name, methods=['POST'])
......
......@@ -91,7 +91,8 @@ class WebService(object):
client = Client()
client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config))
client.connect([endpoint])
client.add_variant("var1", [endpoint], 100)
client.connect()
while True:
request_json = inputqueue.get()
feed, fetch = self.preprocess(request_json, request_json["fetch"])
......@@ -126,7 +127,8 @@ class WebService(object):
client = Client()
client.load_client_config("{}/serving_server_conf.prototxt".format(
self.model_config))
client.connect(["0.0.0.0:{}".format(self.port + 1)])
client.add_variant("var1", ["0.0.0.0:{}".format(self.port + 1)], 100)
client.connect()
self.idx = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册