ctr_prediction.cpp 10.5 KB
Newer Older
X
xulongteng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <cstdlib>
#include <fstream>
#include <sstream>
#include <string>
#include <thread>  // NOLINT
#include "sdk-cpp/ctr_prediction.pb.h"
#include "sdk-cpp/include/common.h"
#include "sdk-cpp/include/predictor_sdk.h"
using baidu::paddle_serving::sdk_cpp::Predictor;
using baidu::paddle_serving::sdk_cpp::PredictorApi;
using baidu::paddle_serving::predictor::ctr_prediction::Request;
using baidu::paddle_serving::predictor::ctr_prediction::Response;
using baidu::paddle_serving::predictor::ctr_prediction::CTRReqInstance;
using baidu::paddle_serving::predictor::ctr_prediction::CTRResInstance;

int sparse_num = 26;
int dense_num = 13;
int hash_dim = 1000001;
W
wangguibao 已提交
36 37 38 39 40 41 42 43

DEFINE_int32(batch_size, 50, "Set the batch size of test file.");
DEFINE_int32(concurrency, 1, "Set the max concurrency of requests");
DEFINE_int32(repeat, 1, "Number of data samples iteration count. Default 1");
DEFINE_bool(enable_profiling,
            true,
            "Enable profiling. Will supress a lot normal output");

X
xulongteng 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65
std::vector<float> cont_min = {0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
std::vector<float> cont_diff = {
    20, 603, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50};
char* data_filename = "./data/ctr_prediction/data.txt";
std::atomic<int> g_concurrency(0);
std::vector<std::vector<int>> response_time;

std::vector<std::string> split(const std::string& str,
                               const std::string& pattern) {
  std::vector<std::string> res;
  if (str == "") return res;
  std::string strs = str + pattern;
  size_t pos = strs.find(pattern);
  while (pos != strs.npos) {
    std::string temp = strs.substr(0, pos);
    res.push_back(temp);
    strs = strs.substr(pos + 1, strs.size());
    pos = strs.find(pattern);
  }
  return res;
}

W
wangguibao 已提交
66 67 68 69 70 71 72 73 74
/**
 * Simulate CPython hash function on string objects
 *
 * Our model training process use this function to convert string objects to
 * unique ids.
 *
 * See string_hash() in
 * https://svn.python.org/projects/python/trunk/Objects/stringobject.c
 */
X
xulongteng 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
int64_t hash(std::string str) {
  int64_t len;
  unsigned char* p;
  int64_t x;

  len = str.size();
  p = (unsigned char*)str.c_str();
  x = *p << 7;
  while (--len >= 0) {
    x = (1000003 * x) ^ *p++;
  }
  x ^= str.size();
  if (x == -1) {
    x = -2;
  }
  return x;
}

int create_req(Request* req,
               const std::vector<std::string>& data_list,
W
wangguibao 已提交
95
               int start_index,
X
xulongteng 已提交
96 97 98 99 100 101 102
               int batch_size) {
  for (int i = 0; i < batch_size; ++i) {
    CTRReqInstance* ins = req->add_instances();
    if (!ins) {
      LOG(ERROR) << "Failed create req instance";
      return -1;
    }
W
wangguibao 已提交
103

X
xulongteng 已提交
104
    // add data
W
wangguibao 已提交
105
    // avoid out of boundary
W
wangguibao 已提交
106
    int cur_index = start_index + i;
W
wangguibao 已提交
107 108 109
    if (cur_index >= data_list.size()) {
      cur_index = cur_index % data_list.size();
    }
W
wangguibao 已提交
110

W
wangguibao 已提交
111
    std::vector<std::string> feature_list = split(data_list[cur_index], "\t");
X
xulongteng 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
    for (int fi = 0; fi < dense_num; fi++) {
      if (feature_list[fi] == "") {
        ins->add_dense_ids(0.0);
      } else {
        float dense_id = std::stof(feature_list[fi]);
        dense_id = (dense_id - cont_min[fi]) / cont_diff[fi];
        ins->add_dense_ids(dense_id);
      }
    }
    for (int fi = dense_num; fi < (dense_num + sparse_num); fi++) {
      int64_t sparse_id =
          hash(std::to_string(fi) + feature_list[fi]) % hash_dim;
      if (sparse_id < 0) {
        // diff between c++ and python
        sparse_id += hash_dim;
      }
      ins->add_sparse_ids(sparse_id);
    }
  }
  return 0;
}
W
wangguibao 已提交
133

X
xulongteng 已提交
134 135 136 137 138 139 140 141 142 143 144
void print_res(const Request& req,
               const Response& res,
               std::string route_tag,
               uint64_t elapse_ms) {
  if (res.err_code() != 0) {
    LOG(ERROR) << "Get result fail :" << res.err_msg();
    return;
  }
  for (uint32_t i = 0; i < res.predictions_size(); ++i) {
    const CTRResInstance& res_ins = res.predictions(i);
    std::ostringstream oss;
W
wangguibao 已提交
145
    oss << "[" << res_ins.prob0() << " " << res_ins.prob1() << "]";
X
xulongteng 已提交
146 147 148
    LOG(INFO) << "Receive result " << oss.str();
  }
  LOG(INFO) << "Succ call predictor[ctr_prediction_service], the tag is: "
W
wangguibao 已提交
149
            << route_tag << ", elapse_ms: " << elapse_ms;
X
xulongteng 已提交
150 151 152 153 154 155 156 157 158
}

void thread_worker(PredictorApi* api,
                   int thread_id,
                   const std::vector<std::string>& data_list) {
  // init
  Request req;
  Response res;
  std::string line;
W
wangguibao 已提交
159 160 161

  api->thrd_initialize();

W
wangguibao 已提交
162 163
  for (int i = 0; i < FLAGS_repeat; ++i) {
    int start_index = 0;
W
wangguibao 已提交
164

W
wangguibao 已提交
165 166 167 168
    while (true) {
      if (start_index >= data_list.size()) {
        break;
      }
W
wangguibao 已提交
169

W
wangguibao 已提交
170
      api->thrd_clear();
W
wangguibao 已提交
171

W
wangguibao 已提交
172 173 174 175 176
      Predictor* predictor = api->fetch_predictor("ctr_prediction_service");
      if (!predictor) {
        LOG(ERROR) << "Failed fetch predictor: ctr_prediction_service";
        return;
      }
W
wangguibao 已提交
177

W
wangguibao 已提交
178 179 180 181 182 183 184 185
      req.Clear();
      res.Clear();

      // wait for other thread
      while (g_concurrency.load() >= FLAGS_concurrency) {
      }
      g_concurrency++;
      LOG(INFO) << "Current concurrency " << g_concurrency.load();
W
wangguibao 已提交
186

W
wangguibao 已提交
187 188 189 190 191
      if (create_req(&req, data_list, start_index, FLAGS_batch_size) != 0) {
        return;
      }
      start_index += FLAGS_batch_size;
      LOG(INFO) << "start_index = " << start_index;
W
wangguibao 已提交
192

W
wangguibao 已提交
193 194
      timeval start;
      gettimeofday(&start, NULL);
W
wangguibao 已提交
195

W
wangguibao 已提交
196 197 198 199 200 201
      if (predictor->inference(&req, &res) != 0) {
        LOG(ERROR) << "failed call predictor with req:"
                   << req.ShortDebugString();
        return;
      }
      g_concurrency--;
W
wangguibao 已提交
202

W
wangguibao 已提交
203 204 205 206
      timeval end;
      gettimeofday(&end, NULL);
      uint64_t elapse_ms = (end.tv_sec * 1000 + end.tv_usec / 1000) -
                           (start.tv_sec * 1000 + start.tv_usec / 1000);
W
wangguibao 已提交
207

W
wangguibao 已提交
208
      response_time[thread_id].push_back(elapse_ms);
W
wangguibao 已提交
209

W
wangguibao 已提交
210 211 212 213 214 215 216
      if (!FLAGS_enable_profiling) {
        print_res(req, res, predictor->tag(), elapse_ms);
      }

      LOG(INFO) << "Done. Current concurrency " << g_concurrency.load();
    }  // end while
  }    // end for
W
wangguibao 已提交
217

X
xulongteng 已提交
218 219
  api->thrd_finalize();
}
W
wangguibao 已提交
220

W
wangguibao 已提交
221
void calc_time() {
X
xulongteng 已提交
222 223 224 225
  std::vector<int> time_list;
  for (auto a : response_time) {
    time_list.insert(time_list.end(), a.begin(), a.end());
  }
W
wangguibao 已提交
226

X
xulongteng 已提交
227
  LOG(INFO) << "Total request : " << (time_list.size());
W
wangguibao 已提交
228 229 230 231 232
  LOG(INFO) << "Batch size : " << FLAGS_batch_size;
  LOG(INFO) << "Max concurrency : " << FLAGS_concurrency;
  LOG(INFO) << "enable_profiling: " << FLAGS_enable_profiling;
  LOG(INFO) << "repeat count: " << FLAGS_repeat;

X
xulongteng 已提交
233 234 235 236 237 238 239 240
  float total_time = 0;
  float max_time = 0;
  float min_time = 1000000;
  for (int i = 0; i < time_list.size(); ++i) {
    total_time += time_list[i];
    if (time_list[i] > max_time) max_time = time_list[i];
    if (time_list[i] < min_time) min_time = time_list[i];
  }
W
wangguibao 已提交
241

X
xulongteng 已提交
242 243 244 245 246 247
  float mean_time = total_time / (time_list.size());
  float var_time;
  for (int i = 0; i < time_list.size(); ++i) {
    var_time += (time_list[i] - mean_time) * (time_list[i] - mean_time);
  }
  var_time = var_time / time_list.size();
W
wangguibao 已提交
248 249 250 251 252 253

  LOG(INFO) << "Total time : " << total_time / FLAGS_concurrency << "ms";
  LOG(INFO) << "Variance : " << var_time << "ms";
  LOG(INFO) << "Max time : " << max_time << "ms";
  LOG(INFO) << "Min time : " << min_time << "ms";

X
xulongteng 已提交
254
  float qps = 0.0;
W
wangguibao 已提交
255 256 257
  if (total_time > 0) {
    qps = (time_list.size() * 1000) / (total_time / FLAGS_concurrency);
  }
X
xulongteng 已提交
258
  LOG(INFO) << "QPS: " << qps << "/s";
W
wangguibao 已提交
259

X
xulongteng 已提交
260 261
  LOG(INFO) << "Latency statistics: ";
  sort(time_list.begin(), time_list.end());
W
wangguibao 已提交
262

X
xulongteng 已提交
263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
  int percent_pos_50 = time_list.size() * 0.5;
  int percent_pos_80 = time_list.size() * 0.8;
  int percent_pos_90 = time_list.size() * 0.9;
  int percent_pos_99 = time_list.size() * 0.99;
  int percent_pos_999 = time_list.size() * 0.999;
  if (time_list.size() != 0) {
    LOG(INFO) << "Mean time : " << mean_time;
    LOG(INFO) << "50 percent ms: " << time_list[percent_pos_50];
    LOG(INFO) << "80 percent ms: " << time_list[percent_pos_80];
    LOG(INFO) << "90 percent ms: " << time_list[percent_pos_90];
    LOG(INFO) << "99 percent ms: " << time_list[percent_pos_99];
    LOG(INFO) << "99.9 percent ms: " << time_list[percent_pos_999];
  } else {
    LOG(INFO) << "N/A";
  }
}
int main(int argc, char** argv) {
W
wangguibao 已提交
280 281
  google::ParseCommandLineFlags(&argc, &argv, true);

X
xulongteng 已提交
282 283
  // initialize
  PredictorApi api;
W
wangguibao 已提交
284 285
  response_time.resize(FLAGS_concurrency);

X
xulongteng 已提交
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
#ifdef BCLOUD
  logging::LoggingSettings settings;
  settings.logging_dest = logging::LOG_TO_FILE;
  std::string log_filename(argv[0]);
  log_filename = log_filename.substr(log_filename.find_last_of('/') + 1);
  settings.log_file = (std::string("./log/") + log_filename + ".log").c_str();
  settings.delete_old = logging::DELETE_OLD_LOG_FILE;
  logging::InitLogging(settings);
  logging::ComlogSinkOptions cso;
  cso.process_name = log_filename;
  cso.enable_wf_device = true;
  logging::ComlogSink::GetInstance()->Setup(&cso);
#else
  struct stat st_buf;
  int ret = 0;
  if ((ret = stat("./log", &st_buf)) != 0) {
    mkdir("./log", 0777);
    ret = stat("./log", &st_buf);
    if (ret != 0) {
      LOG(WARNING) << "Log path ./log not exist, and create fail";
      return -1;
    }
  }
  FLAGS_log_dir = "./log";
  google::InitGoogleLogging(strdup(argv[0]));
  FLAGS_logbufsecs = 0;
  FLAGS_logbuflevel = -1;
#endif
  // predictor conf
  if (api.create("./conf", "predictors.prototxt") != 0) {
    LOG(ERROR) << "Failed create predictors api!";
    return -1;
  }
W
wangguibao 已提交
319 320 321 322 323 324 325 326

  LOG(INFO) << "data sample file: " << data_filename;

  if (FLAGS_enable_profiling) {
    LOG(INFO) << "In profiling mode, lot of normal output will be supressed. "
              << "Use --enable_profiling=false to turn off this mode";
  }

X
xulongteng 已提交
327 328 329 330 331 332
  // read data
  std::ifstream data_file(data_filename);
  if (!data_file) {
    std::cout << "read file error \n" << std::endl;
    return -1;
  }
W
wangguibao 已提交
333

X
xulongteng 已提交
334 335 336 337 338
  std::vector<std::string> data_list;
  std::string line;
  while (getline(data_file, line)) {
    data_list.push_back(line);
  }
W
wangguibao 已提交
339

X
xulongteng 已提交
340 341
  // create threads
  std::vector<std::thread*> thread_pool;
W
wangguibao 已提交
342 343
  for (int i = 0; i < FLAGS_concurrency; ++i) {
    thread_pool.push_back(new std::thread(thread_worker, &api, i, data_list));
X
xulongteng 已提交
344
  }
W
wangguibao 已提交
345 346

  for (int i = 0; i < FLAGS_concurrency; ++i) {
X
xulongteng 已提交
347 348 349
    thread_pool[i]->join();
    delete thread_pool[i];
  }
W
wangguibao 已提交
350

W
wangguibao 已提交
351
  calc_time();
W
wangguibao 已提交
352

X
xulongteng 已提交
353 354 355
  api.destroy();
  return 0;
}