diff --git a/core/general-client/include/general_model.h b/core/general-client/include/general_model.h index e1820ac2cef6b85a731709e2bc2e3c63494995ce..75c1b53bae08878d530f69c25220d77eb8ba42df 100644 --- a/core/general-client/include/general_model.h +++ b/core/general-client/include/general_model.h @@ -31,6 +31,9 @@ using baidu::paddle_serving::sdk_cpp::Predictor; using baidu::paddle_serving::sdk_cpp::PredictorApi; +DECLARE_bool(profile_client); +DECLARE_bool(profile_server); + // given some input data, pack into pb, and send request namespace baidu { namespace paddle_serving { @@ -45,6 +48,8 @@ class PredictorClient { PredictorClient() {} ~PredictorClient() {} + void init_gflags(std::vector argv); + int init(const std::string& client_conf); void set_predictor_conf(const std::string& conf_path, diff --git a/core/general-client/src/general_model.cpp b/core/general-client/src/general_model.cpp index cf73083517673fd8e311f94888808c49d80734ff..1304288de92d98cbeff2ab41d5dce3649aa7b2be 100644 --- a/core/general-client/src/general_model.cpp +++ b/core/general-client/src/general_model.cpp @@ -19,6 +19,9 @@ #include "core/sdk-cpp/include/predictor_sdk.h" #include "core/util/include/timer.h" +DEFINE_bool(profile_client, false, ""); +DEFINE_bool(profile_server, false, ""); + using baidu::paddle_serving::Timer; using baidu::paddle_serving::predictor::general_model::Request; using baidu::paddle_serving::predictor::general_model::Response; @@ -26,11 +29,30 @@ using baidu::paddle_serving::predictor::general_model::Tensor; using baidu::paddle_serving::predictor::general_model::FeedInst; using baidu::paddle_serving::predictor::general_model::FetchInst; +std::once_flag gflags_init_flag; + namespace baidu { namespace paddle_serving { namespace general_model { using configure::GeneralModelConfig; +void PredictorClient::init_gflags(std::vector argv) { + std::call_once(gflags_init_flag, [&]() { + FLAGS_logtostderr = true; + argv.insert(argv.begin(), "dummy"); + int argc = argv.size(); + char **arr = new char *[argv.size()]; + std::string line; + for (size_t i = 0; i < argv.size(); i++) { + arr[i] = &argv[i][0]; + line += argv[i]; + line += ' '; + } + google::ParseCommandLineFlags(&argc, &arr, true); + VLOG(2) << "Init commandline: " << line; + }); +} + int PredictorClient::init(const std::string &conf_file) { try { GeneralModelConfig model_config; @@ -190,6 +212,13 @@ std::vector> PredictorClient::predict( int64_t client_infer_end = 0; int64_t postprocess_start = 0; int64_t postprocess_end = 0; + + if (FLAGS_profile_client) { + if (FLAGS_profile_server) { + req.set_profile_server(true); + } + } + res.Clear(); if (_predictor->inference(&req, &res) != 0) { LOG(ERROR) << "failed call predictor with req: " << req.ShortDebugString(); @@ -211,20 +240,27 @@ std::vector> PredictorClient::predict( postprocess_end = timeline.TimeStampUS(); } - int op_num = res.profile_time_size() / 2; - - VLOG(2) << "preprocess start: " << preprocess_start; - VLOG(2) << "preprocess end: " << preprocess_end; - VLOG(2) << "client infer start: " << client_infer_start; - VLOG(2) << "op1 start: " << res.profile_time(0); - VLOG(2) << "op1 end: " << res.profile_time(1); - VLOG(2) << "op2 start: " << res.profile_time(2); - VLOG(2) << "op2 end: " << res.profile_time(3); - VLOG(2) << "op3 start: " << res.profile_time(4); - VLOG(2) << "op3 end: " << res.profile_time(5); - VLOG(2) << "client infer end: " << client_infer_end; - VLOG(2) << "client postprocess start: " << postprocess_start; - VLOG(2) << "client postprocess end: " << postprocess_end; + if (FLAGS_profile_client) { + std::ostringstream oss; + oss << "PROFILE\t" + << "prepro_0:" << preprocess_start << " " + << "prepro_1:" << preprocess_end << " " + << "client_infer_0:" << client_infer_start << " " + << "client_infer_1:" << client_infer_end << " "; + + if (FLAGS_profile_server) { + int op_num = res.profile_time_size() / 2; + for (int i = 0; i < op_num; ++i) { + oss << "op" << i << "_0:" << res.profile_time(i * 2) << " "; + oss << "op" << i << "_1:" << res.profile_time(i * 2 + 1) << " "; + } + } + + oss << "postpro_0:" << postprocess_start << " "; + oss << "postpro_1:" << postprocess_end; + + fprintf(stderr, "%s\n", oss.str().c_str()); + } return fetch_result; } diff --git a/core/general-client/src/pybind_general_model.cpp b/core/general-client/src/pybind_general_model.cpp index 898b06356f19627945fa9a1e8115422d327091e3..1ce181c4d0c6207eefeaec77eeacfe63394ed839 100644 --- a/core/general-client/src/pybind_general_model.cpp +++ b/core/general-client/src/pybind_general_model.cpp @@ -31,6 +31,10 @@ PYBIND11_MODULE(serving_client, m) { )pddoc"; py::class_(m, "PredictorClient", py::buffer_protocol()) .def(py::init()) + .def("init_gflags", + [](PredictorClient &self, std::vector argv) { + self.init_gflags(argv); + }) .def("init", [](PredictorClient &self, const std::string &conf) { return self.init(conf);