提交 c961fbf0 编写于 作者: T tensor-tang

change the condition to reset the forward in MKLDNNLayer

上级 a8eeef86
......@@ -43,7 +43,8 @@ void MKLDNNAddtoLayer::reshape(
reshapeInput(bs, ih, iw);
ic = inputLayers_[0]->getSize() / ih / iw;
CHECK_EQ((size_t)ic * ih * iw, inputLayers_[0]->getSize());
CHECK_EQ(inputElemenCnt_, (size_t)bs * ic * ih * iw);
CHECK_EQ(inputLayers_[0]->getOutputValue()->getElementCnt(),
(size_t)bs * ic * ih * iw);
for (size_t i = 0; i < inputLayers_.size(); i++) {
CHECK_EQ(int64_t(bs), inputLayers_[i]->getOutput().getBatchSize());
CHECK_EQ(layerSize_, inputLayers_[i]->getSize());
......
......@@ -121,7 +121,8 @@ void MKLDNNBatchNormLayer::reshape(
oh = ih;
ow = iw;
// ic_ and oc can not be changed
CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic)
CHECK_EQ((size_t)ic,
inputLayers_[0]->getOutputValue()->getElementCnt() / bs / ih / iw)
<< "Input channel can not be changed";
reshapeOutput(oh, ow);
resizeOutput(bs, oc * oh * ow);
......
......@@ -36,7 +36,8 @@ void MKLDNNConcatLayer::reshape(
reshapeInput(bs, ih, iw);
ic = inputLayers_[0]->getSize() / ih / iw;
CHECK_EQ((size_t)ic * ih * iw, inputLayers_[0]->getSize());
CHECK_EQ(inputElemenCnt_, (size_t)bs * ic * ih * iw);
CHECK_EQ(inputLayers_[0]->getOutputValue()->getElementCnt(),
(size_t)bs * ic * ih * iw);
CHECK_GT(inputLayers_.size(), 1UL);
channels_.resize(inputLayers_.size());
channels_[0] = ic;
......
......@@ -66,6 +66,15 @@ public:
<< ", " << ow_;
}
size_t keepCondition() {
// reset when the total element size of all inputs changed
size_t totalSize = inputLayers_[0]->getOutputValue()->getElementCnt();
for (size_t i = 1; i < inputLayers_.size(); ++i) {
totalSize += inputLayers_[i]->getOutputValue()->getElementCnt();
}
return totalSize;
}
protected:
void resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& out);
......
......@@ -48,16 +48,13 @@ void MKLDNNLayer::forward(PassType passType) {
REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str());
CHECK(!inputLayers_.empty());
copySeqInfoToOutputs();
size_t elemenCnt = inputLayers_[0]->getOutputValue()->getElementCnt();
if (inputElemenCnt_ != elemenCnt) {
if (condition_ != keepCondition()) {
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward";
// reset when input total sizes changed, not only the batchsize
inputElemenCnt_ = elemenCnt;
condition_ = keepCondition();
reshape(bs_, ic_, ih_, iw_, oc_, oh_, ow_);
printSizeInfo();
// the output_.value and output_.grad are shared with CPU device
shareCPUDevice();
pipelineFwd_.clear();
inVals_.resize(inputLayers_.size(), nullptr);
extInVals_.resize(inputLayers_.size(), nullptr);
......
......@@ -34,8 +34,6 @@ typedef std::shared_ptr<MKLDNNLayer> MKLDNNLayerPtr;
*/
class MKLDNNLayer : public Layer {
protected:
// input value element count
size_t inputElemenCnt_;
// batch size
int bs_;
// they sizes are always from the first input layer
......@@ -44,6 +42,8 @@ protected:
// output image channel, height and width
int oc_, oh_, ow_;
// the condition that forward need be reset
size_t condition_;
// backward also need reset after reset forward handle
bool needResetBwd_;
......@@ -103,14 +103,7 @@ protected:
public:
explicit MKLDNNLayer(const LayerConfig& config)
: Layer(config),
inputElemenCnt_(0),
bs_(0),
ic_(0),
ih_(0),
iw_(0),
oc_(0),
oh_(0),
ow_(0),
condition_(0),
needResetBwd_(true),
outputOnlyMKLDNN_(false),
engine_(mkldnn::engine::cpu, 0),
......@@ -173,6 +166,15 @@ public:
void addOutputArgument(int deviceId) { Layer::addOutputArgument(deviceId); }
protected:
/**
* Some layers may have different condition to reset the forward.
* The function returns the condition that do not need reset forward.
*/
inline virtual size_t keepCondition() {
// reset when the first input element size changed, not only the batchsize
return inputLayers_[0]->getOutputValue()->getElementCnt();
}
/**
* reshape the input image sizes and input batchsize
*/
......
......@@ -61,7 +61,8 @@ void MKLDNNPoolLayer::reshape(
int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) {
reshapeInput(bs, ih, iw);
// ic_ and oc can not be changed
CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic)
CHECK_EQ((size_t)ic,
inputLayers_[0]->getOutputValue()->getElementCnt() / bs / ih / iw)
<< "Input channel can not be changed";
// cal output sizes
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册