未验证 提交 55bee85e 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #5779 from tensor-tang/refine

refine MKLDNNLayer
...@@ -38,12 +38,13 @@ bool MKLDNNAddtoLayer::init(const LayerMap& layerMap, ...@@ -38,12 +38,13 @@ bool MKLDNNAddtoLayer::init(const LayerMap& layerMap,
} }
void MKLDNNAddtoLayer::reshape( void MKLDNNAddtoLayer::reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) { int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) {
CHECK_EQ(layerSize_, getSize()) << "this layer size can not be changed"; CHECK_EQ(layerSize_, getSize()) << "this layer size can not be changed";
reshapeInput(bs, ih, iw); reshapeInput(bs, ih, iw);
ic = inputLayers_[0]->getSize() / ih / iw; ic = inputLayers_[0]->getSize() / ih / iw;
CHECK_EQ((size_t)ic * ih * iw, inputLayers_[0]->getSize()); CHECK_EQ((size_t)ic * ih * iw, inputLayers_[0]->getSize());
CHECK_EQ(inputElemenCnt_, (size_t)bs * ic * ih * iw); CHECK_EQ(inputLayers_[0]->getOutputValue()->getElementCnt(),
(size_t)bs * ic * ih * iw);
for (size_t i = 0; i < inputLayers_.size(); i++) { for (size_t i = 0; i < inputLayers_.size(); i++) {
CHECK_EQ(int64_t(bs), inputLayers_[i]->getOutput().getBatchSize()); CHECK_EQ(int64_t(bs), inputLayers_[i]->getOutput().getBatchSize());
CHECK_EQ(layerSize_, inputLayers_[i]->getSize()); CHECK_EQ(layerSize_, inputLayers_[i]->getSize());
...@@ -57,47 +58,43 @@ void MKLDNNAddtoLayer::reshape( ...@@ -57,47 +58,43 @@ void MKLDNNAddtoLayer::reshape(
} }
void MKLDNNAddtoLayer::resetFwd(std::vector<primitive>& pipeline, void MKLDNNAddtoLayer::resetFwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
resetFwdBuffers(inVals_, bias, out); resetFwdBuffers(inputs, biasVal_, out);
in = inVals_[0];
std::shared_ptr<sum::primitive_desc> fwdPD; std::shared_ptr<sum::primitive_desc> fwdPD;
std::shared_ptr<sum::primitive_desc> biasPD; std::shared_ptr<sum::primitive_desc> biasPD;
resetFwdPD(fwdPD, biasPD, inVals_, bias, out); resetFwdPD(fwdPD, biasPD, inputs, biasVal_, out);
resetFwdPipeline(pipeline, fwdPD, biasPD, inVals_, bias, out); resetFwdPipeline(pipeline, fwdPD, biasPD, inputs, biasVal_, out);
} }
void MKLDNNAddtoLayer::resetBwd(std::vector<primitive>& pipeline, void MKLDNNAddtoLayer::resetBwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
resetBwdBuffers(inGrads_, bias, out); resetBwdBuffers(inputs, biasGrad_, out);
in = inGrads_[0];
// backward only need share output grad to input grad // backward only need share output grad to input grad
for (size_t i = 0; i < inGrads_.size(); i++) { for (size_t i = 0; i < inputs.size(); i++) {
if (inGrads_[i] != nullptr) { if (inputs[i] != nullptr) {
inGrads_[i] = out; inputs[i] = out;
inputLayers_[i]->getOutputGrad()->setData(inGrads_[i]->getData()); inputLayers_[i]->getOutputGrad()->setData(inputs[i]->getData());
} }
} }
// backward bias // backward bias
bwdBias_ = nullptr; bwdBias_ = nullptr;
if (bias) { if (biasGrad_) {
std::vector<float> scales(bs_, 1.0); std::vector<float> scales(bs_, 1.0);
std::vector<memory::primitive_desc> srcPDs(bs_, bias->getPrimitiveDesc()); std::vector<memory::primitive_desc> srcPDs(bs_,
auto biasPD = sum::primitive_desc(bias->getMemoryDesc(), scales, srcPDs); biasGrad_->getPrimitiveDesc());
auto biasPD =
sum::primitive_desc(biasGrad_->getMemoryDesc(), scales, srcPDs);
std::vector<primitive::at> srcs; std::vector<primitive::at> srcs;
for (size_t i = 0; i < grads_.size(); ++i) { for (size_t i = 0; i < grads_.size(); ++i) {
srcs.push_back(*(grads_[i])); srcs.push_back(*(grads_[i]));
} }
bwdBias_.reset(new sum(biasPD, srcs, *bias)); bwdBias_.reset(new sum(biasPD, srcs, *biasGrad_));
pipeline.push_back(*bwdBias_); pipeline.push_back(*bwdBias_);
} }
} }
...@@ -208,7 +205,7 @@ void MKLDNNAddtoLayer::resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs, ...@@ -208,7 +205,7 @@ void MKLDNNAddtoLayer::resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
inputs.resize(inputLayers_.size()); inputs.resize(inputLayers_.size());
for (size_t i = 0; i < inputs.size(); i++) { for (size_t i = 0; i < inputs.size(); i++) {
resetInGrad(inputs[i], inVal_->getPrimitiveDesc(), i); resetInGrad(inputs[i], inVals_[i]->getPrimitiveDesc(), i);
CHECK_PRIMITIVE_DESC_EQ(inputs[i], out->getPrimitiveDesc()); CHECK_PRIMITIVE_DESC_EQ(inputs[i], out->getPrimitiveDesc());
} }
......
...@@ -26,9 +26,6 @@ namespace paddle { ...@@ -26,9 +26,6 @@ namespace paddle {
*/ */
class MKLDNNAddtoLayer : public MKLDNNLayer { class MKLDNNAddtoLayer : public MKLDNNLayer {
protected: protected:
std::vector<MKLDNNMatrixPtr> inVals_;
std::vector<MKLDNNMatrixPtr> inGrads_;
// layer size == ic * ih * iw == oc * oh *ow, and can not be changed // layer size == ic * ih * iw == oc * oh *ow, and can not be changed
size_t layerSize_; size_t layerSize_;
...@@ -50,52 +47,19 @@ public: ...@@ -50,52 +47,19 @@ public:
const ParameterMap& parameterMap) override; const ParameterMap& parameterMap) override;
void reshape( void reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override; int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) override;
void resetFwd(std::vector<mkldnn::primitive>& pipeline, void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void resetBwd(std::vector<mkldnn::primitive>& pipeline, void resetBwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void updateWeights(const UpdateCallback& callback) override; void updateWeights(const UpdateCallback& callback) override;
void printValueFormat() override {
for (size_t i = 0; i < inVals_.size(); ++i) {
VLOG(MKLDNN_FMTS) << i << " input: " << inVals_[i]->getFormat() << " >>>";
}
if (outVal_) {
VLOG(MKLDNN_FMTS) << outVal_->getFormat() << " >>> ";
}
if (extOutVal_) {
VLOG(MKLDNN_FMTS) << extOutVal_->getFormat();
}
}
void printGradFormat() override {
if (extOutGrad_) {
VLOG(MKLDNN_FMTS) << extOutGrad_->getFormat();
}
if (outGrad_) {
VLOG(MKLDNN_FMTS) << outGrad_->getFormat() << " <<< ";
}
for (size_t i = 0; i < inGrads_.size(); ++i) {
VLOG(MKLDNN_FMTS) << i << " input: " << inGrads_[i]->getFormat() << "<<<";
}
}
protected: protected:
/**
* Forward functions: reset buffers(inputs, output, bias),
* reset primitive descriptor,
* reset pipeline.
*/
void resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs, void resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
...@@ -110,17 +74,10 @@ protected: ...@@ -110,17 +74,10 @@ protected:
std::vector<MKLDNNMatrixPtr>& inputs, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
/**
* Backward functions: reset buffers(inputs, output, bias)
*/
void resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs, void resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
/**
* prepare for bias
*/
void prepareBias(MKLDNNMatrixPtr& bias, void prepareBias(MKLDNNMatrixPtr& bias,
const MatrixPtr& biasMat, const MatrixPtr& biasMat,
const MKLDNNMatrixPtr& out, const MKLDNNMatrixPtr& out,
......
...@@ -116,21 +116,20 @@ void MKLDNNBatchNormLayer::calMovingMeanAndVar() { ...@@ -116,21 +116,20 @@ void MKLDNNBatchNormLayer::calMovingMeanAndVar() {
} }
void MKLDNNBatchNormLayer::reshape( void MKLDNNBatchNormLayer::reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) { int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) {
reshapeInput(bs, ih, iw); reshapeInput(bs, ih, iw);
oh = ih; oh = ih;
ow = iw; ow = iw;
// ic_ and oc can not be changed // ic_ and oc can not be changed
CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic) CHECK_EQ((size_t)ic,
inputLayers_[0]->getOutputValue()->getElementCnt() / bs / ih / iw)
<< "Input channel can not be changed"; << "Input channel can not be changed";
reshapeOutput(oh, ow); reshapeOutput(oh, ow);
resizeOutput(bs, oc * oh * ow); resizeOutput(bs, oc * oh * ow);
} }
void MKLDNNBatchNormLayer::resetFwd(std::vector<primitive>& pipeline, void MKLDNNBatchNormLayer::resetFwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
// In training phase, it will always calculate mean and var, // In training phase, it will always calculate mean and var,
// so useGlobalStats must be false. // so useGlobalStats must be false.
...@@ -140,25 +139,23 @@ void MKLDNNBatchNormLayer::resetFwd(std::vector<primitive>& pipeline, ...@@ -140,25 +139,23 @@ void MKLDNNBatchNormLayer::resetFwd(std::vector<primitive>& pipeline,
useGlobalStats_ = false; useGlobalStats_ = false;
} }
resetFwdBuffers(in, wgt, out); resetFwdBuffers(inputs[0], wgtVal_, out);
resetFwdPD(fwdPD_, in, wgt, out); resetFwdPD(fwdPD_, inputs[0], wgtVal_, out);
resetFwdPipeline(pipeline, fwdPD_, in, wgt, out); resetFwdPipeline(pipeline, fwdPD_, inputs[0], wgtVal_, out);
} }
void MKLDNNBatchNormLayer::resetBwd(std::vector<primitive>& pipeline, void MKLDNNBatchNormLayer::resetBwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
std::shared_ptr<bn_bwd::primitive_desc> pd; std::shared_ptr<bn_bwd::primitive_desc> pd;
resetBwdBuffers(in, wgt, out); resetBwdBuffers(inputs[0], wgtGrad_, out);
resetBwdPD(pd, in, wgt, out); resetBwdPD(pd, inputs[0], wgtGrad_, out);
resetBwdPipeline(pipeline, pd, in, wgt, out); resetBwdPipeline(pipeline, pd, inputs[0], wgtGrad_, out);
} }
void MKLDNNBatchNormLayer::forward(PassType passType) { void MKLDNNBatchNormLayer::forward(PassType passType) {
...@@ -260,9 +257,9 @@ void MKLDNNBatchNormLayer::resetFwdPipeline( ...@@ -260,9 +257,9 @@ void MKLDNNBatchNormLayer::resetFwdPipeline(
void MKLDNNBatchNormLayer::resetBwdBuffers(MKLDNNMatrixPtr& in, void MKLDNNBatchNormLayer::resetBwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
CHECK(inVal_ && outVal_); CHECK(inVals_[0] && outVal_);
resetOutGrad(out, outVal_->getPrimitiveDesc()); resetOutGrad(out, outVal_->getPrimitiveDesc());
resetInGrad(in, inVal_->getPrimitiveDesc()); resetInGrad(in, inVals_[0]->getPrimitiveDesc());
if (gradScaleShift_) { if (gradScaleShift_) {
CHECK(wgtVal_); CHECK(wgtVal_);
resetWithMatrix(wgt, gradScaleShift_, wgtVal_->getPrimitiveDesc()); resetWithMatrix(wgt, gradScaleShift_, wgtVal_->getPrimitiveDesc());
...@@ -297,11 +294,12 @@ void MKLDNNBatchNormLayer::resetBwdPipeline( ...@@ -297,11 +294,12 @@ void MKLDNNBatchNormLayer::resetBwdPipeline(
if (pd == nullptr) { if (pd == nullptr) {
return; return;
} }
CHECK(inVal_); CHECK(inVals_[0]);
bwdData_.reset( bwdData_.reset(
wgt && wgtVal_ wgt && wgtVal_
? new bn_bwd(*pd, *inVal_, *mean_, *var_, *out, *wgtVal_, *in, *wgt) ? new bn_bwd(
: new bn_bwd(*pd, *inVal_, *mean_, *var_, *out, *in)); *pd, *inVals_[0], *mean_, *var_, *out, *wgtVal_, *in, *wgt)
: new bn_bwd(*pd, *inVals_[0], *mean_, *var_, *out, *in));
pipeline.push_back(*bwdData_); pipeline.push_back(*bwdData_);
} }
......
...@@ -73,18 +73,14 @@ public: ...@@ -73,18 +73,14 @@ public:
void forward(PassType passType) override; void forward(PassType passType) override;
void reshape( void reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override; int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) override;
void resetFwd(std::vector<mkldnn::primitive>& pipeline, void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void resetBwd(std::vector<mkldnn::primitive>& pipeline, void resetBwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void updateWeights(const UpdateCallback& callback) override; void updateWeights(const UpdateCallback& callback) override;
...@@ -98,11 +94,7 @@ protected: ...@@ -98,11 +94,7 @@ protected:
* moving = moving * AvgFraction + local * (1 - AvgFraction) * moving = moving * AvgFraction + local * (1 - AvgFraction)
*/ */
void calMovingMeanAndVar(); void calMovingMeanAndVar();
/**
* Forward functions: reset buffers(input, weight, output),
* reset primitive descriptor,
* reset pipeline.
*/
void resetFwdBuffers(MKLDNNMatrixPtr& in, void resetFwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
...@@ -115,12 +107,6 @@ protected: ...@@ -115,12 +107,6 @@ protected:
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
/**
* Backward functions: reset buffers(input, weight, output),
* reset primitive descriptor,
* reset pipeline.
*/
void resetBwdBuffers(MKLDNNMatrixPtr& in, void resetBwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
......
...@@ -32,17 +32,16 @@ bool MKLDNNConcatLayer::init(const LayerMap& layerMap, ...@@ -32,17 +32,16 @@ bool MKLDNNConcatLayer::init(const LayerMap& layerMap,
} }
void MKLDNNConcatLayer::reshape( void MKLDNNConcatLayer::reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) { int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) {
reshapeInput(bs, ih, iw); reshapeInput(bs, ih, iw);
ic = inputLayers_[0]->getSize() / ih / iw; ic = inputLayers_[0]->getSize() / ih / iw;
CHECK_EQ((size_t)ic * ih * iw, inputLayers_[0]->getSize()); CHECK_EQ((size_t)ic * ih * iw, inputLayers_[0]->getSize());
CHECK_EQ(inputElemenCnt_, (size_t)bs * ic * ih * iw); CHECK_EQ(inputLayers_[0]->getOutputValue()->getElementCnt(),
(size_t)bs * ic * ih * iw);
CHECK_GT(inputLayers_.size(), 1UL); CHECK_GT(inputLayers_.size(), 1UL);
channels_.resize(inputLayers_.size()); channels_.resize(inputLayers_.size());
channels_[0] = ic; channels_[0] = ic;
// need change the output channel, so use oc_ instead oc = ic;
// TODO(TJ): change API, use &oc
oc_ = ic;
for (size_t i = 1; i < inputLayers_.size(); i++) { for (size_t i = 1; i < inputLayers_.size(); i++) {
int batchsize, height, witdh; int batchsize, height, witdh;
reshapeInput(batchsize, height, witdh, i); reshapeInput(batchsize, height, witdh, i);
...@@ -52,37 +51,31 @@ void MKLDNNConcatLayer::reshape( ...@@ -52,37 +51,31 @@ void MKLDNNConcatLayer::reshape(
channels_[i] = inputLayers_[i]->getSize() / height / witdh; channels_[i] = inputLayers_[i]->getSize() / height / witdh;
CHECK_EQ((size_t)channels_[i] * height * witdh, inputLayers_[i]->getSize()); CHECK_EQ((size_t)channels_[i] * height * witdh, inputLayers_[i]->getSize());
oc_ += channels_[i]; oc += channels_[i];
} }
oh = ih; oh = ih;
ow = iw; ow = iw;
reshapeOutput(oh, ow); reshapeOutput(oh, ow);
resizeOutput(bs, oc_ * oh * ow); resizeOutput(bs, oc * oh * ow);
} }
void MKLDNNConcatLayer::resetFwd(std::vector<primitive>& pipeline, void MKLDNNConcatLayer::resetFwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
resetFwdBuffers(inVals_, out); resetFwdBuffers(inputs, out);
in = inVals_[0];
std::shared_ptr<concat::primitive_desc> fwdPD; std::shared_ptr<concat::primitive_desc> fwdPD;
resetFwdPD(fwdPD, inVals_, out); resetFwdPD(fwdPD, inputs, out);
resetFwdPipeline(pipeline, fwdPD, inVals_, out); resetFwdPipeline(pipeline, fwdPD, inputs, out);
} }
void MKLDNNConcatLayer::resetBwd(std::vector<primitive>& pipeline, void MKLDNNConcatLayer::resetBwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
resetBwdBuffers(inGrads_, out); resetBwdBuffers(inputs, out);
in = inGrads_[0];
resetBwdPipeline(pipeline, bwds_, inGrads_, out); resetBwdPipeline(pipeline, bwds_, inputs, out);
} }
void MKLDNNConcatLayer::resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs, void MKLDNNConcatLayer::resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
...@@ -90,10 +83,7 @@ void MKLDNNConcatLayer::resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs, ...@@ -90,10 +83,7 @@ void MKLDNNConcatLayer::resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
inputs.resize(inputLayers_.size()); inputs.resize(inputLayers_.size());
bool has8c = false, has16c = false, hasnc = false; bool has8c = false, has16c = false, hasnc = false;
for (size_t i = 0; i < inputs.size(); i++) { for (size_t i = 0; i < inputs.size(); i++) {
// resetInValue will use ic_ so temporary change as current input's channel resetInValue(inputs[i], nullptr, i, channels_[i]);
// TODO(TJ): change ic_ as vector then can remove channels_
ic_ = channels_[i];
resetInValue(inputs[i], nullptr, i);
CHECK(inputs[i]); CHECK(inputs[i]);
auto dm = inputs[i]->getDims(); auto dm = inputs[i]->getDims();
// inputs format can be different, but ndims must equal // inputs format can be different, but ndims must equal
...@@ -114,8 +104,6 @@ void MKLDNNConcatLayer::resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs, ...@@ -114,8 +104,6 @@ void MKLDNNConcatLayer::resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
has16c = true; has16c = true;
} }
} }
// change back, ic_ always save the input 0 size
ic_ = channels_[0];
format outFmt; format outFmt;
if (has16c && oc_ % 16 == 0) { if (has16c && oc_ % 16 == 0) {
...@@ -168,14 +156,9 @@ void MKLDNNConcatLayer::resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs, ...@@ -168,14 +156,9 @@ void MKLDNNConcatLayer::resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
inputs.resize(inputLayers_.size()); inputs.resize(inputLayers_.size());
for (size_t i = 0; i < inputs.size(); i++) { for (size_t i = 0; i < inputs.size(); i++) {
CHECK(inVals_[i]); CHECK(inVals_[i]);
// resetInGrad will use inVal_
// TODO(TJ): change move inVals_ to MKLDNNLayer ans remove inVal_
inVal_ = inVals_[i];
resetInGrad(inputs[i], inVals_[i]->getPrimitiveDesc(), i); resetInGrad(inputs[i], inVals_[i]->getPrimitiveDesc(), i);
CHECK_PRIMITIVE_DESC_EQ(inputs[i], inVals_[i]->getPrimitiveDesc()); CHECK_PRIMITIVE_DESC_EQ(inputs[i], inVals_[i]->getPrimitiveDesc());
} }
// change back, inVal_ always save the input 0
inVal_ = inVals_[0];
} }
void MKLDNNConcatLayer::resetBwdPipeline( void MKLDNNConcatLayer::resetBwdPipeline(
......
...@@ -26,8 +26,6 @@ namespace paddle { ...@@ -26,8 +26,6 @@ namespace paddle {
*/ */
class MKLDNNConcatLayer : public MKLDNNLayer { class MKLDNNConcatLayer : public MKLDNNLayer {
protected: protected:
std::vector<MKLDNNMatrixPtr> inVals_;
std::vector<MKLDNNMatrixPtr> inGrads_;
std::vector<std::shared_ptr<mkldnn::primitive>> bwds_; std::vector<std::shared_ptr<mkldnn::primitive>> bwds_;
// input channel numbers // input channel numbers
std::vector<int> channels_; std::vector<int> channels_;
...@@ -47,18 +45,14 @@ public: ...@@ -47,18 +45,14 @@ public:
const ParameterMap& parameterMap) override; const ParameterMap& parameterMap) override;
void reshape( void reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override; int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) override;
void resetFwd(std::vector<mkldnn::primitive>& pipeline, void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void resetBwd(std::vector<mkldnn::primitive>& pipeline, void resetBwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void printSizeInfo() override { void printSizeInfo() override {
...@@ -72,38 +66,16 @@ public: ...@@ -72,38 +66,16 @@ public:
<< ", " << ow_; << ", " << ow_;
} }
void printValueFormat() override { size_t keepCondition() {
for (size_t i = 0; i < inVals_.size(); ++i) { // reset when the total element size of all inputs changed
VLOG(MKLDNN_FMTS) << "Input " << i << ", " << inputLayers_[i]->getName() size_t totalSize = inputLayers_[0]->getOutputValue()->getElementCnt();
<< ": " << inVals_[i]->getFormat() << " >>>"; for (size_t i = 1; i < inputLayers_.size(); ++i) {
} totalSize += inputLayers_[i]->getOutputValue()->getElementCnt();
if (outVal_) {
VLOG(MKLDNN_FMTS) << outVal_->getFormat() << " >>> ";
}
if (extOutVal_) {
VLOG(MKLDNN_FMTS) << extOutVal_->getFormat();
}
}
void printGradFormat() override {
if (extOutGrad_) {
VLOG(MKLDNN_FMTS) << extOutGrad_->getFormat();
}
if (outGrad_) {
VLOG(MKLDNN_FMTS) << outGrad_->getFormat() << " <<< ";
}
for (size_t i = 0; i < inGrads_.size(); ++i) {
VLOG(MKLDNN_FMTS) << "Input " << i << ", " << inputLayers_[i]->getName()
<< ": " << inGrads_[i]->getFormat() << "<<<";
} }
return totalSize;
} }
protected: protected:
/**
* Forward functions: reset buffers(inputs, output, bias),
* reset primitive descriptor,
* reset pipeline.
*/
void resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs, void resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
void resetFwdPD(std::shared_ptr<mkldnn::concat::primitive_desc>& pd, void resetFwdPD(std::shared_ptr<mkldnn::concat::primitive_desc>& pd,
...@@ -113,11 +85,6 @@ protected: ...@@ -113,11 +85,6 @@ protected:
std::shared_ptr<mkldnn::concat::primitive_desc>& pd, std::shared_ptr<mkldnn::concat::primitive_desc>& pd,
std::vector<MKLDNNMatrixPtr>& inputs, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
/**
* Backward functions: reset buffers(inputs, output, bias)
* reset primitives and pipeline
*/
void resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs, void resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
void resetBwdPipeline(std::vector<mkldnn::primitive>& pipeline, void resetBwdPipeline(std::vector<mkldnn::primitive>& pipeline,
......
...@@ -90,7 +90,7 @@ void MKLDNNConvLayer::convertWeightsToPaddle() { ...@@ -90,7 +90,7 @@ void MKLDNNConvLayer::convertWeightsToPaddle() {
} }
void MKLDNNConvLayer::reshape( void MKLDNNConvLayer::reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) { int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) {
reshapeInput(bs, ih, iw); reshapeInput(bs, ih, iw);
// cal output sizes // cal output sizes
...@@ -105,21 +105,17 @@ void MKLDNNConvLayer::reshape( ...@@ -105,21 +105,17 @@ void MKLDNNConvLayer::reshape(
} }
void MKLDNNConvLayer::resetFwd(std::vector<primitive>& pipeline, void MKLDNNConvLayer::resetFwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
resetFwdPD(fwdPD_); resetFwdPD(fwdPD_);
resetFwdBuffers(fwdPD_, in, wgt, bias, out); resetFwdBuffers(fwdPD_, inputs[0], wgtVal_, biasVal_, out);
resetFwdPipeline(pipeline, fwdPD_, in, wgt, bias, out); resetFwdPipeline(pipeline, fwdPD_, inputs[0], wgtVal_, biasVal_, out);
} }
void MKLDNNConvLayer::resetBwd(std::vector<primitive>& pipeline, void MKLDNNConvLayer::resetBwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
std::shared_ptr<conv_bwdWgt::primitive_desc> bwdWgtPD; std::shared_ptr<conv_bwdWgt::primitive_desc> bwdWgtPD;
std::shared_ptr<conv_bwdData::primitive_desc> bwdDataPD; std::shared_ptr<conv_bwdData::primitive_desc> bwdDataPD;
...@@ -128,9 +124,10 @@ void MKLDNNConvLayer::resetBwd(std::vector<primitive>& pipeline, ...@@ -128,9 +124,10 @@ void MKLDNNConvLayer::resetBwd(std::vector<primitive>& pipeline,
resetBwdDataPD(bwdDataPD); resetBwdDataPD(bwdDataPD);
resetBwdBuffers(bwdWgtPD, bwdDataPD, in, wgt, bias, out); resetBwdBuffers(bwdWgtPD, bwdDataPD, inputs[0], wgtGrad_, biasGrad_, out);
resetBwdPipeline(pipeline, bwdWgtPD, bwdDataPD, in, wgt, bias, out); resetBwdPipeline(
pipeline, bwdWgtPD, bwdDataPD, inputs[0], wgtGrad_, biasGrad_, out);
} }
void MKLDNNConvLayer::updateWeights(const UpdateCallback& callback) { void MKLDNNConvLayer::updateWeights(const UpdateCallback& callback) {
...@@ -236,14 +233,14 @@ void MKLDNNConvLayer::resetBwdWgtPD( ...@@ -236,14 +233,14 @@ void MKLDNNConvLayer::resetBwdWgtPD(
loadConvSettings(wgtDims, biasDims, strides, dilations, padL, padR); loadConvSettings(wgtDims, biasDims, strides, dilations, padL, padR);
// create backward weight using input, output and weight value memory desc // create backward weight using input, output and weight value memory desc
CHECK(inVal_) << "Should have internal input value"; CHECK(inVals_[0]) << "Should have internal input value";
CHECK(outVal_) << "Should have internal output value"; CHECK(outVal_) << "Should have internal output value";
CHECK(wgtVal_) << "Should have weight value"; CHECK(wgtVal_) << "Should have weight value";
algorithm algo = algorithm::convolution_direct; algorithm algo = algorithm::convolution_direct;
padding_kind padKind = padding_kind::zero; padding_kind padKind = padding_kind::zero;
auto bwdWgtDesc = biasVal_ != nullptr auto bwdWgtDesc = biasVal_ != nullptr
? conv_bwdWgt::desc(algo, ? conv_bwdWgt::desc(algo,
inVal_->getMemoryDesc(), inVals_[0]->getMemoryDesc(),
wgtVal_->getMemoryDesc(), wgtVal_->getMemoryDesc(),
biasVal_->getMemoryDesc(), biasVal_->getMemoryDesc(),
outVal_->getMemoryDesc(), outVal_->getMemoryDesc(),
...@@ -252,7 +249,7 @@ void MKLDNNConvLayer::resetBwdWgtPD( ...@@ -252,7 +249,7 @@ void MKLDNNConvLayer::resetBwdWgtPD(
padR, padR,
padKind) padKind)
: conv_bwdWgt::desc(algo, : conv_bwdWgt::desc(algo,
inVal_->getMemoryDesc(), inVals_[0]->getMemoryDesc(),
wgtVal_->getMemoryDesc(), wgtVal_->getMemoryDesc(),
outVal_->getMemoryDesc(), outVal_->getMemoryDesc(),
strides, strides,
...@@ -260,7 +257,7 @@ void MKLDNNConvLayer::resetBwdWgtPD( ...@@ -260,7 +257,7 @@ void MKLDNNConvLayer::resetBwdWgtPD(
padR, padR,
padKind); padKind);
pd.reset(new conv_bwdWgt::primitive_desc(bwdWgtDesc, engine_, *fwdPD_)); pd.reset(new conv_bwdWgt::primitive_desc(bwdWgtDesc, engine_, *fwdPD_));
CHECK_PRIMITIVE_DESC_EQ(inVal_, pd->src_primitive_desc()); CHECK_PRIMITIVE_DESC_EQ(inVals_[0], pd->src_primitive_desc());
CHECK_PRIMITIVE_DESC_EQ( CHECK_PRIMITIVE_DESC_EQ(
outVal_, outVal_,
pd->diff_dst_primitive_desc(), pd->diff_dst_primitive_desc(),
...@@ -280,12 +277,12 @@ void MKLDNNConvLayer::resetBwdDataPD( ...@@ -280,12 +277,12 @@ void MKLDNNConvLayer::resetBwdDataPD(
memory::dims wgtDims, biasDims, strides, dilations, padL, padR; memory::dims wgtDims, biasDims, strides, dilations, padL, padR;
loadConvSettings(wgtDims, biasDims, strides, dilations, padL, padR); loadConvSettings(wgtDims, biasDims, strides, dilations, padL, padR);
CHECK(inVal_) << "Should have internal input value"; CHECK(inVals_[0]) << "Should have internal input value";
CHECK(outVal_) << "Should have internal output value"; CHECK(outVal_) << "Should have internal output value";
// create backward data using input and output value memory desc // create backward data using input and output value memory desc
// but using weight memory desc with any format // but using weight memory desc with any format
auto bwdDataDesc = conv_bwdData::desc(algorithm::convolution_direct, auto bwdDataDesc = conv_bwdData::desc(algorithm::convolution_direct,
inVal_->getMemoryDesc(), inVals_[0]->getMemoryDesc(),
MKLDNNMatrix::createMemoryDesc(wgtDims), MKLDNNMatrix::createMemoryDesc(wgtDims),
outVal_->getMemoryDesc(), outVal_->getMemoryDesc(),
strides, strides,
...@@ -294,7 +291,7 @@ void MKLDNNConvLayer::resetBwdDataPD( ...@@ -294,7 +291,7 @@ void MKLDNNConvLayer::resetBwdDataPD(
padding_kind::zero); padding_kind::zero);
pd.reset(new conv_bwdData::primitive_desc(bwdDataDesc, engine_, *fwdPD_)); pd.reset(new conv_bwdData::primitive_desc(bwdDataDesc, engine_, *fwdPD_));
CHECK_PRIMITIVE_DESC_EQ( CHECK_PRIMITIVE_DESC_EQ(
inVal_, inVals_[0],
pd->diff_src_primitive_desc(), pd->diff_src_primitive_desc(),
"primitive desc of in value and grad should be equal"); "primitive desc of in value and grad should be equal");
CHECK_PRIMITIVE_DESC_EQ( CHECK_PRIMITIVE_DESC_EQ(
...@@ -346,12 +343,12 @@ void MKLDNNConvLayer::resetBwdPipeline( ...@@ -346,12 +343,12 @@ void MKLDNNConvLayer::resetBwdPipeline(
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
CHECK(inVal_); CHECK(inVals_[0]);
// add bwdWgt handle // add bwdWgt handle
if (bias) { if (bias) {
bwdWgt_.reset(new conv_bwdWgt(*wgtPD, *inVal_, *out, *wgt, *bias)); bwdWgt_.reset(new conv_bwdWgt(*wgtPD, *inVals_[0], *out, *wgt, *bias));
} else { } else {
bwdWgt_.reset(new conv_bwdWgt(*wgtPD, *inVal_, *out, *wgt)); bwdWgt_.reset(new conv_bwdWgt(*wgtPD, *inVals_[0], *out, *wgt));
} }
pipeline.push_back(*bwdWgt_); pipeline.push_back(*bwdWgt_);
......
...@@ -69,18 +69,14 @@ public: ...@@ -69,18 +69,14 @@ public:
const ParameterMap& parameterMap) override; const ParameterMap& parameterMap) override;
void reshape( void reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override; int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) override;
void resetFwd(std::vector<mkldnn::primitive>& pipeline, void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void resetBwd(std::vector<mkldnn::primitive>& pipeline, void resetBwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void updateWeights(const UpdateCallback& callback) override; void updateWeights(const UpdateCallback& callback) override;
...@@ -107,48 +103,26 @@ protected: ...@@ -107,48 +103,26 @@ protected:
mkldnn::memory::dims& padL, mkldnn::memory::dims& padL,
mkldnn::memory::dims& padR); mkldnn::memory::dims& padR);
/**
* reset the forward primitive descriptor.
*/
void resetFwdPD(std::shared_ptr<conv_fwd::primitive_desc>& pd); void resetFwdPD(std::shared_ptr<conv_fwd::primitive_desc>& pd);
/**
* reset the MKLDNNMatrix buffers used in forward.
*/
void resetFwdBuffers(std::shared_ptr<conv_fwd::primitive_desc>& pd, void resetFwdBuffers(std::shared_ptr<conv_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
/**
* reset the forward pipeline.
*/
void resetFwdPipeline(std::vector<mkldnn::primitive>& pipeline, void resetFwdPipeline(std::vector<mkldnn::primitive>& pipeline,
std::shared_ptr<conv_fwd::primitive_desc>& pd, std::shared_ptr<conv_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
/**
* reset the backward weight primitive descriptor.
*/
void resetBwdWgtPD(std::shared_ptr<conv_bwdWgt::primitive_desc>& pd); void resetBwdWgtPD(std::shared_ptr<conv_bwdWgt::primitive_desc>& pd);
/**
* reset the backward data primitive descriptor.
*/
void resetBwdDataPD(std::shared_ptr<conv_bwdData::primitive_desc>& pd); void resetBwdDataPD(std::shared_ptr<conv_bwdData::primitive_desc>& pd);
/**
* reset the MKLDNNMatrix buffers used in backward.
*/
void resetBwdBuffers(std::shared_ptr<conv_bwdWgt::primitive_desc>& wgtPD, void resetBwdBuffers(std::shared_ptr<conv_bwdWgt::primitive_desc>& wgtPD,
std::shared_ptr<conv_bwdData::primitive_desc>& dataPD, std::shared_ptr<conv_bwdData::primitive_desc>& dataPD,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
/**
* reset the backward pipeline.
*/
void resetBwdPipeline(std::vector<mkldnn::primitive>& pipeline, void resetBwdPipeline(std::vector<mkldnn::primitive>& pipeline,
std::shared_ptr<conv_bwdWgt::primitive_desc>& wgtPD, std::shared_ptr<conv_bwdWgt::primitive_desc>& wgtPD,
std::shared_ptr<conv_bwdData::primitive_desc>& dataPD, std::shared_ptr<conv_bwdData::primitive_desc>& dataPD,
......
...@@ -74,7 +74,7 @@ void MKLDNNFcLayer::convertWeightsToPaddle() { ...@@ -74,7 +74,7 @@ void MKLDNNFcLayer::convertWeightsToPaddle() {
} }
void MKLDNNFcLayer::reshape( void MKLDNNFcLayer::reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) { int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) {
reshapeInput(bs, ih, iw); reshapeInput(bs, ih, iw);
CHECK_EQ(iLayerSize_, inputLayers_[0]->getSize()); CHECK_EQ(iLayerSize_, inputLayers_[0]->getSize());
...@@ -87,32 +87,29 @@ void MKLDNNFcLayer::reshape( ...@@ -87,32 +87,29 @@ void MKLDNNFcLayer::reshape(
} }
void MKLDNNFcLayer::resetFwd(std::vector<primitive>& pipeline, void MKLDNNFcLayer::resetFwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
resetFwdBuffers(in, wgt, bias, out); resetFwdBuffers(inputs[0], wgtVal_, biasVal_, out);
resetFwdPD(fwdPD_, in, wgt, bias, out); resetFwdPD(fwdPD_, inputs[0], wgtVal_, biasVal_, out);
resetFwdPipeline(pipeline, fwdPD_, in, wgt, bias, out); resetFwdPipeline(pipeline, fwdPD_, inputs[0], wgtVal_, biasVal_, out);
} }
void MKLDNNFcLayer::resetBwd(std::vector<primitive>& pipeline, void MKLDNNFcLayer::resetBwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
std::shared_ptr<fc_bwdWgt::primitive_desc> bwdWgtPD; std::shared_ptr<fc_bwdWgt::primitive_desc> bwdWgtPD;
std::shared_ptr<fc_bwdData::primitive_desc> bwdDataPD; std::shared_ptr<fc_bwdData::primitive_desc> bwdDataPD;
resetBwdBuffers(in, wgt, bias, out); resetBwdBuffers(inputs[0], wgtGrad_, biasGrad_, out);
resetBwdWgtPD(bwdWgtPD, wgt, bias, out); resetBwdWgtPD(bwdWgtPD, wgtGrad_, biasGrad_, out);
resetBwdDataPD(bwdDataPD, in, out); resetBwdDataPD(bwdDataPD, inputs[0], out);
resetBwdPipeline(pipeline, bwdWgtPD, bwdDataPD, in, wgt, bias, out); resetBwdPipeline(
pipeline, bwdWgtPD, bwdDataPD, inputs[0], wgtGrad_, biasGrad_, out);
} }
void MKLDNNFcLayer::updateWeights(const UpdateCallback& callback) { void MKLDNNFcLayer::updateWeights(const UpdateCallback& callback) {
...@@ -193,9 +190,9 @@ void MKLDNNFcLayer::resetBwdBuffers(MKLDNNMatrixPtr& in, ...@@ -193,9 +190,9 @@ void MKLDNNFcLayer::resetBwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
CHECK(inVal_ && outVal_); CHECK(inVals_[0] && outVal_);
resetOutGrad(out, outVal_->getPrimitiveDesc()); resetOutGrad(out, outVal_->getPrimitiveDesc());
resetInGrad(in, inVal_->getPrimitiveDesc()); resetInGrad(in, inVals_[0]->getPrimitiveDesc());
CHECK(wgtVal_); CHECK(wgtVal_);
resetWithMatrix(wgt, weight_->getWGrad(), wgtVal_->getPrimitiveDesc()); resetWithMatrix(wgt, weight_->getWGrad(), wgtVal_->getPrimitiveDesc());
...@@ -212,14 +209,15 @@ void MKLDNNFcLayer::resetBwdWgtPD( ...@@ -212,14 +209,15 @@ void MKLDNNFcLayer::resetBwdWgtPD(
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
CHECK(inVal_); CHECK(inVals_[0]);
fc_bwdWgt::desc bwdWgtDesc = bias ? fc_bwdWgt::desc(inVal_->getMemoryDesc(), fc_bwdWgt::desc bwdWgtDesc =
wgt->getMemoryDesc(), bias ? fc_bwdWgt::desc(inVals_[0]->getMemoryDesc(),
bias->getMemoryDesc(), wgt->getMemoryDesc(),
out->getMemoryDesc()) bias->getMemoryDesc(),
: fc_bwdWgt::desc(inVal_->getMemoryDesc(), out->getMemoryDesc())
wgt->getMemoryDesc(), : fc_bwdWgt::desc(inVals_[0]->getMemoryDesc(),
out->getMemoryDesc()); wgt->getMemoryDesc(),
out->getMemoryDesc());
pd.reset(new fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, *fwdPD_)); pd.reset(new fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, *fwdPD_));
} }
...@@ -245,11 +243,11 @@ void MKLDNNFcLayer::resetBwdPipeline( ...@@ -245,11 +243,11 @@ void MKLDNNFcLayer::resetBwdPipeline(
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
CHECK(inVal_); CHECK(inVals_[0]);
if (bias) { if (bias) {
bwdWgt_.reset(new fc_bwdWgt(*bwdWgtPD, *inVal_, *out, *wgt, *bias)); bwdWgt_.reset(new fc_bwdWgt(*bwdWgtPD, *inVals_[0], *out, *wgt, *bias));
} else { } else {
bwdWgt_.reset(new fc_bwdWgt(*bwdWgtPD, *inVal_, *out, *wgt)); bwdWgt_.reset(new fc_bwdWgt(*bwdWgtPD, *inVals_[0], *out, *wgt));
} }
pipeline.push_back(*bwdWgt_); pipeline.push_back(*bwdWgt_);
......
...@@ -52,18 +52,14 @@ public: ...@@ -52,18 +52,14 @@ public:
const ParameterMap& parameterMap) override; const ParameterMap& parameterMap) override;
void reshape( void reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override; int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) override;
void resetFwd(std::vector<mkldnn::primitive>& pipeline, void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void resetBwd(std::vector<mkldnn::primitive>& pipeline, void resetBwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void updateWeights(const UpdateCallback& callback) override; void updateWeights(const UpdateCallback& callback) override;
...@@ -73,11 +69,6 @@ public: ...@@ -73,11 +69,6 @@ public:
void convertWeightsToPaddle() override; void convertWeightsToPaddle() override;
protected: protected:
/**
* Forward functions: reset buffers(input, output, weight and bias),
* reset primitive descriptor,
* reset pipeline.
*/
void resetFwdBuffers(MKLDNNMatrixPtr& in, void resetFwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
...@@ -93,13 +84,6 @@ protected: ...@@ -93,13 +84,6 @@ protected:
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
/**
* Backward functions: reset buffers(input, output, weight and bias),
* reset primitive descriptor for backward weight,
* reset primitive descriptor for backward data,
* reset pipeline.
*/
void resetBwdBuffers(MKLDNNMatrixPtr& in, void resetBwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
......
...@@ -48,31 +48,20 @@ void MKLDNNLayer::forward(PassType passType) { ...@@ -48,31 +48,20 @@ void MKLDNNLayer::forward(PassType passType) {
REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str()); REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str());
CHECK(!inputLayers_.empty()); CHECK(!inputLayers_.empty());
copySeqInfoToOutputs(); copySeqInfoToOutputs();
size_t elemenCnt = inputLayers_[0]->getOutputValue()->getElementCnt(); if (condition_ != keepCondition()) {
if (inputElemenCnt_ != elemenCnt) {
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward"; VLOG(MKLDNN_BASE) << getName() << " reset mkldnn forward";
// reset when input total sizes changed, not only the batchsize condition_ = keepCondition();
inputElemenCnt_ = elemenCnt;
pipelineFwd_.clear();
reshape(bs_, ic_, ih_, iw_, oc_, oh_, ow_); reshape(bs_, ic_, ih_, iw_, oc_, oh_, ow_);
// all cpu device output grad or value share output's printSizeInfo();
// the output_.value and output_.grad are shared with CPU device
shareCPUDevice(); shareCPUDevice();
resetFwd(pipelineFwd_, inVal_, wgtVal_, biasVal_, outVal_); pipelineFwd_.clear();
// MKLDNNLayer output value should be MKLDNNMatrix inVals_.resize(inputLayers_.size(), nullptr);
// so external output value is necessary. extInVals_.resize(inputLayers_.size(), nullptr);
// Then external input value is not necessary, cvtInVals_.resize(inputLayers_.size(), nullptr);
// since input may be mkldnn internal buffer. resetFwd(pipelineFwd_, inVals_, outVal_);
CHECK(extOutVal_) << "external output value is necessary"; prepareValueConversions(pipelineFwd_);
output_.value = std::dynamic_pointer_cast<Matrix>(extOutVal_);
CHECK(inVal_ && outVal_) << "internal memories are necessary";
if (cvtInVal_) {
pipelineFwd_.insert(pipelineFwd_.begin(), *cvtInVal_);
}
if (cvtOutVal_) {
pipelineFwd_.push_back(*cvtOutVal_);
}
convertWeightsFromPaddle(); convertWeightsFromPaddle();
printSizeInfo();
printValueFormat(); printValueFormat();
needResetBwd_ = true; needResetBwd_ = true;
} }
...@@ -80,8 +69,8 @@ void MKLDNNLayer::forward(PassType passType) { ...@@ -80,8 +69,8 @@ void MKLDNNLayer::forward(PassType passType) {
if (inputLayers_[0]->getType() == "data" && inputLayers_.size() == 1) { if (inputLayers_[0]->getType() == "data" && inputLayers_.size() == 1) {
// Update input value data when input layer is "data" type, // Update input value data when input layer is "data" type,
// since the input value data address might be changed. // since the input value data address might be changed.
CHECK(extInVal_); CHECK(extInVals_[0]);
extInVal_->setData(getInputValue(0, CPU_DEVICE)->getData()); extInVals_[0]->setData(getInputValue(0, CPU_DEVICE)->getData());
} }
if (!outputOnlyMKLDNN_) { if (!outputOnlyMKLDNN_) {
...@@ -99,22 +88,13 @@ void MKLDNNLayer::backward(const UpdateCallback& callback) { ...@@ -99,22 +88,13 @@ void MKLDNNLayer::backward(const UpdateCallback& callback) {
if (needResetBwd_) { if (needResetBwd_) {
VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward"; VLOG(MKLDNN_BASE) << getName() << " reset mkldnn backward";
pipelineBwd_.clear(); pipelineBwd_.clear();
inGrads_.resize(inputLayers_.size(), nullptr);
extInGrads_.resize(inputLayers_.size(), nullptr);
cvtInGrads_.resize(inputLayers_.size(), nullptr);
pipelineMergeGrad_.clear(); pipelineMergeGrad_.clear();
mergeGrad_ = nullptr; mergeGrad_ = nullptr;
resetBwd(pipelineBwd_, inGrad_, wgtGrad_, biasGrad_, outGrad_); resetBwd(pipelineBwd_, inGrads_, outGrad_);
// external output grad is not necessary prepareGradConversions(pipelineBwd_);
// since output may be mkldnn internal buffer or merge them directly.
CHECK(outGrad_) << "internal output grad is necessary";
if (extOutGrad_) {
CHECK_EQ(extOutGrad_->getData(), output_.grad->getData())
<< "the external buffer should share the same data with output_.grad";
}
if (cvtOutGrad_) {
pipelineBwd_.insert(pipelineBwd_.begin(), *cvtOutGrad_);
}
if (cvtInGrad_) {
pipelineBwd_.push_back(*cvtInGrad_);
}
printGradFormat(); printGradFormat();
needResetBwd_ = false; needResetBwd_ = false;
} }
...@@ -141,8 +121,8 @@ void MKLDNNLayer::backward(const UpdateCallback& callback) { ...@@ -141,8 +121,8 @@ void MKLDNNLayer::backward(const UpdateCallback& callback) {
void MKLDNNLayer::reshapeInput(int& batchsize, void MKLDNNLayer::reshapeInput(int& batchsize,
int& height, int& height,
int& width, int& width,
size_t inputIdx) { size_t idx) {
const Argument& input = inputLayers_[inputIdx]->getOutput(); const Argument& input = inputLayers_[idx]->getOutput();
batchsize = input.getBatchSize(); batchsize = input.getBatchSize();
int h = input.getFrameHeight(); int h = input.getFrameHeight();
int w = input.getFrameWidth(); int w = input.getFrameWidth();
...@@ -176,27 +156,30 @@ void MKLDNNLayer::resetWithMatrix(MKLDNNMatrixPtr& dnn, ...@@ -176,27 +156,30 @@ void MKLDNNLayer::resetWithMatrix(MKLDNNMatrixPtr& dnn,
void MKLDNNLayer::resetInValue( void MKLDNNLayer::resetInValue(
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
const std::shared_ptr<memory::primitive_desc>& intPD, const std::shared_ptr<memory::primitive_desc>& intPD,
size_t inputIdx) { size_t idx,
cvtInVal_ = nullptr; int inputChannel) {
extInVal_ = nullptr; cvtInVals_[idx] = nullptr;
extInVals_[idx] = nullptr;
in = nullptr; in = nullptr;
CHECK_GT(bs_ * ic_ * ih_ * iw_, 0); inputChannel = inputChannel == 0 ? ic_ : inputChannel;
CHECK_GT(bs_ * inputChannel * ih_ * iw_, 0);
auto extPD = MKLDNNMatrix::createPrimitiveDesc( auto extPD = MKLDNNMatrix::createPrimitiveDesc(
{bs_, ic_, ih_, iw_}, format::nchw, engine_); {bs_, inputChannel, ih_, iw_}, format::nchw, engine_);
const MatrixPtr& inMat = inputLayers_[inputIdx]->getOutputValue(); const MatrixPtr& inMat = inputLayers_[idx]->getOutputValue();
extInVal_ = std::dynamic_pointer_cast<MKLDNNMatrix>(inMat); extInVals_[idx] = std::dynamic_pointer_cast<MKLDNNMatrix>(inMat);
CHECK_EQ(inputIsOnlyMKLDNN(), extInVal_ != nullptr); CHECK_EQ(inputIsOnlyMKLDNN(), extInVals_[idx] != nullptr);
if (extInVal_ == nullptr || extInVal_->getFormat() == format::nc) { if (extInVals_[idx] == nullptr ||
extInVal_ = MKLDNNMatrix::create(extPD, inMat); extInVals_[idx]->getFormat() == format::nc) {
extInVals_[idx] = MKLDNNMatrix::create(extPD, inMat);
} }
in = extInVal_; in = extInVals_[idx];
if (nullptr == intPD || in->getPrimitiveDesc() == *intPD) { if (nullptr == intPD || in->getPrimitiveDesc() == *intPD) {
return; return;
} }
// need create reorder // need create reorder
in = MKLDNNMatrix::create(*intPD); in = MKLDNNMatrix::create(*intPD);
cvtInVal_ = MKLDNNMatrix::createReorder(extInVal_, in); cvtInVals_[idx] = MKLDNNMatrix::createReorder(extInVals_[idx], in);
CHECK(cvtInVal_) << "should not be emptry"; CHECK(cvtInVals_[idx]) << "should not be emptry";
} }
void MKLDNNLayer::resetOutValue(MKLDNNMatrixPtr& out, void MKLDNNLayer::resetOutValue(MKLDNNMatrixPtr& out,
...@@ -218,11 +201,11 @@ void MKLDNNLayer::resetOutValue(MKLDNNMatrixPtr& out, ...@@ -218,11 +201,11 @@ void MKLDNNLayer::resetOutValue(MKLDNNMatrixPtr& out,
void MKLDNNLayer::resetInGrad(MKLDNNMatrixPtr& in, void MKLDNNLayer::resetInGrad(MKLDNNMatrixPtr& in,
memory::primitive_desc intPD, memory::primitive_desc intPD,
size_t inputIdx) { size_t idx) {
cvtInGrad_ = nullptr; cvtInGrads_[idx] = nullptr;
extInGrad_ = nullptr; extInGrads_[idx] = nullptr;
in = nullptr; in = nullptr;
LayerPtr& input = inputLayers_[inputIdx]; LayerPtr& input = inputLayers_[idx];
if (input->getOutputGrad() == nullptr) { if (input->getOutputGrad() == nullptr) {
// no need input grad // no need input grad
return; return;
...@@ -237,23 +220,25 @@ void MKLDNNLayer::resetInGrad(MKLDNNMatrixPtr& in, ...@@ -237,23 +220,25 @@ void MKLDNNLayer::resetInGrad(MKLDNNMatrixPtr& in,
in = MKLDNNMatrix::create(intPD, inMat); in = MKLDNNMatrix::create(intPD, inMat);
Argument& arg = input->getOutput(this->getName()); Argument& arg = input->getOutput(this->getName());
arg.grad = std::dynamic_pointer_cast<Matrix>(in); arg.grad = std::dynamic_pointer_cast<Matrix>(in);
CHECK_PRIMITIVE_DESC_EQ(inVal_, intPD); CHECK_PRIMITIVE_DESC_EQ(inVals_[idx], intPD);
if (inputIsOnlyMKLDNN()) { if (inputIsOnlyMKLDNN()) {
return; return;
} }
extInGrad_ = in; extInGrads_[idx] = in;
if (isPaddleFormat(extInGrad_->getFormat())) { if (isPaddleFormat(extInGrads_[idx]->getFormat())) {
return; return;
} }
// need create reorder // need create reorder
CHECK(extInVal_ != nullptr && isPaddleFormat(extInVal_->getFormat())) CHECK(extInVals_[idx] != nullptr &&
isPaddleFormat(extInVals_[idx]->getFormat()))
<< "should have external input value and the format must be nchw(nc)"; << "should have external input value and the format must be nchw(nc)";
extInGrad_ = MKLDNNMatrix::create(extInVal_->getPrimitiveDesc(), inMat); extInGrads_[idx] =
CHECK_PRIMITIVE_DESC_EQ(inVal_, intPD); MKLDNNMatrix::create(extInVals_[idx]->getPrimitiveDesc(), inMat);
CHECK_PRIMITIVE_DESC_EQ(inVals_[idx], intPD);
in = MKLDNNMatrix::create(intPD); in = MKLDNNMatrix::create(intPD);
cvtInGrad_ = MKLDNNMatrix::createReorder(in, extInGrad_); cvtInGrads_[idx] = MKLDNNMatrix::createReorder(in, extInGrads_[idx]);
CHECK(cvtInGrad_); CHECK(cvtInGrads_[idx]);
} }
void MKLDNNLayer::resetOutGrad(MKLDNNMatrixPtr& out, void MKLDNNLayer::resetOutGrad(MKLDNNMatrixPtr& out,
......
...@@ -34,15 +34,16 @@ typedef std::shared_ptr<MKLDNNLayer> MKLDNNLayerPtr; ...@@ -34,15 +34,16 @@ typedef std::shared_ptr<MKLDNNLayer> MKLDNNLayerPtr;
*/ */
class MKLDNNLayer : public Layer { class MKLDNNLayer : public Layer {
protected: protected:
// input value element count
size_t inputElemenCnt_;
// batch size // batch size
int bs_; int bs_;
// they sizes are always from the first input layer
// input image channel, height and width // input image channel, height and width
int ic_, ih_, iw_; int ic_, ih_, iw_;
// output image channel, height and width // output image channel, height and width
int oc_, oh_, ow_; int oc_, oh_, ow_;
// the condition that forward need be reset
size_t condition_;
// backward also need reset after reset forward handle // backward also need reset after reset forward handle
bool needResetBwd_; bool needResetBwd_;
...@@ -67,18 +68,18 @@ protected: ...@@ -67,18 +68,18 @@ protected:
* When all layers are mkldnn layers, they could save internal data. * When all layers are mkldnn layers, they could save internal data.
*/ */
// below MKLDNNMatrix buffers are all internal buffers // below MKLDNNMatrix buffers are all internal buffers
MKLDNNMatrixPtr inVal_; std::vector<MKLDNNMatrixPtr> inVals_;
MKLDNNMatrixPtr inGrad_; std::vector<MKLDNNMatrixPtr> inGrads_;
MKLDNNMatrixPtr outVal_; MKLDNNMatrixPtr outVal_;
MKLDNNMatrixPtr outGrad_; MKLDNNMatrixPtr outGrad_;
// below are external value and grad // below are external value and grad
MKLDNNMatrixPtr extInVal_; std::vector<MKLDNNMatrixPtr> extInVals_;
MKLDNNMatrixPtr extInGrad_; std::vector<MKLDNNMatrixPtr> extInGrads_;
MKLDNNMatrixPtr extOutVal_; MKLDNNMatrixPtr extOutVal_;
MKLDNNMatrixPtr extOutGrad_; MKLDNNMatrixPtr extOutGrad_;
// convert handle between external and internal buffers // convert handle between external and internal buffers
std::shared_ptr<mkldnn::reorder> cvtInVal_; std::vector<std::shared_ptr<mkldnn::reorder>> cvtInVals_;
std::shared_ptr<mkldnn::reorder> cvtInGrad_; std::vector<std::shared_ptr<mkldnn::reorder>> cvtInGrads_;
std::shared_ptr<mkldnn::reorder> cvtOutVal_; std::shared_ptr<mkldnn::reorder> cvtOutVal_;
std::shared_ptr<mkldnn::reorder> cvtOutGrad_; std::shared_ptr<mkldnn::reorder> cvtOutGrad_;
...@@ -102,14 +103,7 @@ protected: ...@@ -102,14 +103,7 @@ protected:
public: public:
explicit MKLDNNLayer(const LayerConfig& config) explicit MKLDNNLayer(const LayerConfig& config)
: Layer(config), : Layer(config),
inputElemenCnt_(0), condition_(0),
bs_(0),
ic_(0),
ih_(0),
iw_(0),
oc_(0),
oh_(0),
ow_(0),
needResetBwd_(true), needResetBwd_(true),
outputOnlyMKLDNN_(false), outputOnlyMKLDNN_(false),
engine_(mkldnn::engine::cpu, 0), engine_(mkldnn::engine::cpu, 0),
...@@ -125,31 +119,28 @@ public: ...@@ -125,31 +119,28 @@ public:
virtual void backward(const UpdateCallback& callback); virtual void backward(const UpdateCallback& callback);
/** /**
* reshape the input image sizes * reshape the input and output channels and image sizes
* and reset output image and buffer size * and reset output buffer size
* output channel can not be changed
*/ */
virtual void reshape( virtual void reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) = 0; int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) = 0;
/** /**
* reset the mkldnn forward primitve and memories * reset the mkldnn forward primitve and memories
* only would be called when input size changes * only would be called when input size changes
* weight and bias buffers should be coverd by child class itself
*/ */
virtual void resetFwd(std::vector<mkldnn::primitive>& pipeline, virtual void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) = 0; MKLDNNMatrixPtr& out) = 0;
/** /**
* reset the mkldnn backward primitve and memories * reset the mkldnn backward primitve and memories
* only would be called when needed * only would be called when needed
* weight and bias buffers should be coverd by child class itself
*/ */
virtual void resetBwd(std::vector<mkldnn::primitive>& pipeline, virtual void resetBwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) = 0; MKLDNNMatrixPtr& out) = 0;
/** /**
...@@ -175,13 +166,19 @@ public: ...@@ -175,13 +166,19 @@ public:
void addOutputArgument(int deviceId) { Layer::addOutputArgument(deviceId); } void addOutputArgument(int deviceId) { Layer::addOutputArgument(deviceId); }
protected: protected:
/**
* Some layers may have different condition to reset the forward.
* The function returns the condition that do not need reset forward.
*/
inline virtual size_t keepCondition() {
// reset when the first input element size changed, not only the batchsize
return inputLayers_[0]->getOutputValue()->getElementCnt();
}
/** /**
* reshape the input image sizes and input batchsize * reshape the input image sizes and input batchsize
*/ */
void reshapeInput(int& batchsize, void reshapeInput(int& batchsize, int& height, int& width, size_t idx = 0);
int& height,
int& width,
size_t inputIdx = 0);
/** /**
* reshape output image sizes * reshape output image sizes
...@@ -199,11 +196,13 @@ protected: ...@@ -199,11 +196,13 @@ protected:
/** /**
* reset input value from input MKLDNNMatrix and internal primitive desc. * reset input value from input MKLDNNMatrix and internal primitive desc.
* reset both internal and external buffer and create reorder if necessary. * reset both internal and external buffer and create reorder if necessary.
* input channel may be different in concat.
*/ */
void resetInValue( void resetInValue(
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
const std::shared_ptr<mkldnn::memory::primitive_desc>& intPD = nullptr, const std::shared_ptr<mkldnn::memory::primitive_desc>& intPD = nullptr,
size_t inputIdx = 0); size_t idx = 0,
int inputChannel = 0);
/** /**
* reset output value from internal primitive desc. * reset output value from internal primitive desc.
...@@ -218,7 +217,7 @@ protected: ...@@ -218,7 +217,7 @@ protected:
*/ */
void resetInGrad(MKLDNNMatrixPtr& in, void resetInGrad(MKLDNNMatrixPtr& in,
mkldnn::memory::primitive_desc intPD, mkldnn::memory::primitive_desc intPD,
size_t inputIdx = 0); size_t idx = 0);
/** /**
* reset output grad from internal primitive desc. * reset output grad from internal primitive desc.
...@@ -296,17 +295,19 @@ protected: ...@@ -296,17 +295,19 @@ protected:
* print the mkldnn memory format of value * print the mkldnn memory format of value
*/ */
virtual void printValueFormat() { virtual void printValueFormat() {
if (extInVal_) { for (size_t i = 0; i < inVals_.size(); ++i) {
VLOG(MKLDNN_FMTS) << extInVal_->getFormat() << " >>> "; if (!inVals_[i]) {
} continue;
if (inVal_) { }
VLOG(MKLDNN_FMTS) << inVal_->getFormat() << " >>>"; VLOG(MKLDNN_FMTS) << "Input " << i << ", " << inputLayers_[i]->getName()
<< ": " << (extInVals_[i] ? extInVals_[i]->getFormat()
: inVals_[i]->getFormat())
<< " >>> " << inVals_[i]->getFormat() << " >>>";
} }
if (outVal_) { if (outVal_) {
VLOG(MKLDNN_FMTS) << outVal_->getFormat() << " >>> "; VLOG(MKLDNN_FMTS) << outVal_->getFormat() << " >>> "
} << (extOutVal_ ? extOutVal_->getFormat()
if (extOutVal_) { : outVal_->getFormat());
VLOG(MKLDNN_FMTS) << extOutVal_->getFormat();
} }
if (wgtVal_) { if (wgtVal_) {
VLOG(MKLDNN_FMTS) << "Weight value format: " << wgtVal_->getFormat(); VLOG(MKLDNN_FMTS) << "Weight value format: " << wgtVal_->getFormat();
...@@ -320,17 +321,19 @@ protected: ...@@ -320,17 +321,19 @@ protected:
* print the mkldnn memory format of grad * print the mkldnn memory format of grad
*/ */
virtual void printGradFormat() { virtual void printGradFormat() {
if (extOutGrad_) {
VLOG(MKLDNN_FMTS) << extOutGrad_->getFormat();
}
if (outGrad_) { if (outGrad_) {
VLOG(MKLDNN_FMTS) << outGrad_->getFormat() << " <<< "; VLOG(MKLDNN_FMTS) << outGrad_->getFormat() << " <<< "
<< (extOutGrad_ ? extOutGrad_->getFormat()
: outGrad_->getFormat());
} }
if (inGrad_) { for (size_t i = 0; i < inGrads_.size(); ++i) {
VLOG(MKLDNN_FMTS) << inGrad_->getFormat() << " <<<"; if (!inGrads_[i]) {
} continue;
if (extInGrad_) { }
VLOG(MKLDNN_FMTS) << extInGrad_->getFormat() << " <<< "; VLOG(MKLDNN_FMTS) << "Input " << i << ", " << inputLayers_[i]->getName()
<< ": " << (extInGrads_[i] ? extInGrads_[i]->getFormat()
: inGrads_[i]->getFormat())
<< " <<< " << inGrads_[i]->getFormat() << " <<<";
} }
if (wgtGrad_) { if (wgtGrad_) {
VLOG(MKLDNN_FMTS) << "Weight grad format: " << wgtGrad_->getFormat(); VLOG(MKLDNN_FMTS) << "Weight grad format: " << wgtGrad_->getFormat();
...@@ -437,6 +440,41 @@ private: ...@@ -437,6 +440,41 @@ private:
outputOtherDevice_[i].cpuSequenceDims = output_.cpuSequenceDims; outputOtherDevice_[i].cpuSequenceDims = output_.cpuSequenceDims;
} }
} }
void prepareValueConversions(std::vector<mkldnn::primitive>& pipeline) {
// MKLDNNLayer output value should be MKLDNNMatrix
// so external output value is necessary.
// Then external input value is not necessary,
// since input may be mkldnn internal buffer.
CHECK(extOutVal_) << "external output value is necessary";
output_.value = std::dynamic_pointer_cast<Matrix>(extOutVal_);
CHECK(inVals_[0] && outVal_) << "internal memories are necessary";
for (size_t i = 0; i < cvtInVals_.size(); ++i) {
if (cvtInVals_[i]) {
pipeline.insert(pipeline.begin(), *cvtInVals_[i]);
}
}
if (cvtOutVal_) {
pipeline.push_back(*cvtOutVal_);
}
}
void prepareGradConversions(std::vector<mkldnn::primitive>& pipeline) {
// external output grad is not necessary
// since output may be mkldnn internal buffer or merge them directly.
CHECK(outGrad_) << "internal output grad is necessary";
if (extOutGrad_) {
CHECK_EQ(extOutGrad_->getData(), output_.grad->getData())
<< "the external buffer should share the same data with output_.grad";
}
if (cvtOutGrad_) {
pipeline.insert(pipeline.begin(), *cvtOutGrad_);
}
for (size_t i = 0; i < cvtInGrads_.size(); ++i) {
if (cvtInGrads_[i]) {
pipeline.push_back(*cvtInGrads_[i]);
}
}
}
}; };
} // namespace paddle } // namespace paddle
...@@ -58,10 +58,11 @@ bool MKLDNNPoolLayer::init(const LayerMap& layerMap, ...@@ -58,10 +58,11 @@ bool MKLDNNPoolLayer::init(const LayerMap& layerMap,
} }
void MKLDNNPoolLayer::reshape( void MKLDNNPoolLayer::reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) { int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) {
reshapeInput(bs, ih, iw); reshapeInput(bs, ih, iw);
// ic_ and oc can not be changed // ic_ and oc can not be changed
CHECK_EQ(inputElemenCnt_ / bs / ih / iw, (size_t)ic) CHECK_EQ((size_t)ic,
inputLayers_[0]->getOutputValue()->getElementCnt() / bs / ih / iw)
<< "Input channel can not be changed"; << "Input channel can not be changed";
// cal output sizes // cal output sizes
...@@ -74,29 +75,25 @@ void MKLDNNPoolLayer::reshape( ...@@ -74,29 +75,25 @@ void MKLDNNPoolLayer::reshape(
} }
void MKLDNNPoolLayer::resetFwd(std::vector<primitive>& pipeline, void MKLDNNPoolLayer::resetFwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
resetFwdBuffers(in, out); resetFwdBuffers(inputs[0], out);
resetFwdPD(fwdPD_, in, out); resetFwdPD(fwdPD_, inputs[0], out);
resetFwdPipeline(pipeline, fwdPD_, in, out); resetFwdPipeline(pipeline, fwdPD_, inputs[0], out);
} }
void MKLDNNPoolLayer::resetBwd(std::vector<primitive>& pipeline, void MKLDNNPoolLayer::resetBwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
std::shared_ptr<pool_bwd::primitive_desc> pd; std::shared_ptr<pool_bwd::primitive_desc> pd;
resetBwdBuffers(in, out); resetBwdBuffers(inputs[0], out);
resetBwdPD(pd, in, out); resetBwdPD(pd, inputs[0], out);
resetBwdPipeline(pipeline, pd, in, out); resetBwdPipeline(pipeline, pd, inputs[0], out);
} }
void MKLDNNPoolLayer::resetFwdBuffers(MKLDNNMatrixPtr& in, void MKLDNNPoolLayer::resetFwdBuffers(MKLDNNMatrixPtr& in,
...@@ -151,9 +148,9 @@ void MKLDNNPoolLayer::resetFwdPipeline( ...@@ -151,9 +148,9 @@ void MKLDNNPoolLayer::resetFwdPipeline(
void MKLDNNPoolLayer::resetBwdBuffers(MKLDNNMatrixPtr& in, void MKLDNNPoolLayer::resetBwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
CHECK(inVal_ && outVal_); CHECK(inVals_[0] && outVal_);
resetOutGrad(out, outVal_->getPrimitiveDesc()); resetOutGrad(out, outVal_->getPrimitiveDesc());
resetInGrad(in, inVal_->getPrimitiveDesc()); resetInGrad(in, inVals_[0]->getPrimitiveDesc());
} }
void MKLDNNPoolLayer::resetBwdPD(std::shared_ptr<pool_bwd::primitive_desc>& pd, void MKLDNNPoolLayer::resetBwdPD(std::shared_ptr<pool_bwd::primitive_desc>& pd,
......
...@@ -53,18 +53,14 @@ public: ...@@ -53,18 +53,14 @@ public:
const ParameterMap& parameterMap) override; const ParameterMap& parameterMap) override;
void reshape( void reshape(
int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override; int& bs, int& ic, int& ih, int& iw, int& oc, int& oh, int& ow) override;
void resetFwd(std::vector<mkldnn::primitive>& pipeline, void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void resetBwd(std::vector<mkldnn::primitive>& pipeline, void resetBwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void printSizeInfo() override { void printSizeInfo() override {
...@@ -75,11 +71,6 @@ public: ...@@ -75,11 +71,6 @@ public:
} }
protected: protected:
/**
* Forward functions: reset buffers(input, output),
* reset primitive descriptor,
* reset pipeline.
*/
void resetFwdBuffers(MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out); void resetFwdBuffers(MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out);
void resetFwdPD(std::shared_ptr<pool_fwd::primitive_desc>& pd, void resetFwdPD(std::shared_ptr<pool_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr in, MKLDNNMatrixPtr in,
...@@ -88,12 +79,6 @@ protected: ...@@ -88,12 +79,6 @@ protected:
std::shared_ptr<pool_fwd::primitive_desc>& pd, std::shared_ptr<pool_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
/**
* Backward functions: reset buffers(input, output),
* reset primitive descriptor,
* reset pipeline.
*/
void resetBwdBuffers(MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out); void resetBwdBuffers(MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out);
void resetBwdPD(std::shared_ptr<pool_bwd::primitive_desc>& pd, void resetBwdPD(std::shared_ptr<pool_bwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册