diff --git a/paddle/math/MKLDNNMatrix.cpp b/paddle/math/MKLDNNMatrix.cpp index 21a8f73c3e650d4b3c3b86247594cd965f4ead35..a710479bab82ed52122cf59bb14a05ccbd4aa05c 100644 --- a/paddle/math/MKLDNNMatrix.cpp +++ b/paddle/math/MKLDNNMatrix.cpp @@ -152,12 +152,7 @@ void MKLDNNMatrix::downSpatial() { } memory::desc md = memory::desc(dstDims, getDtype(), dstFmt); memory::primitive_desc pd = memory::primitive_desc(md, getEngine()); - mkldnn_primitive_t result; - mkldnn::error::wrap_c_api( - mkldnn_primitive_create(&result, pd.get(), nullptr, nullptr), - "could not create a memory primitive"); - reset(result); - set_data_handle(data_); + resetMKLDNNMemory(pd, data_); } } // namespace paddle diff --git a/paddle/math/MKLDNNMatrix.h b/paddle/math/MKLDNNMatrix.h index 54cfefe23b3dc70fd12fd2ca8886c941047b59f7..39d40a1f61609a649d3341c170d24b0604921ac2 100644 --- a/paddle/math/MKLDNNMatrix.h +++ b/paddle/math/MKLDNNMatrix.h @@ -145,6 +145,27 @@ public: m_.reset(); } + /** + * override the CpuMatrix::resize + */ + void resize(size_t newHeight, size_t newWidth) override { + m_->resize(newHeight, newWidth); + if (data_ == m_->getData() && elementCnt_ == newHeight * newWidth) { + return; + } + CpuMatrix::setData(data_); + height_ = newHeight; + width_ = newWidth; + elementCnt_ = newHeight * newWidth; + stride_ = width_; + auto pd = mkldnn::memory::primitive_desc( + mkldnn::memory::desc({(int)newHeight, (int)newWidth}, + getDtype(), + mkldnn::memory::format::nc), + getEngine()); + resetMKLDNNMemory(pd, data_); + } + /** * override Matrix::getData * check data before return @@ -215,6 +236,17 @@ protected: memory::format srcFmt, memory::format dstFmt, memory::dims dm); + /** + * reset this MKLDNN Memory from primitve desc + */ + void resetMKLDNNMemory(memory::primitive_desc pd, real* data) { + mkldnn_primitive_t result; + mkldnn::error::wrap_c_api( + mkldnn_primitive_create(&result, pd.get(), nullptr, nullptr), + "could not create a memory primitive"); + reset(result); + set_data_handle(data); + } private: // save the CpuMatrixPtr in case the buffer released outside