From b7937d2d45606d953b0089862ef8db6c7dcb2bf0 Mon Sep 17 00:00:00 2001 From: wawltor Date: Thu, 5 Mar 2020 14:28:20 +0800 Subject: [PATCH] [cherry pick] Fix the integer overflow problem of sequence2batch. (#22479) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit cherry-pick from the branch develop,fix the overflow of sequence2batch --- paddle/fluid/operators/math/sequence2batch.h | 24 ++++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/math/sequence2batch.h b/paddle/fluid/operators/math/sequence2batch.h index a3186f82d0..9d9f7ef00b 100644 --- a/paddle/fluid/operators/math/sequence2batch.h +++ b/paddle/fluid/operators/math/sequence2batch.h @@ -50,11 +50,11 @@ class LoDTensor2BatchFunctor { // seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)} // struct SeqInfo { - SeqInfo(int start, int length, int seq_idx) + SeqInfo(size_t start, size_t length, size_t seq_idx) : start(start), length(length), seq_idx(seq_idx) {} - int start; - int length; - int seq_idx; + size_t start; + size_t length; + size_t seq_idx; }; public: @@ -82,7 +82,7 @@ class LoDTensor2BatchFunctor { std::vector seq_info; for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { - int length = lod[seq_id + 1] - lod[seq_id]; + size_t length = lod[seq_id + 1] - lod[seq_id]; seq_info.emplace_back(lod[seq_id], length, seq_id); } @@ -118,8 +118,8 @@ class LoDTensor2BatchFunctor { batch_lods.emplace_back(std::vector{0}); // batch_lods[0] is the start positions for batch LoDTensor - int max_seqlen = seq_info[0].length; - batch_lods[0].resize(static_cast(max_seqlen + 1)); + size_t max_seqlen = seq_info[0].length; + batch_lods[0].resize(max_seqlen + 1); // batch_lods[1] is the raw index in the input LoDTensor batch_lods[1].resize(static_cast(lod_tensor.dims()[0])); // batch_lods[2] is the sort order for the input LoDTensor. @@ -128,11 +128,11 @@ class LoDTensor2BatchFunctor { size_t* batch_starts = batch_lods[0].data(); size_t* seq2batch_idx = batch_lods[1].data(); batch_starts[0] = 0; - for (int n = 0; n < max_seqlen; n++) { - auto batch_id = static_cast(batch_starts[n]); + for (size_t n = 0; n < max_seqlen; n++) { + size_t batch_id = batch_starts[n]; for (size_t i = 0; i < seq_info.size(); ++i) { - int seq_len = seq_info[i].length; - int start = seq_info[i].start; + size_t seq_len = seq_info[i].length; + size_t start = seq_info[i].start; if (n < seq_len) { seq2batch_idx[batch_id] = is_reverse ? start + seq_len - 1 - n : start + n; @@ -141,7 +141,7 @@ class LoDTensor2BatchFunctor { break; } } - batch_starts[n + 1] = static_cast(batch_id); + batch_starts[n + 1] = batch_id; } size_t* seq_order = batch_lods[2].data(); for (size_t i = 0; i < seq_info.size(); ++i) { -- GitLab