未验证 提交 bcdedecb 编写于 作者: H Haonan 提交者: GitHub

handle non-sequence data in sequenceReshapeLayer (#5188)

上级 0318f47e
...@@ -70,11 +70,23 @@ void SequenceReshapeLayer::forward(PassType passType) { ...@@ -70,11 +70,23 @@ void SequenceReshapeLayer::forward(PassType passType) {
size_t outDim = getSize(); size_t outDim = getSize();
size_t numSequences = input.getNumSequences(); size_t numSequences = input.getNumSequences();
auto startPositions = input.sequenceStartPositions->getVector(false);
const int* starts = startPositions->getData();
// by default, we assume each instance as a sequence
IVectorPtr seqStarts;
IVector::resizeOrCreate(seqStarts, input.getBatchSize() + 1, false);
int* startsData = seqStarts->getData();
for (int i = 0; i < input.getBatchSize() + 1; i++) {
startsData[i] = i;
}
const int* starts = startsData;
// if there is sequence, then use start positions
if (input.sequenceStartPositions) {
auto startPositions = input.sequenceStartPositions->getVector(false);
starts = startPositions->getData();
CHECK_EQ(starts[numSequences], input.getBatchSize()); CHECK_EQ(starts[numSequences], input.getBatchSize());
CHECK_EQ(numSequences, startPositions->getSize() - 1); CHECK_EQ(numSequences, startPositions->getSize() - 1);
}
for (size_t seqID = 0; seqID < numSequences; seqID++) { for (size_t seqID = 0; seqID < numSequences; seqID++) {
size_t inNumIns = starts[seqID + 1] - starts[seqID]; size_t inNumIns = starts[seqID + 1] - starts[seqID];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册