提交 f83df131 编写于 作者: M MRXLT

fix bug

上级 e808b498
......@@ -168,14 +168,23 @@ int GeneralReaderOp::inference() {
int cur_len = out->at(i).lod[0].back();
VLOG(2) << "current len: " << cur_len;
out->at(i).lod[0].push_back(cur_len + tensor.shape(0));
VLOG(2) << "new len: " << cur_len + tensor.shape(0);
int sample_len;
if (tensor.shape_size() == 1) {
sample_len = data_len;
} else {
sample_len = tensor.shape(0);
}
out->at(i).lod[0].push_back(cur_len + sample_len);
VLOG(2) << "new len: " << cur_len + sample_len;
}
out->at(i).data.Resize(tensor_size * elem_size[i]);
out->at(i).shape = {out->at(i).lod[0].back()};
for (int j = 1; j < req->insts(0).tensor_array(i).shape_size(); ++j) {
out->at(i).shape.push_back(req->insts(0).tensor_array(i).shape(j));
}
if (out->at(i).shape.size() == 1) {
out->at(i).shape.push_back(1);
}
VLOG(2) << "var[" << i
<< "] is lod_tensor and len=" << out->at(i).lod[0].back();
} else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册