提交 671f3257 编写于 作者: W wangguibao

CTR prediction profiling

上级 57858b93
......@@ -33,8 +33,15 @@ using baidu::paddle_serving::predictor::ctr_prediction::CTRResInstance;
int batch_size = 16;
int sparse_num = 26;
int dense_num = 13;
int thread_num = 1;
int hash_dim = 1000001;
DEFINE_int32(batch_size, 50, "Set the batch size of test file.");
DEFINE_int32(concurrency, 1, "Set the max concurrency of requests");
DEFINE_int32(repeat, 1, "Number of data samples iteration count. Default 1");
DEFINE_bool(enable_profiling,
true,
"Enable profiling. Will supress a lot normal output");
std::vector<float> cont_min = {0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
std::vector<float> cont_diff = {
20, 603, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50};
......@@ -86,7 +93,7 @@ int64_t hash(std::string str) {
int create_req(Request* req,
const std::vector<std::string>& data_list,
int data_index,
int start_index,
int batch_size) {
for (int i = 0; i < batch_size; ++i) {
CTRReqInstance* ins = req->add_instances();
......@@ -94,12 +101,14 @@ int create_req(Request* req,
LOG(ERROR) << "Failed create req instance";
return -1;
}
// add data
// avoid out of boundary
int cur_index = data_index + i;
int cur_index = start_index + i;
if (cur_index >= data_list.size()) {
cur_index = cur_index % data_list.size();
}
std::vector<std::string> feature_list = split(data_list[cur_index], "\t");
for (int fi = 0; fi < dense_num; fi++) {
if (feature_list[fi] == "") {
......@@ -122,10 +131,10 @@ int create_req(Request* req,
}
return 0;
}
void print_res(const Request& req,
const Response& res,
std::string route_tag,
uint64_t mid_ms,
uint64_t elapse_ms) {
if (res.err_code() != 0) {
LOG(ERROR) << "Get result fail :" << res.err_msg();
......@@ -138,64 +147,69 @@ void print_res(const Request& req,
LOG(INFO) << "Receive result " << oss.str();
}
LOG(INFO) << "Succ call predictor[ctr_prediction_service], the tag is: "
<< route_tag << ", mid_ms: " << mid_ms
<< ", elapse_ms: " << elapse_ms;
<< route_tag << ", elapse_ms: " << elapse_ms;
}
void thread_worker(PredictorApi* api,
int thread_id,
int batch_size,
int server_concurrency,
const std::vector<std::string>& data_list) {
// init
Request req;
Response res;
api->thrd_initialize();
std::string line;
int turns = 0;
while (turns < 1000) {
timeval start;
gettimeofday(&start, NULL);
int start_index = 0;
api->thrd_initialize();
while (true) {
api->thrd_clear();
Predictor* predictor = api->fetch_predictor("ctr_prediction_service");
if (!predictor) {
LOG(ERROR) << "Failed fetch predictor: ctr_prediction_service";
return;
}
req.Clear();
res.Clear();
timeval mid;
gettimeofday(&mid, NULL);
uint64_t mid_ms = (mid.tv_sec * 1000 + mid.tv_usec / 1000) -
(start.tv_sec * 1000 + start.tv_usec / 1000);
// wait for other thread
while (g_concurrency.load() >= server_concurrency) {
while (g_concurrency.load() >= FLAGS_concurrency) {
}
g_concurrency++;
LOG(INFO) << "Current concurrency " << g_concurrency.load();
int data_index = turns * batch_size;
if (create_req(&req, data_list, data_index, batch_size) != 0) {
if (create_req(&req, data_list, start_index, FLAGS_batch_size) != 0) {
return;
}
timeval start_run;
gettimeofday(&start_run, NULL);
start_index += FLAGS_batch_size;
timeval start;
gettimeofday(&start, NULL);
if (predictor->inference(&req, &res) != 0) {
LOG(ERROR) << "failed call predictor with req:" << req.ShortDebugString();
return;
}
g_concurrency--;
timeval end;
gettimeofday(&end, NULL);
uint64_t elapse_ms = (end.tv_sec * 1000 + end.tv_usec / 1000) -
(start_run.tv_sec * 1000 + start_run.tv_usec / 1000);
(start.tv_sec * 1000 + start.tv_usec / 1000);
response_time[thread_id].push_back(elapse_ms);
print_res(req, res, predictor->tag(), mid_ms, elapse_ms);
g_concurrency--;
if (!FLAGS_enable_profiling) {
print_res(req, res, predictor->tag(), elapse_ms);
}
LOG(INFO) << "Done. Current concurrency " << g_concurrency.load();
turns++;
}
//
api->thrd_finalize();
}
void calc_time(int server_concurrency, int batch_size) {
std::vector<int> time_list;
for (auto a : response_time) {
......@@ -244,11 +258,12 @@ void calc_time(int server_concurrency, int batch_size) {
}
}
int main(int argc, char** argv) {
google::ParseCommandLineFlags(&argc, &argv, true);
// initialize
PredictorApi api;
response_time.resize(thread_num);
int server_concurrency = thread_num;
// log set
response_time.resize(FLAGS_concurrency);
#ifdef BCLOUD
logging::LoggingSettings settings;
settings.logging_dest = logging::LOG_TO_FILE;
......@@ -282,32 +297,41 @@ int main(int argc, char** argv) {
LOG(ERROR) << "Failed create predictors api!";
return -1;
}
LOG(INFO) << "data sample file: " << data_filename;
LOG(INFO) << "enable_profiling: " << FLAGS_enable_profiling;
if (FLAGS_enable_profiling) {
LOG(INFO) << "In profiling mode, lot of normal output will be supressed. "
<< "Use --enable_profiling=false to turn off this mode";
}
// read data
std::ifstream data_file(data_filename);
if (!data_file) {
std::cout << "read file error \n" << std::endl;
return -1;
}
std::vector<std::string> data_list;
std::string line;
while (getline(data_file, line)) {
data_list.push_back(line);
}
// create threads
std::vector<std::thread*> thread_pool;
for (int i = 0; i < server_concurrency; ++i) {
thread_pool.push_back(new std::thread(thread_worker,
&api,
i,
batch_size,
server_concurrency,
std::ref(data_list)));
for (int i = 0; i < FLAGS_concurrency; ++i) {
thread_pool.push_back(new std::thread(thread_worker, &api, i, data_list));
}
for (int i = 0; i < server_concurrency; ++i) {
for (int i = 0; i < FLAGS_concurrency; ++i) {
thread_pool[i]->join();
delete thread_pool[i];
}
calc_time(server_concurrency, batch_size);
calc_time(FLAGS_concurrency, batch_size);
api.destroy();
return 0;
}
......@@ -160,6 +160,15 @@ int CTRPredictionOp::inference() {
cube_time_us_ += usec;
++cube_req_num_;
cube_req_key_num_ += keys.size();
if (cube_req_num_ >= 1000) {
LOG(INFO) << "Cube request count: " << cube_req_num_;
LOG(INFO) << "Cube request key count: " << cube_req_key_num_;
LOG(INFO) << "Cube request total time: " << cube_time_us_ << "us";
LOG(INFO) << "Average " << cube_time_us_ / cube_req_num_ << "us/req";
LOG(INFO) << "Average " << cube_time_us_ / cube_req_key_num_
<< "us/key";
}
mutex_.unlock();
} else {
ret = cube->seek(table_name, keys, &values);
......
......@@ -64,7 +64,6 @@ int Endpoint::thrd_clear() {
return -1;
}
}
LOG(INFO) << "Succ thrd clear all vars: " << var_size;
return 0;
}
......
......@@ -94,8 +94,6 @@ int PredictorApi::thrd_clear() {
LOG(ERROR) << "Failed thrd clear endpoint:" << it->first;
return -1;
}
LOG(INFO) << "Succ thrd clear endpoint:" << it->first;
}
return 0;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册