From 63ee7290f2656ed96dd773e84ad8b7b78600733a Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 22 Nov 2017 11:26:50 +0800 Subject: [PATCH] remove the tmp buffer --- paddle/gserver/layers/MKLDNNLayer.cpp | 18 ++---------------- paddle/gserver/layers/MKLDNNLayer.h | 5 ----- 2 files changed, 2 insertions(+), 21 deletions(-) diff --git a/paddle/gserver/layers/MKLDNNLayer.cpp b/paddle/gserver/layers/MKLDNNLayer.cpp index 28969d01a1..6fbf3c7fde 100644 --- a/paddle/gserver/layers/MKLDNNLayer.cpp +++ b/paddle/gserver/layers/MKLDNNLayer.cpp @@ -294,22 +294,8 @@ void MKLDNNLayer::resetMergeGrad(MKLDNNMatrixPtr& out) { srcs.push_back(*src); } - // TODO(TJ): remove me when mkldnn sum support different formats - for (size_t i = 1; i < srcPDs.size(); ++i) { - CHECK(srcPDs[0] == srcPDs[i]); - } - tmpOutGrad_ = out; - tmpCvt_ = nullptr; - if (out->getPrimitiveDesc() != srcPDs[0]) { - tmpOutGrad_ = MKLDNNMatrix::create(srcPDs[0]); - tmpCvt_ = MKLDNNMatrix::createReorder(tmpOutGrad_, out); - CHECK(tmpCvt_); - pipelineMergeGrad_.push_back(*tmpCvt_); - } - - auto sumPD = - sum::primitive_desc(tmpOutGrad_->getMemoryDesc(), scales, srcPDs); - mergeGrad_.reset(new sum(sumPD, srcs, *tmpOutGrad_)); + auto sumPD = sum::primitive_desc(out->getMemoryDesc(), scales, srcPDs); + mergeGrad_.reset(new sum(sumPD, srcs, *out)); pipelineMergeGrad_.insert(pipelineMergeGrad_.begin(), *mergeGrad_); } diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h index 8d1271da21..e48b9b5a91 100644 --- a/paddle/gserver/layers/MKLDNNLayer.h +++ b/paddle/gserver/layers/MKLDNNLayer.h @@ -94,11 +94,6 @@ protected: std::vector pipelineMergeGrad_; // tmp input argument to save input grad, only used to merge grad Argument tmpInArg_; - // since mkldnn sum do not support different formats: - // can refer to https://github.com/01org/mkl-dnn/issues/134 - // so need create reorder manually and save tmp MKLDNNMatrix - MKLDNNMatrixPtr tmpOutGrad_; - std::shared_ptr tmpCvt_; public: explicit MKLDNNLayer(const LayerConfig& config) -- GitLab