提交 e18fbd82 编写于 作者: T tensor-tang

skip reset mkldnn when input size does not change

上级 6373291c
...@@ -49,7 +49,6 @@ void MkldnnLayer::resetForwardFC(int bs, ...@@ -49,7 +49,6 @@ void MkldnnLayer::resetForwardFC(int bs,
real* wgtData, real* wgtData,
real* biasData) { real* biasData) {
bool hasSpatial = ih == 1 && iw == 1 ? false : true; bool hasSpatial = ih == 1 && iw == 1 ? false : true;
mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw) mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw)
: createMD({bs, ic}, format::nc); : createMD({bs, ic}, format::nc);
mem::desc wgtMD = hasSpatial ? createMD({oc, ic, ih, iw}, format::oihw) mem::desc wgtMD = hasSpatial ? createMD({oc, ic, ih, iw}, format::oihw)
...@@ -58,7 +57,12 @@ void MkldnnLayer::resetForwardFC(int bs, ...@@ -58,7 +57,12 @@ void MkldnnLayer::resetForwardFC(int bs,
: createMD({}, format::format_undef); : createMD({}, format::format_undef);
mem::desc topMD = createMD({bs, oc}, format::nc); mem::desc topMD = createMD({bs, oc}, format::nc);
inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData)); mem::primitive_desc botPD = mem::primitive_desc(botMD, engine_);
if (inVal_ && inVal_->get_primitive_desc() == botPD) {
return;
}
inVal_.reset(new mem(botPD, botData));
wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData)); wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData));
outVal_.reset(new mem(mem::primitive_desc(topMD, engine_), topData)); outVal_.reset(new mem(mem::primitive_desc(topMD, engine_), topData));
...@@ -111,7 +115,6 @@ void MkldnnLayer::resetBackwardFC(int bs, ...@@ -111,7 +115,6 @@ void MkldnnLayer::resetBackwardFC(int bs,
real* wgtData, real* wgtData,
real* biasDiff) { real* biasDiff) {
bool hasSpatial = ih == 1 && iw == 1 ? false : true; bool hasSpatial = ih == 1 && iw == 1 ? false : true;
engine_ = CpuEngine::Instance().getEngine();
// backward weight // backward weight
mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw) mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw)
...@@ -122,9 +125,19 @@ void MkldnnLayer::resetBackwardFC(int bs, ...@@ -122,9 +125,19 @@ void MkldnnLayer::resetBackwardFC(int bs,
mem::desc biasMD = biasDiff != NULL ? createMD({oc}, format::x) mem::desc biasMD = biasDiff != NULL ? createMD({oc}, format::x)
: createMD({}, format::format_undef); : createMD({}, format::format_undef);
inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData)); mem::primitive_desc topPD = mem::primitive_desc(botMD, engine_);
if (outGrad_ && outGrad_->get_primitive_desc() == topPD) {
return;
}
if (inVal_) {
// update data
inVal_->set_data_handle(botData);
} else {
inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData));
}
wgtGrad_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtDiff)); wgtGrad_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtDiff));
outGrad_.reset(new mem(mem::primitive_desc(topMD, engine_), topDiff)); outGrad_.reset(new mem(topPD, topDiff));
fc_fwd::desc fwdDesc = fc_fwd::desc fwdDesc =
fc_fwd::desc(mkldnn::prop_kind::forward, botMD, wgtMD, topMD); fc_fwd::desc(mkldnn::prop_kind::forward, botMD, wgtMD, topMD);
...@@ -154,7 +167,12 @@ void MkldnnLayer::resetBackwardFC(int bs, ...@@ -154,7 +167,12 @@ void MkldnnLayer::resetBackwardFC(int bs,
fc_bwdData::primitive_desc bwdDataPD = fc_bwdData::primitive_desc bwdDataPD =
fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD); fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD);
inGrad_.reset(new mem(mem::primitive_desc(botMD, engine_), botDiff)); inGrad_.reset(new mem(mem::primitive_desc(botMD, engine_), botDiff));
wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData)); if (wgtVal_) {
// update data
wgtVal_->set_data_handle(wgtData);
} else {
wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData));
}
bwdData_.reset(new fc_bwdData(bwdDataPD, *outGrad_, *wgtVal_, *inGrad_)); bwdData_.reset(new fc_bwdData(bwdDataPD, *outGrad_, *wgtVal_, *inGrad_));
pipelineBwd_.push_back(*bwdData_); pipelineBwd_.push_back(*bwdData_);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册