提交 e18fbd82 编写于 作者: T tensor-tang

skip reset mkldnn when input size does not change

上级 6373291c
......@@ -49,7 +49,6 @@ void MkldnnLayer::resetForwardFC(int bs,
real* wgtData,
real* biasData) {
bool hasSpatial = ih == 1 && iw == 1 ? false : true;
mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw)
: createMD({bs, ic}, format::nc);
mem::desc wgtMD = hasSpatial ? createMD({oc, ic, ih, iw}, format::oihw)
......@@ -58,7 +57,12 @@ void MkldnnLayer::resetForwardFC(int bs,
: createMD({}, format::format_undef);
mem::desc topMD = createMD({bs, oc}, format::nc);
inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData));
mem::primitive_desc botPD = mem::primitive_desc(botMD, engine_);
if (inVal_ && inVal_->get_primitive_desc() == botPD) {
return;
}
inVal_.reset(new mem(botPD, botData));
wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData));
outVal_.reset(new mem(mem::primitive_desc(topMD, engine_), topData));
......@@ -111,7 +115,6 @@ void MkldnnLayer::resetBackwardFC(int bs,
real* wgtData,
real* biasDiff) {
bool hasSpatial = ih == 1 && iw == 1 ? false : true;
engine_ = CpuEngine::Instance().getEngine();
// backward weight
mem::desc botMD = hasSpatial ? createMD({bs, ic, ih, iw}, format::nchw)
......@@ -122,9 +125,19 @@ void MkldnnLayer::resetBackwardFC(int bs,
mem::desc biasMD = biasDiff != NULL ? createMD({oc}, format::x)
: createMD({}, format::format_undef);
mem::primitive_desc topPD = mem::primitive_desc(botMD, engine_);
if (outGrad_ && outGrad_->get_primitive_desc() == topPD) {
return;
}
if (inVal_) {
// update data
inVal_->set_data_handle(botData);
} else {
inVal_.reset(new mem(mem::primitive_desc(botMD, engine_), botData));
}
wgtGrad_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtDiff));
outGrad_.reset(new mem(mem::primitive_desc(topMD, engine_), topDiff));
outGrad_.reset(new mem(topPD, topDiff));
fc_fwd::desc fwdDesc =
fc_fwd::desc(mkldnn::prop_kind::forward, botMD, wgtMD, topMD);
......@@ -154,7 +167,12 @@ void MkldnnLayer::resetBackwardFC(int bs,
fc_bwdData::primitive_desc bwdDataPD =
fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD);
inGrad_.reset(new mem(mem::primitive_desc(botMD, engine_), botDiff));
if (wgtVal_) {
// update data
wgtVal_->set_data_handle(wgtData);
} else {
wgtVal_.reset(new mem(mem::primitive_desc(wgtMD, engine_), wgtData));
}
bwdData_.reset(new fc_bwdData(bwdDataPD, *outGrad_, *wgtVal_, *inGrad_));
pipelineBwd_.push_back(*bwdData_);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册