From bcdedecb5755df1b42e4fa822498224d6d1baccd Mon Sep 17 00:00:00 2001 From: Haonan Date: Tue, 31 Oct 2017 16:23:13 -0700 Subject: [PATCH] handle non-sequence data in sequenceReshapeLayer (#5188) --- .../gserver/layers/SequenceReshapeLayer.cpp | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/paddle/gserver/layers/SequenceReshapeLayer.cpp b/paddle/gserver/layers/SequenceReshapeLayer.cpp index 433592953b2..82297440728 100644 --- a/paddle/gserver/layers/SequenceReshapeLayer.cpp +++ b/paddle/gserver/layers/SequenceReshapeLayer.cpp @@ -70,11 +70,23 @@ void SequenceReshapeLayer::forward(PassType passType) { size_t outDim = getSize(); size_t numSequences = input.getNumSequences(); - auto startPositions = input.sequenceStartPositions->getVector(false); - const int* starts = startPositions->getData(); - CHECK_EQ(starts[numSequences], input.getBatchSize()); - CHECK_EQ(numSequences, startPositions->getSize() - 1); + // by default, we assume each instance as a sequence + IVectorPtr seqStarts; + IVector::resizeOrCreate(seqStarts, input.getBatchSize() + 1, false); + int* startsData = seqStarts->getData(); + for (int i = 0; i < input.getBatchSize() + 1; i++) { + startsData[i] = i; + } + const int* starts = startsData; + + // if there is sequence, then use start positions + if (input.sequenceStartPositions) { + auto startPositions = input.sequenceStartPositions->getVector(false); + starts = startPositions->getData(); + CHECK_EQ(starts[numSequences], input.getBatchSize()); + CHECK_EQ(numSequences, startPositions->getSize() - 1); + } for (size_t seqID = 0; seqID < numSequences; seqID++) { size_t inNumIns = starts[seqID + 1] - starts[seqID]; -- GitLab