diff --git a/paddle/gserver/layers/MkldnnFcLayer.cpp b/paddle/gserver/layers/MkldnnFcLayer.cpp index 5584b43ff17156b6a84609fd8c4985bb9299358f..b62422da83134e71ef5f866f2f5fc1a204ee53bf 100644 --- a/paddle/gserver/layers/MkldnnFcLayer.cpp +++ b/paddle/gserver/layers/MkldnnFcLayer.cpp @@ -77,7 +77,6 @@ void MkldnnFcLayer::reshape() { void MkldnnFcLayer::forward(PassType passType) { Layer::forward(passType); - reshape(); { @@ -97,6 +96,40 @@ void MkldnnFcLayer::forward(PassType passType) { } void MkldnnFcLayer::backward(const UpdateCallback& callback) { - ; // bool hasBias = biases_ && biases_->getWGrad(); + /* Do derivation */ { + REGISTER_TIMER_INFO("BpActTimer", getName().c_str()); + backwardActivation(); + } + + bool hasBias = biases_ && biases_->getWGrad(); + { + REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str()); + real* inVal = getInputValue(0)->getData(); + real* inGrad = + getInputGrad(0) != nullptr ? getInputGrad(0)->getData() : NULL; + real* outGrad = getOutputGrad()->getData(); + real* wgtGrad = weight_->getWGrad()->getData(); + real* wgtVal = weight_->getW()->getData(); + real* biasGrad = hasBias ? biases_->getWGrad()->getData() : NULL; + mkldnnBackwardFC(bs_, + ic_, + ih_, + iw_, + inGrad, + inVal, + oc_, + outGrad, + wgtGrad, + wgtVal, + biasGrad); + } + + { + REGISTER_TIMER_INFO("WeightUpdate", getName().c_str()); + weight_->getParameterPtr()->incUpdate(callback); + if (hasBias) { + biases_->getParameterPtr()->incUpdate(callback); + } + } } } // namespace paddle diff --git a/paddle/gserver/layers/MkldnnLayer.cpp b/paddle/gserver/layers/MkldnnLayer.cpp index d462e8694c268438431c5b66b36b48b7c7e13ef2..64bed5c8214080943ad8b353fe3855e385e3483e 100644 --- a/paddle/gserver/layers/MkldnnLayer.cpp +++ b/paddle/gserver/layers/MkldnnLayer.cpp @@ -88,6 +88,94 @@ void MkldnnLayer::mkldnnForwardFC(int bs, stream_->submit(pipelineFwd_); } +void MkldnnLayer::resetBackwardFC(int bs, + int ic, + int ih, + int iw, + real* botDiff, + real* botData, + int oc, + real* topDiff, + real* wgtDiff, + real* wgtData, + real* biasDiff) { + bool hasSpatial = ih == 1 && iw == 1 ? false : true; + engine_ = CpuEngine::Instance().getEngine(); + + // backward weight + mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw) + : createMD({bs, ic}, format::nc); + mem::desc wgtMD = hasSpatial ? createMD({oc, ic, ih, iw}, format::oihw) + : createMD({oc, ic}, format::oi); + mem::desc topMD = createMD({bs, oc}, format::nc); + mem::desc biasMD = biasDiff != NULL ? createMD({oc}, format::x) + : createMD({}, format::format_undef); + + fc_fwd::desc fwdDesc = + fc_fwd::desc(mkldnn::prop_kind::forward, botMD, wgtMD, topMD); + fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_); + fc_bwdWgt::desc bwdWgtDesc = + biasDiff != NULL ? fc_bwdWgt::desc(botMD, wgtMD, biasMD, topMD) + : fc_bwdWgt::desc(botMD, wgtMD, topMD); + fc_bwdWgt::primitive_desc bwdWgtPD = + fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD); + + mem botVal = mem(mem::primitive_desc(botMD, engine_), botData); + mem wgtGrad = mem(mem::primitive_desc(wgtMD, engine_), wgtDiff); + mem topGrad = mem(mem::primitive_desc(topMD, engine_), topDiff); + + if (biasDiff != NULL) { + mem biasGrad = mem(mem::primitive_desc(biasMD, engine_), biasDiff); + bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, botVal, topGrad, wgtGrad, biasGrad)); + } else { + bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, botVal, topGrad, wgtGrad)); + } + pipelineBwd_.clear(); + pipelineBwd_.push_back(*bwdWgt_); + + // backward data + if (botDiff == NULL) { + return; + } + + fc_bwdData::desc bwdDataDesc = fc_bwdData::desc(botMD, wgtMD, topMD); + fc_bwdData::primitive_desc bwdDataPD = + fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD); + mem botGrad = mem(mem::primitive_desc(botMD, engine_), botDiff); + mem wgtVal = mem(mem::primitive_desc(wgtMD, engine_), wgtData); + bwdData_.reset(new fc_bwdData(bwdDataPD, topGrad, wgtVal, botGrad)); + pipelineBwd_.push_back(*bwdData_); +} + +void MkldnnLayer::mkldnnBackwardFC(int bs, + int ic, + int ih, + int iw, + real* botDiff, + real* botData, + int oc, + real* topDiff, + real* wgtDiff, + real* wgtData, + real* biasDiff) { + // if input size changed, reset it + resetBackwardFC(bs, + ic, + ih, + iw, + botDiff, + botData, + oc, + topDiff, + wgtDiff, + wgtData, + biasDiff); + + // just forward + // update botdata + stream_->submit(pipelineBwd_); +} + mem::desc MkldnnLayer::createMD(mem::dims dims, mem::format fmt, mem::data_type type) { diff --git a/paddle/gserver/layers/MkldnnLayer.h b/paddle/gserver/layers/MkldnnLayer.h index 6e41ee4028e24e6b3c336abf3fbc10ad31aa0fef..5927bd6d52631fc5b5048a788d65420262432562 100644 --- a/paddle/gserver/layers/MkldnnLayer.h +++ b/paddle/gserver/layers/MkldnnLayer.h @@ -42,6 +42,8 @@ protected: std::shared_ptr stream_; std::shared_ptr fwd_; + std::shared_ptr bwdWgt_; + std::shared_ptr bwdData_; std::vector pipelineFwd_; std::vector pipelineBwd_; @@ -56,7 +58,10 @@ public: oh_(0), ow_(0), engine_(mkldnn::engine::cpu, 0), - stream_(nullptr) {} + stream_(nullptr), + fwd_(nullptr), + bwdWgt_(nullptr), + bwdData_(nullptr) {} ~MkldnnLayer() {} @@ -82,6 +87,30 @@ public: real* wgtData, real* biasData); + void resetBackwardFC(int bs, + int ic, + int ih, + int iw, + real* botDiff, + real* botData, + int oc, + real* topDiff, + real* wgtDiff, + real* wgtData, + real* biasDiff); + + void mkldnnBackwardFC(int bs, + int ic, + int ih, + int iw, + real* botDiff, + real* botData, + int oc, + real* topDiff, + real* wgtDiff, + real* wgtData, + real* biasDiff); + // TODO(TJ): move to MkldnnMatrix // create memory desc inline mkldnn::memory::desc createMD(