提交 eea6a33d 编写于 作者: B bjjwwang

add remove dup

上级 1c15da9f
...@@ -69,10 +69,13 @@ int GeneralDistKVInferOp::inference() { ...@@ -69,10 +69,13 @@ int GeneralDistKVInferOp::inference() {
<< ") Failed mutable depended argument, op:" << pre_name; << ") Failed mutable depended argument, op:" << pre_name;
return -1; return -1;
} }
Timer timeline;
timeline.Start();
const TensorVector *in = &input_blob->tensor_vector; const TensorVector *in = &input_blob->tensor_vector;
TensorVector *out = &output_blob->tensor_vector; TensorVector *out = &output_blob->tensor_vector;
std::vector<uint64_t> keys; std::vector<uint64_t> keys;
std::vector<uint64_t> rm_dup_keys;
std::unordered_map<uint64_t, rec::mcube::CubeValue*> key_map;
std::vector<rec::mcube::CubeValue> values; std::vector<rec::mcube::CubeValue> values;
int sparse_count = 0; int sparse_count = 0;
int dense_count = 0; int dense_count = 0;
...@@ -93,7 +96,7 @@ int GeneralDistKVInferOp::inference() { ...@@ -93,7 +96,7 @@ int GeneralDistKVInferOp::inference() {
dataptr_size_pairs.push_back(std::make_pair(data_ptr, elem_num)); dataptr_size_pairs.push_back(std::make_pair(data_ptr, elem_num));
} }
keys.resize(key_len); keys.resize(key_len);
VLOG(2) << "(logid=" << log_id << ") cube number of keys to look up: " << key_len; rm_dup_keys.resize(key_len);
int key_idx = 0; int key_idx = 0;
for (size_t i = 0; i < dataptr_size_pairs.size(); ++i) { for (size_t i = 0; i < dataptr_size_pairs.size(); ++i) {
std::copy(dataptr_size_pairs[i].first, std::copy(dataptr_size_pairs[i].first,
...@@ -101,14 +104,28 @@ int GeneralDistKVInferOp::inference() { ...@@ -101,14 +104,28 @@ int GeneralDistKVInferOp::inference() {
keys.begin() + key_idx); keys.begin() + key_idx);
key_idx += dataptr_size_pairs[i].second; key_idx += dataptr_size_pairs[i].second;
} }
int rm_dup_keys_count = 0;
for (size_t i = 0; i < keys.size(); ++i) {
if (key_map.find(keys[i]) == key_map.end()) {
key_map[keys[i]] = nullptr;
rm_dup_keys[rm_dup_keys_count++] = keys[i];
}
}
rm_dup_keys.resize(rm_dup_keys_count);
VLOG(2) << "(logid=" << log_id << ") cube number of keys to look up: " << key_len << " after rm dup keys: "<< rm_dup_keys_count;
rec::mcube::CubeAPI *cube = rec::mcube::CubeAPI::instance(); rec::mcube::CubeAPI *cube = rec::mcube::CubeAPI::instance();
std::vector<std::string> table_names = cube->get_table_names(); std::vector<std::string> table_names = cube->get_table_names();
if (table_names.size() == 0) { if (table_names.size() == 0) {
LOG(ERROR) << "cube init error or cube config not given."; LOG(ERROR) << "cube init error or cube config not given.";
return -1; return -1;
} }
int ret = cube->seek(table_names[0], keys, &values); int64_t seek_start = timeline.TimeStampUS();
VLOG(2) << "(logid=" << log_id << ") cube seek status: " << ret; int ret = cube->seek(table_names[0], rm_dup_keys, &values);
int64_t seek_end = timeline.TimeStampUS();
VLOG(2) << "(logid=" << log_id << ") cube seek status: " << ret << " seek_time: " << seek_end - seek_start;
for (size_t i = 0; i < rm_dup_keys.size(); ++i) {
key_map[rm_dup_keys[i]] = &values[i];
}
if (values.size() != keys.size() || values[0].buff.size() == 0) { if (values.size() != keys.size() || values[0].buff.size() == 0) {
LOG(ERROR) << "cube value return null"; LOG(ERROR) << "cube value return null";
} }
...@@ -147,22 +164,25 @@ int GeneralDistKVInferOp::inference() { ...@@ -147,22 +164,25 @@ int GeneralDistKVInferOp::inference() {
float *dst_ptr = static_cast<float *>(sparse_out[sparse_idx].data.data()); float *dst_ptr = static_cast<float *>(sparse_out[sparse_idx].data.data());
for (int x = 0; x < sparse_out[sparse_idx].lod[0].back(); ++x) { for (int x = 0; x < sparse_out[sparse_idx].lod[0].back(); ++x) {
float *data_ptr = dst_ptr + x * EMBEDDING_SIZE; float *data_ptr = dst_ptr + x * EMBEDDING_SIZE;
if (values[cube_val_idx].buff.size() == 0) { uint64_t cur_key = keys[cube_val_idx];
rec::mcube::CubeValue* cur_val = key_map[cur_key];
if (cur_val->buff.size() == 0) {
memset(data_ptr, (float)0.0, sizeof(float) * EMBEDDING_SIZE); memset(data_ptr, (float)0.0, sizeof(float) * EMBEDDING_SIZE);
continue; continue;
} }
memcpy(data_ptr, values[cube_val_idx].buff.data()+10, values[cube_val_idx].buff.size()-10); memcpy(data_ptr, cur_val->buff.data()+10, cur_val->buff.size()-10);
cube_val_idx++; cube_val_idx++;
} }
++sparse_idx; ++sparse_idx;
} }
VLOG(2) << "(logid=" << log_id << ") sparse tensor load success."; VLOG(2) << "(logid=" << log_id << ") sparse tensor load success.";
timeline.Pause();
VLOG(2) << "dist kv, cube and datacopy time: " << timeline.ElapsedUS();
TensorVector infer_in; TensorVector infer_in;
infer_in.insert(infer_in.end(), dense_out.begin(), dense_out.end()); infer_in.insert(infer_in.end(), dense_out.begin(), dense_out.end());
infer_in.insert(infer_in.end(), sparse_out.begin(), sparse_out.end()); infer_in.insert(infer_in.end(), sparse_out.begin(), sparse_out.end());
int batch_size = input_blob->_batch_size; int batch_size = input_blob->_batch_size;
output_blob->_batch_size = batch_size; output_blob->_batch_size = batch_size;
Timer timeline;
int64_t start = timeline.TimeStampUS(); int64_t start = timeline.TimeStampUS();
timeline.Start(); timeline.Start();
...@@ -172,7 +192,8 @@ int GeneralDistKVInferOp::inference() { ...@@ -172,7 +192,8 @@ int GeneralDistKVInferOp::inference() {
return -1; return -1;
} }
int64_t end = timeline.TimeStampUS(); int64_t end = timeline.TimeStampUS();
timeline.Pause();
VLOG(2) << "dist kv, pure paddle infer time: " << timeline.ElapsedUS();
CopyBlobInfo(input_blob, output_blob); CopyBlobInfo(input_blob, output_blob);
AddBlobInfo(output_blob, start); AddBlobInfo(output_blob, start);
AddBlobInfo(output_blob, end); AddBlobInfo(output_blob, end);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册