提交 ad477b91 编写于 作者: Y Yancey1989

update

上级 e880a356
...@@ -75,17 +75,22 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -75,17 +75,22 @@ class SequenceConcatOpMaker : public framework::OpProtoAndCheckerMaker {
If the axis is other than 0(here, axis is 1 and level is 1), If the axis is other than 0(here, axis is 1 and level is 1),
each input should have the same LoD information and the LoD each input should have the same LoD information and the LoD
information of the output keeps the same as the input. information of the output keeps the same as the input.
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
LoD(x1) = {{0,2,4}, {0,1,2,3,4}}; Dims(x1) = (4,4,4) LoD(x1) = {{0,2,4}, {0,1,2,3,4}}; Dims(x1) = (4,4,4)
LoD(Out) = {{0,2,4}, {0,1,2,3,4}}; Dims(Out) = (4,7,4) LoD(Out) = {{0,2,4}, {0,1,2,3,4}}; Dims(Out) = (4,7,4)
- Case2: - Case2:
If the axis is 0(here, leve is 0), the inputs are concatenated along If the axis is 0(here, leve is 0), the inputs are concatenated along
time steps, the LoD information of the output need to re-compute. time steps, the LoD information of the output need to re-compute.
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
LoD(x1) = {{0,3,5}, {0,1,2,3,5}}; Dims(x1) = (5,3,4) LoD(x1) = {{0,3,5}, {0,1,2,3,5}}; Dims(x1) = (5,3,4)
LoD(Out) = {{0,5,9}, {0,1,2,3,4,5,6,7,9}}; Dims(Out) = (9,3,4) LoD(Out) = {{0,5,9}, {0,1,2,3,4,5,6,7,9}}; Dims(Out) = (9,3,4)
- Case3: - Case3:
If the axis is 0(here, level is 1). If the axis is 0(here, level is 1).
LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4) LoD(x0) = {{0,2,4}, {0,1,2,3,4}}; Dims(x0) = (4,3,4)
LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (5,3,4) LoD(x1) = {{0,3,5}, {0,1,3,4,5}}; Dims(x1) = (5,3,4)
LoD(Out) = {{0,5,9}, {0,2,5,7,9}}; Dims(Out) = (9,3,4) LoD(Out) = {{0,5,9}, {0,2,5,7,9}}; Dims(Out) = (9,3,4)
......
...@@ -29,22 +29,19 @@ LoD concatLoD(const std::vector<const T*> ins, const size_t axis, ...@@ -29,22 +29,19 @@ LoD concatLoD(const std::vector<const T*> ins, const size_t axis,
auto out_lod = ins[0]->lod(); auto out_lod = ins[0]->lod();
const size_t n = ins.size(); const size_t n = ins.size();
if (axis == 0UL) { if (axis == 0UL) {
if (level == 0UL) { for (size_t i = 1; i < n; ++i) {
for (size_t i = 1; i < n; ++i) { for (size_t j = 0; j < ins[i]->lod()[0].size(); ++j) {
for (size_t j = 0; j < ins[i]->lod()[0].size(); ++j) { out_lod[0][j] += ins[i]->lod()[0][j];
out_lod[0][j] += ins[i]->lod()[0][j];
}
} }
} else if (level == 1UL) {
PADDLE_ENFORCE_EQ(ins[0]->NumLevels(), 2UL, if (ins[0]->NumLevels() == 2) {
"If the level is 1, all of the inputs " for (size_t j = 1; j < ins[i]->lod()[1].size(); ++j) {
"should be the nested sequence."); if (level == 0UL) {
for (size_t i = 1; i < n; ++i) { out_lod[1].push_back(out_lod[1].back() + ins[i]->lod()[1][j] -
for (size_t j = 0; j < ins[i]->lod()[0].size(); ++j) { ins[i]->lod()[1][j - 1]);
out_lod[0].push_back(ins[i]->lod()[0][j]); } else if (level == 1UL) {
} out_lod[1][j] += ins[1]->lod()[1][j];
for (size_t j = 0; j < ins[i]->lod()[1].size(); ++j) { }
out_lod[1][j] += ins[i]->lod()[1][j];
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册