提交 f2317b67 编写于 作者: T tensor-tang

separate resetFwd and resetBwd to some sub functions

上级 66fdbd0c
......@@ -18,6 +18,9 @@ limitations under the License. */
#include "mkldnn.hpp"
namespace paddle {
typedef mkldnn::convolution_forward conv_fwd;
typedef mkldnn::convolution_backward_weights conv_bwdWgt;
typedef mkldnn::convolution_backward_data conv_bwdData;
/**
* @brief A subclass of MKLDNNLayer conv layer.
......@@ -43,7 +46,7 @@ protected:
std::shared_ptr<mkldnn::reorder> cvtWgtVal_;
// save forward primitive_desc, which can be used backward
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> fwdPD_;
std::shared_ptr<conv_fwd::primitive_desc> fwdPD_;
// MKLDNNMatrixPtr which should be created from CPU Device
MKLDNNMatrixPtr cpuInVal_;
......@@ -99,7 +102,6 @@ public:
void convertWeightsToPaddle() override;
protected:
void printSizeInfo() override {
MKLDNNLayer::printSizeInfo();
VLOG(MKLDNN_SIZES) << getName() << ": fh: " << fh_ << ", fw: " << fw_
......@@ -116,6 +118,7 @@ protected:
VLOG(MKLDNN_FMTS) << " >>> " << cpuOutVal_->getFormat();
}
}
void printGradFormatFlow() override {
if (cpuInGrad_) {
VLOG(MKLDNN_FMTS) << cpuInGrad_->getFormat() << " <<<";
......@@ -126,6 +129,107 @@ protected:
}
}
protected:
/**
* load the dims settings of this conv
*/
void loadConvSettings(mkldnn::memory::dims& wgt,
mkldnn::memory::dims& bias,
mkldnn::memory::dims& stride,
mkldnn::memory::dims& dilation,
mkldnn::memory::dims& padL,
mkldnn::memory::dims& padR);
/**
* reset the forward primitive descriptor.
*/
void resetFwdPD(std::shared_ptr<conv_fwd::primitive_desc>& pd);
/**
* reset the MKLDNNMatrix buffers used in forward.
*/
void resetFwdBuffers(std::shared_ptr<conv_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out);
/**
* reset the forward pipeline.
*/
void resetFwdPipeline(std::vector<mkldnn::primitive>& pipeline,
std::shared_ptr<conv_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out);
/**
* reset MKLDNNMatrix of input value
*/
void resetInValue(std::shared_ptr<conv_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in);
/**
* reset MKLDNNMatrix of weight and bias value
*/
void resetWgtBiasValue(std::shared_ptr<conv_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias);
/**
* reset MKLDNNMatrix of output value
*/
void resetOutValue(std::shared_ptr<conv_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& out);
/**
* reset the backward weight primitive descriptor.
*/
void resetBwdWgtPD(std::shared_ptr<conv_bwdWgt::primitive_desc>& pd);
/**
* reset the backward data primitive descriptor.
*/
void resetBwdDataPD(std::shared_ptr<conv_bwdData::primitive_desc>& pd);
/**
* reset the MKLDNNMatrix buffers used in backward.
*/
void resetBwdBuffers(std::shared_ptr<conv_bwdWgt::primitive_desc>& wgtPD,
std::shared_ptr<conv_bwdData::primitive_desc>& dataPD,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out);
/**
* reset the backward pipeline.
*/
void resetBwdPipeline(std::vector<mkldnn::primitive>& pipeline,
std::shared_ptr<conv_bwdWgt::primitive_desc>& wgtPD,
std::shared_ptr<conv_bwdData::primitive_desc>& dataPD,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out);
/**
* reset MKLDNNMatrix of output grad
*/
void resetOutGrad(std::shared_ptr<conv_bwdWgt::primitive_desc>& wgtPD,
MKLDNNMatrixPtr& out);
/**
* reset MKLDNNMatrix of weight and bias grad
*/
void resetWgtBiasGrad(std::shared_ptr<conv_bwdWgt::primitive_desc>& wgtPD,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias);
/**
* reset MKLDNNMatrix of input grad
*/
void resetInGrad(std::shared_ptr<conv_bwdData::primitive_desc>& dataPD,
MKLDNNMatrixPtr& in);
/**
* reset MKLDNNMatrix of weight value for backward data
* since the primitive_desc would be different with wgtVal_
*/
void resetWgtValBwdData(std::shared_ptr<conv_bwdData::primitive_desc>& dataPD,
MKLDNNMatrixPtr& wgt);
/**
* get padding_r according to
* https://github.com/01org/mkl-dnn/blob/master/tests/gtests/
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册