提交 0e3f332f 编写于 作者: W wangguibao

Fix ctr_prediction model

上级 fda47639
......@@ -28,4 +28,6 @@ engines {
runtime_thread_num: 0
batch_infer_size: 0
enable_batch_align: 0
sparse_param_service_type: REMOTE
sparse_param_service_table_name: "dict"
}
......@@ -102,42 +102,74 @@ int CTRPredictionOp::inference() {
predictor::KVManager &kv_manager = predictor::KVManager::instance();
const predictor::KVInfo *kvinfo =
kv_manager.get_kv_info(CTR_PREDICTION_MODEL_NAME);
if (kvinfo != NULL) {
std::string table_name;
if (kvinfo->sparse_param_service_type != configure::EngineDesc::NONE) {
table_name = kvinfo->sparse_param_service_table_name;
if (kvinfo == NULL) {
LOG(ERROR) << "Sparse param service info not found for model "
<< CTR_PREDICTION_MODEL_NAME
<< ". Maybe forgot to specify sparse_param_service_type and "
<< "sparse_param_service_table_name in "
<< "conf/model_toolkit.prototxt";
fill_response_with_message(res, -1, "Sparse param service info not found");
return 0;
}
std::string table_name;
if (kvinfo->sparse_param_service_type != configure::EngineDesc::NONE) {
table_name = kvinfo->sparse_param_service_table_name;
if (table_name.empty()) {
LOG(ERROR) << "sparse_param_service_table_name not specified. "
<< "Please specify it in conf/model_toolkit.protxt for model "
<< CTR_PREDICTION_MODEL_NAME;
fill_response_with_message(
res, -1, "sparse_param_service_table_name not specified");
}
}
if (kvinfo->sparse_param_service_type == configure::EngineDesc::LOCAL) {
// Query local KV service
} else if (kvinfo->sparse_param_service_type ==
configure::EngineDesc::REMOTE) {
int ret = cube->seek(table_name, keys, &values);
if (ret != 0) {
fill_response_with_message(res, -1, "Query cube for embeddings error");
LOG(ERROR) << "Query cube for embeddings error";
return -1;
}
if (kvinfo->sparse_param_service_type == configure::EngineDesc::LOCAL) {
// Query local KV service
LOG(ERROR) << "Local kv service not supported for model "
<< CTR_PREDICTION_MODEL_NAME;
fill_response_with_message(
res, -1, "Local kv service not supported for this model");
return 0;
} else if (kvinfo->sparse_param_service_type ==
configure::EngineDesc::REMOTE) {
int ret = cube->seek(table_name, keys, &values);
if (ret != 0) {
fill_response_with_message(res, -1, "Query cube for embeddings error");
LOG(ERROR) << "Query cube for embeddings error";
return 0;
}
}
#if 0
for (int i = 0; i < keys.size(); ++i) {
std::ostringstream oss;
oss << keys[i] << ": ";
const char *value = (values[i].buff.data());
if (values[i].buff.size() !=
sizeof(float) * CTR_PREDICTION_EMBEDDING_SIZE) {
LOG(WARNING) << "Key " << keys[i] << " has values less than "
<< CTR_PREDICTION_EMBEDDING_SIZE;
}
if (values.size() != keys.size()) {
LOG(ERROR) << "Sparse embeddings not ready; "
<< "maybe forgot to set sparse_param_service_type and "
<< "sparse_param_sevice_table_name for "
<< CTR_PREDICTION_MODEL_NAME
<< " in conf/model_toolkit.prototxt";
fill_response_with_message(
res, -1, "Sparse param service not configured properly");
return 0;
}
for (int i = 0; i < keys.size(); ++i) {
std::ostringstream oss;
oss << keys[i] << ": ";
const char *value = (values[i].buff.data());
if (values[i].buff.size() !=
sizeof(float) * CTR_PREDICTION_EMBEDDING_SIZE) {
LOG(WARNING) << "Key " << keys[i] << " has values less than "
<< CTR_PREDICTION_EMBEDDING_SIZE;
}
#if 0
for (int j = 0; j < values[i].buff.size(); ++j) {
oss << std::hex << std::uppercase << std::setw(2) << std::setfill('0')
<< (static_cast<int>(value[j]) & 0xff);
}
LOG(INFO) << oss.str().c_str();
}
#endif
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册