提交 faf827ba 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #4008 from tensor-tang/refine

Refine MKLDNNMatrix and MKLDNNLayer
...@@ -77,24 +77,6 @@ void MKLDNNFcLayer::convertWeightsToPaddle() { ...@@ -77,24 +77,6 @@ void MKLDNNFcLayer::convertWeightsToPaddle() {
wgtVal_->reorderDataTo(wgtVal_, dstFmt, targetDim); wgtVal_->reorderDataTo(wgtVal_, dstFmt, targetDim);
} }
void MKLDNNFcLayer::convertOutputToOtherDevice() {
copyOutputInfoToOtherDevice();
// find other cpu device and reorder output to cpu device
int cnt = 0;
for (size_t i = 0; i < outputOtherDevice_.size(); i++) {
if (outputOtherDevice_[i].deviceId == CPU_DEVICE) {
// fc cpu output value do not need convert
// just share point
outputOtherDevice_[i].value = output_.value;
++cnt;
}
}
if (cnt > 1) {
LOG(WARNING) << "should not have more than one CPU devie";
}
}
void MKLDNNFcLayer::reshape() { void MKLDNNFcLayer::reshape() {
const Argument& input = getInput(0, getPrev(0)->getDeviceId()); const Argument& input = getInput(0, getPrev(0)->getDeviceId());
int batchSize = input.getBatchSize(); int batchSize = input.getBatchSize();
...@@ -155,7 +137,10 @@ void MKLDNNFcLayer::resetFwd() { ...@@ -155,7 +137,10 @@ void MKLDNNFcLayer::resetFwd() {
// change original output value to mkldnn output value // change original output value to mkldnn output value
output_.value = std::dynamic_pointer_cast<Matrix>(outVal_); output_.value = std::dynamic_pointer_cast<Matrix>(outVal_);
if (!outputIsOnlyMKLDNN()) { if (!outputIsOnlyMKLDNN()) {
convertOutputToOtherDevice(); copyOutputInfoToOtherDevice();
// fc cpu output value do not need create convert
// just share point
getOutput(CPU_DEVICE).value->setData(output_.value->getData());
} }
// create forward handle // create forward handle
...@@ -235,13 +220,12 @@ void MKLDNNFcLayer::resetBwd() { ...@@ -235,13 +220,12 @@ void MKLDNNFcLayer::resetBwd() {
pipelineBwd_.push_back(*bwdWgt_); pipelineBwd_.push_back(*bwdWgt_);
/// backward data /// backward data
device = inputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE; const MatrixPtr& in = inputLayers_[0]->getOutput().grad;
const MatrixPtr& in = getInputGrad(0, device);
if (in == nullptr) { if (in == nullptr) {
return; return;
} }
if (getInput(0, device).getAllCount() > 1) { if (getInput(0, MKLDNN_DEVICE).getAllCount() > 1) {
// TODO(TJ): use outputMaps_ ways when merge outgrad done // TODO(TJ): use outputMaps_ ways to get the inGrad_ when merge outgrad done
} else { } else {
inGrad_ = MKLDNNMatrix::create(in, inVal_->getPrimitiveDesc()); inGrad_ = MKLDNNMatrix::create(in, inVal_->getPrimitiveDesc());
} }
...@@ -258,13 +242,21 @@ void MKLDNNFcLayer::resetBwd() { ...@@ -258,13 +242,21 @@ void MKLDNNFcLayer::resetBwd() {
pipelineBwd_.push_back(*bwdData_); pipelineBwd_.push_back(*bwdData_);
} }
void MKLDNNFcLayer::updateInputData() {
if (inputLayers_[0]->getType() != "data") {
return;
}
real* iData = getInputValue(0, CPU_DEVICE)->getData();
inVal_->setData(iData);
}
void MKLDNNFcLayer::forward(PassType passType) { void MKLDNNFcLayer::forward(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
reshape(); reshape();
{ {
REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str()); REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str());
syncInputValue(); updateInputData();
// just submit forward pipeline // just submit forward pipeline
stream_->submit(pipelineFwd_); stream_->submit(pipelineFwd_);
...@@ -286,7 +278,6 @@ void MKLDNNFcLayer::backward(const UpdateCallback& callback) { ...@@ -286,7 +278,6 @@ void MKLDNNFcLayer::backward(const UpdateCallback& callback) {
REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str()); REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str());
resetBwd(); resetBwd();
syncOutputGrad();
// just sumbmit backward pipeline // just sumbmit backward pipeline
stream_->submit(pipelineBwd_); stream_->submit(pipelineBwd_);
} }
......
...@@ -53,6 +53,8 @@ public: ...@@ -53,6 +53,8 @@ public:
void backward(const UpdateCallback& callback) override; void backward(const UpdateCallback& callback) override;
void updateInputData() override;
protected: protected:
/** /**
* reshape the input image sizes * reshape the input image sizes
...@@ -72,8 +74,6 @@ protected: ...@@ -72,8 +74,6 @@ protected:
* only would be called when needed * only would be called when needed
*/ */
void resetBwd(); void resetBwd();
void convertOutputToOtherDevice() override;
}; };
} // namespace paddle } // namespace paddle
...@@ -114,10 +114,10 @@ public: ...@@ -114,10 +114,10 @@ public:
virtual void convertWeightsToPaddle() {} virtual void convertWeightsToPaddle() {}
/** /**
* convert MKLDNN output to other device. * Update input value data when input layer is "data" type.
* only support CPU device yet * Since the input value data address might be changed.
*/ */
virtual void convertOutputToOtherDevice() {} virtual void updateInputData() {}
/** /**
* print info about sizes * print info about sizes
...@@ -155,6 +155,7 @@ protected: ...@@ -155,6 +155,7 @@ protected:
* copy base info and do not copy data value * copy base info and do not copy data value
*/ */
void copyOutputInfoToOtherDevice() { void copyOutputInfoToOtherDevice() {
int cnt = 0;
for (size_t i = 0; i < outputOtherDevice_.size(); i++) { for (size_t i = 0; i < outputOtherDevice_.size(); i++) {
outputOtherDevice_[i].setFrameHeight(output_.getFrameHeight()); outputOtherDevice_[i].setFrameHeight(output_.getFrameHeight());
outputOtherDevice_[i].setFrameWidth(output_.getFrameWidth()); outputOtherDevice_[i].setFrameWidth(output_.getFrameWidth());
...@@ -163,6 +164,12 @@ protected: ...@@ -163,6 +164,12 @@ protected:
outputOtherDevice_[i].subSequenceStartPositions = outputOtherDevice_[i].subSequenceStartPositions =
output_.subSequenceStartPositions; output_.subSequenceStartPositions;
outputOtherDevice_[i].cpuSequenceDims = output_.cpuSequenceDims; outputOtherDevice_[i].cpuSequenceDims = output_.cpuSequenceDims;
if (outputOtherDevice_[i].deviceId == CPU_DEVICE) {
++cnt;
}
}
if (cnt > 1) {
LOG(WARNING) << "should not have more than one CPU devie";
} }
} }
...@@ -193,32 +200,6 @@ protected: ...@@ -193,32 +200,6 @@ protected:
return outputOtherDevice_.size() == 0; return outputOtherDevice_.size() == 0;
} }
/**
* Sync input value data
*/
void syncInputValue() {
if (inputIsOnlyMKLDNN()) {
return;
}
real* iData = getInputValue(0, CPU_DEVICE)->getData();
// update input data
// since it might be changed if this is after data layer
inVal_->updateData(iData);
}
/**
* Sync output grad data
*/
void syncOutputGrad() {
if (outputIsOnlyMKLDNN()) {
return;
}
// update diff
real* oDiff = getOutput(CPU_DEVICE).grad->getData();
outGrad_->updateData(oDiff);
}
/** /**
* Set deviceId of this layer. * Set deviceId of this layer.
*/ */
......
...@@ -33,14 +33,12 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, memory::primitive_desc pd) { ...@@ -33,14 +33,12 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, memory::primitive_desc pd) {
size_t width = cnts / dims[0]; size_t width = cnts / dims[0];
m = Matrix::create(height, width, false, false); m = Matrix::create(height, width, false, false);
} }
CHECK(m) << " Matrix should not be empty"; CHECK(m) << " Matrix should not be empty";
CpuMatrixPtr cpuMatrix = std::dynamic_pointer_cast<CpuMatrix>(m); CpuMatrixPtr cpuMatrix = std::dynamic_pointer_cast<CpuMatrix>(m);
CHECK(cpuMatrix) << "Only support create from CPU matrix yet"; CHECK(cpuMatrix) << "Only support create from CPU matrix yet";
CHECK_EQ(cpuMatrix->getElementCnt(), cnts) << "Count size does not match";
CHECK_EQ(cnts, m->getElementCnt()) << "Count size does not match"; return std::make_shared<MKLDNNMatrix>(cpuMatrix, pd);
return std::make_shared<MKLDNNMatrix>(
m->getData(), m->getHeight(), m->getWidth(), pd);
} }
MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m,
...@@ -138,7 +136,7 @@ void MKLDNNMatrix::downSpatial() { ...@@ -138,7 +136,7 @@ void MKLDNNMatrix::downSpatial() {
mkldnn_primitive_create(&result, pd.get(), nullptr, nullptr), mkldnn_primitive_create(&result, pd.get(), nullptr, nullptr),
"could not create a memory primitive"); "could not create a memory primitive");
reset(result); reset(result);
set_data_handle(getData()); set_data_handle(data_);
} }
} // namespace paddle } // namespace paddle
...@@ -30,11 +30,10 @@ typedef std::shared_ptr<MKLDNNMatrix> MKLDNNMatrixPtr; ...@@ -30,11 +30,10 @@ typedef std::shared_ptr<MKLDNNMatrix> MKLDNNMatrixPtr;
*/ */
class MKLDNNMatrix : public CpuMatrix, public mkldnn::memory { class MKLDNNMatrix : public CpuMatrix, public mkldnn::memory {
public: public:
MKLDNNMatrix(real* data, MKLDNNMatrix(CpuMatrixPtr m, mkldnn::memory::primitive_desc pd)
size_t height, : CpuMatrix(m->getData(), m->getHeight(), m->getWidth(), false),
size_t width, mkldnn::memory(pd, m->getData()),
mkldnn::memory::primitive_desc pd) m_(m) {}
: CpuMatrix(data, height, width, false), mkldnn::memory(pd, data) {}
~MKLDNNMatrix() {} ~MKLDNNMatrix() {}
...@@ -81,11 +80,29 @@ public: ...@@ -81,11 +80,29 @@ public:
void downSpatial(); void downSpatial();
/** /**
* Update the memory data handle. * set the memory data handle.
* Caution: This will not check the buffer size of the data, * Caution: This will not check the buffer size of the data,
* it should be coverd by user. * it should be coverd by user.
*/ */
void updateData(void* data) { set_data_handle(data); } void setData(real* data) {
set_data_handle(data);
CpuMatrix::setData(data);
m_.reset();
}
/**
* override Matrix::getData
* check data before return
*/
real* getData() override {
CHECK_EQ((void*)data_, get_data_handle());
return data_;
}
const real* getData() const override {
CHECK_EQ((void*)data_, get_data_handle());
return data_;
}
/** /**
* Get primitive descriptor. * Get primitive descriptor.
...@@ -143,6 +160,10 @@ protected: ...@@ -143,6 +160,10 @@ protected:
memory::format srcFmt, memory::format srcFmt,
memory::format dstFmt, memory::format dstFmt,
memory::dims dm); memory::dims dm);
private:
// save the CpuMatrixPtr in case the buffer released outside
CpuMatrixPtr m_;
}; };
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册