未验证 提交 fa051eec 编写于 作者: S Sławomir Siwek 提交者: GitHub

[PHI decoupling] Move MKLDNN code (#48352)

上级 85914800
...@@ -14,11 +14,8 @@ ...@@ -14,11 +14,8 @@
#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_reuse.h"
#endif
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -92,119 +89,5 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var, ...@@ -92,119 +89,5 @@ void TransDataLayout(const OpKernelType& kernel_type_for_var,
out->set_layout(expected_kernel_type.data_layout_); out->set_layout(expected_kernel_type.data_layout_);
} }
#ifdef PADDLE_WITH_MKLDNN
using dnnl::memory;
using dnnl::primitive;
using dnnl::reorder;
void* GetDataFromTensor(const phi::DenseTensor& tensor,
dnnl::memory::data_type type) {
switch (type) {
case dnnl::memory::data_type::f32:
return phi::funcs::to_void_cast(tensor.data<float>());
case dnnl::memory::data_type::s8:
return phi::funcs::to_void_cast(tensor.data<int8_t>());
case dnnl::memory::data_type::u8:
return phi::funcs::to_void_cast(tensor.data<unsigned char>());
case dnnl::memory::data_type::s32:
return phi::funcs::to_void_cast(tensor.data<int32_t>());
case dnnl::memory::data_type::bf16:
return phi::funcs::to_void_cast(
tensor.data<paddle::platform::bfloat16>());
default:
PADDLE_THROW(
platform::errors::InvalidArgument("Wrong mkldnn type provided."));
}
}
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type,
const phi::DenseTensor& in,
phi::DenseTensor* out) {
auto in_layout = kernel_type_for_var.data_layout_;
auto out_layout = expected_kernel_type.data_layout_;
auto place = expected_kernel_type.place_;
PADDLE_ENFORCE(
in_layout == DataLayout::ONEDNN && out_layout != DataLayout::ONEDNN,
platform::errors::InvalidArgument(
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"non-MKLDNN"));
innerTransDataLayoutFromMKLDNN(
in_layout,
paddle::platform::MKLDNNDeviceContext::tls().get_cur_paddle_data_layout(),
in,
out,
place);
}
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout,
DataLayout out_layout,
const phi::DenseTensor& in,
phi::DenseTensor* out,
platform::Place place,
bool always_copy) {
// Set default as NCHW in case not specified
out_layout =
out_layout == DataLayout::kAnyLayout ? DataLayout::kNCHW : out_layout;
auto& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(pool.Get(place));
auto& cpu_engine = dev_ctx->GetEngine();
auto in_tz = phi::vectorize<int64_t>(in.dims());
auto out_tz = in_tz;
memory::data_type in_type =
ToMKLDNNDataType(framework::TransToProtoVarType(in.dtype()));
PADDLE_ENFORCE_NE(
in_type,
memory::data_type::undef,
platform::errors::InvalidArgument(
"Input tensor type (%s) is not supported.",
DataTypeToString(framework::TransToProtoVarType(in.dtype()))));
auto out_format =
phi::funcs::OneDNNFormatForSize(in_tz.size(), ToOneDNNFormat(out_layout));
dnnl::memory::desc out_mem_desc(out_tz, in_type, out_format);
// output tensor has the same dims as input. Reorder don't change dims
out->set_mem_desc(out_mem_desc);
out->Resize(in.dims());
// Note(0x45f): Using initialized() to support slice Tensors
// with shapes like [0, 0, 0].
if (in.initialized() && ((in.mem_desc() != out->mem_desc()) || always_copy)) {
void* in_data = GetDataFromTensor(in, in_type);
phi::funcs::ReorderOneDNNHandler handler(
in_tz, in.dtype(), in_type, cpu_engine);
auto reorder_src_memory_p =
handler.AcquireSrcMemory(in.mem_desc(), in_data);
auto reorder_dst_memory_p =
handler.AcquireDstMemory(out, out->mem_desc(), place);
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("ext_reorder",
platform::TracerEventType::UserDefined,
2,
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
} else {
out->ShareDataWith(in);
}
// For exepected NHWC data format we need to reshape the Output tensor
// As MKL-DNN description was in NCHW and paddle is expecting NHWC
phi::funcs::MatchShapeToLayout(out, in_layout, out_layout);
out->set_layout(DataLayout::kNCHW);
}
#endif
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/op_kernel_type.h" #include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/phi/kernels/funcs/data_layout_transform.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -29,7 +30,7 @@ class OpKernelType; ...@@ -29,7 +30,7 @@ class OpKernelType;
} // namespace paddle } // namespace paddle
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/phi/backends/onednn/onednn_helper.h"
#endif #endif
namespace paddle { namespace paddle {
...@@ -51,54 +52,6 @@ struct CastDataLayout { ...@@ -51,54 +52,6 @@ struct CastDataLayout {
void apply(); void apply();
}; };
#ifdef PADDLE_WITH_MKLDNN
using OneDNNDataType = dnnl::memory::data_type;
inline OneDNNMemoryFormat ToOneDNNFormat(const DataLayout& layout) {
switch (layout) {
case DataLayout::kNHWC:
return OneDNNMemoryFormat::nhwc;
case DataLayout::kNCHW:
return OneDNNMemoryFormat::nchw;
case DataLayout::kNCDHW:
return OneDNNMemoryFormat::ncdhw;
case DataLayout::kNDHWC:
return OneDNNMemoryFormat::ndhwc;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Fail to convert layout %s to oneDNN format.",
phi::DataLayoutToString(layout)));
}
}
inline OneDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
static std::unordered_map<int, OneDNNDataType> dict{
{DataTypeTrait<float>::DataType(), OneDNNDataType::f32},
{DataTypeTrait<int8_t>::DataType(), OneDNNDataType::s8},
{DataTypeTrait<uint8_t>::DataType(), OneDNNDataType::u8},
{DataTypeTrait<int32_t>::DataType(), OneDNNDataType::s32},
{DataTypeTrait<platform::bfloat16>::DataType(), OneDNNDataType::bf16}};
auto iter = dict.find(static_cast<int>(type));
if (iter != dict.end()) return iter->second;
return OneDNNDataType::undef;
}
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout,
DataLayout out_layout,
const phi::DenseTensor& in,
phi::DenseTensor* out,
platform::Place place,
bool always_copy = false);
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type,
const phi::DenseTensor& in,
phi::DenseTensor* out);
void* GetDataFromTensor(const phi::DenseTensor& tensor, OneDNNDataType type);
#endif
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to); std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);
void TransDataLayout(const OpKernelType& kernel_type_for_var, void TransDataLayout(const OpKernelType& kernel_type_for_var,
......
...@@ -53,7 +53,7 @@ TEST(DataTransformBf16, GetDataFromTensorDNNL) { ...@@ -53,7 +53,7 @@ TEST(DataTransformBf16, GetDataFromTensorDNNL) {
place); place);
void* in_data = void* in_data =
paddle::framework::GetDataFromTensor(in, dnnl::memory::data_type::bf16); phi::funcs::GetDataFromTensor(in, dnnl::memory::data_type::bf16);
EXPECT_EQ(in_data, EXPECT_EQ(in_data,
phi::funcs::to_void_cast(in.data<paddle::platform::bfloat16>())); phi::funcs::to_void_cast(in.data<paddle::platform::bfloat16>()));
} }
...@@ -64,7 +64,7 @@ TEST(DataTransformInt32, GetDataFromTensorDNNL) { ...@@ -64,7 +64,7 @@ TEST(DataTransformInt32, GetDataFromTensorDNNL) {
in.mutable_data<int32_t>(phi::make_ddim({2, 3, 1, 2}), place); in.mutable_data<int32_t>(phi::make_ddim({2, 3, 1, 2}), place);
void* in_data = void* in_data =
paddle::framework::GetDataFromTensor(in, dnnl::memory::data_type::s32); phi::funcs::GetDataFromTensor(in, dnnl::memory::data_type::s32);
EXPECT_EQ(in_data, phi::funcs::to_void_cast(in.data<int32_t>())); EXPECT_EQ(in_data, phi::funcs::to_void_cast(in.data<int32_t>()));
} }
#endif #endif
...@@ -57,11 +57,11 @@ void TransformData(const OpKernelType &expected_kernel_type, ...@@ -57,11 +57,11 @@ void TransformData(const OpKernelType &expected_kernel_type,
"No layout transform needed between two oneDNN OPKernels.")); "No layout transform needed between two oneDNN OPKernels."));
if (lin != DataLayout::ONEDNN && lout == DataLayout::ONEDNN) { if (lin != DataLayout::ONEDNN && lout == DataLayout::ONEDNN) {
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel // Case1 - transform from Non-ONEDNN OPKernel to ONEDNN OPKernel
// Just set layout/format. No real transform occur // Just set layout/format. No real transform occur
auto out_format = phi::funcs::OneDNNFormatForSize(in.dims().size(), auto out_format = phi::funcs::OneDNNFormatForSize(
ToOneDNNFormat(lin)); in.dims().size(), phi::funcs::ToOneDNNFormat(lin));
out.ShareDataWith(input_tensor); out.ShareDataWith(input_tensor);
// For NHWC data we need reshape of tensors as MKL-DNN // For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order // is expecting NHWC dims description order
...@@ -69,26 +69,36 @@ void TransformData(const OpKernelType &expected_kernel_type, ...@@ -69,26 +69,36 @@ void TransformData(const OpKernelType &expected_kernel_type,
phi::funcs::MatchShapeToLayout(&out, lin, lout); phi::funcs::MatchShapeToLayout(&out, lin, lout);
// We register only NHWC assuming that model is consistent e.g. either // We register only NHWC assuming that model is consistent e.g. either
// NHWC or NCHW // NHWC or NCHW
paddle::platform::MKLDNNDeviceContext::tls() phi::OneDNNContext::tls().set_cur_paddle_data_layout(lin);
.set_cur_paddle_data_layout(lin);
} }
dnnl::memory::desc out_mem_desc( dnnl::memory::desc out_mem_desc(
vectorize(out.dims()), vectorize(out.dims()),
ToMKLDNNDataType(TransToProtoVarType(in.type())), phi::funcs::ToOneDNNDataType(in.dtype()),
out_format); out_format);
out.set_mem_desc(out_mem_desc); out.set_mem_desc(out_mem_desc);
} else { } else {
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel // Case2 - transfrom from ONEDNN OPKernel to Non-ONEDNN OPKernel
// Do transform via MKLDNN lib // Do transform via ONEDNN lib
TransDataLayoutFromMKLDNN( PADDLE_ENFORCE(
kernel_type_for_var, expected_kernel_type, in, &out); kernel_type_for_var.data_layout_ == DataLayout::ONEDNN &&
expected_kernel_type.data_layout_ != DataLayout::ONEDNN,
platform::errors::InvalidArgument(
"TransDataLayoutFromOneDNN only supports "
"transform from ONEDNN to non-ONEDNN"));
phi::funcs::TransDataLayoutFromOneDNN(
kernel_type_for_var.data_layout_,
phi::OneDNNContext::tls().get_cur_paddle_data_layout(),
in,
&out,
expected_kernel_type.place_);
} }
} else { } else {
// Case3 - transfrom between Non-MKLDNN OPKernels // Case3 - transfrom between Non-ONEDNN OPKernels
TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out); TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out);
} }
#else #else
// Case3 - transfrom between Non-MKLDNN OPKernels // Case3 - transfrom between Non-ONEDNN OPKernels
TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out); TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out);
#endif #endif
transformed = true; transformed = true;
......
...@@ -494,8 +494,8 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key, ...@@ -494,8 +494,8 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
if ((tensor_in->layout() == DataLayout::ONEDNN) && if ((tensor_in->layout() == DataLayout::ONEDNN) &&
(var->IsType<phi::DenseTensor>() == true) && (var->IsType<phi::DenseTensor>() == true) &&
(expected_kernel_key.data_layout_ != DataLayout::ONEDNN) && (expected_kernel_key.data_layout_ != DataLayout::ONEDNN) &&
(paddle::platform::MKLDNNDeviceContext::tls() (phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
.get_cur_paddle_data_layout() == DataLayout::kNHWC)) { DataLayout::kNHWC)) {
VLOG(7) << "Created reshaped dummy input based on MKL-DNN " VLOG(7) << "Created reshaped dummy input based on MKL-DNN "
"phi::DenseTensor , " "phi::DenseTensor , "
"but kNHWC layout" "but kNHWC layout"
......
...@@ -2304,8 +2304,8 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -2304,8 +2304,8 @@ Scope* OperatorWithKernel::PrepareData(
if ((tensor_in->layout() == DataLayout::ONEDNN) && if ((tensor_in->layout() == DataLayout::ONEDNN) &&
(var->IsType<phi::DenseTensor>() == true) && (var->IsType<phi::DenseTensor>() == true) &&
(expected_kernel_key.data_layout_ != DataLayout::ONEDNN) && (expected_kernel_key.data_layout_ != DataLayout::ONEDNN) &&
(paddle::platform::MKLDNNDeviceContext::tls() (phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
.get_cur_paddle_data_layout() == DataLayout::kNHWC) && DataLayout::kNHWC) &&
(tensor_in->dims().size() >= 3)) { (tensor_in->dims().size() >= 3)) {
// Mixed execution : oneDNN and GPU is not supported! // Mixed execution : oneDNN and GPU is not supported!
if (!new_scope) { if (!new_scope) {
...@@ -2757,8 +2757,8 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar( ...@@ -2757,8 +2757,8 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
// then we also need to rotate shape NHWC -> NCWH // then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN) &&
paddle::platform::MKLDNNDeviceContext::tls() phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
.get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) { phi::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.place(),
phi::DataLayout::kNHWC); phi::DataLayout::kNHWC);
......
...@@ -868,13 +868,12 @@ void AnalysisPredictor::MkldnnPreSet( ...@@ -868,13 +868,12 @@ void AnalysisPredictor::MkldnnPreSet(
const std::vector<std::vector<int>> &inputs_shape) { const std::vector<std::vector<int>> &inputs_shape) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
VLOG(2) << "AnalysisPredictor::ZeroCopyRun get_cur_mkldnn_session_id=" VLOG(2) << "AnalysisPredictor::ZeroCopyRun get_cur_mkldnn_session_id="
<< platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id(); << phi::OneDNNContext::tls().get_cur_mkldnn_session_id();
// In cache clearing mode. // In cache clearing mode.
if (config_.mkldnn_cache_capacity_ > 0) { if (config_.mkldnn_cache_capacity_ > 0) {
VLOG(2) << "In mkldnn cache clear mode."; VLOG(2) << "In mkldnn cache clear mode.";
platform::MKLDNNDeviceContext::tls().set_cur_mkldnn_session_id( phi::OneDNNContext::tls().set_cur_mkldnn_session_id(
platform::MKLDNNDeviceContextThreadLocals:: phi::OneDNNContextThreadLocals::kMKLDNNSessionID_CacheClearing);
kMKLDNNSessionID_CacheClearing);
// Set current_input_shape for caching dynamic shape. // Set current_input_shape for caching dynamic shape.
std::stringstream ss; std::stringstream ss;
for (size_t i = 0; i < inputs_shape.size(); ++i) { for (size_t i = 0; i < inputs_shape.size(); ++i) {
...@@ -883,9 +882,9 @@ void AnalysisPredictor::MkldnnPreSet( ...@@ -883,9 +882,9 @@ void AnalysisPredictor::MkldnnPreSet(
} }
} }
VLOG(2) << "Set input shape=" << ss.str(); VLOG(2) << "Set input shape=" << ss.str();
platform::MKLDNNDeviceContext::tls().set_cur_input_shape_str(ss.str()); phi::OneDNNContext::tls().set_cur_input_shape_str(ss.str());
} }
platform::MKLDNNDeviceContext::tls().set_cur_input_shape_cache_capacity( phi::OneDNNContext::tls().set_cur_input_shape_cache_capacity(
config_.mkldnn_cache_capacity_); config_.mkldnn_cache_capacity_);
#endif #endif
...@@ -895,11 +894,11 @@ void AnalysisPredictor::MkldnnPostReset() { ...@@ -895,11 +894,11 @@ void AnalysisPredictor::MkldnnPostReset() {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// In cache clearing mode. // In cache clearing mode.
if (config_.mkldnn_cache_capacity_ > 0 && if (config_.mkldnn_cache_capacity_ > 0 &&
static_cast<platform::MKLDNNDeviceContext *>( static_cast<phi::OneDNNContext *>(
(&platform::DeviceContextPool::Instance())->Get(platform::CPUPlace())) (&platform::DeviceContextPool::Instance())->Get(platform::CPUPlace()))
->GetCachedObjectsNumber() > 0) { ->GetCachedObjectsNumber() > 0) {
if (VLOG_IS_ON(2)) { if (VLOG_IS_ON(2)) {
auto shape_blob_size = static_cast<platform::MKLDNNDeviceContext *>( auto shape_blob_size = static_cast<phi::OneDNNContext *>(
(&platform::DeviceContextPool::Instance()) (&platform::DeviceContextPool::Instance())
->Get(platform::CPUPlace())) ->Get(platform::CPUPlace()))
->GetShapeBlobSize(); ->GetShapeBlobSize();
......
...@@ -378,10 +378,9 @@ void Tensor::CopyToCpuImpl(T *data, ...@@ -378,10 +378,9 @@ void Tensor::CopyToCpuImpl(T *data,
if (paddle::platform::is_cpu_place(t_place)) { if (paddle::platform::is_cpu_place(t_place)) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (tensor->layout() == phi::DataLayout::ONEDNN) if (tensor->layout() == phi::DataLayout::ONEDNN)
paddle::framework::innerTransDataLayoutFromMKLDNN( phi::funcs::TransDataLayoutFromOneDNN(
tensor->layout(), tensor->layout(),
paddle::platform::MKLDNNDeviceContext::tls() phi::OneDNNContext::tls().get_cur_paddle_data_layout(),
.get_cur_paddle_data_layout(),
*tensor, *tensor,
&out, &out,
paddle::platform::CPUPlace(), paddle::platform::CPUPlace(),
...@@ -661,12 +660,12 @@ std::vector<int> Tensor::shape() const { ...@@ -661,12 +660,12 @@ std::vector<int> Tensor::shape() const {
tensor_, tensor_,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"Not found tensor called %s in the scope", name_)); "Not found tensor called %s in the scope", name_));
// mkldnn may does layout transform internally, so need to reorder before // oneDNN may does layout transform internally, so need to reorder before
// return // return
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (tensor->layout() == phi::DataLayout::ONEDNN) { if (tensor->layout() == phi::DataLayout::ONEDNN) {
phi::DataLayout out_layout = paddle::platform::MKLDNNDeviceContext::tls() phi::DataLayout out_layout =
.get_cur_paddle_data_layout(); phi::OneDNNContext::tls().get_cur_paddle_data_layout();
// Set default as NCHW in case not specified // Set default as NCHW in case not specified
out_layout = out_layout == phi::DataLayout::kAnyLayout out_layout = out_layout == phi::DataLayout::kAnyLayout
? phi::DataLayout::kNCHW ? phi::DataLayout::kNCHW
...@@ -853,10 +852,9 @@ void InternalUtils::CopyToCpuWithIoStream(paddle_infer::Tensor *t, ...@@ -853,10 +852,9 @@ void InternalUtils::CopyToCpuWithIoStream(paddle_infer::Tensor *t,
if (paddle::platform::is_cpu_place(t_place)) { if (paddle::platform::is_cpu_place(t_place)) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (tensor->layout() == phi::DataLayout::ONEDNN) if (tensor->layout() == phi::DataLayout::ONEDNN)
paddle::framework::innerTransDataLayoutFromMKLDNN( phi::funcs::TransDataLayoutFromOneDNN(
tensor->layout(), tensor->layout(),
paddle::platform::MKLDNNDeviceContext::tls() phi::OneDNNContext::tls().get_cur_paddle_data_layout(),
.get_cur_paddle_data_layout(),
*tensor, *tensor,
&out, &out,
paddle::platform::CPUPlace(), paddle::platform::CPUPlace(),
......
...@@ -581,10 +581,9 @@ AnalysisPredictor::MkldnnQuantizer::Histogram( ...@@ -581,10 +581,9 @@ AnalysisPredictor::MkldnnQuantizer::Histogram(
void AnalysisPredictor::MkldnnQuantizer::ClearDeviceContext() const { void AnalysisPredictor::MkldnnQuantizer::ClearDeviceContext() const {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx = phi::OneDNNContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(predictor_.place_); (phi::OneDNNContext*)pool.Get(predictor_.place_);
dev_ctx->ResetBlobMap( dev_ctx->ResetBlobMap(phi::OneDNNContext::tls().get_curr_exec());
paddle::platform::MKLDNNDeviceContext::tls().get_curr_exec());
} }
void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const { void AnalysisPredictor::MkldnnQuantizer::PrepareArgument() const {
......
...@@ -79,8 +79,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs, ...@@ -79,8 +79,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs,
int GetNumCachedObjects(void) { int GetNumCachedObjects(void) {
auto &pool = platform::DeviceContextPool::Instance(); auto &pool = platform::DeviceContextPool::Instance();
platform::CPUPlace place; platform::CPUPlace place;
auto onednn_dev_ctx = auto onednn_dev_ctx = dynamic_cast<phi::OneDNNContext *>(pool.Get(place));
dynamic_cast<platform::MKLDNNDeviceContext *>(pool.Get(place));
return onednn_dev_ctx->GetCachedObjectsNumber(); return onednn_dev_ctx->GetCachedObjectsNumber();
} }
......
...@@ -33,13 +33,12 @@ static void DataCopy(const phi::DenseTensor &src_item, ...@@ -33,13 +33,12 @@ static void DataCopy(const phi::DenseTensor &src_item,
phi::DenseTensor out; phi::DenseTensor out;
// Convert to desired Paddle layout, apart from grads of filter // Convert to desired Paddle layout, apart from grads of filter
// as params are not a subject to paddle's data_format // as params are not a subject to paddle's data_format
VLOG(4) << "innerTransDataLayoutFromMKLDNN"; VLOG(4) << "TransDataLayoutFromOneDNN";
framework::innerTransDataLayoutFromMKLDNN( phi::funcs::TransDataLayoutFromOneDNN(
src_item.layout(), src_item.layout(),
fetch_var_name == framework::GradVarName("Filter") fetch_var_name == framework::GradVarName("Filter")
? phi::DataLayout::kNCHW ? phi::DataLayout::kNCHW
: paddle::platform::MKLDNNDeviceContext::tls() : phi::OneDNNContext::tls().get_cur_paddle_data_layout(),
.get_cur_paddle_data_layout(),
src_item, src_item,
&out, &out,
platform::CPUPlace()); platform::CPUPlace());
......
...@@ -41,12 +41,11 @@ static void DeepCopy(const phi::DenseTensor &src_item, ...@@ -41,12 +41,11 @@ static void DeepCopy(const phi::DenseTensor &src_item,
phi::DenseTensor out; phi::DenseTensor out;
// Convert to desired Paddle layout, apart from grads of filter // Convert to desired Paddle layout, apart from grads of filter
// as params are not a subject to paddle's data_format // as params are not a subject to paddle's data_format
framework::innerTransDataLayoutFromMKLDNN( phi::funcs::TransDataLayoutFromOneDNN(
src_item.layout(), src_item.layout(),
fetch_var_name == framework::GradVarName("Filter") fetch_var_name == framework::GradVarName("Filter")
? phi::DataLayout::kNCHW ? phi::DataLayout::kNCHW
: paddle::platform::MKLDNNDeviceContext::tls() : phi::OneDNNContext::tls().get_cur_paddle_data_layout(),
.get_cur_paddle_data_layout(),
src_item, src_item,
&out, &out,
platform::CPUPlace()); platform::CPUPlace());
......
...@@ -115,7 +115,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -115,7 +115,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
// if model is using NHWC and any of shapes in at least 3D // if model is using NHWC and any of shapes in at least 3D
bool should_rotate = bool should_rotate =
ctx->IsRunMKLDNNKernel() && ctx->IsRunMKLDNNKernel() &&
(platform::MKLDNNDeviceContext::tls().get_cur_paddle_data_layout() == (phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
phi::DataLayout::kNHWC) && phi::DataLayout::kNHWC) &&
(x_dims.size() >= 3 || y_dims.size() >= 3); (x_dims.size() >= 3 || y_dims.size() >= 3);
if (should_rotate) { if (should_rotate) {
...@@ -177,8 +177,8 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -177,8 +177,8 @@ class ElementwiseOp : public framework::OperatorWithKernel {
// then we also need to rotate shape NHWC -> NCWH // then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN) &&
paddle::platform::MKLDNNDeviceContext::tls() phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
.get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) { phi::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.place(),
phi::DataLayout::kNHWC); phi::DataLayout::kNHWC);
......
...@@ -28,6 +28,7 @@ using dnnl::memory; ...@@ -28,6 +28,7 @@ using dnnl::memory;
using dnnl::primitive; using dnnl::primitive;
using dnnl::stream; using dnnl::stream;
using phi::DataLayout; using phi::DataLayout;
using phi::OneDNNContext;
using phi::funcs::BinaryOneDNNHandler; using phi::funcs::BinaryOneDNNHandler;
inline std::vector<int64_t> CalculateBroadcastedDims( inline std::vector<int64_t> CalculateBroadcastedDims(
...@@ -63,9 +64,8 @@ inline void AddSubNonBroadcast( ...@@ -63,9 +64,8 @@ inline void AddSubNonBroadcast(
auto reorder_p = auto reorder_p =
reorder_handler->AcquireReorder(dst_memory, src_memory, reorder_attr); reorder_handler->AcquireReorder(dst_memory, src_memory, reorder_attr);
reorder_p->execute(platform::MKLDNNDeviceContext::tls().get_stream(), reorder_p->execute(
*src_memory, OneDNNContext::tls().get_stream(), *src_memory, *dst_memory);
*dst_memory);
} }
template <typename T> template <typename T>
...@@ -99,7 +99,7 @@ inline void BroadcastReduction(const framework::ExecutionContext& ctx, ...@@ -99,7 +99,7 @@ inline void BroadcastReduction(const framework::ExecutionContext& ctx,
dst_memory = reduction_handler.AcquireDstMemory(grad_tensor); dst_memory = reduction_handler.AcquireDstMemory(grad_tensor);
auto reduction_p = reduction_handler.AcquireForwardPrimitive(); auto reduction_p = reduction_handler.AcquireForwardPrimitive();
auto astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto astream = OneDNNContext::tls().get_stream();
reduction_p->execute(astream, reduction_p->execute(astream,
{ {
{DNNL_ARG_SRC, *src_memory}, {DNNL_ARG_SRC, *src_memory},
...@@ -126,8 +126,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> { ...@@ -126,8 +126,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto& dev_ctx = const auto& dev_ctx = ctx.template device_context<OneDNNContext>();
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<phi::DenseTensor>("X"); auto* x = ctx.Input<phi::DenseTensor>("X");
...@@ -188,7 +187,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> { ...@@ -188,7 +187,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
const auto binary_prim = handler.AcquireForwardPrimitive(); const auto binary_prim = handler.AcquireForwardPrimitive();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
const std::unordered_map<int, dnnl::memory> args = { const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_x_memory}, {DNNL_ARG_SRC_0, *src_x_memory},
...@@ -217,8 +216,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -217,8 +216,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx); ElemwiseGradKernel<T>::Compute(ctx);
auto& dev_ctx = auto& dev_ctx = ctx.template device_context<OneDNNContext>();
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine(); const auto& onednn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<phi::DenseTensor>("X"); auto* x = ctx.Input<phi::DenseTensor>("X");
...@@ -257,7 +255,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -257,7 +255,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
std::shared_ptr<dnnl::memory> dst_memory; std::shared_ptr<dnnl::memory> dst_memory;
std::shared_ptr<dnnl::memory> broadcast_src_memory = reorder_src_memory; std::shared_ptr<dnnl::memory> broadcast_src_memory = reorder_src_memory;
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
if (dx) { if (dx) {
// elementwise_add & elementwise_sub // elementwise_add & elementwise_sub
if (BINARY_OP == dnnl::algorithm::binary_add || if (BINARY_OP == dnnl::algorithm::binary_add ||
......
...@@ -20,14 +20,16 @@ limitations under the License. */ ...@@ -20,14 +20,16 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using phi::OneDNNContext;
using phi::funcs::OneDNNGetDataType; using phi::funcs::OneDNNGetDataType;
using phi::funcs::OneDNNMemDesc; using phi::funcs::OneDNNMemDesc;
using phi::funcs::RNNReorderType;
template <typename T, typename T_out = T> template <typename T, typename T_out = T>
class GRUMKLDNNHandler : public RNNMKLDNNHandler<T, dnnl::gru_forward, T_out> { class GRUMKLDNNHandler : public RNNMKLDNNHandler<T, dnnl::gru_forward, T_out> {
public: public:
GRUMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, GRUMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx, const OneDNNContext& dev_ctx,
const dnnl::engine mkldnn_engine, const dnnl::engine mkldnn_engine,
platform::Place cpu_place, platform::Place cpu_place,
const phi::DenseTensor* input, const phi::DenseTensor* input,
...@@ -142,7 +144,7 @@ class GRUMKLDNNHandler : public RNNMKLDNNHandler<T, dnnl::gru_forward, T_out> { ...@@ -142,7 +144,7 @@ class GRUMKLDNNHandler : public RNNMKLDNNHandler<T, dnnl::gru_forward, T_out> {
memory_p = std::make_shared<dnnl::memory>( memory_p = std::make_shared<dnnl::memory>(
this->fwd_pd_->weights_layer_desc(), this->engine_); this->fwd_pd_->weights_layer_desc(), this->engine_);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
dnnl::reorder(user_memory, *memory_p, this->attr_) dnnl::reorder(user_memory, *memory_p, this->attr_)
.execute(astream, user_memory, *memory_p); .execute(astream, user_memory, *memory_p);
...@@ -196,7 +198,7 @@ class GRUMKLDNNHandler : public RNNMKLDNNHandler<T, dnnl::gru_forward, T_out> { ...@@ -196,7 +198,7 @@ class GRUMKLDNNHandler : public RNNMKLDNNHandler<T, dnnl::gru_forward, T_out> {
memory_p = std::make_shared<dnnl::memory>( memory_p = std::make_shared<dnnl::memory>(
this->fwd_pd_->weights_iter_desc(), this->engine_); this->fwd_pd_->weights_iter_desc(), this->engine_);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
dnnl::reorder(user_memory, *memory_p, this->attr_) dnnl::reorder(user_memory, *memory_p, this->attr_)
.execute(astream, user_memory, *memory_p); .execute(astream, user_memory, *memory_p);
...@@ -253,8 +255,7 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> { ...@@ -253,8 +255,7 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> {
template <typename Tout = T> template <typename Tout = T>
void RunKernel(const framework::ExecutionContext& ctx) const { void RunKernel(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = auto& dev_ctx = ctx.template device_context<OneDNNContext>();
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
// Get Tensors // Get Tensors
...@@ -349,7 +350,7 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> { ...@@ -349,7 +350,7 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> {
auto gru_forward_p = handler.AcquireForwardPrimitive(); auto gru_forward_p = handler.AcquireForwardPrimitive();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
gru_forward_p->execute(astream, gru_args); gru_forward_p->execute(astream, gru_args);
astream.wait(); astream.wait();
...@@ -361,13 +362,13 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> { ...@@ -361,13 +362,13 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> {
hidden_data, hidden_data,
input_lod, input_lod,
is_reverse, is_reverse,
platform::RNNReorderType::NTC_PP); RNNReorderType::NTC_PP);
} else { } else {
handler.reorderRNNdata(hidden_onednn_data, handler.reorderRNNdata(hidden_onednn_data,
hidden_data, hidden_data,
input_lod, input_lod,
is_reverse, is_reverse,
platform::RNNReorderType::TNC_PP); RNNReorderType::TNC_PP);
} }
} }
}; };
......
...@@ -20,15 +20,17 @@ limitations under the License. */ ...@@ -20,15 +20,17 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using phi::OneDNNContext;
using phi::funcs::OneDNNGetDataType; using phi::funcs::OneDNNGetDataType;
using phi::funcs::OneDNNMemDesc; using phi::funcs::OneDNNMemDesc;
using phi::funcs::RNNReorderType;
template <typename T, typename T_out = T> template <typename T, typename T_out = T>
class LSTMMKLDNNHandler class LSTMMKLDNNHandler
: public RNNMKLDNNHandler<T, dnnl::lstm_forward, T_out> { : public RNNMKLDNNHandler<T, dnnl::lstm_forward, T_out> {
public: public:
LSTMMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, LSTMMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx, const OneDNNContext& dev_ctx,
const dnnl::engine mkldnn_engine, const dnnl::engine mkldnn_engine,
platform::Place cpu_place, platform::Place cpu_place,
const phi::DenseTensor* input, const phi::DenseTensor* input,
...@@ -186,7 +188,7 @@ class LSTMMKLDNNHandler ...@@ -186,7 +188,7 @@ class LSTMMKLDNNHandler
memory_p = std::make_shared<dnnl::memory>( memory_p = std::make_shared<dnnl::memory>(
this->fwd_pd_->weights_layer_desc(), this->engine_); this->fwd_pd_->weights_layer_desc(), this->engine_);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
dnnl::reorder(user_memory, *memory_p, this->attr_) dnnl::reorder(user_memory, *memory_p, this->attr_)
.execute(astream, user_memory, *memory_p); .execute(astream, user_memory, *memory_p);
...@@ -218,7 +220,7 @@ class LSTMMKLDNNHandler ...@@ -218,7 +220,7 @@ class LSTMMKLDNNHandler
memory_p = std::make_shared<dnnl::memory>( memory_p = std::make_shared<dnnl::memory>(
this->fwd_pd_->weights_iter_desc(), this->engine_); this->fwd_pd_->weights_iter_desc(), this->engine_);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
dnnl::reorder(user_memory, *memory_p, this->attr_) dnnl::reorder(user_memory, *memory_p, this->attr_)
.execute(astream, user_memory, *memory_p); .execute(astream, user_memory, *memory_p);
...@@ -308,7 +310,7 @@ class LSTMMKLDNNHandler ...@@ -308,7 +310,7 @@ class LSTMMKLDNNHandler
memory_p = std::make_shared<dnnl::memory>( memory_p = std::make_shared<dnnl::memory>(
this->fwd_pd_->src_iter_c_desc(), this->engine_); this->fwd_pd_->src_iter_c_desc(), this->engine_);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
dnnl::reorder(user_c0_memory, *memory_p) dnnl::reorder(user_c0_memory, *memory_p)
.execute(astream, user_c0_memory, *memory_p); .execute(astream, user_c0_memory, *memory_p);
...@@ -335,8 +337,7 @@ class FusionLSTMMKLDNNKernel : public framework::OpKernel<T> { ...@@ -335,8 +337,7 @@ class FusionLSTMMKLDNNKernel : public framework::OpKernel<T> {
template <typename Tout = T> template <typename Tout = T>
void RunKernel(const framework::ExecutionContext& ctx) const { void RunKernel(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = auto& dev_ctx = ctx.template device_context<OneDNNContext>();
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
// Get Tensors // Get Tensors
...@@ -444,7 +445,7 @@ class FusionLSTMMKLDNNKernel : public framework::OpKernel<T> { ...@@ -444,7 +445,7 @@ class FusionLSTMMKLDNNKernel : public framework::OpKernel<T> {
auto lstm_forward_p = handler.AcquireForwardPrimitive(); auto lstm_forward_p = handler.AcquireForwardPrimitive();
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
lstm_forward_p->execute(astream, lstm_args); lstm_forward_p->execute(astream, lstm_args);
astream.wait(); astream.wait();
...@@ -456,13 +457,13 @@ class FusionLSTMMKLDNNKernel : public framework::OpKernel<T> { ...@@ -456,13 +457,13 @@ class FusionLSTMMKLDNNKernel : public framework::OpKernel<T> {
hidden_data, hidden_data,
input_lod, input_lod,
is_reverse, is_reverse,
platform::RNNReorderType::NTC_PP); RNNReorderType::NTC_PP);
} else { } else {
handler.reorderRNNdata(hidden_onednn_data, handler.reorderRNNdata(hidden_onednn_data,
hidden_data, hidden_data,
input_lod, input_lod,
is_reverse, is_reverse,
platform::RNNReorderType::TNC_PP); RNNReorderType::TNC_PP);
} }
} }
}; };
......
...@@ -19,14 +19,15 @@ limitations under the License. */ ...@@ -19,14 +19,15 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using paddle::platform::CreateKey; using phi::funcs::CreateKey;
using phi::funcs::OneDNNGetDataType; using phi::funcs::OneDNNGetDataType;
using phi::funcs::RNNReorderType;
template <typename T, typename T_alg, typename T_out = T> template <typename T, typename T_alg, typename T_out = T>
class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> { class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> {
public: public:
RNNMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, RNNMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx, const phi::OneDNNContext& dev_ctx,
const dnnl::engine mkldnn_engine, const dnnl::engine mkldnn_engine,
platform::Place cpu_place, platform::Place cpu_place,
const phi::DenseTensor* input, const phi::DenseTensor* input,
...@@ -51,7 +52,7 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> { ...@@ -51,7 +52,7 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> {
G(G) { G(G) {
// Create memory key without Ti because weights, bias and h0 memories // Create memory key without Ti because weights, bias and h0 memories
// do not depend on Ti size but primitive and input/output memory do // do not depend on Ti size but primitive and input/output memory do
memory_key_ = platform::ExtendKeyWithThreadInfoIfNeeded( memory_key_ = phi::funcs::ExtendKeyWithThreadInfoIfNeeded(
dev_ctx, CreateKey(dev_ctx, unique_name, OneDNNGetDataType<T>())); dev_ctx, CreateKey(dev_ctx, unique_name, OneDNNGetDataType<T>()));
// Is it int8 kernel // Is it int8 kernel
...@@ -86,10 +87,10 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> { ...@@ -86,10 +87,10 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> {
void* output_data, void* output_data,
std::vector<size_t> lod, std::vector<size_t> lod,
const bool is_reverse, const bool is_reverse,
platform::RNNReorderType reorder_type) { RNNReorderType reorder_type) {
switch (reorder_type) { switch (reorder_type) {
// Reorder input memory [WORDS, C] + LoD -> [N, T, C] // Reorder input memory [WORDS, C] + LoD -> [N, T, C]
case platform::RNNReorderType::PP_NTC: { case RNNReorderType::PP_NTC: {
auto* input_data_iter = reinterpret_cast<T*>(input_data); auto* input_data_iter = reinterpret_cast<T*>(input_data);
auto* output_data_iter = reinterpret_cast<T*>(output_data); auto* output_data_iter = reinterpret_cast<T*>(output_data);
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
...@@ -102,7 +103,7 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> { ...@@ -102,7 +103,7 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> {
} }
} break; } break;
// Reorder input memory [WORDS, C] + LoD -> [T, N, C] // Reorder input memory [WORDS, C] + LoD -> [T, N, C]
case platform::RNNReorderType::PP_TNC: { case RNNReorderType::PP_TNC: {
auto* input_data_iter = reinterpret_cast<T*>(input_data); auto* input_data_iter = reinterpret_cast<T*>(input_data);
auto* output_data_iter = reinterpret_cast<T*>(output_data); auto* output_data_iter = reinterpret_cast<T*>(output_data);
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
...@@ -117,7 +118,7 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> { ...@@ -117,7 +118,7 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> {
} }
} break; } break;
// Reorder output values to PP format [N, T, C] -> [WORDS, C] // Reorder output values to PP format [N, T, C] -> [WORDS, C]
case platform::RNNReorderType::NTC_PP: { case RNNReorderType::NTC_PP: {
auto* input_data_iter = reinterpret_cast<T_out*>(input_data); auto* input_data_iter = reinterpret_cast<T_out*>(input_data);
auto* output_data_iter = reinterpret_cast<T_out*>(output_data); auto* output_data_iter = reinterpret_cast<T_out*>(output_data);
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
...@@ -130,7 +131,7 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> { ...@@ -130,7 +131,7 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> {
} }
} break; } break;
// Reorder output values to PP format [T, N, C] -> [WORDS, C] // Reorder output values to PP format [T, N, C] -> [WORDS, C]
case platform::RNNReorderType::TNC_PP: { case RNNReorderType::TNC_PP: {
auto* input_data_iter = reinterpret_cast<T_out*>(input_data); auto* input_data_iter = reinterpret_cast<T_out*>(input_data);
auto* output_data_iter = reinterpret_cast<T_out*>(output_data); auto* output_data_iter = reinterpret_cast<T_out*>(output_data);
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
...@@ -166,17 +167,11 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> { ...@@ -166,17 +167,11 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> {
memset(x_onednn_data, 0, sizeof(T) * N * Ti * IC); memset(x_onednn_data, 0, sizeof(T) * N * Ti * IC);
if (is_NTC(this->fwd_pd_->src_desc())) { if (is_NTC(this->fwd_pd_->src_desc())) {
reorderRNNdata(x_data, reorderRNNdata(
x_onednn_data, x_data, x_onednn_data, input_lod, is_reverse, RNNReorderType::PP_NTC);
input_lod,
is_reverse,
platform::RNNReorderType::PP_NTC);
} else { } else {
reorderRNNdata(x_data, reorderRNNdata(
x_onednn_data, x_data, x_onednn_data, input_lod, is_reverse, RNNReorderType::PP_TNC);
input_lod,
is_reverse,
platform::RNNReorderType::PP_TNC);
} }
return memory_p; return memory_p;
} }
...@@ -219,7 +214,7 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> { ...@@ -219,7 +214,7 @@ class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> {
memory_p = std::make_shared<dnnl::memory>(this->fwd_pd_->src_iter_desc(), memory_p = std::make_shared<dnnl::memory>(this->fwd_pd_->src_iter_desc(),
this->engine_); this->engine_);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = phi::OneDNNContext::tls().get_stream();
dnnl::reorder(user_h0_memory, *memory_p, attr_) dnnl::reorder(user_h0_memory, *memory_p, attr_)
.execute(astream, user_h0_memory, *memory_p); .execute(astream, user_h0_memory, *memory_p);
......
...@@ -26,11 +26,11 @@ limitations under the License. */ ...@@ -26,11 +26,11 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using paddle::platform::CreateKey;
using phi::vectorize; using phi::vectorize;
using phi::funcs::OneDNNGetDataType; using phi::funcs::OneDNNGetDataType;
using phi::funcs::OneDNNMemDesc; using phi::funcs::OneDNNMemDesc;
using Direction = dnnl::rnn_direction; using Direction = dnnl::rnn_direction;
using phi::OneDNNContext;
namespace { namespace {
...@@ -52,7 +52,7 @@ template <typename T, typename T_out = T> ...@@ -52,7 +52,7 @@ template <typename T, typename T_out = T>
class MultiGRUHandler { class MultiGRUHandler {
public: public:
MultiGRUHandler(const paddle::framework::ExecutionContext& ctx, MultiGRUHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx) const OneDNNContext& dev_ctx)
: dev_ctx_(dev_ctx), : dev_ctx_(dev_ctx),
engine_(dev_ctx.GetEngine()), engine_(dev_ctx.GetEngine()),
place_(ctx.GetPlace()), place_(ctx.GetPlace()),
...@@ -112,8 +112,9 @@ class MultiGRUHandler { ...@@ -112,8 +112,9 @@ class MultiGRUHandler {
const std::string unique_name = ctx.OutputName("Hidden"); const std::string unique_name = ctx.OutputName("Hidden");
// Create memory key without Ti because weights, bias and h0 memories // Create memory key without Ti because weights, bias and h0 memories
// do not depend on Ti size but primitive and input/output memory do // do not depend on Ti size but primitive and input/output memory do
memory_key_ = platform::ExtendKeyWithThreadInfoIfNeeded( memory_key_ = phi::funcs::ExtendKeyWithThreadInfoIfNeeded(
dev_ctx, CreateKey(dev_ctx, unique_name, OneDNNGetDataType<T>())); dev_ctx,
phi::funcs::CreateKey(dev_ctx, unique_name, OneDNNGetDataType<T>()));
key_ = memory_key_; key_ = memory_key_;
key_.append("T").append(std::to_string(Ti_)); key_.append("T").append(std::to_string(Ti_));
...@@ -320,7 +321,7 @@ class MultiGRUHandler { ...@@ -320,7 +321,7 @@ class MultiGRUHandler {
auto gru_forward_p0 = AcquireGruPrimitive(layer, dir); auto gru_forward_p0 = AcquireGruPrimitive(layer, dir);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
gru_forward_p0->execute(astream, gru_args); gru_forward_p0->execute(astream, gru_args);
astream.wait(); astream.wait();
return out_mem; return out_mem;
...@@ -343,7 +344,7 @@ class MultiGRUHandler { ...@@ -343,7 +344,7 @@ class MultiGRUHandler {
memory_p = std::make_shared<dnnl::memory>( memory_p = std::make_shared<dnnl::memory>(
gru_pds_[{layer, dir}]->src_iter_desc(), engine_); gru_pds_[{layer, dir}]->src_iter_desc(), engine_);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
dnnl::reorder(user_h0_memory, *memory_p, attrs_[2 * layer + (dir == R2L)]) dnnl::reorder(user_h0_memory, *memory_p, attrs_[2 * layer + (dir == R2L)])
.execute(astream, user_h0_memory, *memory_p); .execute(astream, user_h0_memory, *memory_p);
...@@ -383,7 +384,7 @@ class MultiGRUHandler { ...@@ -383,7 +384,7 @@ class MultiGRUHandler {
memory_p = std::make_shared<dnnl::memory>( memory_p = std::make_shared<dnnl::memory>(
gru_pds_[{layer, dir}]->weights_layer_desc(), engine_); gru_pds_[{layer, dir}]->weights_layer_desc(), engine_);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
dnnl::reorder(user_memory, *memory_p, attrs_[2 * layer + (dir == R2L)]) dnnl::reorder(user_memory, *memory_p, attrs_[2 * layer + (dir == R2L)])
.execute(astream, user_memory, *memory_p); .execute(astream, user_memory, *memory_p);
...@@ -440,7 +441,7 @@ class MultiGRUHandler { ...@@ -440,7 +441,7 @@ class MultiGRUHandler {
memory_p = std::make_shared<dnnl::memory>( memory_p = std::make_shared<dnnl::memory>(
gru_pds_[{layer, dir}]->weights_iter_desc(), engine_); gru_pds_[{layer, dir}]->weights_iter_desc(), engine_);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
dnnl::reorder(user_memory, *memory_p, attrs_[2 * layer + (dir == R2L)]) dnnl::reorder(user_memory, *memory_p, attrs_[2 * layer + (dir == R2L)])
.execute(astream, user_memory, *memory_p); .execute(astream, user_memory, *memory_p);
...@@ -547,7 +548,7 @@ class MultiGRUHandler { ...@@ -547,7 +548,7 @@ class MultiGRUHandler {
auto concat_p = AcquireConcatPrimitive(layer); auto concat_p = AcquireConcatPrimitive(layer);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
concat_p->execute(astream, concat_args); concat_p->execute(astream, concat_args);
astream.wait(); astream.wait();
return out_mem; return out_mem;
...@@ -654,7 +655,7 @@ class MultiGRUHandler { ...@@ -654,7 +655,7 @@ class MultiGRUHandler {
int64_t N_, Ti_; int64_t N_, Ti_;
std::vector<int64_t> ICs, OCs; std::vector<int64_t> ICs, OCs;
const platform::MKLDNNDeviceContext& dev_ctx_; const OneDNNContext& dev_ctx_;
const dnnl::engine engine_; const dnnl::engine engine_;
const platform::Place place_; const platform::Place place_;
const bool origin_mode_; const bool origin_mode_;
...@@ -695,8 +696,7 @@ class MultiGRUMKLDNNKernel : public framework::OpKernel<T> { ...@@ -695,8 +696,7 @@ class MultiGRUMKLDNNKernel : public framework::OpKernel<T> {
template <typename Tout = T> template <typename Tout = T>
void RunKernel(const framework::ExecutionContext& ctx) const { void RunKernel(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = auto& dev_ctx = ctx.template device_context<OneDNNContext>();
ctx.template device_context<platform::MKLDNNDeviceContext>();
MultiGRUHandler<T, Tout> handler(ctx, dev_ctx); MultiGRUHandler<T, Tout> handler(ctx, dev_ctx);
int layers = handler.getLayers(); int layers = handler.getLayers();
......
...@@ -587,7 +587,7 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -587,7 +587,7 @@ class MatMulOp : public framework::OperatorWithKernel {
// to be computed like instead x*y we are to do y*x // to be computed like instead x*y we are to do y*x
bool channelwise_onednn = bool channelwise_onednn =
context->IsRunMKLDNNKernel() && context->IsRunMKLDNNKernel() &&
(platform::MKLDNNDeviceContext::tls().get_cur_paddle_data_layout() == (phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
phi::DataLayout::kNHWC); phi::DataLayout::kNHWC);
if (channelwise_onednn) { if (channelwise_onednn) {
std::swap(dim_x, dim_y); std::swap(dim_x, dim_y);
...@@ -717,8 +717,8 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -717,8 +717,8 @@ class MatMulOp : public framework::OperatorWithKernel {
// then we also need to rotate shape NHWC -> NCWH // then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN) &&
paddle::platform::MKLDNNDeviceContext::tls() phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
.get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) { phi::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.place(),
phi::DataLayout::kNHWC); phi::DataLayout::kNHWC);
......
...@@ -154,8 +154,8 @@ class MatMulV2Op : public framework::OperatorWithKernel { ...@@ -154,8 +154,8 @@ class MatMulV2Op : public framework::OperatorWithKernel {
// op previously) then we also need to rotate shape NHWC -> NCWH // op previously) then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) && if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN) && (tensor.layout() != phi::DataLayout::ONEDNN) &&
paddle::platform::MKLDNNDeviceContext::tls() phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
.get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) { phi::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.place(),
phi::DataLayout::kNHWC); phi::DataLayout::kNHWC);
......
...@@ -49,8 +49,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -49,8 +49,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
"255 and greater or equal to 0, but got %f", "255 and greater or equal to 0, but got %f",
quantization_shift)); quantization_shift));
auto& dev_ctx = auto& dev_ctx = ctx.template device_context<phi::OneDNNContext>();
ctx.template device_context<platform::MKLDNNDeviceContext>();
auto x_tz = phi::vectorize<int64_t>(x->dims()); auto x_tz = phi::vectorize<int64_t>(x->dims());
auto x_type = phi::funcs::ToOneDNNDataType(x->dtype()); auto x_type = phi::funcs::ToOneDNNDataType(x->dtype());
...@@ -78,7 +77,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -78,7 +77,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
auto reorder_p = reorder_handler.AcquireReorder( auto reorder_p = reorder_handler.AcquireReorder(
reorder_dst_memory_p, reorder_src_memory_p, attrs); reorder_dst_memory_p, reorder_src_memory_p, attrs);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = phi::OneDNNContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait(); astream.wait();
......
...@@ -28,9 +28,9 @@ using dnnl::prop_kind; ...@@ -28,9 +28,9 @@ using dnnl::prop_kind;
using dnnl::stream; using dnnl::stream;
using framework::DDim; using framework::DDim;
using framework::ExecutionContext; using framework::ExecutionContext;
using phi::OneDNNContext;
using phi::funcs::OneDNNGetDataType; using phi::funcs::OneDNNGetDataType;
using phi::funcs::to_void_cast; using phi::funcs::to_void_cast;
using platform::MKLDNNDeviceContext;
struct InnerProductCache { struct InnerProductCache {
dnnl::inner_product_forward inner_product_p; dnnl::inner_product_forward inner_product_p;
...@@ -45,7 +45,7 @@ class FCMKLDNNHandler ...@@ -45,7 +45,7 @@ class FCMKLDNNHandler
dnnl::inner_product_forward> { dnnl::inner_product_forward> {
public: public:
FCMKLDNNHandler(const paddle::framework::ExecutionContext& ctx, FCMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx, const OneDNNContext& dev_ctx,
const phi::DenseTensor* x, const phi::DenseTensor* x,
const phi::DenseTensor* weights, const phi::DenseTensor* weights,
const phi::DenseTensor* bias, const phi::DenseTensor* bias,
...@@ -220,7 +220,7 @@ class FCMKLDNNHandler ...@@ -220,7 +220,7 @@ class FCMKLDNNHandler
auto reorder_p = std::make_shared<dnnl::reorder>( auto reorder_p = std::make_shared<dnnl::reorder>(
*user_memory_p, *target_memory_p, attrs); *user_memory_p, *target_memory_p, attrs);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
{ {
platform::RecordEvent record_reorder( platform::RecordEvent record_reorder(
"int_reorder", "int_reorder",
...@@ -237,7 +237,7 @@ class FCMKLDNNHandler ...@@ -237,7 +237,7 @@ class FCMKLDNNHandler
} }
std::string memory_key_; std::string memory_key_;
const platform::MKLDNNDeviceContext& dev_ctx_; const OneDNNContext& dev_ctx_;
public: public:
std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorder( std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorder(
...@@ -388,7 +388,7 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> { ...@@ -388,7 +388,7 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
dnnl::memory x_mem(x_md, engine, to_void_cast<T_in>(x->data<T_in>())); dnnl::memory x_mem(x_md, engine, to_void_cast<T_in>(x->data<T_in>()));
auto reorder_p = dnnl::reorder(x_mem, *src_mem); auto reorder_p = dnnl::reorder(x_mem, *src_mem);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
reorder_p.execute(astream, x_mem, *src_mem); reorder_p.execute(astream, x_mem, *src_mem);
astream.wait(); astream.wait();
} else { } else {
...@@ -398,8 +398,7 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> { ...@@ -398,8 +398,7 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
template <typename T_out, typename T_w> template <typename T_out, typename T_w>
void RunKernel(const framework::ExecutionContext& ctx) const { void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx = const auto& dev_ctx = ctx.template device_context<OneDNNContext>();
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
const auto* x = ctx.Input<phi::DenseTensor>("Input"); const auto* x = ctx.Input<phi::DenseTensor>("Input");
...@@ -417,12 +416,12 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> { ...@@ -417,12 +416,12 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
std::string cache_key; std::string cache_key;
cache_key.reserve(64); cache_key.reserve(64);
cache_key = platform::ExtendKeyWithThreadInfoIfNeeded( cache_key = phi::funcs::ExtendKeyWithThreadInfoIfNeeded(
dev_ctx, dev_ctx,
platform::CreateKey(dev_ctx, phi::funcs::CreateKey(dev_ctx,
ctx.InputName("Input"), ctx.InputName("Input"),
ctx.InputName("W"), ctx.InputName("W"),
phi::vectorize(x->dims()))); phi::vectorize(x->dims())));
auto inner_product_cache = auto inner_product_cache =
std::static_pointer_cast<InnerProductCache>(dev_ctx.GetBlob(cache_key)); std::static_pointer_cast<InnerProductCache>(dev_ctx.GetBlob(cache_key));
...@@ -479,7 +478,7 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> { ...@@ -479,7 +478,7 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
fc_p = handler.AcquireForwardPrimitive(); fc_p = handler.AcquireForwardPrimitive();
} }
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
std::unordered_map<int, dnnl::memory> fc_args = { std::unordered_map<int, dnnl::memory> fc_args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_SRC, *src_memory_p},
......
...@@ -130,8 +130,7 @@ class InterpolateOneDNNKernel : public framework::OpKernel<T> { ...@@ -130,8 +130,7 @@ class InterpolateOneDNNKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto& dev_ctx = const auto& dev_ctx = ctx.template device_context<phi::OneDNNContext>();
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
const auto* x = ctx.Input<phi::DenseTensor>("X"); const auto* x = ctx.Input<phi::DenseTensor>("X");
...@@ -155,7 +154,7 @@ class InterpolateOneDNNKernel : public framework::OpKernel<T> { ...@@ -155,7 +154,7 @@ class InterpolateOneDNNKernel : public framework::OpKernel<T> {
auto resampling_prim = handler.AcquireForwardPrimitive(); auto resampling_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = { const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = phi::OneDNNContext::tls().get_stream();
resampling_prim->execute(astream, args); resampling_prim->execute(astream, args);
astream.wait(); astream.wait();
......
...@@ -98,8 +98,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -98,8 +98,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis"); const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
auto& dev_ctx = auto& dev_ctx = ctx.template device_context<phi::OneDNNContext>();
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
auto src_tz = phi::vectorize(x->dims()); auto src_tz = phi::vectorize(x->dims());
...@@ -125,7 +124,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -125,7 +124,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto layer_norm_p = handler.AcquireForwardPrimitive(); auto layer_norm_p = handler.AcquireForwardPrimitive();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = phi::OneDNNContext::tls().get_stream();
std::unordered_map<int, dnnl::memory> args = {{DNNL_ARG_SRC, *src_memory}, std::unordered_map<int, dnnl::memory> args = {{DNNL_ARG_SRC, *src_memory},
{DNNL_ARG_DST, *dst_memory}}; {DNNL_ARG_DST, *dst_memory}};
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using paddle::platform::MKLDNNDeviceContext; using phi::OneDNNContext;
template <typename T> template <typename T>
class LRNOneDNNHandler class LRNOneDNNHandler
...@@ -124,8 +124,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -124,8 +124,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
true, true,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"Operator DNNL LRN must use CPUPlace")); "Operator DNNL LRN must use CPUPlace"));
auto& dev_ctx = auto& dev_ctx = ctx.template device_context<OneDNNContext>();
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
auto x = ctx.Input<phi::DenseTensor>("X"); auto x = ctx.Input<phi::DenseTensor>("X");
...@@ -142,7 +141,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -142,7 +141,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto workspace_memory = handler.AcquireWorkspaceMemory(mid); auto workspace_memory = handler.AcquireWorkspaceMemory(mid);
mid->set_layout(phi::DataLayout::ONEDNN); mid->set_layout(phi::DataLayout::ONEDNN);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
if (!workspace_memory->get_desc().is_zero()) { if (!workspace_memory->get_desc().is_zero()) {
mid->set_mem_desc(workspace_memory->get_desc()); mid->set_mem_desc(workspace_memory->get_desc());
lrn_p->execute(astream, lrn_p->execute(astream,
...@@ -179,7 +178,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -179,7 +178,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto out_grad = ctx.Input<phi::DenseTensor>(framework::GradVarName("Out")); auto out_grad = ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto in_x_grad = ctx.Output<phi::DenseTensor>(framework::GradVarName("X")); auto in_x_grad = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<OneDNNContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
LRNOneDNNHandler<T> handler( LRNOneDNNHandler<T> handler(
...@@ -192,7 +191,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -192,7 +191,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto lrn_bwd = handler.AcquireBackwardPrimitive(); auto lrn_bwd = handler.AcquireBackwardPrimitive();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
lrn_bwd->execute(astream, lrn_bwd->execute(astream,
{{DNNL_ARG_SRC, *src_memory}, {{DNNL_ARG_SRC, *src_memory},
{DNNL_ARG_DIFF_DST, *diff_dst_memory}, {DNNL_ARG_DIFF_DST, *diff_dst_memory},
......
...@@ -21,12 +21,11 @@ namespace { ...@@ -21,12 +21,11 @@ namespace {
using dnnl::memory; using dnnl::memory;
using paddle::framework::ExecutionContext; using paddle::framework::ExecutionContext;
using paddle::platform::MatMulV2MKLDNNHandler; using paddle::platform::MatMulV2MKLDNNHandler;
using paddle::platform::MKLDNNDeviceContext; using phi::OneDNNContext;
using phi::vectorize; using phi::vectorize;
using phi::funcs::OneDNNGetDataType; using phi::funcs::OneDNNGetDataType;
using Tensor = phi::DenseTensor; using Tensor = phi::DenseTensor;
using paddle::framework::GradVarName; using paddle::framework::GradVarName;
using phi::make_ddim;
// Reshape a rank-3 tensor from P x M x N to (P * M) x N. // Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3. // Identity op if the tensor is not of rank 3.
...@@ -43,7 +42,7 @@ static Tensor FoldOuterDims(const Tensor &input) { ...@@ -43,7 +42,7 @@ static Tensor FoldOuterDims(const Tensor &input) {
// (Warning: This requires transposing data and writes into new memory.) // (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3. // Identity op if the tensor is not of rank 3.
template <typename T> template <typename T>
static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext &dev_ctx, static Tensor FoldFirstAndLastDims(const OneDNNContext &dev_ctx,
const Tensor *input) { const Tensor *input) {
auto input_dims = vectorize(input->dims()); auto input_dims = vectorize(input->dims());
if (input_dims.size() != 3) { if (input_dims.size() != 3) {
...@@ -55,8 +54,7 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext &dev_ctx, ...@@ -55,8 +54,7 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext &dev_ctx,
auto output_dims = vectorize(output.dims()); auto output_dims = vectorize(output.dims());
memory::data_type input_type = paddle::framework::ToMKLDNNDataType( memory::data_type input_type = phi::funcs::ToOneDNNDataType(input->dtype());
paddle::framework::TransToProtoVarType(input->dtype()));
phi::funcs::ReorderOneDNNHandler reorder_handler( phi::funcs::ReorderOneDNNHandler reorder_handler(
output_dims, input->dtype(), input_type, dev_ctx.GetEngine()); output_dims, input->dtype(), input_type, dev_ctx.GetEngine());
...@@ -67,7 +65,7 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext &dev_ctx, ...@@ -67,7 +65,7 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext &dev_ctx,
auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p, auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
reorder_dst_memory_p); reorder_dst_memory_p);
auto &astream = MKLDNNDeviceContext::tls().get_stream(); auto &astream = OneDNNContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait(); astream.wait();
...@@ -153,7 +151,7 @@ class MatMulMKLDNNHandler ...@@ -153,7 +151,7 @@ class MatMulMKLDNNHandler
{DNNL_ARG_WEIGHTS, *weights_memory_p}, {DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_DST, *dst_memory_p}};
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto &astream = OneDNNContext::tls().get_stream();
// Simulate batch matmul by processing in loop // Simulate batch matmul by processing in loop
void *x_ptr = src_memory_p->get_data_handle(); void *x_ptr = src_memory_p->get_data_handle();
...@@ -366,7 +364,7 @@ void ExecuteMatMulV2(const ExecutionContext &ctx, ...@@ -366,7 +364,7 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
*residual_data_memory_p}); *residual_data_memory_p});
} }
auto &astream = MKLDNNDeviceContext::tls().get_stream(); auto &astream = OneDNNContext::tls().get_stream();
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
astream.wait(); astream.wait();
...@@ -402,7 +400,7 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -402,7 +400,7 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
: false; : false;
constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses
const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); const auto &dev_ctx = ctx.template device_context<OneDNNContext>();
const auto &onednn_engine = dev_ctx.GetEngine(); const auto &onednn_engine = dev_ctx.GetEngine();
auto *x = ctx.Input<phi::DenseTensor>("X"); auto *x = ctx.Input<phi::DenseTensor>("X");
...@@ -531,8 +529,7 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -531,8 +529,7 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
ctx.Attr<int>("head_number"))); ctx.Attr<int>("head_number")));
} }
const auto &dev_ctx = const auto &dev_ctx = ctx.template device_context<OneDNNContext>();
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto &onednn_engine = dev_ctx.GetEngine(); const auto &onednn_engine = dev_ctx.GetEngine();
auto x = *ctx.Input<phi::DenseTensor>("X"); auto x = *ctx.Input<phi::DenseTensor>("X");
...@@ -639,7 +636,7 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -639,7 +636,7 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
private: private:
void ExecuteMatMulGrad(const ExecutionContext &ctx, void ExecuteMatMulGrad(const ExecutionContext &ctx,
const MKLDNNDeviceContext &dev_ctx, const OneDNNContext &dev_ctx,
const dnnl::engine &engine, const dnnl::engine &engine,
phi::DenseTensor *x, phi::DenseTensor *x,
bool trans_x, bool trans_x,
...@@ -685,7 +682,7 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -685,7 +682,7 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
{DNNL_ARG_WEIGHTS, *weights_memory_p}, {DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_DST, *dst_memory_p}};
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream(); auto &astream = OneDNNContext::tls().get_stream();
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
astream.wait(); astream.wait();
......
...@@ -27,8 +27,8 @@ namespace operators { ...@@ -27,8 +27,8 @@ namespace operators {
using framework::DDim; using framework::DDim;
using framework::ExecutionContext; using framework::ExecutionContext;
using phi::OneDNNContext;
using platform::MatMulV2MKLDNNHandler; using platform::MatMulV2MKLDNNHandler;
using platform::MKLDNNDeviceContext;
using dnnl::inner_product_forward; using dnnl::inner_product_forward;
using dnnl::memory; using dnnl::memory;
...@@ -105,7 +105,7 @@ class MulPrimitiveFactory { ...@@ -105,7 +105,7 @@ class MulPrimitiveFactory {
auto reorder = dnnl::reorder(reorder_pd); auto reorder = dnnl::reorder(reorder_pd);
auto &astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto &astream = OneDNNContext::tls().get_stream();
{ {
platform::RecordEvent record_reorder( platform::RecordEvent record_reorder(
"int_reorder", "int_reorder",
...@@ -183,7 +183,7 @@ class MulPrimitiveFactory { ...@@ -183,7 +183,7 @@ class MulPrimitiveFactory {
} }
void Execute() { void Execute() {
auto &astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto &astream = OneDNNContext::tls().get_stream();
(*mul_).execute(astream, (*mul_).execute(astream,
{{DNNL_ARG_SRC, *x_input_}, {{DNNL_ARG_SRC, *x_input_},
{DNNL_ARG_WEIGHTS, *y_input_}, {DNNL_ARG_WEIGHTS, *y_input_},
...@@ -278,7 +278,7 @@ class MulPrimitiveFactory { ...@@ -278,7 +278,7 @@ class MulPrimitiveFactory {
auto reorder = dnnl::reorder(src_mem, dst_mem); auto reorder = dnnl::reorder(src_mem, dst_mem);
auto &astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto &astream = OneDNNContext::tls().get_stream();
{ {
platform::RecordEvent record_reorder( platform::RecordEvent record_reorder(
"int_reorder", "int_reorder",
...@@ -313,19 +313,19 @@ class MulPrimitiveFactory { ...@@ -313,19 +313,19 @@ class MulPrimitiveFactory {
/* OT: output data type */ /* OT: output data type */
template <typename XT, typename YT, typename OT> template <typename XT, typename YT, typename OT>
std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory( std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory(
const MKLDNNDeviceContext &dev_ctx, const OneDNNContext &dev_ctx,
const ExecutionContext &ctx, const ExecutionContext &ctx,
const Tensor *input_x, const Tensor *input_x,
const Tensor *input_y, const Tensor *input_y,
const dnnl::engine &mkldnn_engine) { const dnnl::engine &mkldnn_engine) {
std::string key = std::string key =
platform::CreateKey(dev_ctx, phi::funcs::CreateKey(dev_ctx,
framework::TransToProtoVarType(input_x->dtype()), framework::TransToProtoVarType(input_x->dtype()),
phi::vectorize(input_x->dims()), phi::vectorize(input_x->dims()),
framework::TransToProtoVarType(input_y->dtype()), framework::TransToProtoVarType(input_y->dtype()),
phi::vectorize(input_y->dims()), phi::vectorize(input_y->dims()),
ctx.OutputName("Out")); ctx.OutputName("Out"));
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); key = phi::funcs::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
auto prim_creator = std::static_pointer_cast<MulPrimitiveFactory<XT, YT, OT>>( auto prim_creator = std::static_pointer_cast<MulPrimitiveFactory<XT, YT, OT>>(
dev_ctx.GetBlob(key)); dev_ctx.GetBlob(key));
...@@ -341,7 +341,7 @@ std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory( ...@@ -341,7 +341,7 @@ std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory(
/* XT: input x data type, YT: input y data type */ /* XT: input x data type, YT: input y data type */
template <typename XT, typename YT> template <typename XT, typename YT>
inner_product_forward GetMulPrimitive(const MKLDNNDeviceContext &dev_ctx, inner_product_forward GetMulPrimitive(const OneDNNContext &dev_ctx,
const ExecutionContext &ctx, const ExecutionContext &ctx,
const Tensor *input_x, const Tensor *input_x,
const Tensor *input_y, const Tensor *input_y,
...@@ -372,8 +372,8 @@ class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> { ...@@ -372,8 +372,8 @@ class MulMKLDNNINT8Kernel : public framework::OpKernel<XT> {
true, true,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"Operator DNNL Mul must use CPUPlace")); "Operator DNNL Mul must use CPUPlace"));
platform::MKLDNNDeviceContext::tls().log_lib_version(); OneDNNContext::tls().log_lib_version();
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto &dev_ctx = ctx.template device_context<OneDNNContext>();
auto &mkldnn_engine = dev_ctx.GetEngine(); auto &mkldnn_engine = dev_ctx.GetEngine();
const Tensor *x = ctx.Input<phi::DenseTensor>("X"); const Tensor *x = ctx.Input<phi::DenseTensor>("X");
...@@ -401,7 +401,7 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> { ...@@ -401,7 +401,7 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
protected: protected:
void ExecuteMatMul(const ExecutionContext &ctx, void ExecuteMatMul(const ExecutionContext &ctx,
const MKLDNNDeviceContext &dev_ctx, const OneDNNContext &dev_ctx,
const dnnl::engine &onednn_engine, const dnnl::engine &onednn_engine,
const platform::Place &cpu_place, const platform::Place &cpu_place,
const Tensor *x, const Tensor *x,
...@@ -434,7 +434,7 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> { ...@@ -434,7 +434,7 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
{DNNL_ARG_WEIGHTS, *weights_memory_p}, {DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_DST, *dst_memory_p}};
auto &astream = MKLDNNDeviceContext::tls().get_stream(); auto &astream = OneDNNContext::tls().get_stream();
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
astream.wait(); astream.wait();
...@@ -447,7 +447,7 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> { ...@@ -447,7 +447,7 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
private: private:
void RunKernel(const ExecutionContext &ctx) const { void RunKernel(const ExecutionContext &ctx) const {
const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); const auto &dev_ctx = ctx.template device_context<OneDNNContext>();
const auto &onednn_engine = dev_ctx.GetEngine(); const auto &onednn_engine = dev_ctx.GetEngine();
const auto *x = ctx.Input<phi::DenseTensor>("X"); const auto *x = ctx.Input<phi::DenseTensor>("X");
......
...@@ -51,8 +51,7 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -51,8 +51,7 @@ class QuantOpKernel : public framework::OpKernel<T> {
"255 and greater or equal to 0, but got %f", "255 and greater or equal to 0, but got %f",
quantization_shift)); quantization_shift));
auto& dev_ctx = auto& dev_ctx = ctx.template device_context<phi::OneDNNContext>();
ctx.template device_context<platform::MKLDNNDeviceContext>();
auto x_tz = phi::vectorize<int64_t>(x->dims()); auto x_tz = phi::vectorize<int64_t>(x->dims());
...@@ -95,7 +94,7 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -95,7 +94,7 @@ class QuantOpKernel : public framework::OpKernel<T> {
auto reorder_p = reorder_handler.AcquireReorder( auto reorder_p = reorder_handler.AcquireReorder(
reorder_dst_memory_p, reorder_src_memory_p, attrs); reorder_dst_memory_p, reorder_src_memory_p, attrs);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = phi::OneDNNContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait(); astream.wait();
......
...@@ -63,8 +63,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -63,8 +63,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
"shift for signed input.")); "shift for signed input."));
} }
auto& dev_ctx = auto& dev_ctx = ctx.template device_context<phi::OneDNNContext>();
ctx.template device_context<platform::MKLDNNDeviceContext>();
auto src_tz = phi::vectorize(input->dims()); auto src_tz = phi::vectorize(input->dims());
...@@ -102,7 +101,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -102,7 +101,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
auto reorder_p = auto reorder_p =
reorder_handler.AcquireReorder(dst_memory_p, src_memory_p, attrs); reorder_handler.AcquireReorder(dst_memory_p, src_memory_p, attrs);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = phi::OneDNNContext::tls().get_stream();
reorder_p->execute(astream, *src_memory_p, *dst_memory_p); reorder_p->execute(astream, *src_memory_p, *dst_memory_p);
astream.wait(); astream.wait();
......
...@@ -59,8 +59,7 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> { ...@@ -59,8 +59,7 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
private: private:
void RunKernel(const framework::ExecutionContext& ctx) const { void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx = const auto& dev_ctx = ctx.template device_context<phi::OneDNNContext>();
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine(); const auto& onednn_engine = dev_ctx.GetEngine();
auto* x = ctx.Input<phi::DenseTensor>("X"); auto* x = ctx.Input<phi::DenseTensor>("X");
...@@ -84,7 +83,7 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> { ...@@ -84,7 +83,7 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p); reorder_src_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = phi::OneDNNContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait(); astream.wait();
...@@ -304,8 +303,7 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> { ...@@ -304,8 +303,7 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> {
private: private:
void RunKernel(const framework::ExecutionContext& ctx) const { void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx = const auto& dev_ctx = ctx.template device_context<phi::OneDNNContext>();
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine(); const auto& onednn_engine = dev_ctx.GetEngine();
auto* dout = ctx.Input<phi::DenseTensor>(framework::GradVarName("Out")); auto* dout = ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
...@@ -329,7 +327,7 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> { ...@@ -329,7 +327,7 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> {
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
reorder_src_memory_p); reorder_src_memory_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = phi::OneDNNContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait(); astream.wait();
......
...@@ -37,8 +37,7 @@ template <typename T> ...@@ -37,8 +37,7 @@ template <typename T>
class ShuffleChannelMKLDNNKernel : public framework::OpKernel<T> { class ShuffleChannelMKLDNNKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto& dev_ctx = const auto& dev_ctx = ctx.template device_context<phi::OneDNNContext>();
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
const auto* x = ctx.Input<phi::DenseTensor>("X"); const auto* x = ctx.Input<phi::DenseTensor>("X");
...@@ -55,7 +54,7 @@ class ShuffleChannelMKLDNNKernel : public framework::OpKernel<T> { ...@@ -55,7 +54,7 @@ class ShuffleChannelMKLDNNKernel : public framework::OpKernel<T> {
auto shuffle_p = handler.AcquireForwardPrimitive(); auto shuffle_p = handler.AcquireForwardPrimitive();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = phi::OneDNNContext::tls().get_stream();
shuffle_p->execute( shuffle_p->execute(
astream, astream,
{{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}); {{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}});
......
...@@ -52,8 +52,7 @@ class CacheTester { ...@@ -52,8 +52,7 @@ class CacheTester {
// Clear oneDNN cache // Clear oneDNN cache
auto &pool = platform::DeviceContextPool::Instance(); auto &pool = platform::DeviceContextPool::Instance();
platform::CPUPlace place; platform::CPUPlace place;
onednn_dev_ctx_ = onednn_dev_ctx_ = dynamic_cast<phi::OneDNNContext *>(pool.Get(place));
dynamic_cast<platform::MKLDNNDeviceContext *>(pool.Get(place));
onednn_dev_ctx_->ResetBlobMap(nullptr); onednn_dev_ctx_->ResetBlobMap(nullptr);
} }
...@@ -63,7 +62,7 @@ class CacheTester { ...@@ -63,7 +62,7 @@ class CacheTester {
} }
private: private:
platform::MKLDNNDeviceContext *onednn_dev_ctx_; phi::OneDNNContext *onednn_dev_ctx_;
}; };
template <typename T> template <typename T>
......
...@@ -23,6 +23,7 @@ namespace operators { ...@@ -23,6 +23,7 @@ namespace operators {
using Tensor = phi::DenseTensor; using Tensor = phi::DenseTensor;
using phi::DataLayout; using phi::DataLayout;
using phi::OneDNNContext;
template <typename T> template <typename T>
class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
...@@ -32,15 +33,14 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -32,15 +33,14 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
true, true,
paddle::platform::errors::PreconditionNotMet( paddle::platform::errors::PreconditionNotMet(
"Operator DNNL Transpose must use CPUPlace")); "Operator DNNL Transpose must use CPUPlace"));
auto& dev_ctx = auto& dev_ctx = ctx.template device_context<OneDNNContext>();
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& dnnl_engine = dev_ctx.GetEngine(); const auto& dnnl_engine = dev_ctx.GetEngine();
std::vector<int> transpose_axis = ctx.Attr<std::vector<int>>("axis"); std::vector<int> transpose_axis = ctx.Attr<std::vector<int>>("axis");
int ndims = transpose_axis.size(); int ndims = transpose_axis.size();
const phi::DenseTensor* x = ctx.Input<Tensor>("X"); const phi::DenseTensor* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
platform::SetInMemDescWithLogicalLayoutFusesSupport( platform::SetInMemDescWithLogicalLayoutFusesSupport(
ctx, const_cast<phi::DenseTensor*>(x), x->mem_desc()); ctx, const_cast<phi::DenseTensor*>(x), x->mem_desc());
...@@ -131,12 +131,11 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -131,12 +131,11 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); const auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
if (!dx) return; if (!dx) return;
auto& dev_ctx = auto& dev_ctx = ctx.template device_context<OneDNNContext>();
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& dnnl_engine = dev_ctx.GetEngine(); const auto& dnnl_engine = dev_ctx.GetEngine();
std::vector<int> transpose_axis = ctx.Attr<std::vector<int>>("axis"); std::vector<int> transpose_axis = ctx.Attr<std::vector<int>>("axis");
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
int ndims = transpose_axis.size(); int ndims = transpose_axis.size();
if (ndims == 1) { if (ndims == 1) {
......
...@@ -78,52 +78,51 @@ class TransferLayoutFunctor { ...@@ -78,52 +78,51 @@ class TransferLayoutFunctor {
"No layout transform needed between two oneDNN OPKernels.")); "No layout transform needed between two oneDNN OPKernels."));
if (in_layout != DataLayout::ONEDNN && out_layout == DataLayout::ONEDNN) { if (in_layout != DataLayout::ONEDNN && out_layout == DataLayout::ONEDNN) {
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel // Case1 - transform from Non-ONEDNN OPKernel to ONEDNN OPKernel
// Just set layout/format. No real transform occur // Just set layout/format. No real transform occur
auto out_format = phi::funcs::OneDNNFormatForSize( auto out_format = phi::funcs::OneDNNFormatForSize(
in_tensor.dims().size(), framework::ToOneDNNFormat(in_layout)); in_tensor.dims().size(), phi::funcs::ToOneDNNFormat(in_layout));
out_tensor.ShareDataWith(in_tensor); out_tensor.ShareDataWith(in_tensor);
// For NHWC data we need reshape of tensors as MKL-DNN // For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order // is expecting NHWC dims description order
if (in_layout == DataLayout::kNHWC) { if (in_layout == DataLayout::kNHWC) {
VLOG(4) << "kNHWC"; VLOG(4) << "kNHWC";
phi::funcs::MatchShapeToLayout(&out_tensor, in_layout, out_layout); phi::funcs::MatchShapeToLayout(&out_tensor, in_layout, out_layout);
paddle::platform::MKLDNNDeviceContext::tls() phi::OneDNNContext::tls().set_cur_paddle_data_layout(in_layout);
.set_cur_paddle_data_layout(in_layout);
} }
auto out_tz = phi::vectorize<int64_t>(out_tensor.dims()); auto out_tz = phi::vectorize<int64_t>(out_tensor.dims());
dnnl::memory::data_type in_type = framework::ToMKLDNNDataType( dnnl::memory::data_type in_type =
framework::TransToProtoVarType(in_tensor.dtype())); phi::funcs::ToOneDNNDataType(in_tensor.dtype());
dnnl::memory::desc out_mem_desc(out_tz, in_type, out_format); dnnl::memory::desc out_mem_desc(out_tz, in_type, out_format);
out_tensor.set_mem_desc(out_mem_desc); out_tensor.set_mem_desc(out_mem_desc);
} else { } else {
auto target_layout = paddle::platform::MKLDNNDeviceContext::tls() auto target_layout =
.get_cur_paddle_data_layout(); phi::OneDNNContext::tls().get_cur_paddle_data_layout();
// NOTE(zhiqiu): hot fix, follow the same logic in DataCopy() in // NOTE(zhiqiu): hot fix, follow the same logic in DataCopy() in
// fetch_op.cc // fetch_op.cc
if (out_layout == DataLayout::kNCHW && if (out_layout == DataLayout::kNCHW &&
in_name_ == framework::GradVarName("Filter")) { in_name_ == framework::GradVarName("Filter")) {
target_layout = out_layout; target_layout = out_layout;
} }
VLOG(4) << "innerTransDataLayoutFromMKLDNN: " << in_layout << "->" VLOG(4) << "TransDataLayoutFromOneDNN: " << in_layout << "->"
<< target_layout; << target_layout;
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel // Case2 - transfrom from ONEDNN OPKernel to Non-ONEDNN OPKernel
// Do transform via MKLDNN lib // Do transform via ONEDNN lib
paddle::framework::innerTransDataLayoutFromMKLDNN(in_layout, phi::funcs::TransDataLayoutFromOneDNN(in_layout,
target_layout, target_layout,
in_tensor, in_tensor,
&out_tensor, &out_tensor,
dev_ctx_.GetPlace()); dev_ctx_.GetPlace());
} }
} else { } else {
// Case3 - transfrom between Non-MKLDNN OPKernels // Case3 - transfrom between Non-ONEDNN OPKernels
TransDataLayout(dev_ctx_, in_tensor, &out_tensor); TransDataLayout(dev_ctx_, in_tensor, &out_tensor);
} }
#else #else
// Case3 - transfrom between Non-MKLDNN OPKernels // Case3 - transfrom between Non-ONEDNN OPKernels
TransDataLayout(dev_ctx_, in_tensor, &out_tensor); TransDataLayout(dev_ctx_, in_tensor, &out_tensor);
#endif #endif
framework::SetTensorToVariable(*in_, out_tensor, out_); framework::SetTensorToVariable(*in_, out_tensor, out_);
......
...@@ -82,8 +82,8 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -82,8 +82,8 @@ class TransposeOp : public framework::OperatorWithKernel {
// Here we need to match dims to paddle layout // Here we need to match dims to paddle layout
// as we are producing non-oneDNN result // as we are producing non-oneDNN result
if (ctx->IsRunMKLDNNKernel() && (x_dims.size() >= 3) && if (ctx->IsRunMKLDNNKernel() && (x_dims.size() >= 3) &&
(paddle::platform::MKLDNNDeviceContext::tls() (phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
.get_cur_paddle_data_layout() == phi::DataLayout::kNHWC)) { phi::DataLayout::kNHWC)) {
auto dims = phi::vectorize<int>(x_dims); auto dims = phi::vectorize<int>(x_dims);
std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end()); std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
x_dims = x_dims.reshape(dims); x_dims = x_dims.reshape(dims);
......
...@@ -224,7 +224,7 @@ void EmplaceDeviceContexts( ...@@ -224,7 +224,7 @@ void EmplaceDeviceContexts(
for (auto& p : set) { for (auto& p : set) {
if (platform::is_cpu_place(p)) { if (platform::is_cpu_place(p)) {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
EmplaceDeviceContext<MKLDNNDeviceContext>( EmplaceDeviceContext<phi::OneDNNContext>(
place_to_device_context, place_to_device_context,
p, p,
disable_setting_default_stream_for_allocator); disable_setting_default_stream_for_allocator);
......
...@@ -312,11 +312,6 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> { ...@@ -312,11 +312,6 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
}; };
#endif #endif
#ifdef PADDLE_WITH_MKLDNN
using MKLDNNDeviceContextThreadLocals = phi::OneDNNContextThreadLocals;
using MKLDNNDeviceContext = phi::OneDNNContext;
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE #ifdef PADDLE_WITH_CUSTOM_DEVICE
class CustomDeviceContext : public phi::CustomContext { class CustomDeviceContext : public phi::CustomContext {
public: public:
......
...@@ -24,27 +24,20 @@ limitations under the License. */ ...@@ -24,27 +24,20 @@ limitations under the License. */
#include "dnnl.hpp" // NOLINT #include "dnnl.hpp" // NOLINT
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/phi/backends/onednn/onednn_helper.h" #include "paddle/phi/backends/onednn/onednn_helper.h"
namespace paddle { namespace paddle {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
using OneDNNMemoryFormat = dnnl::memory::format_tag; using OneDNNMemoryFormat = dnnl::memory::format_tag;
using phi::OneDNNContext;
#endif #endif
namespace platform { namespace platform {
template <class Type>
using tf_desc = typename Type::desc;
template <class Type>
using tf_pd = typename Type::primitive_desc;
inline void ClearMKLDNNCache(const platform::Place& place, inline void ClearMKLDNNCache(const platform::Place& place,
void* ptr = nullptr) { void* ptr = nullptr) {
// Clear mkl-dnn cache, // Clear mkl-dnn cache,
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx = OneDNNContext* dev_ctx = reinterpret_cast<OneDNNContext*>(pool.Get(place));
(platform::MKLDNNDeviceContext*)pool.Get(place);
dev_ctx->ResetBlobMap(ptr); dev_ctx->ResetBlobMap(ptr);
} }
} }
...@@ -53,71 +46,11 @@ inline void DontClearMKLDNNCache(const platform::Place& place) { ...@@ -53,71 +46,11 @@ inline void DontClearMKLDNNCache(const platform::Place& place) {
// Clear mkl-dnn cache, // Clear mkl-dnn cache,
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx = OneDNNContext* dev_ctx = reinterpret_cast<OneDNNContext*>(pool.Get(place));
(platform::MKLDNNDeviceContext*)pool.Get(place);
dev_ctx->BlockNextCacheClearing(); dev_ctx->BlockNextCacheClearing();
} }
} }
inline void Reorder(dnnl::memory src,
dnnl::memory dst,
const dnnl::engine& engine) {
auto reorder_prim = dnnl::reorder(src, dst);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder",
platform::TracerEventType::UserDefined,
2,
platform::EventRole::kUniqueOp);
reorder_prim.execute(astream, src, dst);
astream.wait();
}
inline std::string ThreadIDasStr(void) {
return std::to_string(
std::hash<std::thread::id>()(std::this_thread::get_id()));
}
template <typename T>
inline void AppendKey(std::string* key, const T& num) {
key->append(std::to_string(num));
}
template <>
inline void AppendKey(std::string* key,
const dnnl::memory::format_tag& format) {
key->append(std::to_string(static_cast<int>(format)));
}
template <>
inline void AppendKey(std::string* key,
const dnnl::memory::data_type& data_type) {
key->append(std::to_string(static_cast<int>(data_type)));
}
template <>
inline void AppendKey(std::string* key, const dnnl::algorithm& algorithm) {
key->append(std::to_string(static_cast<int>(algorithm)));
}
template <>
inline void AppendKey(std::string* key,
const dnnl::normalization_flags& flags) {
key->append(std::to_string(static_cast<int>(flags)));
}
inline void AppendKey(std::string* key, const std::string& str) {
key->append(str);
}
inline void AppendKey(std::string* key, const char* str) { key->append(str); }
template <typename T>
inline void AppendKey(std::string* key, const std::vector<T>& dims) {
for (size_t i = 0; i < dims.size(); i++) {
AppendKey(key, std::to_string(dims[i]));
}
}
// If MKLDNN build and CPU place then register suffix in DeviceContext // If MKLDNN build and CPU place then register suffix in DeviceContext
inline void AttachPointerHashToMKLDNNKey(void* ptr, inline void AttachPointerHashToMKLDNNKey(void* ptr,
const platform::Place& place) { const platform::Place& place) {
...@@ -128,49 +61,30 @@ inline void AttachPointerHashToMKLDNNKey(void* ptr, ...@@ -128,49 +61,30 @@ inline void AttachPointerHashToMKLDNNKey(void* ptr,
static std::mutex static_vars_barrier; static std::mutex static_vars_barrier;
static_vars_barrier.lock(); static_vars_barrier.lock();
static auto first_exec = ptr; static auto first_exec = ptr;
static auto first_thread = ThreadIDasStr(); static auto first_thread = phi::funcs::ThreadIDasStr();
static_vars_barrier.unlock(); static_vars_barrier.unlock();
if (first_exec != ptr) { if (first_exec != ptr) {
paddle::platform::MKLDNNDeviceContext::tls().set_key_suffix( OneDNNContext::tls().set_key_suffix(
"E" + std::to_string(reinterpret_cast<uintptr_t>(ptr))); "E" + std::to_string(reinterpret_cast<uintptr_t>(ptr)));
} }
// Let's register adress of current executor // Let's register adress of current executor
paddle::platform::MKLDNNDeviceContext::tls().set_curr_exec(ptr); OneDNNContext::tls().set_curr_exec(ptr);
// For first thread // For first thread
if (first_thread == ThreadIDasStr()) { if (first_thread == phi::funcs::ThreadIDasStr()) {
paddle::platform::MKLDNNDeviceContext::tls().disable_tid_in_key(); OneDNNContext::tls().disable_tid_in_key();
} }
} }
} }
template <typename... ArgTypes>
inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx,
ArgTypes&&... args) {
std::string key;
key.reserve(64);
using expand_type = int[];
expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...};
key += paddle::platform::MKLDNNDeviceContext::tls().get_key_suffix();
return key;
}
inline std::string ExtendKeyWithThreadInfoIfNeeded(
const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) {
return (paddle::platform::MKLDNNDeviceContext::tls().is_tid_used_in_key() ==
true)
? key + "-t:" + ThreadIDasStr()
: key;
}
inline void RegisterModelLayout( inline void RegisterModelLayout(
std::vector<std::unique_ptr<framework::OperatorBase>>& ops, // NOLINT std::vector<std::unique_ptr<framework::OperatorBase>>& ops, // NOLINT
const platform::Place& place) { const platform::Place& place) {
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
// If there is already registered NHWC then quit this call // If there is already registered NHWC then quit this call
// not to overwrite setting with analysis of internal "while" op block // not to overwrite setting with analysis of internal "while" op block
if (platform::MKLDNNDeviceContext::tls().get_cur_paddle_data_layout() == if (OneDNNContext::tls().get_cur_paddle_data_layout() ==
phi::DataLayout::kNHWC) phi::DataLayout::kNHWC)
return; return;
...@@ -179,7 +93,7 @@ inline void RegisterModelLayout( ...@@ -179,7 +93,7 @@ inline void RegisterModelLayout(
const std::string& attrib_name) -> bool { const std::string& attrib_name) -> bool {
if (op->HasAttr(attrib_name)) { if (op->HasAttr(attrib_name)) {
auto data_format = op->Attr<std::string>(attrib_name); auto data_format = op->Attr<std::string>(attrib_name);
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout( OneDNNContext::tls().set_cur_paddle_data_layout(
data_format.compare("NHWC") == 0 ? phi::DataLayout::kNHWC data_format.compare("NHWC") == 0 ? phi::DataLayout::kNHWC
: phi::DataLayout::kNCHW); : phi::DataLayout::kNCHW);
return true; return true;
...@@ -208,8 +122,6 @@ inline bool HasOpBFLOAT16DataType(const paddle::framework::OpDesc* op) { ...@@ -208,8 +122,6 @@ inline bool HasOpBFLOAT16DataType(const paddle::framework::OpDesc* op) {
return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "bfloat16"; return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "bfloat16";
} }
enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP };
} // namespace platform } // namespace platform
inline std::string FindInputNameByVarName(framework::OpDesc* op, inline std::string FindInputNameByVarName(framework::OpDesc* op,
......
...@@ -284,5 +284,7 @@ inline std::string ExtendKeyWithThreadInfoIfNeeded(const OneDNNContext& dev_ctx, ...@@ -284,5 +284,7 @@ inline std::string ExtendKeyWithThreadInfoIfNeeded(const OneDNNContext& dev_ctx,
: key; : key;
} }
enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP };
} // namespace funcs } // namespace funcs
} // namespace phi } // namespace phi
...@@ -48,16 +48,16 @@ void* GetDataFromTensor(const DenseTensor& tensor, ...@@ -48,16 +48,16 @@ void* GetDataFromTensor(const DenseTensor& tensor,
case dnnl::memory::data_type::bf16: case dnnl::memory::data_type::bf16:
return to_void_cast(tensor.data<dtype::bfloat16>()); return to_void_cast(tensor.data<dtype::bfloat16>());
default: default:
PADDLE_THROW(errors::InvalidArgument("Wrong mkldnn type provided.")); PADDLE_THROW(errors::InvalidArgument("Wrong oneDNN type provided."));
} }
} }
void innerTransDataLayoutFromOneDNN(DataLayout in_layout, void TransDataLayoutFromOneDNN(DataLayout in_layout,
DataLayout out_layout, DataLayout out_layout,
const DenseTensor& in, const DenseTensor& in,
DenseTensor* out, DenseTensor* out,
Place place, Place place,
bool always_copy) { bool always_copy) {
// Set default as NCHW in case not specified // Set default as NCHW in case not specified
out_layout = out_layout == DataLayout::ANY ? DataLayout::NCHW : out_layout; out_layout = out_layout == DataLayout::ANY ? DataLayout::NCHW : out_layout;
......
...@@ -43,7 +43,7 @@ inline OneDNNMemoryFormat ToOneDNNFormat(const DataLayout& layout) { ...@@ -43,7 +43,7 @@ inline OneDNNMemoryFormat ToOneDNNFormat(const DataLayout& layout) {
return OneDNNMemoryFormat::ndhwc; return OneDNNMemoryFormat::ndhwc;
default: default:
PADDLE_THROW( PADDLE_THROW(
errors::InvalidArgument("Fail to convert layout %s to MKLDNN format.", errors::InvalidArgument("Fail to convert layout %s to oneDNN format.",
::phi::DataLayoutToString(layout))); ::phi::DataLayoutToString(layout)));
} }
} }
...@@ -77,12 +77,12 @@ inline OneDNNDataType ToOneDNNDataType(DataType type) { ...@@ -77,12 +77,12 @@ inline OneDNNDataType ToOneDNNDataType(DataType type) {
return OneDNNDataType::undef; return OneDNNDataType::undef;
} }
void innerTransDataLayoutFromOneDNN(DataLayout in_layout, void TransDataLayoutFromOneDNN(DataLayout in_layout,
DataLayout out_layout, DataLayout out_layout,
const DenseTensor& in, const DenseTensor& in,
DenseTensor* out, DenseTensor* out,
Place place, Place place,
bool always_copy = false); bool always_copy = false);
void* GetDataFromTensor(const DenseTensor& tensor, OneDNNDataType type); void* GetDataFromTensor(const DenseTensor& tensor, OneDNNDataType type);
#endif #endif
......
...@@ -130,7 +130,7 @@ void TransferLayoutMKLDNN(const Context& dev_ctx, ...@@ -130,7 +130,7 @@ void TransferLayoutMKLDNN(const Context& dev_ctx,
dst_layout != DataLayout::ONEDNN) { dst_layout != DataLayout::ONEDNN) {
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel // Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel
// Do transform via MKLDNN lib // Do transform via MKLDNN lib
funcs::innerTransDataLayoutFromOneDNN( funcs::TransDataLayoutFromOneDNN(
src_layout, dst_layout, x, out, dev_ctx.GetPlace()); src_layout, dst_layout, x, out, dev_ctx.GetPlace());
} else if (src_layout == DataLayout::ONEDNN && } else if (src_layout == DataLayout::ONEDNN &&
dst_layout == DataLayout::ONEDNN) { dst_layout == DataLayout::ONEDNN) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册