提交 6ec80f49 编写于 作者: B barriery

add log_id to proto; TODO: recompile pdcodegen

上级 f38a527d
...@@ -227,7 +227,8 @@ class PredictorClient { ...@@ -227,7 +227,8 @@ class PredictorClient {
const std::vector<std::vector<int>>& int_shape, const std::vector<std::vector<int>>& int_shape,
const std::vector<std::string>& fetch_name, const std::vector<std::string>& fetch_name,
PredictorRes& predict_res_batch, // NOLINT PredictorRes& predict_res_batch, // NOLINT
const int& pid); const int& pid,
const uint64_t log_id);
int numpy_predict( int numpy_predict(
const std::vector<std::vector<py::array_t<float>>>& float_feed_batch, const std::vector<std::vector<py::array_t<float>>>& float_feed_batch,
...@@ -238,7 +239,8 @@ class PredictorClient { ...@@ -238,7 +239,8 @@ class PredictorClient {
const std::vector<std::vector<int>>& int_shape, const std::vector<std::vector<int>>& int_shape,
const std::vector<std::string>& fetch_name, const std::vector<std::string>& fetch_name,
PredictorRes& predict_res_batch, // NOLINT PredictorRes& predict_res_batch, // NOLINT
const int& pid); const int& pid,
const uint64_t log_id);
private: private:
PredictorApi _api; PredictorApi _api;
......
...@@ -144,7 +144,8 @@ int PredictorClient::batch_predict( ...@@ -144,7 +144,8 @@ int PredictorClient::batch_predict(
const std::vector<std::vector<int>> &int_shape, const std::vector<std::vector<int>> &int_shape,
const std::vector<std::string> &fetch_name, const std::vector<std::string> &fetch_name,
PredictorRes &predict_res_batch, PredictorRes &predict_res_batch,
const int &pid) { const int &pid,
const uint64_t log_id) {
int batch_size = std::max(float_feed_batch.size(), int_feed_batch.size()); int batch_size = std::max(float_feed_batch.size(), int_feed_batch.size());
predict_res_batch.clear(); predict_res_batch.clear();
...@@ -162,6 +163,8 @@ int PredictorClient::batch_predict( ...@@ -162,6 +163,8 @@ int PredictorClient::batch_predict(
VLOG(2) << "int feed name size: " << int_feed_name.size(); VLOG(2) << "int feed name size: " << int_feed_name.size();
VLOG(2) << "max body size : " << brpc::fLU64::FLAGS_max_body_size; VLOG(2) << "max body size : " << brpc::fLU64::FLAGS_max_body_size;
Request req; Request req;
req.set_log_id(log_id);
VLOG(2) << "(logid=" << req.log_id() << ")";
for (auto &name : fetch_name) { for (auto &name : fetch_name) {
req.add_fetch_var_names(name); req.add_fetch_var_names(name);
} }
...@@ -356,7 +359,8 @@ int PredictorClient::numpy_predict( ...@@ -356,7 +359,8 @@ int PredictorClient::numpy_predict(
const std::vector<std::vector<int>> &int_shape, const std::vector<std::vector<int>> &int_shape,
const std::vector<std::string> &fetch_name, const std::vector<std::string> &fetch_name,
PredictorRes &predict_res_batch, PredictorRes &predict_res_batch,
const int &pid) { const int &pid,
const uint64_t log_id) {
int batch_size = std::max(float_feed_batch.size(), int_feed_batch.size()); int batch_size = std::max(float_feed_batch.size(), int_feed_batch.size());
VLOG(2) << "batch size: " << batch_size; VLOG(2) << "batch size: " << batch_size;
predict_res_batch.clear(); predict_res_batch.clear();
...@@ -374,6 +378,8 @@ int PredictorClient::numpy_predict( ...@@ -374,6 +378,8 @@ int PredictorClient::numpy_predict(
VLOG(2) << "int feed name size: " << int_feed_name.size(); VLOG(2) << "int feed name size: " << int_feed_name.size();
VLOG(2) << "max body size : " << brpc::fLU64::FLAGS_max_body_size; VLOG(2) << "max body size : " << brpc::fLU64::FLAGS_max_body_size;
Request req; Request req;
req.set_log_id(log_id);
VLOG(2) << "(logid=" << req.log_id() << ")";
for (auto &name : fetch_name) { for (auto &name : fetch_name) {
req.add_fetch_var_names(name); req.add_fetch_var_names(name);
} }
......
...@@ -107,7 +107,8 @@ PYBIND11_MODULE(serving_client, m) { ...@@ -107,7 +107,8 @@ PYBIND11_MODULE(serving_client, m) {
const std::vector<std::vector<int>> &int_shape, const std::vector<std::vector<int>> &int_shape,
const std::vector<std::string> &fetch_name, const std::vector<std::string> &fetch_name,
PredictorRes &predict_res_batch, PredictorRes &predict_res_batch,
const int &pid) { const int &pid,
const uint64_t log_id) {
return self.batch_predict(float_feed_batch, return self.batch_predict(float_feed_batch,
float_feed_name, float_feed_name,
float_shape, float_shape,
...@@ -116,7 +117,8 @@ PYBIND11_MODULE(serving_client, m) { ...@@ -116,7 +117,8 @@ PYBIND11_MODULE(serving_client, m) {
int_shape, int_shape,
fetch_name, fetch_name,
predict_res_batch, predict_res_batch,
pid); pid,
log_id);
}, },
py::call_guard<py::gil_scoped_release>()) py::call_guard<py::gil_scoped_release>())
.def("numpy_predict", .def("numpy_predict",
...@@ -131,7 +133,8 @@ PYBIND11_MODULE(serving_client, m) { ...@@ -131,7 +133,8 @@ PYBIND11_MODULE(serving_client, m) {
const std::vector<std::vector<int>> &int_shape, const std::vector<std::vector<int>> &int_shape,
const std::vector<std::string> &fetch_name, const std::vector<std::string> &fetch_name,
PredictorRes &predict_res_batch, PredictorRes &predict_res_batch,
const int &pid) { const int &pid,
const uint64_t log_id) {
return self.numpy_predict(float_feed_batch, return self.numpy_predict(float_feed_batch,
float_feed_name, float_feed_name,
float_shape, float_shape,
...@@ -140,7 +143,8 @@ PYBIND11_MODULE(serving_client, m) { ...@@ -140,7 +143,8 @@ PYBIND11_MODULE(serving_client, m) {
int_shape, int_shape,
fetch_name, fetch_name,
predict_res_batch, predict_res_batch,
pid); pid,
log_id);
}, },
py::call_guard<py::gil_scoped_release>()); py::call_guard<py::gil_scoped_release>());
} }
......
...@@ -37,6 +37,7 @@ message Request { ...@@ -37,6 +37,7 @@ message Request {
repeated FeedInst insts = 1; repeated FeedInst insts = 1;
repeated string fetch_var_names = 2; repeated string fetch_var_names = 2;
optional bool profile_server = 3 [ default = false ]; optional bool profile_server = 3 [ default = false ];
required uint64 log_id = 4 [ default = 0 ];
}; };
message Response { message Response {
......
...@@ -280,6 +280,7 @@ class PdsCodeGenerator : public CodeGenerator { ...@@ -280,6 +280,7 @@ class PdsCodeGenerator : public CodeGenerator {
" baidu::rpc::ClosureGuard done_guard(done);\n" " baidu::rpc::ClosureGuard done_guard(done);\n"
" baidu::rpc::Controller* cntl = \n" " baidu::rpc::Controller* cntl = \n"
" static_cast<baidu::rpc::Controller*>(cntl_base);\n" " static_cast<baidu::rpc::Controller*>(cntl_base);\n"
" cntl->set_log_id(request->log_id());\n"
" ::baidu::paddle_serving::predictor::InferService* svr = \n" " ::baidu::paddle_serving::predictor::InferService* svr = \n"
" " " "
"::baidu::paddle_serving::predictor::InferServiceManager::instance(" "::baidu::paddle_serving::predictor::InferServiceManager::instance("
...@@ -317,6 +318,7 @@ class PdsCodeGenerator : public CodeGenerator { ...@@ -317,6 +318,7 @@ class PdsCodeGenerator : public CodeGenerator {
" baidu::rpc::ClosureGuard done_guard(done);\n" " baidu::rpc::ClosureGuard done_guard(done);\n"
" baidu::rpc::Controller* cntl = \n" " baidu::rpc::Controller* cntl = \n"
" static_cast<baidu::rpc::Controller*>(cntl_base);\n" " static_cast<baidu::rpc::Controller*>(cntl_base);\n"
" cntl->set_log_id(request->log_id());\n"
" ::baidu::paddle_serving::predictor::InferService* svr = \n" " ::baidu::paddle_serving::predictor::InferService* svr = \n"
" " " "
"::baidu::paddle_serving::predictor::InferServiceManager::instance(" "::baidu::paddle_serving::predictor::InferServiceManager::instance("
......
...@@ -37,6 +37,7 @@ message Request { ...@@ -37,6 +37,7 @@ message Request {
repeated FeedInst insts = 1; repeated FeedInst insts = 1;
repeated string fetch_var_names = 2; repeated string fetch_var_names = 2;
optional bool profile_server = 3 [ default = false ]; optional bool profile_server = 3 [ default = false ];
required uint64 log_id = 4 [ default = 0 ];
}; };
message Response { message Response {
......
...@@ -233,7 +233,7 @@ class Client(object): ...@@ -233,7 +233,7 @@ class Client(object):
# key)) # key))
pass pass
def predict(self, feed=None, fetch=None, need_variant_tag=False): def predict(self, feed=None, fetch=None, need_variant_tag=False, log_id=0):
self.profile_.record('py_prepro_0') self.profile_.record('py_prepro_0')
if feed is None or fetch is None: if feed is None or fetch is None:
...@@ -319,12 +319,12 @@ class Client(object): ...@@ -319,12 +319,12 @@ class Client(object):
res = self.client_handle_.numpy_predict( res = self.client_handle_.numpy_predict(
float_slot_batch, float_feed_names, float_shape, int_slot_batch, float_slot_batch, float_feed_names, float_shape, int_slot_batch,
int_feed_names, int_shape, fetch_names, result_batch_handle, int_feed_names, int_shape, fetch_names, result_batch_handle,
self.pid) self.pid, log_id)
elif self.has_numpy_input == False: elif self.has_numpy_input == False:
res = self.client_handle_.batch_predict( res = self.client_handle_.batch_predict(
float_slot_batch, float_feed_names, float_shape, int_slot_batch, float_slot_batch, float_feed_names, float_shape, int_slot_batch,
int_feed_names, int_shape, fetch_names, result_batch_handle, int_feed_names, int_shape, fetch_names, result_batch_handle,
self.pid) self.pid, log_id)
else: else:
raise ValueError( raise ValueError(
"Please make sure the inputs are all in list type or all in numpy.array type" "Please make sure the inputs are all in list type or all in numpy.array type"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册