提交 c2f24060 编写于 作者: W wangguibao

CTR prediction profiling

上级 b1fdac18
...@@ -30,7 +30,6 @@ using baidu::paddle_serving::predictor::ctr_prediction::Response; ...@@ -30,7 +30,6 @@ using baidu::paddle_serving::predictor::ctr_prediction::Response;
using baidu::paddle_serving::predictor::ctr_prediction::CTRReqInstance; using baidu::paddle_serving::predictor::ctr_prediction::CTRReqInstance;
using baidu::paddle_serving::predictor::ctr_prediction::CTRResInstance; using baidu::paddle_serving::predictor::ctr_prediction::CTRResInstance;
int batch_size = 16;
int sparse_num = 26; int sparse_num = 26;
int dense_num = 13; int dense_num = 13;
int hash_dim = 1000001; int hash_dim = 1000001;
...@@ -157,67 +156,80 @@ void thread_worker(PredictorApi* api, ...@@ -157,67 +156,80 @@ void thread_worker(PredictorApi* api,
Request req; Request req;
Response res; Response res;
std::string line; std::string line;
int start_index = 0;
api->thrd_initialize(); api->thrd_initialize();
while (true) { for (int i = 0; i < FLAGS_repeat; ++i) {
api->thrd_clear(); int start_index = 0;
Predictor* predictor = api->fetch_predictor("ctr_prediction_service"); while (true) {
if (!predictor) { if (start_index >= data_list.size()) {
LOG(ERROR) << "Failed fetch predictor: ctr_prediction_service"; break;
return; }
}
req.Clear(); api->thrd_clear();
res.Clear();
// wait for other thread Predictor* predictor = api->fetch_predictor("ctr_prediction_service");
while (g_concurrency.load() >= FLAGS_concurrency) { if (!predictor) {
} LOG(ERROR) << "Failed fetch predictor: ctr_prediction_service";
g_concurrency++; return;
LOG(INFO) << "Current concurrency " << g_concurrency.load(); }
if (create_req(&req, data_list, start_index, FLAGS_batch_size) != 0) { req.Clear();
return; res.Clear();
}
start_index += FLAGS_batch_size; // wait for other thread
while (g_concurrency.load() >= FLAGS_concurrency) {
}
g_concurrency++;
LOG(INFO) << "Current concurrency " << g_concurrency.load();
timeval start; if (create_req(&req, data_list, start_index, FLAGS_batch_size) != 0) {
gettimeofday(&start, NULL); return;
}
start_index += FLAGS_batch_size;
LOG(INFO) << "start_index = " << start_index;
if (predictor->inference(&req, &res) != 0) { timeval start;
LOG(ERROR) << "failed call predictor with req:" << req.ShortDebugString(); gettimeofday(&start, NULL);
return;
}
g_concurrency--;
timeval end; if (predictor->inference(&req, &res) != 0) {
gettimeofday(&end, NULL); LOG(ERROR) << "failed call predictor with req:"
uint64_t elapse_ms = (end.tv_sec * 1000 + end.tv_usec / 1000) - << req.ShortDebugString();
(start.tv_sec * 1000 + start.tv_usec / 1000); return;
}
g_concurrency--;
response_time[thread_id].push_back(elapse_ms); timeval end;
gettimeofday(&end, NULL);
uint64_t elapse_ms = (end.tv_sec * 1000 + end.tv_usec / 1000) -
(start.tv_sec * 1000 + start.tv_usec / 1000);
if (!FLAGS_enable_profiling) { response_time[thread_id].push_back(elapse_ms);
print_res(req, res, predictor->tag(), elapse_ms);
}
LOG(INFO) << "Done. Current concurrency " << g_concurrency.load(); if (!FLAGS_enable_profiling) {
} print_res(req, res, predictor->tag(), elapse_ms);
}
LOG(INFO) << "Done. Current concurrency " << g_concurrency.load();
} // end while
} // end for
api->thrd_finalize(); api->thrd_finalize();
} }
void calc_time(int server_concurrency, int batch_size) { void calc_time() {
std::vector<int> time_list; std::vector<int> time_list;
for (auto a : response_time) { for (auto a : response_time) {
time_list.insert(time_list.end(), a.begin(), a.end()); time_list.insert(time_list.end(), a.begin(), a.end());
} }
LOG(INFO) << "Total request : " << (time_list.size()); LOG(INFO) << "Total request : " << (time_list.size());
LOG(INFO) << "Batch size : " << batch_size; LOG(INFO) << "Batch size : " << FLAGS_batch_size;
LOG(INFO) << "Max concurrency : " << server_concurrency; LOG(INFO) << "Max concurrency : " << FLAGS_concurrency;
LOG(INFO) << "enable_profiling: " << FLAGS_enable_profiling;
LOG(INFO) << "repeat count: " << FLAGS_repeat;
float total_time = 0; float total_time = 0;
float max_time = 0; float max_time = 0;
float min_time = 1000000; float min_time = 1000000;
...@@ -226,21 +238,28 @@ void calc_time(int server_concurrency, int batch_size) { ...@@ -226,21 +238,28 @@ void calc_time(int server_concurrency, int batch_size) {
if (time_list[i] > max_time) max_time = time_list[i]; if (time_list[i] > max_time) max_time = time_list[i];
if (time_list[i] < min_time) min_time = time_list[i]; if (time_list[i] < min_time) min_time = time_list[i];
} }
float mean_time = total_time / (time_list.size()); float mean_time = total_time / (time_list.size());
float var_time; float var_time;
for (int i = 0; i < time_list.size(); ++i) { for (int i = 0; i < time_list.size(); ++i) {
var_time += (time_list[i] - mean_time) * (time_list[i] - mean_time); var_time += (time_list[i] - mean_time) * (time_list[i] - mean_time);
} }
var_time = var_time / time_list.size(); var_time = var_time / time_list.size();
LOG(INFO) << "Total time : " << total_time / server_concurrency
<< " Variance : " << var_time << " Max time : " << max_time LOG(INFO) << "Total time : " << total_time / FLAGS_concurrency << "ms";
<< " Min time : " << min_time; LOG(INFO) << "Variance : " << var_time << "ms";
LOG(INFO) << "Max time : " << max_time << "ms";
LOG(INFO) << "Min time : " << min_time << "ms";
float qps = 0.0; float qps = 0.0;
if (total_time > 0) if (total_time > 0) {
qps = (time_list.size() * 1000) / (total_time / server_concurrency); qps = (time_list.size() * 1000) / (total_time / FLAGS_concurrency);
}
LOG(INFO) << "QPS: " << qps << "/s"; LOG(INFO) << "QPS: " << qps << "/s";
LOG(INFO) << "Latency statistics: "; LOG(INFO) << "Latency statistics: ";
sort(time_list.begin(), time_list.end()); sort(time_list.begin(), time_list.end());
int percent_pos_50 = time_list.size() * 0.5; int percent_pos_50 = time_list.size() * 0.5;
int percent_pos_80 = time_list.size() * 0.8; int percent_pos_80 = time_list.size() * 0.8;
int percent_pos_90 = time_list.size() * 0.9; int percent_pos_90 = time_list.size() * 0.9;
...@@ -299,7 +318,6 @@ int main(int argc, char** argv) { ...@@ -299,7 +318,6 @@ int main(int argc, char** argv) {
} }
LOG(INFO) << "data sample file: " << data_filename; LOG(INFO) << "data sample file: " << data_filename;
LOG(INFO) << "enable_profiling: " << FLAGS_enable_profiling;
if (FLAGS_enable_profiling) { if (FLAGS_enable_profiling) {
LOG(INFO) << "In profiling mode, lot of normal output will be supressed. " LOG(INFO) << "In profiling mode, lot of normal output will be supressed. "
...@@ -330,7 +348,7 @@ int main(int argc, char** argv) { ...@@ -330,7 +348,7 @@ int main(int argc, char** argv) {
delete thread_pool[i]; delete thread_pool[i];
} }
calc_time(FLAGS_concurrency, batch_size); calc_time();
api.destroy(); api.destroy();
return 0; return 0;
......
...@@ -148,31 +148,31 @@ int CTRPredictionOp::inference() { ...@@ -148,31 +148,31 @@ int CTRPredictionOp::inference() {
int ret; int ret;
if (FLAGS_enable_ctr_profiling) { gettimeofday(&start, NULL);
gettimeofday(&start, NULL); ret = cube->seek(table_name, keys, &values);
ret = cube->seek(table_name, keys, &values); gettimeofday(&end, NULL);
gettimeofday(&end, NULL); uint64_t usec =
uint64_t usec = end.tv_sec * 1e6 + end.tv_usec - start.tv_sec * 1e6 - start.tv_usec;
end.tv_sec * 1e6 + end.tv_usec - start.tv_sec * 1e6 - start.tv_usec;
// Statistics
// Statistics mutex_.lock();
mutex_.lock(); cube_time_us_ += usec;
cube_time_us_ += usec; ++cube_req_num_;
++cube_req_num_; cube_req_key_num_ += keys.size();
cube_req_key_num_ += keys.size();
if (cube_req_num_ >= 1000) {
if (cube_req_num_ >= 1000) { LOG(INFO) << "Cube request count: " << cube_req_num_;
LOG(INFO) << "Cube request count: " << cube_req_num_; LOG(INFO) << "Cube request key count: " << cube_req_key_num_;
LOG(INFO) << "Cube request key count: " << cube_req_key_num_; LOG(INFO) << "Cube request total time: " << cube_time_us_ << "us";
LOG(INFO) << "Cube request total time: " << cube_time_us_ << "us"; LOG(INFO) << "Average " << cube_time_us_ / cube_req_num_ << "us/req";
LOG(INFO) << "Average " << cube_time_us_ / cube_req_num_ << "us/req"; LOG(INFO) << "Average " << cube_time_us_ / cube_req_key_num_ << "us/key";
LOG(INFO) << "Average " << cube_time_us_ / cube_req_key_num_
<< "us/key"; cube_time_us_ = 0;
} cube_req_num_ = 0;
mutex_.unlock(); cube_req_key_num_ = 0;
} else {
ret = cube->seek(table_name, keys, &values);
} }
mutex_.unlock();
// Statistics end
if (ret != 0) { if (ret != 0) {
fill_response_with_message(res, -1, "Query cube for embeddings error"); fill_response_with_message(res, -1, "Query cube for embeddings error");
......
...@@ -51,8 +51,6 @@ using baidu::paddle_serving::predictor::FLAGS_port; ...@@ -51,8 +51,6 @@ using baidu::paddle_serving::predictor::FLAGS_port;
using baidu::paddle_serving::configure::InferServiceConf; using baidu::paddle_serving::configure::InferServiceConf;
using baidu::paddle_serving::configure::read_proto_conf; using baidu::paddle_serving::configure::read_proto_conf;
DECLARE_bool(logtostderr);
void print_revision(std::ostream& os, void*) { void print_revision(std::ostream& os, void*) {
#if defined(PDSERVING_VERSION) #if defined(PDSERVING_VERSION)
os << PDSERVING_VERSION; os << PDSERVING_VERSION;
...@@ -69,9 +67,6 @@ static bvar::PassiveStatus<std::string> s_predictor_revision( ...@@ -69,9 +67,6 @@ static bvar::PassiveStatus<std::string> s_predictor_revision(
DEFINE_bool(V, false, "print version, bool"); DEFINE_bool(V, false, "print version, bool");
DEFINE_bool(g, false, "user defined gflag path"); DEFINE_bool(g, false, "user defined gflag path");
DEFINE_bool(enable_ctr_profiling,
false,
"Enable profiling in CTR prediction demo");
DECLARE_string(flagfile); DECLARE_string(flagfile);
namespace bthread { namespace bthread {
...@@ -220,7 +215,8 @@ int main(int argc, char** argv) { ...@@ -220,7 +215,8 @@ int main(int argc, char** argv) {
} }
LOG(INFO) << "Succ initialize cube"; LOG(INFO) << "Succ initialize cube";
FLAGS_logtostderr = false; // FATAL messages are output to stderr
FLAGS_stderrthreshold = 3;
if (ServerManager::instance().start_and_wait() != 0) { if (ServerManager::instance().start_and_wait() != 0) {
LOG(ERROR) << "Failed start server and wait!"; LOG(ERROR) << "Failed start server and wait!";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册