提交 89cb3a24 编写于 作者: T tensor-tang

follow comments, refine comment and function name

上级 adf79faa
...@@ -59,7 +59,7 @@ void MKLPackedRecurrentLayer::forwardBatch(int batchSize, ...@@ -59,7 +59,7 @@ void MKLPackedRecurrentLayer::forwardBatch(int batchSize,
MatrixPtr preBatchValue = MatrixPtr preBatchValue =
batchValue_->getBatchValue(n - 1, batchValue->getHeight()); batchValue_->getBatchValue(n - 1, batchValue->getHeight());
packed_weight_->compute(batchValue, preBatchValue); packed_weight_->gemm_compute(preBatchValue, batchValue);
} }
Argument arg; Argument arg;
arg.value = batchValue; arg.value = batchValue;
...@@ -96,7 +96,7 @@ void MKLPackedRecurrentLayer::backwardBatch(int batchSize, ...@@ -96,7 +96,7 @@ void MKLPackedRecurrentLayer::backwardBatch(int batchSize,
if (n != 0) { if (n != 0) {
batchValue = batchGrad_->getBatchValue(n - 1, batchGrad->getHeight()); batchValue = batchGrad_->getBatchValue(n - 1, batchGrad->getHeight());
packed_weightT_->compute(batchValue, batchGrad); packed_weightT_->gemm_compute(batchGrad, batchValue);
} }
if (backwardByBatch && weight_->getWGrad()) { if (backwardByBatch && weight_->getWGrad()) {
......
...@@ -22,8 +22,8 @@ DECLARE_bool(rnn_use_batch); ...@@ -22,8 +22,8 @@ DECLARE_bool(rnn_use_batch);
namespace paddle { namespace paddle {
/** /**
* @brief MKLPackedRecurrentLayer is same with RecurrentLayer but is optimized * @brief MKLPackedRecurrentLayer is almost the same with RecurrentLayer
* with MKL cblas packed gemm. * but is optimized with MKL cblas packed gemm.
* More details: * More details:
* https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/mkl/mkl_packed.md * https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/mkl/mkl_packed.md
*/ */
...@@ -48,7 +48,7 @@ protected: ...@@ -48,7 +48,7 @@ protected:
const int* starts) override; const int* starts) override;
protected: protected:
/// packed_weight_ is contains same data with /// packed_weight_ contains same data with
/// RecurrentLayer::weight_ but is packed /// RecurrentLayer::weight_ but is packed
std::unique_ptr<MKLPackedWeight> packed_weight_; std::unique_ptr<MKLPackedWeight> packed_weight_;
/// packed_weightT_ is the transposition matrix of packed_weight_ /// packed_weightT_ is the transposition matrix of packed_weight_
......
...@@ -22,9 +22,9 @@ namespace paddle { ...@@ -22,9 +22,9 @@ namespace paddle {
class MKLPackedWeight { class MKLPackedWeight {
protected: protected:
/// The pointor of weight /// The pointer of weight
real *weight_; real *weight_;
/// The pointor of cblas packed gemm to weight /// The pointer of cblas packed gemm to weight
real *packedWeight_; real *packedWeight_;
size_t height_; size_t height_;
size_t width_; size_t width_;
...@@ -43,7 +43,7 @@ public: ...@@ -43,7 +43,7 @@ public:
void pack() { pack_(weight_); } void pack() { pack_(weight_); }
void compute(MatrixPtr dst, const MatrixPtr src) { void gemm_compute(const MatrixPtr src, MatrixPtr dst) {
cblas_sgemm_compute(CblasRowMajor, cblas_sgemm_compute(CblasRowMajor,
CblasNoTrans, CblasNoTrans,
CblasPacked, CblasPacked,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册