diff --git a/demo-serving/conf/model_toolkit.prototxt b/demo-serving/conf/model_toolkit.prototxt index 4fcb909ed8409af0943ee0a495e71b0068f7c040..58f6760b2bbc2b0a6e594c45ce18efc220ac9a34 100644 --- a/demo-serving/conf/model_toolkit.prototxt +++ b/demo-serving/conf/model_toolkit.prototxt @@ -28,4 +28,6 @@ engines { runtime_thread_num: 0 batch_infer_size: 0 enable_batch_align: 0 + sparse_param_service_type: REMOTE + sparse_param_service_table_name: "dict" } diff --git a/demo-serving/op/ctr_prediction_op.cpp b/demo-serving/op/ctr_prediction_op.cpp index 891288f8509e89657b682500fc2d3331a2ccf7f2..649d46990aef1b9a782ecfbf91bf25630cfafdf7 100644 --- a/demo-serving/op/ctr_prediction_op.cpp +++ b/demo-serving/op/ctr_prediction_op.cpp @@ -102,42 +102,75 @@ int CTRPredictionOp::inference() { predictor::KVManager &kv_manager = predictor::KVManager::instance(); const predictor::KVInfo *kvinfo = kv_manager.get_kv_info(CTR_PREDICTION_MODEL_NAME); - if (kvinfo != NULL) { - std::string table_name; - if (kvinfo->sparse_param_service_type != configure::EngineDesc::NONE) { - table_name = kvinfo->sparse_param_service_table_name; + if (kvinfo == NULL) { + LOG(ERROR) << "Sparse param service info not found for model " + << CTR_PREDICTION_MODEL_NAME + << ". Maybe forgot to specify sparse_param_service_type and " + << "sparse_param_service_table_name in " + << "conf/model_toolkit.prototxt"; + fill_response_with_message(res, -1, "Sparse param service info not found"); + return 0; + } + + std::string table_name; + if (kvinfo->sparse_param_service_type != configure::EngineDesc::NONE) { + table_name = kvinfo->sparse_param_service_table_name; + if (table_name.empty()) { + LOG(ERROR) << "sparse_param_service_table_name not specified. " + << "Please specify it in conf/model_toolkit.protxt for model " + << CTR_PREDICTION_MODEL_NAME; + fill_response_with_message( + res, -1, "sparse_param_service_table_name not specified"); + return 0; } + } - if (kvinfo->sparse_param_service_type == configure::EngineDesc::LOCAL) { - // Query local KV service - } else if (kvinfo->sparse_param_service_type == - configure::EngineDesc::REMOTE) { - int ret = cube->seek(table_name, keys, &values); - if (ret != 0) { - fill_response_with_message(res, -1, "Query cube for embeddings error"); - LOG(ERROR) << "Query cube for embeddings error"; - return -1; - } + if (kvinfo->sparse_param_service_type == configure::EngineDesc::LOCAL) { + // Query local KV service + LOG(ERROR) << "Local kv service not supported for model " + << CTR_PREDICTION_MODEL_NAME; + + fill_response_with_message( + res, -1, "Local kv service not supported for this model"); + return 0; + } else if (kvinfo->sparse_param_service_type == + configure::EngineDesc::REMOTE) { + int ret = cube->seek(table_name, keys, &values); + if (ret != 0) { + fill_response_with_message(res, -1, "Query cube for embeddings error"); + LOG(ERROR) << "Query cube for embeddings error"; + return 0; } + } -#if 0 - for (int i = 0; i < keys.size(); ++i) { - std::ostringstream oss; - oss << keys[i] << ": "; - const char *value = (values[i].buff.data()); - if (values[i].buff.size() != - sizeof(float) * CTR_PREDICTION_EMBEDDING_SIZE) { - LOG(WARNING) << "Key " << keys[i] << " has values less than " - << CTR_PREDICTION_EMBEDDING_SIZE; - } + if (values.size() != keys.size()) { + LOG(ERROR) << "Sparse embeddings not ready; " + << "maybe forgot to set sparse_param_service_type and " + << "sparse_param_sevice_table_name for " + << CTR_PREDICTION_MODEL_NAME + << " in conf/model_toolkit.prototxt"; + fill_response_with_message( + res, -1, "Sparse param service not configured properly"); + return 0; + } + + for (int i = 0; i < keys.size(); ++i) { + std::ostringstream oss; + oss << keys[i] << ": "; + const char *value = (values[i].buff.data()); + if (values[i].buff.size() != + sizeof(float) * CTR_PREDICTION_EMBEDDING_SIZE) { + LOG(WARNING) << "Key " << keys[i] << " has values less than " + << CTR_PREDICTION_EMBEDDING_SIZE; + } +#if 0 for (int j = 0; j < values[i].buff.size(); ++j) { oss << std::hex << std::uppercase << std::setw(2) << std::setfill('0') << (static_cast(value[j]) & 0xff); } LOG(INFO) << oss.str().c_str(); - } #endif }