提交 bfbd066f 编写于 作者: T tensor-tang

refine

上级 fe51f726
...@@ -77,6 +77,24 @@ void MKLDNNFcLayer::convertWeightsToPaddle() { ...@@ -77,6 +77,24 @@ void MKLDNNFcLayer::convertWeightsToPaddle() {
wgtVal_->reorderDataTo(wgtVal_, dstFmt, targetDim); wgtVal_->reorderDataTo(wgtVal_, dstFmt, targetDim);
} }
void MKLDNNFcLayer::convertOutputToOtherDevice() {
copyOutputInfoToOtherDevice();
// find other cpu device and reorder output to cpu device
int cnt = 0;
for (size_t i = 0; i < outputOtherDevice_.size(); i++) {
if (outputOtherDevice_[i].deviceId == CPU_DEVICE) {
// fc cpu output value do not need convert
// just share point
outputOtherDevice_[i].value = output_.value;
++cnt;
}
}
if (cnt > 1) {
LOG(WARNING) << "should not have more than one CPU devie";
}
}
void MKLDNNFcLayer::reshape() { void MKLDNNFcLayer::reshape() {
const Argument& input = getInput(0, getPrev(0)->getDeviceId()); const Argument& input = getInput(0, getPrev(0)->getDeviceId());
int batchSize = input.getBatchSize(); int batchSize = input.getBatchSize();
...@@ -116,7 +134,7 @@ void MKLDNNFcLayer::resetFwd() { ...@@ -116,7 +134,7 @@ void MKLDNNFcLayer::resetFwd() {
const MatrixPtr& bias = hasBias ? biases_->getW() : nullptr; const MatrixPtr& bias = hasBias ? biases_->getW() : nullptr;
const MatrixPtr& out = output_.value; const MatrixPtr& out = output_.value;
if (prevIsMKLDNN()) { if (prevIsOnlyMKLDNN()) {
const MatrixPtr& in = getInputValue(0); const MatrixPtr& in = getInputValue(0);
inVal_ = std::dynamic_pointer_cast<MKLDNNMatrix>(in); inVal_ = std::dynamic_pointer_cast<MKLDNNMatrix>(in);
CHECK(inVal_) << "Input should be MKLDNNMatrix"; CHECK(inVal_) << "Input should be MKLDNNMatrix";
...@@ -136,30 +154,21 @@ void MKLDNNFcLayer::resetFwd() { ...@@ -136,30 +154,21 @@ void MKLDNNFcLayer::resetFwd() {
// change original output value to mkldnn output value // change original output value to mkldnn output value
output_.value = std::dynamic_pointer_cast<Matrix>(outVal_); output_.value = std::dynamic_pointer_cast<Matrix>(outVal_);
if (!nextIsMKLDNN()) { if (!nextIsOnlyMKLDNN()) {
Argument cpuOutput; convertOutputToOtherDevice();
for (size_t i = 0; i < outputOtherDevice_.size(); i++) {
if (outputOtherDevice_[i].deviceId == CPU_DEVICE) {
cpuOutput = outputOtherDevice_[i];
}
}
cpuOutput.setFrameHeight(output_.getFrameHeight());
cpuOutput.setFrameWidth(output_.getFrameWidth());
// fc cpu output value do not need convert
cpuOutput.value = output_.value;
} }
// create forward handle // create forward handle
prop_kind pk = prop_kind::forward; prop_kind pk = prop_kind::forward;
fc_fwd::desc fwdDesc = fc_fwd::desc fwdDesc = hasBias ? fc_fwd::desc(pk,
hasBias ? fc_fwd::desc(pk, inVal_->getMemoryDesc(),
inVal_->getMD(), wgtVal_->getMemoryDesc(),
wgtVal_->getMD(), biasVal_->getMemoryDesc(),
biasVal_->getMD(), outVal_->getMemoryDesc())
outVal_->getMD()) : fc_fwd::desc(pk,
: fc_fwd::desc( inVal_->getMemoryDesc(),
pk, inVal_->getMD(), wgtVal_->getMD(), outVal_->getMD()); wgtVal_->getMemoryDesc(),
outVal_->getMemoryDesc());
fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_); fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);
if (hasBias) { if (hasBias) {
fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *biasVal_, *outVal_)); fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *biasVal_, *outVal_));
...@@ -184,36 +193,38 @@ void MKLDNNFcLayer::resetBwd() { ...@@ -184,36 +193,38 @@ void MKLDNNFcLayer::resetBwd() {
const MatrixPtr& wgt = weight_->getWGrad(); const MatrixPtr& wgt = weight_->getWGrad();
const MatrixPtr& bias = hasBias ? biases_->getWGrad() : nullptr; const MatrixPtr& bias = hasBias ? biases_->getWGrad() : nullptr;
// TODO(TJ): merge topdiffs // TODO(TJ): merge outgrad
if (nextIsMKLDNN()) { if (nextIsOnlyMKLDNN()) {
// can not directly cast outputgrad to mkldnnmatrix, // can not directly cast outputgrad to mkldnnmatrix,
// since each layer can not write the inputgrad to mkldnn inputgrad. // since each layer can not write the inputgrad to mkldnn inputgrad.
// So just create from matrix with outputvalue format. // So just create from matrix with outputvalue format.
const MatrixPtr& out = getOutput(MKLDNN_DEVICE).grad; const MatrixPtr& out = getOutput(MKLDNN_DEVICE).grad;
outGrad_ = MKLDNNMatrix::create(out, outVal_->getPD()); outGrad_ = MKLDNNMatrix::create(out, outVal_->getPrimitiveDesc());
} else { } else {
const MatrixPtr& out = getOutput(CPU_DEVICE).grad; const MatrixPtr& out = getOutput(CPU_DEVICE).grad;
// fc do not need to convert from cpu device since output always nc // fc do not need to convert from cpu device since output always nc
// only need create from cpu device // only need create from cpu device
outGrad_ = MKLDNNMatrix::create(out, outVal_->getPD()); outGrad_ = MKLDNNMatrix::create(out, outVal_->getPrimitiveDesc());
} }
wgtGrad_ = MKLDNNMatrix::create(wgt, wgtVal_->getPD()); wgtGrad_ = MKLDNNMatrix::create(wgt, wgtVal_->getPrimitiveDesc());
biasGrad_ = hasBias ? MKLDNNMatrix::create(bias, biasVal_->getPD()) : nullptr; biasGrad_ = hasBias ? MKLDNNMatrix::create(bias, biasVal_->getPrimitiveDesc())
: nullptr;
// create memory primitive desc // create memory primitive desc
fc_fwd::desc fwdDesc = fc_fwd::desc(prop_kind::forward, fc_fwd::desc fwdDesc = fc_fwd::desc(prop_kind::forward,
inVal_->getMD(), inVal_->getMemoryDesc(),
wgtGrad_->getMD(), wgtGrad_->getMemoryDesc(),
outGrad_->getMD()); outGrad_->getMemoryDesc());
fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_); fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);
fc_bwdWgt::desc bwdWgtDesc = fc_bwdWgt::desc bwdWgtDesc = hasBias
hasBias ? fc_bwdWgt::desc(inVal_->getMD(), ? fc_bwdWgt::desc(inVal_->getMemoryDesc(),
wgtGrad_->getMD(), wgtGrad_->getMemoryDesc(),
biasGrad_->getMD(), biasGrad_->getMemoryDesc(),
outGrad_->getMD()) outGrad_->getMemoryDesc())
: fc_bwdWgt::desc( : fc_bwdWgt::desc(inVal_->getMemoryDesc(),
inVal_->getMD(), wgtGrad_->getMD(), outGrad_->getMD()); wgtGrad_->getMemoryDesc(),
outGrad_->getMemoryDesc());
fc_bwdWgt::primitive_desc bwdWgtPD = fc_bwdWgt::primitive_desc bwdWgtPD =
fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD); fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD);
...@@ -227,30 +238,20 @@ void MKLDNNFcLayer::resetBwd() { ...@@ -227,30 +238,20 @@ void MKLDNNFcLayer::resetBwd() {
pipelineBwd_.push_back(*bwdWgt_); pipelineBwd_.push_back(*bwdWgt_);
/// backward data /// backward data
if (prevIsMKLDNN()) { int device = prevIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE;
const MatrixPtr& in = getInputGrad(0, MKLDNN_DEVICE); const MatrixPtr& in = getInputGrad(0, device);
if (in == nullptr) { if (in == nullptr) {
return; return;
} }
if (getInput(0, MKLDNN_DEVICE).getAllCount() > 1) { if (getInput(0, device).getAllCount() > 1) {
// TODO(TJ): use outputMaps_ ways when merge topdiff done // TODO(TJ): use outputMaps_ ways when merge outgrad done
} else {
inGrad_ = MKLDNNMatrix::create(in, inVal_->getPD());
}
} else { } else {
const MatrixPtr& in = getInputGrad(0, CPU_DEVICE); inGrad_ = MKLDNNMatrix::create(in, inVal_->getPrimitiveDesc());
if (in == nullptr) {
return;
}
if (getInput(0, CPU_DEVICE).getAllCount() > 1) {
// TODO(TJ): use outputMaps_ ways when merge topdiff done
} else {
inGrad_ = MKLDNNMatrix::create(in, inVal_->getPD());
}
} }
fc_bwdData::desc bwdDataDesc = fc_bwdData::desc bwdDataDesc = fc_bwdData::desc(inVal_->getMemoryDesc(),
fc_bwdData::desc(inVal_->getMD(), wgtGrad_->getMD(), outGrad_->getMD()); wgtGrad_->getMemoryDesc(),
outGrad_->getMemoryDesc());
fc_bwdData::primitive_desc bwdDataPD = fc_bwdData::primitive_desc bwdDataPD =
fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD); fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD);
......
...@@ -72,6 +72,8 @@ protected: ...@@ -72,6 +72,8 @@ protected:
* only would be called when needed * only would be called when needed
*/ */
void resetBwd(); void resetBwd();
void convertOutputToOtherDevice() override;
}; };
} // namespace paddle } // namespace paddle
...@@ -86,10 +86,7 @@ public: ...@@ -86,10 +86,7 @@ public:
CHECK(FLAGS_use_mkldnn) << "MkldnnLayers only support use_mkldnn." CHECK(FLAGS_use_mkldnn) << "MkldnnLayers only support use_mkldnn."
<< "Please set WITH_MKLDNN=ON " << "Please set WITH_MKLDNN=ON "
<< "and set use_mkldnn=True"; << "and set use_mkldnn=True";
if (useGpu_ == true) { CHECK(!useGpu_) << "Do not support GPU yet";
LOG(WARNING) << "Do not support GPU yet, will change to useGpu = false";
useGpu_ = false;
}
// set device id before Layer::init // set device id before Layer::init
setDevice(MKLDNN_DEVICE); setDevice(MKLDNN_DEVICE);
...@@ -116,6 +113,12 @@ public: ...@@ -116,6 +113,12 @@ public:
*/ */
virtual void convertWeightsToPaddle() {} virtual void convertWeightsToPaddle() {}
/**
* convert MKLDNN output to other device.
* only support CPU device yet
*/
virtual void convertOutputToOtherDevice() {}
/** /**
* print info about sizes * print info about sizes
*/ */
...@@ -147,22 +150,25 @@ public: ...@@ -147,22 +150,25 @@ public:
protected: protected:
/** /**
* If next layer only has MKLDNN type. * copy image size and sequence info to other device
* Otherwise, only support otherdevice CPU device.
*/ */
bool nextIsMKLDNN() { void copyOutputInfoToOtherDevice() {
for (size_t i = 0; i < outputOtherDevice_.size(); i++) { for (size_t i = 0; i < outputOtherDevice_.size(); i++) {
CHECK_EQ(outputOtherDevice_[i].deviceId, CPU_DEVICE) outputOtherDevice_[i].setFrameHeight(output_.getFrameHeight());
<< "Only support other device is CPU yet"; outputOtherDevice_[i].setFrameWidth(output_.getFrameWidth());
outputOtherDevice_[i].sequenceStartPositions =
output_.sequenceStartPositions;
outputOtherDevice_[i].subSequenceStartPositions =
output_.subSequenceStartPositions;
outputOtherDevice_[i].cpuSequenceDims = output_.cpuSequenceDims;
} }
return outputOtherDevice_.size() == 0;
} }
/** /**
* Is previous layer MKLDNN type. * Is previous layer only has MKLDNN type.
* Otherwise, only support otherdevice CPU device. * Otherwise, only support the previous layer using CPU device.
*/ */
bool prevIsMKLDNN(int index = 0) { bool prevIsOnlyMKLDNN(int index = 0) {
int prevDevice = getPrev(index)->getDeviceId(); int prevDevice = getPrev(index)->getDeviceId();
if (prevDevice == MKLDNN_DEVICE) { if (prevDevice == MKLDNN_DEVICE) {
return true; return true;
...@@ -173,11 +179,23 @@ protected: ...@@ -173,11 +179,23 @@ protected:
} }
} }
/**
* If output only has MKLDNN device.
* Otherwise, other devices should only using CPU device.
*/
bool nextIsOnlyMKLDNN() {
for (size_t i = 0; i < outputOtherDevice_.size(); i++) {
CHECK_EQ(outputOtherDevice_[i].deviceId, CPU_DEVICE)
<< "Only support other device is CPU yet";
}
return outputOtherDevice_.size() == 0;
}
/** /**
* Sync input value data * Sync input value data
*/ */
void syncInputValue() { void syncInputValue() {
if (prevIsMKLDNN()) { if (prevIsOnlyMKLDNN()) {
return; return;
} }
real* iData = getInputValue(0, CPU_DEVICE)->getData(); real* iData = getInputValue(0, CPU_DEVICE)->getData();
...@@ -190,7 +208,7 @@ protected: ...@@ -190,7 +208,7 @@ protected:
* Sync output grad data * Sync output grad data
*/ */
void syncOutputGrad() { void syncOutputGrad() {
if (nextIsMKLDNN()) { if (nextIsOnlyMKLDNN()) {
return; return;
} }
......
...@@ -31,7 +31,6 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, memory::primitive_desc pd) { ...@@ -31,7 +31,6 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, memory::primitive_desc pd) {
if (m == nullptr) { if (m == nullptr) {
size_t height = dims[0]; size_t height = dims[0];
size_t width = cnts / dims[0]; size_t width = cnts / dims[0];
// LOG(INFO) << height << "," << width;
m = Matrix::create(height, width, false, false); m = Matrix::create(height, width, false, false);
} }
...@@ -40,10 +39,8 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, memory::primitive_desc pd) { ...@@ -40,10 +39,8 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, memory::primitive_desc pd) {
CHECK(cpuMatrix) << "Only support create from CPU matrix yet"; CHECK(cpuMatrix) << "Only support create from CPU matrix yet";
CHECK_EQ(cnts, m->getElementCnt()) << "Count size does not match"; CHECK_EQ(cnts, m->getElementCnt()) << "Count size does not match";
size_t width = m->getWidth(); return std::make_shared<MKLDNNMatrix>(
size_t height = m->getHeight(); m->getData(), m->getHeight(), m->getWidth(), pd);
real* data = m->getData();
return std::make_shared<MKLDNNMatrix>(data, height, width, pd);
} }
MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m,
...@@ -51,9 +48,7 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, ...@@ -51,9 +48,7 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m,
memory::format fmt, memory::format fmt,
engine& eg, engine& eg,
mkldnn::memory::data_type dtype) { mkldnn::memory::data_type dtype) {
memory::desc md = memory::desc(dims, dtype, fmt); return create(m, memory::primitive_desc(memory::desc(dims, dtype, fmt), eg));
memory::primitive_desc pd = memory::primitive_desc(md, eg);
return create(m, pd);
} }
void MKLDNNMatrix::reorderDataFrom(const MKLDNNMatrixPtr& m, void MKLDNNMatrix::reorderDataFrom(const MKLDNNMatrixPtr& m,
...@@ -64,9 +59,7 @@ void MKLDNNMatrix::reorderDataFrom(const MKLDNNMatrixPtr& m, ...@@ -64,9 +59,7 @@ void MKLDNNMatrix::reorderDataFrom(const MKLDNNMatrixPtr& m,
return; return;
} }
CHECK_EQ(getElementCnt(), m->getElementCnt()) << "size should equal"; CHECK_EQ(getElementCnt(), m->getElementCnt()) << "size should equal";
real* srcData = getData(); reorderOnce(getData(), m->getData(), srcFmt, dstFmt, targetDim);
real* dstData = m->getData();
reorderOnce(srcData, dstData, srcFmt, dstFmt, targetDim);
} }
void MKLDNNMatrix::reorderDataTo(const MKLDNNMatrixPtr& m, void MKLDNNMatrix::reorderDataTo(const MKLDNNMatrixPtr& m,
...@@ -77,9 +70,7 @@ void MKLDNNMatrix::reorderDataTo(const MKLDNNMatrixPtr& m, ...@@ -77,9 +70,7 @@ void MKLDNNMatrix::reorderDataTo(const MKLDNNMatrixPtr& m,
return; return;
} }
CHECK_EQ(getElementCnt(), m->getElementCnt()) << "size should equal"; CHECK_EQ(getElementCnt(), m->getElementCnt()) << "size should equal";
real* srcData = getData(); reorderOnce(getData(), m->getData(), srcFmt, dstFmt, targetDim);
real* dstData = m->getData();
reorderOnce(srcData, dstData, srcFmt, dstFmt, targetDim);
} }
void MKLDNNMatrix::reorderOnce(void* srcData, void MKLDNNMatrix::reorderOnce(void* srcData,
...@@ -120,8 +111,9 @@ void MKLDNNMatrix::downSpatial() { ...@@ -120,8 +111,9 @@ void MKLDNNMatrix::downSpatial() {
return; return;
} }
memory::dims srcDims = getDims(); // TODO(TJ): change H(height) and W(width) if support nhwc or more
const int H = 2, W = 3; const int H = 2, W = 3;
memory::dims srcDims = getDims();
if (srcDims[H] != 1 || srcDims[W] != 1) { if (srcDims[H] != 1 || srcDims[W] != 1) {
// can not down spatial // can not down spatial
return; return;
...@@ -141,13 +133,12 @@ void MKLDNNMatrix::downSpatial() { ...@@ -141,13 +133,12 @@ void MKLDNNMatrix::downSpatial() {
} }
memory::desc md = memory::desc(dstDims, getDtype(), dstFmt); memory::desc md = memory::desc(dstDims, getDtype(), dstFmt);
memory::primitive_desc pd = memory::primitive_desc(md, getEngine()); memory::primitive_desc pd = memory::primitive_desc(md, getEngine());
void* data = getData();
mkldnn_primitive_t result; mkldnn_primitive_t result;
mkldnn::error::wrap_c_api( mkldnn::error::wrap_c_api(
mkldnn_primitive_create(&result, pd.get(), nullptr, nullptr), mkldnn_primitive_create(&result, pd.get(), nullptr, nullptr),
"could not create a memory primitive"); "could not create a memory primitive");
reset(result); reset(result);
set_data_handle(data); set_data_handle(getData());
} }
} // namespace paddle } // namespace paddle
...@@ -56,9 +56,9 @@ public: ...@@ -56,9 +56,9 @@ public:
public: public:
/** /**
* Reorder this MKLDNNMatrix from other format. * Reorder this MKLDNNMatrix from other format.
* Support inplace reorder * Support inplace reorder.
* Pay attention: this function would only reorder the data layout. * @note: this function would only reorder the data layout.
* will NOT change this original dim or format info * will NOT change this original dim or format info
*/ */
void reorderDataFrom(const MKLDNNMatrixPtr& m, void reorderDataFrom(const MKLDNNMatrixPtr& m,
memory::format srcFmt, memory::format srcFmt,
...@@ -66,9 +66,9 @@ public: ...@@ -66,9 +66,9 @@ public:
/** /**
* Reorder this MKLDNNMatrix to other format. * Reorder this MKLDNNMatrix to other format.
* Support inplace reorder * Support inplace reorder.
* Pay attention: this function would only reorder the data layout. * @note: this function would only reorder the data layout.
* will NOT change the dst dim or format info * will NOT change the dst dim or format info
*/ */
void reorderDataTo(const MKLDNNMatrixPtr& m, void reorderDataTo(const MKLDNNMatrixPtr& m,
memory::format dstFmt, memory::format dstFmt,
...@@ -90,18 +90,20 @@ public: ...@@ -90,18 +90,20 @@ public:
/** /**
* Get primitive descriptor. * Get primitive descriptor.
*/ */
mkldnn::memory::primitive_desc getPD() { return this->get_primitive_desc(); } mkldnn::memory::primitive_desc getPrimitiveDesc() {
return this->get_primitive_desc();
}
/** /**
* Get memory descriptor. * Get memory descriptor.
*/ */
mkldnn::memory::desc getMD() { return getPD().desc(); } mkldnn::memory::desc getMemoryDesc() { return getPrimitiveDesc().desc(); }
/** /**
* Get dimensions. * Get dimensions.
*/ */
mkldnn::memory::dims getDims() { mkldnn::memory::dims getDims() {
mkldnn::memory::desc md = getMD(); mkldnn::memory::desc md = getMemoryDesc();
const int* src = md.data.dims; const int* src = md.data.dims;
int ndims = md.data.ndims; int ndims = md.data.ndims;
mkldnn::memory::dims dst; mkldnn::memory::dims dst;
...@@ -116,24 +118,25 @@ public: ...@@ -116,24 +118,25 @@ public:
* Get format. * Get format.
*/ */
mkldnn::memory::format getFormat() { mkldnn::memory::format getFormat() {
return (mkldnn::memory::format)(getMD().data.format); return (mkldnn::memory::format)(getMemoryDesc().data.format);
} }
/** /**
* Get memory data type. * Get memory data type.
*/ */
mkldnn::memory::data_type getDtype() { mkldnn::memory::data_type getDtype() {
return (mkldnn::memory::data_type)(getMD().data.data_type); return (mkldnn::memory::data_type)(getMemoryDesc().data.data_type);
} }
/** /**
* Get engine. * Get engine.
*/ */
mkldnn::engine getEngine() { return getPD().get_engine(); } mkldnn::engine getEngine() { return getPrimitiveDesc().get_engine(); }
protected: protected:
/** /**
* Do once reorder supported inplace. * Do reorder once.
* Can support inplace.
*/ */
void reorderOnce(void* srcData, void reorderOnce(void* srcData,
void* dstData, void* dstData,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册