From d34780e1931c05b1ab98664be102b1d69b030729 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 7 Nov 2017 11:44:53 +0800 Subject: [PATCH] fix issue for resnet --- paddle/gserver/layers/MKLDNNFcLayer.cpp | 6 ++---- paddle/gserver/layers/MKLDNNLayer.cpp | 14 +++++--------- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp index d82063a7130..3429c53d239 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.cpp +++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp @@ -60,18 +60,16 @@ void MKLDNNFcLayer::convertWeightsFromPaddle() { } CHECK(wgtVal_) << "should have been initialized"; - bool hasNoSpatial_ = ih_ == 1 && iw_ == 1; auto targetDim = wgtVal_->getDims(); - auto srcFmt = hasNoSpatial_ ? format::io : format::ihwo; + auto srcFmt = targetDim.size() == 2 ? format::io : format::ihwo; wgtVal_->reorderDataFrom(wgtVal_, srcFmt, targetDim); hasInitedWgt_ = true; } void MKLDNNFcLayer::convertWeightsToPaddle() { CHECK(wgtVal_) << "should have been initialized"; - bool hasNoSpatial_ = ih_ == 1 && iw_ == 1; auto targetDim = wgtVal_->getDims(); - auto dstFmt = hasNoSpatial_ ? format::io : format::ihwo; + auto dstFmt = targetDim.size() == 2 ? format::io : format::ihwo; wgtVal_->reorderDataTo(wgtVal_, dstFmt, targetDim); } diff --git a/paddle/gserver/layers/MKLDNNLayer.cpp b/paddle/gserver/layers/MKLDNNLayer.cpp index 5fd62f4f73b..82ef344c7b2 100644 --- a/paddle/gserver/layers/MKLDNNLayer.cpp +++ b/paddle/gserver/layers/MKLDNNLayer.cpp @@ -181,21 +181,17 @@ void MKLDNNLayer::resetInValue( auto extPD = MKLDNNMatrix::createPrimitiveDesc( {bs_, ic_, ih_, iw_}, format::nchw, engine_); const MatrixPtr& inMat = inputLayers_[inputIdx]->getOutputValue(); - in = std::dynamic_pointer_cast(inMat); - CHECK_EQ(inputIsOnlyMKLDNN(), in != nullptr); - if (in == nullptr || in->getFormat() == format::nc) { - in = MKLDNNMatrix::create(extPD, inMat); - } - extInVal_ = isPaddleFormat(in->getFormat()) ? in : nullptr; - if (in->getFormat() == format::nc) { - CHECK(ih_ == 1 && iw_ == 1); + extInVal_ = std::dynamic_pointer_cast(inMat); + CHECK_EQ(inputIsOnlyMKLDNN(), extInVal_ != nullptr); + if (extInVal_ == nullptr || extInVal_->getFormat() == format::nc) { + extInVal_ = MKLDNNMatrix::create(extPD, inMat); } + in = extInVal_; if (nullptr == intPD || in->getPrimitiveDesc() == *intPD) { return; } // need create reorder in = MKLDNNMatrix::create(*intPD); - extInVal_ = extInVal_ ? extInVal_ : MKLDNNMatrix::create(extPD, inMat); cvtInVal_ = MKLDNNMatrix::createReorder(extInVal_, in); CHECK(cvtInVal_) << "should not be emptry"; } -- GitLab