提交 4eecd0c2 编写于 作者: T tensor-tang

use MKLDNNMatrix in fc backward

上级 4bffbd30
......@@ -158,10 +158,8 @@ void MKLDNNFcLayer::resetFwd() {
hasSpatial_ ? memory::dims{oc_, ic_, ih_, iw_} : memory::dims{oc_, ic_},
hasSpatial_ ? format::oihw : format::oi,
engine_);
biasVal_ =
hasBias ? MKLDNNMatrix::create(bias, {oc_}, format::x, engine_) : nullptr;
outVal_ = MKLDNNMatrix::create(out, {bs_, oc_}, format::nc, engine_);
// change original output to mkldnn output
......@@ -193,46 +191,41 @@ void MKLDNNFcLayer::resetBwd() {
return;
}
needResetBwd_ = false;
bool hasBias = biases_ && biases_->getWGrad();
real* iData = getInputValue(0)->getData();
real* iDiff = getInputGrad(0) != nullptr ? getInputGrad(0)->getData() : NULL;
real* oDiff = getOutputGrad()->getData();
real* wDiff = weight_->getWGrad()->getData();
real* bDiff = hasBias ? biases_->getWGrad()->getData() : NULL;
/// backward weight
// create memory desc for backward memory
memory::desc iMD = hasSpatial_ ? createMD({bs_, ic_, ih_, iw_}, format::nchw)
: createMD({bs_, ic_}, format::nc);
memory::desc wMD = hasSpatial_ ? createMD({oc_, ic_, ih_, iw_}, format::oihw)
: createMD({oc_, ic_}, format::oi);
memory::desc oMD = createMD({bs_, oc_}, format::nc);
memory::desc bMD = bDiff != NULL ? createMD({oc_}, format::x)
: createMD({}, format::format_undef);
if (inVal_) {
// update data
inVal_->set_data_handle(iData);
} else {
LOG(FATAL) << "Should not be empty";
// inVal_.reset(new memory(memory::primitive_desc(iMD, engine_), iData));
}
// create memory primitive desc and memory self
wgtGrad_.reset(new memory(memory::primitive_desc(wMD, engine_), wDiff));
outGrad_.reset(new memory(memory::primitive_desc(oMD, engine_), oDiff));
CHECK(inVal_) << "Should have input value";
const MatrixPtr& wgt = weight_->getWGrad();
const MatrixPtr& bias = hasBias ? biases_->getWGrad() : nullptr;
const MatrixPtr& out = output_.grad;
wgtGrad_ = MKLDNNMatrix::create(
wgt, wgtVal_->getDims(), wgtVal_->getFormat(), engine_);
biasGrad_ =
hasBias ? MKLDNNMatrix::create(bias, {oc_}, format::x, engine_) : nullptr;
fc_fwd::desc fwdDesc = fc_fwd::desc(prop_kind::forward, iMD, wMD, oMD);
outGrad_ = MKLDNNMatrix::create(out, {bs_, oc_}, format::nc, engine_);
// change original output to mkldnn output
// TODO: right?
output_.grad = std::dynamic_pointer_cast<Matrix>(outGrad_);
// create memory primitive desc
fc_fwd::desc fwdDesc = fc_fwd::desc(prop_kind::forward,
inVal_->getMD(),
wgtGrad_->getMD(),
outGrad_->getMD());
fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);
fc_bwdWgt::desc bwdWgtDesc = bDiff != NULL
? fc_bwdWgt::desc(iMD, wMD, bMD, oMD)
: fc_bwdWgt::desc(iMD, wMD, oMD);
fc_bwdWgt::desc bwdWgtDesc =
hasBias ? fc_bwdWgt::desc(inVal_->getMD(),
wgtGrad_->getMD(),
biasGrad_->getMD(),
outGrad_->getMD())
: fc_bwdWgt::desc(
inVal_->getMD(), wgtGrad_->getMD(), outGrad_->getMD());
fc_bwdWgt::primitive_desc bwdWgtPD =
fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD);
if (bDiff != NULL) {
biasGrad_.reset(new memory(memory::primitive_desc(bMD, engine_), bDiff));
if (hasBias) {
bwdWgt_.reset(
new fc_bwdWgt(bwdWgtPD, *inVal_, *outGrad_, *wgtGrad_, *biasGrad_));
} else {
......@@ -242,13 +235,19 @@ void MKLDNNFcLayer::resetBwd() {
pipelineBwd_.push_back(*bwdWgt_);
/// backward data
if (iDiff == NULL) {
const MatrixPtr& in = getInputGrad(0);
if (in == nullptr) {
return;
}
fc_bwdData::desc bwdDataDesc = fc_bwdData::desc(iMD, wMD, oMD);
fc_bwdData::desc bwdDataDesc =
fc_bwdData::desc(inVal_->getMD(), wgtGrad_->getMD(), outGrad_->getMD());
fc_bwdData::primitive_desc bwdDataPD =
fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD);
inGrad_.reset(new memory(memory::primitive_desc(iMD, engine_), iDiff));
// TODO: check right, just from ingrad?
inGrad_ =
MKLDNNMatrix::create(in, inVal_->getDims(), inVal_->getFormat(), engine_);
CHECK(wgtVal_) << "Should have weight memory";
bwdData_.reset(new fc_bwdData(bwdDataPD, *outGrad_, *wgtVal_, *inGrad_));
pipelineBwd_.push_back(*bwdData_);
......@@ -264,7 +263,7 @@ void MKLDNNFcLayer::forward(PassType passType) {
// update input data
// since it might be changed if this is after data layer
real* iData = getInputValue(0)->getData();
inVal_->set_data_handle(iData);
inVal_->updateData(iData);
// just submit forward pipeline
stream_->submit(pipelineFwd_);
......@@ -288,7 +287,7 @@ void MKLDNNFcLayer::backward(const UpdateCallback& callback) {
// update diff
real* oDiff = getOutputGrad()->getData();
outGrad_->set_data_handle(oDiff);
outGrad_->updateData(oDiff);
// just sumbmit backward pipeline
stream_->submit(pipelineBwd_);
......
......@@ -52,16 +52,15 @@ protected:
std::vector<mkldnn::primitive> pipelineFwd_;
std::vector<mkldnn::primitive> pipelineBwd_;
// TODO(TJ): change below memory as MKLDNNMatrixPtr type
// MKLDNNMatrixPtr ;
// MKLDNNMatrixPtr
MKLDNNMatrixPtr inVal_;
std::shared_ptr<mkldnn::memory> inGrad_;
MKLDNNMatrixPtr inGrad_;
MKLDNNMatrixPtr outVal_;
std::shared_ptr<mkldnn::memory> outGrad_;
MKLDNNMatrixPtr outGrad_;
MKLDNNMatrixPtr wgtVal_;
std::shared_ptr<mkldnn::memory> wgtGrad_;
MKLDNNMatrixPtr wgtGrad_;
MKLDNNMatrixPtr biasVal_;
std::shared_ptr<mkldnn::memory> biasGrad_;
MKLDNNMatrixPtr biasGrad_;
public:
explicit MKLDNNLayer(const LayerConfig& config)
......@@ -84,17 +83,24 @@ public:
virtual bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) {
CHECK(FLAGS_use_mkldnn) << "MkldnnLayers only support use_mkldnn."
<< "Please set WITH_MKLDNN=ON "
<< "and set use_mkldnn=True";
if (useGpu_ == true) {
LOG(WARNING) << "Do not support GPU yet, will change to useGpu = false";
useGpu_ = false;
}
// set device id before Layer::init
setDevice(MKLDNN_DEVICE);
// change param device to MKLDNN device
setParamsDevice(MKLDNN_DEVICE, parameterMap);
if (!Layer::init(layerMap, parameterMap)) {
return false;
}
CHECK(FLAGS_use_mkldnn) << "MkldnnLayers only support use_mkldnn."
<< "Please set WITH_MKLDNN=ON "
<< "and set use_mkldnn=True";
stream_.reset(new MKLDNNStream());
engine_ = CPUEngine::Instance().getEngine();
setDeviceID(MKLDNN_DEVICE);
return true;
}
......@@ -136,10 +142,33 @@ public:
}
protected:
void setDeviceID(int id) {
deviceId_ = id;
output_.deviceId = id;
// TODO: handle mkldnn device or add mkldnn device to other
/**
* Set deviceId of this layer.
*/
void setDevice(int id) { deviceId_ = id; }
/**
* Set deviceId of the params used in this layer.
*/
void setParamsDevice(int id, const ParameterMap& parameterMap) {
for (auto& inputConfig : config_.inputs()) {
if (inputConfig.has_input_parameter_name()) {
ParameterPtr parameter;
std::string name = inputConfig.input_parameter_name();
CHECK(mapGet(name, parameterMap, &parameter))
<< "Cannot find input parameter " << name << " for layer "
<< getName();
parameter->setDevice(id);
}
}
if (config_.has_bias_parameter_name()) {
ParameterPtr parameter;
std::string name = config_.bias_parameter_name();
CHECK(mapGet(name, parameterMap, &parameter))
<< "Cannot find bias parameter " << name << " for layer "
<< getName();
parameter->setDevice(id);
}
}
};
......
......@@ -44,6 +44,8 @@ public:
set_data_handle(CpuMatrix::getData());
}
~MKLDNNMatrix() {}
static MKLDNNMatrixPtr create(
const MatrixPtr& m,
mkldnn::memory::dims dims,
......@@ -52,21 +54,42 @@ public:
mkldnn::memory::data_type dtype = mkldnn::memory::data_type::f32);
/**
* Get primitive descriptor
* Get primitive descriptor.
*/
mkldnn::memory::primitive_desc getPD() { return this->get_primitive_desc(); }
/**
* Get memory descriptor
* Get memory descriptor.
*/
mkldnn::memory::desc getMD() { return getPD().desc(); }
/**
* Get format
* Get dims.
*/
int getFormat() { return getMD().data.format; }
mkldnn::memory::dims getDims() {
mkldnn::memory::dims dst;
int* src = getMD().data.dims;
int ndims = getMD().data.ndims;
dst.resize(ndims);
for (int i = 0; i < ndims; ++i) {
dst[i] = src[i];
}
return dst;
}
~MKLDNNMatrix() {}
/**
* Get format.
*/
mkldnn::memory::format getFormat() {
return (mkldnn::memory::format)(getMD().data.format);
}
/**
* Update the memory data handle.
* Caution: This will not check the buffer size of the data,
* it should be coverd by user.
*/
void updateData(void* data) { set_data_handle(data); }
};
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册