From f83df131641b958db3f807b0176f68cba5f48950 Mon Sep 17 00:00:00 2001 From: MRXLT Date: Wed, 20 May 2020 16:44:36 +0800 Subject: [PATCH] fix bug --- core/general-server/op/general_reader_op.cpp | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/core/general-server/op/general_reader_op.cpp b/core/general-server/op/general_reader_op.cpp index b529db12..f6e1a34a 100644 --- a/core/general-server/op/general_reader_op.cpp +++ b/core/general-server/op/general_reader_op.cpp @@ -168,14 +168,23 @@ int GeneralReaderOp::inference() { int cur_len = out->at(i).lod[0].back(); VLOG(2) << "current len: " << cur_len; - out->at(i).lod[0].push_back(cur_len + tensor.shape(0)); - VLOG(2) << "new len: " << cur_len + tensor.shape(0); + int sample_len; + if (tensor.shape_size() == 1) { + sample_len = data_len; + } else { + sample_len = tensor.shape(0); + } + out->at(i).lod[0].push_back(cur_len + sample_len); + VLOG(2) << "new len: " << cur_len + sample_len; } out->at(i).data.Resize(tensor_size * elem_size[i]); out->at(i).shape = {out->at(i).lod[0].back()}; for (int j = 1; j < req->insts(0).tensor_array(i).shape_size(); ++j) { out->at(i).shape.push_back(req->insts(0).tensor_array(i).shape(j)); } + if (out->at(i).shape.size() == 1) { + out->at(i).shape.push_back(1); + } VLOG(2) << "var[" << i << "] is lod_tensor and len=" << out->at(i).lod[0].back(); } else { -- GitLab