提交 d7ee239f 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #1187 from qingqing01/gru_bug

Bug fix in GatedRecurrentLayer which only occurs in predicting or testing.
......@@ -314,13 +314,13 @@ void GatedRecurrentLayer::forwardBatch(int batchSize,
batchValue_->resizeOrCreate(*output_.value);
batchValue_->copy(*inputValue, *gate_.value, /* seq2batch */ true);
if (bias_ && bias_->getWGrad()) {
if (bias_) {
gate_.value->addBias(*(bias_->getW()), 1);
}
{
int numBatch = batchValue_->getNumBatch();
int batchSize = 0;
int curBatchSize = 0;
AsyncGpuBlock asyncGpuBlock;
for (int n = 0; n < numBatch; n++) {
MatrixPtr outputValueTmp = batchValue_->getBatchValue(n);
......@@ -330,16 +330,17 @@ void GatedRecurrentLayer::forwardBatch(int batchSize,
gruValue.resetOutputValue =
(batchValue_->getBatchValue(*resetOutput_.value, n))->getData();
batchSize = outputValueTmp->getHeight();
curBatchSize = outputValueTmp->getHeight();
gruValue.prevOutValue =
(n == 0 ? nullptr
: (batchValue_->getBatchValue(n - 1, batchSize))->getData());
(n == 0
? nullptr
: (batchValue_->getBatchValue(n - 1, curBatchSize))->getData());
{
if (useGpu_) {
GruCompute::forward<1>(gruValue, getSize(), batchSize);
GruCompute::forward<1>(gruValue, getSize(), curBatchSize);
} else {
GruCompute::forward<0>(gruValue, getSize(), batchSize);
GruCompute::forward<0>(gruValue, getSize(), curBatchSize);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册