未验证 提交 ce3ee9bb 编写于 作者: P piotrekobiIntel 提交者: GitHub

Changed first batch of deprecated mkldnn headers and function names to new oneDNN names (#37040)

* Change first batch of mkldnn headers and namespace names to dnnl

* Revert changes to tensor.h, which require approval

* Format changes with pre-commit

* Add int32 tests

* Fix int32 tests and call GetDataFromTensor for int32

* Fix test
上级 9c5d5665
...@@ -100,21 +100,21 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var, ...@@ -100,21 +100,21 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
} }
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
using mkldnn::memory; using dnnl::memory;
using mkldnn::primitive; using dnnl::primitive;
using mkldnn::reorder; using dnnl::reorder;
void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) { void* GetDataFromTensor(const Tensor& tensor, dnnl::memory::data_type type) {
switch (type) { switch (type) {
case mkldnn::memory::data_type::f32: case dnnl::memory::data_type::f32:
return platform::to_void_cast(tensor.data<float>()); return platform::to_void_cast(tensor.data<float>());
case mkldnn::memory::data_type::s8: case dnnl::memory::data_type::s8:
return platform::to_void_cast(tensor.data<int8_t>()); return platform::to_void_cast(tensor.data<int8_t>());
case mkldnn::memory::data_type::u8: case dnnl::memory::data_type::u8:
return platform::to_void_cast(tensor.data<unsigned char>()); return platform::to_void_cast(tensor.data<unsigned char>());
case mkldnn::memory::data_type::s32: case dnnl::memory::data_type::s32:
return platform::to_void_cast(tensor.data<int32_t>()); return platform::to_void_cast(tensor.data<int32_t>());
case mkldnn::memory::data_type::bf16: case dnnl::memory::data_type::bf16:
return platform::to_void_cast(tensor.data<paddle::platform::bfloat16>()); return platform::to_void_cast(tensor.data<paddle::platform::bfloat16>());
default: default:
PADDLE_THROW( PADDLE_THROW(
......
...@@ -37,7 +37,7 @@ namespace paddle { ...@@ -37,7 +37,7 @@ namespace paddle {
namespace framework { namespace framework {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
using MKLDNNDataType = mkldnn::memory::data_type; using MKLDNNDataType = dnnl::memory::data_type;
inline MKLDNNMemoryFormat ToMKLDNNFormat(const DataLayout& layout) { inline MKLDNNMemoryFormat ToMKLDNNFormat(const DataLayout& layout) {
switch (layout) { switch (layout) {
......
...@@ -44,7 +44,7 @@ TEST(DataTransform, DataLayoutFunction) { ...@@ -44,7 +44,7 @@ TEST(DataTransform, DataLayoutFunction) {
} }
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
TEST(DataTransform, GetDataFromTensorDNNL) { TEST(DataTransformBf16, GetDataFromTensorDNNL) {
auto place = paddle::platform::CPUPlace(); auto place = paddle::platform::CPUPlace();
paddle::framework::Tensor in = paddle::framework::Tensor(); paddle::framework::Tensor in = paddle::framework::Tensor();
in.mutable_data<paddle::platform::bfloat16>( in.mutable_data<paddle::platform::bfloat16>(
...@@ -55,4 +55,14 @@ TEST(DataTransform, GetDataFromTensorDNNL) { ...@@ -55,4 +55,14 @@ TEST(DataTransform, GetDataFromTensorDNNL) {
EXPECT_EQ(in_data, paddle::platform::to_void_cast( EXPECT_EQ(in_data, paddle::platform::to_void_cast(
in.data<paddle::platform::bfloat16>())); in.data<paddle::platform::bfloat16>()));
} }
TEST(DataTransformInt32, GetDataFromTensorDNNL) {
auto place = paddle::platform::CPUPlace();
paddle::framework::Tensor in = paddle::framework::Tensor();
in.mutable_data<int32_t>(paddle::framework::make_ddim({2, 3, 1, 2}), place);
void* in_data =
paddle::framework::GetDataFromTensor(in, dnnl::memory::data_type::s32);
EXPECT_EQ(in_data, paddle::platform::to_void_cast(in.data<int32_t>()));
}
#endif #endif
...@@ -310,4 +310,117 @@ TEST(DataTypeTransform, CPUTransform) { ...@@ -310,4 +310,117 @@ TEST(DataTypeTransform, CPUTransform) {
static_cast<paddle::platform::bfloat16>(in_data_bool[i]).x); static_cast<paddle::platform::bfloat16>(in_data_bool[i]).x);
} }
} }
// data type transform from/to int32
{
paddle::framework::Tensor in;
paddle::framework::Tensor out;
int32_t* ptr =
in.mutable_data<int32_t>(paddle::framework::make_ddim({2, 3}), place);
int data_number = 2 * 3;
for (int i = 0; i < data_number; ++i) {
ptr[i] = i;
}
// transform from int32 to other data types
paddle::framework::TransDataType(kernel_int32, kernel_fp32, in, &out);
float* out_data_float = out.data<float>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_float[i], static_cast<float>(ptr[i]));
}
paddle::framework::TransDataType(kernel_int32, kernel_fp64, in, &out);
double* out_data_double = out.data<double>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_double[i], static_cast<double>(ptr[i]));
}
paddle::framework::TransDataType(kernel_int32, kernel_bf16, in, &out);
paddle::platform::bfloat16* out_data_bf16 =
out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_bf16[i],
static_cast<paddle::platform::bfloat16>(ptr[i]));
}
paddle::framework::TransDataType(kernel_int32, kernel_int64, in, &out);
int64_t* out_data_int64 = out.data<int64_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_int64[i], static_cast<int64_t>(ptr[i]));
}
paddle::framework::TransDataType(kernel_int32, kernel_bool, in, &out);
bool* out_data_bool = out.data<bool>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_bool[i], static_cast<bool>(ptr[i]));
}
// transform float to int32
float* in_data_float =
in.mutable_data<float>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_float[i] = i;
}
paddle::framework::TransDataType(kernel_fp32, kernel_int32, in, &out);
ptr = out.data<int32_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i], static_cast<int32_t>(in_data_float[i]));
}
// transform double to int32
double* in_data_double =
in.mutable_data<double>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_double[i] = i;
}
paddle::framework::TransDataType(kernel_fp64, kernel_int32, in, &out);
ptr = out.data<int32_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i], static_cast<int32_t>(in_data_double[i]));
}
// transform bfloat16 to int32
paddle::platform::bfloat16* in_data_bf16 =
in.mutable_data<paddle::platform::bfloat16>(
paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_bf16[i] = i;
}
paddle::framework::TransDataType(kernel_bf16, kernel_int32, in, &out);
ptr = out.data<int32_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i], static_cast<int32_t>(in_data_bf16[i]));
}
// transform int64 to int32
int64_t* in_data_int64 =
in.mutable_data<int64_t>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_int64[i] = i;
}
paddle::framework::TransDataType(kernel_int64, kernel_int32, in, &out);
ptr = out.data<int32_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i], static_cast<int32_t>(in_data_int64[i]));
}
// transform bool to int32
bool* in_data_bool =
in.mutable_data<bool>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_bool[i] = i;
}
paddle::framework::TransDataType(kernel_bool, kernel_int32, in, &out);
ptr = out.data<int32_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i], static_cast<int32_t>(in_data_bool[i]));
}
}
} }
...@@ -26,9 +26,9 @@ namespace operators { ...@@ -26,9 +26,9 @@ namespace operators {
using framework::DataLayout; using framework::DataLayout;
using framework::Tensor; using framework::Tensor;
using mkldnn::memory; using dnnl::memory;
using mkldnn::primitive; using dnnl::primitive;
using mkldnn::stream; using dnnl::stream;
template <typename T, dnnl::algorithm BINARY_OP> template <typename T, dnnl::algorithm BINARY_OP>
class EltwiseMKLDNNKernel : public framework::OpKernel<T> { class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
......
...@@ -31,12 +31,11 @@ class LSTMMKLDNNHandler ...@@ -31,12 +31,11 @@ class LSTMMKLDNNHandler
public: public:
LSTMMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, LSTMMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine, const dnnl::engine mkldnn_engine, platform::Place cpu_place,
platform::Place cpu_place, const LoDTensor* input, const LoDTensor* input, const Tensor* weight_h,
const Tensor* weight_h, const Tensor* h0, const Tensor* c0, const Tensor* h0, const Tensor* c0, const bool is_reverse,
const bool is_reverse, const int64_t N, const int64_t Ti, const int64_t N, const int64_t Ti, const int64_t IC,
const int64_t IC, const int64_t OC, const int64_t OC, const std::string& unique_name)
const std::string& unique_name)
: RNNMKLDNNHandler<T, dnnl::lstm_forward, T_out>( : RNNMKLDNNHandler<T, dnnl::lstm_forward, T_out>(
ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, weight_h, h0, ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, weight_h, h0,
is_reverse, N, Ti, IC, OC, 4, is_reverse, N, Ti, IC, OC, 4,
......
...@@ -30,12 +30,11 @@ class RNNMKLDNNHandler : public platform::MKLDNNHandlerT<T, T_alg> { ...@@ -30,12 +30,11 @@ class RNNMKLDNNHandler : public platform::MKLDNNHandlerT<T, T_alg> {
public: public:
RNNMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, RNNMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine, const dnnl::engine mkldnn_engine, platform::Place cpu_place,
platform::Place cpu_place, const LoDTensor* input, const LoDTensor* input, const Tensor* weight_h,
const Tensor* weight_h, const Tensor* h0, const Tensor* h0, const bool is_reverse, const int64_t N,
const bool is_reverse, const int64_t N, const int64_t Ti, const int64_t Ti, const int64_t IC, const int64_t OC,
const int64_t IC, const int64_t OC, const int64_t G, const int64_t G, const std::string& unique_name)
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, T_alg>( : platform::MKLDNNHandlerT<T, T_alg>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
CreateKey(dev_ctx, unique_name, MKLDNNGetDataType<T>(), Ti)), CreateKey(dev_ctx, unique_name, MKLDNNGetDataType<T>(), Ti)),
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "mkldnn.hpp" #include "dnnl.hpp"
#include "paddle/fluid/operators/mkldnn/axpy_handler.h" #include "paddle/fluid/operators/mkldnn/axpy_handler.h"
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "mkldnn.hpp" #include "dnnl.hpp"
#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/dequantize_op.h" #include "paddle/fluid/operators/dequantize_op.h"
...@@ -23,13 +23,13 @@ limitations under the License. */ ...@@ -23,13 +23,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using mkldnn::memory; using dnnl::memory;
using mkldnn::primitive; using dnnl::primitive;
using mkldnn::reorder; using dnnl::reorder;
using platform::to_void_cast; using platform::to_void_cast;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using framework::DataLayout; using framework::DataLayout;
using mkldnn::stream; using dnnl::stream;
using platform::GetMKLDNNFormat; using platform::GetMKLDNNFormat;
template <typename T> template <typename T>
...@@ -64,7 +64,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -64,7 +64,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
auto src_tz = paddle::framework::vectorize<int64_t>(input->dims()); auto src_tz = paddle::framework::vectorize<int64_t>(input->dims());
auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims()); auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims());
mkldnn::memory::data_type src_dt = dnnl::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
MKLDNNMemoryFormat src_fmt = input->format(); MKLDNNMemoryFormat src_fmt = input->format();
...@@ -76,34 +76,34 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -76,34 +76,34 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
const std::string key_src_mem = key + "@s"; const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d"; const std::string key_dst_mem = key + "@d";
std::shared_ptr<mkldnn::memory> src_memory; std::shared_ptr<dnnl::memory> src_memory;
std::shared_ptr<mkldnn::memory> dst_memory; std::shared_ptr<dnnl::memory> dst_memory;
std::shared_ptr<reorder> reorder_p; std::shared_ptr<reorder> reorder_p;
reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim)); reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim));
if (reorder_p == nullptr) { if (reorder_p == nullptr) {
mkldnn::primitive_attr attri; dnnl::primitive_attr attri;
int mask = 0; int mask = 0;
float reorder_scale = 1. / scale_data; float reorder_scale = 1. / scale_data;
attri.set_output_scales(mask, {reorder_scale}); attri.set_output_scales(mask, {reorder_scale});
if (with_shift) { if (with_shift) {
mkldnn::post_ops post_operations; dnnl::post_ops post_operations;
post_operations.append_sum(); post_operations.append_sum();
attri.set_post_ops(post_operations); attri.set_post_ops(post_operations);
std::fill(output_data, output_data + output->numel(), reorder_shift); std::fill(output_data, output_data + output->numel(), reorder_shift);
} }
auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt); auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt);
src_memory = std::make_shared<mkldnn::memory>( src_memory = std::make_shared<dnnl::memory>(src_md, engine,
src_md, engine, to_void_cast<T>(input_data)); to_void_cast<T>(input_data));
auto dst_md = auto dst_md =
platform::MKLDNNMemDesc({dst_tz}, memory::data_type::f32, platform::MKLDNNMemDesc({dst_tz}, memory::data_type::f32,
platform::MKLDNNFormatForSize( platform::MKLDNNFormatForSize(
dst_tz.size(), MKLDNNMemoryFormat::nchw)); dst_tz.size(), MKLDNNMemoryFormat::nchw));
dst_memory = std::make_shared<mkldnn::memory>( dst_memory = std::make_shared<dnnl::memory>(
dst_md, engine, to_void_cast<float>(output_data)); dst_md, engine, to_void_cast<float>(output_data));
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>( auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
...@@ -113,12 +113,12 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -113,12 +113,12 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
dev_ctx.SetBlob(key_src_mem, src_memory); dev_ctx.SetBlob(key_src_mem, src_memory);
dev_ctx.SetBlob(key_dst_mem, dst_memory); dev_ctx.SetBlob(key_dst_mem, dst_memory);
} else { } else {
src_memory = std::static_pointer_cast<mkldnn::memory>( src_memory =
dev_ctx.GetBlob(key_src_mem)); std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_src_mem));
src_memory->set_data_handle(to_void_cast<T>(input_data)); src_memory->set_data_handle(to_void_cast<T>(input_data));
dst_memory = std::static_pointer_cast<mkldnn::memory>( dst_memory =
dev_ctx.GetBlob(key_dst_mem)); std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_dst_mem));
if (with_shift) if (with_shift)
std::fill(output_data, output_data + output->numel(), reorder_shift); std::fill(output_data, output_data + output->numel(), reorder_shift);
dst_memory->set_data_handle(output->mutable_data<float>(ctx.GetPlace())); dst_memory->set_data_handle(output->mutable_data<float>(ctx.GetPlace()));
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "mkldnn.hpp" #include "dnnl.hpp"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/quantize_op.h" #include "paddle/fluid/operators/quantize_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
...@@ -21,13 +21,13 @@ limitations under the License. */ ...@@ -21,13 +21,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using mkldnn::memory; using dnnl::memory;
using mkldnn::primitive; using dnnl::primitive;
using mkldnn::reorder; using dnnl::reorder;
using platform::to_void_cast; using platform::to_void_cast;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using framework::DataLayout; using framework::DataLayout;
using mkldnn::stream; using dnnl::stream;
using platform::GetMKLDNNFormat; using platform::GetMKLDNNFormat;
template <typename T> template <typename T>
...@@ -65,19 +65,19 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -65,19 +65,19 @@ class QuantOpKernel : public framework::OpKernel<T> {
bool bfloat16 = ctx.Attr<bool>("bfloat16"); bool bfloat16 = ctx.Attr<bool>("bfloat16");
// TODO(jczaja): Refactor with Acquire API // TODO(jczaja): Refactor with Acquire API
std::shared_ptr<mkldnn::memory> src_memory; std::shared_ptr<dnnl::memory> src_memory;
std::shared_ptr<mkldnn::memory> dst_memory; std::shared_ptr<dnnl::memory> dst_memory;
std::shared_ptr<reorder> reorder_p; std::shared_ptr<reorder> reorder_p;
std::string out_layout = ctx.Attr<std::string>("output_format"); std::string out_layout = ctx.Attr<std::string>("output_format");
MKLDNNMemoryFormat out_format = MKLDNNMemoryFormat out_format =
platform::data_format_to_memory_format(out_layout); platform::data_format_to_memory_format(out_layout);
mkldnn::primitive_attr attri; dnnl::primitive_attr attri;
int mask = 0; int mask = 0;
attri.set_output_scales(mask, {scale_data}); attri.set_output_scales(mask, {scale_data});
if (with_shift) { if (with_shift) {
mkldnn::post_ops post_operations; dnnl::post_ops post_operations;
post_operations.append_sum(); post_operations.append_sum();
attri.set_post_ops(post_operations); attri.set_post_ops(post_operations);
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace()); uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
...@@ -87,10 +87,10 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -87,10 +87,10 @@ class QuantOpKernel : public framework::OpKernel<T> {
auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32, auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32,
input->format()); input->format());
src_memory = std::make_shared<mkldnn::memory>(src_md, engine, src_memory = std::make_shared<dnnl::memory>(src_md, engine,
to_void_cast<T>(input_data)); to_void_cast<T>(input_data));
std::shared_ptr<mkldnn::memory::desc> dst_md; std::shared_ptr<dnnl::memory::desc> dst_md;
if (bfloat16) { if (bfloat16) {
platform::SetDstMemoryQuantized<paddle::platform::bfloat16>( platform::SetDstMemoryQuantized<paddle::platform::bfloat16>(
ctx, output, dst_tz, engine, dst_md, dst_memory, out_format); ctx, output, dst_tz, engine, dst_md, dst_memory, out_format);
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "mkldnn.hpp" #include "dnnl.hpp"
#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/requantize_op.h" #include "paddle/fluid/operators/requantize_op.h"
...@@ -93,7 +93,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -93,7 +93,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
int mask = 0; int mask = 0;
attri.set_output_scales(mask, {reorder_scale}); attri.set_output_scales(mask, {reorder_scale});
if (with_shift) { if (with_shift) {
mkldnn::post_ops post_operations; dnnl::post_ops post_operations;
post_operations.append_sum(); post_operations.append_sum();
attri.set_post_ops(post_operations); attri.set_post_ops(post_operations);
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace()); uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
......
...@@ -20,22 +20,22 @@ limitations under the License. */ ...@@ -20,22 +20,22 @@ limitations under the License. */
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "mkldnn.hpp" #include "dnnl.hpp"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
using MKLDNNMemoryFormat = mkldnn::memory::format_tag; using MKLDNNMemoryFormat = dnnl::memory::format_tag;
#endif #endif
namespace platform { namespace platform {
using MKLDNNStream = mkldnn::stream; using MKLDNNStream = dnnl::stream;
using MKLDNNEngine = mkldnn::engine; using MKLDNNEngine = dnnl::engine;
using MKLDNNMemory = mkldnn::memory; using MKLDNNMemory = dnnl::memory;
using MKLDNNMemoryDescriptor = mkldnn::memory::desc; using MKLDNNMemoryDescriptor = dnnl::memory::desc;
using MKLDNNPrimitive = mkldnn::primitive; using MKLDNNPrimitive = dnnl::primitive;
using MKLDNNPrimitiveDesc = mkldnn::handle<mkldnn_primitive_desc_t>; using MKLDNNPrimitiveDesc = dnnl::handle<dnnl_primitive_desc_t>;
typedef std::unique_ptr<MKLDNNStream> MKLDNNStreamPtr; typedef std::unique_ptr<MKLDNNStream> MKLDNNStreamPtr;
typedef std::unique_ptr<MKLDNNEngine> MKLDNNEnginePtr; typedef std::unique_ptr<MKLDNNEngine> MKLDNNEnginePtr;
...@@ -62,7 +62,7 @@ using tf_pd = typename Type::primitive_desc; ...@@ -62,7 +62,7 @@ using tf_pd = typename Type::primitive_desc;
template <typename Type, typename Engine, typename... Args> template <typename Type, typename Engine, typename... Args>
std::shared_ptr<tf_pd<Type>> MKLDNNFwdPrimitiveDesc(const Engine& e, std::shared_ptr<tf_pd<Type>> MKLDNNFwdPrimitiveDesc(const Engine& e,
Args&&... args) { Args&&... args) {
auto desc = tf_desc<Type>(mkldnn::prop_kind::forward, (args)...); auto desc = tf_desc<Type>(dnnl::prop_kind::forward, (args)...);
auto pd = new tf_pd<Type>(desc, e); auto pd = new tf_pd<Type>(desc, e);
return std::shared_ptr<tf_pd<Type>>(pd); return std::shared_ptr<tf_pd<Type>>(pd);
} }
...@@ -129,10 +129,10 @@ struct mkldnn_dummy_primitive { ...@@ -129,10 +129,10 @@ struct mkldnn_dummy_primitive {
struct desc {}; struct desc {};
}; };
inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims, inline dnnl::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
mkldnn::memory::data_type data_type, dnnl::memory::data_type data_type,
MKLDNNMemoryFormat format) { MKLDNNMemoryFormat format) {
return mkldnn::memory::desc({dims}, data_type, format); return dnnl::memory::desc({dims}, data_type, format);
} }
inline void ClearMKLDNNCache(const platform::Place& place, inline void ClearMKLDNNCache(const platform::Place& place,
...@@ -159,36 +159,35 @@ inline void DontClearMKLDNNCache(const platform::Place& place) { ...@@ -159,36 +159,35 @@ inline void DontClearMKLDNNCache(const platform::Place& place) {
} }
template <typename Type> template <typename Type>
mkldnn::memory::data_type MKLDNNGetDataType() { dnnl::memory::data_type MKLDNNGetDataType() {
return mkldnn::memory::data_type::undef; return dnnl::memory::data_type::undef;
} }
template <> template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<float>() { inline dnnl::memory::data_type MKLDNNGetDataType<float>() {
return mkldnn::memory::data_type::f32; return dnnl::memory::data_type::f32;
} }
template <> template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<int32_t>() { inline dnnl::memory::data_type MKLDNNGetDataType<int32_t>() {
return mkldnn::memory::data_type::s32; return dnnl::memory::data_type::s32;
} }
template <> template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<int8_t>() { inline dnnl::memory::data_type MKLDNNGetDataType<int8_t>() {
return mkldnn::memory::data_type::s8; return dnnl::memory::data_type::s8;
} }
template <> template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<uint8_t>() { inline dnnl::memory::data_type MKLDNNGetDataType<uint8_t>() {
return mkldnn::memory::data_type::u8; return dnnl::memory::data_type::u8;
} }
template <> template <>
inline mkldnn::memory::data_type inline dnnl::memory::data_type MKLDNNGetDataType<paddle::platform::bfloat16>() {
MKLDNNGetDataType<paddle::platform::bfloat16>() { return dnnl::memory::data_type::bf16;
return mkldnn::memory::data_type::bf16;
} }
inline void Reorder(mkldnn::memory src, mkldnn::memory dst, inline void Reorder(dnnl::memory src, dnnl::memory dst,
const mkldnn::engine& engine) { const dnnl::engine& engine) {
auto reorder_prim = mkldnn::reorder(src, dst); auto reorder_prim = dnnl::reorder(src, dst);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder", platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp); platform::EventRole::kUniqueOp);
...@@ -196,8 +195,7 @@ inline void Reorder(mkldnn::memory src, mkldnn::memory dst, ...@@ -196,8 +195,7 @@ inline void Reorder(mkldnn::memory src, mkldnn::memory dst,
astream.wait(); astream.wait();
} }
inline mkldnn::memory::format_tag GetMKLDNNFormat( inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
mkldnn::memory::desc mem_desc) {
auto ndims = mem_desc.data.ndims; auto ndims = mem_desc.data.ndims;
auto strides = mem_desc.data.format_desc.blocking.strides; auto strides = mem_desc.data.format_desc.blocking.strides;
auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks; auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks;
...@@ -205,62 +203,62 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat( ...@@ -205,62 +203,62 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat(
auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs; auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs;
if (ndims == 1) { if (ndims == 1) {
return mkldnn::memory::format_tag::x; return dnnl::memory::format_tag::x;
} else if (ndims == 2) { } else if (ndims == 2) {
if (inner_nblks == 0) { if (inner_nblks == 0) {
if (strides[0] >= strides[1]) { if (strides[0] >= strides[1]) {
return mkldnn::memory::format_tag::nc; return dnnl::memory::format_tag::nc;
} else { } else {
return mkldnn::memory::format_tag::cn; return dnnl::memory::format_tag::cn;
} }
} }
} else if (ndims == 3) { } else if (ndims == 3) {
if (inner_nblks == 0) { if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2]) { if (strides[0] >= strides[1] && strides[1] >= strides[2]) {
return mkldnn::memory::format_tag::ncw; return dnnl::memory::format_tag::ncw;
} else if (strides[1] >= strides[0] && strides[0] >= strides[2]) { } else if (strides[1] >= strides[0] && strides[0] >= strides[2]) {
return mkldnn::memory::format_tag::ntc; return dnnl::memory::format_tag::ntc;
} else { } else {
return mkldnn::memory::format_tag::nwc; return dnnl::memory::format_tag::nwc;
} }
} }
} else if (ndims == 4) { } else if (ndims == 4) {
if (inner_nblks == 0) { if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] && if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3]) { strides[2] >= strides[3]) {
return mkldnn::memory::format_tag::nchw; return dnnl::memory::format_tag::nchw;
} else if (strides[2] >= strides[3] && strides[3] >= strides[1] && } else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
strides[1] >= strides[0]) { strides[1] >= strides[0]) {
return mkldnn::memory::format_tag::cdba; return dnnl::memory::format_tag::cdba;
} else { } else {
return mkldnn::memory::format_tag::nhwc; return dnnl::memory::format_tag::nhwc;
} }
} else if (inner_nblks == 1) { } else if (inner_nblks == 1) {
if (inner_blks[0] == 16 && inner_idxs[0] == 1) { if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
return mkldnn::memory::format_tag::nChw16c; return dnnl::memory::format_tag::nChw16c;
} else if (inner_blks[0] == 8 && inner_idxs[0] == 1) { } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
return mkldnn::memory::format_tag::nChw8c; return dnnl::memory::format_tag::nChw8c;
} else if (inner_blks[0] == 8 && inner_idxs[0] == 0) { } else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] && if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) { strides[3] >= strides[1]) {
return mkldnn::memory::format_tag::Acdb8a; return dnnl::memory::format_tag::Acdb8a;
} }
} else if (inner_blks[0] == 4 && inner_idxs[0] == 1) { } else if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
return mkldnn::memory::format_tag::nChw4c; return dnnl::memory::format_tag::nChw4c;
} else if (inner_blks[0] == 16 && inner_idxs[0] == 0) { } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] && if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) { strides[3] >= strides[1]) {
return mkldnn::memory::format_tag::Acdb16a; return dnnl::memory::format_tag::Acdb16a;
} }
} }
} else if (inner_nblks == 2) { } else if (inner_nblks == 2) {
if (inner_blks[0] == 16 && inner_blks[1] == 16) { if (inner_blks[0] == 16 && inner_blks[1] == 16) {
if (inner_idxs[0] == 1 && inner_idxs[1] == 0) { if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
return mkldnn::memory::format_tag::OIhw16i16o; return dnnl::memory::format_tag::OIhw16i16o;
} }
} else if (inner_blks[0] == 8 && inner_blks[1] == 8) { } else if (inner_blks[0] == 8 && inner_blks[1] == 8) {
if (inner_idxs[0] == 1 && inner_idxs[1] == 0) { if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
return mkldnn::memory::format_tag::OIhw8i8o; return dnnl::memory::format_tag::OIhw8i8o;
} }
} }
} }
...@@ -268,38 +266,38 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat( ...@@ -268,38 +266,38 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat(
if (inner_nblks == 0) { if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] && if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) { strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::ncdhw; return dnnl::memory::format_tag::ncdhw;
} else { } else {
return mkldnn::memory::format_tag::ndhwc; return dnnl::memory::format_tag::ndhwc;
} }
} else if (inner_nblks == 1) { } else if (inner_nblks == 1) {
if (inner_blks[0] == 8 && inner_idxs[0] == 0) { if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] && if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) { strides[3] >= strides[4] && strides[4] >= strides[1]) {
return mkldnn::memory::format_tag::Acdeb8a; return dnnl::memory::format_tag::Acdeb8a;
} }
if (strides[0] >= strides[1] && strides[1] >= strides[2] && if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) { strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::Abcde8a; return dnnl::memory::format_tag::Abcde8a;
} }
} else if (inner_blks[0] == 8 && inner_idxs[0] == 1) { } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] && if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) { strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::aBcde8b; return dnnl::memory::format_tag::aBcde8b;
} }
} else if (inner_blks[0] == 16 && inner_idxs[0] == 0) { } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] && if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) { strides[3] >= strides[4] && strides[4] >= strides[1]) {
return mkldnn::memory::format_tag::Acdeb16a; return dnnl::memory::format_tag::Acdeb16a;
} }
if (strides[0] >= strides[1] && strides[1] >= strides[2] && if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) { strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::Abcde16a; return dnnl::memory::format_tag::Abcde16a;
} }
} else if (inner_blks[0] == 16 && inner_idxs[0] == 1) { } else if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] && if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) { strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::aBcde16b; return dnnl::memory::format_tag::aBcde16b;
} }
} }
} }
...@@ -308,7 +306,7 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat( ...@@ -308,7 +306,7 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat(
if (strides[0] >= strides[1] && strides[1] >= strides[2] && if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4] && strides[2] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) { strides[4] >= strides[5]) {
return mkldnn::memory::format_tag::abcdef; return dnnl::memory::format_tag::abcdef;
} }
} }
} }
...@@ -325,10 +323,10 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat( ...@@ -325,10 +323,10 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat(
// for (int i=0;i<inner_nblks;++i) { // for (int i=0;i<inner_nblks;++i) {
// std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl; // std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl;
// } // }
return mkldnn::memory::format_tag::undef; return dnnl::memory::format_tag::undef;
} }
inline mkldnn::memory::format_tag GetMKLDNNFormat(const mkldnn::memory memory) { inline dnnl::memory::format_tag GetMKLDNNFormat(const dnnl::memory memory) {
auto mem_desc = memory.get_desc(); auto mem_desc = memory.get_desc();
return GetMKLDNNFormat(mem_desc); return GetMKLDNNFormat(mem_desc);
} }
...@@ -441,24 +439,24 @@ inline void AppendKey(std::string* key, const T& num) { ...@@ -441,24 +439,24 @@ inline void AppendKey(std::string* key, const T& num) {
template <> template <>
inline void AppendKey(std::string* key, inline void AppendKey(std::string* key,
const mkldnn::memory::format_tag& format) { const dnnl::memory::format_tag& format) {
key->append(std::to_string(static_cast<int>(format))); key->append(std::to_string(static_cast<int>(format)));
} }
template <> template <>
inline void AppendKey(std::string* key, inline void AppendKey(std::string* key,
const mkldnn::memory::data_type& data_type) { const dnnl::memory::data_type& data_type) {
key->append(std::to_string(static_cast<int>(data_type))); key->append(std::to_string(static_cast<int>(data_type)));
} }
template <> template <>
inline void AppendKey(std::string* key, const mkldnn::algorithm& algorithm) { inline void AppendKey(std::string* key, const dnnl::algorithm& algorithm) {
key->append(std::to_string(static_cast<int>(algorithm))); key->append(std::to_string(static_cast<int>(algorithm)));
} }
template <> template <>
inline void AppendKey(std::string* key, inline void AppendKey(std::string* key,
const mkldnn::normalization_flags& flags) { const dnnl::normalization_flags& flags) {
key->append(std::to_string(static_cast<int>(flags))); key->append(std::to_string(static_cast<int>(flags)));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册