提交 d0c66c8c 编写于 作者: B bjjwwang

msn dist_kv

上级 1c15da9f
...@@ -101,6 +101,7 @@ int GeneralDistKVInferOp::inference() { ...@@ -101,6 +101,7 @@ int GeneralDistKVInferOp::inference() {
keys.begin() + key_idx); keys.begin() + key_idx);
key_idx += dataptr_size_pairs[i].second; key_idx += dataptr_size_pairs[i].second;
} }
rec::mcube::CubeAPI *cube = rec::mcube::CubeAPI::instance(); rec::mcube::CubeAPI *cube = rec::mcube::CubeAPI::instance();
std::vector<std::string> table_names = cube->get_table_names(); std::vector<std::string> table_names = cube->get_table_names();
if (table_names.size() == 0) { if (table_names.size() == 0) {
...@@ -109,9 +110,6 @@ int GeneralDistKVInferOp::inference() { ...@@ -109,9 +110,6 @@ int GeneralDistKVInferOp::inference() {
} }
int ret = cube->seek(table_names[0], keys, &values); int ret = cube->seek(table_names[0], keys, &values);
VLOG(2) << "(logid=" << log_id << ") cube seek status: " << ret; VLOG(2) << "(logid=" << log_id << ") cube seek status: " << ret;
if (values.size() != keys.size() || values[0].buff.size() == 0) {
LOG(ERROR) << "cube value return null";
}
//size_t EMBEDDING_SIZE = values[0].buff.size() / sizeof(float); //size_t EMBEDDING_SIZE = values[0].buff.size() / sizeof(float);
size_t EMBEDDING_SIZE = 9; size_t EMBEDDING_SIZE = 9;
TensorVector sparse_out; TensorVector sparse_out;
...@@ -125,6 +123,8 @@ int GeneralDistKVInferOp::inference() { ...@@ -125,6 +123,8 @@ int GeneralDistKVInferOp::inference() {
baidu::paddle_serving::predictor::Resource &resource = baidu::paddle_serving::predictor::Resource &resource =
baidu::paddle_serving::predictor::Resource::instance(); baidu::paddle_serving::predictor::Resource::instance();
std::shared_ptr<PaddleGeneralModelConfig> model_config = resource.get_general_model_config().front(); std::shared_ptr<PaddleGeneralModelConfig> model_config = resource.get_general_model_config().front();
int cube_key_found = 0;
int cube_key_miss = 0;
for (size_t i = 0; i < in->size(); ++i) { for (size_t i = 0; i < in->size(); ++i) {
if (in->at(i).dtype != paddle::PaddleDType::INT64) { if (in->at(i).dtype != paddle::PaddleDType::INT64) {
dense_out[dense_idx] = in->at(i); dense_out[dense_idx] = in->at(i);
...@@ -149,13 +149,26 @@ int GeneralDistKVInferOp::inference() { ...@@ -149,13 +149,26 @@ int GeneralDistKVInferOp::inference() {
float *data_ptr = dst_ptr + x * EMBEDDING_SIZE; float *data_ptr = dst_ptr + x * EMBEDDING_SIZE;
if (values[cube_val_idx].buff.size() == 0) { if (values[cube_val_idx].buff.size() == 0) {
memset(data_ptr, (float)0.0, sizeof(float) * EMBEDDING_SIZE); memset(data_ptr, (float)0.0, sizeof(float) * EMBEDDING_SIZE);
VLOG(3) << "(logid=" << log_id << ") cube key not found: " << keys[cube_val_idx];
++cube_key_miss;
++cube_val_idx;
continue; continue;
} }
//VLOG(3) << "(logid=" << log_id << ") cube key found: " << keys[cube_val_idx];
memcpy(data_ptr, values[cube_val_idx].buff.data()+10, values[cube_val_idx].buff.size()-10); memcpy(data_ptr, values[cube_val_idx].buff.data()+10, values[cube_val_idx].buff.size()-10);
cube_val_idx++; //VLOG(3) << keys[cube_val_idx] << ":" << data_ptr[0] << ", " << data_ptr[1] << ", " <<data_ptr[2] << ", " <<data_ptr[3] << ", " <<data_ptr[4] << ", " <<data_ptr[5] << ", " <<data_ptr[6] << ", " <<data_ptr[7] << ", " <<data_ptr[8];
++cube_key_found;
++cube_val_idx;
} }
++sparse_idx; ++sparse_idx;
} }
bool cube_fail = (cube_key_found == 0);
if (cube_fail) {
LOG(WARNING) << "(logid=" << log_id << ") cube seek fail";
//CopyBlobInfo(input_blob, output_blob);
//return 0;
}
VLOG(2) << "(logid=" << log_id << ") cube key found: " << cube_key_found << " , cube key miss: " << cube_key_miss;
VLOG(2) << "(logid=" << log_id << ") sparse tensor load success."; VLOG(2) << "(logid=" << log_id << ") sparse tensor load success.";
TensorVector infer_in; TensorVector infer_in;
infer_in.insert(infer_in.end(), dense_out.begin(), dense_out.end()); infer_in.insert(infer_in.end(), dense_out.begin(), dense_out.end());
...@@ -172,7 +185,10 @@ int GeneralDistKVInferOp::inference() { ...@@ -172,7 +185,10 @@ int GeneralDistKVInferOp::inference() {
return -1; return -1;
} }
int64_t end = timeline.TimeStampUS(); int64_t end = timeline.TimeStampUS();
if (cube_fail) {
float *out_ptr = static_cast<float*>(out->at(0).data.data());
out_ptr[0] = 0.0;
}
CopyBlobInfo(input_blob, output_blob); CopyBlobInfo(input_blob, output_blob);
AddBlobInfo(output_blob, start); AddBlobInfo(output_blob, start);
AddBlobInfo(output_blob, end); AddBlobInfo(output_blob, end);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册