提交 64d71041 编写于 作者: W wangguibao

CTR prediction serving

上级 4f871066
......@@ -26,6 +26,20 @@ message EngineDesc {
required int32 enable_batch_align = 8;
optional string version_file = 9;
optional string version_type = 10;
/*
* Sparse Parameter Service type. Valid types are:
* "None": not use sparse parameter service
* "Local": Use local kv service (rocksdb library & API)
* "Remote": Use remote kv service (cube)
*/
enum SparseParamServiceType {
NONE = 0;
LOCAL = 1;
REMOTE = 2;
}
optional SparseParamServiceType sparse_param_service_type = 11;
optional string sparse_param_service_table_name = 12;
};
// model_toolkit conf
......
......@@ -66,6 +66,8 @@ int test_write_conf() {
engine->set_runtime_thread_num(0);
engine->set_batch_infer_size(0);
engine->set_enable_batch_align(0);
engine->set_sparse_param_service_type(EngineDesc::LOCAL);
engine->set_sparse_param_service_table_name("local_kv");
int ret = baidu::paddle_serving::configure::write_proto_conf(
&model_toolkit_conf, output_dir, model_toolkit_conf_file);
......
......@@ -18,7 +18,7 @@ include(op/CMakeLists.txt)
include(proto/CMakeLists.txt)
add_executable(serving ${serving_srcs})
add_dependencies(serving pdcodegen fluid_cpu_engine pdserving paddle_fluid
opencv_imgcodecs)
opencv_imgcodecs cube-api)
if (WITH_GPU)
add_dependencies(serving fluid_gpu_engine)
endif()
......@@ -40,6 +40,7 @@ target_link_libraries(serving opencv_imgcodecs
${opencv_depend_libs})
target_link_libraries(serving pdserving)
target_link_libraries(serving cube-api)
target_link_libraries(serving kvdb rocksdb)
......
......@@ -15,7 +15,9 @@
#include "demo-serving/op/ctr_prediction_op.h"
#include <algorithm>
#include <string>
#include "cube/cube-api/include/cube_api.h"
#include "predictor/framework/infer.h"
#include "predictor/framework/kv_manager.h"
#include "predictor/framework/memory.h"
namespace baidu {
......@@ -41,13 +43,6 @@ const int CTR_PREDICTION_DENSE_SLOT_ID = 26;
const int CTR_PREDICTION_DENSE_DIM = 13;
const int CTR_PREDICTION_EMBEDDING_SIZE = 10;
#if 1
struct CubeValue {
int error;
std::string buff;
};
#endif
void fill_response_with_message(Response *response,
int err_code,
std::string err_msg) {
......@@ -61,6 +56,13 @@ void fill_response_with_message(Response *response,
return;
}
std::string str_tolower(std::string s) {
std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) {
return std::tolower(c);
});
return s;
}
int CTRPredictionOp::inference() {
const Request *req = dynamic_cast<const Request *>(get_request_message());
......@@ -83,8 +85,8 @@ int CTRPredictionOp::inference() {
}
// Query cube API for sparse embeddings
std::vector<int64_t> keys;
std::vector<CubeValue> values;
std::vector<uint64_t> keys;
std::vector<rec::mcube::CubeValue> values;
for (uint32_t si = 0; si < sample_size; ++si) {
const CTRReqInstance &req_instance = req->instances(si);
......@@ -100,13 +102,26 @@ int CTRPredictionOp::inference() {
}
}
#if 0
mCube::CubeAPI* cube = CubeAPI::instance();
int ret = cube->seek(keys, values);
if (ret != 0) {
fill_response_with_message(res, -1, "Query cube for embeddings error");
LOG(ERROR) << "Query cube for embeddings error";
return -1;
#if 1
rec::mcube::CubeAPI *cube = rec::mcube::CubeAPI::instance();
predictor::KVManager &kv_manager = predictor::KVManager::instance();
const predictor::KVInfo *kvinfo =
kv_manager.get_kv_info(CTR_PREDICTION_MODEL_NAME);
std::string table_name;
if (kvinfo->sparse_param_service_type != configure::EngineDesc::NONE) {
std::string table_name = kvinfo->sparse_param_service_table_name;
}
if (kvinfo->sparse_param_service_type == configure::EngineDesc::LOCAL) {
// Query local KV service
} else if (kvinfo->sparse_param_service_type ==
configure::EngineDesc::REMOTE) {
int ret = cube->seek(table_name, keys, &values);
if (ret != 0) {
fill_response_with_message(res, -1, "Query cube for embeddings error");
LOG(ERROR) << "Query cube for embeddings error";
return -1;
}
}
#else
float buff[CTR_PREDICTION_EMBEDDING_SIZE] = {
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include "predictor/common/inner_common.h"
namespace baidu {
namespace paddle_serving {
namespace predictor {
using configure::ModelToolkitConf;
struct KVInfo {
std::string model_name;
uint32_t sparse_param_service_type;
std::string sparse_param_service_table_name;
};
class KVManager {
public:
static KVManager &instance() {
static KVManager ins;
return ins;
}
int proc_initialize(const char *path, const char *file) {
ModelToolkitConf model_toolkit_conf;
if (configure::read_proto_conf(path, file, &model_toolkit_conf) != 0) {
LOG(ERROR) << "failed load infer config, path: " << path << "/" << file;
return -1;
}
size_t engine_num = model_toolkit_conf.engines_size();
for (size_t ei = 0; ei < engine_num; ++ei) {
const configure::EngineDesc &conf = model_toolkit_conf.engines(ei);
std::string engine_name = conf.name();
KVInfo *kvinfo = new (std::nothrow) KVInfo();
kvinfo->model_name = engine_name;
if (conf.has_sparse_param_service_type()) {
kvinfo->sparse_param_service_type = conf.sparse_param_service_type();
} else {
kvinfo->sparse_param_service_type = configure::EngineDesc::NONE;
}
if (conf.has_sparse_param_service_table_name()) {
kvinfo->sparse_param_service_table_name =
conf.sparse_param_service_table_name();
} else {
kvinfo->sparse_param_service_table_name = "";
}
auto r = _map.insert(std::make_pair(engine_name, kvinfo));
if (!r.second) {
LOG(ERROR) << "Failed insert item: " << engine_name;
return -1;
}
LOG(WARNING) << "Succ proc initialize kvmanager for engine: "
<< engine_name;
}
return 0;
}
const KVInfo *get_kv_info(std::string model_name) {
auto it = _map.find(model_name);
if (it == _map.end()) {
LOG(WARNING) << "Cannot find kvinfo for model " << model_name;
return NULL;
}
return it->second;
}
private:
std::map<std::string, KVInfo *> _map;
};
} // namespace predictor
} // namespace paddle_serving
} // namespace baidu
......@@ -16,7 +16,7 @@
#include <string>
#include "predictor/common/inner_common.h"
#include "predictor/framework/infer.h"
#include "predictor/framework/kv_manager.h"
namespace baidu {
namespace paddle_serving {
namespace predictor {
......@@ -76,6 +76,12 @@ int Resource::initialize(const std::string& path, const std::string& file) {
<< model_toolkit_path << "/" << model_toolkit_file;
return -1;
}
if (KVManager::instance().proc_initialize(
model_toolkit_path.c_str(), model_toolkit_file.c_str()) != 0) {
LOG(ERROR) << "Failed proc initialize kvmanager, config: "
<< model_toolkit_path << "/" << model_toolkit_file;
}
}
if (THREAD_KEY_CREATE(&_tls_bspec_key, dynamic_resource_deleter) != 0) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册