提交 b0d9b68a 编写于 作者: T tensor-tang

unify functions of mkldnn_fc and refine comments

上级 13d00053
...@@ -285,10 +285,9 @@ void MKLDNNConvLayer::resetWgtBiasValue( ...@@ -285,10 +285,9 @@ void MKLDNNConvLayer::resetWgtBiasValue(
wgt = MKLDNNMatrix::create(weight_->getW(), pd->weights_primitive_desc()); wgt = MKLDNNMatrix::create(weight_->getW(), pd->weights_primitive_desc());
VLOG(MKLDNN_FMTS) << "Weight value format: " << wgt->getFormat(); VLOG(MKLDNN_FMTS) << "Weight value format: " << wgt->getFormat();
bias = nullptr; bias = (biases_ && biases_->getW())
if (biases_ && biases_->getW()) { ? MKLDNNMatrix::create(biases_->getW(), pd->bias_primitive_desc())
bias = MKLDNNMatrix::create(biases_->getW(), pd->bias_primitive_desc()); : nullptr;
}
} }
void MKLDNNConvLayer::resetOutValue( void MKLDNNConvLayer::resetOutValue(
...@@ -356,6 +355,7 @@ void MKLDNNConvLayer::resetBwdWgtPD( ...@@ -356,6 +355,7 @@ void MKLDNNConvLayer::resetBwdWgtPD(
void MKLDNNConvLayer::resetBwdDataPD( void MKLDNNConvLayer::resetBwdDataPD(
std::shared_ptr<conv_bwdData::primitive_desc>& pd) { std::shared_ptr<conv_bwdData::primitive_desc>& pd) {
pd = nullptr;
if (inputLayers_[0]->getOutput().grad == nullptr) { if (inputLayers_[0]->getOutput().grad == nullptr) {
return; return;
} }
...@@ -476,6 +476,7 @@ void MKLDNNConvLayer::resetWgtBiasGrad( ...@@ -476,6 +476,7 @@ void MKLDNNConvLayer::resetWgtBiasGrad(
<< "primitive desc of weight grad and value should be equal"; << "primitive desc of weight grad and value should be equal";
VLOG(MKLDNN_FMTS) << "weight grad format: " << wgt->getFormat(); VLOG(MKLDNN_FMTS) << "weight grad format: " << wgt->getFormat();
bias = nullptr;
if (biasVal_ == nullptr) { if (biasVal_ == nullptr) {
return; return;
} }
......
...@@ -17,9 +17,6 @@ limitations under the License. */ ...@@ -17,9 +17,6 @@ limitations under the License. */
using namespace mkldnn; // NOLINT using namespace mkldnn; // NOLINT
typedef memory::format format; typedef memory::format format;
typedef inner_product_forward fc_fwd;
typedef inner_product_backward_weights fc_bwdWgt;
typedef inner_product_backward_data fc_bwdData;
namespace paddle { namespace paddle {
...@@ -93,35 +90,88 @@ void MKLDNNFcLayer::reshape( ...@@ -93,35 +90,88 @@ void MKLDNNFcLayer::reshape(
printSizeInfo(); printSizeInfo();
} }
void MKLDNNFcLayer::resetFwd(std::vector<mkldnn::primitive>& pipeline, void MKLDNNFcLayer::resetFwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in, MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
pipeline.clear(); resetFwdBuffers(in, wgt, bias, out);
bool hasBias = biases_ && biases_->getW();
const MatrixPtr& wgtVal = weight_->getW(); resetFwdPD(fwdPD_, in, wgt, bias, out);
const MatrixPtr& biasVal = hasBias ? biases_->getW() : nullptr;
const MatrixPtr& outVal = output_.value; resetFwdPipeline(pipeline, fwdPD_, in, wgt, bias, out);
printValueFormatFlow();
}
void MKLDNNFcLayer::resetBwd(std::vector<primitive>& pipeline,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) {
std::shared_ptr<fc_bwdWgt::primitive_desc> bwdWgtPD;
std::shared_ptr<fc_bwdData::primitive_desc> bwdDataPD;
resetBwdBuffers(in, wgt, bias, out);
resetBwdWgtPD(bwdWgtPD, wgt, bias, out);
resetBwdDataPD(bwdDataPD, in, out);
resetBwdPipeline(pipeline, bwdWgtPD, bwdDataPD, in, wgt, bias, out);
printGradFormatFlow();
}
void MKLDNNFcLayer::updateInputData() {
inVal_->setData(getInputValue(0, CPU_DEVICE)->getData());
}
void MKLDNNFcLayer::updateWeights(const UpdateCallback& callback) {
weight_->getParameterPtr()->incUpdate(callback);
if (biases_ && biases_->getWGrad()) {
biases_->getParameterPtr()->incUpdate(callback);
}
}
void MKLDNNFcLayer::resetFwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) {
resetInValue(in);
resetWgtBiasValue(wgt, bias);
resetOutValue(out);
}
void MKLDNNFcLayer::resetInValue(MKLDNNMatrixPtr& in) {
if (inputIsOnlyMKLDNN()) { if (inputIsOnlyMKLDNN()) {
const MatrixPtr& inVal = getInputValue(0); const MatrixPtr& dnnIn = getInputValue(0);
in = std::dynamic_pointer_cast<MKLDNNMatrix>(inVal); in = std::dynamic_pointer_cast<MKLDNNMatrix>(dnnIn);
CHECK(in) << "Input should be MKLDNNMatrix"; CHECK(in) << "Input should be MKLDNNMatrix";
} else { } else {
CHECK_EQ(getPrev(0)->getDeviceId(), CPU_DEVICE) << "Only support CPU yet"; CHECK_EQ(getPrev(0)->getDeviceId(), CPU_DEVICE) << "Only support CPU yet";
const MatrixPtr& inVal = getInputValue(0, CPU_DEVICE); const MatrixPtr& cpuIn = getInputValue(0, CPU_DEVICE);
in = MKLDNNMatrix::create( in = MKLDNNMatrix::create(
inVal, memory::dims{bs_, ic_, ih_, iw_}, format::nchw, engine_); cpuIn, {bs_, ic_, ih_, iw_}, format::nchw, engine_);
} }
in->downSpatial(); in->downSpatial();
}
void MKLDNNFcLayer::resetWgtBiasValue(MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias) {
wgt = MKLDNNMatrix::create( wgt = MKLDNNMatrix::create(
wgtVal, memory::dims{oc_, ic_, ih_, iw_}, format::oihw, engine_); weight_->getW(), {oc_, ic_, ih_, iw_}, format::oihw, engine_);
wgt->downSpatial(); wgt->downSpatial();
bias = hasBias ? MKLDNNMatrix::create(biasVal, {oc_}, format::x, engine_)
bias = (biases_ && biases_->getW())
? MKLDNNMatrix::create(biases_->getW(), {oc_}, format::x, engine_)
: nullptr; : nullptr;
out = MKLDNNMatrix::create(outVal, {bs_, oc_}, format::nc, engine_); }
void MKLDNNFcLayer::resetOutValue(MKLDNNMatrixPtr& out) {
out = MKLDNNMatrix::create(output_.value, {bs_, oc_}, format::nc, engine_);
// change original output value to mkldnn output value // change original output value to mkldnn output value
output_.value = std::dynamic_pointer_cast<Matrix>(out); output_.value = std::dynamic_pointer_cast<Matrix>(out);
if (!outputIsOnlyMKLDNN()) { if (!outputIsOnlyMKLDNN()) {
...@@ -129,10 +179,18 @@ void MKLDNNFcLayer::resetFwd(std::vector<mkldnn::primitive>& pipeline, ...@@ -129,10 +179,18 @@ void MKLDNNFcLayer::resetFwd(std::vector<mkldnn::primitive>& pipeline,
// just share point // just share point
getOutput(CPU_DEVICE).value->setData(output_.value->getData()); getOutput(CPU_DEVICE).value->setData(output_.value->getData());
} }
}
// create forward handle void MKLDNNFcLayer::resetFwdPD(std::shared_ptr<fc_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr in,
MKLDNNMatrixPtr wgt,
MKLDNNMatrixPtr bias,
MKLDNNMatrixPtr out) {
CHECK(in);
CHECK(wgt);
CHECK(out);
prop_kind pk = prop_kind::forward; prop_kind pk = prop_kind::forward;
fc_fwd::desc fwdDesc = hasBias ? fc_fwd::desc(pk, fc_fwd::desc fwdDesc = bias != nullptr ? fc_fwd::desc(pk,
in->getMemoryDesc(), in->getMemoryDesc(),
wgt->getMemoryDesc(), wgt->getMemoryDesc(),
bias->getMemoryDesc(), bias->getMemoryDesc(),
...@@ -141,34 +199,39 @@ void MKLDNNFcLayer::resetFwd(std::vector<mkldnn::primitive>& pipeline, ...@@ -141,34 +199,39 @@ void MKLDNNFcLayer::resetFwd(std::vector<mkldnn::primitive>& pipeline,
in->getMemoryDesc(), in->getMemoryDesc(),
wgt->getMemoryDesc(), wgt->getMemoryDesc(),
out->getMemoryDesc()); out->getMemoryDesc());
fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_); pd.reset(new fc_fwd::primitive_desc(fwdDesc, engine_));
if (hasBias) { }
fwd_.reset(new fc_fwd(fwdPD, *in, *wgt, *bias, *out));
void MKLDNNFcLayer::resetFwdPipeline(
std::vector<primitive>& pipeline,
std::shared_ptr<fc_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) {
pipeline.clear();
if (bias) {
fwd_.reset(new fc_fwd(*pd, *in, *wgt, *bias, *out));
} else { } else {
fwd_.reset(new fc_fwd(fwdPD, *in, *wgt, *out)); fwd_.reset(new fc_fwd(*pd, *in, *wgt, *out));
} }
printValueFormatFlow();
pipeline.push_back(*fwd_); pipeline.push_back(*fwd_);
} }
void MKLDNNFcLayer::resetBwd(std::vector<mkldnn::primitive>& pipeline, void MKLDNNFcLayer::resetBwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias, MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) { MKLDNNMatrixPtr& out) {
pipeline.clear(); resetOutGrad(out);
if (!needResetBwd_) {
return;
}
needResetBwd_ = false;
bool hasBias = biases_ && biases_->getWGrad();
/// backward weight resetWgtBiasGrad(wgt, bias);
CHECK(inVal_) << "Should have input value";
const MatrixPtr& wgtGrad = weight_->getWGrad();
const MatrixPtr& biasGrad = hasBias ? biases_->getWGrad() : nullptr;
resetInGrad(in);
}
void MKLDNNFcLayer::resetOutGrad(MKLDNNMatrixPtr& out) {
// TODO(TJ): merge outgrad // TODO(TJ): merge outgrad
int device = outputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE; int device = outputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE;
// for MKLDNN device: // for MKLDNN device:
...@@ -178,66 +241,88 @@ void MKLDNNFcLayer::resetBwd(std::vector<mkldnn::primitive>& pipeline, ...@@ -178,66 +241,88 @@ void MKLDNNFcLayer::resetBwd(std::vector<mkldnn::primitive>& pipeline,
// for CPU device: // for CPU device:
// fc do not need to convert from cpu device since output is always nc format // fc do not need to convert from cpu device since output is always nc format
// only need create from cpu device // only need create from cpu device
const MatrixPtr& outGrad = getOutput(device).grad; CHECK(outVal_);
out = MKLDNNMatrix::create(outGrad, outVal_->getPrimitiveDesc()); out =
wgt = MKLDNNMatrix::create(wgtGrad, wgtVal_->getPrimitiveDesc()); MKLDNNMatrix::create(getOutput(device).grad, outVal_->getPrimitiveDesc());
bias = hasBias ? MKLDNNMatrix::create(biasGrad, biasVal_->getPrimitiveDesc()) }
: nullptr;
// create memory primitive desc void MKLDNNFcLayer::resetWgtBiasGrad(MKLDNNMatrixPtr& wgt,
fc_fwd::desc fwdDesc = fc_fwd::desc(prop_kind::forward, MKLDNNMatrixPtr& bias) {
inVal_->getMemoryDesc(), CHECK(wgtVal_);
wgt->getMemoryDesc(), wgt = MKLDNNMatrix::create(weight_->getWGrad(), wgtVal_->getPrimitiveDesc());
out->getMemoryDesc());
fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_); bias = nullptr;
fc_bwdWgt::desc bwdWgtDesc = hasBias if (biasVal_ == nullptr) {
? fc_bwdWgt::desc(inVal_->getMemoryDesc(), return;
}
bias =
MKLDNNMatrix::create(biases_->getWGrad(), biasVal_->getPrimitiveDesc());
}
void MKLDNNFcLayer::resetInGrad(MKLDNNMatrixPtr& in) {
in = nullptr;
const MatrixPtr& inGrad = inputLayers_[0]->getOutput().grad;
if (inGrad == nullptr) {
return;
}
// TODO(TJ): use outputMaps_ ways to get the inGrad_ when merge outgrad done
CHECK(inVal_);
in = MKLDNNMatrix::create(inGrad, inVal_->getPrimitiveDesc());
}
void MKLDNNFcLayer::resetBwdWgtPD(
std::shared_ptr<fc_bwdWgt::primitive_desc>& pd,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) {
CHECK(inVal_);
fc_bwdWgt::desc bwdWgtDesc = bias ? fc_bwdWgt::desc(inVal_->getMemoryDesc(),
wgt->getMemoryDesc(), wgt->getMemoryDesc(),
bias->getMemoryDesc(), bias->getMemoryDesc(),
out->getMemoryDesc()) out->getMemoryDesc())
: fc_bwdWgt::desc(inVal_->getMemoryDesc(), : fc_bwdWgt::desc(inVal_->getMemoryDesc(),
wgt->getMemoryDesc(), wgt->getMemoryDesc(),
out->getMemoryDesc()); out->getMemoryDesc());
fc_bwdWgt::primitive_desc bwdWgtPD = pd.reset(new fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, *fwdPD_));
fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD); }
void MKLDNNFcLayer::resetBwdDataPD(
std::shared_ptr<fc_bwdData::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& out) {
pd = nullptr;
if (in == nullptr) {
return;
}
CHECK(wgtVal_);
fc_bwdData::desc bwdDataDesc = fc_bwdData::desc(
in->getMemoryDesc(), wgtVal_->getMemoryDesc(), out->getMemoryDesc());
pd.reset(new fc_bwdData::primitive_desc(bwdDataDesc, engine_, *fwdPD_));
}
if (hasBias) { void MKLDNNFcLayer::resetBwdPipeline(
bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, *inVal_, *out, *wgt, *bias)); std::vector<primitive>& pipeline,
std::shared_ptr<fc_bwdWgt::primitive_desc>& bwdWgtPD,
std::shared_ptr<fc_bwdData::primitive_desc>& bwdDataPD,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out) {
pipeline.clear();
CHECK(inVal_);
if (bias) {
bwdWgt_.reset(new fc_bwdWgt(*bwdWgtPD, *inVal_, *out, *wgt, *bias));
} else { } else {
bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, *inVal_, *out, *wgt)); bwdWgt_.reset(new fc_bwdWgt(*bwdWgtPD, *inVal_, *out, *wgt));
} }
pipeline.push_back(*bwdWgt_); pipeline.push_back(*bwdWgt_);
/// backward data if (bwdDataPD == nullptr) {
const MatrixPtr& inGrad = inputLayers_[0]->getOutput().grad;
if (inGrad == nullptr) {
return; return;
} }
if (getInput(0, MKLDNN_DEVICE).getAllCount() > 1) {
// TODO(TJ): use outputMaps_ ways to get the inGrad_ when merge outgrad done
} else {
in = MKLDNNMatrix::create(inGrad, inVal_->getPrimitiveDesc());
}
fc_bwdData::desc bwdDataDesc = fc_bwdData::desc(
inVal_->getMemoryDesc(), wgt->getMemoryDesc(), out->getMemoryDesc());
fc_bwdData::primitive_desc bwdDataPD =
fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD);
CHECK(wgtVal_) << "Should have weight memory"; CHECK(wgtVal_) << "Should have weight memory";
bwdData_.reset(new fc_bwdData(bwdDataPD, *out, *wgtVal_, *in)); bwdData_.reset(new fc_bwdData(*bwdDataPD, *out, *wgtVal_, *in));
printGradFormatFlow();
pipeline.push_back(*bwdData_); pipeline.push_back(*bwdData_);
} }
void MKLDNNFcLayer::updateInputData() {
inVal_->setData(getInputValue(0, CPU_DEVICE)->getData());
}
void MKLDNNFcLayer::updateWeights(const UpdateCallback& callback) {
weight_->getParameterPtr()->incUpdate(callback);
if (biases_ && biases_->getWGrad()) {
biases_->getParameterPtr()->incUpdate(callback);
}
}
} // namespace paddle } // namespace paddle
...@@ -18,6 +18,9 @@ limitations under the License. */ ...@@ -18,6 +18,9 @@ limitations under the License. */
#include "mkldnn.hpp" #include "mkldnn.hpp"
namespace paddle { namespace paddle {
typedef mkldnn::inner_product_forward fc_fwd;
typedef mkldnn::inner_product_backward_weights fc_bwdWgt;
typedef mkldnn::inner_product_backward_data fc_bwdData;
/** /**
* @brief A subclass of MKLDNNLayer fc layer. * @brief A subclass of MKLDNNLayer fc layer.
...@@ -32,6 +35,9 @@ protected: ...@@ -32,6 +35,9 @@ protected:
// if has already init the weight // if has already init the weight
bool hasInitedWgt_; bool hasInitedWgt_;
// save forward primitive_desc, which can be used backward
std::shared_ptr<fc_fwd::primitive_desc> fwdPD_;
// fc weight and bias // fc weight and bias
std::unique_ptr<Weight> weight_; std::unique_ptr<Weight> weight_;
std::unique_ptr<Weight> biases_; std::unique_ptr<Weight> biases_;
...@@ -67,6 +73,59 @@ public: ...@@ -67,6 +73,59 @@ public:
void convertWeightsFromPaddle() override; void convertWeightsFromPaddle() override;
void convertWeightsToPaddle() override; void convertWeightsToPaddle() override;
protected:
/**
* Forward functions: reset buffers(input, output, weight and bias),
* reset primitive descriptor,
* reset pipeline.
*/
void resetFwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out);
void resetInValue(MKLDNNMatrixPtr& in);
void resetWgtBiasValue(MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias);
void resetOutValue(MKLDNNMatrixPtr& out);
void resetFwdPD(std::shared_ptr<fc_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr in,
MKLDNNMatrixPtr wgt,
MKLDNNMatrixPtr bias,
MKLDNNMatrixPtr out);
void resetFwdPipeline(std::vector<mkldnn::primitive>& pipeline,
std::shared_ptr<fc_fwd::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out);
/**
* Backward functions: reset buffers(input, output, weight and bias),
* reset primitive descriptor for backward weight,
* reset primitive descriptor for backward data,
* reset pipeline.
*/
void resetBwdBuffers(MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out);
void resetOutGrad(MKLDNNMatrixPtr& out);
void resetWgtBiasGrad(MKLDNNMatrixPtr& wgt, MKLDNNMatrixPtr& bias);
void resetInGrad(MKLDNNMatrixPtr& in);
void resetBwdWgtPD(std::shared_ptr<fc_bwdWgt::primitive_desc>& pd,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out);
void resetBwdDataPD(std::shared_ptr<fc_bwdData::primitive_desc>& pd,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& out);
void resetBwdPipeline(std::vector<mkldnn::primitive>& pipeline,
std::shared_ptr<fc_bwdWgt::primitive_desc>& bwdWgtPD,
std::shared_ptr<fc_bwdData::primitive_desc>& bwdDataPD,
MKLDNNMatrixPtr& in,
MKLDNNMatrixPtr& wgt,
MKLDNNMatrixPtr& bias,
MKLDNNMatrixPtr& out);
}; };
} // namespace paddle } // namespace paddle
...@@ -66,11 +66,12 @@ public: ...@@ -66,11 +66,12 @@ public:
/** /**
* Create reorder primitive. * Create reorder primitive.
* Create a mkldnn::reorder handle for converting src MKLDNNMatrix to dst. * Create a mkldnn::reorder handle for converting src MKLDNNMatrix to dst.
* checkData: for whether to check the data handle of src and dst is the same. * checkData: whether to check the data handle of src and dst.
* if true, means check it and do not want support inplace reorder; * if true, it will check the data and do not allow them equal;
* otherwise do not check data which means the created reorder * otherwise, it will not check them, then the reorder created
* maybe inplace buffer and do not guarantee the logical is correct * may have inplace buffer.
* since not all format or conversion support inplace. * Do not set false, if you can not guarantee the inplace logical
* would work with your reorder.
*/ */
static std::shared_ptr<mkldnn::reorder> createReorder( static std::shared_ptr<mkldnn::reorder> createReorder(
const MKLDNNMatrixPtr& src, const MKLDNNMatrixPtr& src,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册