提交 1203ebc4 编写于 作者: T tensor-tang

add mkldnn fc backward

上级 90d5be74
......@@ -77,7 +77,6 @@ void MkldnnFcLayer::reshape() {
void MkldnnFcLayer::forward(PassType passType) {
Layer::forward(passType);
reshape();
{
......@@ -97,6 +96,40 @@ void MkldnnFcLayer::forward(PassType passType) {
}
void MkldnnFcLayer::backward(const UpdateCallback& callback) {
; // bool hasBias = biases_ && biases_->getWGrad();
/* Do derivation */ {
REGISTER_TIMER_INFO("BpActTimer", getName().c_str());
backwardActivation();
}
bool hasBias = biases_ && biases_->getWGrad();
{
REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str());
real* inVal = getInputValue(0)->getData();
real* inGrad =
getInputGrad(0) != nullptr ? getInputGrad(0)->getData() : NULL;
real* outGrad = getOutputGrad()->getData();
real* wgtGrad = weight_->getWGrad()->getData();
real* wgtVal = weight_->getW()->getData();
real* biasGrad = hasBias ? biases_->getWGrad()->getData() : NULL;
mkldnnBackwardFC(bs_,
ic_,
ih_,
iw_,
inGrad,
inVal,
oc_,
outGrad,
wgtGrad,
wgtVal,
biasGrad);
}
{
REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
weight_->getParameterPtr()->incUpdate(callback);
if (hasBias) {
biases_->getParameterPtr()->incUpdate(callback);
}
}
}
} // namespace paddle
......@@ -88,6 +88,94 @@ void MkldnnLayer::mkldnnForwardFC(int bs,
stream_->submit(pipelineFwd_);
}
void MkldnnLayer::resetBackwardFC(int bs,
int ic,
int ih,
int iw,
real* botDiff,
real* botData,
int oc,
real* topDiff,
real* wgtDiff,
real* wgtData,
real* biasDiff) {
bool hasSpatial = ih == 1 && iw == 1 ? false : true;
engine_ = CpuEngine::Instance().getEngine();
// backward weight
mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw)
: createMD({bs, ic}, format::nc);
mem::desc wgtMD = hasSpatial ? createMD({oc, ic, ih, iw}, format::oihw)
: createMD({oc, ic}, format::oi);
mem::desc topMD = createMD({bs, oc}, format::nc);
mem::desc biasMD = biasDiff != NULL ? createMD({oc}, format::x)
: createMD({}, format::format_undef);
fc_fwd::desc fwdDesc =
fc_fwd::desc(mkldnn::prop_kind::forward, botMD, wgtMD, topMD);
fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);
fc_bwdWgt::desc bwdWgtDesc =
biasDiff != NULL ? fc_bwdWgt::desc(botMD, wgtMD, biasMD, topMD)
: fc_bwdWgt::desc(botMD, wgtMD, topMD);
fc_bwdWgt::primitive_desc bwdWgtPD =
fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD);
mem botVal = mem(mem::primitive_desc(botMD, engine_), botData);
mem wgtGrad = mem(mem::primitive_desc(wgtMD, engine_), wgtDiff);
mem topGrad = mem(mem::primitive_desc(topMD, engine_), topDiff);
if (biasDiff != NULL) {
mem biasGrad = mem(mem::primitive_desc(biasMD, engine_), biasDiff);
bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, botVal, topGrad, wgtGrad, biasGrad));
} else {
bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, botVal, topGrad, wgtGrad));
}
pipelineBwd_.clear();
pipelineBwd_.push_back(*bwdWgt_);
// backward data
if (botDiff == NULL) {
return;
}
fc_bwdData::desc bwdDataDesc = fc_bwdData::desc(botMD, wgtMD, topMD);
fc_bwdData::primitive_desc bwdDataPD =
fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD);
mem botGrad = mem(mem::primitive_desc(botMD, engine_), botDiff);
mem wgtVal = mem(mem::primitive_desc(wgtMD, engine_), wgtData);
bwdData_.reset(new fc_bwdData(bwdDataPD, topGrad, wgtVal, botGrad));
pipelineBwd_.push_back(*bwdData_);
}
void MkldnnLayer::mkldnnBackwardFC(int bs,
int ic,
int ih,
int iw,
real* botDiff,
real* botData,
int oc,
real* topDiff,
real* wgtDiff,
real* wgtData,
real* biasDiff) {
// if input size changed, reset it
resetBackwardFC(bs,
ic,
ih,
iw,
botDiff,
botData,
oc,
topDiff,
wgtDiff,
wgtData,
biasDiff);
// just forward
// update botdata
stream_->submit(pipelineBwd_);
}
mem::desc MkldnnLayer::createMD(mem::dims dims,
mem::format fmt,
mem::data_type type) {
......
......@@ -42,6 +42,8 @@ protected:
std::shared_ptr<MkldnnStream> stream_;
std::shared_ptr<mkldnn::primitive> fwd_;
std::shared_ptr<mkldnn::primitive> bwdWgt_;
std::shared_ptr<mkldnn::primitive> bwdData_;
std::vector<mkldnn::primitive> pipelineFwd_;
std::vector<mkldnn::primitive> pipelineBwd_;
......@@ -56,7 +58,10 @@ public:
oh_(0),
ow_(0),
engine_(mkldnn::engine::cpu, 0),
stream_(nullptr) {}
stream_(nullptr),
fwd_(nullptr),
bwdWgt_(nullptr),
bwdData_(nullptr) {}
~MkldnnLayer() {}
......@@ -82,6 +87,30 @@ public:
real* wgtData,
real* biasData);
void resetBackwardFC(int bs,
int ic,
int ih,
int iw,
real* botDiff,
real* botData,
int oc,
real* topDiff,
real* wgtDiff,
real* wgtData,
real* biasDiff);
void mkldnnBackwardFC(int bs,
int ic,
int ih,
int iw,
real* botDiff,
real* botData,
int oc,
real* topDiff,
real* wgtDiff,
real* wgtData,
real* biasDiff);
// TODO(TJ): move to MkldnnMatrix
// create memory desc
inline mkldnn::memory::desc createMD(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册