diff --git a/paddle/gserver/activations/MKLDNNActivation.cpp b/paddle/gserver/activations/MKLDNNActivation.cpp index 18c5638100065109fba1f0647a1c5f91256f7b9d..f3ccd68160859795f28a40f8d0d4032adb289ccf 100644 --- a/paddle/gserver/activations/MKLDNNActivation.cpp +++ b/paddle/gserver/activations/MKLDNNActivation.cpp @@ -126,7 +126,7 @@ void MKLDNNEltwiseActivation::resetFwd(Argument& act) { copyInVal_ = nullptr; if (act.grad && algo == algorithm::eltwise_tanh) { // tanh need save src input for backward - inVal_ = MKLDNNMatrix::create(nullptr, val_->getPrimitiveDesc()); + inVal_ = MKLDNNMatrix::create(val_->getPrimitiveDesc()); copyInVal_ = std::make_shared(*val_, *inVal_); CHECK(copyInVal_) << "should not be emptry"; pipelineFwd_.push_back(*copyInVal_); @@ -145,7 +145,7 @@ void MKLDNNEltwiseActivation::resetBwd(Argument& act) { algorithm algo = getAlgo(this->getName()); float alpha = getBwdAlpha(); float beta = getBeta(); - grad_ = MKLDNNMatrix::create(act.grad, val_->getPrimitiveDesc()); + grad_ = MKLDNNMatrix::create(val_->getPrimitiveDesc(), act.grad); auto eng = CPUEngine::Instance().getEngine(); auto bwdDesc = eltwise_bwd::desc( algo, grad_->getMemoryDesc(), val_->getMemoryDesc(), alpha, beta); @@ -230,7 +230,7 @@ void MKLDNNActivation::resetFwd(Argument& act) { int ic = cnt_ / bs / ih / iw; CHECK_EQ(cnt_, (size_t)bs * ic * ih * iw); val_ = MKLDNNMatrix::create( - act.value, {bs, ic, ih, iw}, mkldnn::memory::format::nchw, *engine_); + {bs, ic, ih, iw}, mkldnn::memory::format::nchw, *engine_, act.value); CHECK(val_); val_->downSpatial(); } diff --git a/paddle/gserver/layers/MKLDNNConvLayer.cpp b/paddle/gserver/layers/MKLDNNConvLayer.cpp index 463e6ad0ed73578f90babf07a068291c82801971..3fbfb1ab1f6da0d405b117b2bb6c55239f58aa88 100644 --- a/paddle/gserver/layers/MKLDNNConvLayer.cpp +++ b/paddle/gserver/layers/MKLDNNConvLayer.cpp @@ -370,8 +370,7 @@ void MKLDNNConvLayer::resetWgtValBwdData( // since the primitive_desc would be different with wgtVal_ CHECK(wgtVal_) << "should have weight value"; if (dataPD->weights_primitive_desc() != wgtVal_->getPrimitiveDesc()) { - wgtValBwdData_ = - MKLDNNMatrix::create(nullptr, dataPD->weights_primitive_desc()); + wgtValBwdData_ = MKLDNNMatrix::create(dataPD->weights_primitive_desc()); cvtWgtVal_ = MKLDNNMatrix::createReorder(wgtVal_, wgtValBwdData_); CHECK(cvtWgtVal_); } else { diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index ab59357ad01c1bd2a9d378460387f1953a9c0c74..80c67529dafe4d08be4271e8a57489d1770683eb 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -323,7 +323,7 @@ protected: if (mat == nullptr) { return; } - dnn = MKLDNNMatrix::create(mat, pd); + dnn = MKLDNNMatrix::create(pd, mat); } /** @@ -343,7 +343,7 @@ protected: in = std::dynamic_pointer_cast(inMat); CHECK_EQ(inputIsOnlyMKLDNN(), in != nullptr); if (in == nullptr || in->getFormat() == mkldnn::memory::format::nc) { - in = MKLDNNMatrix::create(inMat, extPD); + in = MKLDNNMatrix::create(extPD, inMat); } extInVal_ = isPaddleFormat(in->getFormat()) ? in : nullptr; if (in->getFormat() == mkldnn::memory::format::nc) { @@ -353,8 +353,8 @@ protected: return; } // need create reorder - in = MKLDNNMatrix::create(nullptr, *intPD); - extInVal_ = extInVal_ ? extInVal_ : MKLDNNMatrix::create(inMat, extPD); + in = MKLDNNMatrix::create(*intPD); + extInVal_ = extInVal_ ? extInVal_ : MKLDNNMatrix::create(extPD, inMat); cvtInVal_ = MKLDNNMatrix::createReorder(extInVal_, in); CHECK(cvtInVal_) << "should not be emptry"; } @@ -366,18 +366,18 @@ protected: void resetOutValue(MKLDNNMatrixPtr& out, mkldnn::memory::primitive_desc intPD) { cvtOutVal_ = nullptr; - out = MKLDNNMatrix::create(output_.value, intPD); + out = MKLDNNMatrix::create(intPD, output_.value); extOutVal_ = out; if (outputIsOnlyMKLDNN() || isPaddleFormat(extOutVal_->getFormat())) { return; } // need create reorder CHECK_GT(bs_ * oc_ * oh_ * ow_, 0); - extOutVal_ = MKLDNNMatrix::create(output_.value, - {bs_, oc_, oh_, ow_}, + extOutVal_ = MKLDNNMatrix::create(mkldnn::memory::dims{bs_, oc_, oh_, ow_}, mkldnn::memory::format::nchw, - engine_); - out = MKLDNNMatrix::create(nullptr, intPD); + engine_, + output_.value); + out = MKLDNNMatrix::create(intPD); cvtOutVal_ = MKLDNNMatrix::createReorder(out, extOutVal_); CHECK(cvtOutVal_) << "should not be empty"; } @@ -402,7 +402,7 @@ protected: // and the mkldnn input layer will merge them to actual prev->output_.grad const MatrixPtr& inMat = input->getOutputMapSize() <= 1 ? input->getOutputGrad() : nullptr; - in = MKLDNNMatrix::create(inMat, intPD); + in = MKLDNNMatrix::create(intPD, inMat); Argument& arg = input->getOutput(this->getName()); arg.grad = std::dynamic_pointer_cast(in); CHECK(inVal_ != nullptr && inVal_->getPrimitiveDesc() == intPD) @@ -418,10 +418,10 @@ protected: // need create reorder CHECK(extInVal_ != nullptr && isPaddleFormat(extInVal_->getFormat())) << "should have external input value and the format must be nchw(nc)"; - extInGrad_ = MKLDNNMatrix::create(inMat, extInVal_->getPrimitiveDesc()); + extInGrad_ = MKLDNNMatrix::create(extInVal_->getPrimitiveDesc(), inMat); CHECK(inVal_ != nullptr && inVal_->getPrimitiveDesc() == intPD) << "should have internal input value and primitive desc must equal"; - in = MKLDNNMatrix::create(nullptr, intPD); + in = MKLDNNMatrix::create(intPD); cvtInGrad_ = MKLDNNMatrix::createReorder(in, extInGrad_); CHECK(cvtInGrad_); } @@ -440,7 +440,7 @@ protected: extOutGrad_ = nullptr; out = nullptr; MatrixPtr& outMat = output_.grad; - out = MKLDNNMatrix::create(outMat, intPD); + out = MKLDNNMatrix::create(intPD, outMat); resetMergeGrad(out); if (outputIsOnlyMKLDNN()) { return; @@ -453,10 +453,10 @@ protected: // need create reorder CHECK(extOutVal_ != nullptr && isPaddleFormat(extOutVal_->getFormat())) << "should have external output value and the format must be nchw(nc)"; - extOutGrad_ = MKLDNNMatrix::create(outMat, extOutVal_->getPrimitiveDesc()); + extOutGrad_ = MKLDNNMatrix::create(extOutVal_->getPrimitiveDesc(), outMat); CHECK(outVal_ != nullptr && outVal_->getPrimitiveDesc() == intPD) << "should have internal output value and primitive desc must equal"; - out = MKLDNNMatrix::create(nullptr, intPD); + out = MKLDNNMatrix::create(intPD); cvtOutGrad_ = MKLDNNMatrix::createReorder(extOutGrad_, out); CHECK(cvtOutGrad_); } @@ -499,7 +499,7 @@ protected: tmpOutGrad_ = out; tmpCvt_ = nullptr; if (out->getPrimitiveDesc() != srcPDs[0]) { - tmpOutGrad_ = MKLDNNMatrix::create(nullptr, srcPDs[0]); + tmpOutGrad_ = MKLDNNMatrix::create(srcPDs[0]); tmpCvt_ = MKLDNNMatrix::createReorder(tmpOutGrad_, out); CHECK(tmpCvt_); pipelineMergeGrad_.push_back(*tmpCvt_); diff --git a/paddle/math/MKLDNNMatrix.cpp b/paddle/math/MKLDNNMatrix.cpp index c606560473a5910309369b38e5e79ee189dfc9d7..21a8f73c3e650d4b3c3b86247594cd965f4ead35 100644 --- a/paddle/math/MKLDNNMatrix.cpp +++ b/paddle/math/MKLDNNMatrix.cpp @@ -18,7 +18,7 @@ using namespace mkldnn; // NOLINT namespace paddle { -MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, memory::primitive_desc pd) { +MKLDNNMatrixPtr MKLDNNMatrix::create(memory::primitive_desc pd, MatrixPtr m) { memory::desc md = pd.desc(); size_t ndims = md.data.ndims; int* dims = md.data.dims; @@ -41,12 +41,12 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, memory::primitive_desc pd) { return std::make_shared(cpuMatrix, pd); } -MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, - memory::dims dims, +MKLDNNMatrixPtr MKLDNNMatrix::create(memory::dims dims, memory::format fmt, engine& eg, + MatrixPtr m, mkldnn::memory::data_type dtype) { - return create(m, createPrimitiveDesc(dims, fmt, eg, dtype)); + return create(createPrimitiveDesc(dims, fmt, eg, dtype), m); } std::shared_ptr MKLDNNMatrix::createReorder(const MKLDNNMatrixPtr& src, diff --git a/paddle/math/MKLDNNMatrix.h b/paddle/math/MKLDNNMatrix.h index 9e3f29eb57550cd7c3f25b25ee7c5cb726b0a21d..fe755d096da9713e39581a909e5d21aa93d69f0f 100644 --- a/paddle/math/MKLDNNMatrix.h +++ b/paddle/math/MKLDNNMatrix.h @@ -40,16 +40,17 @@ public: /** * Create MKLDNNMatrix from a MatrixPtr and memory primitive_desc */ - static MKLDNNMatrixPtr create(MatrixPtr m, mkldnn::memory::primitive_desc pd); + static MKLDNNMatrixPtr create(mkldnn::memory::primitive_desc pd, + MatrixPtr m = nullptr); /** * Create MKLDNNMatrix from a MatrixPtr and memory details info */ static MKLDNNMatrixPtr create( - MatrixPtr m, mkldnn::memory::dims dims, mkldnn::memory::format fmt, mkldnn::engine& eg, + MatrixPtr m = nullptr, mkldnn::memory::data_type dtype = mkldnn::memory::data_type::f32); /**