diff --git a/paddle/gserver/layers/MKLDNNAddtoLayer.cpp b/paddle/gserver/layers/MKLDNNAddtoLayer.cpp index 0eeea821d21d2824f904b9fa82c1c1b8ebd96df9..39bffc26f7ddcd159130c492115b41080e32ce7f 100644 --- a/paddle/gserver/layers/MKLDNNAddtoLayer.cpp +++ b/paddle/gserver/layers/MKLDNNAddtoLayer.cpp @@ -43,7 +43,8 @@ void MKLDNNAddtoLayer::reshape( reshapeInput(bs, ih, iw); ic = inputLayers_[0]->getSize() / ih / iw; CHECK_EQ((size_t)ic * ih * iw, inputLayers_[0]->getSize()); - CHECK_EQ(inputElemenCnt_, (size_t)bs * ic * ih * iw); + CHECK_EQ(inputLayers_[0]->getOutputValue()->getElementCnt(), + (size_t)bs * ic * ih * iw); for (size_t i = 0; i < inputLayers_.size(); i++) { CHECK_EQ(int64_t(bs), inputLayers_[i]->getOutput().getBatchSize()); CHECK_EQ(layerSize_, inputLayers_[i]->getSize()); diff --git a/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp index 63f9bb27956aefcede234d5af55a46d8cd7e8686..d66c361ae05e4a1089786e4620d2eb2218a8a29c 100644 --- a/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp +++ b/paddle/gserver/layers/MKLDNNBatchNormLayer.cpp @@ -121,7 +121,8 @@ void MKLDNNBatchNormLayer::reshape( oh = ih; ow = iw; // ic_ and oc can not be changed - CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic) + CHECK_EQ((size_t)ic, + inputLayers_[0]->getOutputValue()->getElementCnt() / bs / ih / iw) << "Input channel can not be changed"; reshapeOutput(oh, ow); resizeOutput(bs, oc * oh * ow); diff --git a/paddle/gserver/layers/MKLDNNConcatLayer.cpp b/paddle/gserver/layers/MKLDNNConcatLayer.cpp index 8311fe61ae96ec669e5f930c614f0a150b9ad30e..44bb0883b89c712d70e2d4fdfe16bdfde86f81b7 100644 --- a/paddle/gserver/layers/MKLDNNConcatLayer.cpp +++ b/paddle/gserver/layers/MKLDNNConcatLayer.cpp @@ -36,7 +36,8 @@ void MKLDNNConcatLayer::reshape( reshapeInput(bs, ih, iw); ic = inputLayers_[0]->getSize() / ih / iw; CHECK_EQ((size_t)ic * ih * iw, inputLayers_[0]->getSize()); - CHECK_EQ(inputElemenCnt_, (size_t)bs * ic * ih * iw); + CHECK_EQ(inputLayers_[0]->getOutputValue()->getElementCnt(), + (size_t)bs * ic * ih * iw); CHECK_GT(inputLayers_.size(), 1UL); channels_.resize(inputLayers_.size()); channels_[0] = ic; diff --git a/paddle/gserver/layers/MKLDNNConcatLayer.h b/paddle/gserver/layers/MKLDNNConcatLayer.h index f9357a161a34166d3fbb9724375134d95626e1e1..37f3a26c5ed5db10cdba507368874c9557fb75ef 100644 --- a/paddle/gserver/layers/MKLDNNConcatLayer.h +++ b/paddle/gserver/layers/MKLDNNConcatLayer.h @@ -66,6 +66,15 @@ public: << ", " << ow_; } + size_t keepCondition() { + // reset when the total element size of all inputs changed + size_t totalSize = inputLayers_[0]->getOutputValue()->getElementCnt(); + for (size_t i = 1; i < inputLayers_.size(); ++i) { + totalSize += inputLayers_[i]->getOutputValue()->getElementCnt(); + } + return totalSize; + } + protected: void resetFwdBuffers(std::vector& inputs, MKLDNNMatrixPtr& out); diff --git a/paddle/gserver/layers/MKLDNNLayer.cpp b/paddle/gserver/layers/MKLDNNLayer.cpp index 3c783e7e72ec95e2b81b1277992569d2130f226e..28969d01a13b7831794cef856af11ad2ec01c31e 100644 --- a/paddle/gserver/layers/MKLDNNLayer.cpp +++ b/paddle/gserver/layers/MKLDNNLayer.cpp @@ -48,16 +48,13 @@ void MKLDNNLayer::forward(PassType passType) { REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str()); CHECK(!inputLayers_.empty()); copySeqInfoToOutputs(); - size_t elemenCnt = inputLayers_[0]->getOutputValue()->getElementCnt(); - if (inputElemenCnt_ != elemenCnt) { + if (condition_ != keepCondition()) { VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward"; - // reset when input total sizes changed, not only the batchsize - inputElemenCnt_ = elemenCnt; + condition_ = keepCondition(); reshape(bs_, ic_, ih_, iw_, oc_, oh_, ow_); printSizeInfo(); // the output_.value and output_.grad are shared with CPU device shareCPUDevice(); - pipelineFwd_.clear(); inVals_.resize(inputLayers_.size(), nullptr); extInVals_.resize(inputLayers_.size(), nullptr); diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index 532e66d9788e4ef24def05d9fb85d668820d575b..907927f984f1a7cd4a72038515569251df48d56f 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -34,8 +34,6 @@ typedef std::shared_ptr MKLDNNLayerPtr; */ class MKLDNNLayer : public Layer { protected: - // input value element count - size_t inputElemenCnt_; // batch size int bs_; // they sizes are always from the first input layer @@ -44,6 +42,8 @@ protected: // output image channel, height and width int oc_, oh_, ow_; + // the condition that forward need be reset + size_t condition_; // backward also need reset after reset forward handle bool needResetBwd_; @@ -103,14 +103,7 @@ protected: public: explicit MKLDNNLayer(const LayerConfig& config) : Layer(config), - inputElemenCnt_(0), - bs_(0), - ic_(0), - ih_(0), - iw_(0), - oc_(0), - oh_(0), - ow_(0), + condition_(0), needResetBwd_(true), outputOnlyMKLDNN_(false), engine_(mkldnn::engine::cpu, 0), @@ -173,6 +166,15 @@ public: void addOutputArgument(int deviceId) { Layer::addOutputArgument(deviceId); } protected: + /** + * Some layers may have different condition to reset the forward. + * The function returns the condition that do not need reset forward. + */ + inline virtual size_t keepCondition() { + // reset when the first input element size changed, not only the batchsize + return inputLayers_[0]->getOutputValue()->getElementCnt(); + } + /** * reshape the input image sizes and input batchsize */ diff --git a/paddle/gserver/layers/MKLDNNPoolLayer.cpp b/paddle/gserver/layers/MKLDNNPoolLayer.cpp index 86122f93c5c3d7ee97530621cdd26c1bf94f88f4..a8252593c8fbb8013ab909e74a057850ba54bcaa 100644 --- a/paddle/gserver/layers/MKLDNNPoolLayer.cpp +++ b/paddle/gserver/layers/MKLDNNPoolLayer.cpp @@ -61,7 +61,8 @@ void MKLDNNPoolLayer::reshape( int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) { reshapeInput(bs, ih, iw); // ic_ and oc can not be changed - CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic) + CHECK_EQ((size_t)ic, + inputLayers_[0]->getOutputValue()->getElementCnt() / bs / ih / iw) << "Input channel can not be changed"; // cal output sizes