提交 1c15da9f 编写于 作者: B bjjwwang

shenzhen intl

上级 0c5bb75b
...@@ -112,7 +112,8 @@ int GeneralDistKVInferOp::inference() { ...@@ -112,7 +112,8 @@ int GeneralDistKVInferOp::inference() {
if (values.size() != keys.size() || values[0].buff.size() == 0) { if (values.size() != keys.size() || values[0].buff.size() == 0) {
LOG(ERROR) << "cube value return null"; LOG(ERROR) << "cube value return null";
} }
size_t EMBEDDING_SIZE = values[0].buff.size() / sizeof(float); //size_t EMBEDDING_SIZE = values[0].buff.size() / sizeof(float);
size_t EMBEDDING_SIZE = 9;
TensorVector sparse_out; TensorVector sparse_out;
sparse_out.resize(sparse_count); sparse_out.resize(sparse_count);
TensorVector dense_out; TensorVector dense_out;
...@@ -146,9 +147,11 @@ int GeneralDistKVInferOp::inference() { ...@@ -146,9 +147,11 @@ int GeneralDistKVInferOp::inference() {
float *dst_ptr = static_cast<float *>(sparse_out[sparse_idx].data.data()); float *dst_ptr = static_cast<float *>(sparse_out[sparse_idx].data.data());
for (int x = 0; x < sparse_out[sparse_idx].lod[0].back(); ++x) { for (int x = 0; x < sparse_out[sparse_idx].lod[0].back(); ++x) {
float *data_ptr = dst_ptr + x * EMBEDDING_SIZE; float *data_ptr = dst_ptr + x * EMBEDDING_SIZE;
memcpy(data_ptr, if (values[cube_val_idx].buff.size() == 0) {
values[cube_val_idx].buff.data(), memset(data_ptr, (float)0.0, sizeof(float) * EMBEDDING_SIZE);
values[cube_val_idx].buff.size()); continue;
}
memcpy(data_ptr, values[cube_val_idx].buff.data()+10, values[cube_val_idx].buff.size()-10);
cube_val_idx++; cube_val_idx++;
} }
++sparse_idx; ++sparse_idx;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册