From c961fbf09a6d3b3ef61a426d725116f3ef510069 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Mon, 20 Nov 2017 17:18:01 +0800 Subject: [PATCH] change the condition to reset the forward in MKLDNNLayer --- paddle/gserver/layers/MKLDNNAddtoLayer.cpp | 3 ++- .../gserver/layers/MKLDNNBatchNormLayer.cpp | 3 ++- paddle/gserver/layers/MKLDNNConcatLayer.cpp | 3 ++- paddle/gserver/layers/MKLDNNConcatLayer.h | 9 ++++++++ paddle/gserver/layers/MKLDNNLayer.cpp | 7 ++---- paddle/gserver/layers/MKLDNNLayer.h | 22 ++++++++++--------- paddle/gserver/layers/MKLDNNPoolLayer.cpp | 3 ++- 7 files changed, 31 insertions(+), 19 deletions(-) diff --git a/paddle/gserver/layers/MKLDNNAddtoLayer.cpp b/paddle/gserver/layers/MKLDNNAddtoLayer.cpp index 0eeea821d21..39bffc26f7d 100644 --- a/paddle/gserver/layers/MKLDNNAddtoLayer.cpp +++ b/paddle/gserver/layers/MKLDNNAddtoLayer.cpp @@ -43,7 +43,8 @@ void MKLDNNAddtoLayer::reshape( reshapeInput(bs, ih, iw); ic = inputLayers_[0]->getSize() / ih / iw; CHECK_EQ((size_t)ic * ih * iw, inputLayers_[0]->getSize()); - CHECK_EQ(inputElemenCnt_, (size_t)bs * ic * ih * iw); + CHECK_EQ(inputLayers_[0]->getOutputValue()->getElementCnt(), + (size_t)bs * ic * ih * iw); for (size_t i = 0; i < inputLayers_.size(); i++) { CHECK_EQ(int64_t(bs), inputLayers_[i]->getOutput().getBatchSize()); CHECK_EQ(layerSize_, inputLayers_[i]->getSize()); diff --git a/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp index 63f9bb27956..d66c361ae05 100644 --- a/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp +++ b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp @@ -121,7 +121,8 @@ void MKLDNNBatchNormLayer::reshape( oh = ih; ow = iw; // ic_ and oc can not be changed - CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic) + CHECK_EQ((size_t)ic, + inputLayers_[0]->getOutputValue()->getElementCnt() / bs / ih / iw) << "Input channel can not be changed"; reshapeOutput(oh, ow); resizeOutput(bs, oc * oh * ow); diff --git a/paddle/gserver/layers/MKLDNNConcatLayer.cpp b/paddle/gserver/layers/MKLDNNConcatLayer.cpp index 8311fe61ae9..44bb0883b89 100644 --- a/paddle/gserver/layers/MKLDNNConcatLayer.cpp +++ b/paddle/gserver/layers/MKLDNNConcatLayer.cpp @@ -36,7 +36,8 @@ void MKLDNNConcatLayer::reshape( reshapeInput(bs, ih, iw); ic = inputLayers_[0]->getSize() / ih / iw; CHECK_EQ((size_t)ic * ih * iw, inputLayers_[0]->getSize()); - CHECK_EQ(inputElemenCnt_, (size_t)bs * ic * ih * iw); + CHECK_EQ(inputLayers_[0]->getOutputValue()->getElementCnt(), + (size_t)bs * ic * ih * iw); CHECK_GT(inputLayers_.size(), 1UL); channels_.resize(inputLayers_.size()); channels_[0] = ic; diff --git a/paddle/gserver/layers/MKLDNNConcatLayer.h b/paddle/gserver/layers/MKLDNNConcatLayer.h index f9357a161a3..37f3a26c5ed 100644 --- a/paddle/gserver/layers/MKLDNNConcatLayer.h +++ b/paddle/gserver/layers/MKLDNNConcatLayer.h @@ -66,6 +66,15 @@ public: << ", " << ow_; } + size_t keepCondition() { + // reset when the total element size of all inputs changed + size_t totalSize = inputLayers_[0]->getOutputValue()->getElementCnt(); + for (size_t i = 1; i < inputLayers_.size(); ++i) { + totalSize += inputLayers_[i]->getOutputValue()->getElementCnt(); + } + return totalSize; + } + protected: void resetFwdBuffers(std::vector& inputs, MKLDNNMatrixPtr& out); diff --git a/paddle/gserver/layers/MKLDNNLayer.cpp b/paddle/gserver/layers/MKLDNNLayer.cpp index 3c783e7e72e..28969d01a13 100644 --- a/paddle/gserver/layers/MKLDNNLayer.cpp +++ b/paddle/gserver/layers/MKLDNNLayer.cpp @@ -48,16 +48,13 @@ void MKLDNNLayer::forward(PassType passType) { REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str()); CHECK(!inputLayers_.empty()); copySeqInfoToOutputs(); - size_t elemenCnt = inputLayers_[0]->getOutputValue()->getElementCnt(); - if (inputElemenCnt_ != elemenCnt) { + if (condition_ != keepCondition()) { VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward"; - // reset when input total sizes changed, not only the batchsize - inputElemenCnt_ = elemenCnt; + condition_ = keepCondition(); reshape(bs_, ic_, ih_, iw_, oc_, oh_, ow_); printSizeInfo(); // the output_.value and output_.grad are shared with CPU device shareCPUDevice(); - pipelineFwd_.clear(); inVals_.resize(inputLayers_.size(), nullptr); extInVals_.resize(inputLayers_.size(), nullptr); diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index 532e66d9788..907927f984f 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -34,8 +34,6 @@ typedef std::shared_ptr MKLDNNLayerPtr; */ class MKLDNNLayer : public Layer { protected: - // input value element count - size_t inputElemenCnt_; // batch size int bs_; // they sizes are always from the first input layer @@ -44,6 +42,8 @@ protected: // output image channel, height and width int oc_, oh_, ow_; + // the condition that forward need be reset + size_t condition_; // backward also need reset after reset forward handle bool needResetBwd_; @@ -103,14 +103,7 @@ protected: public: explicit MKLDNNLayer(const LayerConfig& config) : Layer(config), - inputElemenCnt_(0), - bs_(0), - ic_(0), - ih_(0), - iw_(0), - oc_(0), - oh_(0), - ow_(0), + condition_(0), needResetBwd_(true), outputOnlyMKLDNN_(false), engine_(mkldnn::engine::cpu, 0), @@ -173,6 +166,15 @@ public: void addOutputArgument(int deviceId) { Layer::addOutputArgument(deviceId); } protected: + /** + * Some layers may have different condition to reset the forward. + * The function returns the condition that do not need reset forward. + */ + inline virtual size_t keepCondition() { + // reset when the first input element size changed, not only the batchsize + return inputLayers_[0]->getOutputValue()->getElementCnt(); + } + /** * reshape the input image sizes and input batchsize */ diff --git a/paddle/gserver/layers/MKLDNNPoolLayer.cpp b/paddle/gserver/layers/MKLDNNPoolLayer.cpp index 86122f93c5c..a8252593c8f 100644 --- a/paddle/gserver/layers/MKLDNNPoolLayer.cpp +++ b/paddle/gserver/layers/MKLDNNPoolLayer.cpp @@ -61,7 +61,8 @@ void MKLDNNPoolLayer::reshape( int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) { reshapeInput(bs, ih, iw); // ic_ and oc can not be changed - CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic) + CHECK_EQ((size_t)ic, + inputLayers_[0]->getOutputValue()->getElementCnt() / bs / ih / iw) << "Input channel can not be changed"; // cal output sizes -- GitLab