提交 33481bb2 编写于 作者: M MRXLT

fix bert_service_op

上级 3a7fd43f
......@@ -172,14 +172,14 @@ int BertServiceOp::inference() {
LOG(INFO) << "batch_size : " << out->at(0).shape[0]
<< " emb_size : " << out->at(0).shape[1];
uint32_t emb_size = out->at(0).shape[1] float *out_data =
reinterpret_cast<float *>(out->at(0).data.data());
uint32_t emb_size = out->at(0).shape[1];
float *out_data = reinterpret_cast<float *>(out->at(0).data.data());
for (uint32_t bi = 0; bi < batch_size; bi++) {
BertResInstance *res_instance = res->add_instances();
for (uint32_t si = 0; si < 1; si++) {
EmbeddingValues *emb_instance = res_instance->add_instances();
for (uint32_t ei = 0; ei < emb_size; ei++) {
uint32_t index = bi * EMB_SIZE + ei;
uint32_t index = bi * emb_size + ei;
emb_instance->add_values(out_data[index]);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册