提交 732f991e 编写于 作者: W wangguibao

CTR prediction profiling

上级 671f3257
......@@ -30,7 +30,6 @@ using baidu::paddle_serving::predictor::ctr_prediction::Response;
using baidu::paddle_serving::predictor::ctr_prediction::CTRReqInstance;
using baidu::paddle_serving::predictor::ctr_prediction::CTRResInstance;
int batch_size = 16;
int sparse_num = 26;
int dense_num = 13;
int hash_dim = 1000001;
......@@ -157,11 +156,17 @@ void thread_worker(PredictorApi* api,
Request req;
Response res;
std::string line;
int start_index = 0;
api->thrd_initialize();
for (int i = 0; i < FLAGS_repeat; ++i) {
int start_index = 0;
while (true) {
if (start_index >= data_list.size()) {
break;
}
api->thrd_clear();
Predictor* predictor = api->fetch_predictor("ctr_prediction_service");
......@@ -183,12 +188,14 @@ void thread_worker(PredictorApi* api,
return;
}
start_index += FLAGS_batch_size;
LOG(INFO) << "start_index = " << start_index;
timeval start;
gettimeofday(&start, NULL);
if (predictor->inference(&req, &res) != 0) {
LOG(ERROR) << "failed call predictor with req:" << req.ShortDebugString();
LOG(ERROR) << "failed call predictor with req:"
<< req.ShortDebugString();
return;
}
g_concurrency--;
......@@ -205,19 +212,24 @@ void thread_worker(PredictorApi* api,
}
LOG(INFO) << "Done. Current concurrency " << g_concurrency.load();
}
} // end while
} // end for
api->thrd_finalize();
}
void calc_time(int server_concurrency, int batch_size) {
void calc_time() {
std::vector<int> time_list;
for (auto a : response_time) {
time_list.insert(time_list.end(), a.begin(), a.end());
}
LOG(INFO) << "Total request : " << (time_list.size());
LOG(INFO) << "Batch size : " << batch_size;
LOG(INFO) << "Max concurrency : " << server_concurrency;
LOG(INFO) << "Batch size : " << FLAGS_batch_size;
LOG(INFO) << "Max concurrency : " << FLAGS_concurrency;
LOG(INFO) << "enable_profiling: " << FLAGS_enable_profiling;
LOG(INFO) << "repeat count: " << FLAGS_repeat;
float total_time = 0;
float max_time = 0;
float min_time = 1000000;
......@@ -226,21 +238,28 @@ void calc_time(int server_concurrency, int batch_size) {
if (time_list[i] > max_time) max_time = time_list[i];
if (time_list[i] < min_time) min_time = time_list[i];
}
float mean_time = total_time / (time_list.size());
float var_time;
for (int i = 0; i < time_list.size(); ++i) {
var_time += (time_list[i] - mean_time) * (time_list[i] - mean_time);
}
var_time = var_time / time_list.size();
LOG(INFO) << "Total time : " << total_time / server_concurrency
<< " Variance : " << var_time << " Max time : " << max_time
<< " Min time : " << min_time;
LOG(INFO) << "Total time : " << total_time / FLAGS_concurrency << "ms";
LOG(INFO) << "Variance : " << var_time << "ms";
LOG(INFO) << "Max time : " << max_time << "ms";
LOG(INFO) << "Min time : " << min_time << "ms";
float qps = 0.0;
if (total_time > 0)
qps = (time_list.size() * 1000) / (total_time / server_concurrency);
if (total_time > 0) {
qps = (time_list.size() * 1000) / (total_time / FLAGS_concurrency);
}
LOG(INFO) << "QPS: " << qps << "/s";
LOG(INFO) << "Latency statistics: ";
sort(time_list.begin(), time_list.end());
int percent_pos_50 = time_list.size() * 0.5;
int percent_pos_80 = time_list.size() * 0.8;
int percent_pos_90 = time_list.size() * 0.9;
......@@ -299,7 +318,6 @@ int main(int argc, char** argv) {
}
LOG(INFO) << "data sample file: " << data_filename;
LOG(INFO) << "enable_profiling: " << FLAGS_enable_profiling;
if (FLAGS_enable_profiling) {
LOG(INFO) << "In profiling mode, lot of normal output will be supressed. "
......@@ -330,7 +348,7 @@ int main(int argc, char** argv) {
delete thread_pool[i];
}
calc_time(FLAGS_concurrency, batch_size);
calc_time();
api.destroy();
return 0;
......
......@@ -148,7 +148,6 @@ int CTRPredictionOp::inference() {
int ret;
if (FLAGS_enable_ctr_profiling) {
gettimeofday(&start, NULL);
ret = cube->seek(table_name, keys, &values);
gettimeofday(&end, NULL);
......@@ -166,13 +165,14 @@ int CTRPredictionOp::inference() {
LOG(INFO) << "Cube request key count: " << cube_req_key_num_;
LOG(INFO) << "Cube request total time: " << cube_time_us_ << "us";
LOG(INFO) << "Average " << cube_time_us_ / cube_req_num_ << "us/req";
LOG(INFO) << "Average " << cube_time_us_ / cube_req_key_num_
<< "us/key";
LOG(INFO) << "Average " << cube_time_us_ / cube_req_key_num_ << "us/key";
cube_time_us_ = 0;
cube_req_num_ = 0;
cube_req_key_num_ = 0;
}
mutex_.unlock();
} else {
ret = cube->seek(table_name, keys, &values);
}
// Statistics end
if (ret != 0) {
fill_response_with_message(res, -1, "Query cube for embeddings error");
......
......@@ -51,8 +51,6 @@ using baidu::paddle_serving::predictor::FLAGS_port;
using baidu::paddle_serving::configure::InferServiceConf;
using baidu::paddle_serving::configure::read_proto_conf;
DECLARE_bool(logtostderr);
void print_revision(std::ostream& os, void*) {
#if defined(PDSERVING_VERSION)
os << PDSERVING_VERSION;
......@@ -69,9 +67,6 @@ static bvar::PassiveStatus<std::string> s_predictor_revision(
DEFINE_bool(V, false, "print version, bool");
DEFINE_bool(g, false, "user defined gflag path");
DEFINE_bool(enable_ctr_profiling,
false,
"Enable profiling in CTR prediction demo");
DECLARE_string(flagfile);
namespace bthread {
......@@ -220,7 +215,8 @@ int main(int argc, char** argv) {
}
LOG(INFO) << "Succ initialize cube";
FLAGS_logtostderr = false;
// FATAL messages are output to stderr
FLAGS_stderrthreshold = 3;
if (ServerManager::instance().start_and_wait() != 0) {
LOG(ERROR) << "Failed start server and wait!";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册