提交 732f991e 编写于 作者: W wangguibao

CTR prediction profiling

上级 671f3257
...@@ -30,7 +30,6 @@ using baidu::paddle_serving::predictor::ctr_prediction::Response; ...@@ -30,7 +30,6 @@ using baidu::paddle_serving::predictor::ctr_prediction::Response;
using baidu::paddle_serving::predictor::ctr_prediction::CTRReqInstance; using baidu::paddle_serving::predictor::ctr_prediction::CTRReqInstance;
using baidu::paddle_serving::predictor::ctr_prediction::CTRResInstance; using baidu::paddle_serving::predictor::ctr_prediction::CTRResInstance;
int batch_size = 16;
int sparse_num = 26; int sparse_num = 26;
int dense_num = 13; int dense_num = 13;
int hash_dim = 1000001; int hash_dim = 1000001;
...@@ -157,11 +156,17 @@ void thread_worker(PredictorApi* api, ...@@ -157,11 +156,17 @@ void thread_worker(PredictorApi* api,
Request req; Request req;
Response res; Response res;
std::string line; std::string line;
int start_index = 0;
api->thrd_initialize(); api->thrd_initialize();
for (int i = 0; i < FLAGS_repeat; ++i) {
int start_index = 0;
while (true) { while (true) {
if (start_index >= data_list.size()) {
break;
}
api->thrd_clear(); api->thrd_clear();
Predictor* predictor = api->fetch_predictor("ctr_prediction_service"); Predictor* predictor = api->fetch_predictor("ctr_prediction_service");
...@@ -183,12 +188,14 @@ void thread_worker(PredictorApi* api, ...@@ -183,12 +188,14 @@ void thread_worker(PredictorApi* api,
return; return;
} }
start_index += FLAGS_batch_size; start_index += FLAGS_batch_size;
LOG(INFO) << "start_index = " << start_index;
timeval start; timeval start;
gettimeofday(&start, NULL); gettimeofday(&start, NULL);
if (predictor->inference(&req, &res) != 0) { if (predictor->inference(&req, &res) != 0) {
LOG(ERROR) << "failed call predictor with req:" << req.ShortDebugString(); LOG(ERROR) << "failed call predictor with req:"
<< req.ShortDebugString();
return; return;
} }
g_concurrency--; g_concurrency--;
...@@ -205,19 +212,24 @@ void thread_worker(PredictorApi* api, ...@@ -205,19 +212,24 @@ void thread_worker(PredictorApi* api,
} }
LOG(INFO) << "Done. Current concurrency " << g_concurrency.load(); LOG(INFO) << "Done. Current concurrency " << g_concurrency.load();
} } // end while
} // end for
api->thrd_finalize(); api->thrd_finalize();
} }
void calc_time(int server_concurrency, int batch_size) { void calc_time() {
std::vector<int> time_list; std::vector<int> time_list;
for (auto a : response_time) { for (auto a : response_time) {
time_list.insert(time_list.end(), a.begin(), a.end()); time_list.insert(time_list.end(), a.begin(), a.end());
} }
LOG(INFO) << "Total request : " << (time_list.size()); LOG(INFO) << "Total request : " << (time_list.size());
LOG(INFO) << "Batch size : " << batch_size; LOG(INFO) << "Batch size : " << FLAGS_batch_size;
LOG(INFO) << "Max concurrency : " << server_concurrency; LOG(INFO) << "Max concurrency : " << FLAGS_concurrency;
LOG(INFO) << "enable_profiling: " << FLAGS_enable_profiling;
LOG(INFO) << "repeat count: " << FLAGS_repeat;
float total_time = 0; float total_time = 0;
float max_time = 0; float max_time = 0;
float min_time = 1000000; float min_time = 1000000;
...@@ -226,21 +238,28 @@ void calc_time(int server_concurrency, int batch_size) { ...@@ -226,21 +238,28 @@ void calc_time(int server_concurrency, int batch_size) {
if (time_list[i] > max_time) max_time = time_list[i]; if (time_list[i] > max_time) max_time = time_list[i];
if (time_list[i] < min_time) min_time = time_list[i]; if (time_list[i] < min_time) min_time = time_list[i];
} }
float mean_time = total_time / (time_list.size()); float mean_time = total_time / (time_list.size());
float var_time; float var_time;
for (int i = 0; i < time_list.size(); ++i) { for (int i = 0; i < time_list.size(); ++i) {
var_time += (time_list[i] - mean_time) * (time_list[i] - mean_time); var_time += (time_list[i] - mean_time) * (time_list[i] - mean_time);
} }
var_time = var_time / time_list.size(); var_time = var_time / time_list.size();
LOG(INFO) << "Total time : " << total_time / server_concurrency
<< " Variance : " << var_time << " Max time : " << max_time LOG(INFO) << "Total time : " << total_time / FLAGS_concurrency << "ms";
<< " Min time : " << min_time; LOG(INFO) << "Variance : " << var_time << "ms";
LOG(INFO) << "Max time : " << max_time << "ms";
LOG(INFO) << "Min time : " << min_time << "ms";
float qps = 0.0; float qps = 0.0;
if (total_time > 0) if (total_time > 0) {
qps = (time_list.size() * 1000) / (total_time / server_concurrency); qps = (time_list.size() * 1000) / (total_time / FLAGS_concurrency);
}
LOG(INFO) << "QPS: " << qps << "/s"; LOG(INFO) << "QPS: " << qps << "/s";
LOG(INFO) << "Latency statistics: "; LOG(INFO) << "Latency statistics: ";
sort(time_list.begin(), time_list.end()); sort(time_list.begin(), time_list.end());
int percent_pos_50 = time_list.size() * 0.5; int percent_pos_50 = time_list.size() * 0.5;
int percent_pos_80 = time_list.size() * 0.8; int percent_pos_80 = time_list.size() * 0.8;
int percent_pos_90 = time_list.size() * 0.9; int percent_pos_90 = time_list.size() * 0.9;
...@@ -299,7 +318,6 @@ int main(int argc, char** argv) { ...@@ -299,7 +318,6 @@ int main(int argc, char** argv) {
} }
LOG(INFO) << "data sample file: " << data_filename; LOG(INFO) << "data sample file: " << data_filename;
LOG(INFO) << "enable_profiling: " << FLAGS_enable_profiling;
if (FLAGS_enable_profiling) { if (FLAGS_enable_profiling) {
LOG(INFO) << "In profiling mode, lot of normal output will be supressed. " LOG(INFO) << "In profiling mode, lot of normal output will be supressed. "
...@@ -330,7 +348,7 @@ int main(int argc, char** argv) { ...@@ -330,7 +348,7 @@ int main(int argc, char** argv) {
delete thread_pool[i]; delete thread_pool[i];
} }
calc_time(FLAGS_concurrency, batch_size); calc_time();
api.destroy(); api.destroy();
return 0; return 0;
......
...@@ -148,7 +148,6 @@ int CTRPredictionOp::inference() { ...@@ -148,7 +148,6 @@ int CTRPredictionOp::inference() {
int ret; int ret;
if (FLAGS_enable_ctr_profiling) {
gettimeofday(&start, NULL); gettimeofday(&start, NULL);
ret = cube->seek(table_name, keys, &values); ret = cube->seek(table_name, keys, &values);
gettimeofday(&end, NULL); gettimeofday(&end, NULL);
...@@ -166,13 +165,14 @@ int CTRPredictionOp::inference() { ...@@ -166,13 +165,14 @@ int CTRPredictionOp::inference() {
LOG(INFO) << "Cube request key count: " << cube_req_key_num_; LOG(INFO) << "Cube request key count: " << cube_req_key_num_;
LOG(INFO) << "Cube request total time: " << cube_time_us_ << "us"; LOG(INFO) << "Cube request total time: " << cube_time_us_ << "us";
LOG(INFO) << "Average " << cube_time_us_ / cube_req_num_ << "us/req"; LOG(INFO) << "Average " << cube_time_us_ / cube_req_num_ << "us/req";
LOG(INFO) << "Average " << cube_time_us_ / cube_req_key_num_ LOG(INFO) << "Average " << cube_time_us_ / cube_req_key_num_ << "us/key";
<< "us/key";
cube_time_us_ = 0;
cube_req_num_ = 0;
cube_req_key_num_ = 0;
} }
mutex_.unlock(); mutex_.unlock();
} else { // Statistics end
ret = cube->seek(table_name, keys, &values);
}
if (ret != 0) { if (ret != 0) {
fill_response_with_message(res, -1, "Query cube for embeddings error"); fill_response_with_message(res, -1, "Query cube for embeddings error");
......
...@@ -51,8 +51,6 @@ using baidu::paddle_serving::predictor::FLAGS_port; ...@@ -51,8 +51,6 @@ using baidu::paddle_serving::predictor::FLAGS_port;
using baidu::paddle_serving::configure::InferServiceConf; using baidu::paddle_serving::configure::InferServiceConf;
using baidu::paddle_serving::configure::read_proto_conf; using baidu::paddle_serving::configure::read_proto_conf;
DECLARE_bool(logtostderr);
void print_revision(std::ostream& os, void*) { void print_revision(std::ostream& os, void*) {
#if defined(PDSERVING_VERSION) #if defined(PDSERVING_VERSION)
os << PDSERVING_VERSION; os << PDSERVING_VERSION;
...@@ -69,9 +67,6 @@ static bvar::PassiveStatus<std::string> s_predictor_revision( ...@@ -69,9 +67,6 @@ static bvar::PassiveStatus<std::string> s_predictor_revision(
DEFINE_bool(V, false, "print version, bool"); DEFINE_bool(V, false, "print version, bool");
DEFINE_bool(g, false, "user defined gflag path"); DEFINE_bool(g, false, "user defined gflag path");
DEFINE_bool(enable_ctr_profiling,
false,
"Enable profiling in CTR prediction demo");
DECLARE_string(flagfile); DECLARE_string(flagfile);
namespace bthread { namespace bthread {
...@@ -220,7 +215,8 @@ int main(int argc, char** argv) { ...@@ -220,7 +215,8 @@ int main(int argc, char** argv) {
} }
LOG(INFO) << "Succ initialize cube"; LOG(INFO) << "Succ initialize cube";
FLAGS_logtostderr = false; // FATAL messages are output to stderr
FLAGS_stderrthreshold = 3;
if (ServerManager::instance().start_and_wait() != 0) { if (ServerManager::instance().start_and_wait() != 0) {
LOG(ERROR) << "Failed start server and wait!"; LOG(ERROR) << "Failed start server and wait!";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册