提交 9e38dafa 编写于 作者: T tensor-tang

change MKLDNNMatrix create interface since MatrixPtr is not always required

上级 c1914543
...@@ -126,7 +126,7 @@ void MKLDNNEltwiseActivation::resetFwd(Argument& act) { ...@@ -126,7 +126,7 @@ void MKLDNNEltwiseActivation::resetFwd(Argument& act) {
copyInVal_ = nullptr; copyInVal_ = nullptr;
if (act.grad && algo == algorithm::eltwise_tanh) { if (act.grad && algo == algorithm::eltwise_tanh) {
// tanh need save src input for backward // tanh need save src input for backward
inVal_ = MKLDNNMatrix::create(nullptr, val_->getPrimitiveDesc()); inVal_ = MKLDNNMatrix::create(val_->getPrimitiveDesc());
copyInVal_ = std::make_shared<mkldnn::reorder>(*val_, *inVal_); copyInVal_ = std::make_shared<mkldnn::reorder>(*val_, *inVal_);
CHECK(copyInVal_) << "should not be emptry"; CHECK(copyInVal_) << "should not be emptry";
pipelineFwd_.push_back(*copyInVal_); pipelineFwd_.push_back(*copyInVal_);
...@@ -145,7 +145,7 @@ void MKLDNNEltwiseActivation::resetBwd(Argument& act) { ...@@ -145,7 +145,7 @@ void MKLDNNEltwiseActivation::resetBwd(Argument& act) {
algorithm algo = getAlgo(this->getName()); algorithm algo = getAlgo(this->getName());
float alpha = getBwdAlpha(); float alpha = getBwdAlpha();
float beta = getBeta(); float beta = getBeta();
grad_ = MKLDNNMatrix::create(act.grad, val_->getPrimitiveDesc()); grad_ = MKLDNNMatrix::create(val_->getPrimitiveDesc(), act.grad);
auto eng = CPUEngine::Instance().getEngine(); auto eng = CPUEngine::Instance().getEngine();
auto bwdDesc = eltwise_bwd::desc( auto bwdDesc = eltwise_bwd::desc(
algo, grad_->getMemoryDesc(), val_->getMemoryDesc(), alpha, beta); algo, grad_->getMemoryDesc(), val_->getMemoryDesc(), alpha, beta);
...@@ -230,7 +230,7 @@ void MKLDNNActivation::resetFwd(Argument& act) { ...@@ -230,7 +230,7 @@ void MKLDNNActivation::resetFwd(Argument& act) {
int ic = cnt_ / bs / ih / iw; int ic = cnt_ / bs / ih / iw;
CHECK_EQ(cnt_, (size_t)bs * ic * ih * iw); CHECK_EQ(cnt_, (size_t)bs * ic * ih * iw);
val_ = MKLDNNMatrix::create( val_ = MKLDNNMatrix::create(
act.value, {bs, ic, ih, iw}, mkldnn::memory::format::nchw, *engine_); {bs, ic, ih, iw}, mkldnn::memory::format::nchw, *engine_, act.value);
CHECK(val_); CHECK(val_);
val_->downSpatial(); val_->downSpatial();
} }
......
...@@ -370,8 +370,7 @@ void MKLDNNConvLayer::resetWgtValBwdData( ...@@ -370,8 +370,7 @@ void MKLDNNConvLayer::resetWgtValBwdData(
// since the primitive_desc would be different with wgtVal_ // since the primitive_desc would be different with wgtVal_
CHECK(wgtVal_) << "should have weight value"; CHECK(wgtVal_) << "should have weight value";
if (dataPD->weights_primitive_desc() != wgtVal_->getPrimitiveDesc()) { if (dataPD->weights_primitive_desc() != wgtVal_->getPrimitiveDesc()) {
wgtValBwdData_ = wgtValBwdData_ = MKLDNNMatrix::create(dataPD->weights_primitive_desc());
MKLDNNMatrix::create(nullptr, dataPD->weights_primitive_desc());
cvtWgtVal_ = MKLDNNMatrix::createReorder(wgtVal_, wgtValBwdData_); cvtWgtVal_ = MKLDNNMatrix::createReorder(wgtVal_, wgtValBwdData_);
CHECK(cvtWgtVal_); CHECK(cvtWgtVal_);
} else { } else {
......
...@@ -323,7 +323,7 @@ protected: ...@@ -323,7 +323,7 @@ protected:
if (mat == nullptr) { if (mat == nullptr) {
return; return;
} }
dnn = MKLDNNMatrix::create(mat, pd); dnn = MKLDNNMatrix::create(pd, mat);
} }
/** /**
...@@ -343,7 +343,7 @@ protected: ...@@ -343,7 +343,7 @@ protected:
in = std::dynamic_pointer_cast<MKLDNNMatrix>(inMat); in = std::dynamic_pointer_cast<MKLDNNMatrix>(inMat);
CHECK_EQ(inputIsOnlyMKLDNN(), in != nullptr); CHECK_EQ(inputIsOnlyMKLDNN(), in != nullptr);
if (in == nullptr || in->getFormat() == mkldnn::memory::format::nc) { if (in == nullptr || in->getFormat() == mkldnn::memory::format::nc) {
in = MKLDNNMatrix::create(inMat, extPD); in = MKLDNNMatrix::create(extPD, inMat);
} }
extInVal_ = isPaddleFormat(in->getFormat()) ? in : nullptr; extInVal_ = isPaddleFormat(in->getFormat()) ? in : nullptr;
if (in->getFormat() == mkldnn::memory::format::nc) { if (in->getFormat() == mkldnn::memory::format::nc) {
...@@ -353,8 +353,8 @@ protected: ...@@ -353,8 +353,8 @@ protected:
return; return;
} }
// need create reorder // need create reorder
in = MKLDNNMatrix::create(nullptr, *intPD); in = MKLDNNMatrix::create(*intPD);
extInVal_ = extInVal_ ? extInVal_ : MKLDNNMatrix::create(inMat, extPD); extInVal_ = extInVal_ ? extInVal_ : MKLDNNMatrix::create(extPD, inMat);
cvtInVal_ = MKLDNNMatrix::createReorder(extInVal_, in); cvtInVal_ = MKLDNNMatrix::createReorder(extInVal_, in);
CHECK(cvtInVal_) << "should not be emptry"; CHECK(cvtInVal_) << "should not be emptry";
} }
...@@ -366,18 +366,18 @@ protected: ...@@ -366,18 +366,18 @@ protected:
void resetOutValue(MKLDNNMatrixPtr& out, void resetOutValue(MKLDNNMatrixPtr& out,
mkldnn::memory::primitive_desc intPD) { mkldnn::memory::primitive_desc intPD) {
cvtOutVal_ = nullptr; cvtOutVal_ = nullptr;
out = MKLDNNMatrix::create(output_.value, intPD); out = MKLDNNMatrix::create(intPD, output_.value);
extOutVal_ = out; extOutVal_ = out;
if (outputIsOnlyMKLDNN() || isPaddleFormat(extOutVal_->getFormat())) { if (outputIsOnlyMKLDNN() || isPaddleFormat(extOutVal_->getFormat())) {
return; return;
} }
// need create reorder // need create reorder
CHECK_GT(bs_ * oc_ * oh_ * ow_, 0); CHECK_GT(bs_ * oc_ * oh_ * ow_, 0);
extOutVal_ = MKLDNNMatrix::create(output_.value, extOutVal_ = MKLDNNMatrix::create(mkldnn::memory::dims{bs_, oc_, oh_, ow_},
{bs_, oc_, oh_, ow_},
mkldnn::memory::format::nchw, mkldnn::memory::format::nchw,
engine_); engine_,
out = MKLDNNMatrix::create(nullptr, intPD); output_.value);
out = MKLDNNMatrix::create(intPD);
cvtOutVal_ = MKLDNNMatrix::createReorder(out, extOutVal_); cvtOutVal_ = MKLDNNMatrix::createReorder(out, extOutVal_);
CHECK(cvtOutVal_) << "should not be empty"; CHECK(cvtOutVal_) << "should not be empty";
} }
...@@ -402,7 +402,7 @@ protected: ...@@ -402,7 +402,7 @@ protected:
// and the mkldnn input layer will merge them to actual prev->output_.grad // and the mkldnn input layer will merge them to actual prev->output_.grad
const MatrixPtr& inMat = const MatrixPtr& inMat =
input->getOutputMapSize() <= 1 ? input->getOutputGrad() : nullptr; input->getOutputMapSize() <= 1 ? input->getOutputGrad() : nullptr;
in = MKLDNNMatrix::create(inMat, intPD); in = MKLDNNMatrix::create(intPD, inMat);
Argument& arg = input->getOutput(this->getName()); Argument& arg = input->getOutput(this->getName());
arg.grad = std::dynamic_pointer_cast<Matrix>(in); arg.grad = std::dynamic_pointer_cast<Matrix>(in);
CHECK(inVal_ != nullptr && inVal_->getPrimitiveDesc() == intPD) CHECK(inVal_ != nullptr && inVal_->getPrimitiveDesc() == intPD)
...@@ -418,10 +418,10 @@ protected: ...@@ -418,10 +418,10 @@ protected:
// need create reorder // need create reorder
CHECK(extInVal_ != nullptr && isPaddleFormat(extInVal_->getFormat())) CHECK(extInVal_ != nullptr && isPaddleFormat(extInVal_->getFormat()))
<< "should have external input value and the format must be nchw(nc)"; << "should have external input value and the format must be nchw(nc)";
extInGrad_ = MKLDNNMatrix::create(inMat, extInVal_->getPrimitiveDesc()); extInGrad_ = MKLDNNMatrix::create(extInVal_->getPrimitiveDesc(), inMat);
CHECK(inVal_ != nullptr && inVal_->getPrimitiveDesc() == intPD) CHECK(inVal_ != nullptr && inVal_->getPrimitiveDesc() == intPD)
<< "should have internal input value and primitive desc must equal"; << "should have internal input value and primitive desc must equal";
in = MKLDNNMatrix::create(nullptr, intPD); in = MKLDNNMatrix::create(intPD);
cvtInGrad_ = MKLDNNMatrix::createReorder(in, extInGrad_); cvtInGrad_ = MKLDNNMatrix::createReorder(in, extInGrad_);
CHECK(cvtInGrad_); CHECK(cvtInGrad_);
} }
...@@ -440,7 +440,7 @@ protected: ...@@ -440,7 +440,7 @@ protected:
extOutGrad_ = nullptr; extOutGrad_ = nullptr;
out = nullptr; out = nullptr;
MatrixPtr& outMat = output_.grad; MatrixPtr& outMat = output_.grad;
out = MKLDNNMatrix::create(outMat, intPD); out = MKLDNNMatrix::create(intPD, outMat);
resetMergeGrad(out); resetMergeGrad(out);
if (outputIsOnlyMKLDNN()) { if (outputIsOnlyMKLDNN()) {
return; return;
...@@ -453,10 +453,10 @@ protected: ...@@ -453,10 +453,10 @@ protected:
// need create reorder // need create reorder
CHECK(extOutVal_ != nullptr && isPaddleFormat(extOutVal_->getFormat())) CHECK(extOutVal_ != nullptr && isPaddleFormat(extOutVal_->getFormat()))
<< "should have external output value and the format must be nchw(nc)"; << "should have external output value and the format must be nchw(nc)";
extOutGrad_ = MKLDNNMatrix::create(outMat, extOutVal_->getPrimitiveDesc()); extOutGrad_ = MKLDNNMatrix::create(extOutVal_->getPrimitiveDesc(), outMat);
CHECK(outVal_ != nullptr && outVal_->getPrimitiveDesc() == intPD) CHECK(outVal_ != nullptr && outVal_->getPrimitiveDesc() == intPD)
<< "should have internal output value and primitive desc must equal"; << "should have internal output value and primitive desc must equal";
out = MKLDNNMatrix::create(nullptr, intPD); out = MKLDNNMatrix::create(intPD);
cvtOutGrad_ = MKLDNNMatrix::createReorder(extOutGrad_, out); cvtOutGrad_ = MKLDNNMatrix::createReorder(extOutGrad_, out);
CHECK(cvtOutGrad_); CHECK(cvtOutGrad_);
} }
...@@ -499,7 +499,7 @@ protected: ...@@ -499,7 +499,7 @@ protected:
tmpOutGrad_ = out; tmpOutGrad_ = out;
tmpCvt_ = nullptr; tmpCvt_ = nullptr;
if (out->getPrimitiveDesc() != srcPDs[0]) { if (out->getPrimitiveDesc() != srcPDs[0]) {
tmpOutGrad_ = MKLDNNMatrix::create(nullptr, srcPDs[0]); tmpOutGrad_ = MKLDNNMatrix::create(srcPDs[0]);
tmpCvt_ = MKLDNNMatrix::createReorder(tmpOutGrad_, out); tmpCvt_ = MKLDNNMatrix::createReorder(tmpOutGrad_, out);
CHECK(tmpCvt_); CHECK(tmpCvt_);
pipelineMergeGrad_.push_back(*tmpCvt_); pipelineMergeGrad_.push_back(*tmpCvt_);
......
...@@ -18,7 +18,7 @@ using namespace mkldnn; // NOLINT ...@@ -18,7 +18,7 @@ using namespace mkldnn; // NOLINT
namespace paddle { namespace paddle {
MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, memory::primitive_desc pd) { MKLDNNMatrixPtr MKLDNNMatrix::create(memory::primitive_desc pd, MatrixPtr m) {
memory::desc md = pd.desc(); memory::desc md = pd.desc();
size_t ndims = md.data.ndims; size_t ndims = md.data.ndims;
int* dims = md.data.dims; int* dims = md.data.dims;
...@@ -41,12 +41,12 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, memory::primitive_desc pd) { ...@@ -41,12 +41,12 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, memory::primitive_desc pd) {
return std::make_shared<MKLDNNMatrix>(cpuMatrix, pd); return std::make_shared<MKLDNNMatrix>(cpuMatrix, pd);
} }
MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, MKLDNNMatrixPtr MKLDNNMatrix::create(memory::dims dims,
memory::dims dims,
memory::format fmt, memory::format fmt,
engine& eg, engine& eg,
MatrixPtr m,
mkldnn::memory::data_type dtype) { mkldnn::memory::data_type dtype) {
return create(m, createPrimitiveDesc(dims, fmt, eg, dtype)); return create(createPrimitiveDesc(dims, fmt, eg, dtype), m);
} }
std::shared_ptr<reorder> MKLDNNMatrix::createReorder(const MKLDNNMatrixPtr& src, std::shared_ptr<reorder> MKLDNNMatrix::createReorder(const MKLDNNMatrixPtr& src,
......
...@@ -40,16 +40,17 @@ public: ...@@ -40,16 +40,17 @@ public:
/** /**
* Create MKLDNNMatrix from a MatrixPtr and memory primitive_desc * Create MKLDNNMatrix from a MatrixPtr and memory primitive_desc
*/ */
static MKLDNNMatrixPtr create(MatrixPtr m, mkldnn::memory::primitive_desc pd); static MKLDNNMatrixPtr create(mkldnn::memory::primitive_desc pd,
MatrixPtr m = nullptr);
/** /**
* Create MKLDNNMatrix from a MatrixPtr and memory details info * Create MKLDNNMatrix from a MatrixPtr and memory details info
*/ */
static MKLDNNMatrixPtr create( static MKLDNNMatrixPtr create(
MatrixPtr m,
mkldnn::memory::dims dims, mkldnn::memory::dims dims,
mkldnn::memory::format fmt, mkldnn::memory::format fmt,
mkldnn::engine& eg, mkldnn::engine& eg,
MatrixPtr m = nullptr,
mkldnn::memory::data_type dtype = mkldnn::memory::data_type::f32); mkldnn::memory::data_type dtype = mkldnn::memory::data_type::f32);
/** /**
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册