未验证 提交 4e9a2948 编写于 作者: W Wang Guibao 提交者: GitHub

Merge pull request #35 from wangguibao/ctr_model_serving

CTR PREDICTION model serving
......@@ -26,6 +26,20 @@ message EngineDesc {
required int32 enable_batch_align = 8;
optional string version_file = 9;
optional string version_type = 10;
/*
* Sparse Parameter Service type. Valid types are:
* "None": not use sparse parameter service
* "Local": Use local kv service (rocksdb library & API)
* "Remote": Use remote kv service (cube)
*/
enum SparseParamServiceType {
NONE = 0;
LOCAL = 1;
REMOTE = 2;
}
optional SparseParamServiceType sparse_param_service_type = 11;
optional string sparse_param_service_table_name = 12;
};
// model_toolkit conf
......@@ -35,6 +49,8 @@ message ModelToolkitConf { repeated EngineDesc engines = 1; };
message ResourceConf {
required string model_toolkit_path = 1;
required string model_toolkit_file = 2;
optional string cube_config_path = 3;
optional string cube_config_file = 4;
};
// DAG node depency info
......
......@@ -66,6 +66,8 @@ int test_write_conf() {
engine->set_runtime_thread_num(0);
engine->set_batch_infer_size(0);
engine->set_enable_batch_align(0);
engine->set_sparse_param_service_type(EngineDesc::LOCAL);
engine->set_sparse_param_service_table_name("local_kv");
int ret = baidu::paddle_serving::configure::write_proto_conf(
&model_toolkit_conf, output_dir, model_toolkit_conf_file);
......
......@@ -57,6 +57,10 @@ add_executable(text_classification_press
target_link_libraries(text_classification_press -Wl,--whole-archive sdk-cpp -Wl,--no-whole-archive -lpthread -lcrypto -lm -lrt -lssl -ldl
-lz)
add_executable(ctr_prediction
${CMAKE_CURRENT_LIST_DIR}/src/ctr_prediction.cpp)
target_link_libraries(ctr_prediction -Wl,--whole-archive sdk-cpp
-Wl,--no-whole-archive -lpthread -lcrypto -lm -lrt -lssl -ldl -lz)
# install
install(TARGETS ximage
RUNTIME DESTINATION
......@@ -104,3 +108,11 @@ install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/conf DESTINATION
${PADDLE_SERVING_INSTALL_DIR}/demo/client/text_classification/)
install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/data/text_classification DESTINATION
${PADDLE_SERVING_INSTALL_DIR}/demo/client/text_classification/data)
install(TARGETS ctr_prediction
RUNTIME DESTINATION
${PADDLE_SERVING_INSTALL_DIR}/demo/client/ctr_prediction/bin)
install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/conf DESTINATION
${PADDLE_SERVING_INSTALL_DIR}/demo/client/ctr_prediction/)
install(DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/data/ctr_prediction DESTINATION
${PADDLE_SERVING_INSTALL_DIR}/demo/client/ctr_prediction/data)
......@@ -124,3 +124,18 @@ predictors {
}
}
}
predictors {
name: "ctr_prediction_service"
service_name: "baidu.paddle_serving.predictor.ctr_prediction.CTRPredictionService"
endpoint_router: "WeightedRandomRender"
weighted_random_render_conf {
variant_weight_list: "50"
}
variants {
tag: "var1"
naming_conf {
cluster: "list://127.0.0.1:8010"
}
}
}
因为 它太大了无法显示 source diff 。你可以改为 查看blob
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <sys/stat.h>
#include <sys/types.h>
#include <unistd.h>
#include <cstdlib>
#include <fstream>
#include <sstream>
#include <string>
#include <thread> // NOLINT
#include "sdk-cpp/ctr_prediction.pb.h"
#include "sdk-cpp/include/common.h"
#include "sdk-cpp/include/predictor_sdk.h"
using baidu::paddle_serving::sdk_cpp::Predictor;
using baidu::paddle_serving::sdk_cpp::PredictorApi;
using baidu::paddle_serving::predictor::ctr_prediction::Request;
using baidu::paddle_serving::predictor::ctr_prediction::Response;
using baidu::paddle_serving::predictor::ctr_prediction::CTRReqInstance;
using baidu::paddle_serving::predictor::ctr_prediction::CTRResInstance;
int batch_size = 1;
int sparse_num = 26;
int dense_num = 13;
int thread_num = 1;
int hash_dim = 1000001;
std::vector<float> cont_min = {0, -3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
std::vector<float> cont_diff = {
20, 603, 100, 50, 64000, 500, 100, 50, 500, 10, 10, 10, 50};
char* data_filename = "./data/ctr_prediction/data.txt";
std::atomic<int> g_concurrency(0);
std::vector<std::vector<int>> response_time;
std::vector<std::string> split(const std::string& str,
const std::string& pattern) {
std::vector<std::string> res;
if (str == "") return res;
std::string strs = str + pattern;
size_t pos = strs.find(pattern);
while (pos != strs.npos) {
std::string temp = strs.substr(0, pos);
res.push_back(temp);
strs = strs.substr(pos + 1, strs.size());
pos = strs.find(pattern);
}
return res;
}
/**
* Simulate CPython hash function on string objects
*
* Our model training process use this function to convert string objects to
* unique ids.
*
* See string_hash() in
* https://svn.python.org/projects/python/trunk/Objects/stringobject.c
*/
int64_t hash(std::string str) {
int64_t len;
unsigned char* p;
int64_t x;
len = str.size();
p = (unsigned char*)str.c_str();
x = *p << 7;
while (--len >= 0) {
x = (1000003 * x) ^ *p++;
}
x ^= str.size();
if (x == -1) {
x = -2;
}
return x;
}
int create_req(Request* req,
const std::vector<std::string>& data_list,
int data_index,
int batch_size) {
for (int i = 0; i < batch_size; ++i) {
CTRReqInstance* ins = req->add_instances();
if (!ins) {
LOG(ERROR) << "Failed create req instance";
return -1;
}
// add data
std::vector<std::string> feature_list =
split(data_list[data_index + i], "\t");
for (int fi = 0; fi < dense_num; fi++) {
if (feature_list[fi] == "") {
ins->add_dense_ids(0.0);
} else {
float dense_id = std::stof(feature_list[fi]);
dense_id = (dense_id - cont_min[fi]) / cont_diff[fi];
ins->add_dense_ids(dense_id);
}
}
for (int fi = dense_num; fi < (dense_num + sparse_num); fi++) {
int64_t sparse_id =
hash(std::to_string(fi) + feature_list[fi]) % hash_dim;
if (sparse_id < 0) {
// diff between c++ and python
sparse_id += hash_dim;
}
ins->add_sparse_ids(sparse_id);
}
}
return 0;
}
void print_res(const Request& req,
const Response& res,
std::string route_tag,
uint64_t mid_ms,
uint64_t elapse_ms) {
if (res.err_code() != 0) {
LOG(ERROR) << "Get result fail :" << res.err_msg();
return;
}
for (uint32_t i = 0; i < res.predictions_size(); ++i) {
const CTRResInstance& res_ins = res.predictions(i);
std::ostringstream oss;
oss << "[" << res_ins.prob0() << " " << res_ins.prob1() << "]";
LOG(INFO) << "Receive result " << oss.str();
}
LOG(INFO) << "Succ call predictor[ctr_prediction_service], the tag is: "
<< route_tag << ", mid_ms: " << mid_ms
<< ", elapse_ms: " << elapse_ms;
}
void thread_worker(PredictorApi* api,
int thread_id,
int batch_size,
int server_concurrency,
const std::vector<std::string>& data_list) {
// init
Request req;
Response res;
api->thrd_initialize();
std::string line;
int turns = 0;
while (turns < 1000) {
timeval start;
gettimeofday(&start, NULL);
api->thrd_clear();
Predictor* predictor = api->fetch_predictor("ctr_prediction_service");
if (!predictor) {
LOG(ERROR) << "Failed fetch predictor: ctr_prediction_service";
return;
}
req.Clear();
res.Clear();
timeval mid;
gettimeofday(&mid, NULL);
uint64_t mid_ms = (mid.tv_sec * 1000 + mid.tv_usec / 1000) -
(start.tv_sec * 1000 + start.tv_usec / 1000);
// wait for other thread
while (g_concurrency.load() >= server_concurrency) {
}
g_concurrency++;
LOG(INFO) << "Current concurrency " << g_concurrency.load();
int data_index = turns * batch_size;
if (create_req(&req, data_list, data_index, batch_size) != 0) {
return;
}
timeval start_run;
gettimeofday(&start_run, NULL);
if (predictor->inference(&req, &res) != 0) {
LOG(ERROR) << "failed call predictor with req:" << req.ShortDebugString();
return;
}
timeval end;
gettimeofday(&end, NULL);
uint64_t elapse_ms = (end.tv_sec * 1000 + end.tv_usec / 1000) -
(start_run.tv_sec * 1000 + start_run.tv_usec / 1000);
response_time[thread_id].push_back(elapse_ms);
print_res(req, res, predictor->tag(), mid_ms, elapse_ms);
g_concurrency--;
LOG(INFO) << "Done. Current concurrency " << g_concurrency.load();
turns++;
}
//
api->thrd_finalize();
}
void calc_time(int server_concurrency, int batch_size) {
std::vector<int> time_list;
for (auto a : response_time) {
time_list.insert(time_list.end(), a.begin(), a.end());
}
LOG(INFO) << "Total request : " << (time_list.size());
LOG(INFO) << "Batch size : " << batch_size;
LOG(INFO) << "Max concurrency : " << server_concurrency;
float total_time = 0;
float max_time = 0;
float min_time = 1000000;
for (int i = 0; i < time_list.size(); ++i) {
total_time += time_list[i];
if (time_list[i] > max_time) max_time = time_list[i];
if (time_list[i] < min_time) min_time = time_list[i];
}
float mean_time = total_time / (time_list.size());
float var_time;
for (int i = 0; i < time_list.size(); ++i) {
var_time += (time_list[i] - mean_time) * (time_list[i] - mean_time);
}
var_time = var_time / time_list.size();
LOG(INFO) << "Total time : " << total_time / server_concurrency
<< " Variance : " << var_time << " Max time : " << max_time
<< " Min time : " << min_time;
float qps = 0.0;
if (total_time > 0)
qps = (time_list.size() * 1000) / (total_time / server_concurrency);
LOG(INFO) << "QPS: " << qps << "/s";
LOG(INFO) << "Latency statistics: ";
sort(time_list.begin(), time_list.end());
int percent_pos_50 = time_list.size() * 0.5;
int percent_pos_80 = time_list.size() * 0.8;
int percent_pos_90 = time_list.size() * 0.9;
int percent_pos_99 = time_list.size() * 0.99;
int percent_pos_999 = time_list.size() * 0.999;
if (time_list.size() != 0) {
LOG(INFO) << "Mean time : " << mean_time;
LOG(INFO) << "50 percent ms: " << time_list[percent_pos_50];
LOG(INFO) << "80 percent ms: " << time_list[percent_pos_80];
LOG(INFO) << "90 percent ms: " << time_list[percent_pos_90];
LOG(INFO) << "99 percent ms: " << time_list[percent_pos_99];
LOG(INFO) << "99.9 percent ms: " << time_list[percent_pos_999];
} else {
LOG(INFO) << "N/A";
}
}
int main(int argc, char** argv) {
// initialize
PredictorApi api;
response_time.resize(thread_num);
int server_concurrency = thread_num;
// log set
#ifdef BCLOUD
logging::LoggingSettings settings;
settings.logging_dest = logging::LOG_TO_FILE;
std::string log_filename(argv[0]);
log_filename = log_filename.substr(log_filename.find_last_of('/') + 1);
settings.log_file = (std::string("./log/") + log_filename + ".log").c_str();
settings.delete_old = logging::DELETE_OLD_LOG_FILE;
logging::InitLogging(settings);
logging::ComlogSinkOptions cso;
cso.process_name = log_filename;
cso.enable_wf_device = true;
logging::ComlogSink::GetInstance()->Setup(&cso);
#else
struct stat st_buf;
int ret = 0;
if ((ret = stat("./log", &st_buf)) != 0) {
mkdir("./log", 0777);
ret = stat("./log", &st_buf);
if (ret != 0) {
LOG(WARNING) << "Log path ./log not exist, and create fail";
return -1;
}
}
FLAGS_log_dir = "./log";
google::InitGoogleLogging(strdup(argv[0]));
FLAGS_logbufsecs = 0;
FLAGS_logbuflevel = -1;
#endif
// predictor conf
if (api.create("./conf", "predictors.prototxt") != 0) {
LOG(ERROR) << "Failed create predictors api!";
return -1;
}
// read data
std::ifstream data_file(data_filename);
if (!data_file) {
std::cout << "read file error \n" << std::endl;
return -1;
}
std::vector<std::string> data_list;
std::string line;
while (getline(data_file, line)) {
data_list.push_back(line);
}
// create threads
std::vector<std::thread*> thread_pool;
for (int i = 0; i < server_concurrency; ++i) {
thread_pool.push_back(new std::thread(thread_worker,
&api,
i,
batch_size,
server_concurrency,
std::ref(data_list)));
}
for (int i = 0; i < server_concurrency; ++i) {
thread_pool[i]->join();
delete thread_pool[i];
}
calc_time(server_concurrency, batch_size);
api.destroy();
return 0;
}
......@@ -18,7 +18,7 @@ include(op/CMakeLists.txt)
include(proto/CMakeLists.txt)
add_executable(serving ${serving_srcs})
add_dependencies(serving pdcodegen fluid_cpu_engine pdserving paddle_fluid
opencv_imgcodecs)
opencv_imgcodecs cube-api)
if (WITH_GPU)
add_dependencies(serving fluid_gpu_engine)
endif()
......@@ -40,6 +40,7 @@ target_link_libraries(serving opencv_imgcodecs
${opencv_depend_libs})
target_link_libraries(serving pdserving)
target_link_libraries(serving cube-api)
target_link_libraries(serving kvdb rocksdb)
......
[{
"dict_name": "dict",
"shard": 2,
"dup": 1,
"timeout": 200,
"retry": 3,
"backup_request": 100,
"type": "ipport_list",
"load_balancer": "rr",
"nodes": [{
"ipport_list": "list://xxx.xxx.xxx.xxx:8000"
},{
"ipport_list": "list://xxx.xxx.xxx.xxx:8000"
}]
}]
--enable_model_toolkit
--enable_cube=false
......@@ -18,3 +18,14 @@ engines {
batch_infer_size: 0
enable_batch_align: 0
}
engines {
name: "ctr_prediction"
type: "FLUID_CPU_ANALYSIS_DIR"
reloadable_meta: "./data/model/paddle/fluid_time_file"
reloadable_type: "timestamp_ne"
model_data_path: "./data/model/paddle/fluid/ctr_prediction"
runtime_thread_num: 0
batch_infer_size: 0
enable_batch_align: 0
}
model_toolkit_path: "./conf/"
model_toolkit_file: "model_toolkit.prototxt"
cube_config_file: "./conf/cube.conf"
......@@ -31,3 +31,8 @@ services {
name: "EchoKVDBService"
workflows: "workflow7"
}
services {
name: "CTRPredictionService"
workflows: "workflow8"
}
......@@ -75,3 +75,11 @@ workflows {
type: "KVDBEchoOp"
}
}
workflows {
name: "workflow8"
workflow_type: "Sequence"
nodes {
name: "ctr_prediction_service_op"
type: "CTRPredictionOp"
}
}
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "demo-serving/op/ctr_prediction_op.h"
#include <algorithm>
#include <string>
#if 0
#include <iomanip>
#endif
#include "cube/cube-api/include/cube_api.h"
#include "predictor/framework/infer.h"
#include "predictor/framework/kv_manager.h"
#include "predictor/framework/memory.h"
namespace baidu {
namespace paddle_serving {
namespace serving {
using baidu::paddle_serving::predictor::MempoolWrapper;
using baidu::paddle_serving::predictor::ctr_prediction::CTRResInstance;
using baidu::paddle_serving::predictor::ctr_prediction::Response;
using baidu::paddle_serving::predictor::ctr_prediction::CTRReqInstance;
using baidu::paddle_serving::predictor::ctr_prediction::Request;
const int VARIABLE_NAME_LEN = 256;
// Total 26 sparse input + 1 dense input
const int CTR_PREDICTION_INPUT_SLOTS = 27;
// First 26: sparse input
const int CTR_PREDICTION_SPARSE_SLOTS = 26;
// Last 1: dense input
const int CTR_PREDICTION_DENSE_SLOT_ID = 26;
const int CTR_PREDICTION_DENSE_DIM = 13;
const int CTR_PREDICTION_EMBEDDING_SIZE = 10;
void fill_response_with_message(Response *response,
int err_code,
std::string err_msg) {
if (response == NULL) {
LOG(ERROR) << "response is NULL";
return;
}
response->set_err_code(err_code);
response->set_err_msg(err_msg);
return;
}
int CTRPredictionOp::inference() {
const Request *req = dynamic_cast<const Request *>(get_request_message());
TensorVector *in = butil::get_object<TensorVector>();
Response *res = mutable_data<Response>();
uint32_t sample_size = req->instances_size();
if (sample_size <= 0) {
LOG(WARNING) << "No instances need to inference!";
fill_response_with_message(res, -1, "Sample size invalid");
return 0;
}
paddle::PaddleTensor lod_tensors[CTR_PREDICTION_INPUT_SLOTS];
for (int i = 0; i < CTR_PREDICTION_INPUT_SLOTS; ++i) {
lod_tensors[i].dtype = paddle::PaddleDType::FLOAT32;
std::vector<std::vector<size_t>> &lod = lod_tensors[i].lod;
lod.resize(1);
lod[0].push_back(0);
}
// Query cube API for sparse embeddings
std::vector<uint64_t> keys;
std::vector<rec::mcube::CubeValue> values;
for (uint32_t si = 0; si < sample_size; ++si) {
const CTRReqInstance &req_instance = req->instances(si);
if (req_instance.sparse_ids_size() != CTR_PREDICTION_SPARSE_SLOTS) {
std::ostringstream iss;
iss << "Sparse input size != " << CTR_PREDICTION_SPARSE_SLOTS;
fill_response_with_message(res, -1, iss.str());
return 0;
}
for (int i = 0; i < req_instance.sparse_ids_size(); ++i) {
keys.push_back(req_instance.sparse_ids(i));
}
}
rec::mcube::CubeAPI *cube = rec::mcube::CubeAPI::instance();
predictor::KVManager &kv_manager = predictor::KVManager::instance();
const predictor::KVInfo *kvinfo =
kv_manager.get_kv_info(CTR_PREDICTION_MODEL_NAME);
if (kvinfo != NULL) {
std::string table_name;
if (kvinfo->sparse_param_service_type != configure::EngineDesc::NONE) {
table_name = kvinfo->sparse_param_service_table_name;
}
if (kvinfo->sparse_param_service_type == configure::EngineDesc::LOCAL) {
// Query local KV service
} else if (kvinfo->sparse_param_service_type ==
configure::EngineDesc::REMOTE) {
int ret = cube->seek(table_name, keys, &values);
if (ret != 0) {
fill_response_with_message(res, -1, "Query cube for embeddings error");
LOG(ERROR) << "Query cube for embeddings error";
return -1;
}
}
#if 0
for (int i = 0; i < keys.size(); ++i) {
std::ostringstream oss;
oss << keys[i] << ": ";
const char *value = (values[i].buff.data());
if (values[i].buff.size() !=
sizeof(float) * CTR_PREDICTION_EMBEDDING_SIZE) {
LOG(WARNING) << "Key " << keys[i] << " has values less than "
<< CTR_PREDICTION_EMBEDDING_SIZE;
}
for (int j = 0; j < values[i].buff.size(); ++j) {
oss << std::hex << std::uppercase << std::setw(2) << std::setfill('0')
<< (static_cast<int>(value[j]) & 0xff);
}
LOG(INFO) << oss.str().c_str();
}
#endif
}
// Sparse embeddings
for (int i = 0; i < CTR_PREDICTION_SPARSE_SLOTS; ++i) {
paddle::PaddleTensor &lod_tensor = lod_tensors[i];
std::vector<std::vector<size_t>> &lod = lod_tensor.lod;
char name[VARIABLE_NAME_LEN];
snprintf(name, VARIABLE_NAME_LEN, "embedding_%d.tmp_0", i);
lod_tensor.name = std::string(name);
for (uint32_t si = 0; si < sample_size; ++si) {
const CTRReqInstance &req_instance = req->instances(si);
lod[0].push_back(lod[0].back() + 1);
}
lod_tensor.shape = {lod[0].back(), CTR_PREDICTION_EMBEDDING_SIZE};
lod_tensor.data.Resize(lod[0].back() * sizeof(float) *
CTR_PREDICTION_EMBEDDING_SIZE);
int offset = 0;
for (uint32_t si = 0; si < sample_size; ++si) {
float *data_ptr = static_cast<float *>(lod_tensor.data.data()) + offset;
const CTRReqInstance &req_instance = req->instances(si);
int idx = si * CTR_PREDICTION_SPARSE_SLOTS + i;
if (values[idx].buff.size() !=
sizeof(float) * CTR_PREDICTION_EMBEDDING_SIZE) {
LOG(ERROR) << "Embedding vector size not expected";
fill_response_with_message(
res, -1, "Embedding vector size not expected");
return 0;
}
memcpy(data_ptr, values[idx].buff.data(), values[idx].buff.size());
offset += CTR_PREDICTION_EMBEDDING_SIZE;
}
in->push_back(lod_tensor);
}
// Dense features
paddle::PaddleTensor &lod_tensor = lod_tensors[CTR_PREDICTION_DENSE_SLOT_ID];
lod_tensor.dtype = paddle::PaddleDType::FLOAT32;
std::vector<std::vector<size_t>> &lod = lod_tensor.lod;
lod_tensor.name = std::string("dense_input");
for (uint32_t si = 0; si < sample_size; ++si) {
const CTRReqInstance &req_instance = req->instances(si);
if (req_instance.dense_ids_size() != CTR_PREDICTION_DENSE_DIM) {
std::ostringstream iss;
iss << "dense input size != " << CTR_PREDICTION_DENSE_DIM;
fill_response_with_message(res, -1, iss.str());
return 0;
}
lod[0].push_back(lod[0].back() + req_instance.dense_ids_size());
}
lod_tensor.shape = {lod[0].back() / CTR_PREDICTION_DENSE_DIM,
CTR_PREDICTION_DENSE_DIM};
lod_tensor.data.Resize(lod[0].back() * sizeof(float));
int offset = 0;
for (uint32_t si = 0; si < sample_size; ++si) {
float *data_ptr = static_cast<float *>(lod_tensor.data.data()) + offset;
const CTRReqInstance &req_instance = req->instances(si);
int id_count = req_instance.dense_ids_size();
memcpy(data_ptr,
req_instance.dense_ids().data(),
sizeof(float) * req_instance.dense_ids_size());
offset += req_instance.dense_ids_size();
}
in->push_back(lod_tensor);
TensorVector *out = butil::get_object<TensorVector>();
if (!out) {
LOG(ERROR) << "Failed get tls output object";
fill_response_with_message(res, -1, "Failed get thread local resource");
return 0;
}
// call paddle fluid model for inferencing
if (predictor::InferManager::instance().infer(
CTR_PREDICTION_MODEL_NAME, in, out, sample_size)) {
LOG(ERROR) << "Failed do infer in fluid model: "
<< CTR_PREDICTION_MODEL_NAME;
fill_response_with_message(res, -1, "Failed do infer in fluid model");
return 0;
}
if (out->size() != sample_size) {
LOG(ERROR) << "Output tensor size not equal that of input";
fill_response_with_message(res, -1, "Output size != input size");
return 0;
}
for (size_t i = 0; i < out->size(); ++i) {
int dim1 = out->at(i).shape[0];
int dim2 = out->at(i).shape[1];
if (out->at(i).dtype != paddle::PaddleDType::FLOAT32) {
LOG(ERROR) << "Expected data type float";
fill_response_with_message(res, -1, "Expected data type float");
return 0;
}
float *data = static_cast<float *>(out->at(i).data.data());
for (int j = 0; j < dim1; ++j) {
CTRResInstance *res_instance = res->add_predictions();
res_instance->set_prob0(data[j * dim2]);
res_instance->set_prob1(data[j * dim2 + 1]);
}
}
for (size_t i = 0; i < in->size(); ++i) {
(*in)[i].shape.clear();
}
in->clear();
butil::return_object<TensorVector>(in);
for (size_t i = 0; i < out->size(); ++i) {
(*out)[i].shape.clear();
}
out->clear();
butil::return_object<TensorVector>(out);
res->set_err_code(0);
res->set_err_msg(std::string(""));
return 0;
}
DEFINE_OP(CTRPredictionOp);
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#ifdef BCLOUD
#ifdef WITH_GPU
#include "paddle/paddle_inference_api.h"
#else
#include "paddle/fluid/inference/api/paddle_inference_api.h"
#endif
#else
#include "paddle/fluid/inference/paddle_inference_api.h"
#endif
#include "demo-serving/ctr_prediction.pb.h"
namespace baidu {
namespace paddle_serving {
namespace serving {
static const char* CTR_PREDICTION_MODEL_NAME = "ctr_prediction";
/**
* CTRPredictionOp: Serve CTR prediction requests.
*
* Original model can be found here:
* https://github.com/PaddlePaddle/models/tree/develop/PaddleRec/ctr
*
* NOTE:
*
* The main purpose of this OP is to demonstrate usage of large-scale sparse
* parameter service (RocksDB for local, mCube for distributed service). To
* achieve this, we modified the orginal model slightly:
* 1) Function ctr_dnn_model() returns feed_vars and fetch_vars
* 2) Use fluid.io.save_inference_model using feed_vars and fetch_vars
* returned from ctr_dnn_model(), instead of fluid.io.save_persistables
* 3) Further, feed_vars were specified to be inputs of concat layer. Then in
* the process of save_inference_model(), the generated inference program will
* have the inputs of concat layer as feed targets.
* 4) Weight values for the embedding layer will be fetched from sparse param
* server for each sample
*
* Please refer to doc/CTR_PREDICTION.md for details on the original model
* and modifications we made
*
*/
class CTRPredictionOp
: public baidu::paddle_serving::predictor::OpWithChannel<
baidu::paddle_serving::predictor::ctr_prediction::Response> {
public:
typedef std::vector<paddle::PaddleTensor> TensorVector;
DECLARE_OP(CTRPredictionOp);
int inference();
};
} // namespace serving
} // namespace paddle_serving
} // namespace baidu
......@@ -6,6 +6,7 @@ LIST(APPEND protofiles
${CMAKE_CURRENT_LIST_DIR}/echo_kvdb_service.proto
${CMAKE_CURRENT_LIST_DIR}/int64tensor_service.proto
${CMAKE_CURRENT_LIST_DIR}/text_classification.proto
${CMAKE_CURRENT_LIST_DIR}/ctr_prediction.proto
)
PROTOBUF_GENERATE_SERVING_CPP(TRUE PROTO_SRCS PROTO_HDRS ${protofiles})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
import "pds_option.proto";
import "builtin_format.proto";
package baidu.paddle_serving.predictor.ctr_prediction;
option cc_generic_services = true;
message CTRReqInstance {
repeated int64 sparse_ids = 1;
repeated float dense_ids = 2;
};
message Request { repeated CTRReqInstance instances = 1; };
message CTRResInstance {
required float prob0 = 1;
required float prob1 = 2;
};
message Response {
repeated CTRResInstance predictions = 1;
required int64 err_code = 2;
optional string err_msg = 3;
};
service CTRPredictionService {
rpc inference(Request) returns (Response);
rpc debug(Request) returns (Response);
option (pds.options).generate_impl = true;
};
# CTR预估模型
## 1. 背景
在搜索、推荐、在线广告等业务场景中,embedding参数的规模常常非常庞大,达到数百GB甚至T级别;训练如此规模的模型需要用到多机分布式训练能力,将参数分片更新和保存;另一方面,训练好的模型,要应用于在线业务,也难以单机加载。Paddle Serving提供大规模稀疏参数读写服务,用户可以方便地将超大规模的稀疏参数以kv形式托管到参数服务,在线预测只需将所需要的参数子集从参数服务读取回来,再执行后续的预测流程。
我们以CTR预估模型为例,演示Paddle Serving中如何使用大规模稀疏参数服务。关于模型细节请参考[原始模型](https://github.com/PaddlePaddle/models/tree/v1.5/PaddleRec/ctr)
根据[对数据集的描述](https://www.kaggle.com/c/criteo-display-ad-challenge/data),该模型原始输入为13维integer features和26维categorical features。在我们的模型中,13维integer feature作为dense feature整体feed到一个data layer,而26维categorical features各自作为一个feature分别feed到一个data layer。除此之外,为计算auc指标,还将label作为一个feature输入。
若按缺省训练参数,本模型的embedding dim为100w,size为10,也就是参数矩阵为1000000 x 10的float型矩阵,实际占用内存共1000000 x 10 x sizeof(float) = 39MB;**实际场景中,embedding参数要大的多;因此该demo仅为演示使用**
## 2. 模型裁剪
在写本文档时([v1.5](https://github.com/PaddlePaddle/models/tree/v1.5)),训练脚本用PaddlePaddle py_reader加速样例读取速度,program中带有py_reader相关OP,且训练过程中只保存了模型参数,没有保存program,保存的参数没法直接用预测库加载;另外原始网络中最终输出的tensor是auc和batch_auc,而实际模型用于预测时只需要每个样例的predict,需要改掉模型的输出tensor为predict。再有,为了演示稀疏参数服务的使用,我们要有意将embedding layer包含的lookup_table OP从预测program中拿掉,以embedding layer的output variable作为网络的输入,然后再添加对应的feed OP,使得我们能够在预测时从稀疏参数服务获取到embedding向量后,将数据直接feed到各个embedding的output variable。
基于以上几方面考虑,我们需要对原始program进行裁剪。大致过程为:
1) 去掉py_reader相关代码,改为用fluid自带的reader和DataFeed
2) 修改原始网络配置,将predict变量作为fetch target
3) 修改原始网络配置,将26个稀疏参数的embedding layer的output作为feed target,以与后续稀疏参数服务配合使用
4) 修改后的网络,本地train 1个batch后,调用`fluid.io.save_inference_model()`,获得裁剪后的模型program
5) 裁剪后的program,用python再次处理,去掉embedding layer的lookup_table OP。这是因为,当前Paddle Fluid在第4步`save_inference_model()`时没有裁剪干净,还保留了embedding的lookup_table OP;如果这些OP不去除掉,那么embedding的output variable就会有2个输入OP:一个是feed OP(我们要添加的),一个是lookup_table;而lookup_table又没有输入,它的输出会与feed OP的输出互相覆盖,导致错乱。另外网络中还保留了SparseFeatFactors这个variable(全局共享的embedding矩阵对应的变量),这个variable也要去掉,否则网络加载时还会尝试从磁盘读取embedding参数,就失去了我们这个demo的意义。
6) 第4步拿到的program,与分布式训练保存的模型参数(除embedding之外)保存到一起,形成完整的预测模型
第1) - 第5)步裁剪完毕后的模型网络配置如下:
![Pruned CTR prediction network](doc/pruned-ctr-network.png)
整个裁剪过程具体说明如下:
### 2.1 网络配置中去除py_reader
Inference program调用ctr_dnn_model()函数时添加`user_py_reader=False`参数。这会在ctr_dnn_model定义中将py_reader相关的代码去掉
修改前:
```python
def train():
args = parse_args()
if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir)
loss, auc_var, batch_auc_var, py_reader, _ = ctr_dnn_model(args.embedding_size, args.sparse_feature_dim)
...
```
修改后:
```python
def train():
args = parse_args()
if not os.path.isdir(args.model_output_dir):
os.mkdir(args.model_output_dir)
loss, auc_var, batch_auc_var, py_reader, _ = ctr_dnn_model(args.embedding_size, args.sparse_feature_dim, use_py_reader=False)
...
```
### 2.2 网络配置中修改feed targets和fetch targets
如第2节开头所述,为了使program适合于演示稀疏参数的使用,我们要裁剪program,将`ctr_dnn_model`中feed variable list和fetch variable分别改掉:
1) Inference program中26维稀疏特征的输入改为每个特征的embedding layer的output variable
2) fetch targets中返回的是predict,取代auc_var和batch_auc_var
截至写本文时,原始的网络配置 (network_conf.py中)`ctr_dnn_model`定义如下:
```python
def ctr_dnn_model(embedding_size, sparse_feature_dim, use_py_reader=True):
def embedding_layer(input):
emb = fluid.layers.embedding(
input=input,
is_sparse=True,
# you need to patch https://github.com/PaddlePaddle/Paddle/pull/14190
# if you want to set is_distributed to True
is_distributed=False,
size=[sparse_feature_dim, embedding_size],
param_attr=fluid.ParamAttr(name="SparseFeatFactors",
initializer=fluid.initializer.Uniform()))
return fluid.layers.sequence_pool(input=emb, pool_type='average') # 需修改1
dense_input = fluid.layers.data(
name="dense_input", shape=[dense_feature_dim], dtype='float32')
sparse_input_ids = [
fluid.layers.data(name="C" + str(i), shape=[1], lod_level=1, dtype='int64')
for i in range(1, 27)]
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
words = [dense_input] + sparse_input_ids + [label]
py_reader = None
if use_py_reader:
py_reader = fluid.layers.create_py_reader_by_data(capacity=64,
feed_list=words,
name='py_reader',
use_double_buffer=True)
words = fluid.layers.read_file(py_reader)
sparse_embed_seq = list(map(embedding_layer, words[1:-1])) # 需修改2
concated = fluid.layers.concat(sparse_embed_seq + words[0:1], axis=1)
fc1 = fluid.layers.fc(input=concated, size=400, act='relu',
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(concated.shape[1]))))
fc2 = fluid.layers.fc(input=fc1, size=400, act='relu',
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(fc1.shape[1]))))
fc3 = fluid.layers.fc(input=fc2, size=400, act='relu',
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(fc2.shape[1]))))
predict = fluid.layers.fc(input=fc3, size=2, act='softmax',
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(fc3.shape[1]))))
cost = fluid.layers.cross_entropy(input=predict, label=words[-1])
avg_cost = fluid.layers.reduce_sum(cost)
accuracy = fluid.layers.accuracy(input=predict, label=words[-1])
auc_var, batch_auc_var, auc_states = \
fluid.layers.auc(input=predict, label=words[-1], num_thresholds=2 ** 12, slide_steps=20)
return avg_cost, auc_var, batch_auc_var, py_reader, words # 需修改3
```
修改后
```python
def ctr_dnn_model(embedding_size, sparse_feature_dim, use_py_reader=True):
def embedding_layer(input):
emb = fluid.layers.embedding(
input=input,
is_sparse=True,
# you need to patch https://github.com/PaddlePaddle/Paddle/pull/14190
# if you want to set is_distributed to True
is_distributed=False,
size=[sparse_feature_dim, embedding_size],
param_attr=fluid.ParamAttr(name="SparseFeatFactors",
initializer=fluid.initializer.Uniform()))
seq = fluid.layers.sequence_pool(input=emb, pool_type='average')
return emb, seq # 对应上文修改处1
dense_input = fluid.layers.data(
name="dense_input", shape=[dense_feature_dim], dtype='float32')
sparse_input_ids = [
fluid.layers.data(name="C" + str(i), shape=[1], lod_level=1, dtype='int64')
for i in range(1, 27)]
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
words = [dense_input] + sparse_input_ids + [label]
sparse_embed_and_seq = list(map(embedding_layer, words[1:-1]))
emb_list = [x[0] for x in sparse_embed_and_seq] # 对应上文修改处2
sparse_embed_seq = [x[1] for x in sparse_embed_and_seq]
concated = fluid.layers.concat(sparse_embed_seq + words[0:1], axis=1)
train_feed_vars = words # 对应上文修改处2
inference_feed_vars = emb_list + words[0:1]
fc1 = fluid.layers.fc(input=concated, size=400, act='relu',
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(concated.shape[1]))))
fc2 = fluid.layers.fc(input=fc1, size=400, act='relu',
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(fc1.shape[1]))))
fc3 = fluid.layers.fc(input=fc2, size=400, act='relu',
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(fc2.shape[1]))))
predict = fluid.layers.fc(input=fc3, size=2, act='softmax',
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Normal(
scale=1 / math.sqrt(fc3.shape[1]))))
cost = fluid.layers.cross_entropy(input=predict, label=words[-1])
avg_cost = fluid.layers.reduce_sum(cost)
accuracy = fluid.layers.accuracy(input=predict, label=words[-1])
auc_var, batch_auc_var, auc_states = \
fluid.layers.auc(input=predict, label=words[-1], num_thresholds=2 ** 12, slide_steps=20)
fetch_vars = [predict]
# 对应上文修改处3
return avg_cost, auc_var, batch_auc_var, train_feed_vars, inference_feed_vars, fetch_vars
```
说明:
1) 修改处1,我们将embedding layer的输出变量返回
2) 修改处2,我们将embedding layer的输出变量保存到`emb_list`,后者进一步保存到`inference_feed_vars`,用来将来在`save_inference_model()`时指定feed variable list。
3) 修改处3,我们将`words`变量作为训练时的feed variable list (`train_feed_vars`),将embedding layer的output variable作为infer时的feed variable list (`inference_feed_vars`),将`predict`作为fetch target (`fetch_vars`),分别返回。`inference_feed_vars``fetch_vars`用于`fluid.io.save_inference_model()`时指定feed variable list和fetch target list
### 2.3 fluid.io.save_inference_model()保存裁剪后的program
`fluid.io.save_inference_model()`不仅保存模型参数,还能够根据feed variable list和fetch target list参数,对program进行裁剪,形成适合inference用的program。大致原理是,根据前向网络配置,从fetch target list开始,反向查找其所依赖的OP列表,并将每个OP的输入加入目标variable list,再次递归地反向找到所有依赖OP和variable list。
在2.2节中我们已经拿到所需的`inference_feed_vars``fetch_vars`,接下来只要在训练过程中每次保存模型参数时改为调用`fluid.io.save_inference_model()`
修改前:
```python
def train_loop(args, train_program, py_reader, loss, auc_var, batch_auc_var,
trainer_num, trainer_id):
...省略
for pass_id in range(args.num_passes):
pass_start = time.time()
batch_id = 0
py_reader.start()
try:
while True:
loss_val, auc_val, batch_auc_val = pe.run(fetch_list=[loss.name, auc_var.name, batch_auc_var.name])
loss_val = np.mean(loss_val)
auc_val = np.mean(auc_val)
batch_auc_val = np.mean(batch_auc_val)
logger.info("TRAIN --> pass: {} batch: {} loss: {} auc: {}, batch_auc: {}"
.format(pass_id, batch_id, loss_val/args.batch_size, auc_val, batch_auc_val))
if batch_id % 1000 == 0 and batch_id != 0:
model_dir = args.model_output_dir + '/batch-' + str(batch_id)
if args.trainer_id == 0:
fluid.io.save_persistables(executor=exe, dirname=model_dir,
main_program=fluid.default_main_program())
batch_id += 1
except fluid.core.EOFException:
py_reader.reset()
print("pass_id: %d, pass_time_cost: %f" % (pass_id, time.time() - pass_start))
...省略
```
修改后
```python
def train_loop(args,
train_program,
train_feed_vars,
inference_feed_vars, # 裁剪program用的feed variable list
fetch_vars, # 裁剪program用的fetch variable list
loss,
auc_var,
batch_auc_var,
trainer_num,
trainer_id):
# 因为已经将py_reader去掉,这里用fluid自带的DataFeeder
dataset = reader.CriteoDataset(args.sparse_feature_dim)
train_reader = paddle.batch(
paddle.reader.shuffle(
dataset.train([args.train_data_path], trainer_num, trainer_id),
buf_size=args.batch_size * 100),
batch_size=args.batch_size)
inference_feed_var_names = [var.name for var in inference_feed_vars]
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
total_time = 0
pass_id = 0
batch_id = 0
feed_var_names = [var.name for var in feed_vars]
feeder = fluid.DataFeeder(feed_var_names, place)
for data in train_reader():
loss_val, auc_val, batch_auc_val = exe.run(fluid.default_main_program(),
feed = feeder.feed(data),
fetch_list=[loss.name, auc_var.name, batch_auc_var.name])
fluid.io.save_inference_model(model_dir,
inference_feed_var_names,
fetch_vars,
exe,
fluid.default_main_program())
break # 我们只要裁剪后的program,不需要模型参数,因此只train一个batch就停止了
loss_val = np.mean(loss_val)
auc_val = np.mean(auc_val)
batch_auc_val = np.mean(batch_auc_val)
logger.info("TRAIN --> pass: {} batch: {} loss: {} auc: {}, batch_auc: {}"
.format(pass_id, batch_id, loss_val/args.batch_size, auc_val, batch_auc_val))
```
### 2.4 用python再次处理inference program,去除lookup_table OP和SparseFeatFactors变量
这一步是因为`fluid.io.save_inference_model()`裁剪出的program没有将lookup_table OP去除。未来如果`save_inference_model`接口完善,本节可跳过
主要代码:
```python
def prune_program():
args = parse_args()
# 从磁盘打开网络配置文件并反序列化成protobuf message
model_dir = args.model_output_dir + "/inference_only"
model_file = model_dir + "/__model__"
with open(model_file, "rb") as f:
protostr = f.read()
f.close()
proto = framework_pb2.ProgramDesc.FromString(six.binary_type(protostr))
# 去除lookup_table OP
block = proto.blocks[0]
kept_ops = [op for op in block.ops if op.type != "lookup_table"]
del block.ops[:]
block.ops.extend(kept_ops)
# 去除SparseFeatFactors var
kept_vars = [var for var in block.vars if var.name != "SparseFeatFactors"]
del block.vars[:]
block.vars.extend(kept_vars)
# 写回磁盘文件
with open(model_file + ".pruned", "wb") as f:
f.write(proto.SerializePartialToString())
f.close()
with open(model_file + ".prototxt.pruned", "w") as f:
f.write(text_format.MessageToString(proto))
f.close()
```
### 2.5 裁剪过程串到一起
我们提供了完整的裁剪CTR预估模型的脚本文件save_program.py,同[CTR分布式训练任务](doc/DISTRIBUTED_TRAINING_AND_SERVING.md)一起发布,可以在trainer和pserver容器的训练脚本目录下找到
## 3. 整个预测计算流程
Client端:
1) Dense feature: 从dataset每条样例读取13个integer features,形成1个dense feature
2) Sparse feature: 从dataset每条样例读取26个categorical feature,分别经过hash(str(feature_index) + feature_string)签名,得到每个feature的id,形成26个sparse feature
Serving端:
1) Dense feature: dense feature共13个float型数字,一起feed到网络dense_input这个variable对应的LodTensor
2) Sparse feature: 26个sparse feature id,分别访问kv服务获取对应的embedding向量,feed到对应的26个embedding layer的output variable。在我们裁剪出来的网络中,这些variable分别对应的变量名为embedding_0.tmp_0, embedding_1.tmp_0, ... embedding_25.tmp_0
3) 执行预测,获取预测结果。
......@@ -150,7 +150,7 @@ type: 预测引擎的类型。可在inferencer-fluid-cpu/src/fluid_cpu_engine.cp
**fluid Analysis API和fluid Native API的区别**
Analysis API在模型加载过程中,会对模型计算逻辑进行多种优化,包括但不限于zero copy tensor,相邻OP的fuse等
Analysis API在模型加载过程中,会对模型计算逻辑进行多种优化,包括但不限于zero copy tensor,相邻OP的fuse等**但优化逻辑不是一定对所有模型都有加速作用,有时甚至会有反作用,请以实测结果为准**
reloadable_meta: 目前实际内容无意义,用来通过对该文件的mtime判断是否超过reload时间阈值
......
......@@ -40,6 +40,7 @@ DEFINE_int32(
DEFINE_int32(reload_interval_s, 10, "");
DEFINE_bool(enable_model_toolkit, false, "enable model toolkit");
DEFINE_string(enable_protocol_list, "baidu_std", "set protocol list");
DEFINE_bool(enable_cube, false, "enable cube");
const char* START_OP_NAME = "startup_op";
} // namespace predictor
......
......@@ -39,6 +39,9 @@ DECLARE_int32(num_threads);
DECLARE_int32(reload_interval_s);
DECLARE_bool(enable_model_toolkit);
DECLARE_string(enable_protocol_list);
DECLARE_bool(enable_cube);
DECLARE_string(cube_config_path);
DECLARE_string(cube_config_file);
// STATIC Variables
extern const char* START_OP_NAME;
......
......@@ -760,6 +760,7 @@ class InferManager {
}
LOG(WARNING) << "Succ proc finalize engine, name: " << it->first;
}
_map.clear();
return 0;
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include "predictor/common/inner_common.h"
namespace baidu {
namespace paddle_serving {
namespace predictor {
using configure::ModelToolkitConf;
struct KVInfo {
std::string model_name;
uint32_t sparse_param_service_type;
std::string sparse_param_service_table_name;
};
class KVManager {
public:
static KVManager &instance() {
static KVManager ins;
return ins;
}
int proc_initialize(const char *path, const char *file) {
ModelToolkitConf model_toolkit_conf;
if (configure::read_proto_conf(path, file, &model_toolkit_conf) != 0) {
LOG(ERROR) << "failed load infer config, path: " << path << "/" << file;
return -1;
}
size_t engine_num = model_toolkit_conf.engines_size();
for (size_t ei = 0; ei < engine_num; ++ei) {
const configure::EngineDesc &conf = model_toolkit_conf.engines(ei);
std::string engine_name = conf.name();
KVInfo *kvinfo = new (std::nothrow) KVInfo();
kvinfo->model_name = engine_name;
if (conf.has_sparse_param_service_type()) {
kvinfo->sparse_param_service_type = conf.sparse_param_service_type();
} else {
kvinfo->sparse_param_service_type = configure::EngineDesc::NONE;
}
if (conf.has_sparse_param_service_table_name()) {
kvinfo->sparse_param_service_table_name =
conf.sparse_param_service_table_name();
} else {
kvinfo->sparse_param_service_table_name = "";
}
auto r = _map.insert(std::make_pair(engine_name, kvinfo));
if (!r.second) {
LOG(ERROR) << "Failed insert item: " << engine_name;
return -1;
}
LOG(WARNING) << engine_name << ": "
<< kvinfo->sparse_param_service_table_name;
LOG(WARNING) << "Succ proc initialize kvmanager for engine: "
<< engine_name;
}
return 0;
}
const KVInfo *get_kv_info(std::string model_name) {
auto it = _map.find(model_name);
if (it == _map.end()) {
LOG(WARNING) << "Cannot find kvinfo for model " << model_name;
return NULL;
}
return it->second;
}
private:
std::map<std::string, KVInfo *> _map;
};
} // namespace predictor
} // namespace paddle_serving
} // namespace baidu
......@@ -15,14 +15,13 @@
#include "predictor/framework/resource.h"
#include <string>
#include "predictor/common/inner_common.h"
#include "predictor/framework/infer.h"
#include "predictor/framework/kv_manager.h"
namespace baidu {
namespace paddle_serving {
namespace predictor {
using configure::ResourceConf;
using rec::mcube::CubeAPI;
// __thread bool p_thread_initialized = false;
static void dynamic_resource_deleter(void* d) {
......@@ -76,6 +75,12 @@ int Resource::initialize(const std::string& path, const std::string& file) {
<< model_toolkit_path << "/" << model_toolkit_file;
return -1;
}
if (KVManager::instance().proc_initialize(
model_toolkit_path.c_str(), model_toolkit_file.c_str()) != 0) {
LOG(ERROR) << "Failed proc initialize kvmanager, config: "
<< model_toolkit_path << "/" << model_toolkit_file;
}
}
if (THREAD_KEY_CREATE(&_tls_bspec_key, dynamic_resource_deleter) != 0) {
......@@ -91,6 +96,44 @@ int Resource::initialize(const std::string& path, const std::string& file) {
return 0;
}
int Resource::cube_initialize(const std::string& path,
const std::string& file) {
// cube
if (!FLAGS_enable_cube) {
return 0;
}
ResourceConf resource_conf;
if (configure::read_proto_conf(path, file, &resource_conf) != 0) {
LOG(ERROR) << "Failed initialize resource from: " << path << "/" << file;
return -1;
}
int err = 0;
std::string cube_config_path = resource_conf.cube_config_path();
if (err != 0) {
LOG(ERROR) << "reade cube_config_path failed, path[" << path << "], file["
<< cube_config_path << "]";
return -1;
}
std::string cube_config_file = resource_conf.cube_config_file();
if (err != 0) {
LOG(ERROR) << "reade cube_config_file failed, path[" << path << "], file["
<< cube_config_file << "]";
return -1;
}
err = CubeAPI::instance()->init(cube_config_file.c_str());
if (err != 0) {
LOG(ERROR) << "failed initialize cube, config: " << cube_config_path << "/"
<< cube_config_file << " error code : " << err;
return -1;
}
LOG(INFO) << "Successfully initialize cube";
return 0;
}
int Resource::thread_initialize() {
// mempool
if (MempoolWrapper::instance().thread_initialize() != 0) {
......@@ -192,7 +235,10 @@ int Resource::finalize() {
LOG(ERROR) << "Failed proc finalize infer manager";
return -1;
}
if (CubeAPI::instance()->destroy() != 0) {
LOG(ERROR) << "Destory cube api failed ";
return -1;
}
THREAD_KEY_DELETE(_tls_bspec_key);
return 0;
......
......@@ -13,9 +13,12 @@
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "cube/cube-api/include/cube_api.h"
#include "kvdb/paddle_rocksdb.h"
#include "predictor/common/inner_common.h"
#include "predictor/framework/infer.h"
#include "predictor/framework/memory.h"
namespace baidu {
......@@ -35,7 +38,13 @@ struct DynamicResource {
class Resource {
public:
Resource() {}
Resource() {
// Reference InferManager::instance() explicitly, to make sure static
// instance of InferManager is constructed before that of Resource, and
// destruct after that of Resource
// See https://stackoverflow.com/a/335746/1513460
InferManager::instance();
}
~Resource() { finalize(); }
......@@ -45,7 +54,7 @@ class Resource {
}
int initialize(const std::string& path, const std::string& file);
int cube_initialize(const std::string& path, const std::string& file);
int thread_initialize();
int thread_clear();
......
......@@ -209,6 +209,14 @@ int main(int argc, char** argv) {
}
LOG(INFO) << "Succ call pthread worker start function";
if (Resource::instance().cube_initialize(FLAGS_resource_path,
FLAGS_resource_file) != 0) {
LOG(ERROR) << "Failed initialize cube, conf: " << FLAGS_resource_path << "/"
<< FLAGS_resource_file;
return -1;
}
LOG(INFO) << "Succ initialize cube";
FLAGS_logtostderr = false;
if (ServerManager::instance().start_and_wait() != 0) {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
import "pds_option.proto";
import "builtin_format.proto";
package baidu.paddle_serving.predictor.ctr_prediction;
option cc_generic_services = true;
message CTRReqInstance {
repeated int64 sparse_ids = 1;
repeated float dense_ids = 2;
};
message Request { repeated CTRReqInstance instances = 1; };
message CTRResInstance {
required float prob0 = 1;
required float prob1 = 2;
};
message Response {
repeated CTRResInstance predictions = 1;
required int64 err_code = 2;
optional string err_msg = 3;
};
service CTRPredictionService {
rpc inference(Request) returns (Response);
rpc debug(Request) returns (Response);
option (pds.options).generate_stub = true;
};
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册