diff --git a/paddle/gserver/layers/MKLDNNConvLayer.cpp b/paddle/gserver/layers/MKLDNNConvLayer.cpp index 2647cb600653b4f43322016afb231a55f4db5642..88b047c89bd40aba1afc456c22a2870c62989c1c 100644 --- a/paddle/gserver/layers/MKLDNNConvLayer.cpp +++ b/paddle/gserver/layers/MKLDNNConvLayer.cpp @@ -449,13 +449,14 @@ void MKLDNNConvLayer::resetOutGrad( cvtOutGrad_ = nullptr; if (!outputIsOnlyMKLDNN()) { const MatrixPtr& cpuOut = getOutput(CPU_DEVICE).grad; + outMat->setData(cpuOut->getData()); // same PrimitiveDesc with cpuInVal_ CHECK(cpuOutVal_); cpuOutGrad_ = MKLDNNMatrix::create(cpuOut, cpuOutVal_->getPrimitiveDesc()); if (cpuOutGrad_->getPrimitiveDesc() == out->getPrimitiveDesc()) { - outMat->setData(cpuOut->getData()); out = cpuOutGrad_; } else { + out = MKLDNNMatrix::create(nullptr, wgtPD->diff_dst_primitive_desc()); cvtOutGrad_ = MKLDNNMatrix::createReorder(cpuOutGrad_, out); CHECK(cvtOutGrad_); } diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp index 66b358bcea53f61ddcc15323704fa9f154fb2a73..afd092666bf8b8a3389b36aa1f0edb256a9968e6 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.cpp +++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp @@ -232,6 +232,7 @@ void MKLDNNFcLayer::resetBwdBuffers(MKLDNNMatrixPtr& in, void MKLDNNFcLayer::resetOutGrad(MKLDNNMatrixPtr& out) { // TODO(TJ): merge outgrad int device = outputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE; + output_.grad->setData(getOutput(device).grad->getData()); // for MKLDNN device: // can not directly cast outputgrad to mkldnnmatrix, // since each layer can not write the inputgrad to mkldnn inputgrad. diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index c4e4a6874e6fdb491c344c70dfea422dc0924cd9..d8555a833187ddf64b096135e920e5be2b3a8c2f 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -141,18 +141,16 @@ public: } void backward(const UpdateCallback& callback) override { - /* Do derivation */ { + if (needResetBwd_) { + resetBwd(pipelineBwd_, inGrad_, wgtGrad_, biasGrad_, outGrad_); + needResetBwd_ = false; + } + { REGISTER_TIMER_INFO("BpActTimer", getName().c_str()); backwardActivation(); } - { REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str()); - if (needResetBwd_) { - resetBwd(pipelineBwd_, inGrad_, wgtGrad_, biasGrad_, outGrad_); - needResetBwd_ = false; - } - stream_->submit(pipelineBwd_); }