diff --git a/core/general-server/op/general_text_infer_op.cpp b/core/general-server/op/general_text_infer_op.cpp index d5acc83881b50da37c3b9526db475d4fe6e0453b..42be5ed056bbe325aeabe79bbbbd7892f2a1c275 100644 --- a/core/general-server/op/general_text_infer_op.cpp +++ b/core/general-server/op/general_text_infer_op.cpp @@ -56,7 +56,13 @@ int GeneralTextInferOp::inference() { const TensorVector *in = &reader_out->tensor_vector; TensorVector *out = butil::get_object(); - int batch_size = (*in)[0].shape[0]; + int batch_size = 0; + if (in->at(0).lod.size() == 1) { + batch_size = in->at(0).lod[0].size() - 1; + } else { + batch_size = in->at(0).shape[0]; + } + VLOG(2) << "infer batch size: " << batch_size; // infer Timer timeline; double infer_time = 0.0;