提交 4eecd0c2 编写于 作者: T tensor-tang

use MKLDNNMatrix in fc backward

上级 4bffbd30
...@@ -158,10 +158,8 @@ void MKLDNNFcLayer::resetFwd() { ...@@ -158,10 +158,8 @@ void MKLDNNFcLayer::resetFwd() {
hasSpatial_ ? memory::dims{oc_, ic_, ih_, iw_} : memory::dims{oc_, ic_}, hasSpatial_ ? memory::dims{oc_, ic_, ih_, iw_} : memory::dims{oc_, ic_},
hasSpatial_ ? format::oihw : format::oi, hasSpatial_ ? format::oihw : format::oi,
engine_); engine_);
biasVal_ = biasVal_ =
hasBias ? MKLDNNMatrix::create(bias, {oc_}, format::x, engine_) : nullptr; hasBias ? MKLDNNMatrix::create(bias, {oc_}, format::x, engine_) : nullptr;
outVal_ = MKLDNNMatrix::create(out, {bs_, oc_}, format::nc, engine_); outVal_ = MKLDNNMatrix::create(out, {bs_, oc_}, format::nc, engine_);
// change original output to mkldnn output // change original output to mkldnn output
...@@ -193,46 +191,41 @@ void MKLDNNFcLayer::resetBwd() { ...@@ -193,46 +191,41 @@ void MKLDNNFcLayer::resetBwd() {
return; return;
} }
needResetBwd_ = false; needResetBwd_ = false;
bool hasBias = biases_ && biases_->getWGrad(); bool hasBias = biases_ && biases_->getWGrad();
real* iData = getInputValue(0)->getData();
real* iDiff = getInputGrad(0) != nullptr ? getInputGrad(0)->getData() : NULL;
real* oDiff = getOutputGrad()->getData();
real* wDiff = weight_->getWGrad()->getData();
real* bDiff = hasBias ? biases_->getWGrad()->getData() : NULL;
/// backward weight /// backward weight
// create memory desc for backward memory CHECK(inVal_) << "Should have input value";
memory::desc iMD = hasSpatial_ ? createMD({bs_, ic_, ih_, iw_}, format::nchw) const MatrixPtr& wgt = weight_->getWGrad();
: createMD({bs_, ic_}, format::nc); const MatrixPtr& bias = hasBias ? biases_->getWGrad() : nullptr;
memory::desc wMD = hasSpatial_ ? createMD({oc_, ic_, ih_, iw_}, format::oihw) const MatrixPtr& out = output_.grad;
: createMD({oc_, ic_}, format::oi);
memory::desc oMD = createMD({bs_, oc_}, format::nc); wgtGrad_ = MKLDNNMatrix::create(
memory::desc bMD = bDiff != NULL ? createMD({oc_}, format::x) wgt, wgtVal_->getDims(), wgtVal_->getFormat(), engine_);
: createMD({}, format::format_undef); biasGrad_ =
hasBias ? MKLDNNMatrix::create(bias, {oc_}, format::x, engine_) : nullptr;
if (inVal_) {
// update data
inVal_->set_data_handle(iData);
} else {
LOG(FATAL) << "Should not be empty";
// inVal_.reset(new memory(memory::primitive_desc(iMD, engine_), iData));
}
// create memory primitive desc and memory self outGrad_ = MKLDNNMatrix::create(out, {bs_, oc_}, format::nc, engine_);
wgtGrad_.reset(new memory(memory::primitive_desc(wMD, engine_), wDiff)); // change original output to mkldnn output
outGrad_.reset(new memory(memory::primitive_desc(oMD, engine_), oDiff)); // TODO: right?
output_.grad = std::dynamic_pointer_cast<Matrix>(outGrad_);
fc_fwd::desc fwdDesc = fc_fwd::desc(prop_kind::forward, iMD, wMD, oMD); // create memory primitive desc
fc_fwd::desc fwdDesc = fc_fwd::desc(prop_kind::forward,
inVal_->getMD(),
wgtGrad_->getMD(),
outGrad_->getMD());
fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_); fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);
fc_bwdWgt::desc bwdWgtDesc = bDiff != NULL fc_bwdWgt::desc bwdWgtDesc =
? fc_bwdWgt::desc(iMD, wMD, bMD, oMD) hasBias ? fc_bwdWgt::desc(inVal_->getMD(),
: fc_bwdWgt::desc(iMD, wMD, oMD); wgtGrad_->getMD(),
biasGrad_->getMD(),
outGrad_->getMD())
: fc_bwdWgt::desc(
inVal_->getMD(), wgtGrad_->getMD(), outGrad_->getMD());
fc_bwdWgt::primitive_desc bwdWgtPD = fc_bwdWgt::primitive_desc bwdWgtPD =
fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD); fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD);
if (bDiff != NULL) { if (hasBias) {
biasGrad_.reset(new memory(memory::primitive_desc(bMD, engine_), bDiff));
bwdWgt_.reset( bwdWgt_.reset(
new fc_bwdWgt(bwdWgtPD, *inVal_, *outGrad_, *wgtGrad_, *biasGrad_)); new fc_bwdWgt(bwdWgtPD, *inVal_, *outGrad_, *wgtGrad_, *biasGrad_));
} else { } else {
...@@ -242,13 +235,19 @@ void MKLDNNFcLayer::resetBwd() { ...@@ -242,13 +235,19 @@ void MKLDNNFcLayer::resetBwd() {
pipelineBwd_.push_back(*bwdWgt_); pipelineBwd_.push_back(*bwdWgt_);
/// backward data /// backward data
if (iDiff == NULL) { const MatrixPtr& in = getInputGrad(0);
if (in == nullptr) {
return; return;
} }
fc_bwdData::desc bwdDataDesc = fc_bwdData::desc(iMD, wMD, oMD); fc_bwdData::desc bwdDataDesc =
fc_bwdData::desc(inVal_->getMD(), wgtGrad_->getMD(), outGrad_->getMD());
fc_bwdData::primitive_desc bwdDataPD = fc_bwdData::primitive_desc bwdDataPD =
fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD); fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD);
inGrad_.reset(new memory(memory::primitive_desc(iMD, engine_), iDiff));
// TODO: check right, just from ingrad?
inGrad_ =
MKLDNNMatrix::create(in, inVal_->getDims(), inVal_->getFormat(), engine_);
CHECK(wgtVal_) << "Should have weight memory"; CHECK(wgtVal_) << "Should have weight memory";
bwdData_.reset(new fc_bwdData(bwdDataPD, *outGrad_, *wgtVal_, *inGrad_)); bwdData_.reset(new fc_bwdData(bwdDataPD, *outGrad_, *wgtVal_, *inGrad_));
pipelineBwd_.push_back(*bwdData_); pipelineBwd_.push_back(*bwdData_);
...@@ -264,7 +263,7 @@ void MKLDNNFcLayer::forward(PassType passType) { ...@@ -264,7 +263,7 @@ void MKLDNNFcLayer::forward(PassType passType) {
// update input data // update input data
// since it might be changed if this is after data layer // since it might be changed if this is after data layer
real* iData = getInputValue(0)->getData(); real* iData = getInputValue(0)->getData();
inVal_->set_data_handle(iData); inVal_->updateData(iData);
// just submit forward pipeline // just submit forward pipeline
stream_->submit(pipelineFwd_); stream_->submit(pipelineFwd_);
...@@ -288,7 +287,7 @@ void MKLDNNFcLayer::backward(const UpdateCallback& callback) { ...@@ -288,7 +287,7 @@ void MKLDNNFcLayer::backward(const UpdateCallback& callback) {
// update diff // update diff
real* oDiff = getOutputGrad()->getData(); real* oDiff = getOutputGrad()->getData();
outGrad_->set_data_handle(oDiff); outGrad_->updateData(oDiff);
// just sumbmit backward pipeline // just sumbmit backward pipeline
stream_->submit(pipelineBwd_); stream_->submit(pipelineBwd_);
......
...@@ -52,16 +52,15 @@ protected: ...@@ -52,16 +52,15 @@ protected:
std::vector<mkldnn::primitive> pipelineFwd_; std::vector<mkldnn::primitive> pipelineFwd_;
std::vector<mkldnn::primitive> pipelineBwd_; std::vector<mkldnn::primitive> pipelineBwd_;
// TODO(TJ): change below memory as MKLDNNMatrixPtr type // MKLDNNMatrixPtr
// MKLDNNMatrixPtr ;
MKLDNNMatrixPtr inVal_; MKLDNNMatrixPtr inVal_;
std::shared_ptr<mkldnn::memory> inGrad_; MKLDNNMatrixPtr inGrad_;
MKLDNNMatrixPtr outVal_; MKLDNNMatrixPtr outVal_;
std::shared_ptr<mkldnn::memory> outGrad_; MKLDNNMatrixPtr outGrad_;
MKLDNNMatrixPtr wgtVal_; MKLDNNMatrixPtr wgtVal_;
std::shared_ptr<mkldnn::memory> wgtGrad_; MKLDNNMatrixPtr wgtGrad_;
MKLDNNMatrixPtr biasVal_; MKLDNNMatrixPtr biasVal_;
std::shared_ptr<mkldnn::memory> biasGrad_; MKLDNNMatrixPtr biasGrad_;
public: public:
explicit MKLDNNLayer(const LayerConfig& config) explicit MKLDNNLayer(const LayerConfig& config)
...@@ -84,17 +83,24 @@ public: ...@@ -84,17 +83,24 @@ public:
virtual bool init(const LayerMap& layerMap, virtual bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) { const ParameterMap& parameterMap) {
CHECK(FLAGS_use_mkldnn) << "MkldnnLayers only support use_mkldnn."
<< "Please set WITH_MKLDNN=ON "
<< "and set use_mkldnn=True";
if (useGpu_ == true) {
LOG(WARNING) << "Do not support GPU yet, will change to useGpu = false";
useGpu_ = false;
}
// set device id before Layer::init
setDevice(MKLDNN_DEVICE);
// change param device to MKLDNN device
setParamsDevice(MKLDNN_DEVICE, parameterMap);
if (!Layer::init(layerMap, parameterMap)) { if (!Layer::init(layerMap, parameterMap)) {
return false; return false;
} }
CHECK(FLAGS_use_mkldnn) << "MkldnnLayers only support use_mkldnn."
<< "Please set WITH_MKLDNN=ON "
<< "and set use_mkldnn=True";
stream_.reset(new MKLDNNStream()); stream_.reset(new MKLDNNStream());
engine_ = CPUEngine::Instance().getEngine(); engine_ = CPUEngine::Instance().getEngine();
setDeviceID(MKLDNN_DEVICE);
return true; return true;
} }
...@@ -136,10 +142,33 @@ public: ...@@ -136,10 +142,33 @@ public:
} }
protected: protected:
void setDeviceID(int id) { /**
deviceId_ = id; * Set deviceId of this layer.
output_.deviceId = id; */
// TODO: handle mkldnn device or add mkldnn device to other void setDevice(int id) { deviceId_ = id; }
/**
* Set deviceId of the params used in this layer.
*/
void setParamsDevice(int id, const ParameterMap& parameterMap) {
for (auto& inputConfig : config_.inputs()) {
if (inputConfig.has_input_parameter_name()) {
ParameterPtr parameter;
std::string name = inputConfig.input_parameter_name();
CHECK(mapGet(name, parameterMap, &parameter))
<< "Cannot find input parameter " << name << " for layer "
<< getName();
parameter->setDevice(id);
}
}
if (config_.has_bias_parameter_name()) {
ParameterPtr parameter;
std::string name = config_.bias_parameter_name();
CHECK(mapGet(name, parameterMap, &parameter))
<< "Cannot find bias parameter " << name << " for layer "
<< getName();
parameter->setDevice(id);
}
} }
}; };
......
...@@ -44,6 +44,8 @@ public: ...@@ -44,6 +44,8 @@ public:
set_data_handle(CpuMatrix::getData()); set_data_handle(CpuMatrix::getData());
} }
~MKLDNNMatrix() {}
static MKLDNNMatrixPtr create( static MKLDNNMatrixPtr create(
const MatrixPtr& m, const MatrixPtr& m,
mkldnn::memory::dims dims, mkldnn::memory::dims dims,
...@@ -52,21 +54,42 @@ public: ...@@ -52,21 +54,42 @@ public:
mkldnn::memory::data_type dtype = mkldnn::memory::data_type::f32); mkldnn::memory::data_type dtype = mkldnn::memory::data_type::f32);
/** /**
* Get primitive descriptor * Get primitive descriptor.
*/ */
mkldnn::memory::primitive_desc getPD() { return this->get_primitive_desc(); } mkldnn::memory::primitive_desc getPD() { return this->get_primitive_desc(); }
/** /**
* Get memory descriptor * Get memory descriptor.
*/ */
mkldnn::memory::desc getMD() { return getPD().desc(); } mkldnn::memory::desc getMD() { return getPD().desc(); }
/** /**
* Get format * Get dims.
*/ */
int getFormat() { return getMD().data.format; } mkldnn::memory::dims getDims() {
mkldnn::memory::dims dst;
int* src = getMD().data.dims;
int ndims = getMD().data.ndims;
dst.resize(ndims);
for (int i = 0; i < ndims; ++i) {
dst[i] = src[i];
}
return dst;
}
~MKLDNNMatrix() {} /**
* Get format.
*/
mkldnn::memory::format getFormat() {
return (mkldnn::memory::format)(getMD().data.format);
}
/**
* Update the memory data handle.
* Caution: This will not check the buffer size of the data,
* it should be coverd by user.
*/
void updateData(void* data) { set_data_handle(data); }
}; };
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册