提交 aa28d046 编写于 作者: C caoying03

fix a bug of sequence_slice layer when batch_size=1

上级 ab6b3c48
...@@ -130,6 +130,8 @@ void SequenceSliceLayer::calSelectedRows(const MatrixPtr starts, ...@@ -130,6 +130,8 @@ void SequenceSliceLayer::calSelectedRows(const MatrixPtr starts,
CHECK(starts || ends) << "At least one of the start or end indices " CHECK(starts || ends) << "At least one of the start or end indices "
<< "should be given."; << "should be given.";
bool hasSubseq = getInput(0).hasSubseq();
outSeqStartPos_.resize(1, 0); outSeqStartPos_.resize(1, 0);
outSubSeqStartPos_.resize(1, 0); outSubSeqStartPos_.resize(1, 0);
selectedRows_.clear(); selectedRows_.clear();
...@@ -151,14 +153,13 @@ void SequenceSliceLayer::calSelectedRows(const MatrixPtr starts, ...@@ -151,14 +153,13 @@ void SequenceSliceLayer::calSelectedRows(const MatrixPtr starts,
int seqLen = endPos - begPos + 1; int seqLen = endPos - begPos + 1;
CHECK_GT(seqLen, 0U); CHECK_GT(seqLen, 0U);
for (int m = begPos; m <= endPos; ++m) selectedRows_.push_back(m); for (int m = begPos; m <= endPos; ++m) selectedRows_.push_back(m);
inputSeqInfoVec_.size() > 1 hasSubseq
? outSubSeqStartPos_.push_back(outSubSeqStartPos_.back() + seqLen) ? outSubSeqStartPos_.push_back(outSubSeqStartPos_.back() + seqLen)
: outSeqStartPos_.push_back(outSeqStartPos_.back() + seqLen); : outSeqStartPos_.push_back(outSeqStartPos_.back() + seqLen);
} }
rowIdx++; rowIdx++;
} }
if (inputSeqInfoVec_.size() > 1) if (hasSubseq) outSeqStartPos_.push_back(outSubSeqStartPos_.back());
outSeqStartPos_.push_back(outSubSeqStartPos_.back());
} }
if (useGpu_) { if (useGpu_) {
...@@ -175,7 +176,7 @@ void SequenceSliceLayer::calSelectedRows(const MatrixPtr starts, ...@@ -175,7 +176,7 @@ void SequenceSliceLayer::calSelectedRows(const MatrixPtr starts,
output_.sequenceStartPositions->copyFrom( output_.sequenceStartPositions->copyFrom(
outSeqStartPos_.data(), outSeqStartPos_.size(), false); outSeqStartPos_.data(), outSeqStartPos_.size(), false);
if (inputSeqInfoVec_.size() > 1) { if (hasSubseq) {
ICpuGpuVector::resizeOrCreate( ICpuGpuVector::resizeOrCreate(
output_.subSequenceStartPositions, outSubSeqStartPos_.size(), false); output_.subSequenceStartPositions, outSubSeqStartPos_.size(), false);
output_.subSequenceStartPositions->copyFrom( output_.subSequenceStartPositions->copyFrom(
...@@ -203,10 +204,11 @@ void SequenceSliceLayer::forward(PassType passType) { ...@@ -203,10 +204,11 @@ void SequenceSliceLayer::forward(PassType passType) {
} else } else
copySliceIdsToCpu(); copySliceIdsToCpu();
// calculate the selected row indices in a batch, /*
// and build the output sequence information. * calculate the selected row indices in a batch, and build the output
calSelectedRows(startIdsOnCpu_ ? startIdsOnCpu_ : nullptr, * sequence information.
endIdsOnCpu_ ? endIdsOnCpu_ : nullptr); */
calSelectedRows(startIdsOnCpu_, endIdsOnCpu_);
resetOutput(selectedRows_.size(), getSize()); resetOutput(selectedRows_.size(), getSize());
......
...@@ -30,6 +30,8 @@ const int MAX_SEQ_NUM = 17; ...@@ -30,6 +30,8 @@ const int MAX_SEQ_NUM = 17;
const int MAX_SEQ_LEN = 23; const int MAX_SEQ_LEN = 23;
const int MAX_BEAM_SIZE = 13; const int MAX_BEAM_SIZE = 13;
const size_t SEED = (size_t)(time(NULL));
vector<real> randSampling(real range, int n) { vector<real> randSampling(real range, int n) {
CHECK_GE(range, n); CHECK_GE(range, n);
vector<real> num(range); vector<real> num(range);
...@@ -46,7 +48,7 @@ void genSeqInfo(vector<int>& seqStartPos, vector<int>& subSeqStartPos) { ...@@ -46,7 +48,7 @@ void genSeqInfo(vector<int>& seqStartPos, vector<int>& subSeqStartPos) {
seqStartPos.resize(1, 0); seqStartPos.resize(1, 0);
subSeqStartPos.resize(1, 0); subSeqStartPos.resize(1, 0);
srand((size_t)(time(NULL))); srand(SEED);
int seqNum = 1 + (rand() % MAX_SEQ_NUM); int seqNum = 1 + (rand() % MAX_SEQ_NUM);
for (int i = 0; i < seqNum; ++i) { for (int i = 0; i < seqNum; ++i) {
int subSeqNum = 1 + (rand() % MAX_SEQ_NUM); int subSeqNum = 1 + (rand() % MAX_SEQ_NUM);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册