提交 c0176fd3 编写于 作者: B barrierye

make client to know the process ID of the server

上级 4d4b9e89
......@@ -53,10 +53,15 @@ class PredictorRes {
const std::string& name) {
return _float_map[name];
}
void set_server_pid(int pid) { _server_pid = pid; }
int server_pid() { return _server_pid; }
public:
std::map<std::string, std::vector<std::vector<int64_t>>> _int64_map;
std::map<std::string, std::vector<std::vector<float>>> _float_map;
private:
int _server_pid;
};
class PredictorClient {
......
......@@ -237,6 +237,7 @@ int PredictorClient::predict(const std::vector<std::vector<float>> &float_feed,
}
postprocess_end = timeline.TimeStampUS();
}
predict_res.set_server_pid(res.server_pid());
}
if (FLAGS_profile_client) {
......@@ -400,6 +401,7 @@ int PredictorClient::batch_predict(
}
}
postprocess_end = timeline.TimeStampUS();
predict_res_batch.set_server_pid(res.server_pid());
}
if (FLAGS_profile_client) {
......
......@@ -39,7 +39,8 @@ PYBIND11_MODULE(serving_client, m) {
[](PredictorRes &self, std::string &name) {
return self.get_float_by_name(name);
},
py::return_value_policy::reference);
py::return_value_policy::reference)
.def("server_pid", [](PredictorRes &self) { return self.server_pid(); });
py::class_<PredictorClient>(m, "PredictorClient", py::buffer_protocol())
.def(py::init())
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "core/general-server/op/general_response_op.h"
#include <unistd.h>
#include <algorithm>
#include <iostream>
#include <memory>
......@@ -74,6 +75,10 @@ int GeneralResponseOp::inference() {
// response inst with only fetch_var_names
Response *res = mutable_data<Response>();
// to let the client know which server the current response comes from
VLOG(2) << "getpid: " << getpid();
res->set_server_pid(static_cast<int>(getpid()));
for (int i = 0; i < batch_size; ++i) {
FetchInst *fetch_inst = res->add_insts();
for (auto &idx : fetch_index) {
......
......@@ -41,6 +41,7 @@ message Request {
message Response {
repeated FetchInst insts = 1;
repeated int64 profile_time = 2;
optional int64 server_pid = 3;
};
service GeneralModelService {
......
......@@ -41,6 +41,7 @@ message Request {
message Response {
repeated FetchInst insts = 1;
repeated int64 profile_time = 2;
optional int64 server_pid = 3;
};
service GeneralModelService {
......
......@@ -18,7 +18,8 @@ import sys
client = Client()
client.load_client_config(sys.argv[1])
client.connect(["127.0.0.1:9393"])
client.add_variant("var1", ["127.0.0.1:9393"], 50)
client.connect()
import paddle
test_reader = paddle.batch(
......@@ -27,5 +28,6 @@ test_reader = paddle.batch(
batch_size=1)
for data in test_reader():
fetch_map = client.predict(feed={"x": data[0][0]}, fetch=["price"])
print("{} {}".format(fetch_map["price"][0], data[0][1][0]))
[fetch_map, server_pid] = client.predict(
feed={"x": data[0][0]}, fetch=["price"], need_server_pid=True)
print("[{}] {} {}".format(server_pid, fetch_map["price"][0], data[0][1][0]))
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=doc-string-missing
import paddle_serving_client
import os
......@@ -27,10 +28,14 @@ float_type = 1
class SDKConfig(object):
def __init__(self):
self.sdk_desc = sdk.SDKConf()
self.endpoints = []
self.tag_list = []
self.cluster_list = []
self.variant_weight_list = []
def set_server_endpoints(self, endpoints):
self.endpoints = endpoints
def add_server_variant(self, tag, cluster, variant_weight):
self.tag_list.append(tag)
self.cluster_list.append(cluster)
self.variant_weight_list.append(variant_weight)
def gen_desc(self):
predictor_desc = sdk.Predictor()
......@@ -38,14 +43,15 @@ class SDKConfig(object):
predictor_desc.service_name = \
"baidu.paddle_serving.predictor.general_model.GeneralModelService"
predictor_desc.endpoint_router = "WeightedRandomRender"
predictor_desc.weighted_random_render_conf.variant_weight_list = "100"
predictor_desc.weighted_random_render_conf.variant_weight_list = "|".join(
self.variant_weight_list)
variant_desc = sdk.VariantConf()
variant_desc.tag = "var1"
variant_desc.naming_conf.cluster = "list://{}".format(":".join(
self.endpoints))
predictor_desc.variants.extend([variant_desc])
for idx, tag in enumerate(self.tag_list):
variant_desc = sdk.VariantConf()
variant_desc.tag = tag
variant_desc.naming_conf.cluster = "list://{}".format(",".join(
self.cluster_list[idx]))
predictor_desc.variants.extend([variant_desc])
self.sdk_desc.predictors.extend([predictor_desc])
self.sdk_desc.default_variant_conf.tag = "default"
......@@ -79,6 +85,7 @@ class Client(object):
self.feed_names_to_idx_ = {}
self.rpath()
self.pid = os.getpid()
self.predictor_sdk_ = SDKConfig()
def rpath(self):
lib_path = os.path.dirname(paddle_serving_client.__file__)
......@@ -130,13 +137,15 @@ class Client(object):
return
def connect(self, endpoints):
def add_variant(self, tag, cluster, variant_weight):
self.predictor_sdk_.add_server_variant(tag, cluster,
str(variant_weight))
def connect(self):
# check whether current endpoint is available
# init from client config
# create predictor here
predictor_sdk = SDKConfig()
predictor_sdk.set_server_endpoints(endpoints)
sdk_desc = predictor_sdk.gen_desc()
sdk_desc = self.predictor_sdk_.gen_desc()
print(sdk_desc)
self.client_handle_.create_predictor_by_desc(sdk_desc.SerializeToString(
))
......@@ -155,7 +164,7 @@ class Client(object):
raise SystemExit("The shape of feed tensor {} not match.".format(
key))
def predict(self, feed={}, fetch=[]):
def predict(self, feed={}, fetch=[], need_server_pid=False):
int_slot = []
float_slot = []
int_feed_names = []
......@@ -190,9 +199,10 @@ class Client(object):
result_map[name] = self.result_handle_.get_float_by_name(name)[
0]
return result_map
return [result_map, self.result_handle_.server_pid()
] if need_server_pid else result_map
def batch_predict(self, feed_batch=[], fetch=[]):
def batch_predict(self, feed_batch=[], fetch=[], need_server_pid=False):
int_slot_batch = []
float_slot_batch = []
int_feed_names = []
......@@ -239,7 +249,8 @@ class Client(object):
index]
result_map_batch.append(result_map)
return result_map_batch
return [result_map, self.result_handle_.server_pid()
] if need_server_pid else result_map
def release(self):
self.client_handle_.destroy_predictor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册