diff --git a/demo-serving/op/bert_service_op.cpp b/demo-serving/op/bert_service_op.cpp index 4782821219ef2aec361ead471f65a034b90c41a8..1261da68bf1f93af8a2ea580a56352b348d3fea3 100644 --- a/demo-serving/op/bert_service_op.cpp +++ b/demo-serving/op/bert_service_op.cpp @@ -172,14 +172,14 @@ int BertServiceOp::inference() { LOG(INFO) << "batch_size : " << out->at(0).shape[0] << " emb_size : " << out->at(0).shape[1]; - uint32_t emb_size = out->at(0).shape[1] float *out_data = - reinterpret_cast(out->at(0).data.data()); + uint32_t emb_size = out->at(0).shape[1]; + float *out_data = reinterpret_cast(out->at(0).data.data()); for (uint32_t bi = 0; bi < batch_size; bi++) { BertResInstance *res_instance = res->add_instances(); for (uint32_t si = 0; si < 1; si++) { EmbeddingValues *emb_instance = res_instance->add_instances(); for (uint32_t ei = 0; ei < emb_size; ei++) { - uint32_t index = bi * EMB_SIZE + ei; + uint32_t index = bi * emb_size + ei; emb_instance->add_values(out_data[index]); } }