未验证 提交 ce3ee9bb 编写于 作者: P piotrekobiIntel 提交者: GitHub

Changed first batch of deprecated mkldnn headers and function names to new oneDNN names (#37040)

* Change first batch of mkldnn headers and namespace names to dnnl

* Revert changes to tensor.h, which require approval

* Format changes with pre-commit

* Add int32 tests

* Fix int32 tests and call GetDataFromTensor for int32

* Fix test
上级 9c5d5665
......@@ -100,21 +100,21 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
}
#ifdef PADDLE_WITH_MKLDNN
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using dnnl::memory;
using dnnl::primitive;
using dnnl::reorder;
void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
void* GetDataFromTensor(const Tensor& tensor, dnnl::memory::data_type type) {
switch (type) {
case mkldnn::memory::data_type::f32:
case dnnl::memory::data_type::f32:
return platform::to_void_cast(tensor.data<float>());
case mkldnn::memory::data_type::s8:
case dnnl::memory::data_type::s8:
return platform::to_void_cast(tensor.data<int8_t>());
case mkldnn::memory::data_type::u8:
case dnnl::memory::data_type::u8:
return platform::to_void_cast(tensor.data<unsigned char>());
case mkldnn::memory::data_type::s32:
case dnnl::memory::data_type::s32:
return platform::to_void_cast(tensor.data<int32_t>());
case mkldnn::memory::data_type::bf16:
case dnnl::memory::data_type::bf16:
return platform::to_void_cast(tensor.data<paddle::platform::bfloat16>());
default:
PADDLE_THROW(
......
......@@ -37,7 +37,7 @@ namespace paddle {
namespace framework {
#ifdef PADDLE_WITH_MKLDNN
using MKLDNNDataType = mkldnn::memory::data_type;
using MKLDNNDataType = dnnl::memory::data_type;
inline MKLDNNMemoryFormat ToMKLDNNFormat(const DataLayout& layout) {
switch (layout) {
......
......@@ -44,7 +44,7 @@ TEST(DataTransform, DataLayoutFunction) {
}
#ifdef PADDLE_WITH_MKLDNN
TEST(DataTransform, GetDataFromTensorDNNL) {
TEST(DataTransformBf16, GetDataFromTensorDNNL) {
auto place = paddle::platform::CPUPlace();
paddle::framework::Tensor in = paddle::framework::Tensor();
in.mutable_data<paddle::platform::bfloat16>(
......@@ -55,4 +55,14 @@ TEST(DataTransform, GetDataFromTensorDNNL) {
EXPECT_EQ(in_data, paddle::platform::to_void_cast(
in.data<paddle::platform::bfloat16>()));
}
TEST(DataTransformInt32, GetDataFromTensorDNNL) {
auto place = paddle::platform::CPUPlace();
paddle::framework::Tensor in = paddle::framework::Tensor();
in.mutable_data<int32_t>(paddle::framework::make_ddim({2, 3, 1, 2}), place);
void* in_data =
paddle::framework::GetDataFromTensor(in, dnnl::memory::data_type::s32);
EXPECT_EQ(in_data, paddle::platform::to_void_cast(in.data<int32_t>()));
}
#endif
......@@ -310,4 +310,117 @@ TEST(DataTypeTransform, CPUTransform) {
static_cast<paddle::platform::bfloat16>(in_data_bool[i]).x);
}
}
// data type transform from/to int32
{
paddle::framework::Tensor in;
paddle::framework::Tensor out;
int32_t* ptr =
in.mutable_data<int32_t>(paddle::framework::make_ddim({2, 3}), place);
int data_number = 2 * 3;
for (int i = 0; i < data_number; ++i) {
ptr[i] = i;
}
// transform from int32 to other data types
paddle::framework::TransDataType(kernel_int32, kernel_fp32, in, &out);
float* out_data_float = out.data<float>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_float[i], static_cast<float>(ptr[i]));
}
paddle::framework::TransDataType(kernel_int32, kernel_fp64, in, &out);
double* out_data_double = out.data<double>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_double[i], static_cast<double>(ptr[i]));
}
paddle::framework::TransDataType(kernel_int32, kernel_bf16, in, &out);
paddle::platform::bfloat16* out_data_bf16 =
out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_bf16[i],
static_cast<paddle::platform::bfloat16>(ptr[i]));
}
paddle::framework::TransDataType(kernel_int32, kernel_int64, in, &out);
int64_t* out_data_int64 = out.data<int64_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_int64[i], static_cast<int64_t>(ptr[i]));
}
paddle::framework::TransDataType(kernel_int32, kernel_bool, in, &out);
bool* out_data_bool = out.data<bool>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_bool[i], static_cast<bool>(ptr[i]));
}
// transform float to int32
float* in_data_float =
in.mutable_data<float>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_float[i] = i;
}
paddle::framework::TransDataType(kernel_fp32, kernel_int32, in, &out);
ptr = out.data<int32_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i], static_cast<int32_t>(in_data_float[i]));
}
// transform double to int32
double* in_data_double =
in.mutable_data<double>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_double[i] = i;
}
paddle::framework::TransDataType(kernel_fp64, kernel_int32, in, &out);
ptr = out.data<int32_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i], static_cast<int32_t>(in_data_double[i]));
}
// transform bfloat16 to int32
paddle::platform::bfloat16* in_data_bf16 =
in.mutable_data<paddle::platform::bfloat16>(
paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_bf16[i] = i;
}
paddle::framework::TransDataType(kernel_bf16, kernel_int32, in, &out);
ptr = out.data<int32_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i], static_cast<int32_t>(in_data_bf16[i]));
}
// transform int64 to int32
int64_t* in_data_int64 =
in.mutable_data<int64_t>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_int64[i] = i;
}
paddle::framework::TransDataType(kernel_int64, kernel_int32, in, &out);
ptr = out.data<int32_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i], static_cast<int32_t>(in_data_int64[i]));
}
// transform bool to int32
bool* in_data_bool =
in.mutable_data<bool>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_bool[i] = i;
}
paddle::framework::TransDataType(kernel_bool, kernel_int32, in, &out);
ptr = out.data<int32_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i], static_cast<int32_t>(in_data_bool[i]));
}
}
}
......@@ -26,9 +26,9 @@ namespace operators {
using framework::DataLayout;
using framework::Tensor;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::stream;
using dnnl::memory;
using dnnl::primitive;
using dnnl::stream;
template <typename T, dnnl::algorithm BINARY_OP>
class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
......
......@@ -31,12 +31,11 @@ class LSTMMKLDNNHandler
public:
LSTMMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const LoDTensor* input,
const Tensor* weight_h, const Tensor* h0, const Tensor* c0,
const bool is_reverse, const int64_t N, const int64_t Ti,
const int64_t IC, const int64_t OC,
const std::string& unique_name)
const dnnl::engine mkldnn_engine, platform::Place cpu_place,
const LoDTensor* input, const Tensor* weight_h,
const Tensor* h0, const Tensor* c0, const bool is_reverse,
const int64_t N, const int64_t Ti, const int64_t IC,
const int64_t OC, const std::string& unique_name)
: RNNMKLDNNHandler<T, dnnl::lstm_forward, T_out>(
ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, weight_h, h0,
is_reverse, N, Ti, IC, OC, 4,
......
......@@ -30,12 +30,11 @@ class RNNMKLDNNHandler : public platform::MKLDNNHandlerT<T, T_alg> {
public:
RNNMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const LoDTensor* input,
const Tensor* weight_h, const Tensor* h0,
const bool is_reverse, const int64_t N, const int64_t Ti,
const int64_t IC, const int64_t OC, const int64_t G,
const std::string& unique_name)
const dnnl::engine mkldnn_engine, platform::Place cpu_place,
const LoDTensor* input, const Tensor* weight_h,
const Tensor* h0, const bool is_reverse, const int64_t N,
const int64_t Ti, const int64_t IC, const int64_t OC,
const int64_t G, const std::string& unique_name)
: platform::MKLDNNHandlerT<T, T_alg>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
CreateKey(dev_ctx, unique_name, MKLDNNGetDataType<T>(), Ti)),
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include "mkldnn.hpp"
#include "dnnl.hpp"
#include "paddle/fluid/operators/mkldnn/axpy_handler.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device_context.h"
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "dnnl.hpp"
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/dequantize_op.h"
......@@ -23,13 +23,13 @@ limitations under the License. */
namespace paddle {
namespace operators {
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using dnnl::memory;
using dnnl::primitive;
using dnnl::reorder;
using platform::to_void_cast;
using Tensor = framework::Tensor;
using framework::DataLayout;
using mkldnn::stream;
using dnnl::stream;
using platform::GetMKLDNNFormat;
template <typename T>
......@@ -64,7 +64,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
auto src_tz = paddle::framework::vectorize<int64_t>(input->dims());
auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims());
mkldnn::memory::data_type src_dt =
dnnl::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type());
MKLDNNMemoryFormat src_fmt = input->format();
......@@ -76,34 +76,34 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
std::shared_ptr<mkldnn::memory> src_memory;
std::shared_ptr<mkldnn::memory> dst_memory;
std::shared_ptr<dnnl::memory> src_memory;
std::shared_ptr<dnnl::memory> dst_memory;
std::shared_ptr<reorder> reorder_p;
reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim));
if (reorder_p == nullptr) {
mkldnn::primitive_attr attri;
dnnl::primitive_attr attri;
int mask = 0;
float reorder_scale = 1. / scale_data;
attri.set_output_scales(mask, {reorder_scale});
if (with_shift) {
mkldnn::post_ops post_operations;
dnnl::post_ops post_operations;
post_operations.append_sum();
attri.set_post_ops(post_operations);
std::fill(output_data, output_data + output->numel(), reorder_shift);
}
auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt);
src_memory = std::make_shared<mkldnn::memory>(
src_md, engine, to_void_cast<T>(input_data));
src_memory = std::make_shared<dnnl::memory>(src_md, engine,
to_void_cast<T>(input_data));
auto dst_md =
platform::MKLDNNMemDesc({dst_tz}, memory::data_type::f32,
platform::MKLDNNFormatForSize(
dst_tz.size(), MKLDNNMemoryFormat::nchw));
dst_memory = std::make_shared<mkldnn::memory>(
dst_memory = std::make_shared<dnnl::memory>(
dst_md, engine, to_void_cast<float>(output_data));
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
......@@ -113,12 +113,12 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
dev_ctx.SetBlob(key_src_mem, src_memory);
dev_ctx.SetBlob(key_dst_mem, dst_memory);
} else {
src_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_src_mem));
src_memory =
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_src_mem));
src_memory->set_data_handle(to_void_cast<T>(input_data));
dst_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_dst_mem));
dst_memory =
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_dst_mem));
if (with_shift)
std::fill(output_data, output_data + output->numel(), reorder_shift);
dst_memory->set_data_handle(output->mutable_data<float>(ctx.GetPlace()));
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "dnnl.hpp"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/quantize_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
......@@ -21,13 +21,13 @@ limitations under the License. */
namespace paddle {
namespace operators {
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using dnnl::memory;
using dnnl::primitive;
using dnnl::reorder;
using platform::to_void_cast;
using Tensor = framework::Tensor;
using framework::DataLayout;
using mkldnn::stream;
using dnnl::stream;
using platform::GetMKLDNNFormat;
template <typename T>
......@@ -65,19 +65,19 @@ class QuantOpKernel : public framework::OpKernel<T> {
bool bfloat16 = ctx.Attr<bool>("bfloat16");
// TODO(jczaja): Refactor with Acquire API
std::shared_ptr<mkldnn::memory> src_memory;
std::shared_ptr<mkldnn::memory> dst_memory;
std::shared_ptr<dnnl::memory> src_memory;
std::shared_ptr<dnnl::memory> dst_memory;
std::shared_ptr<reorder> reorder_p;
std::string out_layout = ctx.Attr<std::string>("output_format");
MKLDNNMemoryFormat out_format =
platform::data_format_to_memory_format(out_layout);
mkldnn::primitive_attr attri;
dnnl::primitive_attr attri;
int mask = 0;
attri.set_output_scales(mask, {scale_data});
if (with_shift) {
mkldnn::post_ops post_operations;
dnnl::post_ops post_operations;
post_operations.append_sum();
attri.set_post_ops(post_operations);
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
......@@ -87,10 +87,10 @@ class QuantOpKernel : public framework::OpKernel<T> {
auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32,
input->format());
src_memory = std::make_shared<mkldnn::memory>(src_md, engine,
to_void_cast<T>(input_data));
src_memory = std::make_shared<dnnl::memory>(src_md, engine,
to_void_cast<T>(input_data));
std::shared_ptr<mkldnn::memory::desc> dst_md;
std::shared_ptr<dnnl::memory::desc> dst_md;
if (bfloat16) {
platform::SetDstMemoryQuantized<paddle::platform::bfloat16>(
ctx, output, dst_tz, engine, dst_md, dst_memory, out_format);
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "mkldnn.hpp"
#include "dnnl.hpp"
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/requantize_op.h"
......@@ -93,7 +93,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
int mask = 0;
attri.set_output_scales(mask, {reorder_scale});
if (with_shift) {
mkldnn::post_ops post_operations;
dnnl::post_ops post_operations;
post_operations.append_sum();
attri.set_post_ops(post_operations);
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
......
......@@ -20,22 +20,22 @@ limitations under the License. */
#include <string>
#include <utility>
#include <vector>
#include "mkldnn.hpp"
#include "dnnl.hpp"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
#ifdef PADDLE_WITH_MKLDNN
using MKLDNNMemoryFormat = mkldnn::memory::format_tag;
using MKLDNNMemoryFormat = dnnl::memory::format_tag;
#endif
namespace platform {
using MKLDNNStream = mkldnn::stream;
using MKLDNNEngine = mkldnn::engine;
using MKLDNNMemory = mkldnn::memory;
using MKLDNNMemoryDescriptor = mkldnn::memory::desc;
using MKLDNNPrimitive = mkldnn::primitive;
using MKLDNNPrimitiveDesc = mkldnn::handle<mkldnn_primitive_desc_t>;
using MKLDNNStream = dnnl::stream;
using MKLDNNEngine = dnnl::engine;
using MKLDNNMemory = dnnl::memory;
using MKLDNNMemoryDescriptor = dnnl::memory::desc;
using MKLDNNPrimitive = dnnl::primitive;
using MKLDNNPrimitiveDesc = dnnl::handle<dnnl_primitive_desc_t>;
typedef std::unique_ptr<MKLDNNStream> MKLDNNStreamPtr;
typedef std::unique_ptr<MKLDNNEngine> MKLDNNEnginePtr;
......@@ -62,7 +62,7 @@ using tf_pd = typename Type::primitive_desc;
template <typename Type, typename Engine, typename... Args>
std::shared_ptr<tf_pd<Type>> MKLDNNFwdPrimitiveDesc(const Engine& e,
Args&&... args) {
auto desc = tf_desc<Type>(mkldnn::prop_kind::forward, (args)...);
auto desc = tf_desc<Type>(dnnl::prop_kind::forward, (args)...);
auto pd = new tf_pd<Type>(desc, e);
return std::shared_ptr<tf_pd<Type>>(pd);
}
......@@ -129,10 +129,10 @@ struct mkldnn_dummy_primitive {
struct desc {};
};
inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
mkldnn::memory::data_type data_type,
MKLDNNMemoryFormat format) {
return mkldnn::memory::desc({dims}, data_type, format);
inline dnnl::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
dnnl::memory::data_type data_type,
MKLDNNMemoryFormat format) {
return dnnl::memory::desc({dims}, data_type, format);
}
inline void ClearMKLDNNCache(const platform::Place& place,
......@@ -159,36 +159,35 @@ inline void DontClearMKLDNNCache(const platform::Place& place) {
}
template <typename Type>
mkldnn::memory::data_type MKLDNNGetDataType() {
return mkldnn::memory::data_type::undef;
dnnl::memory::data_type MKLDNNGetDataType() {
return dnnl::memory::data_type::undef;
}
template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<float>() {
return mkldnn::memory::data_type::f32;
inline dnnl::memory::data_type MKLDNNGetDataType<float>() {
return dnnl::memory::data_type::f32;
}
template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<int32_t>() {
return mkldnn::memory::data_type::s32;
inline dnnl::memory::data_type MKLDNNGetDataType<int32_t>() {
return dnnl::memory::data_type::s32;
}
template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<int8_t>() {
return mkldnn::memory::data_type::s8;
inline dnnl::memory::data_type MKLDNNGetDataType<int8_t>() {
return dnnl::memory::data_type::s8;
}
template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<uint8_t>() {
return mkldnn::memory::data_type::u8;
inline dnnl::memory::data_type MKLDNNGetDataType<uint8_t>() {
return dnnl::memory::data_type::u8;
}
template <>
inline mkldnn::memory::data_type
MKLDNNGetDataType<paddle::platform::bfloat16>() {
return mkldnn::memory::data_type::bf16;
inline dnnl::memory::data_type MKLDNNGetDataType<paddle::platform::bfloat16>() {
return dnnl::memory::data_type::bf16;
}
inline void Reorder(mkldnn::memory src, mkldnn::memory dst,
const mkldnn::engine& engine) {
auto reorder_prim = mkldnn::reorder(src, dst);
inline void Reorder(dnnl::memory src, dnnl::memory dst,
const dnnl::engine& engine) {
auto reorder_prim = dnnl::reorder(src, dst);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
......@@ -196,8 +195,7 @@ inline void Reorder(mkldnn::memory src, mkldnn::memory dst,
astream.wait();
}
inline mkldnn::memory::format_tag GetMKLDNNFormat(
mkldnn::memory::desc mem_desc) {
inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
auto ndims = mem_desc.data.ndims;
auto strides = mem_desc.data.format_desc.blocking.strides;
auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks;
......@@ -205,62 +203,62 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat(
auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs;
if (ndims == 1) {
return mkldnn::memory::format_tag::x;
return dnnl::memory::format_tag::x;
} else if (ndims == 2) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1]) {
return mkldnn::memory::format_tag::nc;
return dnnl::memory::format_tag::nc;
} else {
return mkldnn::memory::format_tag::cn;
return dnnl::memory::format_tag::cn;
}
}
} else if (ndims == 3) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2]) {
return mkldnn::memory::format_tag::ncw;
return dnnl::memory::format_tag::ncw;
} else if (strides[1] >= strides[0] && strides[0] >= strides[2]) {
return mkldnn::memory::format_tag::ntc;
return dnnl::memory::format_tag::ntc;
} else {
return mkldnn::memory::format_tag::nwc;
return dnnl::memory::format_tag::nwc;
}
}
} else if (ndims == 4) {
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3]) {
return mkldnn::memory::format_tag::nchw;
return dnnl::memory::format_tag::nchw;
} else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
strides[1] >= strides[0]) {
return mkldnn::memory::format_tag::cdba;
return dnnl::memory::format_tag::cdba;
} else {
return mkldnn::memory::format_tag::nhwc;
return dnnl::memory::format_tag::nhwc;
}
} else if (inner_nblks == 1) {
if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
return mkldnn::memory::format_tag::nChw16c;
return dnnl::memory::format_tag::nChw16c;
} else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
return mkldnn::memory::format_tag::nChw8c;
return dnnl::memory::format_tag::nChw8c;
} else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return mkldnn::memory::format_tag::Acdb8a;
return dnnl::memory::format_tag::Acdb8a;
}
} else if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
return mkldnn::memory::format_tag::nChw4c;
return dnnl::memory::format_tag::nChw4c;
} else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) {
return mkldnn::memory::format_tag::Acdb16a;
return dnnl::memory::format_tag::Acdb16a;
}
}
} else if (inner_nblks == 2) {
if (inner_blks[0] == 16 && inner_blks[1] == 16) {
if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
return mkldnn::memory::format_tag::OIhw16i16o;
return dnnl::memory::format_tag::OIhw16i16o;
}
} else if (inner_blks[0] == 8 && inner_blks[1] == 8) {
if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
return mkldnn::memory::format_tag::OIhw8i8o;
return dnnl::memory::format_tag::OIhw8i8o;
}
}
}
......@@ -268,38 +266,38 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat(
if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::ncdhw;
return dnnl::memory::format_tag::ncdhw;
} else {
return mkldnn::memory::format_tag::ndhwc;
return dnnl::memory::format_tag::ndhwc;
}
} else if (inner_nblks == 1) {
if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return mkldnn::memory::format_tag::Acdeb8a;
return dnnl::memory::format_tag::Acdeb8a;
}
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::Abcde8a;
return dnnl::memory::format_tag::Abcde8a;
}
} else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::aBcde8b;
return dnnl::memory::format_tag::aBcde8b;
}
} else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return mkldnn::memory::format_tag::Acdeb16a;
return dnnl::memory::format_tag::Acdeb16a;
}
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::Abcde16a;
return dnnl::memory::format_tag::Abcde16a;
}
} else if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::aBcde16b;
return dnnl::memory::format_tag::aBcde16b;
}
}
}
......@@ -308,7 +306,7 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat(
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) {
return mkldnn::memory::format_tag::abcdef;
return dnnl::memory::format_tag::abcdef;
}
}
}
......@@ -325,10 +323,10 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat(
// for (int i=0;i<inner_nblks;++i) {
// std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl;
// }
return mkldnn::memory::format_tag::undef;
return dnnl::memory::format_tag::undef;
}
inline mkldnn::memory::format_tag GetMKLDNNFormat(const mkldnn::memory memory) {
inline dnnl::memory::format_tag GetMKLDNNFormat(const dnnl::memory memory) {
auto mem_desc = memory.get_desc();
return GetMKLDNNFormat(mem_desc);
}
......@@ -441,24 +439,24 @@ inline void AppendKey(std::string* key, const T& num) {
template <>
inline void AppendKey(std::string* key,
const mkldnn::memory::format_tag& format) {
const dnnl::memory::format_tag& format) {
key->append(std::to_string(static_cast<int>(format)));
}
template <>
inline void AppendKey(std::string* key,
const mkldnn::memory::data_type& data_type) {
const dnnl::memory::data_type& data_type) {
key->append(std::to_string(static_cast<int>(data_type)));
}
template <>
inline void AppendKey(std::string* key, const mkldnn::algorithm& algorithm) {
inline void AppendKey(std::string* key, const dnnl::algorithm& algorithm) {
key->append(std::to_string(static_cast<int>(algorithm)));
}
template <>
inline void AppendKey(std::string* key,
const mkldnn::normalization_flags& flags) {
const dnnl::normalization_flags& flags) {
key->append(std::to_string(static_cast<int>(flags)));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册