提交 933b1f6a 编写于 作者: W wangjiawei04

add dist kv with quant infer op test=serving

上级 c01a9d42
......@@ -23,8 +23,8 @@
#include "core/predictor/framework/infer.h"
#include "core/predictor/framework/memory.h"
#include "core/predictor/framework/resource.h"
#include "core/util/include/timer.h"
#include "core/predictor/tools/quant.h"
#include "core/util/include/timer.h"
namespace baidu {
namespace paddle_serving {
......@@ -93,7 +93,7 @@ int GeneralDistKVQuantInferOp::inference() {
if (values.size() != keys.size() || values[0].buff.size() == 0) {
LOG(ERROR) << "cube value return null";
}
TensorVector sparse_out;
sparse_out.resize(sparse_count);
TensorVector dense_out;
......@@ -106,8 +106,8 @@ int GeneralDistKVQuantInferOp::inference() {
baidu::paddle_serving::predictor::Resource::instance();
std::shared_ptr<PaddleGeneralModelConfig> model_config =
resource.get_general_model_config();
int cube_quant_bits = resource.get_cube_quant_bits();
size_t EMBEDDING_SIZE = 0;
int cube_quant_bits = resource.get_cube_quant_bits();
size_t EMBEDDING_SIZE = 0;
if (cube_quant_bits == 0) {
EMBEDDING_SIZE = values[0].buff.size() / sizeof(float);
} else {
......@@ -140,18 +140,26 @@ int GeneralDistKVQuantInferOp::inference() {
for (int x = 0; x < sparse_out[sparse_idx].lod[0].back(); ++x) {
float *data_ptr = dst_ptr + x * EMBEDDING_SIZE;
if (cube_quant_bits == 0) {
memcpy(data_ptr,
values[cube_val_idx].buff.data(),
values[cube_val_idx].buff.size());
memcpy(data_ptr,
values[cube_val_idx].buff.data(),
values[cube_val_idx].buff.size());
} else {
// min (float), max (float), num, num, num... (Byte)
size_t num_of_float = values[cube_val_idx].buff.size() - 2 * sizeof(float);
float* float_ptr = new float[num_of_float];
char* src_ptr = new char[values[cube_val_idx].buff.size()];
memcpy(src_ptr, values[cube_val_idx].buff.data(), values[cube_val_idx].buff.size());
float* minmax = reinterpret_cast<float*>(src_ptr);
dequant(src_ptr + 2*sizeof(float), float_ptr, minmax[0], minmax[1], num_of_float, cube_quant_bits);
memcpy(data_ptr, float_ptr, sizeof(float)*num_of_float);
size_t num_of_float =
values[cube_val_idx].buff.size() - 2 * sizeof(float);
float *float_ptr = new float[num_of_float];
char *src_ptr = new char[values[cube_val_idx].buff.size()];
memcpy(src_ptr,
values[cube_val_idx].buff.data(),
values[cube_val_idx].buff.size());
float *minmax = reinterpret_cast<float *>(src_ptr);
dequant(src_ptr + 2 * sizeof(float),
float_ptr,
minmax[0],
minmax[1],
num_of_float,
cube_quant_bits);
memcpy(data_ptr, float_ptr, sizeof(float) * num_of_float);
delete float_ptr;
delete src_ptr;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册