提交 3e829228 编写于 作者: Y Yan Chunwei 提交者: GitHub

fix tensorarray unpack bug (#4614)

上级 5daf0575
...@@ -217,12 +217,11 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) { ...@@ -217,12 +217,11 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) {
// collect indice need to copy to the batch // collect indice need to copy to the batch
std::vector<size_t> indice; std::vector<size_t> indice;
for (size_t seq_id = 0; seq_id < meta.size(); seq_id++) { for (const auto& seq : meta) {
const auto& seq_meta = meta[seq_id]; size_t id = seq.begin + index;
if (index >= seq_meta.end) break; if (id >= seq.end) break;
indice.push_back(seq_meta.begin + index); indice.push_back(id);
} }
PADDLE_ENFORCE(!indice.empty(), "invalid batch at %d", index); PADDLE_ENFORCE(!indice.empty(), "invalid batch at %d", index);
// copy the indice of records in LoDTensor // copy the indice of records in LoDTensor
...@@ -232,16 +231,18 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) { ...@@ -232,16 +231,18 @@ LoDTensor DynamicBatchUnpacker::GetBatch(size_t index) {
result.Resize(make_ddim(record_dims_vec)); result.Resize(make_ddim(record_dims_vec));
result.mutable_data<value_type>(platform::CPUPlace()); result.mutable_data<value_type>(platform::CPUPlace());
for (size_t i = 0; i < indice.size() - 1; i++) { for (size_t i = 0; i < indice.size(); i++) {
auto index = indice[i]; auto index = indice[i];
auto target = result.Slice<value_type>(i, i + 1); auto target = result.Slice<value_type>(i, i + 1);
auto source_ = source->Slice<value_type>(index, index + 1); auto source_ = source->Slice<value_type>(index, index + 1);
target.CopyFrom<value_type>(source_, platform::CPUPlace()); target.CopyFrom<value_type>(source_, platform::CPUPlace());
} }
return result; return result;
} }
// TODO(supejom) to cache lod if reasonable
LoDTensor PackDynamicBatch(const std::vector<LoDTensor>& source, LoDTensor PackDynamicBatch(const std::vector<LoDTensor>& source,
const std::vector<DySeqMeta>& meta, const LoD& lod, const std::vector<DySeqMeta>& meta, const LoD& lod,
size_t level) { size_t level) {
...@@ -273,7 +274,6 @@ LoDTensor PackDynamicBatch(const std::vector<LoDTensor>& source, ...@@ -273,7 +274,6 @@ LoDTensor PackDynamicBatch(const std::vector<LoDTensor>& source,
} }
result.set_lod(lod); result.set_lod(lod);
return result; return result;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册