diff --git a/paddle/gserver/layers/MkldnnLayer.cpp b/paddle/gserver/layers/MkldnnLayer.cpp index 0e1e1c306178760540db8fbf21cc2a94aa473cd7..c909fe274d4e9a87f07e23fdeaefa3aab2029836 100644 --- a/paddle/gserver/layers/MkldnnLayer.cpp +++ b/paddle/gserver/layers/MkldnnLayer.cpp @@ -49,7 +49,6 @@ void MkldnnLayer::resetForwardFC(int bs, real* wgtData, real* biasData) { bool hasSpatial = ih == 1 && iw == 1 ? false : true; - mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw) : createMD({bs, ic}, format::nc); mem::desc wgtMD = hasSpatial ? createMD({oc, ic, ih, iw}, format::oihw) @@ -58,7 +57,12 @@ void MkldnnLayer::resetForwardFC(int bs, : createMD({}, format::format_undef); mem::desc topMD = createMD({bs, oc}, format::nc); - inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData)); + mem::primitive_desc botPD = mem::primitive_desc(botMD, engine_); + if (inVal_ && inVal_->get_primitive_desc() == botPD) { + return; + } + + inVal_.reset(new mem(botPD, botData)); wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData)); outVal_.reset(new mem(mem::primitive_desc(topMD, engine_), topData)); @@ -111,7 +115,6 @@ void MkldnnLayer::resetBackwardFC(int bs, real* wgtData, real* biasDiff) { bool hasSpatial = ih == 1 && iw == 1 ? false : true; - engine_ = CpuEngine::Instance().getEngine(); // backward weight mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw) @@ -122,9 +125,19 @@ void MkldnnLayer::resetBackwardFC(int bs, mem::desc biasMD = biasDiff != NULL ? createMD({oc}, format::x) : createMD({}, format::format_undef); - inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData)); + mem::primitive_desc topPD = mem::primitive_desc(botMD, engine_); + if (outGrad_ && outGrad_->get_primitive_desc() == topPD) { + return; + } + + if (inVal_) { + // update data + inVal_->set_data_handle(botData); + } else { + inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData)); + } wgtGrad_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtDiff)); - outGrad_.reset(new mem(mem::primitive_desc(topMD, engine_), topDiff)); + outGrad_.reset(new mem(topPD, topDiff)); fc_fwd::desc fwdDesc = fc_fwd::desc(mkldnn::prop_kind::forward, botMD, wgtMD, topMD); @@ -154,7 +167,12 @@ void MkldnnLayer::resetBackwardFC(int bs, fc_bwdData::primitive_desc bwdDataPD = fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD); inGrad_.reset(new mem(mem::primitive_desc(botMD, engine_), botDiff)); - wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData)); + if (wgtVal_) { + // update data + wgtVal_->set_data_handle(wgtData); + } else { + wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData)); + } bwdData_.reset(new fc_bwdData(bwdDataPD, *outGrad_, *wgtVal_, *inGrad_)); pipelineBwd_.push_back(*bwdData_); }