未验证 提交 ce3ee9bb 编写于 作者: P piotrekobiIntel 提交者: GitHub

Changed first batch of deprecated mkldnn headers and function names to new oneDNN names (#37040)

* Change first batch of mkldnn headers and namespace names to dnnl

* Revert changes to tensor.h, which require approval

* Format changes with pre-commit

* Add int32 tests

* Fix int32 tests and call GetDataFromTensor for int32

* Fix test
上级 9c5d5665
...@@ -100,21 +100,21 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var, ...@@ -100,21 +100,21 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
} }
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
using mkldnn::memory; using dnnl::memory;
using mkldnn::primitive; using dnnl::primitive;
using mkldnn::reorder; using dnnl::reorder;
void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) { void* GetDataFromTensor(const Tensor& tensor, dnnl::memory::data_type type) {
switch (type) { switch (type) {
case mkldnn::memory::data_type::f32: case dnnl::memory::data_type::f32:
return platform::to_void_cast(tensor.data<float>()); return platform::to_void_cast(tensor.data<float>());
case mkldnn::memory::data_type::s8: case dnnl::memory::data_type::s8:
return platform::to_void_cast(tensor.data<int8_t>()); return platform::to_void_cast(tensor.data<int8_t>());
case mkldnn::memory::data_type::u8: case dnnl::memory::data_type::u8:
return platform::to_void_cast(tensor.data<unsigned char>()); return platform::to_void_cast(tensor.data<unsigned char>());
case mkldnn::memory::data_type::s32: case dnnl::memory::data_type::s32:
return platform::to_void_cast(tensor.data<int32_t>()); return platform::to_void_cast(tensor.data<int32_t>());
case mkldnn::memory::data_type::bf16: case dnnl::memory::data_type::bf16:
return platform::to_void_cast(tensor.data<paddle::platform::bfloat16>()); return platform::to_void_cast(tensor.data<paddle::platform::bfloat16>());
default: default:
PADDLE_THROW( PADDLE_THROW(
......
...@@ -37,7 +37,7 @@ namespace paddle { ...@@ -37,7 +37,7 @@ namespace paddle {
namespace framework { namespace framework {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
using MKLDNNDataType = mkldnn::memory::data_type; using MKLDNNDataType = dnnl::memory::data_type;
inline MKLDNNMemoryFormat ToMKLDNNFormat(const DataLayout& layout) { inline MKLDNNMemoryFormat ToMKLDNNFormat(const DataLayout& layout) {
switch (layout) { switch (layout) {
......
...@@ -44,7 +44,7 @@ TEST(DataTransform, DataLayoutFunction) { ...@@ -44,7 +44,7 @@ TEST(DataTransform, DataLayoutFunction) {
} }
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
TEST(DataTransform, GetDataFromTensorDNNL) { TEST(DataTransformBf16, GetDataFromTensorDNNL) {
auto place = paddle::platform::CPUPlace(); auto place = paddle::platform::CPUPlace();
paddle::framework::Tensor in = paddle::framework::Tensor(); paddle::framework::Tensor in = paddle::framework::Tensor();
in.mutable_data<paddle::platform::bfloat16>( in.mutable_data<paddle::platform::bfloat16>(
...@@ -55,4 +55,14 @@ TEST(DataTransform, GetDataFromTensorDNNL) { ...@@ -55,4 +55,14 @@ TEST(DataTransform, GetDataFromTensorDNNL) {
EXPECT_EQ(in_data, paddle::platform::to_void_cast( EXPECT_EQ(in_data, paddle::platform::to_void_cast(
in.data<paddle::platform::bfloat16>())); in.data<paddle::platform::bfloat16>()));
} }
TEST(DataTransformInt32, GetDataFromTensorDNNL) {
auto place = paddle::platform::CPUPlace();
paddle::framework::Tensor in = paddle::framework::Tensor();
in.mutable_data<int32_t>(paddle::framework::make_ddim({2, 3, 1, 2}), place);
void* in_data =
paddle::framework::GetDataFromTensor(in, dnnl::memory::data_type::s32);
EXPECT_EQ(in_data, paddle::platform::to_void_cast(in.data<int32_t>()));
}
#endif #endif
...@@ -310,4 +310,117 @@ TEST(DataTypeTransform, CPUTransform) { ...@@ -310,4 +310,117 @@ TEST(DataTypeTransform, CPUTransform) {
static_cast<paddle::platform::bfloat16>(in_data_bool[i]).x); static_cast<paddle::platform::bfloat16>(in_data_bool[i]).x);
} }
} }
// data type transform from/to int32
{
paddle::framework::Tensor in;
paddle::framework::Tensor out;
int32_t* ptr =
in.mutable_data<int32_t>(paddle::framework::make_ddim({2, 3}), place);
int data_number = 2 * 3;
for (int i = 0; i < data_number; ++i) {
ptr[i] = i;
}
// transform from int32 to other data types
paddle::framework::TransDataType(kernel_int32, kernel_fp32, in, &out);
float* out_data_float = out.data<float>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_float[i], static_cast<float>(ptr[i]));
}
paddle::framework::TransDataType(kernel_int32, kernel_fp64, in, &out);
double* out_data_double = out.data<double>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_double[i], static_cast<double>(ptr[i]));
}
paddle::framework::TransDataType(kernel_int32, kernel_bf16, in, &out);
paddle::platform::bfloat16* out_data_bf16 =
out.data<paddle::platform::bfloat16>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_bf16[i],
static_cast<paddle::platform::bfloat16>(ptr[i]));
}
paddle::framework::TransDataType(kernel_int32, kernel_int64, in, &out);
int64_t* out_data_int64 = out.data<int64_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_int64[i], static_cast<int64_t>(ptr[i]));
}
paddle::framework::TransDataType(kernel_int32, kernel_bool, in, &out);
bool* out_data_bool = out.data<bool>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(out_data_bool[i], static_cast<bool>(ptr[i]));
}
// transform float to int32
float* in_data_float =
in.mutable_data<float>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_float[i] = i;
}
paddle::framework::TransDataType(kernel_fp32, kernel_int32, in, &out);
ptr = out.data<int32_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i], static_cast<int32_t>(in_data_float[i]));
}
// transform double to int32
double* in_data_double =
in.mutable_data<double>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_double[i] = i;
}
paddle::framework::TransDataType(kernel_fp64, kernel_int32, in, &out);
ptr = out.data<int32_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i], static_cast<int32_t>(in_data_double[i]));
}
// transform bfloat16 to int32
paddle::platform::bfloat16* in_data_bf16 =
in.mutable_data<paddle::platform::bfloat16>(
paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_bf16[i] = i;
}
paddle::framework::TransDataType(kernel_bf16, kernel_int32, in, &out);
ptr = out.data<int32_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i], static_cast<int32_t>(in_data_bf16[i]));
}
// transform int64 to int32
int64_t* in_data_int64 =
in.mutable_data<int64_t>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_int64[i] = i;
}
paddle::framework::TransDataType(kernel_int64, kernel_int32, in, &out);
ptr = out.data<int32_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i], static_cast<int32_t>(in_data_int64[i]));
}
// transform bool to int32
bool* in_data_bool =
in.mutable_data<bool>(paddle::framework::make_ddim({2, 3}), place);
for (int i = 0; i < data_number; ++i) {
in_data_bool[i] = i;
}
paddle::framework::TransDataType(kernel_bool, kernel_int32, in, &out);
ptr = out.data<int32_t>();
for (int i = 0; i < data_number; ++i) {
EXPECT_EQ(ptr[i], static_cast<int32_t>(in_data_bool[i]));
}
}
} }
...@@ -26,9 +26,9 @@ namespace operators { ...@@ -26,9 +26,9 @@ namespace operators {
using framework::DataLayout; using framework::DataLayout;
using framework::Tensor; using framework::Tensor;
using mkldnn::memory; using dnnl::memory;
using mkldnn::primitive; using dnnl::primitive;
using mkldnn::stream; using dnnl::stream;
template <typename T, dnnl::algorithm BINARY_OP> template <typename T, dnnl::algorithm BINARY_OP>
class EltwiseMKLDNNKernel : public framework::OpKernel<T> { class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
......
...@@ -31,12 +31,11 @@ class LSTMMKLDNNHandler ...@@ -31,12 +31,11 @@ class LSTMMKLDNNHandler
public: public:
LSTMMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, LSTMMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine, const dnnl::engine mkldnn_engine, platform::Place cpu_place,
platform::Place cpu_place, const LoDTensor* input, const LoDTensor* input, const Tensor* weight_h,
const Tensor* weight_h, const Tensor* h0, const Tensor* c0, const Tensor* h0, const Tensor* c0, const bool is_reverse,
const bool is_reverse, const int64_t N, const int64_t Ti, const int64_t N, const int64_t Ti, const int64_t IC,
const int64_t IC, const int64_t OC, const int64_t OC, const std::string& unique_name)
const std::string& unique_name)
: RNNMKLDNNHandler<T, dnnl::lstm_forward, T_out>( : RNNMKLDNNHandler<T, dnnl::lstm_forward, T_out>(
ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, weight_h, h0, ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, weight_h, h0,
is_reverse, N, Ti, IC, OC, 4, is_reverse, N, Ti, IC, OC, 4,
......
...@@ -30,12 +30,11 @@ class RNNMKLDNNHandler : public platform::MKLDNNHandlerT<T, T_alg> { ...@@ -30,12 +30,11 @@ class RNNMKLDNNHandler : public platform::MKLDNNHandlerT<T, T_alg> {
public: public:
RNNMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, RNNMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine, const dnnl::engine mkldnn_engine, platform::Place cpu_place,
platform::Place cpu_place, const LoDTensor* input, const LoDTensor* input, const Tensor* weight_h,
const Tensor* weight_h, const Tensor* h0, const Tensor* h0, const bool is_reverse, const int64_t N,
const bool is_reverse, const int64_t N, const int64_t Ti, const int64_t Ti, const int64_t IC, const int64_t OC,
const int64_t IC, const int64_t OC, const int64_t G, const int64_t G, const std::string& unique_name)
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, T_alg>( : platform::MKLDNNHandlerT<T, T_alg>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
CreateKey(dev_ctx, unique_name, MKLDNNGetDataType<T>(), Ti)), CreateKey(dev_ctx, unique_name, MKLDNNGetDataType<T>(), Ti)),
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "mkldnn.hpp" #include "dnnl.hpp"
#include "paddle/fluid/operators/mkldnn/axpy_handler.h" #include "paddle/fluid/operators/mkldnn/axpy_handler.h"
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "mkldnn.hpp" #include "dnnl.hpp"
#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/dequantize_op.h" #include "paddle/fluid/operators/dequantize_op.h"
...@@ -23,13 +23,13 @@ limitations under the License. */ ...@@ -23,13 +23,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using mkldnn::memory; using dnnl::memory;
using mkldnn::primitive; using dnnl::primitive;
using mkldnn::reorder; using dnnl::reorder;
using platform::to_void_cast; using platform::to_void_cast;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using framework::DataLayout; using framework::DataLayout;
using mkldnn::stream; using dnnl::stream;
using platform::GetMKLDNNFormat; using platform::GetMKLDNNFormat;
template <typename T> template <typename T>
...@@ -64,7 +64,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -64,7 +64,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
auto src_tz = paddle::framework::vectorize<int64_t>(input->dims()); auto src_tz = paddle::framework::vectorize<int64_t>(input->dims());
auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims()); auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims());
mkldnn::memory::data_type src_dt = dnnl::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
MKLDNNMemoryFormat src_fmt = input->format(); MKLDNNMemoryFormat src_fmt = input->format();
...@@ -76,34 +76,34 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -76,34 +76,34 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
const std::string key_src_mem = key + "@s"; const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d"; const std::string key_dst_mem = key + "@d";
std::shared_ptr<mkldnn::memory> src_memory; std::shared_ptr<dnnl::memory> src_memory;
std::shared_ptr<mkldnn::memory> dst_memory; std::shared_ptr<dnnl::memory> dst_memory;
std::shared_ptr<reorder> reorder_p; std::shared_ptr<reorder> reorder_p;
reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim)); reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim));
if (reorder_p == nullptr) { if (reorder_p == nullptr) {
mkldnn::primitive_attr attri; dnnl::primitive_attr attri;
int mask = 0; int mask = 0;
float reorder_scale = 1. / scale_data; float reorder_scale = 1. / scale_data;
attri.set_output_scales(mask, {reorder_scale}); attri.set_output_scales(mask, {reorder_scale});
if (with_shift) { if (with_shift) {
mkldnn::post_ops post_operations; dnnl::post_ops post_operations;
post_operations.append_sum(); post_operations.append_sum();
attri.set_post_ops(post_operations); attri.set_post_ops(post_operations);
std::fill(output_data, output_data + output->numel(), reorder_shift); std::fill(output_data, output_data + output->numel(), reorder_shift);
} }
auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt); auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt);
src_memory = std::make_shared<mkldnn::memory>( src_memory = std::make_shared<dnnl::memory>(src_md, engine,
src_md, engine, to_void_cast<T>(input_data)); to_void_cast<T>(input_data));
auto dst_md = auto dst_md =
platform::MKLDNNMemDesc({dst_tz}, memory::data_type::f32, platform::MKLDNNMemDesc({dst_tz}, memory::data_type::f32,
platform::MKLDNNFormatForSize( platform::MKLDNNFormatForSize(
dst_tz.size(), MKLDNNMemoryFormat::nchw)); dst_tz.size(), MKLDNNMemoryFormat::nchw));
dst_memory = std::make_shared<mkldnn::memory>( dst_memory = std::make_shared<dnnl::memory>(
dst_md, engine, to_void_cast<float>(output_data)); dst_md, engine, to_void_cast<float>(output_data));
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>( auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
...@@ -113,12 +113,12 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -113,12 +113,12 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
dev_ctx.SetBlob(key_src_mem, src_memory); dev_ctx.SetBlob(key_src_mem, src_memory);
dev_ctx.SetBlob(key_dst_mem, dst_memory); dev_ctx.SetBlob(key_dst_mem, dst_memory);
} else { } else {
src_memory = std::static_pointer_cast<mkldnn::memory>( src_memory =
dev_ctx.GetBlob(key_src_mem)); std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_src_mem));
src_memory->set_data_handle(to_void_cast<T>(input_data)); src_memory->set_data_handle(to_void_cast<T>(input_data));
dst_memory = std::static_pointer_cast<mkldnn::memory>( dst_memory =
dev_ctx.GetBlob(key_dst_mem)); std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_dst_mem));
if (with_shift) if (with_shift)
std::fill(output_data, output_data + output->numel(), reorder_shift); std::fill(output_data, output_data + output->numel(), reorder_shift);
dst_memory->set_data_handle(output->mutable_data<float>(ctx.GetPlace())); dst_memory->set_data_handle(output->mutable_data<float>(ctx.GetPlace()));
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "mkldnn.hpp" #include "dnnl.hpp"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/quantize_op.h" #include "paddle/fluid/operators/quantize_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
...@@ -21,13 +21,13 @@ limitations under the License. */ ...@@ -21,13 +21,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using mkldnn::memory; using dnnl::memory;
using mkldnn::primitive; using dnnl::primitive;
using mkldnn::reorder; using dnnl::reorder;
using platform::to_void_cast; using platform::to_void_cast;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using framework::DataLayout; using framework::DataLayout;
using mkldnn::stream; using dnnl::stream;
using platform::GetMKLDNNFormat; using platform::GetMKLDNNFormat;
template <typename T> template <typename T>
...@@ -65,19 +65,19 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -65,19 +65,19 @@ class QuantOpKernel : public framework::OpKernel<T> {
bool bfloat16 = ctx.Attr<bool>("bfloat16"); bool bfloat16 = ctx.Attr<bool>("bfloat16");
// TODO(jczaja): Refactor with Acquire API // TODO(jczaja): Refactor with Acquire API
std::shared_ptr<mkldnn::memory> src_memory; std::shared_ptr<dnnl::memory> src_memory;
std::shared_ptr<mkldnn::memory> dst_memory; std::shared_ptr<dnnl::memory> dst_memory;
std::shared_ptr<reorder> reorder_p; std::shared_ptr<reorder> reorder_p;
std::string out_layout = ctx.Attr<std::string>("output_format"); std::string out_layout = ctx.Attr<std::string>("output_format");
MKLDNNMemoryFormat out_format = MKLDNNMemoryFormat out_format =
platform::data_format_to_memory_format(out_layout); platform::data_format_to_memory_format(out_layout);
mkldnn::primitive_attr attri; dnnl::primitive_attr attri;
int mask = 0; int mask = 0;
attri.set_output_scales(mask, {scale_data}); attri.set_output_scales(mask, {scale_data});
if (with_shift) { if (with_shift) {
mkldnn::post_ops post_operations; dnnl::post_ops post_operations;
post_operations.append_sum(); post_operations.append_sum();
attri.set_post_ops(post_operations); attri.set_post_ops(post_operations);
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace()); uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
...@@ -87,10 +87,10 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -87,10 +87,10 @@ class QuantOpKernel : public framework::OpKernel<T> {
auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32, auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32,
input->format()); input->format());
src_memory = std::make_shared<mkldnn::memory>(src_md, engine, src_memory = std::make_shared<dnnl::memory>(src_md, engine,
to_void_cast<T>(input_data)); to_void_cast<T>(input_data));
std::shared_ptr<mkldnn::memory::desc> dst_md; std::shared_ptr<dnnl::memory::desc> dst_md;
if (bfloat16) { if (bfloat16) {
platform::SetDstMemoryQuantized<paddle::platform::bfloat16>( platform::SetDstMemoryQuantized<paddle::platform::bfloat16>(
ctx, output, dst_tz, engine, dst_md, dst_memory, out_format); ctx, output, dst_tz, engine, dst_md, dst_memory, out_format);
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "mkldnn.hpp" #include "dnnl.hpp"
#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/requantize_op.h" #include "paddle/fluid/operators/requantize_op.h"
...@@ -93,7 +93,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -93,7 +93,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
int mask = 0; int mask = 0;
attri.set_output_scales(mask, {reorder_scale}); attri.set_output_scales(mask, {reorder_scale});
if (with_shift) { if (with_shift) {
mkldnn::post_ops post_operations; dnnl::post_ops post_operations;
post_operations.append_sum(); post_operations.append_sum();
attri.set_post_ops(post_operations); attri.set_post_ops(post_operations);
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace()); uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
......
...@@ -20,22 +20,22 @@ limitations under the License. */ ...@@ -20,22 +20,22 @@ limitations under the License. */
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "mkldnn.hpp" #include "dnnl.hpp"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
using MKLDNNMemoryFormat = mkldnn::memory::format_tag; using MKLDNNMemoryFormat = dnnl::memory::format_tag;
#endif #endif
namespace platform { namespace platform {
using MKLDNNStream = mkldnn::stream; using MKLDNNStream = dnnl::stream;
using MKLDNNEngine = mkldnn::engine; using MKLDNNEngine = dnnl::engine;
using MKLDNNMemory = mkldnn::memory; using MKLDNNMemory = dnnl::memory;
using MKLDNNMemoryDescriptor = mkldnn::memory::desc; using MKLDNNMemoryDescriptor = dnnl::memory::desc;
using MKLDNNPrimitive = mkldnn::primitive; using MKLDNNPrimitive = dnnl::primitive;
using MKLDNNPrimitiveDesc = mkldnn::handle<mkldnn_primitive_desc_t>; using MKLDNNPrimitiveDesc = dnnl::handle<dnnl_primitive_desc_t>;
typedef std::unique_ptr<MKLDNNStream> MKLDNNStreamPtr; typedef std::unique_ptr<MKLDNNStream> MKLDNNStreamPtr;
typedef std::unique_ptr<MKLDNNEngine> MKLDNNEnginePtr; typedef std::unique_ptr<MKLDNNEngine> MKLDNNEnginePtr;
...@@ -62,7 +62,7 @@ using tf_pd = typename Type::primitive_desc; ...@@ -62,7 +62,7 @@ using tf_pd = typename Type::primitive_desc;
template <typename Type, typename Engine, typename... Args> template <typename Type, typename Engine, typename... Args>
std::shared_ptr<tf_pd<Type>> MKLDNNFwdPrimitiveDesc(const Engine& e, std::shared_ptr<tf_pd<Type>> MKLDNNFwdPrimitiveDesc(const Engine& e,
Args&&... args) { Args&&... args) {
auto desc = tf_desc<Type>(mkldnn::prop_kind::forward, (args)...); auto desc = tf_desc<Type>(dnnl::prop_kind::forward, (args)...);
auto pd = new tf_pd<Type>(desc, e); auto pd = new tf_pd<Type>(desc, e);
return std::shared_ptr<tf_pd<Type>>(pd); return std::shared_ptr<tf_pd<Type>>(pd);
} }
...@@ -129,10 +129,10 @@ struct mkldnn_dummy_primitive { ...@@ -129,10 +129,10 @@ struct mkldnn_dummy_primitive {
struct desc {}; struct desc {};
}; };
inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims, inline dnnl::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
mkldnn::memory::data_type data_type, dnnl::memory::data_type data_type,
MKLDNNMemoryFormat format) { MKLDNNMemoryFormat format) {
return mkldnn::memory::desc({dims}, data_type, format); return dnnl::memory::desc({dims}, data_type, format);
} }
inline void ClearMKLDNNCache(const platform::Place& place, inline void ClearMKLDNNCache(const platform::Place& place,
...@@ -159,36 +159,35 @@ inline void DontClearMKLDNNCache(const platform::Place& place) { ...@@ -159,36 +159,35 @@ inline void DontClearMKLDNNCache(const platform::Place& place) {
} }
template <typename Type> template <typename Type>
mkldnn::memory::data_type MKLDNNGetDataType() { dnnl::memory::data_type MKLDNNGetDataType() {
return mkldnn::memory::data_type::undef; return dnnl::memory::data_type::undef;
} }
template <> template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<float>() { inline dnnl::memory::data_type MKLDNNGetDataType<float>() {
return mkldnn::memory::data_type::f32; return dnnl::memory::data_type::f32;
} }
template <> template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<int32_t>() { inline dnnl::memory::data_type MKLDNNGetDataType<int32_t>() {
return mkldnn::memory::data_type::s32; return dnnl::memory::data_type::s32;
} }
template <> template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<int8_t>() { inline dnnl::memory::data_type MKLDNNGetDataType<int8_t>() {
return mkldnn::memory::data_type::s8; return dnnl::memory::data_type::s8;
} }
template <> template <>
inline mkldnn::memory::data_type MKLDNNGetDataType<uint8_t>() { inline dnnl::memory::data_type MKLDNNGetDataType<uint8_t>() {
return mkldnn::memory::data_type::u8; return dnnl::memory::data_type::u8;
} }
template <> template <>
inline mkldnn::memory::data_type inline dnnl::memory::data_type MKLDNNGetDataType<paddle::platform::bfloat16>() {
MKLDNNGetDataType<paddle::platform::bfloat16>() { return dnnl::memory::data_type::bf16;
return mkldnn::memory::data_type::bf16;
} }
inline void Reorder(mkldnn::memory src, mkldnn::memory dst, inline void Reorder(dnnl::memory src, dnnl::memory dst,
const mkldnn::engine& engine) { const dnnl::engine& engine) {
auto reorder_prim = mkldnn::reorder(src, dst); auto reorder_prim = dnnl::reorder(src, dst);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder", platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp); platform::EventRole::kUniqueOp);
...@@ -196,8 +195,7 @@ inline void Reorder(mkldnn::memory src, mkldnn::memory dst, ...@@ -196,8 +195,7 @@ inline void Reorder(mkldnn::memory src, mkldnn::memory dst,
astream.wait(); astream.wait();
} }
inline mkldnn::memory::format_tag GetMKLDNNFormat( inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
mkldnn::memory::desc mem_desc) {
auto ndims = mem_desc.data.ndims; auto ndims = mem_desc.data.ndims;
auto strides = mem_desc.data.format_desc.blocking.strides; auto strides = mem_desc.data.format_desc.blocking.strides;
auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks; auto inner_nblks = mem_desc.data.format_desc.blocking.inner_nblks;
...@@ -205,62 +203,62 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat( ...@@ -205,62 +203,62 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat(
auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs; auto inner_idxs = mem_desc.data.format_desc.blocking.inner_idxs;
if (ndims == 1) { if (ndims == 1) {
return mkldnn::memory::format_tag::x; return dnnl::memory::format_tag::x;
} else if (ndims == 2) { } else if (ndims == 2) {
if (inner_nblks == 0) { if (inner_nblks == 0) {
if (strides[0] >= strides[1]) { if (strides[0] >= strides[1]) {
return mkldnn::memory::format_tag::nc; return dnnl::memory::format_tag::nc;
} else { } else {
return mkldnn::memory::format_tag::cn; return dnnl::memory::format_tag::cn;
} }
} }
} else if (ndims == 3) { } else if (ndims == 3) {
if (inner_nblks == 0) { if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2]) { if (strides[0] >= strides[1] && strides[1] >= strides[2]) {
return mkldnn::memory::format_tag::ncw; return dnnl::memory::format_tag::ncw;
} else if (strides[1] >= strides[0] && strides[0] >= strides[2]) { } else if (strides[1] >= strides[0] && strides[0] >= strides[2]) {
return mkldnn::memory::format_tag::ntc; return dnnl::memory::format_tag::ntc;
} else { } else {
return mkldnn::memory::format_tag::nwc; return dnnl::memory::format_tag::nwc;
} }
} }
} else if (ndims == 4) { } else if (ndims == 4) {
if (inner_nblks == 0) { if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] && if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3]) { strides[2] >= strides[3]) {
return mkldnn::memory::format_tag::nchw; return dnnl::memory::format_tag::nchw;
} else if (strides[2] >= strides[3] && strides[3] >= strides[1] && } else if (strides[2] >= strides[3] && strides[3] >= strides[1] &&
strides[1] >= strides[0]) { strides[1] >= strides[0]) {
return mkldnn::memory::format_tag::cdba; return dnnl::memory::format_tag::cdba;
} else { } else {
return mkldnn::memory::format_tag::nhwc; return dnnl::memory::format_tag::nhwc;
} }
} else if (inner_nblks == 1) { } else if (inner_nblks == 1) {
if (inner_blks[0] == 16 && inner_idxs[0] == 1) { if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
return mkldnn::memory::format_tag::nChw16c; return dnnl::memory::format_tag::nChw16c;
} else if (inner_blks[0] == 8 && inner_idxs[0] == 1) { } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
return mkldnn::memory::format_tag::nChw8c; return dnnl::memory::format_tag::nChw8c;
} else if (inner_blks[0] == 8 && inner_idxs[0] == 0) { } else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] && if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) { strides[3] >= strides[1]) {
return mkldnn::memory::format_tag::Acdb8a; return dnnl::memory::format_tag::Acdb8a;
} }
} else if (inner_blks[0] == 4 && inner_idxs[0] == 1) { } else if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
return mkldnn::memory::format_tag::nChw4c; return dnnl::memory::format_tag::nChw4c;
} else if (inner_blks[0] == 16 && inner_idxs[0] == 0) { } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] && if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[1]) { strides[3] >= strides[1]) {
return mkldnn::memory::format_tag::Acdb16a; return dnnl::memory::format_tag::Acdb16a;
} }
} }
} else if (inner_nblks == 2) { } else if (inner_nblks == 2) {
if (inner_blks[0] == 16 && inner_blks[1] == 16) { if (inner_blks[0] == 16 && inner_blks[1] == 16) {
if (inner_idxs[0] == 1 && inner_idxs[1] == 0) { if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
return mkldnn::memory::format_tag::OIhw16i16o; return dnnl::memory::format_tag::OIhw16i16o;
} }
} else if (inner_blks[0] == 8 && inner_blks[1] == 8) { } else if (inner_blks[0] == 8 && inner_blks[1] == 8) {
if (inner_idxs[0] == 1 && inner_idxs[1] == 0) { if (inner_idxs[0] == 1 && inner_idxs[1] == 0) {
return mkldnn::memory::format_tag::OIhw8i8o; return dnnl::memory::format_tag::OIhw8i8o;
} }
} }
} }
...@@ -268,38 +266,38 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat( ...@@ -268,38 +266,38 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat(
if (inner_nblks == 0) { if (inner_nblks == 0) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] && if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) { strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::ncdhw; return dnnl::memory::format_tag::ncdhw;
} else { } else {
return mkldnn::memory::format_tag::ndhwc; return dnnl::memory::format_tag::ndhwc;
} }
} else if (inner_nblks == 1) { } else if (inner_nblks == 1) {
if (inner_blks[0] == 8 && inner_idxs[0] == 0) { if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] && if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) { strides[3] >= strides[4] && strides[4] >= strides[1]) {
return mkldnn::memory::format_tag::Acdeb8a; return dnnl::memory::format_tag::Acdeb8a;
} }
if (strides[0] >= strides[1] && strides[1] >= strides[2] && if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) { strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::Abcde8a; return dnnl::memory::format_tag::Abcde8a;
} }
} else if (inner_blks[0] == 8 && inner_idxs[0] == 1) { } else if (inner_blks[0] == 8 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] && if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) { strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::aBcde8b; return dnnl::memory::format_tag::aBcde8b;
} }
} else if (inner_blks[0] == 16 && inner_idxs[0] == 0) { } else if (inner_blks[0] == 16 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] && if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) { strides[3] >= strides[4] && strides[4] >= strides[1]) {
return mkldnn::memory::format_tag::Acdeb16a; return dnnl::memory::format_tag::Acdeb16a;
} }
if (strides[0] >= strides[1] && strides[1] >= strides[2] && if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) { strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::Abcde16a; return dnnl::memory::format_tag::Abcde16a;
} }
} else if (inner_blks[0] == 16 && inner_idxs[0] == 1) { } else if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] && if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) { strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::aBcde16b; return dnnl::memory::format_tag::aBcde16b;
} }
} }
} }
...@@ -308,7 +306,7 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat( ...@@ -308,7 +306,7 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat(
if (strides[0] >= strides[1] && strides[1] >= strides[2] && if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4] && strides[2] >= strides[3] && strides[3] >= strides[4] &&
strides[4] >= strides[5]) { strides[4] >= strides[5]) {
return mkldnn::memory::format_tag::abcdef; return dnnl::memory::format_tag::abcdef;
} }
} }
} }
...@@ -325,10 +323,10 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat( ...@@ -325,10 +323,10 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat(
// for (int i=0;i<inner_nblks;++i) { // for (int i=0;i<inner_nblks;++i) {
// std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl; // std::cout<<"INNER_IDXS["<<i<<"]: "<<inner_idxs[i]<<std::endl;
// } // }
return mkldnn::memory::format_tag::undef; return dnnl::memory::format_tag::undef;
} }
inline mkldnn::memory::format_tag GetMKLDNNFormat(const mkldnn::memory memory) { inline dnnl::memory::format_tag GetMKLDNNFormat(const dnnl::memory memory) {
auto mem_desc = memory.get_desc(); auto mem_desc = memory.get_desc();
return GetMKLDNNFormat(mem_desc); return GetMKLDNNFormat(mem_desc);
} }
...@@ -441,24 +439,24 @@ inline void AppendKey(std::string* key, const T& num) { ...@@ -441,24 +439,24 @@ inline void AppendKey(std::string* key, const T& num) {
template <> template <>
inline void AppendKey(std::string* key, inline void AppendKey(std::string* key,
const mkldnn::memory::format_tag& format) { const dnnl::memory::format_tag& format) {
key->append(std::to_string(static_cast<int>(format))); key->append(std::to_string(static_cast<int>(format)));
} }
template <> template <>
inline void AppendKey(std::string* key, inline void AppendKey(std::string* key,
const mkldnn::memory::data_type& data_type) { const dnnl::memory::data_type& data_type) {
key->append(std::to_string(static_cast<int>(data_type))); key->append(std::to_string(static_cast<int>(data_type)));
} }
template <> template <>
inline void AppendKey(std::string* key, const mkldnn::algorithm& algorithm) { inline void AppendKey(std::string* key, const dnnl::algorithm& algorithm) {
key->append(std::to_string(static_cast<int>(algorithm))); key->append(std::to_string(static_cast<int>(algorithm)));
} }
template <> template <>
inline void AppendKey(std::string* key, inline void AppendKey(std::string* key,
const mkldnn::normalization_flags& flags) { const dnnl::normalization_flags& flags) {
key->append(std::to_string(static_cast<int>(flags))); key->append(std::to_string(static_cast<int>(flags)));
} }
......
...@@ -33,14 +33,14 @@ namespace platform { ...@@ -33,14 +33,14 @@ namespace platform {
using framework::DataLayout; using framework::DataLayout;
using framework::Tensor; using framework::Tensor;
using user_function = std::function<std::shared_ptr<float>(const float*)>; using user_function = std::function<std::shared_ptr<float>(const float*)>;
using memory = mkldnn::memory; using memory = dnnl::memory;
template <typename T, typename TForward, template <typename T, typename TForward,
typename TBackward = mkldnn_dummy_primitive, typename TBackward = mkldnn_dummy_primitive,
typename TBackward_params = mkldnn_dummy_primitive> typename TBackward_params = mkldnn_dummy_primitive>
class MKLDNNHandlerNoCachingT { class MKLDNNHandlerNoCachingT {
public: public:
MKLDNNHandlerNoCachingT(mkldnn::engine engine, platform::Place cpu_place) MKLDNNHandlerNoCachingT(dnnl::engine engine, platform::Place cpu_place)
: engine_(engine), place_(cpu_place), fwd_pd_(nullptr), bwd_pd_(nullptr) { : engine_(engine), place_(cpu_place), fwd_pd_(nullptr), bwd_pd_(nullptr) {
platform::MKLDNNDeviceContext::tls().log_lib_version(); platform::MKLDNNDeviceContext::tls().log_lib_version();
} }
...@@ -60,7 +60,7 @@ class MKLDNNHandlerNoCachingT { ...@@ -60,7 +60,7 @@ class MKLDNNHandlerNoCachingT {
return std::make_shared<TBackward_params>(*bwd_w_pd_); return std::make_shared<TBackward_params>(*bwd_w_pd_);
} }
std::shared_ptr<mkldnn::memory> AcquireSrcMemory( std::shared_ptr<dnnl::memory> AcquireSrcMemory(
const framework::Tensor* input) { const framework::Tensor* input) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(fwd_pd_->src_desc(), return this->AcquireMemoryFromPrimitive(fwd_pd_->src_desc(),
...@@ -68,33 +68,33 @@ class MKLDNNHandlerNoCachingT { ...@@ -68,33 +68,33 @@ class MKLDNNHandlerNoCachingT {
} }
template <typename T_out = T> template <typename T_out = T>
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output) { std::shared_ptr<dnnl::memory> AcquireDstMemory(framework::Tensor* output) {
T_out* ptr = T_out* ptr =
output->mutable_data<T_out>(place_, fwd_pd_->dst_desc().get_size()); output->mutable_data<T_out>(place_, fwd_pd_->dst_desc().get_size());
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc(), ptr); return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc(), ptr);
} }
template <typename T_out = T> template <typename T_out = T>
std::shared_ptr<mkldnn::memory> AcquireDstMemory(void) { std::shared_ptr<dnnl::memory> AcquireDstMemory(void) {
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc()); return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc());
} }
template <typename T_out = T> template <typename T_out = T>
std::shared_ptr<mkldnn::memory> AcquireDstMemory( std::shared_ptr<dnnl::memory> AcquireDstMemory(
const framework::Tensor* output) { const framework::Tensor* output) {
const T_out* output_data = output->data<T_out>(); const T_out* output_data = output->data<T_out>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_desc(), return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_desc(),
to_void_cast<T_out>(output_data)); to_void_cast<T_out>(output_data));
} }
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory( std::shared_ptr<dnnl::memory> AcquireDiffDstMemory(
const framework::Tensor* diffdst) { const framework::Tensor* diffdst) {
const T* ptr = diffdst->data<T>(); const T* ptr = diffdst->data<T>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_desc(), return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_dst_desc(),
to_void_cast<T>(ptr)); to_void_cast<T>(ptr));
} }
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory( std::shared_ptr<dnnl::memory> AcquireDiffSrcMemory(
framework::Tensor* diffsrc) { framework::Tensor* diffsrc) {
T* ptr = T* ptr =
diffsrc->mutable_data<T>(place_, bwd_pd_->diff_src_desc().get_size()); diffsrc->mutable_data<T>(place_, bwd_pd_->diff_src_desc().get_size());
...@@ -102,7 +102,7 @@ class MKLDNNHandlerNoCachingT { ...@@ -102,7 +102,7 @@ class MKLDNNHandlerNoCachingT {
} }
// Buffer of given Tensor is used for oneDNN computation // Buffer of given Tensor is used for oneDNN computation
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemory( std::shared_ptr<dnnl::memory> AcquireDiffWeightsMemory(
framework::Tensor* diff_weights) { framework::Tensor* diff_weights) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_, bwd_w_pd_,
...@@ -115,7 +115,7 @@ class MKLDNNHandlerNoCachingT { ...@@ -115,7 +115,7 @@ class MKLDNNHandlerNoCachingT {
} }
// Buffer is allocated by oneDNN to store computation results // Buffer is allocated by oneDNN to store computation results
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemory(void) { std::shared_ptr<dnnl::memory> AcquireDiffWeightsMemory(void) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_, bwd_w_pd_,
platform::errors::Unavailable( platform::errors::Unavailable(
...@@ -179,37 +179,36 @@ class MKLDNNHandlerNoCachingT { ...@@ -179,37 +179,36 @@ class MKLDNNHandlerNoCachingT {
bwd_desc, engine_, *fwd_pd_); bwd_desc, engine_, *fwd_pd_);
} }
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive( std::shared_ptr<dnnl::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::desc md, void* ptr) { dnnl::memory::desc md, void* ptr) {
return std::make_shared<mkldnn::memory>(md, engine_, ptr); return std::make_shared<dnnl::memory>(md, engine_, ptr);
} }
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive( std::shared_ptr<dnnl::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::desc md) { dnnl::memory::desc md) {
return std::make_shared<mkldnn::memory>(md, engine_); return std::make_shared<dnnl::memory>(md, engine_);
} }
void AcquireReorder(const std::shared_ptr<mkldnn::memory>& user_memory_p, void AcquireReorder(const std::shared_ptr<dnnl::memory>& user_memory_p,
const std::shared_ptr<mkldnn::memory>& target_memory_p) { const std::shared_ptr<dnnl::memory>& target_memory_p) {
auto reorder_p = auto reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p); std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder", platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp); platform::EventRole::kUniqueOp);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p}, reorder_p->execute(astream, {{DNNL_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}}); {DNNL_ARG_TO, *target_memory_p}});
astream.wait(); astream.wait();
} }
template <typename F = T> template <typename F = T>
std::shared_ptr<mkldnn::memory> AcquireMemoryWithReorder( std::shared_ptr<dnnl::memory> AcquireMemoryWithReorder(
const mkldnn::memory::desc& user_md, const dnnl::memory::desc& user_md, const dnnl::memory::desc& target_md,
const mkldnn::memory::desc& target_md, void* ptr, void* ptr, bool is_persistent = false,
bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {}) { std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {}) {
std::shared_ptr<mkldnn::memory> target_memory_p; std::shared_ptr<dnnl::memory> target_memory_p;
if (custom_reorder_func) { if (custom_reorder_func) {
auto reordered_data = auto reordered_data =
custom_reorder_func(reinterpret_cast<const F*>(ptr)); custom_reorder_func(reinterpret_cast<const F*>(ptr));
...@@ -217,15 +216,15 @@ class MKLDNNHandlerNoCachingT { ...@@ -217,15 +216,15 @@ class MKLDNNHandlerNoCachingT {
} }
auto user_memory_p = std::make_shared<dnnl::memory>(user_md, engine_, ptr); auto user_memory_p = std::make_shared<dnnl::memory>(user_md, engine_, ptr);
if (user_md != target_md) { if (user_md != target_md) {
target_memory_p = std::make_shared<mkldnn::memory>(target_md, engine_); target_memory_p = std::make_shared<dnnl::memory>(target_md, engine_);
auto reorder_p = auto reorder_p =
std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p); std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder", platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp); platform::EventRole::kUniqueOp);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p}, reorder_p->execute(astream, {{DNNL_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}}); {DNNL_ARG_TO, *target_memory_p}});
astream.wait(); astream.wait();
} else { } else {
target_memory_p = user_memory_p; target_memory_p = user_memory_p;
...@@ -233,7 +232,7 @@ class MKLDNNHandlerNoCachingT { ...@@ -233,7 +232,7 @@ class MKLDNNHandlerNoCachingT {
return target_memory_p; return target_memory_p;
} }
mkldnn::engine engine_; dnnl::engine engine_;
platform::Place place_; platform::Place place_;
std::shared_ptr<typename TForward::primitive_desc> fwd_pd_; std::shared_ptr<typename TForward::primitive_desc> fwd_pd_;
std::shared_ptr<typename TBackward::primitive_desc> bwd_pd_; std::shared_ptr<typename TBackward::primitive_desc> bwd_pd_;
...@@ -245,7 +244,7 @@ template <typename T, typename TForward, ...@@ -245,7 +244,7 @@ template <typename T, typename TForward,
typename TBackward_params = mkldnn_dummy_primitive> typename TBackward_params = mkldnn_dummy_primitive>
class MKLDNNHandlerT { class MKLDNNHandlerT {
public: public:
MKLDNNHandlerT(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, MKLDNNHandlerT(const MKLDNNDeviceContext& dev_ctx, dnnl::engine engine,
platform::Place cpu_place, const std::string& base_key) platform::Place cpu_place, const std::string& base_key)
: dev_ctx_(dev_ctx), : dev_ctx_(dev_ctx),
engine_(engine), engine_(engine),
...@@ -294,7 +293,7 @@ class MKLDNNHandlerT { ...@@ -294,7 +293,7 @@ class MKLDNNHandlerT {
return backward_p; return backward_p;
} }
std::shared_ptr<mkldnn::memory> AcquireSrcMemory( std::shared_ptr<dnnl::memory> AcquireSrcMemory(
const framework::Tensor* input) { const framework::Tensor* input) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(
...@@ -302,7 +301,7 @@ class MKLDNNHandlerT { ...@@ -302,7 +301,7 @@ class MKLDNNHandlerT {
} }
template <typename T_out = T> template <typename T_out = T>
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output) { std::shared_ptr<dnnl::memory> AcquireDstMemory(framework::Tensor* output) {
T_out* ptr = T_out* ptr =
output->mutable_data<T_out>(place_, fwd_pd_->dst_desc().get_size()); output->mutable_data<T_out>(place_, fwd_pd_->dst_desc().get_size());
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc(), ptr, return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc(), ptr,
...@@ -310,12 +309,12 @@ class MKLDNNHandlerT { ...@@ -310,12 +309,12 @@ class MKLDNNHandlerT {
} }
template <typename T_out = T> template <typename T_out = T>
std::shared_ptr<mkldnn::memory> AcquireDstMemory(void) { std::shared_ptr<dnnl::memory> AcquireDstMemory(void) {
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc(), "@dstt_mem_p"); return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_desc(), "@dstt_mem_p");
} }
template <typename T_out = T> template <typename T_out = T>
std::shared_ptr<mkldnn::memory> AcquireDstMemory( std::shared_ptr<dnnl::memory> AcquireDstMemory(
const framework::Tensor* output) { const framework::Tensor* output) {
const T_out* output_data = output->data<T_out>(); const T_out* output_data = output->data<T_out>();
return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_desc(), return this->AcquireMemoryFromPrimitive(bwd_pd_->dst_desc(),
...@@ -323,14 +322,14 @@ class MKLDNNHandlerT { ...@@ -323,14 +322,14 @@ class MKLDNNHandlerT {
"@bwd-dst_mem_p"); "@bwd-dst_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory( std::shared_ptr<dnnl::memory> AcquireDiffDstMemory(
const framework::Tensor* diffdst) { const framework::Tensor* diffdst) {
const T* ptr = diffdst->data<T>(); const T* ptr = diffdst->data<T>();
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(
bwd_pd_->diff_dst_desc(), to_void_cast<T>(ptr), "@diff_dst_mem_p"); bwd_pd_->diff_dst_desc(), to_void_cast<T>(ptr), "@diff_dst_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory( std::shared_ptr<dnnl::memory> AcquireDiffSrcMemory(
framework::Tensor* diffsrc) { framework::Tensor* diffsrc) {
T* ptr = T* ptr =
diffsrc->mutable_data<T>(place_, bwd_pd_->diff_src_desc().get_size()); diffsrc->mutable_data<T>(place_, bwd_pd_->diff_src_desc().get_size());
...@@ -339,7 +338,7 @@ class MKLDNNHandlerT { ...@@ -339,7 +338,7 @@ class MKLDNNHandlerT {
} }
// Buffer of given Tensor is used for oneDNN computation // Buffer of given Tensor is used for oneDNN computation
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemory( std::shared_ptr<dnnl::memory> AcquireDiffWeightsMemory(
framework::Tensor* diff_weights) { framework::Tensor* diff_weights) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_, bwd_w_pd_,
...@@ -352,7 +351,7 @@ class MKLDNNHandlerT { ...@@ -352,7 +351,7 @@ class MKLDNNHandlerT {
} }
// Buffer is allocated by oneDNN to store computation results // Buffer is allocated by oneDNN to store computation results
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemory(void) { std::shared_ptr<dnnl::memory> AcquireDiffWeightsMemory(void) {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_, bwd_w_pd_,
platform::errors::Unavailable( platform::errors::Unavailable(
...@@ -467,19 +466,19 @@ class MKLDNNHandlerT { ...@@ -467,19 +466,19 @@ class MKLDNNHandlerT {
} }
} }
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive( std::shared_ptr<dnnl::memory> AcquireMemoryFromPrimitive(
const std::string& suffix) { const std::string& suffix) {
return std::static_pointer_cast<mkldnn::memory>( return std::static_pointer_cast<dnnl::memory>(
dev_ctx_.GetBlob(key_ + suffix)); dev_ctx_.GetBlob(key_ + suffix));
} }
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive( std::shared_ptr<dnnl::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::desc md, void* ptr, const std::string& suffix) { dnnl::memory::desc md, void* ptr, const std::string& suffix) {
const auto local_key = key_ + suffix; const auto local_key = key_ + suffix;
auto mem_p = auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) { if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(md, engine_, ptr); mem_p = std::make_shared<dnnl::memory>(md, engine_, ptr);
dev_ctx_.SetBlob(local_key, mem_p); dev_ctx_.SetBlob(local_key, mem_p);
} else { } else {
mem_p->set_data_handle(ptr); mem_p->set_data_handle(ptr);
...@@ -487,37 +486,36 @@ class MKLDNNHandlerT { ...@@ -487,37 +486,36 @@ class MKLDNNHandlerT {
return mem_p; return mem_p;
} }
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive( std::shared_ptr<dnnl::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::desc md, const std::string& suffix) { dnnl::memory::desc md, const std::string& suffix) {
const auto local_key = key_ + suffix; const auto local_key = key_ + suffix;
auto mem_p = auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) { if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(md, engine_); mem_p = std::make_shared<dnnl::memory>(md, engine_);
dev_ctx_.SetBlob(local_key, mem_p); dev_ctx_.SetBlob(local_key, mem_p);
} }
return mem_p; return mem_p;
} }
void AcquireReorder(const std::shared_ptr<mkldnn::memory>& user_memory_p, void AcquireReorder(const std::shared_ptr<dnnl::memory>& user_memory_p,
const std::shared_ptr<mkldnn::memory>& target_memory_p) { const std::shared_ptr<dnnl::memory>& target_memory_p) {
auto reorder_p = auto reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p); std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder", platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp); platform::EventRole::kUniqueOp);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p}, reorder_p->execute(astream, {{DNNL_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}}); {DNNL_ARG_TO, *target_memory_p}});
astream.wait(); astream.wait();
} }
template <typename F = T> template <typename F = T>
std::shared_ptr<mkldnn::memory> AcquireMemoryWithReorder( std::shared_ptr<dnnl::memory> AcquireMemoryWithReorder(
const mkldnn::memory::desc& user_md, const dnnl::memory::desc& user_md, const dnnl::memory::desc& target_md,
const mkldnn::memory::desc& target_md, void* ptr, void* ptr, const std::string& suffix, bool is_persistent = false,
const std::string& suffix, bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {}, std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {},
const std::vector<float>& scale_data = {1.0f}, int mask = 0) { const std::vector<float>& scale_data = {1.0f}, int mask = 0) {
const auto target_key = key_ + suffix + "_target"; const auto target_key = key_ + suffix + "_target";
...@@ -537,7 +535,7 @@ class MKLDNNHandlerT { ...@@ -537,7 +535,7 @@ class MKLDNNHandlerT {
auto user_memory_p = auto user_memory_p =
std::make_shared<dnnl::memory>(user_md, engine_, ptr); std::make_shared<dnnl::memory>(user_md, engine_, ptr);
if (user_md != target_md) { if (user_md != target_md) {
target_memory_p = std::make_shared<mkldnn::memory>(target_md, engine_); target_memory_p = std::make_shared<dnnl::memory>(target_md, engine_);
dnnl::reorder::primitive_desc reorder_pdesc; dnnl::reorder::primitive_desc reorder_pdesc;
if (is_int8<T>()) { if (is_int8<T>()) {
dnnl::primitive_attr attr; dnnl::primitive_attr attr;
...@@ -554,8 +552,8 @@ class MKLDNNHandlerT { ...@@ -554,8 +552,8 @@ class MKLDNNHandlerT {
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder", platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp); platform::EventRole::kUniqueOp);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p}, reorder_p->execute(astream, {{DNNL_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}}); {DNNL_ARG_TO, *target_memory_p}});
astream.wait(); astream.wait();
} else { } else {
target_memory_p = user_memory_p; target_memory_p = user_memory_p;
...@@ -571,27 +569,26 @@ class MKLDNNHandlerT { ...@@ -571,27 +569,26 @@ class MKLDNNHandlerT {
// TODO(jczaja): Here we detect if reorder is cached it means it is needed // TODO(jczaja): Here we detect if reorder is cached it means it is needed
// need to change this to get rid of keys // need to change this to get rid of keys
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>( auto reorder_p = std::static_pointer_cast<dnnl::reorder>(
dev_ctx_.GetBlob(key_reorder_p)); dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p != nullptr) { if (reorder_p != nullptr) {
platform::RecordEvent record_reorder("int_reorder", platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp); platform::EventRole::kUniqueOp);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p}, reorder_p->execute(astream, {{DNNL_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}}); {DNNL_ARG_TO, *target_memory_p}});
astream.wait(); astream.wait();
} }
} }
return target_memory_p; return target_memory_p;
} }
std::shared_ptr<mkldnn::memory> AcquireMemory(const std::string& suffix) { std::shared_ptr<dnnl::memory> AcquireMemory(const std::string& suffix) {
const auto local_key = key_ + suffix; const auto local_key = key_ + suffix;
return std::static_pointer_cast<mkldnn::memory>( return std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(local_key));
dev_ctx_.GetBlob(local_key));
} }
const MKLDNNDeviceContext& dev_ctx_; const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_; dnnl::engine engine_;
platform::Place place_; platform::Place place_;
std::string key_common_; std::string key_common_;
std::string key_; std::string key_;
...@@ -605,7 +602,7 @@ class BinaryMKLDNNHandler ...@@ -605,7 +602,7 @@ class BinaryMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> { : public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> {
public: public:
BinaryMKLDNNHandler(const dnnl::algorithm algo, const int axis, BinaryMKLDNNHandler(const dnnl::algorithm algo, const int axis,
const mkldnn::engine engine, platform::Place cpu_place, const dnnl::engine engine, platform::Place cpu_place,
const Tensor* x, const Tensor* y, Tensor* z, const Tensor* x, const Tensor* y, Tensor* z,
float scale_x, float scale_y, float scale_z, float scale_x, float scale_y, float scale_z,
const dnnl::post_ops& post_ops = dnnl::post_ops()) const dnnl::post_ops& post_ops = dnnl::post_ops())
...@@ -662,7 +659,7 @@ class BinaryMKLDNNHandler ...@@ -662,7 +659,7 @@ class BinaryMKLDNNHandler
this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md, src1_md, this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md, src1_md,
dst_md); dst_md);
} }
std::shared_ptr<mkldnn::memory> AcquireSecondSrcMemory( std::shared_ptr<dnnl::memory> AcquireSecondSrcMemory(
const framework::Tensor* input) { const framework::Tensor* input) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src1_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->src1_desc(),
...@@ -707,7 +704,7 @@ class BroadcastDataMKLDNNHandler ...@@ -707,7 +704,7 @@ class BroadcastDataMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> { : public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> {
public: public:
BroadcastDataMKLDNNHandler(const dnnl::algorithm algo, BroadcastDataMKLDNNHandler(const dnnl::algorithm algo,
const mkldnn::engine engine, const dnnl::engine engine,
platform::Place cpu_place, const Tensor* out, platform::Place cpu_place, const Tensor* out,
const Tensor* x, float scale_x, float scale_y, const Tensor* x, float scale_x, float scale_y,
const std::vector<int64_t>& input_dims) const std::vector<int64_t>& input_dims)
...@@ -735,7 +732,7 @@ class BroadcastDataMKLDNNHandler ...@@ -735,7 +732,7 @@ class BroadcastDataMKLDNNHandler
} }
template <typename T_out = T> template <typename T_out = T>
std::shared_ptr<mkldnn::memory> AcquireDstMemory(framework::Tensor* output) { std::shared_ptr<dnnl::memory> AcquireDstMemory(framework::Tensor* output) {
T_out* ptr = output->mutable_data<T_out>( T_out* ptr = output->mutable_data<T_out>(
this->place_, this->fwd_pd_->dst_desc().get_size()); this->place_, this->fwd_pd_->dst_desc().get_size());
memset(ptr, 0, this->fwd_pd_->dst_desc().get_size()); memset(ptr, 0, this->fwd_pd_->dst_desc().get_size());
...@@ -748,7 +745,7 @@ class ReductionMKLDNNHandler ...@@ -748,7 +745,7 @@ class ReductionMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::reduction> { : public platform::MKLDNNHandlerNoCachingT<T, dnnl::reduction> {
public: public:
ReductionMKLDNNHandler(const dnnl::algorithm algo, const float p, ReductionMKLDNNHandler(const dnnl::algorithm algo, const float p,
const float eps, const mkldnn::engine engine, const float eps, const dnnl::engine engine,
platform::Place cpu_place, const Tensor* x, platform::Place cpu_place, const Tensor* x,
const Tensor* y, std::vector<int64_t> y_tz, const Tensor* y, std::vector<int64_t> y_tz,
const dnnl::primitive_attr& attr = NULL) const dnnl::primitive_attr& attr = NULL)
...@@ -777,15 +774,15 @@ class ReductionMKLDNNHandler ...@@ -777,15 +774,15 @@ class ReductionMKLDNNHandler
template <typename T> template <typename T>
class ActivationMKLDNNHandler class ActivationMKLDNNHandler
: public MKLDNNHandlerNoCachingT<T, mkldnn::eltwise_forward, : public MKLDNNHandlerNoCachingT<T, dnnl::eltwise_forward,
mkldnn::eltwise_backward> { dnnl::eltwise_backward> {
public: public:
ActivationMKLDNNHandler(mkldnn::algorithm algorithm, ActivationMKLDNNHandler(dnnl::algorithm algorithm,
const framework::ExecutionContext& ctx, const framework::ExecutionContext& ctx,
const mkldnn::engine engine, Place cpu_place, const dnnl::engine engine, Place cpu_place,
const framework::Tensor* in_x) const framework::Tensor* in_x)
: platform::MKLDNNHandlerNoCachingT<T, mkldnn::eltwise_forward, : platform::MKLDNNHandlerNoCachingT<T, dnnl::eltwise_forward,
mkldnn::eltwise_backward>(engine, dnnl::eltwise_backward>(engine,
cpu_place) { cpu_place) {
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0; float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0; float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
...@@ -811,7 +808,7 @@ class ActivationMKLDNNHandler ...@@ -811,7 +808,7 @@ class ActivationMKLDNNHandler
: ctx.Attr<float>("max"); : ctx.Attr<float>("max");
} else { } else {
// paddle uses beta but mkldnn uses alpha for swish // paddle uses beta but mkldnn uses alpha for swish
if (algorithm == mkldnn::algorithm::eltwise_swish) { if (algorithm == dnnl::algorithm::eltwise_swish) {
std::swap(alpha, beta); std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) { } else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<float>("threshold"); alpha = ctx.Attr<float>("threshold");
...@@ -827,24 +824,24 @@ class ActivationMKLDNNHandler ...@@ -827,24 +824,24 @@ class ActivationMKLDNNHandler
auto src_tz = framework::vectorize<int64_t>(in_x->dims()); auto src_tz = framework::vectorize<int64_t>(in_x->dims());
auto src_fmt = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : in_x->format(); auto src_fmt = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : in_x->format();
auto md = auto md =
mkldnn::memory::desc(src_tz, platform::MKLDNNGetDataType<T>(), src_fmt); dnnl::memory::desc(src_tz, platform::MKLDNNGetDataType<T>(), src_fmt);
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training, this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
algorithm, md, alpha, beta); algorithm, md, alpha, beta);
} }
ActivationMKLDNNHandler(mkldnn::algorithm algorithm, ActivationMKLDNNHandler(dnnl::algorithm algorithm,
const framework::ExecutionContext& ctx, const framework::ExecutionContext& ctx,
const mkldnn::engine engine, Place cpu_place, const dnnl::engine engine, Place cpu_place,
const framework::Tensor* in_x, const Tensor* out_grad) const framework::Tensor* in_x, const Tensor* out_grad)
: platform::MKLDNNHandlerNoCachingT<T, mkldnn::eltwise_forward, : platform::MKLDNNHandlerNoCachingT<T, dnnl::eltwise_forward,
mkldnn::eltwise_backward>(engine, dnnl::eltwise_backward>(engine,
cpu_place) { cpu_place) {
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0; float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 0;
float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0; float beta = ctx.HasAttr("beta") ? ctx.Attr<float>("beta") : 0;
// paddle uses beta but mkldnn uses alpha for swish // paddle uses beta but mkldnn uses alpha for swish
if (algorithm == mkldnn::algorithm::eltwise_swish) { if (algorithm == dnnl::algorithm::eltwise_swish) {
std::swap(alpha, beta); std::swap(alpha, beta);
} else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) { } else if (algorithm == dnnl::algorithm::eltwise_bounded_relu) {
alpha = ctx.Attr<float>("threshold"); alpha = ctx.Attr<float>("threshold");
...@@ -870,13 +867,13 @@ class ActivationMKLDNNHandler ...@@ -870,13 +867,13 @@ class ActivationMKLDNNHandler
auto src_md = platform::MKLDNNMemDesc( auto src_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), src_fmt); dims, platform::MKLDNNGetDataType<T>(), src_fmt);
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training, this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
algorithm, src_md, alpha, beta); algorithm, src_md, alpha, beta);
this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md, this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md,
alpha, beta); alpha, beta);
} }
std::shared_ptr<mkldnn::memory> AcquireBackwardSrcMemory( std::shared_ptr<dnnl::memory> AcquireBackwardSrcMemory(
const framework::Tensor* input) { const framework::Tensor* input) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->bwd_pd_->src_desc(), return this->AcquireMemoryFromPrimitive(this->bwd_pd_->src_desc(),
...@@ -888,7 +885,7 @@ class ReorderMKLDNNHandler { ...@@ -888,7 +885,7 @@ class ReorderMKLDNNHandler {
public: public:
ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
framework::proto::VarType::Type vtype, framework::proto::VarType::Type vtype,
mkldnn::memory::data_type dtype, mkldnn::engine engine) dnnl::memory::data_type dtype, dnnl::engine engine)
: dims_(dims), : dims_(dims),
vtype_(vtype), vtype_(vtype),
vtype_dst_(vtype), vtype_dst_(vtype),
...@@ -898,10 +895,9 @@ class ReorderMKLDNNHandler { ...@@ -898,10 +895,9 @@ class ReorderMKLDNNHandler {
ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
framework::proto::VarType::Type vtype, framework::proto::VarType::Type vtype,
mkldnn::memory::data_type dtype, dnnl::memory::data_type dtype,
framework::proto::VarType::Type vtype_dst, framework::proto::VarType::Type vtype_dst,
mkldnn::memory::data_type dtype_dst, dnnl::memory::data_type dtype_dst, dnnl::engine engine)
mkldnn::engine engine)
: dims_(dims), : dims_(dims),
vtype_(vtype), vtype_(vtype),
vtype_dst_(vtype_dst), vtype_dst_(vtype_dst),
...@@ -909,56 +905,56 @@ class ReorderMKLDNNHandler { ...@@ -909,56 +905,56 @@ class ReorderMKLDNNHandler {
dtype_dst_(dtype_dst), dtype_dst_(dtype_dst),
engine_(engine) {} engine_(engine) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory( std::shared_ptr<dnnl::memory> AcquireSrcMemory(const MKLDNNMemoryFormat& fmt,
const MKLDNNMemoryFormat& fmt, void* ptr) { void* ptr) {
auto md = mkldnn::memory::desc(dims_, dtype_, fmt); auto md = dnnl::memory::desc(dims_, dtype_, fmt);
return std::make_shared<mkldnn::memory>(md, engine_, ptr); return std::make_shared<dnnl::memory>(md, engine_, ptr);
} }
std::shared_ptr<mkldnn::memory> AcquireSubmemory( std::shared_ptr<dnnl::memory> AcquireSubmemory(
const std::vector<int64_t>& dims, const std::vector<int64_t>& offset, const std::vector<int64_t>& dims, const std::vector<int64_t>& offset,
const std::shared_ptr<mkldnn::memory>& mem_p) { const std::shared_ptr<dnnl::memory>& mem_p) {
auto sub_md = mem_p->get_desc().submemory_desc(dims, {offset}); auto sub_md = mem_p->get_desc().submemory_desc(dims, {offset});
auto sub_mem_p = std::make_shared<mkldnn::memory>(sub_md, engine_, auto sub_mem_p = std::make_shared<dnnl::memory>(sub_md, engine_,
mem_p->get_data_handle()); mem_p->get_data_handle());
return sub_mem_p; return sub_mem_p;
} }
std::shared_ptr<mkldnn::memory> AcquireDstMemory( std::shared_ptr<dnnl::memory> AcquireDstMemory(framework::Tensor* output,
framework::Tensor* output, const MKLDNNMemoryFormat& fmt, const MKLDNNMemoryFormat& fmt,
platform::Place place) { platform::Place place) {
auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_dst_, fmt); auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_dst_, fmt);
auto dst_data = output->mutable_data(place, vtype_dst_, dst_md.get_size()); auto dst_data = output->mutable_data(place, vtype_dst_, dst_md.get_size());
return std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data); return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
} }
std::shared_ptr<mkldnn::memory> AcquireDstMemory( std::shared_ptr<dnnl::memory> AcquireDstMemory(
framework::Tensor* output, const std::vector<int64_t>& dims, framework::Tensor* output, const std::vector<int64_t>& dims,
const MKLDNNMemoryFormat& fmt, platform::Place place) { const MKLDNNMemoryFormat& fmt, platform::Place place) {
auto dst_md = platform::MKLDNNMemDesc(dims, dtype_dst_, fmt); auto dst_md = platform::MKLDNNMemDesc(dims, dtype_dst_, fmt);
auto dst_data = output->mutable_data(place, vtype_dst_, dst_md.get_size()); auto dst_data = output->mutable_data(place, vtype_dst_, dst_md.get_size());
return std::make_shared<mkldnn::memory>(dst_md, engine_, dst_data); return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
} }
std::shared_ptr<mkldnn::reorder> AcquireReorder( std::shared_ptr<dnnl::reorder> AcquireReorder(
std::shared_ptr<mkldnn::memory> dst_memory_p, std::shared_ptr<dnnl::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) { std::shared_ptr<dnnl::memory> src_memory_p) {
return std::make_shared<mkldnn::reorder>(*(src_memory_p), *(dst_memory_p)); return std::make_shared<dnnl::reorder>(*(src_memory_p), *(dst_memory_p));
} }
private: private:
std::vector<int64_t> dims_; std::vector<int64_t> dims_;
framework::proto::VarType::Type vtype_, vtype_dst_; framework::proto::VarType::Type vtype_, vtype_dst_;
mkldnn::memory::data_type dtype_, dtype_dst_; dnnl::memory::data_type dtype_, dtype_dst_;
mkldnn::engine engine_; dnnl::engine engine_;
}; };
template <typename T> template <typename T>
static void SetDstMemoryQuantized( static void SetDstMemoryQuantized(
const framework::ExecutionContext& ctx, framework::Tensor* output, const framework::ExecutionContext& ctx, framework::Tensor* output,
std::vector<int64_t> dst_tz, const mkldnn::engine& engine, std::vector<int64_t> dst_tz, const dnnl::engine& engine,
std::shared_ptr<mkldnn::memory::desc>& dst_md, // NOLINT std::shared_ptr<dnnl::memory::desc>& dst_md, // NOLINT
std::shared_ptr<mkldnn::memory>& dst_memory, // NOLINT std::shared_ptr<dnnl::memory>& dst_memory, // NOLINT
MKLDNNMemoryFormat output_format) { MKLDNNMemoryFormat output_format) {
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
const size_t dst_dims = dst_tz.size(); const size_t dst_dims = dst_tz.size();
...@@ -974,9 +970,9 @@ static void SetDstMemoryQuantized( ...@@ -974,9 +970,9 @@ static void SetDstMemoryQuantized(
{dst_tz}, paddle::framework::ToMKLDNNDataType( {dst_tz}, paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<T>::DataType()), framework::DataTypeTrait<T>::DataType()),
dst_fmt); dst_fmt);
dst_md.reset(new mkldnn::memory::desc(tmp_dst_md)); dst_md.reset(new dnnl::memory::desc(tmp_dst_md));
dst_memory.reset( dst_memory.reset(
new mkldnn::memory(*dst_md, engine, to_void_cast<T>(output_data))); new dnnl::memory(*dst_md, engine, to_void_cast<T>(output_data)));
} }
} // namespace platform } // namespace platform
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册