From c2f240605750c548ae9794aecfd830076cb95778 Mon Sep 17 00:00:00 2001 From: wangguibao Date: Fri, 27 Sep 2019 15:37:30 +0800 Subject: [PATCH] CTR prediction profiling --- demo-client/src/ctr_prediction.cpp | 112 +++++++++++++++----------- demo-serving/op/ctr_prediction_op.cpp | 48 +++++------ predictor/src/pdserving.cpp | 8 +- 3 files changed, 91 insertions(+), 77 deletions(-) diff --git a/demo-client/src/ctr_prediction.cpp b/demo-client/src/ctr_prediction.cpp index 579ec5cc..73bf029f 100644 --- a/demo-client/src/ctr_prediction.cpp +++ b/demo-client/src/ctr_prediction.cpp @@ -30,7 +30,6 @@ using baidu::paddle_serving::predictor::ctr_prediction::Response; using baidu::paddle_serving::predictor::ctr_prediction::CTRReqInstance; using baidu::paddle_serving::predictor::ctr_prediction::CTRResInstance; -int batch_size = 16; int sparse_num = 26; int dense_num = 13; int hash_dim = 1000001; @@ -157,67 +156,80 @@ void thread_worker(PredictorApi* api, Request req; Response res; std::string line; - int start_index = 0; api->thrd_initialize(); - while (true) { - api->thrd_clear(); + for (int i = 0; i < FLAGS_repeat; ++i) { + int start_index = 0; - Predictor* predictor = api->fetch_predictor("ctr_prediction_service"); - if (!predictor) { - LOG(ERROR) << "Failed fetch predictor: ctr_prediction_service"; - return; - } + while (true) { + if (start_index >= data_list.size()) { + break; + } - req.Clear(); - res.Clear(); + api->thrd_clear(); - // wait for other thread - while (g_concurrency.load() >= FLAGS_concurrency) { - } - g_concurrency++; - LOG(INFO) << "Current concurrency " << g_concurrency.load(); + Predictor* predictor = api->fetch_predictor("ctr_prediction_service"); + if (!predictor) { + LOG(ERROR) << "Failed fetch predictor: ctr_prediction_service"; + return; + } - if (create_req(&req, data_list, start_index, FLAGS_batch_size) != 0) { - return; - } - start_index += FLAGS_batch_size; + req.Clear(); + res.Clear(); + + // wait for other thread + while (g_concurrency.load() >= FLAGS_concurrency) { + } + g_concurrency++; + LOG(INFO) << "Current concurrency " << g_concurrency.load(); - timeval start; - gettimeofday(&start, NULL); + if (create_req(&req, data_list, start_index, FLAGS_batch_size) != 0) { + return; + } + start_index += FLAGS_batch_size; + LOG(INFO) << "start_index = " << start_index; - if (predictor->inference(&req, &res) != 0) { - LOG(ERROR) << "failed call predictor with req:" << req.ShortDebugString(); - return; - } - g_concurrency--; + timeval start; + gettimeofday(&start, NULL); - timeval end; - gettimeofday(&end, NULL); - uint64_t elapse_ms = (end.tv_sec * 1000 + end.tv_usec / 1000) - - (start.tv_sec * 1000 + start.tv_usec / 1000); + if (predictor->inference(&req, &res) != 0) { + LOG(ERROR) << "failed call predictor with req:" + << req.ShortDebugString(); + return; + } + g_concurrency--; - response_time[thread_id].push_back(elapse_ms); + timeval end; + gettimeofday(&end, NULL); + uint64_t elapse_ms = (end.tv_sec * 1000 + end.tv_usec / 1000) - + (start.tv_sec * 1000 + start.tv_usec / 1000); - if (!FLAGS_enable_profiling) { - print_res(req, res, predictor->tag(), elapse_ms); - } + response_time[thread_id].push_back(elapse_ms); - LOG(INFO) << "Done. Current concurrency " << g_concurrency.load(); - } + if (!FLAGS_enable_profiling) { + print_res(req, res, predictor->tag(), elapse_ms); + } + + LOG(INFO) << "Done. Current concurrency " << g_concurrency.load(); + } // end while + } // end for api->thrd_finalize(); } -void calc_time(int server_concurrency, int batch_size) { +void calc_time() { std::vector time_list; for (auto a : response_time) { time_list.insert(time_list.end(), a.begin(), a.end()); } + LOG(INFO) << "Total request : " << (time_list.size()); - LOG(INFO) << "Batch size : " << batch_size; - LOG(INFO) << "Max concurrency : " << server_concurrency; + LOG(INFO) << "Batch size : " << FLAGS_batch_size; + LOG(INFO) << "Max concurrency : " << FLAGS_concurrency; + LOG(INFO) << "enable_profiling: " << FLAGS_enable_profiling; + LOG(INFO) << "repeat count: " << FLAGS_repeat; + float total_time = 0; float max_time = 0; float min_time = 1000000; @@ -226,21 +238,28 @@ void calc_time(int server_concurrency, int batch_size) { if (time_list[i] > max_time) max_time = time_list[i]; if (time_list[i] < min_time) min_time = time_list[i]; } + float mean_time = total_time / (time_list.size()); float var_time; for (int i = 0; i < time_list.size(); ++i) { var_time += (time_list[i] - mean_time) * (time_list[i] - mean_time); } var_time = var_time / time_list.size(); - LOG(INFO) << "Total time : " << total_time / server_concurrency - << " Variance : " << var_time << " Max time : " << max_time - << " Min time : " << min_time; + + LOG(INFO) << "Total time : " << total_time / FLAGS_concurrency << "ms"; + LOG(INFO) << "Variance : " << var_time << "ms"; + LOG(INFO) << "Max time : " << max_time << "ms"; + LOG(INFO) << "Min time : " << min_time << "ms"; + float qps = 0.0; - if (total_time > 0) - qps = (time_list.size() * 1000) / (total_time / server_concurrency); + if (total_time > 0) { + qps = (time_list.size() * 1000) / (total_time / FLAGS_concurrency); + } LOG(INFO) << "QPS: " << qps << "/s"; + LOG(INFO) << "Latency statistics: "; sort(time_list.begin(), time_list.end()); + int percent_pos_50 = time_list.size() * 0.5; int percent_pos_80 = time_list.size() * 0.8; int percent_pos_90 = time_list.size() * 0.9; @@ -299,7 +318,6 @@ int main(int argc, char** argv) { } LOG(INFO) << "data sample file: " << data_filename; - LOG(INFO) << "enable_profiling: " << FLAGS_enable_profiling; if (FLAGS_enable_profiling) { LOG(INFO) << "In profiling mode, lot of normal output will be supressed. " @@ -330,7 +348,7 @@ int main(int argc, char** argv) { delete thread_pool[i]; } - calc_time(FLAGS_concurrency, batch_size); + calc_time(); api.destroy(); return 0; diff --git a/demo-serving/op/ctr_prediction_op.cpp b/demo-serving/op/ctr_prediction_op.cpp index db9571b0..274db260 100644 --- a/demo-serving/op/ctr_prediction_op.cpp +++ b/demo-serving/op/ctr_prediction_op.cpp @@ -148,31 +148,31 @@ int CTRPredictionOp::inference() { int ret; - if (FLAGS_enable_ctr_profiling) { - gettimeofday(&start, NULL); - ret = cube->seek(table_name, keys, &values); - gettimeofday(&end, NULL); - uint64_t usec = - end.tv_sec * 1e6 + end.tv_usec - start.tv_sec * 1e6 - start.tv_usec; - - // Statistics - mutex_.lock(); - cube_time_us_ += usec; - ++cube_req_num_; - cube_req_key_num_ += keys.size(); - - if (cube_req_num_ >= 1000) { - LOG(INFO) << "Cube request count: " << cube_req_num_; - LOG(INFO) << "Cube request key count: " << cube_req_key_num_; - LOG(INFO) << "Cube request total time: " << cube_time_us_ << "us"; - LOG(INFO) << "Average " << cube_time_us_ / cube_req_num_ << "us/req"; - LOG(INFO) << "Average " << cube_time_us_ / cube_req_key_num_ - << "us/key"; - } - mutex_.unlock(); - } else { - ret = cube->seek(table_name, keys, &values); + gettimeofday(&start, NULL); + ret = cube->seek(table_name, keys, &values); + gettimeofday(&end, NULL); + uint64_t usec = + end.tv_sec * 1e6 + end.tv_usec - start.tv_sec * 1e6 - start.tv_usec; + + // Statistics + mutex_.lock(); + cube_time_us_ += usec; + ++cube_req_num_; + cube_req_key_num_ += keys.size(); + + if (cube_req_num_ >= 1000) { + LOG(INFO) << "Cube request count: " << cube_req_num_; + LOG(INFO) << "Cube request key count: " << cube_req_key_num_; + LOG(INFO) << "Cube request total time: " << cube_time_us_ << "us"; + LOG(INFO) << "Average " << cube_time_us_ / cube_req_num_ << "us/req"; + LOG(INFO) << "Average " << cube_time_us_ / cube_req_key_num_ << "us/key"; + + cube_time_us_ = 0; + cube_req_num_ = 0; + cube_req_key_num_ = 0; } + mutex_.unlock(); + // Statistics end if (ret != 0) { fill_response_with_message(res, -1, "Query cube for embeddings error"); diff --git a/predictor/src/pdserving.cpp b/predictor/src/pdserving.cpp index 28247bce..56ffee84 100644 --- a/predictor/src/pdserving.cpp +++ b/predictor/src/pdserving.cpp @@ -51,8 +51,6 @@ using baidu::paddle_serving::predictor::FLAGS_port; using baidu::paddle_serving::configure::InferServiceConf; using baidu::paddle_serving::configure::read_proto_conf; -DECLARE_bool(logtostderr); - void print_revision(std::ostream& os, void*) { #if defined(PDSERVING_VERSION) os << PDSERVING_VERSION; @@ -69,9 +67,6 @@ static bvar::PassiveStatus s_predictor_revision( DEFINE_bool(V, false, "print version, bool"); DEFINE_bool(g, false, "user defined gflag path"); -DEFINE_bool(enable_ctr_profiling, - false, - "Enable profiling in CTR prediction demo"); DECLARE_string(flagfile); namespace bthread { @@ -220,7 +215,8 @@ int main(int argc, char** argv) { } LOG(INFO) << "Succ initialize cube"; - FLAGS_logtostderr = false; + // FATAL messages are output to stderr + FLAGS_stderrthreshold = 3; if (ServerManager::instance().start_and_wait() != 0) { LOG(ERROR) << "Failed start server and wait!"; -- GitLab