提交 ca3db070 编写于 作者: T tensor-tang

add createReorder and createMemoryDesc in MKLDNNMatrix

上级 171fee2c
...@@ -49,6 +49,31 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, ...@@ -49,6 +49,31 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m,
return create(m, memory::primitive_desc(memory::desc(dims, dtype, fmt), eg)); return create(m, memory::primitive_desc(memory::desc(dims, dtype, fmt), eg));
} }
std::shared_ptr<reorder> MKLDNNMatrix::createReorder(const MKLDNNMatrixPtr& src,
const MKLDNNMatrixPtr& dst,
bool checkData) {
if (src == dst) {
return nullptr;
}
if (src->getPrimitiveDesc() == dst->getPrimitiveDesc()) {
return nullptr;
}
if (checkData && (src->getData() == dst->getData())) {
LOG(FATAL) << "can not create reorder with inplace data";
return nullptr;
}
memory::dims srcDims = src->getDims();
memory::dims dstDims = dst->getDims();
CHECK_EQ(srcDims.size(), dstDims.size());
for (size_t i = 0; i < srcDims.size(); ++i) {
CHECK_EQ(srcDims[i], dstDims[i]);
}
return std::make_shared<reorder>(*src, *dst);
}
void MKLDNNMatrix::reorderDataFrom(const MKLDNNMatrixPtr& m, void MKLDNNMatrix::reorderDataFrom(const MKLDNNMatrixPtr& m,
memory::format srcFmt, memory::format srcFmt,
memory::dims targetDim) { memory::dims targetDim) {
......
...@@ -52,6 +52,25 @@ public: ...@@ -52,6 +52,25 @@ public:
mkldnn::engine& eg, mkldnn::engine& eg,
mkldnn::memory::data_type dtype = mkldnn::memory::data_type::f32); mkldnn::memory::data_type dtype = mkldnn::memory::data_type::f32);
/**
* Create Memory descriptor.
* default with any format and f32 dtype
*/
static mkldnn::memory::desc createMemoryDesc(
const mkldnn::memory::dims& dims,
const mkldnn::memory::format& fmt = mkldnn::memory::format::any,
const mkldnn::memory::data_type& dtype = mkldnn::memory::data_type::f32) {
return mkldnn::memory::desc(dims, dtype, fmt);
}
/**
* Create reorder primitive.
*/
static std::shared_ptr<mkldnn::reorder> createReorder(
const MKLDNNMatrixPtr& src,
const MKLDNNMatrixPtr& dst,
bool checkData = true);
public: public:
/** /**
* Reorder this MKLDNNMatrix from other format. * Reorder this MKLDNNMatrix from other format.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册