未验证 提交 1c31bb94 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #5543 from tensor-tang/ds2

add resize of MKLDNNMatrix
...@@ -152,12 +152,7 @@ void MKLDNNMatrix::downSpatial() { ...@@ -152,12 +152,7 @@ void MKLDNNMatrix::downSpatial() {
} }
memory::desc md = memory::desc(dstDims, getDtype(), dstFmt); memory::desc md = memory::desc(dstDims, getDtype(), dstFmt);
memory::primitive_desc pd = memory::primitive_desc(md, getEngine()); memory::primitive_desc pd = memory::primitive_desc(md, getEngine());
mkldnn_primitive_t result; resetMKLDNNMemory(pd, data_);
mkldnn::error::wrap_c_api(
mkldnn_primitive_create(&result, pd.get(), nullptr, nullptr),
"could not create a memory primitive");
reset(result);
set_data_handle(data_);
} }
} // namespace paddle } // namespace paddle
...@@ -145,6 +145,27 @@ public: ...@@ -145,6 +145,27 @@ public:
m_.reset(); m_.reset();
} }
/**
* override the CpuMatrix::resize
*/
void resize(size_t newHeight, size_t newWidth) override {
m_->resize(newHeight, newWidth);
if (data_ == m_->getData() && elementCnt_ == newHeight * newWidth) {
return;
}
CpuMatrix::setData(data_);
height_ = newHeight;
width_ = newWidth;
elementCnt_ = newHeight * newWidth;
stride_ = width_;
auto pd = mkldnn::memory::primitive_desc(
mkldnn::memory::desc({(int)newHeight, (int)newWidth},
getDtype(),
mkldnn::memory::format::nc),
getEngine());
resetMKLDNNMemory(pd, data_);
}
/** /**
* override Matrix::getData * override Matrix::getData
* check data before return * check data before return
...@@ -215,6 +236,17 @@ protected: ...@@ -215,6 +236,17 @@ protected:
memory::format srcFmt, memory::format srcFmt,
memory::format dstFmt, memory::format dstFmt,
memory::dims dm); memory::dims dm);
/**
* reset this MKLDNN Memory from primitve desc
*/
void resetMKLDNNMemory(memory::primitive_desc pd, real* data) {
mkldnn_primitive_t result;
mkldnn::error::wrap_c_api(
mkldnn_primitive_create(&result, pd.get(), nullptr, nullptr),
"could not create a memory primitive");
reset(result);
set_data_handle(data);
}
private: private:
// save the CpuMatrixPtr in case the buffer released outside // save the CpuMatrixPtr in case the buffer released outside
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册