diff --git a/paddle/operators/sequence_concat_op.cc b/paddle/operators/sequence_concat_op.cc index 5dc0b24e6af8ffc2afa43da3cdf40783b99557fa..c4d681bc852fb88d1a01471af88282cc1399bec8 100644 --- a/paddle/operators/sequence_concat_op.cc +++ b/paddle/operators/sequence_concat_op.cc @@ -75,17 +75,22 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { If the axis is other than 0(here, axis is 1 and level is 1), each input should have the same LoD information and the LoD information of the output keeps the same as the input. + LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) LoD(x1) = {{0,2,4}, {0,1,2,3,4}}; Dims(x1) = (4,4,4) LoD(Out) = {{0,2,4}, {0,1,2,3,4}}; Dims(Out) = (4,7,4) + - Case2: If the axis is 0(here, leve is 0), the inputs are concatenated along time steps, the LoD information of the output need to re-compute. + LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) LoD(x1) = {{0,3,5}, {0,1,2,3,5}}; Dims(x1) = (5,3,4) LoD(Out) = {{0,5,9}, {0,1,2,3,4,5,6,7,9}}; Dims(Out) = (9,3,4) + - Case3: If the axis is 0(here, level is 1). + LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (5,3,4) LoD(Out) = {{0,5,9}, {0,2,5,7,9}}; Dims(Out) = (9,3,4) diff --git a/paddle/operators/sequence_concat_op.h b/paddle/operators/sequence_concat_op.h index 91c952caf25d72dec4cf39783e7274b49167e43a..b08699e1a1566179c05df3f3ccb10646307590e3 100644 --- a/paddle/operators/sequence_concat_op.h +++ b/paddle/operators/sequence_concat_op.h @@ -29,22 +29,19 @@ LoD concatLoD(const std::vector ins, const size_t axis, auto out_lod = ins[0]->lod(); const size_t n = ins.size(); if (axis == 0UL) { - if (level == 0UL) { - for (size_t i = 1; i < n; ++i) { - for (size_t j = 0; j < ins[i]->lod()[0].size(); ++j) { - out_lod[0][j] += ins[i]->lod()[0][j]; - } + for (size_t i = 1; i < n; ++i) { + for (size_t j = 0; j < ins[i]->lod()[0].size(); ++j) { + out_lod[0][j] += ins[i]->lod()[0][j]; } - } else if (level == 1UL) { - PADDLE_ENFORCE_EQ(ins[0]->NumLevels(), 2UL, - "If the level is 1, all of the inputs " - "should be the nested sequence."); - for (size_t i = 1; i < n; ++i) { - for (size_t j = 0; j < ins[i]->lod()[0].size(); ++j) { - out_lod[0].push_back(ins[i]->lod()[0][j]); - } - for (size_t j = 0; j < ins[i]->lod()[1].size(); ++j) { - out_lod[1][j] += ins[i]->lod()[1][j]; + + if (ins[0]->NumLevels() == 2) { + for (size_t j = 1; j < ins[i]->lod()[1].size(); ++j) { + if (level == 0UL) { + out_lod[1].push_back(out_lod[1].back() + ins[i]->lod()[1][j] - + ins[i]->lod()[1][j - 1]); + } else if (level == 1UL) { + out_lod[1][j] += ins[1]->lod()[1][j]; + } } } }