diff --git a/paddle/gserver/layers/MKLPackedRecurrentLayer.cpp b/paddle/gserver/layers/MKLPackedRecurrentLayer.cpp index b4a6413048327c2e6d709e1a972354eb6dbd9aa7..dd75555fae134664d92ba9f8ffdea8af78166b7e 100644 --- a/paddle/gserver/layers/MKLPackedRecurrentLayer.cpp +++ b/paddle/gserver/layers/MKLPackedRecurrentLayer.cpp @@ -59,7 +59,7 @@ void MKLPackedRecurrentLayer::forwardBatch(int batchSize, MatrixPtr preBatchValue = batchValue_->getBatchValue(n - 1, batchValue->getHeight()); - packed_weight_->compute(batchValue, preBatchValue); + packed_weight_->gemm_compute(preBatchValue, batchValue); } Argument arg; arg.value = batchValue; @@ -96,7 +96,7 @@ void MKLPackedRecurrentLayer::backwardBatch(int batchSize, if (n != 0) { batchValue = batchGrad_->getBatchValue(n - 1, batchGrad->getHeight()); - packed_weightT_->compute(batchValue, batchGrad); + packed_weightT_->gemm_compute(batchGrad, batchValue); } if (backwardByBatch && weight_->getWGrad()) { diff --git a/paddle/gserver/layers/MKLPackedRecurrentLayer.h b/paddle/gserver/layers/MKLPackedRecurrentLayer.h index 19874d538e919ce484a444593fd831b2138c4938..bded523a8fbd6ff18f28859bd2a1bf3c1a25e2a0 100644 --- a/paddle/gserver/layers/MKLPackedRecurrentLayer.h +++ b/paddle/gserver/layers/MKLPackedRecurrentLayer.h @@ -22,8 +22,8 @@ DECLARE_bool(rnn_use_batch); namespace paddle { /** - * @brief MKLPackedRecurrentLayer is same with RecurrentLayer but is optimized - * with MKL cblas packed gemm. + * @brief MKLPackedRecurrentLayer is almost the same with RecurrentLayer + * but is optimized with MKL cblas packed gemm. * More details: * https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/mkl/mkl_packed.md */ @@ -48,7 +48,7 @@ protected: const int* starts) override; protected: - /// packed_weight_ is contains same data with + /// packed_weight_ contains same data with /// RecurrentLayer::weight_ but is packed std::unique_ptr packed_weight_; /// packed_weightT_ is the transposition matrix of packed_weight_ diff --git a/paddle/gserver/layers/MKLPackedWeight.h b/paddle/gserver/layers/MKLPackedWeight.h index f77aa4dbbf5ddb0534c91c06b08ae0c00d592883..15d5093beb43e2f086601c2616ace033da34f341 100644 --- a/paddle/gserver/layers/MKLPackedWeight.h +++ b/paddle/gserver/layers/MKLPackedWeight.h @@ -22,9 +22,9 @@ namespace paddle { class MKLPackedWeight { protected: - /// The pointor of weight + /// The pointer of weight real *weight_; - /// The pointor of cblas packed gemm to weight + /// The pointer of cblas packed gemm to weight real *packedWeight_; size_t height_; size_t width_; @@ -43,7 +43,7 @@ public: void pack() { pack_(weight_); } - void compute(MatrixPtr dst, const MatrixPtr src) { + void gemm_compute(const MatrixPtr src, MatrixPtr dst) { cblas_sgemm_compute(CblasRowMajor, CblasNoTrans, CblasPacked,