提交 d7ee239f 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #1187 from qingqing01/gru_bug

Bug fix in GatedRecurrentLayer which only occurs in predicting or testing.
...@@ -314,13 +314,13 @@ void GatedRecurrentLayer::forwardBatch(int batchSize, ...@@ -314,13 +314,13 @@ void GatedRecurrentLayer::forwardBatch(int batchSize,
batchValue_->resizeOrCreate(*output_.value); batchValue_->resizeOrCreate(*output_.value);
batchValue_->copy(*inputValue, *gate_.value, /* seq2batch */ true); batchValue_->copy(*inputValue, *gate_.value, /* seq2batch */ true);
if (bias_ && bias_->getWGrad()) { if (bias_) {
gate_.value->addBias(*(bias_->getW()), 1); gate_.value->addBias(*(bias_->getW()), 1);
} }
{ {
int numBatch = batchValue_->getNumBatch(); int numBatch = batchValue_->getNumBatch();
int batchSize = 0; int curBatchSize = 0;
AsyncGpuBlock asyncGpuBlock; AsyncGpuBlock asyncGpuBlock;
for (int n = 0; n < numBatch; n++) { for (int n = 0; n < numBatch; n++) {
MatrixPtr outputValueTmp = batchValue_->getBatchValue(n); MatrixPtr outputValueTmp = batchValue_->getBatchValue(n);
...@@ -330,16 +330,17 @@ void GatedRecurrentLayer::forwardBatch(int batchSize, ...@@ -330,16 +330,17 @@ void GatedRecurrentLayer::forwardBatch(int batchSize,
gruValue.resetOutputValue = gruValue.resetOutputValue =
(batchValue_->getBatchValue(*resetOutput_.value, n))->getData(); (batchValue_->getBatchValue(*resetOutput_.value, n))->getData();
batchSize = outputValueTmp->getHeight(); curBatchSize = outputValueTmp->getHeight();
gruValue.prevOutValue = gruValue.prevOutValue =
(n == 0 ? nullptr (n == 0
: (batchValue_->getBatchValue(n - 1, batchSize))->getData()); ? nullptr
: (batchValue_->getBatchValue(n - 1, curBatchSize))->getData());
{ {
if (useGpu_) { if (useGpu_) {
GruCompute::forward<1>(gruValue, getSize(), batchSize); GruCompute::forward<1>(gruValue, getSize(), curBatchSize);
} else { } else {
GruCompute::forward<0>(gruValue, getSize(), batchSize); GruCompute::forward<0>(gruValue, getSize(), curBatchSize);
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册