diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp index f4deb351f2bd6de912b6338d6bf47bf291a1acc2..53433cef35a377a73f87b041fdcfadd848dd2ec9 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.cpp +++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp @@ -220,13 +220,12 @@ void MKLDNNFcLayer::resetBwd() { pipelineBwd_.push_back(*bwdWgt_); /// backward data - device = inputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE; - const MatrixPtr& in = getInputGrad(0, device); + const MatrixPtr& in = inputLayers_[0]->getOutput().grad; if (in == nullptr) { return; } - if (getInput(0, device).getAllCount() > 1) { - // TODO(TJ): use outputMaps_ ways when merge outgrad done + if (getInput(0, MKLDNN_DEVICE).getAllCount() > 1) { + // TODO(TJ): use outputMaps_ ways to get the inGrad_ when merge outgrad done } else { inGrad_ = MKLDNNMatrix::create(in, inVal_->getPrimitiveDesc()); } @@ -243,13 +242,21 @@ void MKLDNNFcLayer::resetBwd() { pipelineBwd_.push_back(*bwdData_); } +void MKLDNNFcLayer::updateInputData() { + if (inputLayers_[0]->getType() != "data") { + return; + } + real* iData = getInputValue(0, CPU_DEVICE)->getData(); + inVal_->setData(iData); +} + void MKLDNNFcLayer::forward(PassType passType) { Layer::forward(passType); reshape(); { REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str()); - syncInputValue(); + updateInputData(); // just submit forward pipeline stream_->submit(pipelineFwd_); @@ -271,7 +278,6 @@ void MKLDNNFcLayer::backward(const UpdateCallback& callback) { REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str()); resetBwd(); - syncOutputGrad(); // just sumbmit backward pipeline stream_->submit(pipelineBwd_); } diff --git a/paddle/gserver/layers/MKLDNNFcLayer.h b/paddle/gserver/layers/MKLDNNFcLayer.h index e2657a8d5e9d9c4722429150543eb96111ef51b2..4ad67a16e056a718c45a28babcf22a7cd571b15c 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.h +++ b/paddle/gserver/layers/MKLDNNFcLayer.h @@ -53,6 +53,8 @@ public: void backward(const UpdateCallback& callback) override; + void updateInputData() override; + protected: /** * reshape the input image sizes diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index 1a3e949fb993e6f6a57416a4bf650764223ee388..543364edceff684bdcd002a8f4f10e7ce5e6953b 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -113,6 +113,12 @@ public: */ virtual void convertWeightsToPaddle() {} + /** + * Update input value data when input layer is "data" type. + * Since the input value data address might be changed. + */ + virtual void updateInputData() {} + /** * print info about sizes */ @@ -194,32 +200,6 @@ protected: return outputOtherDevice_.size() == 0; } - /** - * Sync input value data - */ - void syncInputValue() { - if (inputIsOnlyMKLDNN()) { - return; - } - real* iData = getInputValue(0, CPU_DEVICE)->getData(); - // update input data - // since it might be changed if this is after data layer - inVal_->setData(iData); - } - - /** - * Sync output grad data - */ - void syncOutputGrad() { - if (outputIsOnlyMKLDNN()) { - return; - } - - // update diff - real* oDiff = getOutput(CPU_DEVICE).grad->getData(); - outGrad_->setData(oDiff); - } - /** * Set deviceId of this layer. */