From 82091035514c0ddeae2c18ff5f523a2647d59948 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 22 Dec 2017 13:43:25 +0800 Subject: [PATCH] follow comments and refine code --- paddle/gserver/layers/MKLPackedGemm.h | 95 --------- .../layers/MKLPackedRecurrentLayer.cpp | 191 ++---------------- .../gserver/layers/MKLPackedRecurrentLayer.h | 87 ++------ paddle/gserver/layers/MKLPackedWeight.h | 100 +++++++++ 4 files changed, 125 insertions(+), 348 deletions(-) delete mode 100644 paddle/gserver/layers/MKLPackedGemm.h create mode 100644 paddle/gserver/layers/MKLPackedWeight.h diff --git a/paddle/gserver/layers/MKLPackedGemm.h b/paddle/gserver/layers/MKLPackedGemm.h deleted file mode 100644 index 91e2515e3..000000000 --- a/paddle/gserver/layers/MKLPackedGemm.h +++ /dev/null @@ -1,95 +0,0 @@ -/* Copyright (c) 2017 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/math/MathFunctions.h" -#include "paddle/math/Matrix.h" - -namespace paddle { - -class MKLPackedGemm { -protected: - real* weightPacked_; - real* weightTPacked_; - size_t weightHeight_; - size_t weightWidth_; - -public: - explicit MKLPackedGemm(MatrixPtr weight) { - weightHeight_ = weight->getHeight(); - weightWidth_ = weight->getWidth(); - weightPacked_ = - cblas_sgemm_alloc(CblasBMatrix, 1, weightWidth_, weightHeight_); - weightTPacked_ = - cblas_sgemm_alloc(CblasBMatrix, 1, weightWidth_, weightHeight_); - cblas_sgemm_pack(CblasRowMajor, - CblasBMatrix, - CblasNoTrans, - 1, - weightWidth_, - weightHeight_, - 1.0, - weight->getData(), - weightWidth_, - weightPacked_); - cblas_sgemm_pack(CblasRowMajor, - CblasBMatrix, - CblasTrans, - 1, - weightWidth_, - weightHeight_, - 1.0, - weight->getData(), - weightWidth_, - weightTPacked_); - } - void compute(MatrixPtr batch2, MatrixPtr batch1, bool transW = false) { - if (transW) { - cblas_sgemm_compute(CblasRowMajor, - CblasNoTrans, - CblasPacked, - batch2->getHeight(), - weightWidth_, - weightHeight_, - batch1->getData(), - weightHeight_, - weightTPacked_, - weightWidth_, - 1, - batch2->getData(), - weightWidth_); - } else { - cblas_sgemm_compute(CblasRowMajor, - CblasNoTrans, - CblasPacked, - batch2->getHeight(), - weightWidth_, - weightHeight_, - batch1->getData(), - weightHeight_, - weightPacked_, - weightWidth_, - 1, - batch2->getData(), - weightWidth_); - } - } - ~MKLPackedGemm() { - cblas_sgemm_free(weightPacked_); - cblas_sgemm_free(weightTPacked_); - } -}; - -} // namespace paddle diff --git a/paddle/gserver/layers/MKLPackedRecurrentLayer.cpp b/paddle/gserver/layers/MKLPackedRecurrentLayer.cpp index 6f455af91..bd3c4ceb5 100644 --- a/paddle/gserver/layers/MKLPackedRecurrentLayer.cpp +++ b/paddle/gserver/layers/MKLPackedRecurrentLayer.cpp @@ -20,188 +20,21 @@ REGISTER_LAYER(mkl_packed_recurrent, MKLPackedRecurrentLayer); bool MKLPackedRecurrentLayer::init(const LayerMap& layerMap, const ParameterMap& parameterMap) { - if (!Layer::init(layerMap, parameterMap)) return false; - CHECK_EQ(1U, inputLayers_.size()); - CHECK_EQ(1U, parameters_.size()); - CHECK_EQ(getSize() * getSize(), parameters_[0]->getSize()); - weight_.reset(new Weight(getSize(), getSize(), parameters_[0])); - if (biasParameter_.get() != NULL) { - bias_.reset(new Weight(1, getSize(), biasParameter_)); + if (!RecurrentLayer::init(layerMap, parameterMap)) return false; + packed_weight_.reset(new MKLPackedWeight(weight_->getW())); + packed_weight_->pack(); + if (needGradient_) { + packed_weightT_.reset(new MKLPackedWeight(weight_->getW(), true)); + packed_weightT_->pack(); } - reversed_ = config_.reversed(); - - sgemm_packed_.reset(new MKLPackedGemm(weight_->getW())); - return true; } -void MKLPackedRecurrentLayer::resetState() { - CHECK(!reversed_) << "state is not allowed for reversed recurrent layer"; - Matrix::resizeOrCreate( - prevOutput_, 1, getSize(), /* trans= */ false, useGpu_); - prevOutput_->zeroMem(); -} - -void MKLPackedRecurrentLayer::setState(LayerStatePtr state) { - CHECK(state->value.size() == 1) << "one matrix is expected for RNN state"; - prevOutput_->copyFrom(*(state->value[0])); -} - -LayerStatePtr MKLPackedRecurrentLayer::getState() { - LayerStatePtr res = std::make_shared(); - res->value.push_back(prevOutput_->clone(0, 0, useGpu_)); - res->value[0]->copyFrom(*prevOutput_); - return res; -} - -void MKLPackedRecurrentLayer::forward(PassType passType) { - REGISTER_TIMER_INFO("RecurrentFwTimer", getName().c_str()); - Layer::forward(passType); - const Argument& input = getInput(0); - CHECK(input.sequenceStartPositions); - int batchSize = input.getBatchSize(); - size_t numSequences = input.getNumSequences(); - resetOutput(batchSize, getSize()); - CHECK_EQ(getSize(), input.value->getWidth()); - const int* starts = input.sequenceStartPositions->getData(false); - CHECK_EQ(starts[numSequences], batchSize); - - output_.value->assign(*input.value); - if (bias_) { - output_.value->addBias(*bias_->getW(), 1); - } - if (!FLAGS_rnn_use_batch) { - forwardSequence(batchSize, numSequences, starts); - } else { - forwardBatch(batchSize, numSequences, starts); - } -} - -void MKLPackedRecurrentLayer::forwardSequence(int batchSize, - size_t numSequences, - const int* starts) { - REGISTER_TIMER_INFO("RecurrentFwSequence", getName().c_str()); - - frameOutput_.reserve(batchSize); - for (int i = frameOutput_.size(); i < batchSize; ++i) { - Argument arg; - arg.value = Matrix::create(nullptr, - /* height= */ 1, - getSize(), - /* trans= */ false, - useGpu_); - arg.grad = Matrix::create(nullptr, - /* height= */ 1, - getSize(), - /* trans= */ false, - useGpu_); - frameOutput_.push_back(arg); - } - - for (int i = 0; i < batchSize; ++i) { - frameOutput_[i].value->setData(output_.value->getData() + i * getSize()); - } - - for (size_t i = 0; i < numSequences; ++i) { - forwardOneSequence(starts[i], starts[i + 1] - starts[i]); - } -} - -void MKLPackedRecurrentLayer::forwardOneSequence(int start, int length) { - if (!reversed_) { - if (prevOutput_) { - frameOutput_[start].value->mul(*prevOutput_, *weight_->getW(), 1, 1); - } - activation_->forward(frameOutput_[start]).check(); - - for (int i = 1; i < length; ++i) { - frameOutput_[start + i].value->mul( - *frameOutput_[start + i - 1].value, *weight_->getW(), 1, 1); - activation_->forward(frameOutput_[start + i]).check(); - } - if (prevOutput_) { - prevOutput_->assign(*frameOutput_[start + length - 1].value); - } - } else { - activation_->forward(frameOutput_[start + length - 1]).check(); - for (int i = length - 2; i >= 0; --i) { - frameOutput_[start + i].value->mul( - *frameOutput_[start + i + 1].value, *weight_->getW(), 1, 1); - activation_->forward(frameOutput_[start + i]).check(); - } - } -} - void MKLPackedRecurrentLayer::backward(const UpdateCallback& callback) { - REGISTER_TIMER_INFO("RecurrentBwTimer", getName().c_str()); - const Argument& input = getInput(0); - CHECK(input.sequenceStartPositions); - int batchSize = input.getBatchSize(); - const int* starts = input.sequenceStartPositions->getData(false); - size_t numSequences = input.getNumSequences(); - - if (!FLAGS_rnn_use_batch) { - backwardSequence(batchSize, numSequences, starts); - } else { - backwardBatch(batchSize, numSequences, starts); - } - - if (input.grad) { - input.grad->add(*output_.grad); - } - - if (bias_ && bias_->getWGrad()) { - bias_->getWGrad()->collectBias(*output_.grad, 1); - bias_->getParameterPtr()->incUpdate(callback); - } - - weight_->getParameterPtr()->incUpdate(callback); - sgemm_packed_.reset(new MKLPackedGemm(weight_->getW())); -} - -void MKLPackedRecurrentLayer::backwardSequence(int batchSize, - size_t numSequences, - const int* starts) { - REGISTER_TIMER_INFO("RecurrentBwSequence", getName().c_str()); - for (int i = 0; i < batchSize; ++i) { - frameOutput_[i].grad->setData(output_.grad->getData() + i * getSize()); - } - - for (size_t i = 0; i < numSequences; ++i) { - backwardOneSequence(starts[i], starts[i + 1] - starts[i]); - } -} - -void MKLPackedRecurrentLayer::backwardOneSequence(int start, int length) { - MatrixPtr weightT = weight_->getW()->getTranspose(); - if (!reversed_) { - for (int i = length - 1; i > 0; --i) { - activation_->backward(frameOutput_[start + i]).check(); - frameOutput_[start + i - 1].grad->mul( - *frameOutput_[start + i].grad, *weightT, 1, 1); - } - activation_->backward(frameOutput_[start]).check(); - if (weight_->getWGrad()) { - weight_->getWGrad()->mul( - *output_.value->subMatrix(start, length - 1)->getTranspose(), - *output_.grad->subMatrix(start + 1, length - 1), - 1, - 1); - } - } else { - for (int i = 0; i < length - 1; ++i) { - activation_->backward(frameOutput_[start + i]).check(); - frameOutput_[start + i + 1].grad->mul( - *frameOutput_[start + i].grad, *weightT, 1, 1); - } - activation_->backward(frameOutput_[start + length - 1]).check(); - if (weight_->getWGrad()) { - weight_->getWGrad()->mul( - *output_.value->subMatrix(start + 1, length - 1)->getTranspose(), - *output_.grad->subMatrix(start, length - 1), - 1, - 1); - } + RecurrentLayer::backward(callback); + packed_weight_->pack(); + if (needGradient_) { + packed_weightT_->pack(); } } @@ -227,7 +60,7 @@ void MKLPackedRecurrentLayer::forwardBatch(int batchSize, batchValue_->getBatchValue(n - 1, batch2->getHeight()); // batch2->mul(*batch1, *weight_->getW(), 1, 1); - sgemm_packed_->compute(batch2, batch1); + packed_weight_->compute(batch2, batch1); } #pragma omp parallel for collapse(2) @@ -272,7 +105,7 @@ void MKLPackedRecurrentLayer::backwardBatch(int batchSize, if (n != 0) { batch1 = batchGrad_->getBatchValue(n - 1, batch2->getHeight()); // batch1->mul(*batch2, *weightT, 1, 1); - sgemm_packed_->compute(batch1, batch2, true); + packed_weightT_->compute(batch1, batch2); } if (backwardByBatch && weight_->getWGrad()) { diff --git a/paddle/gserver/layers/MKLPackedRecurrentLayer.h b/paddle/gserver/layers/MKLPackedRecurrentLayer.h index b8727e0ff..ba6487b11 100644 --- a/paddle/gserver/layers/MKLPackedRecurrentLayer.h +++ b/paddle/gserver/layers/MKLPackedRecurrentLayer.h @@ -16,7 +16,8 @@ limitations under the License. */ #include #include "Layer.h" -#include "MKLPackedGemm.h" +#include "MKLPackedWeight.h" +#include "RecurrentLayer.h" #include "SequenceToBatch.h" #include "paddle/utils/Stat.h" @@ -45,90 +46,28 @@ namespace paddle { * them by rnn_use_batch flag. */ -class MKLPackedRecurrentLayer : public Layer { +class MKLPackedRecurrentLayer : public RecurrentLayer { public: - explicit MKLPackedRecurrentLayer(const LayerConfig& config) : Layer(config) {} + explicit MKLPackedRecurrentLayer(const LayerConfig& config) + : RecurrentLayer(config) {} bool init(const LayerMap& layerMap, const ParameterMap& parameterMap) override; - void forward(PassType passType) override; - void backward(const UpdateCallback& callback) override; - void resetState() override; - - void setState(LayerStatePtr state) override; - - LayerStatePtr getState() override; - protected: - /** - * @brief If user do not set --rnn_use_batch=true, it will - * compute rnn forward one sequence by one sequence in default. - * @param batchSize Total words number of all samples in this batch. - * @param numSequences The sample number. - * @param starts Each start position of each samples. - */ - void forwardSequence(int batchSize, size_t numSequences, const int* starts); - /** - * @brief Compute rnn forward by one sequence. - * @param start The start position of this sequence (or sample). - * @param length The length of this sequence (or sample), namely the words - * number of this sequence. - */ - void forwardOneSequence(int start, int length); - /** - * @brief Compute rnn backward one sequence by onesequence. - * @param batchSize Total words number of all samples in this batch. - * @param numSequences The sample number. - * @param starts Each start position of each samples. - */ - void backwardSequence(int batchSize, size_t numSequences, const int* starts); - /** - * @brief Compute rnn backward by one sequence. - * @param start The start position of this sequence (or sample). - * @param length The length of this sequence (or sample), namely the words - * number of this sequence. - */ - void backwardOneSequence(int start, int length); + void forwardBatch(int batchSize, + size_t numSequences, + const int* starts) override; - /** - * @brief Reorganize input into batches and compute rnn forward batch - * by batch. It will convert batch shape to sequence after finishing forward. - * The batch info can refer to SequenceToBatch class. - * @param batchSize Total words number of all samples in this batch. - * @param numSequences The sample number. - * @param starts Each start position of each samples. - */ - void forwardBatch(int batchSize, size_t numSequences, const int* starts); - - /** - * @brief Reorganize input into batches and compute rnn forward batch - * by batch. - * @param batchSize Total words number of all samples in this batch. - * @param numSequences The sample number. - * @param starts Each start position of each samples. - */ - void backwardBatch(int batchSize, size_t numSequences, const int* starts); + void backwardBatch(int batchSize, + size_t numSequences, + const int* starts) override; protected: - std::unique_ptr weight_; - std::unique_ptr bias_; - - /// frameOutput_[i] is used to hold the i-th sample of output_ - std::vector frameOutput_; - MatrixPtr prevOutput_; - /// Whether compute rnn by reverse. - bool reversed_; - /// If compute batch by batch, batchValue_ will be used to save the - /// reorganized input value. - std::unique_ptr batchValue_; - /// If compute batch by batch, batchGrad_ will be used to save the - /// gradient with respect to reorganized input value. - std::unique_ptr batchGrad_; - - std::unique_ptr sgemm_packed_; + std::unique_ptr packed_weight_; + std::unique_ptr packed_weightT_; }; } // namespace paddle diff --git a/paddle/gserver/layers/MKLPackedWeight.h b/paddle/gserver/layers/MKLPackedWeight.h new file mode 100644 index 000000000..a8dcfd561 --- /dev/null +++ b/paddle/gserver/layers/MKLPackedWeight.h @@ -0,0 +1,100 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/math/MathFunctions.h" +#include "paddle/parameter/Parameter.h" +#include "paddle/parameter/Weight.h" + +namespace paddle { + +class MKLPackedWeight { +protected: + real *weight_; + real *packedWeight_; + size_t height_; + size_t width_; + bool transW_; + +public: + MKLPackedWeight(MatrixPtr weight, bool transW = false) { + packedWeight_ = nullptr; + weight_ = weight->getData(); + height_ = weight->getHeight(); + width_ = weight->getWidth(); + transW_ = transW; + } + + ~MKLPackedWeight() { free_(); } + + void pack() { pack_(weight_); } + + void compute(MatrixPtr dst, MatrixPtr src) { + cblas_sgemm_compute(CblasRowMajor, + CblasNoTrans, + CblasPacked, + src->getHeight(), + transW_ ? height_ : width_, + transW_ ? width_ : height_, + src->getData(), + src->getWidth(), + packedWeight_, + width_, + 1.0, + dst->getData(), + dst->getWidth()); + } + + void compute(size_t M, real *A, size_t lda, real *C, size_t ldc) { + cblas_sgemm_compute(CblasRowMajor, + CblasNoTrans, + CblasPacked, + M, + width_, + height_, + A, + lda, + packedWeight_, + width_, + 1.0, + C, + ldc); + } + +protected: + void pack_(real *src) { + if (!packedWeight_) { + packedWeight_ = cblas_sgemm_alloc(CblasBMatrix, 1, width_, height_); + } + cblas_sgemm_pack(CblasRowMajor, + CblasBMatrix, + transW_ ? CblasTrans : CblasNoTrans, + 1, + transW_ ? height_ : width_, + transW_ ? width_ : height_, + 1.0, + src, + width_, + packedWeight_); + } + + void free_() { + if (packedWeight_) { + cblas_sgemm_free(packedWeight_); + } + } +}; + +} // namespace paddle -- GitLab