提交 7d3df259 编写于 作者: G guru4elephant

fix batch size problem for general text infer

上级 995698f3
......@@ -56,7 +56,13 @@ int GeneralTextInferOp::inference() {
const TensorVector *in = &reader_out->tensor_vector;
TensorVector *out = butil::get_object<TensorVector>();
int batch_size = (*in)[0].shape[0];
int batch_size = 0;
if (in->at(0).lod.size() == 1) {
batch_size = in->at(0).lod[0].size() - 1;
} else {
batch_size = in->at(0).shape[0];
}
VLOG(2) << "infer batch size: " << batch_size;
// infer
Timer timeline;
double infer_time = 0.0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册