提交 df2b054b 编写于 作者: T tensor-tang

follow comments refine code

上级 43606158
...@@ -53,28 +53,19 @@ void MKLPackedRecurrentLayer::forwardBatch(int batchSize, ...@@ -53,28 +53,19 @@ void MKLPackedRecurrentLayer::forwardBatch(int batchSize,
REGISTER_TIMER_INFO("RecurrentFwBatch", getName().c_str()); REGISTER_TIMER_INFO("RecurrentFwBatch", getName().c_str());
/* forward one batch */ /* forward one batch */
for (size_t n = 0; n < batchValue_->getNumBatch(); n++) { for (size_t n = 0; n < batchValue_->getNumBatch(); n++) {
MatrixPtr batch2 = batchValue_->getBatchValue(n); MatrixPtr batchValue = batchValue_->getBatchValue(n);
if (n != 0) { if (n != 0) {
MatrixPtr batch1 = MatrixPtr preBatchValue =
batchValue_->getBatchValue(n - 1, batch2->getHeight()); batchValue_->getBatchValue(n - 1, batchValue->getHeight());
// batch2->mul(*batch1, *weight_->getW(), 1, 1); packed_weight_->compute(batchValue, preBatchValue);
packed_weight_->compute(batch2, batch1);
}
#pragma omp parallel for collapse(2)
for (size_t i = 0; i < batch2->getHeight(); i++) {
for (size_t j = 0; j < batch2->getWidth(); j++) {
*(batch2->getData() + i * batch2->getWidth() + j) =
*(batch2->getData() + i * batch2->getWidth() + j) > 0
? *(batch2->getData() + i * batch2->getWidth() + j)
: 0;
}
} }
Argument arg;
arg.value = batchValue;
activation_->forward(arg).check();
} }
} }
batchValue_->copyBackSeq(*output_.value); batchValue_->copyBackSeq(*output_.value);
} }
...@@ -94,25 +85,27 @@ void MKLPackedRecurrentLayer::backwardBatch(int batchSize, ...@@ -94,25 +85,27 @@ void MKLPackedRecurrentLayer::backwardBatch(int batchSize,
REGISTER_TIMER_INFO("RecurrentBwData", getName().c_str()); REGISTER_TIMER_INFO("RecurrentBwData", getName().c_str());
/* backward one batch */ /* backward one batch */
for (int n = (int)numBatch - 1; n >= 0; n--) { for (int n = (int)numBatch - 1; n >= 0; n--) {
MatrixPtr batch2 = batchGrad_->getBatchValue(n); MatrixPtr batchGrad = batchGrad_->getBatchValue(n);
MatrixPtr batch1 = batchValue_->getBatchValue(n, batch2->getHeight()); MatrixPtr batchValue =
batchValue_->getBatchValue(n, batchGrad->getHeight());
Argument arg; Argument arg;
arg.value = batch1; arg.value = batchValue;
arg.grad = batch2; arg.grad = batchGrad;
activation_->backward(arg).check(); activation_->backward(arg).check();
if (n != 0) { if (n != 0) {
batch1 = batchGrad_->getBatchValue(n - 1, batch2->getHeight()); batchValue = batchGrad_->getBatchValue(n - 1, batchGrad->getHeight());
// batch1->mul(*batch2, *weightT, 1, 1); packed_weightT_->compute(batchValue, batchGrad);
packed_weightT_->compute(batch1, batch2);
} }
if (backwardByBatch && weight_->getWGrad()) { if (backwardByBatch && weight_->getWGrad()) {
if (n != 0) { if (n != 0) {
/* backward weight */ /* backward weight */
batch1 = batchValue_->getBatchValue(n - 1, batch2->getHeight()); batchValue =
weight_->getWGrad()->mul(*batch1->getTranspose(), *batch2, 1, 1); batchValue_->getBatchValue(n - 1, batchGrad->getHeight());
weight_->getWGrad()->mul(
*batchValue->getTranspose(), *batchGrad, 1, 1);
} }
} }
} }
...@@ -124,19 +117,14 @@ void MKLPackedRecurrentLayer::backwardBatch(int batchSize, ...@@ -124,19 +117,14 @@ void MKLPackedRecurrentLayer::backwardBatch(int batchSize,
REGISTER_TIMER_INFO("RecurrentBwWeight", getName().c_str()); REGISTER_TIMER_INFO("RecurrentBwWeight", getName().c_str());
for (size_t seq = 0; seq < numSequences; ++seq) { for (size_t seq = 0; seq < numSequences; ++seq) {
int len = starts[seq + 1] - starts[seq]; int len = starts[seq + 1] - starts[seq];
if (!reversed_) { weight_->getWGrad()->mul(
weight_->getWGrad()->mul( *output_.value
*output_.value->subMatrix(starts[seq], len - 1)->getTranspose(), ->subMatrix(reversed_ ? starts[seq] + 1 : starts[seq], len - 1)
*output_.grad->subMatrix(starts[seq] + 1, len - 1), ->getTranspose(),
1, *output_.grad->subMatrix(reversed_ ? starts[seq] : starts[seq] + 1,
1); len - 1),
} else { 1,
weight_->getWGrad()->mul( 1);
*output_.value->subMatrix(starts[seq] + 1, len - 1)->getTranspose(),
*output_.grad->subMatrix(starts[seq], len - 1),
1,
1);
}
} }
} }
} }
......
...@@ -14,36 +14,18 @@ limitations under the License. */ ...@@ -14,36 +14,18 @@ limitations under the License. */
#pragma once #pragma once
#include <gflags/gflags.h>
#include "Layer.h"
#include "MKLPackedWeight.h" #include "MKLPackedWeight.h"
#include "RecurrentLayer.h" #include "RecurrentLayer.h"
#include "SequenceToBatch.h"
#include "paddle/utils/Stat.h"
DECLARE_bool(rnn_use_batch); DECLARE_bool(rnn_use_batch);
namespace paddle { namespace paddle {
/** /**
* @brief MKLPackedRecurrentLayer takes 1 input layer. The output size is the * @brief MKLPackedRecurrentLayer is same with RecurrentLayer but is optimized
* same with * with MKL cblas packed gemm.
* input layer. * More details:
* For each sequence [start, end] it performs the following computation: * https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/mkl/mkl_packed.md
* \f[
* out_{i} = act(in_{i}) \ \ \text{for} \ i = start \\
* out_{i} = act(in_{i} + out_{i-1} * W) \ \ \text{for} \ start < i <= end
*
* \f]
* If reversed is true, the order is reversed:
* \f[
* out_{i} = act(in_{i}) \ \ \text{for} \ i = end \\
* out_{i} = act(in_{i} + out_{i+1} * W) \ \ \text{for} \ start <= i < end
* \f]
* There are two methods to calculate rnn. One way is to compute rnn one
* sequence by one sequence. The other way is to reorganize the input
* into batches, then compute rnn one batch by one batch. Users can select
* them by rnn_use_batch flag.
*/ */
class MKLPackedRecurrentLayer : public RecurrentLayer { class MKLPackedRecurrentLayer : public RecurrentLayer {
...@@ -66,7 +48,10 @@ protected: ...@@ -66,7 +48,10 @@ protected:
const int* starts) override; const int* starts) override;
protected: protected:
/// packed_weight_ is contains same data with
/// RecurrentLayer::weight_ but is packed
std::unique_ptr<MKLPackedWeight> packed_weight_; std::unique_ptr<MKLPackedWeight> packed_weight_;
/// packed_weightT_ is the transposition matrix of packed_weight_
std::unique_ptr<MKLPackedWeight> packed_weightT_; std::unique_ptr<MKLPackedWeight> packed_weightT_;
}; };
......
...@@ -22,7 +22,9 @@ namespace paddle { ...@@ -22,7 +22,9 @@ namespace paddle {
class MKLPackedWeight { class MKLPackedWeight {
protected: protected:
/// The pointor of weight
real *weight_; real *weight_;
/// The pointor of cblas packed gemm to weight
real *packedWeight_; real *packedWeight_;
size_t height_; size_t height_;
size_t width_; size_t width_;
...@@ -41,7 +43,7 @@ public: ...@@ -41,7 +43,7 @@ public:
void pack() { pack_(weight_); } void pack() { pack_(weight_); }
void compute(MatrixPtr dst, MatrixPtr src) { void compute(MatrixPtr dst, const MatrixPtr src) {
cblas_sgemm_compute(CblasRowMajor, cblas_sgemm_compute(CblasRowMajor,
CblasNoTrans, CblasNoTrans,
CblasPacked, CblasPacked,
...@@ -57,22 +59,6 @@ public: ...@@ -57,22 +59,6 @@ public:
dst->getWidth()); dst->getWidth());
} }
void compute(size_t M, real *A, size_t lda, real *C, size_t ldc) {
cblas_sgemm_compute(CblasRowMajor,
CblasNoTrans,
CblasPacked,
M,
width_,
height_,
A,
lda,
packedWeight_,
width_,
1.0,
C,
ldc);
}
protected: protected:
void pack_(real *src) { void pack_(real *src) {
if (!packedWeight_) { if (!packedWeight_) {
......
...@@ -13,10 +13,6 @@ See the License for the specific language governing permissions and ...@@ -13,10 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "RecurrentLayer.h" #include "RecurrentLayer.h"
#include <gflags/gflags.h>
#include "Layer.h"
#include "SequenceToBatch.h"
#include "paddle/utils/Stat.h"
DEFINE_bool(rnn_use_batch, false, "Using the batch method for calculation."); DEFINE_bool(rnn_use_batch, false, "Using the batch method for calculation.");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册