diff --git a/core/general-server/op/general_response_op.cpp b/core/general-server/op/general_response_op.cpp index f2fdcedd7d97cf401e426daab0c23c46324319e0..c478500260c025c7b22be0f8dba03544050029e5 100644 --- a/core/general-server/op/general_response_op.cpp +++ b/core/general-server/op/general_response_op.cpp @@ -106,10 +106,18 @@ int GeneralResponseOp::inference() { } } } else { - for (int j = 0; j < batch_size; ++j) { - for (int k = j * cap; k < (j + 1) * cap; ++k) { + int var_size = in->at(idx).shape[0]; + if (var_size == batch_size) { + for (int j = 0; j < batch_size; ++j) { + for (int k = j * cap; k < (j + 1) * cap; ++k) { + res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_data( + reinterpret_cast(&(data_ptr[k])), sizeof(float)); + } + } + } else { + for (int j = 0; j < batch_size; ++j) { res->mutable_insts(j)->mutable_tensor_array(var_idx)->add_data( - reinterpret_cast(&(data_ptr[k])), sizeof(float)); + reinterpret_cast(&(data_ptr[0])), sizeof(float)); } } }