提交 c0176fd3 编写于 作者: B barrierye

make client to know the process ID of the server

上级 4d4b9e89
...@@ -53,10 +53,15 @@ class PredictorRes { ...@@ -53,10 +53,15 @@ class PredictorRes {
const std::string& name) { const std::string& name) {
return _float_map[name]; return _float_map[name];
} }
void set_server_pid(int pid) { _server_pid = pid; }
int server_pid() { return _server_pid; }
public: public:
std::map<std::string, std::vector<std::vector<int64_t>>> _int64_map; std::map<std::string, std::vector<std::vector<int64_t>>> _int64_map;
std::map<std::string, std::vector<std::vector<float>>> _float_map; std::map<std::string, std::vector<std::vector<float>>> _float_map;
private:
int _server_pid;
}; };
class PredictorClient { class PredictorClient {
......
...@@ -237,6 +237,7 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed, ...@@ -237,6 +237,7 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed,
} }
postprocess_end = timeline.TimeStampUS(); postprocess_end = timeline.TimeStampUS();
} }
predict_res.set_server_pid(res.server_pid());
} }
if (FLAGS_profile_client) { if (FLAGS_profile_client) {
...@@ -400,6 +401,7 @@ int PredictorClient::batch_predict( ...@@ -400,6 +401,7 @@ int PredictorClient::batch_predict(
} }
} }
postprocess_end = timeline.TimeStampUS(); postprocess_end = timeline.TimeStampUS();
predict_res_batch.set_server_pid(res.server_pid());
} }
if (FLAGS_profile_client) { if (FLAGS_profile_client) {
......
...@@ -39,7 +39,8 @@ PYBIND11_MODULE(serving_client, m) { ...@@ -39,7 +39,8 @@ PYBIND11_MODULE(serving_client, m) {
[](PredictorRes &self, std::string &name) { [](PredictorRes &self, std::string &name) {
return self.get_float_by_name(name); return self.get_float_by_name(name);
}, },
py::return_value_policy::reference); py::return_value_policy::reference)
.def("server_pid", [](PredictorRes &self) { return self.server_pid(); });
py::class_<PredictorClient>(m, "PredictorClient", py::buffer_protocol()) py::class_<PredictorClient>(m, "PredictorClient", py::buffer_protocol())
.def(py::init()) .def(py::init())
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "core/general-server/op/general_response_op.h" #include "core/general-server/op/general_response_op.h"
#include <unistd.h>
#include <algorithm> #include <algorithm>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
...@@ -74,6 +75,10 @@ int GeneralResponseOp::inference() { ...@@ -74,6 +75,10 @@ int GeneralResponseOp::inference() {
// response inst with only fetch_var_names // response inst with only fetch_var_names
Response *res = mutable_data<Response>(); Response *res = mutable_data<Response>();
// to let the client know which server the current response comes from
VLOG(2) << "getpid: " << getpid();
res->set_server_pid(static_cast<int>(getpid()));
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
FetchInst *fetch_inst = res->add_insts(); FetchInst *fetch_inst = res->add_insts();
for (auto &idx : fetch_index) { for (auto &idx : fetch_index) {
......
...@@ -41,6 +41,7 @@ message Request { ...@@ -41,6 +41,7 @@ message Request {
message Response { message Response {
repeated FetchInst insts = 1; repeated FetchInst insts = 1;
repeated int64 profile_time = 2; repeated int64 profile_time = 2;
optional int64 server_pid = 3;
}; };
service GeneralModelService { service GeneralModelService {
......
...@@ -41,6 +41,7 @@ message Request { ...@@ -41,6 +41,7 @@ message Request {
message Response { message Response {
repeated FetchInst insts = 1; repeated FetchInst insts = 1;
repeated int64 profile_time = 2; repeated int64 profile_time = 2;
optional int64 server_pid = 3;
}; };
service GeneralModelService { service GeneralModelService {
......
...@@ -18,7 +18,8 @@ import sys ...@@ -18,7 +18,8 @@ import sys
client = Client() client = Client()
client.load_client_config(sys.argv[1]) client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9393"]) client.add_variant("var1", ["127.0.0.1:9393"], 50)
client.connect()
import paddle import paddle
test_reader = paddle.batch( test_reader = paddle.batch(
...@@ -27,5 +28,6 @@ test_reader = paddle.batch( ...@@ -27,5 +28,6 @@ test_reader = paddle.batch(
batch_size=1) batch_size=1)
for data in test_reader(): for data in test_reader():
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"]) [fetch_map, server_pid] = client.predict(
print("{} {}".format(fetch_map["price"][0], data[0][1][0])) feed={"x": data[0][0]}, fetch=["price"], need_server_pid=True)
print("[{}] {} {}".format(server_pid, fetch_map["price"][0], data[0][1][0]))
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing
import paddle_serving_client import paddle_serving_client
import os import os
...@@ -27,10 +28,14 @@ float_type = 1 ...@@ -27,10 +28,14 @@ float_type = 1
class SDKConfig(object): class SDKConfig(object):
def __init__(self): def __init__(self):
self.sdk_desc = sdk.SDKConf() self.sdk_desc = sdk.SDKConf()
self.endpoints = [] self.tag_list = []
self.cluster_list = []
self.variant_weight_list = []
def set_server_endpoints(self, endpoints): def add_server_variant(self, tag, cluster, variant_weight):
self.endpoints = endpoints self.tag_list.append(tag)
self.cluster_list.append(cluster)
self.variant_weight_list.append(variant_weight)
def gen_desc(self): def gen_desc(self):
predictor_desc = sdk.Predictor() predictor_desc = sdk.Predictor()
...@@ -38,14 +43,15 @@ class SDKConfig(object): ...@@ -38,14 +43,15 @@ class SDKConfig(object):
predictor_desc.service_name = \ predictor_desc.service_name = \
"baidu.paddle_serving.predictor.general_model.GeneralModelService" "baidu.paddle_serving.predictor.general_model.GeneralModelService"
predictor_desc.endpoint_router = "WeightedRandomRender" predictor_desc.endpoint_router = "WeightedRandomRender"
predictor_desc.weighted_random_render_conf.variant_weight_list = "100" predictor_desc.weighted_random_render_conf.variant_weight_list = "|".join(
self.variant_weight_list)
variant_desc = sdk.VariantConf() for idx, tag in enumerate(self.tag_list):
variant_desc.tag = "var1" variant_desc = sdk.VariantConf()
variant_desc.naming_conf.cluster = "list://{}".format(":".join( variant_desc.tag = tag
self.endpoints)) variant_desc.naming_conf.cluster = "list://{}".format(",".join(
self.cluster_list[idx]))
predictor_desc.variants.extend([variant_desc]) predictor_desc.variants.extend([variant_desc])
self.sdk_desc.predictors.extend([predictor_desc]) self.sdk_desc.predictors.extend([predictor_desc])
self.sdk_desc.default_variant_conf.tag = "default" self.sdk_desc.default_variant_conf.tag = "default"
...@@ -79,6 +85,7 @@ class Client(object): ...@@ -79,6 +85,7 @@ class Client(object):
self.feed_names_to_idx_ = {} self.feed_names_to_idx_ = {}
self.rpath() self.rpath()
self.pid = os.getpid() self.pid = os.getpid()
self.predictor_sdk_ = SDKConfig()
def rpath(self): def rpath(self):
lib_path = os.path.dirname(paddle_serving_client.__file__) lib_path = os.path.dirname(paddle_serving_client.__file__)
...@@ -130,13 +137,15 @@ class Client(object): ...@@ -130,13 +137,15 @@ class Client(object):
return return
def connect(self, endpoints): def add_variant(self, tag, cluster, variant_weight):
self.predictor_sdk_.add_server_variant(tag, cluster,
str(variant_weight))
def connect(self):
# check whether current endpoint is available # check whether current endpoint is available
# init from client config # init from client config
# create predictor here # create predictor here
predictor_sdk = SDKConfig() sdk_desc = self.predictor_sdk_.gen_desc()
predictor_sdk.set_server_endpoints(endpoints)
sdk_desc = predictor_sdk.gen_desc()
print(sdk_desc) print(sdk_desc)
self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString( self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
)) ))
...@@ -155,7 +164,7 @@ class Client(object): ...@@ -155,7 +164,7 @@ class Client(object):
raise SystemExit("The shape of feed tensor {} not match.".format( raise SystemExit("The shape of feed tensor {} not match.".format(
key)) key))
def predict(self, feed={}, fetch=[]): def predict(self, feed={}, fetch=[], need_server_pid=False):
int_slot = [] int_slot = []
float_slot = [] float_slot = []
int_feed_names = [] int_feed_names = []
...@@ -190,9 +199,10 @@ class Client(object): ...@@ -190,9 +199,10 @@ class Client(object):
result_map[name] = self.result_handle_.get_float_by_name(name)[ result_map[name] = self.result_handle_.get_float_by_name(name)[
0] 0]
return result_map return [result_map, self.result_handle_.server_pid()
] if need_server_pid else result_map
def batch_predict(self, feed_batch=[], fetch=[]): def batch_predict(self, feed_batch=[], fetch=[], need_server_pid=False):
int_slot_batch = [] int_slot_batch = []
float_slot_batch = [] float_slot_batch = []
int_feed_names = [] int_feed_names = []
...@@ -239,7 +249,8 @@ class Client(object): ...@@ -239,7 +249,8 @@ class Client(object):
index] index]
result_map_batch.append(result_map) result_map_batch.append(result_map)
return result_map_batch return [result_map, self.result_handle_.server_pid()
] if need_server_pid else result_map
def release(self): def release(self):
self.client_handle_.destroy_predictor() self.client_handle_.destroy_predictor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册