未验证 提交 b7937d2d 编写于 作者: W wawltor 提交者: GitHub

[cherry pick] Fix the integer overflow problem of sequence2batch. (#22479)

cherry-pick from the branch develop,fix the overflow of sequence2batch
上级 61ec75c5
......@@ -50,11 +50,11 @@ class LoDTensor2BatchFunctor {
// seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
//
struct SeqInfo {
SeqInfo(int start, int length, int seq_idx)
SeqInfo(size_t start, size_t length, size_t seq_idx)
: start(start), length(length), seq_idx(seq_idx) {}
int start;
int length;
int seq_idx;
size_t start;
size_t length;
size_t seq_idx;
};
public:
......@@ -82,7 +82,7 @@ class LoDTensor2BatchFunctor {
std::vector<SeqInfo> seq_info;
for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
int length = lod[seq_id + 1] - lod[seq_id];
size_t length = lod[seq_id + 1] - lod[seq_id];
seq_info.emplace_back(lod[seq_id], length, seq_id);
}
......@@ -118,8 +118,8 @@ class LoDTensor2BatchFunctor {
batch_lods.emplace_back(std::vector<size_t>{0});
// batch_lods[0] is the start positions for batch LoDTensor
int max_seqlen = seq_info[0].length;
batch_lods[0].resize(static_cast<size_t>(max_seqlen + 1));
size_t max_seqlen = seq_info[0].length;
batch_lods[0].resize(max_seqlen + 1);
// batch_lods[1] is the raw index in the input LoDTensor
batch_lods[1].resize(static_cast<size_t>(lod_tensor.dims()[0]));
// batch_lods[2] is the sort order for the input LoDTensor.
......@@ -128,11 +128,11 @@ class LoDTensor2BatchFunctor {
size_t* batch_starts = batch_lods[0].data();
size_t* seq2batch_idx = batch_lods[1].data();
batch_starts[0] = 0;
for (int n = 0; n < max_seqlen; n++) {
auto batch_id = static_cast<int>(batch_starts[n]);
for (size_t n = 0; n < max_seqlen; n++) {
size_t batch_id = batch_starts[n];
for (size_t i = 0; i < seq_info.size(); ++i) {
int seq_len = seq_info[i].length;
int start = seq_info[i].start;
size_t seq_len = seq_info[i].length;
size_t start = seq_info[i].start;
if (n < seq_len) {
seq2batch_idx[batch_id] =
is_reverse ? start + seq_len - 1 - n : start + n;
......@@ -141,7 +141,7 @@ class LoDTensor2BatchFunctor {
break;
}
}
batch_starts[n + 1] = static_cast<size_t>(batch_id);
batch_starts[n + 1] = batch_id;
}
size_t* seq_order = batch_lods[2].data();
for (size_t i = 0; i < seq_info.size(); ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册