提交 f2317b67 编写于 作者: T tensor-tang

separate resetFwd and resetBwd to some sub functions

上级 66fdbd0c
...@@ -18,6 +18,9 @@ limitations under the License. */ ...@@ -18,6 +18,9 @@ limitations under the License. */
#include "mkldnn.hpp" #include "mkldnn.hpp"
namespace paddle { namespace paddle {
typedef mkldnn::convolution_forward conv_fwd;
typedef mkldnn::convolution_backward_weights conv_bwdWgt;
typedef mkldnn::convolution_backward_data conv_bwdData;
/** /**
* @brief A subclass of MKLDNNLayer conv layer. * @brief A subclass of MKLDNNLayer conv layer.
...@@ -43,7 +46,7 @@ protected: ...@@ -43,7 +46,7 @@ protected:
std::shared_ptr<mkldnn::reorder> cvtWgtVal_; std::shared_ptr<mkldnn::reorder> cvtWgtVal_;
// save forward primitive_desc, which can be used backward // save forward primitive_desc, which can be used backward
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwdPD_; std::shared_ptr<conv_fwd::primitive_desc> fwdPD_;
// MKLDNNMatrixPtr which should be created from CPU Device // MKLDNNMatrixPtr which should be created from CPU Device
MKLDNNMatrixPtr cpuInVal_; MKLDNNMatrixPtr cpuInVal_;
...@@ -99,7 +102,6 @@ public: ...@@ -99,7 +102,6 @@ public:
void convertWeightsToPaddle() override; void convertWeightsToPaddle() override;
protected:
void printSizeInfo() override { void printSizeInfo() override {
MKLDNNLayer::printSizeInfo(); MKLDNNLayer::printSizeInfo();
VLOG(MKLDNN_SIZES) << getName() << ": fh: " << fh_ << ", fw: " << fw_ VLOG(MKLDNN_SIZES) << getName() << ": fh: " << fh_ << ", fw: " << fw_
...@@ -116,6 +118,7 @@ protected: ...@@ -116,6 +118,7 @@ protected:
VLOG(MKLDNN_FMTS) << " >>> " << cpuOutVal_->getFormat(); VLOG(MKLDNN_FMTS) << " >>> " << cpuOutVal_->getFormat();
} }
} }
void printGradFormatFlow() override { void printGradFormatFlow() override {
if (cpuInGrad_) { if (cpuInGrad_) {
VLOG(MKLDNN_FMTS) << cpuInGrad_->getFormat() << " <<<"; VLOG(MKLDNN_FMTS) << cpuInGrad_->getFormat() << " <<<";
...@@ -126,6 +129,107 @@ protected: ...@@ -126,6 +129,107 @@ protected:
} }
} }
protected:
/**
* load the dims settings of this conv
*/
void loadConvSettings(mkldnn::memory::dims& wgt,
mkldnn::memory::dims& bias,
mkldnn::memory::dims& stride,
mkldnn::memory::dims& dilation,
mkldnn::memory::dims& padL,
mkldnn::memory::dims& padR);
/**
* reset the forward primitive descriptor.
*/
void resetFwdPD(std::shared_ptr<conv_fwd::primitive_desc>& pd);
/**
* reset the MKLDNNMatrix buffers used in forward.
*/
void resetFwdBuffers(std::shared_ptr<conv_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out);
/**
* reset the forward pipeline.
*/
void resetFwdPipeline(std::vector<mkldnn::primitive>& pipeline,
std::shared_ptr<conv_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out);
/**
* reset MKLDNNMatrix of input value
*/
void resetInValue(std::shared_ptr<conv_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in);
/**
* reset MKLDNNMatrix of weight and bias value
*/
void resetWgtBiasValue(std::shared_ptr<conv_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias);
/**
* reset MKLDNNMatrix of output value
*/
void resetOutValue(std::shared_ptr<conv_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& out);
/**
* reset the backward weight primitive descriptor.
*/
void resetBwdWgtPD(std::shared_ptr<conv_bwdWgt::primitive_desc>& pd);
/**
* reset the backward data primitive descriptor.
*/
void resetBwdDataPD(std::shared_ptr<conv_bwdData::primitive_desc>& pd);
/**
* reset the MKLDNNMatrix buffers used in backward.
*/
void resetBwdBuffers(std::shared_ptr<conv_bwdWgt::primitive_desc>& wgtPD,
std::shared_ptr<conv_bwdData::primitive_desc>& dataPD,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out);
/**
* reset the backward pipeline.
*/
void resetBwdPipeline(std::vector<mkldnn::primitive>& pipeline,
std::shared_ptr<conv_bwdWgt::primitive_desc>& wgtPD,
std::shared_ptr<conv_bwdData::primitive_desc>& dataPD,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out);
/**
* reset MKLDNNMatrix of output grad
*/
void resetOutGrad(std::shared_ptr<conv_bwdWgt::primitive_desc>& wgtPD,
MKLDNNMatrixPtr& out);
/**
* reset MKLDNNMatrix of weight and bias grad
*/
void resetWgtBiasGrad(std::shared_ptr<conv_bwdWgt::primitive_desc>& wgtPD,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias);
/**
* reset MKLDNNMatrix of input grad
*/
void resetInGrad(std::shared_ptr<conv_bwdData::primitive_desc>& dataPD,
MKLDNNMatrixPtr& in);
/**
* reset MKLDNNMatrix of weight value for backward data
* since the primitive_desc would be different with wgtVal_
*/
void resetWgtValBwdData(std::shared_ptr<conv_bwdData::primitive_desc>& dataPD,
MKLDNNMatrixPtr& wgt);
/** /**
* get padding_r according to * get padding_r according to
* https://github.com/01org/mkl-dnn/blob/master/tests/gtests/ * https://github.com/01org/mkl-dnn/blob/master/tests/gtests/
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册