From d0c66c8c9cf1fa18c13f00b10f98645d52c08a00 Mon Sep 17 00:00:00 2001 From: bjjwwang Date: Mon, 23 Aug 2021 09:27:51 +0000 Subject: [PATCH] msn dist_kv --- .../op/general_dist_kv_infer_op.cpp | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/core/general-server/op/general_dist_kv_infer_op.cpp b/core/general-server/op/general_dist_kv_infer_op.cpp index 7b10a57d..4bf6aba1 100644 --- a/core/general-server/op/general_dist_kv_infer_op.cpp +++ b/core/general-server/op/general_dist_kv_infer_op.cpp @@ -101,6 +101,7 @@ int GeneralDistKVInferOp::inference() { keys.begin() + key_idx); key_idx += dataptr_size_pairs[i].second; } + rec::mcube::CubeAPI *cube = rec::mcube::CubeAPI::instance(); std::vector table_names = cube->get_table_names(); if (table_names.size() == 0) { @@ -109,9 +110,6 @@ int GeneralDistKVInferOp::inference() { } int ret = cube->seek(table_names[0], keys, &values); VLOG(2) << "(logid=" << log_id << ") cube seek status: " << ret; - if (values.size() != keys.size() || values[0].buff.size() == 0) { - LOG(ERROR) << "cube value return null"; - } //size_t EMBEDDING_SIZE = values[0].buff.size() / sizeof(float); size_t EMBEDDING_SIZE = 9; TensorVector sparse_out; @@ -125,6 +123,8 @@ int GeneralDistKVInferOp::inference() { baidu::paddle_serving::predictor::Resource &resource = baidu::paddle_serving::predictor::Resource::instance(); std::shared_ptr model_config = resource.get_general_model_config().front(); + int cube_key_found = 0; + int cube_key_miss = 0; for (size_t i = 0; i < in->size(); ++i) { if (in->at(i).dtype != paddle::PaddleDType::INT64) { dense_out[dense_idx] = in->at(i); @@ -149,13 +149,26 @@ int GeneralDistKVInferOp::inference() { float *data_ptr = dst_ptr + x * EMBEDDING_SIZE; if (values[cube_val_idx].buff.size() == 0) { memset(data_ptr, (float)0.0, sizeof(float) * EMBEDDING_SIZE); + VLOG(3) << "(logid=" << log_id << ") cube key not found: " << keys[cube_val_idx]; + ++cube_key_miss; + ++cube_val_idx; continue; } + //VLOG(3) << "(logid=" << log_id << ") cube key found: " << keys[cube_val_idx]; memcpy(data_ptr, values[cube_val_idx].buff.data()+10, values[cube_val_idx].buff.size()-10); - cube_val_idx++; + //VLOG(3) << keys[cube_val_idx] << ":" << data_ptr[0] << ", " << data_ptr[1] << ", " <(out->at(0).data.data()); + out_ptr[0] = 0.0; + } CopyBlobInfo(input_blob, output_blob); AddBlobInfo(output_blob, start); AddBlobInfo(output_blob, end); -- GitLab