提交 8daf85c8 编写于 作者: H hjchen2

Fix crash while loading uint8 quantized model

上级 b6e709e2
......@@ -90,28 +90,28 @@ Executor<Device, T>::Executor(const Program<Device> &program, int batch_size,
}
}
template <typename Device>
template <typename T>
static void LoadMemInternal(void **data, LoDTensor *tensor,
bool quant_uint8 = false) {
char **data_buf = reinterpret_cast<char **>(data);
int64_t size = tensor->numel();
Device *tensor_data = tensor->mutable_data<Device>();
T *tensor_data = tensor->mutable_data<T>();
if (quant_uint8) {
// should be moved into operator init function
float min_value;
float max_value;
memory::Copy(&min_value, data_buf, sizeof(float));
memory::Copy(&max_value, data_buf + sizeof(float), sizeof(float));
data_buf += 2 * sizeof(float);
memory::Copy(&min_value, *data_buf, sizeof(float));
memory::Copy(&max_value, *data_buf + sizeof(float), sizeof(float));
*data_buf += 2 * sizeof(float);
const float factor = (max_value - min_value) / 255.0;
const uint8_t *uint8_data = reinterpret_cast<uint8_t *>(data_buf);
const uint8_t *uint8_data = reinterpret_cast<uint8_t *>(*data_buf);
for (int k = 0; k < size; ++k) {
tensor_data[k] = uint8_data[k] * factor + min_value;
}
data_buf += size * sizeof(uint8_t);
*data_buf += size * sizeof(uint8_t);
} else {
memory::Copy(tensor_data, *data_buf, size * sizeof(Device));
*data_buf += size * sizeof(Device);
memory::Copy(tensor_data, *data_buf, size * sizeof(T));
*data_buf += size * sizeof(T);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册