提交 c397599d 编写于 作者: T tensor-tang

remove weight and bias in MKLDNN reset function, since not all layers have weight and bias.

and remove some comments.
上级 a9490a10
...@@ -58,25 +58,21 @@ void MKLDNNAddtoLayer::reshape( ...@@ -58,25 +58,21 @@ void MKLDNNAddtoLayer::reshape(
void MKLDNNAddtoLayer::resetFwd(std::vector<primitive>& pipeline, void MKLDNNAddtoLayer::resetFwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
resetFwdBuffers(inVals_, bias, out); resetFwdBuffers(inVals_, biasVal_, out);
in = inVals_[0]; in = inVals_[0];
std::shared_ptr<sum::primitive_desc> fwdPD; std::shared_ptr<sum::primitive_desc> fwdPD;
std::shared_ptr<sum::primitive_desc> biasPD; std::shared_ptr<sum::primitive_desc> biasPD;
resetFwdPD(fwdPD, biasPD, inVals_, bias, out); resetFwdPD(fwdPD, biasPD, inVals_, biasVal_, out);
resetFwdPipeline(pipeline, fwdPD, biasPD, inVals_, bias, out); resetFwdPipeline(pipeline, fwdPD, biasPD, inVals_, biasVal_, out);
} }
void MKLDNNAddtoLayer::resetBwd(std::vector<primitive>& pipeline, void MKLDNNAddtoLayer::resetBwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
resetBwdBuffers(inGrads_, bias, out); resetBwdBuffers(inGrads_, biasGrad_, out);
in = inGrads_[0]; in = inGrads_[0];
// backward only need share output grad to input grad // backward only need share output grad to input grad
...@@ -89,15 +85,17 @@ void MKLDNNAddtoLayer::resetBwd(std::vector<primitive>& pipeline, ...@@ -89,15 +85,17 @@ void MKLDNNAddtoLayer::resetBwd(std::vector<primitive>& pipeline,
// backward bias // backward bias
bwdBias_ = nullptr; bwdBias_ = nullptr;
if (bias) { if (biasGrad_) {
std::vector<float> scales(bs_, 1.0); std::vector<float> scales(bs_, 1.0);
std::vector<memory::primitive_desc> srcPDs(bs_, bias->getPrimitiveDesc()); std::vector<memory::primitive_desc> srcPDs(bs_,
auto biasPD = sum::primitive_desc(bias->getMemoryDesc(), scales, srcPDs); biasGrad_->getPrimitiveDesc());
auto biasPD =
sum::primitive_desc(biasGrad_->getMemoryDesc(), scales, srcPDs);
std::vector<primitive::at> srcs; std::vector<primitive::at> srcs;
for (size_t i = 0; i < grads_.size(); ++i) { for (size_t i = 0; i < grads_.size(); ++i) {
srcs.push_back(*(grads_[i])); srcs.push_back(*(grads_[i]));
} }
bwdBias_.reset(new sum(biasPD, srcs, *bias)); bwdBias_.reset(new sum(biasPD, srcs, *biasGrad_));
pipeline.push_back(*bwdBias_); pipeline.push_back(*bwdBias_);
} }
} }
......
...@@ -54,14 +54,10 @@ public: ...@@ -54,14 +54,10 @@ public:
void resetFwd(std::vector<mkldnn::primitive>& pipeline, void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void resetBwd(std::vector<mkldnn::primitive>& pipeline, void resetBwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void updateWeights(const UpdateCallback& callback) override; void updateWeights(const UpdateCallback& callback) override;
...@@ -91,11 +87,6 @@ public: ...@@ -91,11 +87,6 @@ public:
} }
protected: protected:
/**
* Forward functions: reset buffers(inputs, output, bias),
* reset primitive descriptor,
* reset pipeline.
*/
void resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs, void resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
...@@ -110,17 +101,10 @@ protected: ...@@ -110,17 +101,10 @@ protected:
std::vector<MKLDNNMatrixPtr>& inputs, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
/**
* Backward functions: reset buffers(inputs, output, bias)
*/
void resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs, void resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
/**
* prepare for bias
*/
void prepareBias(MKLDNNMatrixPtr& bias, void prepareBias(MKLDNNMatrixPtr& bias,
const MatrixPtr& biasMat, const MatrixPtr& biasMat,
const MKLDNNMatrixPtr& out, const MKLDNNMatrixPtr& out,
......
...@@ -129,8 +129,6 @@ void MKLDNNBatchNormLayer::reshape( ...@@ -129,8 +129,6 @@ void MKLDNNBatchNormLayer::reshape(
void MKLDNNBatchNormLayer::resetFwd(std::vector<primitive>& pipeline, void MKLDNNBatchNormLayer::resetFwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
// In training phase, it will always calculate mean and var, // In training phase, it will always calculate mean and var,
// so useGlobalStats must be false. // so useGlobalStats must be false.
...@@ -140,25 +138,23 @@ void MKLDNNBatchNormLayer::resetFwd(std::vector<primitive>& pipeline, ...@@ -140,25 +138,23 @@ void MKLDNNBatchNormLayer::resetFwd(std::vector<primitive>& pipeline,
useGlobalStats_ = false; useGlobalStats_ = false;
} }
resetFwdBuffers(in, wgt, out); resetFwdBuffers(in, wgtVal_, out);
resetFwdPD(fwdPD_, in, wgt, out); resetFwdPD(fwdPD_, in, wgtVal_, out);
resetFwdPipeline(pipeline, fwdPD_, in, wgt, out); resetFwdPipeline(pipeline, fwdPD_, in, wgtVal_, out);
} }
void MKLDNNBatchNormLayer::resetBwd(std::vector<primitive>& pipeline, void MKLDNNBatchNormLayer::resetBwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
std::shared_ptr<bn_bwd::primitive_desc> pd; std::shared_ptr<bn_bwd::primitive_desc> pd;
resetBwdBuffers(in, wgt, out); resetBwdBuffers(in, wgtGrad_, out);
resetBwdPD(pd, in, wgt, out); resetBwdPD(pd, in, wgtGrad_, out);
resetBwdPipeline(pipeline, pd, in, wgt, out); resetBwdPipeline(pipeline, pd, in, wgtGrad_, out);
} }
void MKLDNNBatchNormLayer::forward(PassType passType) { void MKLDNNBatchNormLayer::forward(PassType passType) {
......
...@@ -77,14 +77,10 @@ public: ...@@ -77,14 +77,10 @@ public:
void resetFwd(std::vector<mkldnn::primitive>& pipeline, void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void resetBwd(std::vector<mkldnn::primitive>& pipeline, void resetBwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void updateWeights(const UpdateCallback& callback) override; void updateWeights(const UpdateCallback& callback) override;
...@@ -98,11 +94,7 @@ protected: ...@@ -98,11 +94,7 @@ protected:
* moving = moving * AvgFraction + local * (1 - AvgFraction) * moving = moving * AvgFraction + local * (1 - AvgFraction)
*/ */
void calMovingMeanAndVar(); void calMovingMeanAndVar();
/**
* Forward functions: reset buffers(input, weight, output),
* reset primitive descriptor,
* reset pipeline.
*/
void resetFwdBuffers(MKLDNNMatrixPtr& in, void resetFwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
...@@ -115,12 +107,6 @@ protected: ...@@ -115,12 +107,6 @@ protected:
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
/**
* Backward functions: reset buffers(input, weight, output),
* reset primitive descriptor,
* reset pipeline.
*/
void resetBwdBuffers(MKLDNNMatrixPtr& in, void resetBwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
......
...@@ -60,8 +60,6 @@ void MKLDNNConcatLayer::reshape( ...@@ -60,8 +60,6 @@ void MKLDNNConcatLayer::reshape(
void MKLDNNConcatLayer::resetFwd(std::vector<primitive>& pipeline, void MKLDNNConcatLayer::resetFwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
resetFwdBuffers(inVals_, out); resetFwdBuffers(inVals_, out);
in = inVals_[0]; in = inVals_[0];
...@@ -74,8 +72,6 @@ void MKLDNNConcatLayer::resetFwd(std::vector<primitive>& pipeline, ...@@ -74,8 +72,6 @@ void MKLDNNConcatLayer::resetFwd(std::vector<primitive>& pipeline,
void MKLDNNConcatLayer::resetBwd(std::vector<primitive>& pipeline, void MKLDNNConcatLayer::resetBwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
resetBwdBuffers(inGrads_, out); resetBwdBuffers(inGrads_, out);
in = inGrads_[0]; in = inGrads_[0];
......
...@@ -51,14 +51,10 @@ public: ...@@ -51,14 +51,10 @@ public:
void resetFwd(std::vector<mkldnn::primitive>& pipeline, void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void resetBwd(std::vector<mkldnn::primitive>& pipeline, void resetBwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void printSizeInfo() override { void printSizeInfo() override {
...@@ -99,11 +95,6 @@ public: ...@@ -99,11 +95,6 @@ public:
} }
protected: protected:
/**
* Forward functions: reset buffers(inputs, output, bias),
* reset primitive descriptor,
* reset pipeline.
*/
void resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs, void resetFwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
void resetFwdPD(std::shared_ptr<mkldnn::concat::primitive_desc>& pd, void resetFwdPD(std::shared_ptr<mkldnn::concat::primitive_desc>& pd,
...@@ -113,11 +104,6 @@ protected: ...@@ -113,11 +104,6 @@ protected:
std::shared_ptr<mkldnn::concat::primitive_desc>& pd, std::shared_ptr<mkldnn::concat::primitive_desc>& pd,
std::vector<MKLDNNMatrixPtr>& inputs, std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
/**
* Backward functions: reset buffers(inputs, output, bias)
* reset primitives and pipeline
*/
void resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs, void resetBwdBuffers(std::vector<MKLDNNMatrixPtr>& inputs,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
void resetBwdPipeline(std::vector<mkldnn::primitive>& pipeline, void resetBwdPipeline(std::vector<mkldnn::primitive>& pipeline,
......
...@@ -106,20 +106,16 @@ void MKLDNNConvLayer::reshape( ...@@ -106,20 +106,16 @@ void MKLDNNConvLayer::reshape(
void MKLDNNConvLayer::resetFwd(std::vector<primitive>& pipeline, void MKLDNNConvLayer::resetFwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
resetFwdPD(fwdPD_); resetFwdPD(fwdPD_);
resetFwdBuffers(fwdPD_, in, wgt, bias, out); resetFwdBuffers(fwdPD_, in, wgtVal_, biasVal_, out);
resetFwdPipeline(pipeline, fwdPD_, in, wgt, bias, out); resetFwdPipeline(pipeline, fwdPD_, in, wgtVal_, biasVal_, out);
} }
void MKLDNNConvLayer::resetBwd(std::vector<primitive>& pipeline, void MKLDNNConvLayer::resetBwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
std::shared_ptr<conv_bwdWgt::primitive_desc> bwdWgtPD; std::shared_ptr<conv_bwdWgt::primitive_desc> bwdWgtPD;
std::shared_ptr<conv_bwdData::primitive_desc> bwdDataPD; std::shared_ptr<conv_bwdData::primitive_desc> bwdDataPD;
...@@ -128,9 +124,9 @@ void MKLDNNConvLayer::resetBwd(std::vector<primitive>& pipeline, ...@@ -128,9 +124,9 @@ void MKLDNNConvLayer::resetBwd(std::vector<primitive>& pipeline,
resetBwdDataPD(bwdDataPD); resetBwdDataPD(bwdDataPD);
resetBwdBuffers(bwdWgtPD, bwdDataPD, in, wgt, bias, out); resetBwdBuffers(bwdWgtPD, bwdDataPD, in, wgtGrad_, biasGrad_, out);
resetBwdPipeline(pipeline, bwdWgtPD, bwdDataPD, in, wgt, bias, out); resetBwdPipeline(pipeline, bwdWgtPD, bwdDataPD, in, wgtGrad_, biasGrad_, out);
} }
void MKLDNNConvLayer::updateWeights(const UpdateCallback& callback) { void MKLDNNConvLayer::updateWeights(const UpdateCallback& callback) {
......
...@@ -73,14 +73,10 @@ public: ...@@ -73,14 +73,10 @@ public:
void resetFwd(std::vector<mkldnn::primitive>& pipeline, void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void resetBwd(std::vector<mkldnn::primitive>& pipeline, void resetBwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void updateWeights(const UpdateCallback& callback) override; void updateWeights(const UpdateCallback& callback) override;
...@@ -107,48 +103,26 @@ protected: ...@@ -107,48 +103,26 @@ protected:
mkldnn::memory::dims& padL, mkldnn::memory::dims& padL,
mkldnn::memory::dims& padR); mkldnn::memory::dims& padR);
/**
* reset the forward primitive descriptor.
*/
void resetFwdPD(std::shared_ptr<conv_fwd::primitive_desc>& pd); void resetFwdPD(std::shared_ptr<conv_fwd::primitive_desc>& pd);
/**
* reset the MKLDNNMatrix buffers used in forward.
*/
void resetFwdBuffers(std::shared_ptr<conv_fwd::primitive_desc>& pd, void resetFwdBuffers(std::shared_ptr<conv_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
/**
* reset the forward pipeline.
*/
void resetFwdPipeline(std::vector<mkldnn::primitive>& pipeline, void resetFwdPipeline(std::vector<mkldnn::primitive>& pipeline,
std::shared_ptr<conv_fwd::primitive_desc>& pd, std::shared_ptr<conv_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
/**
* reset the backward weight primitive descriptor.
*/
void resetBwdWgtPD(std::shared_ptr<conv_bwdWgt::primitive_desc>& pd); void resetBwdWgtPD(std::shared_ptr<conv_bwdWgt::primitive_desc>& pd);
/**
* reset the backward data primitive descriptor.
*/
void resetBwdDataPD(std::shared_ptr<conv_bwdData::primitive_desc>& pd); void resetBwdDataPD(std::shared_ptr<conv_bwdData::primitive_desc>& pd);
/**
* reset the MKLDNNMatrix buffers used in backward.
*/
void resetBwdBuffers(std::shared_ptr<conv_bwdWgt::primitive_desc>& wgtPD, void resetBwdBuffers(std::shared_ptr<conv_bwdWgt::primitive_desc>& wgtPD,
std::shared_ptr<conv_bwdData::primitive_desc>& dataPD, std::shared_ptr<conv_bwdData::primitive_desc>& dataPD,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
/**
* reset the backward pipeline.
*/
void resetBwdPipeline(std::vector<mkldnn::primitive>& pipeline, void resetBwdPipeline(std::vector<mkldnn::primitive>& pipeline,
std::shared_ptr<conv_bwdWgt::primitive_desc>& wgtPD, std::shared_ptr<conv_bwdWgt::primitive_desc>& wgtPD,
std::shared_ptr<conv_bwdData::primitive_desc>& dataPD, std::shared_ptr<conv_bwdData::primitive_desc>& dataPD,
......
...@@ -88,31 +88,27 @@ void MKLDNNFcLayer::reshape( ...@@ -88,31 +88,27 @@ void MKLDNNFcLayer::reshape(
void MKLDNNFcLayer::resetFwd(std::vector<primitive>& pipeline, void MKLDNNFcLayer::resetFwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
resetFwdBuffers(in, wgt, bias, out); resetFwdBuffers(in, wgtVal_, biasVal_, out);
resetFwdPD(fwdPD_, in, wgt, bias, out); resetFwdPD(fwdPD_, in, wgtVal_, biasVal_, out);
resetFwdPipeline(pipeline, fwdPD_, in, wgt, bias, out); resetFwdPipeline(pipeline, fwdPD_, in, wgtVal_, biasVal_, out);
} }
void MKLDNNFcLayer::resetBwd(std::vector<primitive>& pipeline, void MKLDNNFcLayer::resetBwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
std::shared_ptr<fc_bwdWgt::primitive_desc> bwdWgtPD; std::shared_ptr<fc_bwdWgt::primitive_desc> bwdWgtPD;
std::shared_ptr<fc_bwdData::primitive_desc> bwdDataPD; std::shared_ptr<fc_bwdData::primitive_desc> bwdDataPD;
resetBwdBuffers(in, wgt, bias, out); resetBwdBuffers(in, wgtGrad_, biasGrad_, out);
resetBwdWgtPD(bwdWgtPD, wgt, bias, out); resetBwdWgtPD(bwdWgtPD, wgtGrad_, biasGrad_, out);
resetBwdDataPD(bwdDataPD, in, out); resetBwdDataPD(bwdDataPD, in, out);
resetBwdPipeline(pipeline, bwdWgtPD, bwdDataPD, in, wgt, bias, out); resetBwdPipeline(pipeline, bwdWgtPD, bwdDataPD, in, wgtGrad_, biasGrad_, out);
} }
void MKLDNNFcLayer::updateWeights(const UpdateCallback& callback) { void MKLDNNFcLayer::updateWeights(const UpdateCallback& callback) {
......
...@@ -56,14 +56,10 @@ public: ...@@ -56,14 +56,10 @@ public:
void resetFwd(std::vector<mkldnn::primitive>& pipeline, void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void resetBwd(std::vector<mkldnn::primitive>& pipeline, void resetBwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void updateWeights(const UpdateCallback& callback) override; void updateWeights(const UpdateCallback& callback) override;
...@@ -73,11 +69,6 @@ public: ...@@ -73,11 +69,6 @@ public:
void convertWeightsToPaddle() override; void convertWeightsToPaddle() override;
protected: protected:
/**
* Forward functions: reset buffers(input, output, weight and bias),
* reset primitive descriptor,
* reset pipeline.
*/
void resetFwdBuffers(MKLDNNMatrixPtr& in, void resetFwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
...@@ -93,13 +84,6 @@ protected: ...@@ -93,13 +84,6 @@ protected:
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
/**
* Backward functions: reset buffers(input, output, weight and bias),
* reset primitive descriptor for backward weight,
* reset primitive descriptor for backward data,
* reset pipeline.
*/
void resetBwdBuffers(MKLDNNMatrixPtr& in, void resetBwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
......
...@@ -58,7 +58,7 @@ void MKLDNNLayer::forward(PassType passType) { ...@@ -58,7 +58,7 @@ void MKLDNNLayer::forward(PassType passType) {
printSizeInfo(); printSizeInfo();
// all cpu device output grad or value share output's // all cpu device output grad or value share output's
shareCPUDevice(); shareCPUDevice();
resetFwd(pipelineFwd_, inVal_, wgtVal_, biasVal_, outVal_); resetFwd(pipelineFwd_, inVal_, outVal_);
// MKLDNNLayer output value should be MKLDNNMatrix // MKLDNNLayer output value should be MKLDNNMatrix
// so external output value is necessary. // so external output value is necessary.
// Then external input value is not necessary, // Then external input value is not necessary,
...@@ -101,7 +101,7 @@ void MKLDNNLayer::backward(const UpdateCallback& callback) { ...@@ -101,7 +101,7 @@ void MKLDNNLayer::backward(const UpdateCallback& callback) {
pipelineBwd_.clear(); pipelineBwd_.clear();
pipelineMergeGrad_.clear(); pipelineMergeGrad_.clear();
mergeGrad_ = nullptr; mergeGrad_ = nullptr;
resetBwd(pipelineBwd_, inGrad_, wgtGrad_, biasGrad_, outGrad_); resetBwd(pipelineBwd_, inGrad_, outGrad_);
// external output grad is not necessary // external output grad is not necessary
// since output may be mkldnn internal buffer or merge them directly. // since output may be mkldnn internal buffer or merge them directly.
CHECK(outGrad_) << "internal output grad is necessary"; CHECK(outGrad_) << "internal output grad is necessary";
......
...@@ -134,21 +134,19 @@ public: ...@@ -134,21 +134,19 @@ public:
/** /**
* reset the mkldnn forward primitve and memories * reset the mkldnn forward primitve and memories
* only would be called when input size changes * only would be called when input size changes
* weight and bias buffers should be coverd by child class itself
*/ */
virtual void resetFwd(std::vector<mkldnn::primitive>& pipeline, virtual void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) = 0; MKLDNNMatrixPtr& out) = 0;
/** /**
* reset the mkldnn backward primitve and memories * reset the mkldnn backward primitve and memories
* only would be called when needed * only would be called when needed
* weight and bias buffers should be coverd by child class itself
*/ */
virtual void resetBwd(std::vector<mkldnn::primitive>& pipeline, virtual void resetBwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) = 0; MKLDNNMatrixPtr& out) = 0;
/** /**
......
...@@ -75,8 +75,6 @@ void MKLDNNPoolLayer::reshape( ...@@ -75,8 +75,6 @@ void MKLDNNPoolLayer::reshape(
void MKLDNNPoolLayer::resetFwd(std::vector<primitive>& pipeline, void MKLDNNPoolLayer::resetFwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
resetFwdBuffers(in, out); resetFwdBuffers(in, out);
...@@ -87,8 +85,6 @@ void MKLDNNPoolLayer::resetFwd(std::vector<primitive>& pipeline, ...@@ -87,8 +85,6 @@ void MKLDNNPoolLayer::resetFwd(std::vector<primitive>& pipeline,
void MKLDNNPoolLayer::resetBwd(std::vector<primitive>& pipeline, void MKLDNNPoolLayer::resetBwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
std::shared_ptr<pool_bwd::primitive_desc> pd; std::shared_ptr<pool_bwd::primitive_desc> pd;
......
...@@ -57,14 +57,10 @@ public: ...@@ -57,14 +57,10 @@ public:
void resetFwd(std::vector<mkldnn::primitive>& pipeline, void resetFwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void resetBwd(std::vector<mkldnn::primitive>& pipeline, void resetBwd(std::vector<mkldnn::primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) override; MKLDNNMatrixPtr& out) override;
void printSizeInfo() override { void printSizeInfo() override {
...@@ -75,11 +71,6 @@ public: ...@@ -75,11 +71,6 @@ public:
} }
protected: protected:
/**
* Forward functions: reset buffers(input, output),
* reset primitive descriptor,
* reset pipeline.
*/
void resetFwdBuffers(MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out); void resetFwdBuffers(MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out);
void resetFwdPD(std::shared_ptr<pool_fwd::primitive_desc>& pd, void resetFwdPD(std::shared_ptr<pool_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr in, MKLDNNMatrixPtr in,
...@@ -88,12 +79,6 @@ protected: ...@@ -88,12 +79,6 @@ protected:
std::shared_ptr<pool_fwd::primitive_desc>& pd, std::shared_ptr<pool_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& out); MKLDNNMatrixPtr& out);
/**
* Backward functions: reset buffers(input, output),
* reset primitive descriptor,
* reset pipeline.
*/
void resetBwdBuffers(MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out); void resetBwdBuffers(MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& out);
void resetBwdPD(std::shared_ptr<pool_bwd::primitive_desc>& pd, void resetBwdPD(std::shared_ptr<pool_bwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册