未验证 提交 8a339d24 编写于 作者: S Sławomir Siwek 提交者: GitHub

mkldnn directory cleanup (#47779)

* cleanup unused code

* unify is_int8 is_bfloat16

* Simplify matmul_v2 FWD kernel

* remove RunKernel methods

* remove import namespace

* remove headers

* clean fluid/phi cross imports

* remove fluid axpy_handler

* delete fluid methods

* activations

* OneDNNMemDesc

* MKLDNNFormatForSize

* MatchShapeToLayout

* MKLDNNMemoryFormat

* MKLDNNFormat

* ReorderMKLDNNHandler

* to_void_cast

* review suggestions

* interpolate

* remove fluid depedency
上级 4e09b089
......@@ -101,15 +101,16 @@ void* GetDataFromTensor(const phi::DenseTensor& tensor,
dnnl::memory::data_type type) {
switch (type) {
case dnnl::memory::data_type::f32:
return platform::to_void_cast(tensor.data<float>());
return phi::funcs::to_void_cast(tensor.data<float>());
case dnnl::memory::data_type::s8:
return platform::to_void_cast(tensor.data<int8_t>());
return phi::funcs::to_void_cast(tensor.data<int8_t>());
case dnnl::memory::data_type::u8:
return platform::to_void_cast(tensor.data<unsigned char>());
return phi::funcs::to_void_cast(tensor.data<unsigned char>());
case dnnl::memory::data_type::s32:
return platform::to_void_cast(tensor.data<int32_t>());
return phi::funcs::to_void_cast(tensor.data<int32_t>());
case dnnl::memory::data_type::bf16:
return platform::to_void_cast(tensor.data<paddle::platform::bfloat16>());
return phi::funcs::to_void_cast(
tensor.data<paddle::platform::bfloat16>());
default:
PADDLE_THROW(
platform::errors::InvalidArgument("Wrong mkldnn type provided."));
......@@ -125,7 +126,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
auto place = expected_kernel_type.place_;
PADDLE_ENFORCE(
in_layout == DataLayout::kMKLDNN && out_layout != DataLayout::kMKLDNN,
in_layout == DataLayout::ONEDNN && out_layout != DataLayout::ONEDNN,
platform::errors::InvalidArgument(
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"non-MKLDNN"));
......@@ -165,7 +166,7 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout,
DataTypeToString(framework::TransToProtoVarType(in.dtype()))));
auto out_format =
platform::MKLDNNFormatForSize(in_tz.size(), ToMKLDNNFormat(out_layout));
phi::funcs::OneDNNFormatForSize(in_tz.size(), ToOneDNNFormat(out_layout));
dnnl::memory::desc out_mem_desc(out_tz, in_type, out_format);
// output tensor has the same dims as input. Reorder don't change dims
......@@ -177,8 +178,8 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout,
if (in.initialized() && ((in.mem_desc() != out->mem_desc()) || always_copy)) {
void* in_data = GetDataFromTensor(in, in_type);
platform::ReorderMKLDNNHandler handler(
in_tz, framework::TransToProtoVarType(in.dtype()), in_type, cpu_engine);
phi::funcs::ReorderOneDNNHandler handler(
in_tz, in.dtype(), in_type, cpu_engine);
auto reorder_src_memory_p =
handler.AcquireSrcMemory(in.mem_desc(), in_data);
......@@ -199,7 +200,7 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout,
}
// For exepected NHWC data format we need to reshape the Output tensor
// As MKL-DNN description was in NCHW and paddle is expecting NHWC
platform::MatchShapeToLayout(out, in_layout, out_layout);
phi::funcs::MatchShapeToLayout(out, in_layout, out_layout);
out->set_layout(DataLayout::kNCHW);
}
......
......@@ -52,51 +52,35 @@ struct CastDataLayout {
};
#ifdef PADDLE_WITH_MKLDNN
using MKLDNNDataType = dnnl::memory::data_type;
using OneDNNDataType = dnnl::memory::data_type;
inline MKLDNNMemoryFormat ToMKLDNNFormat(const DataLayout& layout) {
inline OneDNNMemoryFormat ToOneDNNFormat(const DataLayout& layout) {
switch (layout) {
case DataLayout::kNHWC:
return MKLDNNMemoryFormat::nhwc;
return OneDNNMemoryFormat::nhwc;
case DataLayout::kNCHW:
return MKLDNNMemoryFormat::nchw;
return OneDNNMemoryFormat::nchw;
case DataLayout::kNCDHW:
return MKLDNNMemoryFormat::ncdhw;
return OneDNNMemoryFormat::ncdhw;
case DataLayout::kNDHWC:
return MKLDNNMemoryFormat::ndhwc;
return OneDNNMemoryFormat::ndhwc;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Fail to convert layout %s to MKLDNN format.",
"Fail to convert layout %s to oneDNN format.",
phi::DataLayoutToString(layout)));
}
}
inline DataLayout ToPaddleLayout(const MKLDNNMemoryFormat& format) {
switch (format) {
case MKLDNNMemoryFormat::nhwc:
return DataLayout::kNHWC;
case MKLDNNMemoryFormat::nchw:
return DataLayout::kNCHW;
case MKLDNNMemoryFormat::ncdhw:
return DataLayout::kNCDHW;
case MKLDNNMemoryFormat::ndhwc:
return DataLayout::kNDHWC;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Fail to convert MKLDNN format to paddle layout."));
}
}
inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
static std::unordered_map<int, MKLDNNDataType> dict{
{DataTypeTrait<float>::DataType(), MKLDNNDataType::f32},
{DataTypeTrait<int8_t>::DataType(), MKLDNNDataType::s8},
{DataTypeTrait<uint8_t>::DataType(), MKLDNNDataType::u8},
{DataTypeTrait<int32_t>::DataType(), MKLDNNDataType::s32},
{DataTypeTrait<platform::bfloat16>::DataType(), MKLDNNDataType::bf16}};
inline OneDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
static std::unordered_map<int, OneDNNDataType> dict{
{DataTypeTrait<float>::DataType(), OneDNNDataType::f32},
{DataTypeTrait<int8_t>::DataType(), OneDNNDataType::s8},
{DataTypeTrait<uint8_t>::DataType(), OneDNNDataType::u8},
{DataTypeTrait<int32_t>::DataType(), OneDNNDataType::s32},
{DataTypeTrait<platform::bfloat16>::DataType(), OneDNNDataType::bf16}};
auto iter = dict.find(static_cast<int>(type));
if (iter != dict.end()) return iter->second;
return MKLDNNDataType::undef;
return OneDNNDataType::undef;
}
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout,
......@@ -111,7 +95,7 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
const phi::DenseTensor& in,
phi::DenseTensor* out);
void* GetDataFromTensor(const phi::DenseTensor& tensor, MKLDNNDataType type);
void* GetDataFromTensor(const phi::DenseTensor& tensor, OneDNNDataType type);
#endif
......
......@@ -54,9 +54,8 @@ TEST(DataTransformBf16, GetDataFromTensorDNNL) {
void* in_data =
paddle::framework::GetDataFromTensor(in, dnnl::memory::data_type::bf16);
EXPECT_EQ(
in_data,
paddle::platform::to_void_cast(in.data<paddle::platform::bfloat16>()));
EXPECT_EQ(in_data,
phi::funcs::to_void_cast(in.data<paddle::platform::bfloat16>()));
}
TEST(DataTransformInt32, GetDataFromTensorDNNL) {
......@@ -66,6 +65,6 @@ TEST(DataTransformInt32, GetDataFromTensorDNNL) {
void* in_data =
paddle::framework::GetDataFromTensor(in, dnnl::memory::data_type::s32);
EXPECT_EQ(in_data, paddle::platform::to_void_cast(in.data<int32_t>()));
EXPECT_EQ(in_data, phi::funcs::to_void_cast(in.data<int32_t>()));
}
#endif
......@@ -49,24 +49,24 @@ void TransformData(const OpKernelType &expected_kernel_type,
// do layout transform
if (NeedTransformLayout(lout, lin)) {
#ifdef PADDLE_WITH_MKLDNN
if (lin == DataLayout::kMKLDNN || lout == DataLayout::kMKLDNN) {
if (lin == DataLayout::ONEDNN || lout == DataLayout::ONEDNN) {
PADDLE_ENFORCE_EQ(
!(lin == DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN),
!(lin == DataLayout::ONEDNN && lout == DataLayout::ONEDNN),
true,
platform::errors::PreconditionNotMet(
"No layout transform needed between two MKLDNN OPKernels."));
"No layout transform needed between two oneDNN OPKernels."));
if (lin != DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN) {
if (lin != DataLayout::ONEDNN && lout == DataLayout::ONEDNN) {
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
// Just set layout/format. No real transform occur
auto out_format = platform::MKLDNNFormatForSize(in.dims().size(),
ToMKLDNNFormat(lin));
auto out_format = phi::funcs::OneDNNFormatForSize(in.dims().size(),
ToOneDNNFormat(lin));
out.ShareDataWith(input_tensor);
// For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order
if (lin == DataLayout::kNHWC || lin == DataLayout::kNDHWC) {
platform::MatchShapeToLayout(&out, lin, lout);
phi::funcs::MatchShapeToLayout(&out, lin, lout);
// We register only NHWC assuming that model is consistent e.g. either
// NHWC or NCHW
paddle::platform::MKLDNNDeviceContext::tls()
......
......@@ -25,7 +25,7 @@ namespace ir {
using string::PrettyLogDetail;
void ConvActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = paddle::platform::GetSupportedActivations();
auto act_types = phi::funcs::GetSupportedActivations();
std::vector<std::string> conv_types = {"conv2d"};
for (auto& act_type : act_types) {
......@@ -64,7 +64,7 @@ void ConvActivationMkldnnFusePass::FuseConvAct(Graph* graph,
OpDesc* conv_op = conv->Op();
OpDesc* act_op = activation->Op();
auto attr_map = paddle::platform::GetAttributeMap(act_type);
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) {
conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
......@@ -145,7 +145,7 @@ void ConvActivationMkldnnFusePass::FuseConvConcatAct(
OpDesc* conv_op = node->inputs[0]->Op();
OpDesc* act_op = activation_op->Op();
auto attr_map = paddle::platform::GetAttributeMap(act_type);
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) {
conv_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
......
......@@ -27,7 +27,7 @@ namespace ir {
using string::PrettyLogDetail;
void ElementwiseActivationOneDNNPass::ApplyImpl(Graph *graph) const {
auto act_types = paddle::platform::GetSupportedActivations();
auto act_types = phi::funcs::GetSupportedActivations();
std::vector<std::string> elt_types = {
"elementwise_add", "elementwise_sub", "elementwise_mul"};
......@@ -76,7 +76,7 @@ void ElementwiseActivationOneDNNPass::FuseElementwiseAct(
}
auto *activation_op = activation->Op();
auto attr_map = paddle::platform::GetAttributeMap(act_type);
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (activation_op->HasAttr(attr.first)) {
elementwise_op->SetAttr(attr.second,
......
......@@ -25,7 +25,7 @@ namespace ir {
using string::PrettyLogDetail;
void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
auto act_types = paddle::platform::GetSupportedActivations();
auto act_types = phi::funcs::GetSupportedActivations();
for (auto act_type : act_types) FuseFCAct(graph, act_type);
}
......@@ -61,7 +61,7 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
"is used."));
}
auto attr_map = paddle::platform::GetAttributeMap(act_type);
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (act_op->HasAttr(attr.first)) {
fc_op->SetAttr(attr.second, act_op->GetAttr(attr.first));
......
......@@ -31,7 +31,7 @@ namespace ir {
class Graph;
void InterpolateMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
void InterpolateOneDNNPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
......@@ -70,4 +70,4 @@ void InterpolateMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
} // namespace paddle
REGISTER_PASS(interpolate_mkldnn_pass,
paddle::framework::ir::InterpolateMKLDNNPass);
paddle::framework::ir::InterpolateOneDNNPass);
......@@ -28,9 +28,9 @@ namespace ir {
*/
class Graph;
class InterpolateMKLDNNPass : public FusePassBase {
class InterpolateOneDNNPass : public FusePassBase {
public:
virtual ~InterpolateMKLDNNPass() {}
virtual ~InterpolateOneDNNPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
......
......@@ -25,7 +25,7 @@ namespace ir {
using string::PrettyLogDetail;
void MatmulActivationMkldnnFusePass::ApplyImpl(Graph* graph) const {
auto act_types = paddle::platform::GetSupportedActivations();
auto act_types = phi::funcs::GetSupportedActivations();
auto matmul_types = {"matmul", "matmul_v2"};
for (const auto& matmul_type : matmul_types)
......@@ -64,7 +64,7 @@ void MatmulActivationMkldnnFusePass::FuseMatmulAct(
OpDesc* matmul_op = matmul->Op();
OpDesc* act_op = activation->Op();
auto attr_map = paddle::platform::GetAttributeMap(act_type);
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto& attrs : attr_map) {
if (act_op->HasAttr(attrs.first)) {
matmul_op->SetAttr(attrs.second, act_op->GetAttr(attrs.first));
......
......@@ -27,7 +27,7 @@ namespace ir {
using string::PrettyLogDetail;
void SoftplusActivationOneDNNPass::ApplyImpl(Graph *graph) const {
auto act_types = paddle::platform::GetSupportedActivations();
auto act_types = phi::funcs::GetSupportedActivations();
// Currently softplus can't be fused with hard_sigmoid
act_types.erase(
......@@ -75,7 +75,7 @@ void SoftplusActivationOneDNNPass::FuseSoftplusActivation(
}
auto *activation_op = activation->Op();
auto attr_map = paddle::platform::GetAttributeMap(act_type);
auto attr_map = phi::funcs::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (activation_op->HasAttr(attr.first)) {
softplus_op->SetAttr(attr.second, activation_op->GetAttr(attr.first));
......
......@@ -230,7 +230,7 @@ std::shared_ptr<OperatorBase> TransferLayout(const std::string& var_name,
#ifdef PADDLE_WITH_MKLDNN
// NOTE(zhiqiu): hot fix, follow the same logic in DataCopy() in fetch_op.cc
if (in_layout == phi::DataLayout::kMKLDNN &&
if (in_layout == phi::DataLayout::ONEDNN &&
var_name == framework::GradVarName("Filter") && is_fetch_v2) {
VLOG(4) << "Match special case(Filter && fetch_v2) " << var_name;
out_layout = phi::DataLayout::kNCHW;
......@@ -484,9 +484,9 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
// MKL-DNN shape of Var may differ from kNHWC Var
// In such situation corressponding resized Var
// has to be created and registered
if ((tensor_in->layout() == DataLayout::kMKLDNN) &&
if ((tensor_in->layout() == DataLayout::ONEDNN) &&
(var->IsType<phi::DenseTensor>() == true) &&
(expected_kernel_key.data_layout_ != DataLayout::kMKLDNN) &&
(expected_kernel_key.data_layout_ != DataLayout::ONEDNN) &&
(paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == DataLayout::kNHWC)) {
VLOG(7) << "Created reshaped dummy input based on MKL-DNN "
......
......@@ -244,7 +244,7 @@ void InterpretercoreInferShapeContext::ShareAllLoD(
auto* out_tensor = out_var->GetMutable<phi::DenseTensor>();
out_tensor->set_lod(in_tensor.lod());
#ifdef PADDLE_WITH_MKLDNN
if (in_tensor.layout() != DataLayout::kMKLDNN)
if (in_tensor.layout() != DataLayout::ONEDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
......@@ -309,7 +309,7 @@ void InterpretercoreInferShapeContext::ShareLoD(const std::string& in,
// This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
// OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
// in Compute()
if (in_tensor.layout() != DataLayout::kMKLDNN)
if (in_tensor.layout() != DataLayout::ONEDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
......@@ -338,7 +338,7 @@ bool InterpretercoreInferShapeContext::IsRunMKLDNNKernel() const {
auto& op_with_kernel = dynamic_cast<const OperatorWithKernel&>(op_);
return ((op_with_kernel.kernel_type()) &&
(op_with_kernel.kernel_type()->data_layout_ ==
phi::DataLayout::kMKLDNN));
phi::DataLayout::ONEDNN));
} catch (std::bad_cast& exp) {
return false;
}
......
......@@ -102,8 +102,8 @@ inline bool NeedTransformLayout(const DataLayout& l, const DataLayout& r) {
(l != DataLayout::kAnyLayout && r != DataLayout::kAnyLayout && l != r);
#ifdef PADDLE_WITH_MKLDNN
// Layout transform needed for either non-MKLDNN to MKLDNN or vice versa
ret |= (l != DataLayout::kMKLDNN && r == DataLayout::kMKLDNN);
ret |= (l == DataLayout::kMKLDNN && r != DataLayout::kMKLDNN);
ret |= (l != DataLayout::ONEDNN && r == DataLayout::ONEDNN);
ret |= (l == DataLayout::ONEDNN && r != DataLayout::ONEDNN);
#endif
return ret;
}
......
......@@ -913,7 +913,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
auto* out_tensor = out_var->GetMutable<phi::DenseTensor>();
out_tensor->set_lod(in_tensor.lod());
#ifdef PADDLE_WITH_MKLDNN
if (in_tensor.layout() != DataLayout::kMKLDNN)
if (in_tensor.layout() != DataLayout::ONEDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
......@@ -978,7 +978,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
// This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
// OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
// in Compute()
if (in_tensor.layout() != DataLayout::kMKLDNN)
if (in_tensor.layout() != DataLayout::ONEDNN)
#endif
out_tensor->set_layout(in_tensor.layout());
}
......@@ -1006,7 +1006,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
auto& op_with_kernel = dynamic_cast<const OperatorWithKernel&>(op_);
return ((op_with_kernel.kernel_type()) &&
(op_with_kernel.kernel_type()->data_layout_ ==
phi::DataLayout::kMKLDNN));
phi::DataLayout::ONEDNN));
} catch (const std::bad_cast& exp) {
return false;
}
......@@ -1441,7 +1441,7 @@ bool OperatorWithKernel::SupportsKernelType(
this->CanMKLDNNBeUsed(exe_ctx, kernel_type.data_type_)) {
auto tmp_kernel_type = kernel_type;
tmp_kernel_type.library_type_ = framework::LibraryType::kMKLDNN;
tmp_kernel_type.data_layout_ = framework::DataLayout::kMKLDNN;
tmp_kernel_type.data_layout_ = framework::DataLayout::ONEDNN;
return kernels.find(tmp_kernel_type) != kernels.end();
}
#endif
......@@ -1637,7 +1637,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
phi_kernel_name = kernel_signature_->name;
// NOTE(jiahongyu): The registered MKLDNN kernel have library_type =
// LibraryType::kMKLDNN and data_layout_ = DataLayout::kMKLDNN. But the default
// LibraryType::kMKLDNN and data_layout_ = DataLayout::ONEDNN. But the default
// values are kPlain, so we need to modify the library_type and data_layout_
// here. There are three statements in if condition:
// 1. Whether mkldnn kernel fallbacks to plain kernel;
......@@ -1648,7 +1648,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
!paddle::platform::in_mkldnn_white_list(type_) &&
this->CanMKLDNNBeUsed(exe_ctx, kernel_type_->data_type_)) {
kernel_type_->library_type_ = framework::LibraryType::kMKLDNN;
kernel_type_->data_layout_ = framework::DataLayout::kMKLDNN;
kernel_type_->data_layout_ = framework::DataLayout::ONEDNN;
}
#endif
......@@ -1897,7 +1897,7 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
if (!this->DnnFallback() && !paddle::platform::in_mkldnn_white_list(type_) &&
this->CanMKLDNNBeUsed(ctx, expected_kernel_key.data_type_)) {
expected_kernel_key.library_type_ = framework::LibraryType::kMKLDNN;
expected_kernel_key.data_layout_ = framework::DataLayout::kMKLDNN;
expected_kernel_key.data_layout_ = framework::DataLayout::ONEDNN;
}
#endif
......@@ -2295,16 +2295,16 @@ Scope* OperatorWithKernel::PrepareData(
// Var without buffer may be needed
// for some situation like InferShape().
// In this situation We cannot skip Var analysis, as
// MKL-DNN shape of Var may differ from kNHWC Var
// oneDNN shape of Var may differ from kNHWC Var
// In such situation corressponding resized Var
// has to be created and registered
if ((tensor_in->layout() == DataLayout::kMKLDNN) &&
if ((tensor_in->layout() == DataLayout::ONEDNN) &&
(var->IsType<phi::DenseTensor>() == true) &&
(expected_kernel_key.data_layout_ != DataLayout::kMKLDNN) &&
(expected_kernel_key.data_layout_ != DataLayout::ONEDNN) &&
(paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == DataLayout::kNHWC) &&
(tensor_in->dims().size() >= 3)) {
// Mixed execution : MKL-DNN and GPU is not supported!
// Mixed execution : oneDNN and GPU is not supported!
if (!new_scope) {
new_scope = &scope.NewScope();
}
......@@ -2312,9 +2312,9 @@ Scope* OperatorWithKernel::PrepareData(
in_vars->at(i) = trans_var;
auto out = trans_var->GetMutable<phi::DenseTensor>();
out->Resize(tensor_in->dims());
platform::MatchShapeToLayout(
phi::funcs::MatchShapeToLayout(
out, tensor_in->layout(), DataLayout::kNHWC);
VLOG(7) << "Created reshaped dummy input based on MKL-DNN "
VLOG(7) << "Created reshaped dummy input based on oneDNN "
"phi::DenseTensor , "
"but kNHWC layout"
<< in_name << " in Operator " << type_;
......@@ -2752,8 +2752,8 @@ OpKernelType OperatorWithKernel::GetKernelTypeForVar(
// When the op is first oneDNN op (there was some non oneDNN op
// previously)
// then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN) &&
if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN) &&
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
......
......@@ -70,7 +70,7 @@ TEST(PhiUtils, TransOpKernelTypeToPhiKernelKey) {
paddle::framework::OpKernelType op_kernel_type_mkldnn(
paddle::framework::proto::VarType::FP32,
paddle::platform::CPUPlace(),
phi::DataLayout::kMKLDNN,
phi::DataLayout::ONEDNN,
paddle::framework::LibraryType::kMKLDNN);
auto kernel_key_mkldnn =
paddle::framework::TransOpKernelTypeToPhiKernelKey(op_kernel_type_mkldnn);
......
......@@ -57,7 +57,7 @@ void TensorCopyImpl(const TENSOR& src,
// oneDNN tensors due to padding may be of bigger size
// than numel()*size(type())
auto dst_ptr =
src.layout() == DataLayout::kMKLDNN
src.layout() == DataLayout::ONEDNN
? dst->mutable_data(dst_place, src.dtype(), src.memory_size())
: dst->mutable_data(dst_place, src.dtype());
#else
......@@ -72,7 +72,7 @@ void TensorCopyImpl(const TENSOR& src,
VLOG(4) << "src:" << src_ptr << ", dst:" << dst_ptr;
#ifdef PADDLE_WITH_MKLDNN
auto size = src.layout() == DataLayout::kMKLDNN
auto size = src.layout() == DataLayout::ONEDNN
? src.memory_size()
: src.numel() * framework::DataTypeSize(src.dtype());
#else
......@@ -471,7 +471,7 @@ void TensorCopySync(const phi::DenseTensor& src,
dst->Resize(src.dims());
dst->set_layout(src.layout());
#ifdef PADDLE_WITH_MKLDNN
if (src.layout() == DataLayout::kMKLDNN) {
if (src.layout() == DataLayout::ONEDNN) {
dst->set_mem_desc(src.mem_desc());
}
#endif
......
......@@ -251,7 +251,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
bool IsRunMKLDNNKernel() const override {
return (op_kernel_type_ &&
(op_kernel_type_->data_layout_ == phi::DataLayout::kMKLDNN));
(op_kernel_type_->data_layout_ == phi::DataLayout::ONEDNN));
}
paddle::small_vector<framework::InferShapeVarPtr, phi::kInputSmallVectorSize>
......
......@@ -232,7 +232,7 @@ PreparedOp PrepareImpl(
std::string phi_kernel_name;
// NOTE(jiahongyu): The registered MKLDNN kernel have library_type =
// LibraryType::kMKLDNN and data_layout_ = DataLayout::kMKLDNN. But the default
// LibraryType::kMKLDNN and data_layout_ = DataLayout::ONEDNN. But the default
// values are kPlain, so we need to modify the library_type and data_layout_
// here. There are three statements in if condition:
// 1. Whether mkldnn kernel fallbacks to plain kernel;
......@@ -242,7 +242,7 @@ PreparedOp PrepareImpl(
if (!op.DnnFallback() && !paddle::platform::in_mkldnn_white_list(op.Type()) &&
op.CanMKLDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.data_type_)) {
expected_kernel_key.library_type_ = framework::LibraryType::kMKLDNN;
expected_kernel_key.data_layout_ = framework::DataLayout::kMKLDNN;
expected_kernel_key.data_layout_ = framework::DataLayout::ONEDNN;
}
#endif
......
......@@ -377,7 +377,7 @@ void Tensor::CopyToCpuImpl(T *data,
if (paddle::platform::is_cpu_place(t_place)) {
#ifdef PADDLE_WITH_MKLDNN
if (tensor->layout() == phi::DataLayout::kMKLDNN)
if (tensor->layout() == phi::DataLayout::ONEDNN)
paddle::framework::innerTransDataLayoutFromMKLDNN(
tensor->layout(),
paddle::platform::MKLDNNDeviceContext::tls()
......@@ -664,7 +664,7 @@ std::vector<int> Tensor::shape() const {
// mkldnn may does layout transform internally, so need to reorder before
// return
#ifdef PADDLE_WITH_MKLDNN
if (tensor->layout() == phi::DataLayout::kMKLDNN) {
if (tensor->layout() == phi::DataLayout::ONEDNN) {
phi::DataLayout out_layout = paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout();
// Set default as NCHW in case not specified
......@@ -852,7 +852,7 @@ void InternalUtils::CopyToCpuWithIoStream(paddle_infer::Tensor *t,
if (paddle::platform::is_cpu_place(t_place)) {
#ifdef PADDLE_WITH_MKLDNN
if (tensor->layout() == phi::DataLayout::kMKLDNN)
if (tensor->layout() == phi::DataLayout::ONEDNN)
paddle::framework::innerTransDataLayoutFromMKLDNN(
tensor->layout(),
paddle::platform::MKLDNNDeviceContext::tls()
......
......@@ -26,9 +26,6 @@ add_subdirectory(sequence_ops)
add_subdirectory(string)
add_subdirectory(jit)
add_subdirectory(prim_ops)
if(WITH_MKLDNN)
add_subdirectory(mkldnn)
endif()
if(WITH_DISTRIBUTE)
......
......@@ -23,7 +23,6 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
#include "paddle/phi/backends/dynload/port.h"
#include "paddle/phi/infermeta/backward.h"
......
......@@ -211,8 +211,8 @@ framework::OpKernelType BatchNormOp::GetKernelTypeForVar(
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if ((var_name == "X") &&
(expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN)) {
(expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_layout = ar.Get<std::string>("data_layout");
......@@ -401,8 +401,8 @@ framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar(
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if (((var_name == "X") || (var_name == framework::GradVarName("Y"))) &&
(expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN)) {
(expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_layout = ar.Get<std::string>("data_layout");
......
......@@ -29,7 +29,7 @@ static void DataCopy(const phi::DenseTensor &src_item,
if (src_item.IsInitialized() && src_item.numel() > 0) {
#ifdef PADDLE_WITH_MKLDNN
// Conversion from MKL-DNN to Paddle
if (src_item.layout() == phi::DataLayout::kMKLDNN) {
if (src_item.layout() == phi::DataLayout::ONEDNN) {
phi::DenseTensor out;
// Convert to desired Paddle layout, apart from grads of filter
// as params are not a subject to paddle's data_format
......
......@@ -37,7 +37,7 @@ static void DeepCopy(const phi::DenseTensor &src_item,
if (src_item.IsInitialized() && src_item.numel() > 0) {
#ifdef PADDLE_WITH_MKLDNN
// Conversion from MKL-DNN to Paddle
if (src_item.layout() == phi::DataLayout::kMKLDNN) {
if (src_item.layout() == phi::DataLayout::ONEDNN) {
phi::DenseTensor out;
// Convert to desired Paddle layout, apart from grads of filter
// as params are not a subject to paddle's data_format
......
......@@ -220,8 +220,8 @@ framework::OpKernelType ConvOp::GetKernelTypeForVar(
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if ((var_name == "Input") &&
(expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN)) {
(expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
......@@ -470,8 +470,8 @@ framework::OpKernelType ConvOpGrad::GetKernelTypeForVar(
// bias are having shape in NCHW order
if (((var_name == "Input") ||
(var_name == framework::GradVarName("Output"))) &&
(expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN)) {
(expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
......
......@@ -48,8 +48,8 @@ framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar(
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if ((var_name == "Input") &&
(expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN)) {
(expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
......
......@@ -26,7 +26,7 @@ framework::OpKernelType DeQuantOp::GetExpectedKernelType(
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
......
......@@ -174,8 +174,8 @@ class ElementwiseOp : public framework::OperatorWithKernel {
// When elementwise is first oneDNN op (there was some non oneDNN op
// previously)
// then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN) &&
if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN) &&
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
......
......@@ -28,6 +28,7 @@ using dnnl::memory;
using dnnl::primitive;
using dnnl::stream;
using phi::DataLayout;
using phi::funcs::BinaryOneDNNHandler;
inline std::vector<int64_t> CalculateBroadcastedDims(
const phi::DenseTensor* x, const phi::DenseTensor* y) {
......@@ -51,11 +52,12 @@ inline std::vector<int64_t> CalculateBroadcastedDims(
return dst_tz_ex;
}
inline void AddSubNonBroadcast(platform::ReorderMKLDNNHandler* reorder_handler,
phi::DenseTensor* grad_tensor,
const std::shared_ptr<dnnl::memory>& src_memory,
const std::shared_ptr<dnnl::memory>& dst_memory,
const std::vector<float>& scales) {
inline void AddSubNonBroadcast(
phi::funcs::ReorderOneDNNHandler* reorder_handler,
phi::DenseTensor* grad_tensor,
const std::shared_ptr<dnnl::memory>& src_memory,
const std::shared_ptr<dnnl::memory>& dst_memory,
const std::vector<float>& scales) {
dnnl::primitive_attr reorder_attr;
reorder_attr.set_output_scales(0, scales);
auto reorder_p =
......@@ -84,7 +86,7 @@ inline void BroadcastReduction(const framework::ExecutionContext& ctx,
broadcast_reduction_attr.set_post_ops(po);
}
platform::ReductionMKLDNNHandler<T> reduction_handler(
phi::funcs::ReductionOneDNNHandler<T> reduction_handler(
dnnl::algorithm::reduction_sum,
0.0f,
0.0f,
......@@ -132,18 +134,18 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
float scale_o = ctx.Attr<float>("Scale_out");
int axis = ctx.Attr<int>("axis");
platform::BinaryMKLDNNHandler<T> handler(BINARY_OP,
axis,
mkldnn_engine,
ctx.GetPlace(),
x,
y,
z,
scale_x,
scale_y,
scale_o,
true,
get_post_ops(ctx));
BinaryOneDNNHandler<T> handler(BINARY_OP,
axis,
mkldnn_engine,
ctx.GetPlace(),
x,
y,
z,
scale_x,
scale_y,
scale_o,
true,
get_post_ops(ctx));
// oneDNN's binary is optimized for broadcasting y into x, so in other case
// we have to swap tensors to achieve optimal performance
......@@ -239,16 +241,13 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
int axis = ctx.Attr<int>("axis");
auto tz = phi::vectorize<int64_t>(dout->dims());
auto proto_type_dout = framework::TransToProtoVarType(dout->dtype());
auto dout_type = phi::funcs::ToOneDNNDataType(dout->dtype());
platform::ReorderMKLDNNHandler reorder_handler(
tz,
proto_type_dout,
framework::ToMKLDNNDataType(proto_type_dout),
onednn_engine);
phi::funcs::ReorderOneDNNHandler reorder_handler(
tz, dout->dtype(), dout_type, onednn_engine);
auto reorder_src_memory = reorder_handler.AcquireSrcMemory(
dout->mem_desc(), platform::to_void_cast(dout->data<T>()));
dout->mem_desc(), phi::funcs::to_void_cast(dout->data<T>()));
std::shared_ptr<dnnl::memory> dst_memory;
std::shared_ptr<dnnl::memory> broadcast_src_memory = reorder_src_memory;
......@@ -265,17 +264,17 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
&reorder_handler, dx, reorder_src_memory, dst_memory, scales);
}
} else { // elementwise_mul & elementwise_div
platform::BinaryMKLDNNHandler<T> binary_handler(BINARY_OP,
axis,
onednn_engine,
ctx.GetPlace(),
dout,
y,
dx,
1.0f,
1.0f,
1.0f,
false);
BinaryOneDNNHandler<T> binary_handler(BINARY_OP,
axis,
onednn_engine,
ctx.GetPlace(),
dout,
y,
dx,
1.0f,
1.0f,
1.0f,
false);
const auto src_dout_memory = binary_handler.AcquireSrcMemory(dout);
const auto src_y_memory = binary_handler.AcquireSecondSrcMemory(y);
......@@ -323,23 +322,22 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
std::shared_ptr<dnnl::memory> src_0_memory;
std::shared_ptr<dnnl::memory> src_1_memory;
platform::BinaryMKLDNNHandler<T> binary_handler(
dnnl::algorithm::binary_mul,
axis,
onednn_engine,
ctx.GetPlace(),
dout,
x,
nullptr,
1.0f,
1.0f,
1.0f,
false);
BinaryOneDNNHandler<T> binary_handler(dnnl::algorithm::binary_mul,
axis,
onednn_engine,
ctx.GetPlace(),
dout,
x,
nullptr,
1.0f,
1.0f,
1.0f,
false);
src_1_memory = binary_handler.AcquireSecondSrcMemory(x);
if (BINARY_OP == dnnl::algorithm::binary_div) {
platform::BinaryMKLDNNHandler<T> post_op_binary_handler(
BinaryOneDNNHandler<T> post_op_binary_handler(
dnnl::algorithm::binary_div,
axis,
onednn_engine,
......@@ -358,19 +356,18 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
po.append_binary(dnnl::algorithm::binary_div,
post_op_memory->get_desc());
binary_handler =
platform::BinaryMKLDNNHandler<T>(dnnl::algorithm::binary_mul,
axis,
onednn_engine,
ctx.GetPlace(),
dout,
out,
nullptr,
-1.0f,
1.0f,
1.0f,
false,
po);
binary_handler = BinaryOneDNNHandler<T>(dnnl::algorithm::binary_mul,
axis,
onednn_engine,
ctx.GetPlace(),
dout,
out,
nullptr,
-1.0f,
1.0f,
1.0f,
false,
po);
src_1_memory = binary_handler.AcquireSecondSrcMemory(out);
}
......
......@@ -20,10 +20,8 @@ limitations under the License. */
namespace paddle {
namespace operators {
using paddle::platform::MKLDNNGetDataType;
using paddle::platform::MKLDNNMemDesc;
using phi::CPUContext;
using platform::to_void_cast;
using phi::funcs::OneDNNGetDataType;
using phi::funcs::OneDNNMemDesc;
template <typename T, typename T_out = T>
class GRUMKLDNNHandler : public RNNMKLDNNHandler<T, dnnl::gru_forward, T_out> {
......@@ -73,7 +71,7 @@ class GRUMKLDNNHandler : public RNNMKLDNNHandler<T, dnnl::gru_forward, T_out> {
// Weights for int8 kernel are of a type s8
const auto weights_dt =
is_INT8 ? dnnl::memory::data_type::s8 : MKLDNNGetDataType<T>();
is_INT8 ? dnnl::memory::data_type::s8 : OneDNNGetDataType<T>();
// oneDNN RNN dimensions
const int64_t D = 1; // Directions
......@@ -81,18 +79,18 @@ class GRUMKLDNNHandler : public RNNMKLDNNHandler<T, dnnl::gru_forward, T_out> {
const int64_t G = 3; // Number of Gates, 3 for GRU
// Create memory descriptors
auto input_md = MKLDNNMemDesc(
{Ti, N, IC}, MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::ntc);
auto input_md = OneDNNMemDesc(
{Ti, N, IC}, OneDNNGetDataType<T>(), OneDNNMemoryFormat::ntc);
auto weight_x_md =
MKLDNNMemDesc({L, D, IC, G, OC}, weights_dt, MKLDNNMemoryFormat::any);
OneDNNMemDesc({L, D, IC, G, OC}, weights_dt, OneDNNMemoryFormat::any);
auto weight_h_md =
MKLDNNMemDesc({L, D, OC, G, OC}, weights_dt, MKLDNNMemoryFormat::any);
auto bias_md = MKLDNNMemDesc(
{L, D, G, OC}, MKLDNNGetDataType<float>(), MKLDNNMemoryFormat::ldgo);
auto hidden_md = MKLDNNMemDesc(
{Ti, N, OC}, MKLDNNGetDataType<T_out>(), MKLDNNMemoryFormat::ntc);
auto h0_md = MKLDNNMemDesc(
{L, D, N, OC}, MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::ldnc);
OneDNNMemDesc({L, D, OC, G, OC}, weights_dt, OneDNNMemoryFormat::any);
auto bias_md = OneDNNMemDesc(
{L, D, G, OC}, OneDNNGetDataType<float>(), OneDNNMemoryFormat::ldgo);
auto hidden_md = OneDNNMemDesc(
{Ti, N, OC}, OneDNNGetDataType<T_out>(), OneDNNMemoryFormat::ntc);
auto h0_md = OneDNNMemDesc(
{L, D, N, OC}, OneDNNGetDataType<T>(), OneDNNMemoryFormat::ldnc);
// Create GRU oneDNN primitive
const auto direction =
......@@ -121,9 +119,9 @@ class GRUMKLDNNHandler : public RNNMKLDNNHandler<T, dnnl::gru_forward, T_out> {
std::static_pointer_cast<dnnl::memory>(this->dev_ctx_.GetBlob(wx_key));
if (!memory_p) {
auto user_md = MKLDNNMemDesc({1, 1, this->IC, this->G, this->OC},
MKLDNNGetDataType<U>(),
MKLDNNMemoryFormat::ldigo);
auto user_md = OneDNNMemDesc({1, 1, this->IC, this->G, this->OC},
OneDNNGetDataType<U>(),
OneDNNMemoryFormat::ldigo);
auto user_memory = dnnl::memory(user_md, this->engine_);
auto* weight_x_data = reinterpret_cast<U*>(user_memory.get_data_handle());
......@@ -161,9 +159,9 @@ class GRUMKLDNNHandler : public RNNMKLDNNHandler<T, dnnl::gru_forward, T_out> {
std::static_pointer_cast<dnnl::memory>(this->dev_ctx_.GetBlob(wh_key));
if (!memory_p) {
auto user_md = MKLDNNMemDesc({1, 1, this->OC, this->G, this->OC},
MKLDNNGetDataType<U>(),
MKLDNNMemoryFormat::ldigo);
auto user_md = OneDNNMemDesc({1, 1, this->OC, this->G, this->OC},
OneDNNGetDataType<U>(),
OneDNNMemoryFormat::ldigo);
auto user_memory = dnnl::memory(user_md, this->engine_);
// Reorder weights_h from PP format [OC, 2OC] + [OC, OC] to
......@@ -357,7 +355,7 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> {
auto* hidden_onednn_data = hidden_onednn_memory_p->get_data_handle();
auto* hidden_data =
to_void_cast(hidden->mutable_data<Tout>(ctx.GetPlace()));
phi::funcs::to_void_cast(hidden->mutable_data<Tout>(ctx.GetPlace()));
if (handler.is_NTC()) {
handler.reorderRNNdata(hidden_onednn_data,
hidden_data,
......
......@@ -20,10 +20,8 @@ limitations under the License. */
namespace paddle {
namespace operators {
using paddle::platform::MKLDNNGetDataType;
using paddle::platform::MKLDNNMemDesc;
using phi::CPUContext;
using platform::to_void_cast;
using phi::funcs::OneDNNGetDataType;
using phi::funcs::OneDNNMemDesc;
template <typename T, typename T_out = T>
class LSTMMKLDNNHandler
......@@ -80,7 +78,7 @@ class LSTMMKLDNNHandler
// Weights for int8 kernel are of a type s8
const auto weights_dt =
is_INT8 ? dnnl::memory::data_type::s8 : MKLDNNGetDataType<T>();
is_INT8 ? dnnl::memory::data_type::s8 : OneDNNGetDataType<T>();
// oneDNN RNN dimensions
const int64_t D = 1; // Directions
......@@ -88,21 +86,21 @@ class LSTMMKLDNNHandler
const int64_t G = 4; // Number of Gates, 4 for LSTM
// Create memory descriptors
auto input_md = MKLDNNMemDesc(
{Ti, N, IC}, MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::tnc);
auto input_md = OneDNNMemDesc(
{Ti, N, IC}, OneDNNGetDataType<T>(), OneDNNMemoryFormat::tnc);
auto weight_x_md =
MKLDNNMemDesc({L, D, IC, G, OC}, weights_dt, MKLDNNMemoryFormat::any);
OneDNNMemDesc({L, D, IC, G, OC}, weights_dt, OneDNNMemoryFormat::any);
auto weight_h_md =
MKLDNNMemDesc({L, D, OC, G, OC}, weights_dt, MKLDNNMemoryFormat::any);
auto bias_md = MKLDNNMemDesc(
{L, D, G, OC}, MKLDNNGetDataType<float>(), MKLDNNMemoryFormat::ldgo);
auto hidden_md = MKLDNNMemDesc(
{Ti, N, OC}, MKLDNNGetDataType<T_out>(), MKLDNNMemoryFormat::any);
OneDNNMemDesc({L, D, OC, G, OC}, weights_dt, OneDNNMemoryFormat::any);
auto bias_md = OneDNNMemDesc(
{L, D, G, OC}, OneDNNGetDataType<float>(), OneDNNMemoryFormat::ldgo);
auto hidden_md = OneDNNMemDesc(
{Ti, N, OC}, OneDNNGetDataType<T_out>(), OneDNNMemoryFormat::any);
auto h0_md = MKLDNNMemDesc(
{L, D, N, OC}, MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
auto c0_md = MKLDNNMemDesc(
{L, D, N, OC}, MKLDNNGetDataType<float>(), MKLDNNMemoryFormat::any);
auto h0_md = OneDNNMemDesc(
{L, D, N, OC}, OneDNNGetDataType<T>(), OneDNNMemoryFormat::any);
auto c0_md = OneDNNMemDesc(
{L, D, N, OC}, OneDNNGetDataType<float>(), OneDNNMemoryFormat::any);
// Create LSTM oneDNN primitive
const auto direction =
......@@ -123,9 +121,9 @@ class LSTMMKLDNNHandler
dnnl::memory::desc(),
dnnl::memory::desc());
} else {
auto weight_peephole_md = MKLDNNMemDesc({L, D, 3, OC},
MKLDNNGetDataType<float>(),
MKLDNNMemoryFormat::ldgo);
auto weight_peephole_md = OneDNNMemDesc({L, D, 3, OC},
OneDNNGetDataType<float>(),
OneDNNMemoryFormat::ldgo);
this->AcquireForwardPrimitiveDescriptor(
this->attr_,
dnnl::prop_kind::forward_inference,
......@@ -173,9 +171,9 @@ class LSTMMKLDNNHandler
std::static_pointer_cast<dnnl::memory>(this->dev_ctx_.GetBlob(wx_key));
if (!memory_p) {
auto user_md = MKLDNNMemDesc({1, 1, this->IC, this->G, this->OC},
MKLDNNGetDataType<U>(),
MKLDNNMemoryFormat::ldigo);
auto user_md = OneDNNMemDesc({1, 1, this->IC, this->G, this->OC},
OneDNNGetDataType<U>(),
OneDNNMemoryFormat::ldigo);
auto user_memory = dnnl::memory(user_md, this->engine_);
auto* weight_x_data = reinterpret_cast<U*>(user_memory.get_data_handle());
......@@ -205,9 +203,9 @@ class LSTMMKLDNNHandler
std::static_pointer_cast<dnnl::memory>(this->dev_ctx_.GetBlob(wh_key));
if (!memory_p) {
auto user_md = MKLDNNMemDesc({1, 1, this->OC, this->G, this->OC},
MKLDNNGetDataType<U>(),
MKLDNNMemoryFormat::ldigo);
auto user_md = OneDNNMemDesc({1, 1, this->OC, this->G, this->OC},
OneDNNGetDataType<U>(),
OneDNNMemoryFormat::ldigo);
auto user_memory = dnnl::memory(user_md, this->engine_);
auto* weight_h_data = reinterpret_cast<U*>(user_memory.get_data_handle());
......@@ -264,9 +262,9 @@ class LSTMMKLDNNHandler
this->dev_ctx_.GetBlob(peepholes_key));
if (!memory_p) {
auto user_md = MKLDNNMemDesc({1, 1, 3, this->OC},
MKLDNNGetDataType<float>(),
MKLDNNMemoryFormat::ldgo);
auto user_md = OneDNNMemDesc({1, 1, 3, this->OC},
OneDNNGetDataType<float>(),
OneDNNMemoryFormat::ldgo);
auto user_memory = dnnl::memory(user_md, this->engine_);
memory_p = std::make_shared<dnnl::memory>(
this->fwd_pd_->weights_peephole_desc(), this->engine_);
......@@ -292,15 +290,16 @@ class LSTMMKLDNNHandler
if (!memory_p) {
auto user_c0_memory = dnnl::memory();
if (c0) {
user_c0_memory = dnnl::memory({{1, 1, this->N, this->OC},
MKLDNNGetDataType<float>(),
MKLDNNMemoryFormat::ldnc},
this->engine_,
to_void_cast(c0->data<float>()));
user_c0_memory =
dnnl::memory({{1, 1, this->N, this->OC},
OneDNNGetDataType<float>(),
OneDNNMemoryFormat::ldnc},
this->engine_,
phi::funcs::to_void_cast(c0->data<float>()));
} else {
user_c0_memory = dnnl::memory({{1, 1, this->N, this->OC},
MKLDNNGetDataType<float>(),
MKLDNNMemoryFormat::ldnc},
OneDNNGetDataType<float>(),
OneDNNMemoryFormat::ldnc},
this->engine_);
memset(user_c0_memory.get_data_handle(),
0,
......@@ -451,7 +450,7 @@ class FusionLSTMMKLDNNKernel : public framework::OpKernel<T> {
auto* hidden_onednn_data = hidden_onednn_memory_p->get_data_handle();
auto* hidden_data =
to_void_cast(hidden->mutable_data<Tout>(ctx.GetPlace()));
phi::funcs::to_void_cast(hidden->mutable_data<Tout>(ctx.GetPlace()));
if (handler.is_NTC()) {
handler.reorderRNNdata(hidden_onednn_data,
hidden_data,
......
......@@ -20,13 +20,10 @@ namespace paddle {
namespace operators {
using paddle::platform::CreateKey;
using paddle::platform::MKLDNNGetDataType;
using paddle::platform::MKLDNNMemDesc;
using phi::CPUContext;
using platform::to_void_cast;
using phi::funcs::OneDNNGetDataType;
template <typename T, typename T_alg, typename T_out = T>
class RNNMKLDNNHandler : public platform::MKLDNNHandlerT<T, T_alg> {
class RNNMKLDNNHandler : public phi::funcs::OneDNNHandlerT<T, T_alg> {
public:
RNNMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx,
......@@ -42,11 +39,11 @@ class RNNMKLDNNHandler : public platform::MKLDNNHandlerT<T, T_alg> {
const int64_t OC,
const int64_t G,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, T_alg>(
: phi::funcs::OneDNNHandlerT<T, T_alg>(
dev_ctx,
dev_ctx.GetEngine(),
cpu_place,
CreateKey(dev_ctx, unique_name, MKLDNNGetDataType<T>(), Ti)),
CreateKey(dev_ctx, unique_name, OneDNNGetDataType<T>(), Ti)),
N(N),
Ti(Ti),
IC(IC),
......@@ -55,7 +52,7 @@ class RNNMKLDNNHandler : public platform::MKLDNNHandlerT<T, T_alg> {
// Create memory key without Ti because weights, bias and h0 memories
// do not depend on Ti size but primitive and input/output memory do
memory_key_ = platform::ExtendKeyWithThreadInfoIfNeeded(
dev_ctx, CreateKey(dev_ctx, unique_name, MKLDNNGetDataType<T>()));
dev_ctx, CreateKey(dev_ctx, unique_name, OneDNNGetDataType<T>()));
// Is it int8 kernel
const bool is_INT8 = std::is_same<T, uint8_t>::value;
......@@ -163,7 +160,7 @@ class RNNMKLDNNHandler : public platform::MKLDNNHandlerT<T, T_alg> {
}
const auto& input_lod = input->lod()[0];
auto* x_data = to_void_cast(input->data<T>());
auto* x_data = phi::funcs::to_void_cast(input->data<T>());
auto* x_onednn_data = memory_p->get_data_handle();
memset(x_onednn_data, 0, sizeof(T) * N * Ti * IC);
......@@ -210,12 +207,12 @@ class RNNMKLDNNHandler : public platform::MKLDNNHandlerT<T, T_alg> {
auto user_h0_memory = dnnl::memory();
if (h0) {
user_h0_memory = dnnl::memory(
{{1, 1, N, OC}, MKLDNNGetDataType<U>(), MKLDNNMemoryFormat::ldnc},
{{1, 1, N, OC}, OneDNNGetDataType<U>(), OneDNNMemoryFormat::ldnc},
this->engine_,
to_void_cast(h0->data<U>()));
phi::funcs::to_void_cast(h0->data<U>()));
} else {
user_h0_memory = dnnl::memory(
{{1, 1, N, OC}, MKLDNNGetDataType<U>(), MKLDNNMemoryFormat::ldnc},
{{1, 1, N, OC}, OneDNNGetDataType<U>(), OneDNNMemoryFormat::ldnc},
this->engine_);
memset(user_h0_memory.get_data_handle(), 0, sizeof(U) * N * OC);
}
......
......@@ -27,11 +27,9 @@ namespace paddle {
namespace operators {
using paddle::platform::CreateKey;
using paddle::platform::MKLDNNGetDataType;
using paddle::platform::MKLDNNMemDesc;
using phi::CPUContext;
using phi::vectorize;
using platform::to_void_cast;
using phi::funcs::OneDNNGetDataType;
using phi::funcs::OneDNNMemDesc;
using Direction = dnnl::rnn_direction;
namespace {
......@@ -115,7 +113,7 @@ class MultiGRUHandler {
// Create memory key without Ti because weights, bias and h0 memories
// do not depend on Ti size but primitive and input/output memory do
memory_key_ = platform::ExtendKeyWithThreadInfoIfNeeded(
dev_ctx, CreateKey(dev_ctx, unique_name, MKLDNNGetDataType<T>()));
dev_ctx, CreateKey(dev_ctx, unique_name, OneDNNGetDataType<T>()));
key_ = memory_key_;
key_.append("T").append(std::to_string(Ti_));
......@@ -176,26 +174,26 @@ class MultiGRUHandler {
const auto weights_dt =
is_int8 ? dnnl::memory::data_type::s8 : dnnl::memory::data_type::f32;
auto x_md = MKLDNNMemDesc({Ti_, N_, ICs[layer]},
MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::ntc);
auto h0_md = MKLDNNMemDesc({L, D, N_, OCs[layer]},
MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::ldnc);
auto wx_md = MKLDNNMemDesc({L, D, ICs[layer], G, OCs[layer]},
auto x_md = OneDNNMemDesc({Ti_, N_, ICs[layer]},
OneDNNGetDataType<T>(),
OneDNNMemoryFormat::ntc);
auto h0_md = OneDNNMemDesc({L, D, N_, OCs[layer]},
OneDNNGetDataType<T>(),
OneDNNMemoryFormat::ldnc);
auto wx_md = OneDNNMemDesc({L, D, ICs[layer], G, OCs[layer]},
weights_dt,
MKLDNNMemoryFormat::any);
auto wh_md = MKLDNNMemDesc({L, D, OCs[layer], G, OCs[layer]},
OneDNNMemoryFormat::any);
auto wh_md = OneDNNMemDesc({L, D, OCs[layer], G, OCs[layer]},
weights_dt,
MKLDNNMemoryFormat::any);
auto b_md = MKLDNNMemDesc({L, D, G, OCs[layer]},
MKLDNNGetDataType<float>(),
MKLDNNMemoryFormat::ldgo);
OneDNNMemoryFormat::any);
auto b_md = OneDNNMemDesc({L, D, G, OCs[layer]},
OneDNNGetDataType<float>(),
OneDNNMemoryFormat::ldgo);
auto h_md =
MKLDNNMemDesc({Ti_, N_, OCs[layer]},
(layer == layers_ - 1) ? MKLDNNGetDataType<T_out>()
: MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::ntc);
OneDNNMemDesc({Ti_, N_, OCs[layer]},
(layer == layers_ - 1) ? OneDNNGetDataType<T_out>()
: OneDNNGetDataType<T>(),
OneDNNMemoryFormat::ntc);
auto desc = std::make_shared<dnnl::gru_forward::desc>(
dnnl::prop_kind::forward_inference,
......@@ -226,10 +224,10 @@ class MultiGRUHandler {
if (pd == nullptr) {
const int axis = 2;
auto in_md =
MKLDNNMemDesc({Ti_, N_, OCs[layer]},
(layer == layers_ - 1) ? MKLDNNGetDataType<T_out>()
: MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::ntc);
OneDNNMemDesc({Ti_, N_, OCs[layer]},
(layer == layers_ - 1) ? OneDNNGetDataType<T_out>()
: OneDNNGetDataType<T>(),
OneDNNMemoryFormat::ntc);
std::vector<dnnl::memory::desc> src_mds{in_md, in_md};
pd = std::make_shared<dnnl::concat::primitive_desc>(
......@@ -251,7 +249,7 @@ class MultiGRUHandler {
dev_ctx_.SetBlob(key, memory_p);
}
auto* x_data = to_void_cast(x_->data<T>());
auto* x_data = phi::funcs::to_void_cast(x_->data<T>());
auto* x_onednn_data = memory_p->get_data_handle();
memset(x_onednn_data, 0, sizeof(T) * N_ * Ti_ * ICs[0]);
......@@ -336,8 +334,8 @@ class MultiGRUHandler {
if (!memory_p) {
auto user_h0_memory = dnnl::memory();
user_h0_memory = dnnl::memory({{1, 1, N_, OCs[layer]},
MKLDNNGetDataType<float>(),
MKLDNNMemoryFormat::ldnc},
OneDNNGetDataType<float>(),
OneDNNMemoryFormat::ldnc},
engine_);
memset(
user_h0_memory.get_data_handle(), 0, sizeof(float) * N_ * OCs[layer]);
......@@ -360,9 +358,9 @@ class MultiGRUHandler {
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(key));
if (!memory_p) {
auto user_md = MKLDNNMemDesc({1, 1, ICs[layer], 3, OCs[layer]},
MKLDNNGetDataType<float>(),
MKLDNNMemoryFormat::ldigo);
auto user_md = OneDNNMemDesc({1, 1, ICs[layer], 3, OCs[layer]},
OneDNNGetDataType<float>(),
OneDNNMemoryFormat::ldigo);
auto user_memory = dnnl::memory(user_md, engine_);
auto* weight_x_data =
......@@ -400,9 +398,9 @@ class MultiGRUHandler {
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(key));
if (!memory_p) {
auto user_md = MKLDNNMemDesc({1, 1, OCs[layer], 3, OCs[layer]},
MKLDNNGetDataType<float>(),
MKLDNNMemoryFormat::ldigo);
auto user_md = OneDNNMemDesc({1, 1, OCs[layer], 3, OCs[layer]},
OneDNNGetDataType<float>(),
OneDNNMemoryFormat::ldigo);
auto user_memory = dnnl::memory(user_md, engine_);
// Reorder weights_h from PP format [OC, 2OC] + [OC, OC] to
......@@ -599,7 +597,8 @@ class MultiGRUHandler {
template <typename Tout>
void reorderOutput(std::shared_ptr<dnnl::memory> mem, int layer) {
auto* data = mem->get_data_handle();
auto* hidden_data = to_void_cast(hidden_->mutable_data<Tout>(place_));
auto* hidden_data =
phi::funcs::to_void_cast(hidden_->mutable_data<Tout>(place_));
if (isNTC(gru_pds_[{layers_ - 1, L2R}]->dst_desc())) {
reorderNTCtoPP(data, hidden_data, layers_ - 1);
......
......@@ -143,7 +143,7 @@ framework::OpKernelType MultiGRUOp::GetExpectedKernelType(
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
......
......@@ -348,8 +348,8 @@ class InterpolateOp : public framework::OperatorWithKernel {
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN)) {
if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_layout");
......
......@@ -452,8 +452,8 @@ class InterpolateV2Op : public framework::OperatorWithKernel {
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN)) {
if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_layout");
......
......@@ -233,8 +233,8 @@ class LRNOp : public framework::OperatorWithKernel {
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN)) {
if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
......@@ -357,8 +357,8 @@ class LRNOpGrad : public framework::OperatorWithKernel {
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN)) {
if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
......
......@@ -715,8 +715,8 @@ class MatMulOp : public framework::OperatorWithKernel {
// When matmul is first oneDNN op in a chain (there was some non oneDNN op
// previously)
// then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN) &&
if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN) &&
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
......
......@@ -152,8 +152,8 @@ class MatMulV2Op : public framework::OperatorWithKernel {
#ifdef PADDLE_WITH_MKLDNN
// When matmul_v2 is first oneDNN op in a chain (there was some non oneDNN
// op previously) then we also need to rotate shape NHWC -> NCWH
if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN) &&
if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN) &&
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == phi::DataLayout::kNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
......
cc_library(
mkldnn_axpy_handler
SRCS axpy_handler.cc
DEPS place device_context enforce)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mkldnn/axpy_handler.h"
#include <cinttypes>
#include <memory>
#include <string>
#include <vector>
#include "dnnl.hpp"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
namespace plat = paddle::platform;
namespace {
template <typename T>
class AXPYHandler {
public:
AXPYHandler(const dnnl::engine mkldnn_engine, int n, float alpha) {
platform::MKLDNNDeviceContext::tls().log_lib_version();
auto md = dnnl::memory::desc(
{n}, plat::MKLDNNGetDataType<T>(), dnnl::memory::format_tag::x);
src_mem_ = dnnl::memory(md, mkldnn_engine, DNNL_MEMORY_NONE);
dst_mem_ = dnnl::memory(md, mkldnn_engine, DNNL_MEMORY_NONE);
dnnl::primitive_attr reorder_attr;
dnnl::post_ops post_operations;
if (alpha != 1.f) {
std::vector<float> scales(1, alpha);
reorder_attr.set_output_scales(0, scales);
}
post_operations.append_sum(1.0f);
reorder_attr.set_post_ops(post_operations);
reorder_p_ = dnnl::reorder(src_mem_, dst_mem_, reorder_attr);
}
dnnl::memory &AcquireSrcMemory(const T *x) {
src_mem_.set_data_handle(plat::to_void_cast<T>(x));
return src_mem_;
}
dnnl::memory &AcquireDstMemory(T *y) {
dst_mem_.set_data_handle(y);
return dst_mem_;
}
const dnnl::reorder &AcquireReorder() { return reorder_p_; }
private:
dnnl::memory src_mem_;
dnnl::memory dst_mem_;
dnnl::reorder reorder_p_;
};
template class AXPYHandler<float>;
template class AXPYHandler<plat::bfloat16>;
template <typename T>
static void naive_axpy(int n, T alpha, const T *x, T *y) {
while (n-- > 0) {
*y += alpha * *x;
++y;
++x;
}
}
} // namespace
template <typename T>
class OneDNNAXPYHandler<T>::Impl {
public:
Impl(int64_t n, T alpha);
void operator()(const T *x, T *y);
private:
std::unique_ptr<AXPYHandler<T>> handler_;
int64_t n_;
T alpha_;
};
template <typename T>
OneDNNAXPYHandler<T>::Impl::Impl(int64_t n, T alpha) : n_{n}, alpha_{alpha} {
auto &pool = plat::DeviceContextPool::Instance();
auto cpu_place = plat::CPUPlace();
auto *dev_ctx =
dynamic_cast<plat::MKLDNNDeviceContext *>(pool.Get(cpu_place));
auto &cpu_engine = dev_ctx->GetEngine();
handler_ = std::make_unique<AXPYHandler<T>>(
cpu_engine, n, static_cast<float>(alpha));
}
template <typename T>
void OneDNNAXPYHandler<T>::Impl::operator()(const T *x, T *y) {
if (this->n_ < 100) {
naive_axpy(this->n_, this->alpha_, x, y);
return;
}
auto &reorder_src_mem_p = handler_->AcquireSrcMemory(x);
auto &reorder_dst_mem_p = handler_->AcquireDstMemory(y);
auto reorder_p = handler_->AcquireReorder();
auto &astream = plat::MKLDNNDeviceContext::tls().get_stream();
reorder_p.execute(astream, reorder_src_mem_p, reorder_dst_mem_p);
astream.wait();
}
template <typename T>
OneDNNAXPYHandler<T>::OneDNNAXPYHandler(int64_t n, T alpha)
: pimpl_{new Impl{n, alpha}, [](Impl *impl) { delete impl; }} {
VLOG(4) << "[OneDNN] OneDNNAXPYHandler<" << typeid(T).name() << ">, "
<< "n: " << n << ", alpha: " << alpha;
}
template <typename T>
void OneDNNAXPYHandler<T>::operator()(const T *x, T *y) {
pimpl_->operator()(x, y);
}
template class OneDNNAXPYHandler<float>;
template class OneDNNAXPYHandler<plat::bfloat16>;
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
namespace paddle {
namespace operators {
///
/// @brief Helper class for AXPY execution using oneDNN library.
///
/// @tparam T Data type.
///
template <typename T>
class OneDNNAXPYHandler {
public:
OneDNNAXPYHandler(OneDNNAXPYHandler&) = delete;
OneDNNAXPYHandler(OneDNNAXPYHandler&&) = delete;
OneDNNAXPYHandler& operator=(OneDNNAXPYHandler&) = delete;
OneDNNAXPYHandler& operator=(OneDNNAXPYHandler&&) = delete;
///
/// @brief Constructor.
///
/// @param[in] n The number of elements in tensor (assumed 1D tensor)
/// @param[in] alpha The alpha coefficient.
///
OneDNNAXPYHandler(int64_t n, T alpha);
///
/// @brief Executes AXPY.
///
/// @param[in] x The pointer to input X tensor data.
/// @param[out] y The pointer to output Y tensor data.
///
void operator()(const T* x, T* y);
private:
OneDNNAXPYHandler() = delete;
// (arogowie-intel) Private implementation idiom to hide dependency
// on OneDNN headers.
class Impl;
// We need custom deleter, since the compiler is unable to parameterize
// an allocator's default deleter due to incomple type.
std::unique_ptr<Impl, void (*)(Impl*)> pimpl_;
};
} // namespace operators
} // namespace paddle
......@@ -24,13 +24,11 @@ namespace operators {
using dnnl::memory;
using dnnl::primitive;
using dnnl::reorder;
using dnnl::stream;
using paddle::platform::MKLDNNDeviceContext;
using platform::to_void_cast;
template <typename T>
class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
class BatchNormMKLDNNHandler : public phi::funcs::OneDNNHandlerNoCachingT<
T,
dnnl::batch_normalization_forward,
dnnl::batch_normalization_backward> {
......@@ -40,9 +38,9 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
const Tensor *in_x,
const Tensor *scale,
const Tensor *out_grad)
: platform::MKLDNNHandlerNoCachingT<T,
dnnl::batch_normalization_forward,
dnnl::batch_normalization_backward>(
: phi::funcs::OneDNNHandlerNoCachingT<T,
dnnl::batch_normalization_forward,
dnnl::batch_normalization_backward>(
mkldnn_engine, ctx.GetPlace()) {
auto scale_tz = phi::vectorize<int64_t>(scale->dims());
PADDLE_ENFORCE_EQ(
......@@ -98,8 +96,8 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
std::shared_ptr<dnnl::memory> AcquireMeanMemory(
const phi::DenseTensor *mean) {
const T *mean_data = mean->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
to_void_cast<T>(mean_data));
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->mean_desc(), phi::funcs::to_void_cast<T>(mean_data));
}
std::shared_ptr<dnnl::memory> AcquireMeanMemory(phi::DenseTensor *mean) {
......@@ -112,8 +110,9 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandlerNoCachingT<
std::shared_ptr<dnnl::memory> AcquireVarianceMemory(
const phi::DenseTensor *variance) {
const T *variance_data = variance->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
to_void_cast<T>(variance_data));
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->variance_desc(),
phi::funcs::to_void_cast<T>(variance_data));
}
std::shared_ptr<dnnl::memory> AcquireVarianceMemory(
......
......@@ -23,13 +23,14 @@ namespace operators {
using Tensor = phi::DenseTensor;
using phi::DataLayout;
using phi::funcs::OneDNNMemDesc;
inline dnnl::memory::dims GetWeightsTz(const phi::DenseTensor* filter,
const int groups) {
auto weights_tz = phi::vectorize(filter->dims());
int g = std::max(groups, 1);
int g_dim = (g > 1) ? 1 : 0;
platform::GetGroupConvWeightsTz(weights_tz, g);
phi::funcs::GetGroupConvWeightsTz(weights_tz, g);
// gIOHW -> gOIHW || IOHW -> OIHW
std::swap(weights_tz[g_dim + 0], weights_tz[g_dim + 1]);
return weights_tz;
......@@ -37,7 +38,8 @@ inline dnnl::memory::dims GetWeightsTz(const phi::DenseTensor* filter,
template <typename T, typename K, typename T_out>
class ConvTransposeMKLDNNHandlerT
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::deconvolution_forward> {
: public phi::funcs::OneDNNHandlerNoCachingT<T,
dnnl::deconvolution_forward> {
public:
ConvTransposeMKLDNNHandlerT(const framework::ExecutionContext& ctx,
const dnnl::engine mkldnn_engine,
......@@ -45,7 +47,7 @@ class ConvTransposeMKLDNNHandlerT
const phi::DenseTensor* filter,
const phi::DenseTensor* bias,
phi::DenseTensor* output)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::deconvolution_forward>(
: phi::funcs::OneDNNHandlerNoCachingT<T, dnnl::deconvolution_forward>(
mkldnn_engine, ctx.GetPlace()),
is_test_(ctx.Attr<bool>("is_test")) {
PADDLE_ENFORCE_EQ(is_test_,
......@@ -57,16 +59,16 @@ class ConvTransposeMKLDNNHandlerT
PADDLE_ENFORCE_EQ(
input->layout(),
DataLayout::kMKLDNN,
DataLayout::ONEDNN,
platform::errors::InvalidArgument(
"Got wrong layout = %d for Input tensor.", input->layout()));
PADDLE_ENFORCE_EQ(
filter->layout(),
DataLayout::kMKLDNN,
DataLayout::ONEDNN,
platform::errors::InvalidArgument(
"The filter tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN,
DataLayout::ONEDNN,
filter->layout()));
PADDLE_ENFORCE_EQ(
......@@ -85,10 +87,10 @@ class ConvTransposeMKLDNNHandlerT
if (bias) {
PADDLE_ENFORCE_EQ(
bias->layout(),
DataLayout::kMKLDNN,
DataLayout::ONEDNN,
platform::errors::InvalidArgument(
"The bias tensor's laytout should be %d, but got %d.",
DataLayout::kMKLDNN,
DataLayout::ONEDNN,
bias->layout()));
PADDLE_ENFORCE_EQ(
......@@ -136,25 +138,24 @@ class ConvTransposeMKLDNNHandlerT
const auto src_tz = phi::vectorize(input->dims());
const auto weights_tz = GetWeightsTz(filter, groups);
const auto dst_tz = phi::vectorize(output->dims());
const auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
const auto mkldnn_paddings = phi::funcs::ToOneDNNPadding(paddings);
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
const auto chosen_memory_format = MKLDNNMemoryFormat::any;
const auto chosen_memory_format = OneDNNMemoryFormat::any;
auto data_type = dnnl::memory::data_type::f32;
if (ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16" ||
std::is_same<T_out, platform::bfloat16>::value)
data_type = dnnl::memory::data_type::bf16;
const auto src_md =
platform::MKLDNNMemDesc(src_tz, data_type, chosen_memory_format);
const auto src_md = OneDNNMemDesc(src_tz, data_type, chosen_memory_format);
const auto weights_md =
platform::MKLDNNMemDesc(weights_tz, data_type, chosen_memory_format);
const auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);
OneDNNMemDesc(weights_tz, data_type, chosen_memory_format);
const auto dst_md = OneDNNMemDesc(
dst_tz, phi::funcs::OneDNNGetDataType<T_out>(), chosen_memory_format);
const dnnl::primitive_attr conv_trans_attr = CreateConvAttrs(ctx);
auto fwd_prop_kind = is_test_ ? dnnl::prop_kind::forward_inference
......@@ -162,7 +163,7 @@ class ConvTransposeMKLDNNHandlerT
if (bias) {
std::vector<int64_t> bias_tz = phi::vectorize(bias->dims());
const auto bias_md =
platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x);
OneDNNMemDesc(bias_tz, data_type, OneDNNMemoryFormat::x);
this->AcquireForwardPrimitiveDescriptor(
conv_trans_attr,
fwd_prop_kind,
......@@ -221,10 +222,10 @@ class ConvTransposeMKLDNNHandlerT
std::shared_ptr<dnnl::memory> AcquireSrcMemoryWithReorder(
const phi::DenseTensor* input) {
const T* input_data = input->data<T>();
return platform::MKLDNNHandlerNoCachingT<T, dnnl::deconvolution_forward>::
return phi::funcs::OneDNNHandlerNoCachingT<T, dnnl::deconvolution_forward>::
AcquireMemoryWithReorder(input->mem_desc(),
this->fwd_pd_->src_desc(),
platform::to_void_cast<T>(input_data));
phi::funcs::to_void_cast<T>(input_data));
}
std::shared_ptr<dnnl::memory> AcquireWeightsMemoryWithReorder(
......@@ -236,16 +237,16 @@ class ConvTransposeMKLDNNHandlerT
auto weights_tz = GetWeightsTz(filter, groups);
int g = std::max(groups, 1);
auto user_src_md = platform::MKLDNNMemDesc(
auto user_src_md = OneDNNMemDesc(
weights_tz,
platform::MKLDNNGetDataType<K>(),
(g == 1) ? MKLDNNMemoryFormat::iohw : MKLDNNMemoryFormat::giohw);
phi::funcs::OneDNNGetDataType<K>(),
(g == 1) ? OneDNNMemoryFormat::iohw : OneDNNMemoryFormat::giohw);
return this->template AcquireMemoryWithReorder<K>(
dev_ctx,
user_src_md,
this->fwd_pd_->weights_desc(),
platform::to_void_cast<K>(filter_data),
phi::funcs::to_void_cast<K>(filter_data),
key,
"@weights_mem_p",
is_test_);
......@@ -276,7 +277,7 @@ class ConvTransposeMKLDNNHandlerT
target_memory_p =
std::make_shared<dnnl::memory>(target_md, this->engine_);
dnnl::reorder::primitive_desc reorder_pdesc;
if (platform::is_int8<T>()) {
if (phi::funcs::is_int8<T>()) {
dnnl::primitive_attr attr;
attr.set_output_scales(mask, scale_data);
reorder_pdesc = dnnl::reorder::primitive_desc(
......@@ -334,17 +335,17 @@ class ConvTransposeMKLDNNHandlerT
const std::string& key,
const phi::DenseTensor* bias) {
const K* bias_data = bias->data<K>();
auto user_bias_md =
platform::MKLDNNMemDesc(phi::vectorize(bias->dims()),
platform::MKLDNNGetDataType<K>(),
MKLDNNMemoryFormat::x);
return this->AcquireMemoryWithReorder(dev_ctx,
user_bias_md,
this->fwd_pd_->bias_desc(),
platform::to_void_cast<K>(bias_data),
key,
"@bias_mem_p",
is_test_);
auto user_bias_md = OneDNNMemDesc(phi::vectorize(bias->dims()),
phi::funcs::OneDNNGetDataType<K>(),
OneDNNMemoryFormat::x);
return this->AcquireMemoryWithReorder(
dev_ctx,
user_bias_md,
this->fwd_pd_->bias_desc(),
phi::funcs::to_void_cast<K>(bias_data),
key,
"@bias_mem_p",
is_test_);
}
private:
......
......@@ -26,10 +26,8 @@ namespace operators {
using dnnl::memory;
using dnnl::primitive;
using dnnl::reorder;
using platform::to_void_cast;
using Tensor = phi::DenseTensor;
using dnnl::stream;
using phi::DataLayout;
template <typename T>
class DeQuantOpKernel : public framework::OpKernel<T> {
......@@ -55,8 +53,8 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
ctx.template device_context<platform::MKLDNNDeviceContext>();
auto x_tz = phi::vectorize<int64_t>(x->dims());
auto x_paddle_dtype = framework::TransToProtoVarType(x->dtype());
auto out_paddle_dtype = framework::TransToProtoVarType(out->dtype());
auto x_type = phi::funcs::ToOneDNNDataType(x->dtype());
auto out_type = phi::funcs::ToOneDNNDataType(out->dtype());
dnnl::primitive_attr attrs;
static constexpr int32_t mask = 0; // same shift and scale for whole tensor
......@@ -69,16 +67,11 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
DNNL_ARG_SRC, mask, {static_cast<int32_t>(quantization_shift)});
}
platform::ReorderMKLDNNHandler reorder_handler(
x_tz,
x_paddle_dtype,
framework::ToMKLDNNDataType(x_paddle_dtype),
out_paddle_dtype,
framework::ToMKLDNNDataType(out_paddle_dtype),
dev_ctx.GetEngine());
phi::funcs::ReorderOneDNNHandler reorder_handler(
x_tz, x->dtype(), x_type, out->dtype(), out_type, dev_ctx.GetEngine());
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->mem_desc(), platform::to_void_cast(x->data<T>()));
x->mem_desc(), phi::funcs::to_void_cast(x->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out, x->mem_desc(), dev_ctx.GetPlace());
......
......@@ -29,14 +29,9 @@ using dnnl::stream;
using framework::DDim;
using framework::ExecutionContext;
using LoDTensor = phi::DenseTensor;
using phi::funcs::OneDNNGetDataType;
using phi::funcs::to_void_cast;
using platform::MKLDNNDeviceContext;
using platform::MKLDNNGetDataType;
using platform::to_void_cast;
template <typename T>
constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
struct InnerProductCache {
dnnl::inner_product_forward inner_product_p;
......@@ -47,8 +42,8 @@ struct InnerProductCache {
};
template <typename T_in, typename T_w, typename T_out>
class FCMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T_in,
dnnl::inner_product_forward> {
: public phi::funcs::OneDNNHandlerNoCachingT<T_in,
dnnl::inner_product_forward> {
public:
FCMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx,
......@@ -59,7 +54,7 @@ class FCMKLDNNHandler
const int in_num_col_dims,
dnnl::engine mkldnn_engine,
platform::Place cpu_place)
: platform::MKLDNNHandlerNoCachingT<T_in, dnnl::inner_product_forward>(
: phi::funcs::OneDNNHandlerNoCachingT<T_in, dnnl::inner_product_forward>(
mkldnn_engine, cpu_place),
dev_ctx_(dev_ctx) {
this->memory_key_ = ctx.InputName("W");
......@@ -82,14 +77,14 @@ class FCMKLDNNHandler
dnnl::memory::desc bias_md;
auto src_md = dnnl::memory::desc(
{MB, IC}, MKLDNNGetDataType<T_in>(), dnnl::memory::format_tag::any);
{MB, IC}, OneDNNGetDataType<T_in>(), dnnl::memory::format_tag::any);
auto weights_md = dnnl::memory::desc(
{OC, IC}, MKLDNNGetDataType<T_w>(), dnnl::memory::format_tag::any);
{OC, IC}, OneDNNGetDataType<T_w>(), dnnl::memory::format_tag::any);
auto dst_md = dnnl::memory::desc(
{MB, OC}, MKLDNNGetDataType<T_out>(), dnnl::memory::format_tag::any);
{MB, OC}, OneDNNGetDataType<T_out>(), dnnl::memory::format_tag::any);
if (bias) {
bias_md = dnnl::memory::desc({bias->numel()},
MKLDNNGetDataType<float>(),
OneDNNGetDataType<float>(),
dnnl::memory::format_tag::a);
}
......@@ -110,7 +105,7 @@ class FCMKLDNNHandler
std::vector<float> output_shift_scale;
float scale = 1.0f;
if (IsInt8<T_w>()) {
if (phi::funcs::is_int8<T_w>()) {
std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx);
int mask = CreateMask(1, output_shift_scale.size() > 1);
attributes.set_output_scales(mask, output_shift_scale);
......@@ -250,7 +245,7 @@ class FCMKLDNNHandler
const std::vector<float>& scale_weights) {
const float* bias_data = bias->data<float>();
if (IsInt8<T_w>() == false) {
if (phi::funcs::is_int8<T_w>() == false) {
// for BF16/FP32 bias is 1D and has no scales, so reorder is not needed
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->bias_desc(),
to_void_cast<float>(bias_data));
......@@ -267,7 +262,7 @@ class FCMKLDNNHandler
attrs.set_output_scales(mask, scale_data);
auto user_md = dnnl::memory::desc({bias->dims()[0]},
MKLDNNGetDataType<float>(),
OneDNNGetDataType<float>(),
dnnl::memory::format_tag::a);
memory_p = this->AcquireMemoryWithReorderAndAttrs(
......@@ -292,10 +287,10 @@ class FCMKLDNNHandler
auto weights_dims = this->fwd_pd_->weights_desc().dims();
auto user_md = dnnl::memory::desc(weights_dims,
MKLDNNGetDataType<float>(),
OneDNNGetDataType<float>(),
dnnl::memory::format_tag::io);
if (IsInt8<T_w>()) {
if (phi::funcs::is_int8<T_w>()) {
dnnl::primitive_attr attrs;
int mask = CreateMask(0, scale_data.size() > 1);
attrs.set_output_scales(mask, scale_data);
......@@ -358,7 +353,7 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
IF_CHANGE_FC_TW_TYPENAME((std::is_same<T_in, uint8_t>::value), ([&] {
if (force_fp32_output) {
this->RunKernel<float, T_w>(ctx);
} else if (IsInt8<T_in>()) {
} else if (phi::funcs::is_int8<T_in>()) {
if (fuse_relu) {
this->RunKernel<uint8_t, T_w>(ctx);
} else {
......
......@@ -25,29 +25,28 @@ using dnnl::reorder;
using dnnl::resampling_forward;
using dnnl::stream;
using phi::DataLayout;
using platform::to_void_cast;
template <typename T = float>
class InterpolateMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::resampling_forward> {
class InterpolateOneDNNHandler
: public phi::funcs::OneDNNHandlerNoCachingT<T, dnnl::resampling_forward> {
public:
InterpolateMKLDNNHandler(const dnnl::algorithm algo,
InterpolateOneDNNHandler(const dnnl::algorithm algo,
const dnnl::engine engine,
platform::Place cpu_place,
const phi::DenseTensor* x,
phi::DenseTensor* out)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::resampling_forward>(
: phi::funcs::OneDNNHandlerNoCachingT<T, dnnl::resampling_forward>(
engine, cpu_place) {
const auto dst_tz = phi::vectorize(out->dims());
const auto dst_md = memory::desc(
dst_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
dst_tz, phi::funcs::OneDNNGetDataType<T>(), OneDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_inference, algo, x->mem_desc(), dst_md);
}
};
template <typename T = float>
class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
class InterpolateOneDNNKernel : public framework::OpKernel<T> {
std::vector<int> ComputeOutputShape(
const framework::ExecutionContext& ctx) const {
const auto* x = ctx.Input<phi::DenseTensor>("X");
......@@ -147,7 +146,7 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
framework::DDim dim_out = phi::make_ddim(out_dims_vec);
out->Resize(dim_out);
InterpolateMKLDNNHandler<T> handler(
InterpolateOneDNNHandler<T> handler(
algo, mkldnn_engine, ctx.GetPlace(), x, out);
auto src_memory_p = handler.AcquireSrcMemory(x);
......@@ -173,10 +172,10 @@ namespace ops = paddle::operators;
REGISTER_OP_KERNEL(nearest_interp,
MKLDNN,
::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>,
ops::InterpolateMKLDNNKernel<int8_t>,
ops::InterpolateMKLDNNKernel<uint8_t>);
ops::InterpolateOneDNNKernel<float>,
ops::InterpolateOneDNNKernel<int8_t>,
ops::InterpolateOneDNNKernel<uint8_t>);
REGISTER_OP_KERNEL(bilinear_interp,
MKLDNN,
::paddle::platform::CPUPlace,
ops::InterpolateMKLDNNKernel<float>);
ops::InterpolateOneDNNKernel<float>);
......@@ -20,18 +20,19 @@ namespace paddle {
namespace operators {
template <typename T>
class LayerNormMKLDNNHandler
: public platform::
MKLDNNHandlerNoCachingT<T, dnnl::layer_normalization_forward> {
class LayerNormOneDNNHandler
: public phi::funcs::
OneDNNHandlerNoCachingT<T, dnnl::layer_normalization_forward> {
public:
LayerNormMKLDNNHandler(const std::vector<int64_t>& dims,
LayerNormOneDNNHandler(const std::vector<int64_t>& dims,
const float& epsilon,
const dnnl::normalization_flags& flags,
const bool& is_test,
const phi::DenseTensor* x,
const dnnl::engine engine,
platform::Place cpu_place)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::layer_normalization_forward>(
: phi::funcs::OneDNNHandlerNoCachingT<T,
dnnl::layer_normalization_forward>(
engine, cpu_place) {
const auto fwd_prop_kind = is_test ? dnnl::prop_kind::forward_inference
: dnnl::prop_kind::forward_training;
......@@ -103,7 +104,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
flags |= dnnl::normalization_flags::use_scale_shift;
}
LayerNormMKLDNNHandler<T> handler(
LayerNormOneDNNHandler<T> handler(
src_tz, epsilon, flags, is_test, x, mkldnn_engine, ctx.GetPlace());
auto src_memory = handler.AcquireSrcMemory(x);
......
......@@ -20,17 +20,17 @@ namespace operators {
using paddle::platform::MKLDNNDeviceContext;
template <typename T>
class LRNMKLDNNHandler
: public platform::
MKLDNNHandlerNoCachingT<T, dnnl::lrn_forward, dnnl::lrn_backward> {
class LRNOneDNNHandler
: public phi::funcs::
OneDNNHandlerNoCachingT<T, dnnl::lrn_forward, dnnl::lrn_backward> {
public:
LRNMKLDNNHandler(const framework::ExecutionContext& ctx,
LRNOneDNNHandler(const framework::ExecutionContext& ctx,
const dnnl::engine mkldnn_engine,
platform::Place cpu_place,
const phi::DenseTensor* input)
: platform::
MKLDNNHandlerNoCachingT<T, dnnl::lrn_forward, dnnl::lrn_backward>(
: phi::funcs::
OneDNNHandlerNoCachingT<T, dnnl::lrn_forward, dnnl::lrn_backward>(
mkldnn_engine, cpu_place) {
const int n = ctx.Attr<int>("n");
// MKL-DNN implements LRN in a caffe way:
......@@ -55,14 +55,14 @@ class LRNMKLDNNHandler
k);
}
LRNMKLDNNHandler(const framework::ExecutionContext& ctx,
LRNOneDNNHandler(const framework::ExecutionContext& ctx,
const dnnl::engine mkldnn_engine,
platform::Place cpu_place,
const phi::DenseTensor* in_x,
const phi::DenseTensor* out_grad,
phi::DenseTensor* in_x_grad)
: platform::
MKLDNNHandlerNoCachingT<T, dnnl::lrn_forward, dnnl::lrn_backward>(
: phi::funcs::
OneDNNHandlerNoCachingT<T, dnnl::lrn_forward, dnnl::lrn_backward>(
mkldnn_engine, cpu_place) {
PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"),
......@@ -107,7 +107,7 @@ class LRNMKLDNNHandler
const T* workspace_data = workspace->data<T>();
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->workspace_desc(),
platform::to_void_cast<T>(workspace_data));
phi::funcs::to_void_cast<T>(workspace_data));
}
};
......@@ -132,7 +132,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto out = ctx.Output<phi::DenseTensor>("Out");
auto mid = ctx.Output<phi::DenseTensor>("MidOut");
LRNMKLDNNHandler<T> handler(ctx, mkldnn_engine, ctx.GetPlace(), x);
LRNOneDNNHandler<T> handler(ctx, mkldnn_engine, ctx.GetPlace(), x);
auto src_memory = handler.AcquireSrcMemory(x);
auto dst_memory = handler.AcquireDstMemory(out);
......@@ -140,7 +140,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto lrn_p = handler.AcquireForwardPrimitive();
auto workspace_memory = handler.AcquireWorkspaceMemory(mid);
mid->set_layout(phi::DataLayout::kMKLDNN);
mid->set_layout(phi::DataLayout::ONEDNN);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (!workspace_memory->get_desc().is_zero()) {
......@@ -182,7 +182,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
LRNMKLDNNHandler<T> handler(
LRNOneDNNHandler<T> handler(
ctx, mkldnn_engine, ctx.GetPlace(), in_x, out_grad, in_x_grad);
auto src_memory = handler.AcquireSrcMemory(in_x);
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace paddle {
namespace operators {
using framework::ExecutionContext;
using platform::MKLDNNDeviceContext;
using Tensor = phi::DenseTensor;
template <typename T>
class MatMulGradMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override;
private:
void ExecuteMatMulGrad(const ExecutionContext& ctx,
const MKLDNNDeviceContext& dev_ctx,
const dnnl::engine& engine,
phi::DenseTensor* x,
bool trans_x,
bool is_fold_init_dims_x,
phi::DenseTensor* y,
bool trans_y,
bool is_fold_init_dims_y,
phi::DenseTensor* out) const;
void RunKernel(const ExecutionContext& ctx) const;
};
} // namespace operators
} // namespace paddle
......@@ -11,16 +11,19 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace {
using dnnl::memory;
using paddle::framework::ExecutionContext;
using paddle::platform::MatMulV2MKLDNNHandler;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNGetDataType;
using paddle::platform::to_void_cast;
using phi::vectorize;
using phi::funcs::OneDNNGetDataType;
using Tensor = phi::DenseTensor;
using paddle::framework::GradVarName;
using phi::make_ddim;
......@@ -54,15 +57,11 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext &dev_ctx,
memory::data_type input_type = paddle::framework::ToMKLDNNDataType(
paddle::framework::TransToProtoVarType(input->dtype()));
paddle::platform::ReorderMKLDNNHandler reorder_handler(
output_dims,
paddle::framework::TransToProtoVarType(input->dtype()),
input_type,
dev_ctx.GetEngine());
phi::funcs::ReorderOneDNNHandler reorder_handler(
output_dims, input->dtype(), input_type, dev_ctx.GetEngine());
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
memory::format_tag::abc,
paddle::platform::to_void_cast(input->data<T>()));
memory::format_tag::abc, phi::funcs::to_void_cast(input->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
&output, memory::format_tag::bac, dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
......@@ -76,16 +75,6 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext &dev_ctx,
return output;
}
template <typename T>
constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
template <typename T>
constexpr bool IsBfloat16() {
return std::is_same<T, paddle::platform::bfloat16>::value;
}
// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
// original x_dim is returned.
static paddle::framework::DDim RowMatrixDimsFromVector(
......@@ -112,7 +101,7 @@ phi::DDim GetDimForInput(const ExecutionContext &ctx, std::string input_name) {
template <typename XT, typename YT, typename OT>
class MatMulMKLDNNHandler
: public paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul> {
: public phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
public:
MatMulMKLDNNHandler(const dnnl::engine engine,
paddle::platform::Place cpu_place,
......@@ -122,8 +111,8 @@ class MatMulMKLDNNHandler
bool trans_y,
Tensor *out,
float scale)
: paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul>(engine,
cpu_place) {
: phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul>(engine,
cpu_place) {
auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x->dims(), 0, trans_x);
auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y->dims(), 0, trans_y);
......@@ -146,9 +135,9 @@ class MatMulMKLDNNHandler
!trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K};
memory::dims out_strides = memory::dims{M * N, N, 1};
auto x_md = memory::desc(x_dims, MKLDNNGetDataType<XT>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<YT>(), y_strides);
auto out_md = memory::desc(out_dims, MKLDNNGetDataType<OT>(), out_strides);
auto x_md = memory::desc(x_dims, OneDNNGetDataType<XT>(), x_strides);
auto y_md = memory::desc(y_dims, OneDNNGetDataType<YT>(), y_strides);
auto out_md = memory::desc(out_dims, OneDNNGetDataType<OT>(), out_strides);
dnnl::primitive_attr attrs;
if (scale != 1.0f) attrs.set_output_scales(0, {scale});
......@@ -158,8 +147,9 @@ class MatMulMKLDNNHandler
std::shared_ptr<memory> AcquireWeightsMemory(const Tensor *input) {
const YT *input_data = input->data<YT>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<YT>(input_data));
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->weights_desc(),
phi::funcs::to_void_cast<YT>(input_data));
}
public:
......@@ -350,18 +340,14 @@ bool IsOutputFused(const ExecutionContext &ctx) {
template <typename T, typename T_out>
void ExecuteMatMulV2(const ExecutionContext &ctx,
const MKLDNNDeviceContext &dev_ctx,
const dnnl::engine onednn_engine,
paddle::platform::Place cpu_place,
const Tensor *x,
const std::vector<int64_t> &x_dims,
bool trans_x,
const Tensor *y,
const std::vector<int64_t> &y_dims,
bool trans_y,
Tensor *out,
const std::vector<int64_t> &out_dims,
int execution_number = 0) {
Tensor *out) {
std::vector<int64_t> x_strides_override = GetInputStrides(ctx, "X");
std::vector<int64_t> y_strides_override = GetInputStrides(ctx, "Y");
MatMulV2MKLDNNHandler<T, T, T_out> handler(ctx,
......@@ -399,14 +385,13 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
// TODO(jczaja): Explain why int8 format of dst is ABCD and do not need
// permute
if (IsOutputFused(ctx) && !IsInt8<T_out>()) {
if (IsOutputFused(ctx) && !phi::funcs::is_int8<T_out>()) {
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_Out");
auto permuted_md = dst_memory_p->get_desc().permute_axes(axis);
out->set_mem_desc(
permuted_md.reshape(phi::vectorize<int64_t>(out->dims())));
out->set_mem_desc(permuted_md.reshape(vectorize<int64_t>(out->dims())));
} else {
out->set_mem_desc(
dst_memory_p->get_desc().reshape(phi::vectorize<int64_t>(out->dims())));
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
}
}
......@@ -423,20 +408,75 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
"head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number")));
}
constexpr bool is_int8 = IsInt8<T>();
constexpr bool is_bfloat16 = IsBfloat16<T>();
constexpr bool is_int8 = phi::funcs::is_int8<T>();
constexpr bool is_bfloat16 = phi::funcs::is_bfloat16<T>();
const bool force_fp32_output = ctx.HasAttr("force_fp32_output")
? ctx.Attr<bool>("force_fp32_output")
: false;
constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses
const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &onednn_engine = dev_ctx.GetEngine();
auto *x = ctx.Input<phi::DenseTensor>("X");
auto *y = ctx.Input<phi::DenseTensor>("Y");
auto *out = ctx.Output<phi::DenseTensor>("Out");
bool trans_x = ctx.HasAttr("trans_x") ? ctx.Attr<bool>("trans_x")
: ctx.Attr<bool>("transpose_X");
bool trans_y = ctx.HasAttr("trans_y") ? ctx.Attr<bool>("trans_y")
: ctx.Attr<bool>("transpose_Y");
auto x_dims = vectorize(GetDimForInput(ctx, "X"));
auto y_dims = vectorize(GetDimForInput(ctx, "Y"));
int ndims = std::max(x_dims.size(), y_dims.size());
ndims = std::max(ndims, 3);
std::vector<int64_t> x_bd_dims(ndims, 1);
std::vector<int64_t> y_bd_dims(ndims, 1);
CalculateMatrixDims(ctx, x_dims, y_dims, &x_bd_dims, &y_bd_dims, out);
if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
RunKernel<float>(ctx);
ExecuteMatMulV2<T, float>(ctx,
onednn_engine,
x,
x_bd_dims,
trans_x,
y,
y_bd_dims,
trans_y,
out);
} else if (is_bfloat16) {
RunKernel<paddle::platform::bfloat16>(ctx);
ExecuteMatMulV2<T, paddle::platform::bfloat16>(ctx,
onednn_engine,
x,
x_bd_dims,
trans_x,
y,
y_bd_dims,
trans_y,
out);
} else if (fuse_relu) {
RunKernel<uint8_t>(ctx);
ExecuteMatMulV2<T, uint8_t>(ctx,
onednn_engine,
x,
x_bd_dims,
trans_x,
y,
y_bd_dims,
trans_y,
out);
} else {
RunKernel<int8_t>(ctx);
ExecuteMatMulV2<T, int8_t>(ctx,
onednn_engine,
x,
x_bd_dims,
trans_x,
y,
y_bd_dims,
trans_y,
out);
}
}
......@@ -446,7 +486,6 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
const std::vector<int64_t> &y_dims,
std::vector<int64_t> *x_bd_dims,
std::vector<int64_t> *y_bd_dims,
std::vector<int64_t> *out_dims,
Tensor *out) const {
if (x_dims.size() == 1) {
(*x_bd_dims)[(*x_bd_dims).size() - 1] = x_dims[0];
......@@ -470,6 +509,7 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
}
if (!IsOutputFused(ctx) && x_dims.size() > 2 && y_dims.size() > 2) {
auto out_dims = vectorize(out->dims());
for (size_t i = 0; i < (*x_bd_dims).size() - 2; ++i) {
PADDLE_ENFORCE_EQ(
(*x_bd_dims)[i] == (*y_bd_dims)[i] || (*x_bd_dims)[i] == 1 ||
......@@ -483,126 +523,194 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
(*x_bd_dims)[i],
i,
(*y_bd_dims)[i]));
(*out_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]);
(out_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]);
}
out->Resize(phi::make_ddim((*out_dims)));
out->Resize(phi::make_ddim((out_dims)));
}
}
};
template <typename T_out>
void RunKernel(const ExecutionContext &ctx) const {
const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &onednn_engine = dev_ctx.GetEngine();
auto *x = ctx.Input<phi::DenseTensor>("X");
auto *y = ctx.Input<phi::DenseTensor>("Y");
auto *out = ctx.Output<phi::DenseTensor>("Out");
bool trans_x = ctx.HasAttr("trans_x") ? ctx.Attr<bool>("trans_x")
: ctx.Attr<bool>("transpose_X");
bool trans_y = ctx.HasAttr("trans_y") ? ctx.Attr<bool>("trans_y")
: ctx.Attr<bool>("transpose_Y");
auto x_dims = vectorize(GetDimForInput(ctx, "X"));
auto y_dims = vectorize(GetDimForInput(ctx, "Y"));
auto out_dims = vectorize(out->dims());
template <typename T>
class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const ExecutionContext &ctx) const override {
if (ctx.HasAttr("head_number")) {
PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"),
1,
paddle::platform::errors::Unimplemented(
"oneDNN matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number")));
}
int ndims = std::max(x_dims.size(), y_dims.size());
ndims = std::max(ndims, 3);
const auto &dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto &onednn_engine = dev_ctx.GetEngine();
std::vector<int64_t> x_bd_dims(ndims, 1);
std::vector<int64_t> y_bd_dims(ndims, 1);
auto x = *ctx.Input<phi::DenseTensor>("X");
auto y = *ctx.Input<phi::DenseTensor>("Y");
auto dout =
*ctx.Input<phi::DenseTensor>(paddle::framework::GradVarName("Out"));
auto *dx =
ctx.Output<phi::DenseTensor>(paddle::framework::GradVarName("X"));
auto *dy =
ctx.Output<phi::DenseTensor>(paddle::framework::GradVarName("Y"));
bool transpose_x = ctx.HasAttr("transpose_X")
? ctx.Attr<bool>("transpose_X")
: ctx.Attr<bool>("trans_x");
bool transpose_y = ctx.HasAttr("transpose_Y")
? ctx.Attr<bool>("transpose_Y")
: ctx.Attr<bool>("trans_y");
ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
paddle::framework::DDim dx_dims;
if (dx) {
dx_dims = dx->dims();
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
}
CalculateMatrixDims(
ctx, x_dims, y_dims, &x_bd_dims, &y_bd_dims, &out_dims, out);
paddle::framework::DDim dy_dims;
if (dy) {
dy_dims = dy->dims();
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
}
}
ExecuteMatMulV2<T, T_out>(ctx,
if (transpose_x && transpose_y) {
this->ExecuteMatMulGrad(
ctx, dev_ctx, onednn_engine, &y, true, true, &dout, true, false, dx);
this->ExecuteMatMulGrad(
ctx, dev_ctx, onednn_engine, &dout, true, true, &x, true, false, dy);
} else if (transpose_x) {
this->ExecuteMatMulGrad(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
x,
x_bd_dims,
trans_x,
y,
y_bd_dims,
trans_y,
out,
out_dims);
}
};
template <typename T>
class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); }
&y,
false,
false,
&dout,
true,
false,
dx);
this->ExecuteMatMulGrad(ctx,
dev_ctx,
onednn_engine,
&x,
false,
false,
&dout,
false,
true,
dy);
} else if (transpose_y) {
this->ExecuteMatMulGrad(ctx,
dev_ctx,
onednn_engine,
&dout,
false,
false,
&y,
false,
true,
dx);
this->ExecuteMatMulGrad(
ctx, dev_ctx, onednn_engine, &dout, true, true, &x, false, true, dy);
} else {
this->ExecuteMatMulGrad(ctx,
dev_ctx,
onednn_engine,
&dout,
false,
false,
&y,
true,
false,
dx);
this->ExecuteMatMulGrad(
ctx, dev_ctx, onednn_engine, &x, true, true, &dout, false, true, dy);
}
private:
void CalculateGradMatrixDims(const ExecutionContext &ctx,
Tensor *dx_tmp,
Tensor *dy_tmp,
const std::vector<int64_t> &dx_dims,
const std::vector<int64_t> &dy_dims,
std::vector<int64_t> *dx_bd_dims,
std::vector<int64_t> *dy_bd_dims) const {
for (size_t i = 0; i < dx_dims.size() - 2; ++i) {
if (dx_dims[i] != dy_dims[i]) {
if (dx_dims[i] == 1) {
(*dx_bd_dims)[i] = dy_dims[i];
} else {
(*dy_bd_dims)[i] = dx_dims[i];
}
if (dx) {
if (dx_dims != x.dims()) {
dx->Resize(dx_dims);
dx->set_mem_desc(x.mem_desc());
}
}
if (dy) {
if (dy_dims != y.dims()) {
dy->Resize(dy_dims);
dy->set_mem_desc(y.mem_desc());
}
}
dx_tmp->Resize(phi::make_ddim((*dx_bd_dims)));
dx_tmp->mutable_data<T>(ctx.GetPlace());
dy_tmp->Resize(phi::make_ddim((*dy_bd_dims)));
dy_tmp->mutable_data<T>(ctx.GetPlace());
}
void ReduceSumForMatmulGradOutput(
const ExecutionContext &ctx,
const MKLDNNDeviceContext &dev_ctx,
const dnnl::engine onednn_engine,
const Tensor *dx_tmp,
Tensor *dx,
const std::vector<int64_t> &dx_dims,
const std::vector<int64_t> &squeezed_dims) const {
paddle::platform::ReductionMKLDNNHandler<T> handler(
dnnl::algorithm::reduction_sum,
0.0f,
0.0f,
onednn_engine,
ctx.GetPlace(),
dx_tmp,
dx,
dx_dims);
private:
void ExecuteMatMulGrad(const ExecutionContext &ctx,
const MKLDNNDeviceContext &dev_ctx,
const dnnl::engine &engine,
phi::DenseTensor *x,
bool trans_x,
bool is_fold_init_dims_x,
phi::DenseTensor *y,
bool trans_y,
bool is_fold_init_dims_y,
phi::DenseTensor *out) const {
// gradient is calculated in a different way when broadcasting is used
bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) &&
out->dims().size() == 2;
Tensor x_combined, y_combined;
if (!need_combine) {
x_combined = *x;
y_combined = *y;
} else {
x_combined = is_fold_init_dims_x ? FoldOuterDims(*x)
: FoldFirstAndLastDims<T>(dev_ctx, x);
y_combined = is_fold_init_dims_y ? FoldOuterDims(*y)
: FoldFirstAndLastDims<T>(dev_ctx, y);
}
auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx);
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
std::unordered_map<int, dnnl::memory> reduction_args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
MatMulMKLDNNHandler<T, T, T> handler(engine,
ctx.GetPlace(),
&x_combined,
trans_x,
&y_combined,
trans_y,
out,
alpha);
auto &astream = MKLDNNDeviceContext::tls().get_stream();
auto reduction_p = handler.AcquireForwardPrimitive();
const auto src_memory_p = handler.AcquireSrcMemory(&x_combined);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined);
const auto dst_memory_p = handler.AcquireDstMemory(out);
reduction_p->execute(astream, reduction_args);
astream.wait();
auto matmul_p = handler.AcquireForwardPrimitive();
dx->set_mem_desc(dst_memory_p->get_desc().reshape(squeezed_dims));
}
std::unordered_map<int, dnnl::memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
std::vector<int64_t> ExtendDimsWithOnes(const std::vector<int64_t> &dims,
int new_size) const {
std::vector<int64_t> new_dims(new_size, 1);
for (size_t i = 0; i < dims.size(); ++i) {
new_dims[new_size - dims.size() + i] = dims[i];
}
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
return new_dims;
out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
}
};
void RunKernel(const ExecutionContext &ctx) const {
template <typename T>
class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const ExecutionContext &ctx) const override {
const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &onednn_engine = dev_ctx.GetEngine();
......@@ -660,113 +768,39 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, &dx_bd_dims, &dy_bd_dims);
if (trans_x && trans_y) {
ExecuteMatMulV2<T, T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
y,
y_dims,
true,
dout,
dout_dims,
true,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T, T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
true,
x,
x_dims,
true,
&dy_tmp,
dy_bd_dims,
2);
ExecuteMatMulV2<T, T>(
ctx, onednn_engine, y, y_dims, true, dout, dout_dims, true, &dx_tmp);
ExecuteMatMulV2<T, T>(
ctx, onednn_engine, dout, dout_dims, true, x, x_dims, true, &dy_tmp);
} else if (trans_x) {
ExecuteMatMulV2<T, T>(
ctx, onednn_engine, y, y_dims, false, dout, dout_dims, true, &dx_tmp);
ExecuteMatMulV2<T, T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
y,
y_dims,
false,
dout,
dout_dims,
true,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T, T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
x,
x_dims,
false,
dout,
dout_dims,
false,
&dy_tmp,
dy_bd_dims,
2);
&dy_tmp);
} else if (trans_y) {
ExecuteMatMulV2<T, T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
false,
y,
y_dims,
false,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T, T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
true,
x,
x_dims,
false,
&dy_tmp,
dy_bd_dims,
2);
&dx_tmp);
ExecuteMatMulV2<T, T>(
ctx, onednn_engine, dout, dout_dims, true, x, x_dims, false, &dy_tmp);
} else {
ExecuteMatMulV2<T, T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
false,
y,
y_dims,
true,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T, T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
x,
x_dims,
true,
dout,
dout_dims,
false,
&dy_tmp,
dy_bd_dims,
2);
ExecuteMatMulV2<T, T>(
ctx, onednn_engine, dout, dout_dims, false, y, y_dims, true, &dx_tmp);
ExecuteMatMulV2<T, T>(
ctx, onednn_engine, x, x_dims, true, dout, dout_dims, false, &dy_tmp);
}
if (x_dims != dx_bd_dims) {
......@@ -776,7 +810,7 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
&dx_tmp,
dx,
x_dims,
phi::vectorize(x->dims()));
vectorize(x->dims()));
} else {
*dx = std::move(dx_tmp);
}
......@@ -787,7 +821,7 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
&dy_tmp,
dy,
y_dims,
phi::vectorize(y->dims()));
vectorize(y->dims()));
} else {
*dy = std::move(dy_tmp);
}
......@@ -797,162 +831,76 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
}
private:
paddle::operators::MatMulGradMKLDNNKernel<T> matmul_v1_grad_mkldnn_kernel;
};
} // anonymous namespace
namespace paddle {
namespace operators {
template <typename T>
void MatMulGradMKLDNNKernel<T>::Compute(const ExecutionContext &ctx) const {
if (ctx.HasAttr("head_number")) {
PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"),
1,
platform::errors::Unimplemented(
"oneDNN matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number")));
}
RunKernel(ctx);
}
void CalculateGradMatrixDims(const ExecutionContext &ctx,
Tensor *dx_tmp,
Tensor *dy_tmp,
const std::vector<int64_t> &dx_dims,
const std::vector<int64_t> &dy_dims,
std::vector<int64_t> *dx_bd_dims,
std::vector<int64_t> *dy_bd_dims) const {
for (size_t i = 0; i < dx_dims.size() - 2; ++i) {
if (dx_dims[i] != dy_dims[i]) {
if (dx_dims[i] == 1) {
(*dx_bd_dims)[i] = dy_dims[i];
} else {
(*dy_bd_dims)[i] = dx_dims[i];
}
}
}
template <typename T>
void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(
const ExecutionContext &ctx,
const MKLDNNDeviceContext &dev_ctx,
const dnnl::engine &engine,
Tensor *x,
bool trans_x,
bool is_fold_init_dims_x,
Tensor *y,
bool trans_y,
bool is_fold_init_dims_y,
Tensor *out) const {
// gradient is calculated in a different way when broadcasting is used
bool need_combine = (x->dims().size() == 3 || y->dims().size() == 3) &&
out->dims().size() == 2;
Tensor x_combined, y_combined;
if (!need_combine) {
x_combined = *x;
y_combined = *y;
} else {
x_combined = is_fold_init_dims_x ? FoldOuterDims(*x)
: FoldFirstAndLastDims<T>(dev_ctx, x);
y_combined = is_fold_init_dims_y ? FoldOuterDims(*y)
: FoldFirstAndLastDims<T>(dev_ctx, y);
dx_tmp->Resize(phi::make_ddim((*dx_bd_dims)));
dx_tmp->mutable_data<T>(ctx.GetPlace());
dy_tmp->Resize(phi::make_ddim((*dy_bd_dims)));
dy_tmp->mutable_data<T>(ctx.GetPlace());
}
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
MatMulMKLDNNHandler<T, T, T> handler(engine,
ctx.GetPlace(),
&x_combined,
trans_x,
&y_combined,
trans_y,
out,
alpha);
const auto src_memory_p = handler.AcquireSrcMemory(&x_combined);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined);
const auto dst_memory_p = handler.AcquireDstMemory(out);
void ReduceSumForMatmulGradOutput(
const ExecutionContext &ctx,
const MKLDNNDeviceContext &dev_ctx,
const dnnl::engine onednn_engine,
const Tensor *dx_tmp,
Tensor *dx,
const std::vector<int64_t> &dx_dims,
const std::vector<int64_t> &squeezed_dims) const {
phi::funcs::ReductionOneDNNHandler<T> handler(
dnnl::algorithm::reduction_sum,
0.0f,
0.0f,
onednn_engine,
ctx.GetPlace(),
dx_tmp,
dx,
dx_dims);
auto matmul_p = handler.AcquireForwardPrimitive();
auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx);
std::unordered_map<int, dnnl::memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
std::unordered_map<int, dnnl::memory> reduction_args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
auto &astream = platform::MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
auto &astream = MKLDNNDeviceContext::tls().get_stream();
auto reduction_p = handler.AcquireForwardPrimitive();
out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
}
reduction_p->execute(astream, reduction_args);
astream.wait();
template <typename T>
void MatMulGradMKLDNNKernel<T>::RunKernel(const ExecutionContext &ctx) const {
const auto &dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto &onednn_engine = dev_ctx.GetEngine();
auto x = *ctx.Input<phi::DenseTensor>("X");
auto y = *ctx.Input<phi::DenseTensor>("Y");
auto dout = *ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto *dx = ctx.Output<phi::DenseTensor>(framework::GradVarName("X"));
auto *dy = ctx.Output<phi::DenseTensor>(framework::GradVarName("Y"));
bool transpose_x = ctx.HasAttr("transpose_X") ? ctx.Attr<bool>("transpose_X")
: ctx.Attr<bool>("trans_x");
bool transpose_y = ctx.HasAttr("transpose_Y") ? ctx.Attr<bool>("transpose_Y")
: ctx.Attr<bool>("trans_y");
ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
framework::DDim dx_dims;
if (dx) {
dx_dims = dx->dims();
if (dx_dims != x.dims()) {
dx->Resize(x.dims());
}
dx->set_mem_desc(dst_memory_p->get_desc().reshape(squeezed_dims));
}
framework::DDim dy_dims;
if (dy) {
dy_dims = dy->dims();
if (dy_dims != y.dims()) {
dy->Resize(y.dims());
std::vector<int64_t> ExtendDimsWithOnes(const std::vector<int64_t> &dims,
int new_size) const {
std::vector<int64_t> new_dims(new_size, 1);
for (size_t i = 0; i < dims.size(); ++i) {
new_dims[new_size - dims.size() + i] = dims[i];
}
}
if (transpose_x && transpose_y) {
this->ExecuteMatMulGrad(
ctx, dev_ctx, onednn_engine, &y, true, true, &dout, true, false, dx);
this->ExecuteMatMulGrad(
ctx, dev_ctx, onednn_engine, &dout, true, true, &x, true, false, dy);
} else if (transpose_x) {
this->ExecuteMatMulGrad(
ctx, dev_ctx, onednn_engine, &y, false, false, &dout, true, false, dx);
this->ExecuteMatMulGrad(
ctx, dev_ctx, onednn_engine, &x, false, false, &dout, false, true, dy);
} else if (transpose_y) {
this->ExecuteMatMulGrad(
ctx, dev_ctx, onednn_engine, &dout, false, false, &y, false, true, dx);
this->ExecuteMatMulGrad(
ctx, dev_ctx, onednn_engine, &dout, true, true, &x, false, true, dy);
} else {
this->ExecuteMatMulGrad(
ctx, dev_ctx, onednn_engine, &dout, false, false, &y, true, false, dx);
this->ExecuteMatMulGrad(
ctx, dev_ctx, onednn_engine, &x, true, true, &dout, false, true, dy);
}
if (dx) {
if (dx_dims != x.dims()) {
dx->Resize(dx_dims);
dx->set_mem_desc(x.mem_desc());
}
}
if (dy) {
if (dy_dims != y.dims()) {
dy->Resize(dy_dims);
dy->set_mem_desc(y.mem_desc());
}
return new_dims;
}
}
template class MatMulGradMKLDNNKernel<float>;
template class MatMulGradMKLDNNKernel<paddle::platform::bfloat16>;
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
private:
MatMulGradMKLDNNKernel<T> matmul_v1_grad_mkldnn_kernel;
};
} // anonymous namespace
REGISTER_OP_KERNEL(matmul,
MKLDNN,
......@@ -965,8 +913,8 @@ REGISTER_OP_KERNEL(matmul,
REGISTER_OP_KERNEL(matmul_grad,
MKLDNN,
::paddle::platform::CPUPlace,
ops::MatMulGradMKLDNNKernel<float>,
ops::MatMulGradMKLDNNKernel<paddle::platform::bfloat16>);
MatMulGradMKLDNNKernel<float>,
MatMulGradMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(matmul_v2,
MKLDNN,
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Functor>
class MKLDNNActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
OP_INOUT_CHECK(context.HasInput("X"), "Input", "X", "Activation");
OP_INOUT_CHECK(context.HasInput("Out"), "Output", "Out", "Activation");
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(context);
}
};
template <typename Functor>
class MKLDNNActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(context);
}
};
} // namespace operators
} // namespace paddle
......@@ -30,7 +30,6 @@ using LoDTensor = phi::DenseTensor;
using platform::MatMulV2MKLDNNHandler;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
using dnnl::inner_product_forward;
using dnnl::memory;
......@@ -73,7 +72,7 @@ class MulPrimitiveFactory {
return *(mul_);
}
auto src_desc = CreateMemDescriptor<XT>(&x_matrix, MKLDNNMemoryFormat::nc);
auto src_desc = CreateMemDescriptor<XT>(&x_matrix, OneDNNMemoryFormat::nc);
x_input_ = CreateMemory<XT>(src_desc, &x_matrix);
if (is_int8_) {
......@@ -84,7 +83,7 @@ class MulPrimitiveFactory {
y_input_ = TransposeInputY(&y_matrix);
}
auto dst_desc = CreateMemDescriptor<OT>(output, MKLDNNMemoryFormat::any);
auto dst_desc = CreateMemDescriptor<OT>(output, OneDNNMemoryFormat::any);
mul_ = CreateMulPrimitive(*x_input_, *y_input_, dst_desc, output, ctx);
Execute();
......@@ -126,8 +125,8 @@ class MulPrimitiveFactory {
auto ndims = input_y.get_desc().data.ndims;
auto y_dims = std::vector<int64_t>(dims, dims + ndims);
auto user_y_desc = CreateMemDescriptor<YT>(y_dims, MKLDNNMemoryFormat::oi);
auto y_desc = CreateMemDescriptor<int8_t>(y_dims, MKLDNNMemoryFormat::oi);
auto user_y_desc = CreateMemDescriptor<YT>(y_dims, OneDNNMemoryFormat::oi);
auto y_desc = CreateMemDescriptor<int8_t>(y_dims, OneDNNMemoryFormat::oi);
return ReorderWithScale(
user_y_desc, y_desc, input_y.get_data_handle(), scale_y);
......@@ -205,8 +204,8 @@ class MulPrimitiveFactory {
auto dst_mdesc =
data->dims().size() >= 4
? (data->dims().size() == 5
? CreateMemDescriptor<T>(data, MKLDNNMemoryFormat::ncdhw)
: CreateMemDescriptor<T>(data, MKLDNNMemoryFormat::nchw))
? CreateMemDescriptor<T>(data, OneDNNMemoryFormat::ncdhw)
: CreateMemDescriptor<T>(data, OneDNNMemoryFormat::nchw))
: src_mdesc;
if (src_mdesc != dst_mdesc) {
......@@ -214,8 +213,8 @@ class MulPrimitiveFactory {
Reorder(src_mdesc,
dst_mdesc,
to_void_cast<T>(data->data<T>()),
to_void_cast<T>(x_tmp.data<T>()));
phi::funcs::to_void_cast<T>(data->data<T>()),
phi::funcs::to_void_cast<T>(x_tmp.data<T>()));
x_tmp.Resize(data->dims());
x_tmp.set_mem_desc(dst_mdesc);
......@@ -230,7 +229,7 @@ class MulPrimitiveFactory {
void UpdateDataPointers(const ExecutionContext &ctx,
Tensor *out,
const Tensor *in) {
x_input_->set_data_handle(to_void_cast<XT>(in->data<XT>()));
x_input_->set_data_handle(phi::funcs::to_void_cast<XT>(in->data<XT>()));
output_->set_data_handle(out->mutable_data<OT>(ctx.GetPlace()));
out->set_mem_desc(output_->get_desc());
}
......@@ -238,23 +237,24 @@ class MulPrimitiveFactory {
template <typename T>
memory::desc CreateMemDescriptor(
const Tensor *tensor,
MKLDNNMemoryFormat format,
memory::data_type type = platform::MKLDNNGetDataType<T>()) {
OneDNNMemoryFormat format,
memory::data_type type = phi::funcs::OneDNNGetDataType<T>()) {
auto dims = phi::vectorize<int64_t>(tensor->dims());
return platform::MKLDNNMemDesc(dims, type, format);
return phi::funcs::OneDNNMemDesc(dims, type, format);
}
template <typename T>
memory::desc CreateMemDescriptor(
const std::vector<int64_t> &dims,
MKLDNNMemoryFormat format,
memory::data_type type = platform::MKLDNNGetDataType<T>()) {
return platform::MKLDNNMemDesc(dims, type, format);
OneDNNMemoryFormat format,
memory::data_type type = phi::funcs::OneDNNGetDataType<T>()) {
return phi::funcs::OneDNNMemDesc(dims, type, format);
}
template <typename T>
memory CreateMemory(const memory::desc &desc, const Tensor *tensor) {
return memory(desc, engine_, to_void_cast<T>(tensor->data<T>()));
return memory(
desc, engine_, phi::funcs::to_void_cast<T>(tensor->data<T>()));
}
memory CreateDstMemory(
......@@ -266,7 +266,7 @@ class MulPrimitiveFactory {
OT *output_data = output->mutable_data<OT>(ctx.GetPlace(), buffer_size);
output->set_mem_desc(dst_desc);
return memory(dst_desc, engine_, to_void_cast<OT>(output_data));
return memory(dst_desc, engine_, phi::funcs::to_void_cast<OT>(output_data));
}
memory Reorder(const memory::desc &src_desc,
......@@ -296,9 +296,10 @@ class MulPrimitiveFactory {
memory TransposeInputY(const Tensor *input_y) {
auto dims = phi::vectorize<int64_t>(input_y->dims());
std::swap(dims[0], dims[1]); // Correct output dimensions
auto src_desc = CreateMemDescriptor<YT>(dims, MKLDNNMemoryFormat::io);
auto dst_desc = CreateMemDescriptor<YT>(dims, MKLDNNMemoryFormat::oi);
return Reorder(src_desc, dst_desc, to_void_cast<YT>(input_y->data<YT>()));
auto src_desc = CreateMemDescriptor<YT>(dims, OneDNNMemoryFormat::io);
auto dst_desc = CreateMemDescriptor<YT>(dims, OneDNNMemoryFormat::oi);
return Reorder(
src_desc, dst_desc, phi::funcs::to_void_cast<YT>(input_y->data<YT>()));
}
const dnnl::engine &engine_;
......
......@@ -25,7 +25,6 @@ namespace operators {
using dnnl::memory;
using dnnl::primitive;
using dnnl::reorder;
using platform::to_void_cast;
using Tensor = phi::DenseTensor;
using dnnl::stream;
using phi::DataLayout;
......@@ -72,28 +71,24 @@ class QuantOpKernel : public framework::OpKernel<T> {
DNNL_ARG_DST, mask, {static_cast<int32_t>(quantization_shift)});
}
framework::proto::VarType::Type x_paddle_dtype =
framework::TransToProtoVarType(x->dtype());
framework::proto::VarType::Type out_paddle_dtype;
auto x_type = phi::funcs::ToOneDNNDataType(x->dtype());
DataType out_dtype;
if (bfloat16) {
out_paddle_dtype = framework::proto::VarType::BF16;
out_dtype = DataType::BFLOAT16;
} else if (is_negative_input && !with_shift) {
out_paddle_dtype = framework::proto::VarType::INT8;
out_dtype = DataType::INT8;
} else {
out_paddle_dtype = framework::proto::VarType::UINT8;
out_dtype = DataType::UINT8;
}
platform::ReorderMKLDNNHandler reorder_handler(
x_tz,
x_paddle_dtype,
framework::ToMKLDNNDataType(x_paddle_dtype),
out_paddle_dtype,
framework::ToMKLDNNDataType(out_paddle_dtype),
dev_ctx.GetEngine());
auto out_type = phi::funcs::ToOneDNNDataType(out_dtype);
phi::funcs::ReorderOneDNNHandler reorder_handler(
x_tz, x->dtype(), x_type, out_dtype, out_type, dev_ctx.GetEngine());
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->mem_desc(), platform::to_void_cast(x->data<T>()));
x->mem_desc(), phi::funcs::to_void_cast(x->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out, x->mem_desc(), dev_ctx.GetPlace());
......
......@@ -24,7 +24,6 @@ namespace operators {
using dnnl::memory;
using dnnl::reorder;
using platform::to_void_cast;
using Tensor = phi::DenseTensor;
namespace {
......@@ -88,10 +87,10 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
if (reorder_p == nullptr) {
auto src_dt = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(input->dtype()));
auto dst_dt = with_shift ? framework::MKLDNNDataType::u8 : src_dt;
auto dst_dt = with_shift ? framework::OneDNNDataType::u8 : src_dt;
src_memory = std::make_shared<dnnl::memory>(
input->mem_desc(), engine, to_void_cast<T>(input_data));
input->mem_desc(), engine, phi::funcs::to_void_cast<T>(input_data));
auto xstrides = input->mem_desc().data.format_desc.blocking.strides;
......@@ -112,11 +111,11 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
clip_to_uint8(shift_out - reorder_scale * shift_in);
std::memset(output_data, reorder_shift, output->numel());
dst_memory = std::make_shared<dnnl::memory>(
dst_md, engine, to_void_cast<uint8_t>(output_data));
dst_md, engine, phi::funcs::to_void_cast<uint8_t>(output_data));
} else {
T* output_data = output->mutable_data<T>(ctx.GetPlace());
dst_memory = std::make_shared<dnnl::memory>(
dst_md, engine, to_void_cast<T>(output_data));
dst_md, engine, phi::funcs::to_void_cast<T>(output_data));
}
auto reorder_pd =
......@@ -129,7 +128,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
} else {
src_memory =
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_src_mem));
src_memory->set_data_handle(to_void_cast<T>(input_data));
src_memory->set_data_handle(phi::funcs::to_void_cast<T>(input_data));
dst_memory =
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_dst_mem));
......
......@@ -30,8 +30,6 @@ enum class ReshapeKernelOpName {
namespace paddle {
namespace operators {
using platform::to_void_cast;
static std::vector<int> extract_shape(
const std::vector<const phi::DenseTensor*>& list_new_shape_tensor) {
std::vector<int> vec_new_shape;
......@@ -73,16 +71,12 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
auto x_vec_dims = phi::vectorize(x_dims);
dnnl::memory::data_type x_type =
framework::ToMKLDNNDataType(framework::TransToProtoVarType(x->dtype()));
platform::ReorderMKLDNNHandler reorder_handler(
x_vec_dims,
framework::TransToProtoVarType(x->dtype()),
x_type,
onednn_engine);
auto x_type = phi::funcs ::ToOneDNNDataType(x->dtype());
phi::funcs::ReorderOneDNNHandler reorder_handler(
x_vec_dims, x->dtype(), x_type, onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->mem_desc(), platform::to_void_cast(x->data<T>()));
x->mem_desc(), phi::funcs::to_void_cast(x->data<T>()));
out->Resize(x_dims); // to match x numel, format is changed later
// reorder is done into a plain tag to allow usage with blocked formats
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
......@@ -347,16 +341,12 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel<T, op_name> {
auto dout_vec_dims = phi::vectorize(dout->dims());
dnnl::memory::data_type dout_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(dout->dtype()));
platform::ReorderMKLDNNHandler reorder_handler(
dout_vec_dims,
framework::TransToProtoVarType(dout->dtype()),
dout_type,
onednn_engine);
auto dout_type = phi::funcs::ToOneDNNDataType(dout->dtype());
phi::funcs::ReorderOneDNNHandler reorder_handler(
dout_vec_dims, dout->dtype(), dout_type, onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->mem_desc(), platform::to_void_cast(dout->data<T>()));
dout->mem_desc(), phi::funcs::to_void_cast(dout->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dx, this->getPlainFormatTag(dout), ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
......
......@@ -17,17 +17,16 @@ limitations under the License. */
namespace paddle {
namespace operators {
using platform::MKLDNNGetDataType;
template <typename T>
class ShuffleChannelMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::shuffle_forward> {
: public phi::funcs::OneDNNHandlerNoCachingT<T, dnnl::shuffle_forward> {
public:
ShuffleChannelMKLDNNHandler(const phi::DenseTensor* x,
const int group,
const dnnl::engine engine,
platform::Place cpu_place)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::shuffle_forward>(engine,
cpu_place) {
: phi::funcs::OneDNNHandlerNoCachingT<T, dnnl::shuffle_forward>(
engine, cpu_place) {
static constexpr int channel_axis = 1;
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_training, x->mem_desc(), channel_axis, group);
......
......@@ -53,19 +53,17 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto x_vec_dims = phi::vectorize(x->dims());
framework::proto::VarType::Type x_paddle_type =
framework::TransToProtoVarType(x->dtype());
dnnl::memory::data_type x_type = framework::ToMKLDNNDataType(x_paddle_type);
platform::ReorderMKLDNNHandler reorder_handler(
x_vec_dims, x_paddle_type, x_type, dnnl_engine);
auto x_type = phi::funcs::ToOneDNNDataType(x->dtype());
phi::funcs::ReorderOneDNNHandler reorder_handler(
x_vec_dims, x->dtype(), x_type, dnnl_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->mem_desc(), platform::to_void_cast(x->data<T>()));
x->mem_desc(), phi::funcs::to_void_cast(x->data<T>()));
auto dst_md =
dnnl::memory::desc(x_vec_dims,
x->mem_desc().data_type(),
platform::GetPlainMKLDNNFormat(x_vec_dims.size()));
phi::funcs::GetPlainOneDNNFormat(x_vec_dims.size()));
// a trick is used here to fake transpose of out_md, so later it will be
// "untransposed", leaving output data in plain format tag
auto dst_strides = FakeTranposeStrides(dst_md, transpose_axis);
......@@ -148,17 +146,13 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
}
auto dout_vec_dims = phi::vectorize(dout->dims());
auto dout_type = phi::funcs::ToOneDNNDataType(dout->dtype());
framework::proto::VarType::Type dout_paddle_type =
framework::TransToProtoVarType(dout->dtype());
dnnl::memory::data_type dout_type =
framework::ToMKLDNNDataType(dout_paddle_type);
platform::ReorderMKLDNNHandler reorder_handler(
dout_vec_dims, dout_paddle_type, dout_type, dnnl_engine);
phi::funcs::ReorderOneDNNHandler reorder_handler(
dout_vec_dims, dout->dtype(), dout_type, dnnl_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->mem_desc(), platform::to_void_cast(dout->data<T>()));
dout->mem_desc(), phi::funcs::to_void_cast(dout->data<T>()));
auto reorder_dst_memory_p =
reorder_handler.AcquireDstMemory(dx, dout->mem_desc(), ctx.GetPlace());
......
......@@ -708,7 +708,7 @@ class Pad2dOp : public framework::OperatorWithKernel {
.data.format_desc.blocking.inner_nblks == 0) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
......@@ -720,8 +720,8 @@ class Pad2dOp : public framework::OperatorWithKernel {
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN)) {
if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
......
......@@ -42,7 +42,7 @@ class Pad3dOp : public framework::OperatorWithKernel {
.data.format_desc.blocking.inner_nblks == 0) {
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
......@@ -54,8 +54,8 @@ class Pad3dOp : public framework::OperatorWithKernel {
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN)) {
if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
......
......@@ -58,8 +58,8 @@ framework::OpKernelType PoolOp::GetKernelTypeForVar(
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN)) {
if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
......@@ -92,8 +92,8 @@ framework::OpKernelType PoolOpGrad::GetKernelTypeForVar(
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
if ((expected_kernel_type.data_layout_ == phi::DataLayout::kMKLDNN) &&
(tensor.layout() != phi::DataLayout::kMKLDNN)) {
if ((expected_kernel_type.data_layout_ == phi::DataLayout::ONEDNN) &&
(tensor.layout() != phi::DataLayout::ONEDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
......
......@@ -24,7 +24,7 @@ framework::OpKernelType QuantOp::GetExpectedKernelType(
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
......
......@@ -24,7 +24,7 @@ framework::OpKernelType ReQuantOp::GetExpectedKernelType(
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
......
......@@ -167,7 +167,7 @@ class SliceOp : public framework::OperatorWithKernel {
.data.format_desc.blocking.inner_nblks == 0)
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
......@@ -340,7 +340,7 @@ class SliceOpGrad : public framework::OperatorWithKernel {
.data.format_desc.blocking.inner_nblks == 0)
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
......
......@@ -124,7 +124,7 @@ class SplitOp : public framework::OperatorWithKernel {
if (x_md.data.format_desc.blocking.inner_nblks == 0)
return framework::OpKernelType(input_data_type,
ctx.GetPlace(),
phi::DataLayout::kMKLDNN,
phi::DataLayout::ONEDNN,
framework::LibraryType::kMKLDNN);
}
#endif
......
......@@ -128,7 +128,7 @@ class SqueezeOp : public framework::OperatorWithKernel {
// #ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// phi::DataLayout::kMKLDNN,
// phi::DataLayout::ONEDNN,
// framework::LibraryType::kMKLDNN);
// }
// #endif
......@@ -155,7 +155,7 @@ class SqueezeGradOp : public framework::OperatorWithKernel {
// #ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// phi::DataLayout::kMKLDNN,
// phi::DataLayout::ONEDNN,
// framework::LibraryType::kMKLDNN);
// }
// #endif
......@@ -222,7 +222,7 @@ class Squeeze2Op : public framework::OperatorWithKernel {
// #ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// phi::DataLayout::kMKLDNN,
// phi::DataLayout::ONEDNN,
// framework::LibraryType::kMKLDNN);
// }
// #endif
......@@ -270,7 +270,7 @@ class Squeeze2GradOp : public framework::OperatorWithKernel {
// #ifdef PADDLE_WITH_MKLDNN
// if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
// return framework::OpKernelType(input_data_type, ctx.GetPlace(),
// phi::DataLayout::kMKLDNN,
// phi::DataLayout::ONEDNN,
// framework::LibraryType::kMKLDNN);
// }
// #endif
......
......@@ -49,7 +49,7 @@ class TransferLayoutOp : public framework::OperatorWithKernel {
auto *in_tensor = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in);
// NOTE(zhiqiu): hot fix, allow empty tensor of kMKLDNN layout to run this
// op
if (in_tensor->layout() != DataLayout::kMKLDNN) {
if (in_tensor->layout() != DataLayout::ONEDNN) {
PADDLE_ENFORCE_EQ(in_tensor->IsInitialized(),
true,
platform::errors::PreconditionNotMet(
......
......@@ -63,33 +63,32 @@ class TransferLayoutFunctor {
auto in_layout = static_cast<DataLayout>(src_layout_);
auto *tensor_out = out_->GetMutable<phi::DenseTensor>();
VLOG(4) << in_layout << "->" << out_layout << " " << in_tensor.layout();
if (!in_tensor.IsInitialized() && in_layout == DataLayout::kMKLDNN &&
if (!in_tensor.IsInitialized() && in_layout == DataLayout::ONEDNN &&
out_layout == DataLayout::kNHWC) {
tensor_out->Resize(in_tensor.dims());
tensor_out->set_layout(out_layout);
platform::MatchShapeToLayout(tensor_out, in_layout, out_layout);
phi::funcs::MatchShapeToLayout(tensor_out, in_layout, out_layout);
return;
}
if (in_layout == DataLayout::kMKLDNN || out_layout == DataLayout::kMKLDNN) {
if (in_layout == DataLayout::ONEDNN || out_layout == DataLayout::ONEDNN) {
PADDLE_ENFORCE_NE(
in_layout,
out_layout,
platform::errors::PreconditionNotMet(
"No layout transform needed between two MKLDNN OPKernels."));
"No layout transform needed between two oneDNN OPKernels."));
if (in_layout != DataLayout::kMKLDNN &&
out_layout == DataLayout::kMKLDNN) {
if (in_layout != DataLayout::ONEDNN && out_layout == DataLayout::ONEDNN) {
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
// Just set layout/format. No real transform occur
auto out_format = platform::MKLDNNFormatForSize(
in_tensor.dims().size(), framework::ToMKLDNNFormat(in_layout));
auto out_format = phi::funcs::OneDNNFormatForSize(
in_tensor.dims().size(), framework::ToOneDNNFormat(in_layout));
out_tensor.ShareDataWith(in_tensor);
// For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order
if (in_layout == DataLayout::kNHWC) {
VLOG(4) << "kNHWC";
platform::MatchShapeToLayout(&out_tensor, in_layout, out_layout);
phi::funcs::MatchShapeToLayout(&out_tensor, in_layout, out_layout);
paddle::platform::MKLDNNDeviceContext::tls()
.set_cur_paddle_data_layout(in_layout);
}
......
......@@ -25,121 +25,19 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/phi/backends/onednn/onednn_helper.h"
namespace paddle {
#ifdef PADDLE_WITH_MKLDNN
using MKLDNNMemoryFormat = dnnl::memory::format_tag;
using OneDNNMemoryFormat = dnnl::memory::format_tag;
#endif
namespace platform {
using MKLDNNStream = dnnl::stream;
using MKLDNNEngine = dnnl::engine;
using MKLDNNMemory = dnnl::memory;
using MKLDNNMemoryDescriptor = dnnl::memory::desc;
using MKLDNNPrimitive = dnnl::primitive;
using MKLDNNPrimitiveDesc = dnnl::handle<dnnl_primitive_desc_t>;
typedef std::unique_ptr<MKLDNNStream> MKLDNNStreamPtr;
typedef std::unique_ptr<MKLDNNEngine> MKLDNNEnginePtr;
typedef std::unique_ptr<MKLDNNMemory> MKLDNNMemoryPtr;
typedef std::unique_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr;
typedef std::unique_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr;
template <typename Type>
void* to_void_cast(const Type* t) {
return static_cast<void*>(const_cast<Type*>(t));
}
template <typename Type>
void* to_void_reinterpret_cast(const Type* t) {
return reinterpret_cast<void*>(const_cast<Type*>(t));
}
template <class Type>
using tf_desc = typename Type::desc;
template <class Type>
using tf_pd = typename Type::primitive_desc;
template <typename Type, typename Engine, typename... Args>
std::shared_ptr<tf_pd<Type>> MKLDNNFwdPrimitiveDesc(const Engine& e,
Args&&... args) {
auto desc = tf_desc<Type>(dnnl::prop_kind::forward, (args)...);
auto pd = new tf_pd<Type>(desc, e);
return std::shared_ptr<tf_pd<Type>>(pd);
}
template <typename Type, typename Engine, typename Primitive, typename... Args>
tf_pd<Type> MKLDNNBwdPrimitiveDesc(const Engine& e,
const Primitive& p,
Args&&... args) {
auto desc = tf_desc<Type>(args...);
return tf_pd<Type>(desc, e, p);
}
inline void MatchShapeToLayout(phi::DenseTensor* tensor_in,
phi::DataLayout from,
phi::DataLayout to) {
auto print_dims = [](const std::vector<int>& dims) {
std::ostringstream oss;
if (!dims.empty()) {
oss << "[";
// Convert all but the last element to avoid a trailing ","
std::copy(
dims.begin(), dims.end() - 1, std::ostream_iterator<int>(oss, ","));
// Now add the last element with no delimiter
oss << dims.back() << "]";
}
return oss.str();
};
// In these data layouts, channel dimension is either on 2nd position: nChw or
// at last nhwC, so for dim==2 these layouts are the same and nothing should
// be done. Similarly for dim==1 when you have just one possible combination.
if (tensor_in->dims().size() < 3) {
VLOG(3) << "Keeping kMKLDNN/kNHWC/kNDHWC output_shape"
<< print_dims(phi::vectorize<int>(tensor_in->dims()));
return;
}
switch (from) {
case phi::DataLayout::kMKLDNN:
if ((to == phi::DataLayout::kNHWC) || (to == phi::DataLayout::kNDHWC)) {
auto dims = phi::vectorize<int>(tensor_in->dims());
std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
tensor_in->Resize(phi::make_ddim(dims));
VLOG(3) << "Rotating Shape from: kMKLDNN to: kNHWC/kNDHWC output_shape"
<< print_dims(dims);
}
break;
case phi::DataLayout::kNHWC:
case phi::DataLayout::kNDHWC:
if (to == phi::DataLayout::kMKLDNN) {
auto dims = phi::vectorize<int>(tensor_in->dims());
std::rotate(dims.begin() + 1, dims.end() - 1, dims.end());
tensor_in->Resize(phi::make_ddim(dims));
VLOG(3) << "Rotating Shape from: kNHWC/kNDHWC to: kMKLDNN output_shape"
<< print_dims(dims);
}
break;
default:
break;
}
}
struct mkldnn_dummy_primitive {
struct primitive_desc {};
struct desc {};
};
inline dnnl::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
dnnl::memory::data_type data_type,
MKLDNNMemoryFormat format) {
return dnnl::memory::desc({dims}, data_type, format);
}
inline void ClearMKLDNNCache(const platform::Place& place,
void* ptr = nullptr) {
// Clear mkl-dnn cache,
......@@ -161,33 +59,6 @@ inline void DontClearMKLDNNCache(const platform::Place& place) {
}
}
template <typename Type>
dnnl::memory::data_type MKLDNNGetDataType() {
return dnnl::memory::data_type::undef;
}
template <>
inline dnnl::memory::data_type MKLDNNGetDataType<float>() {
return dnnl::memory::data_type::f32;
}
template <>
inline dnnl::memory::data_type MKLDNNGetDataType<int32_t>() {
return dnnl::memory::data_type::s32;
}
template <>
inline dnnl::memory::data_type MKLDNNGetDataType<int8_t>() {
return dnnl::memory::data_type::s8;
}
template <>
inline dnnl::memory::data_type MKLDNNGetDataType<uint8_t>() {
return dnnl::memory::data_type::u8;
}
template <>
inline dnnl::memory::data_type MKLDNNGetDataType<paddle::platform::bfloat16>() {
return dnnl::memory::data_type::bf16;
}
inline void Reorder(dnnl::memory src,
dnnl::memory dst,
const dnnl::engine& engine) {
......@@ -201,95 +72,6 @@ inline void Reorder(dnnl::memory src,
astream.wait();
}
inline dnnl::memory::format_tag GetPlainMKLDNNFormat(int tensor_rank) {
switch (tensor_rank) {
case 1:
return dnnl::memory::format_tag::a;
case 2:
return dnnl::memory::format_tag::ab;
case 3:
return dnnl::memory::format_tag::abc;
case 4:
return dnnl::memory::format_tag::abcd;
case 5:
return dnnl::memory::format_tag::abcde;
case 6:
return dnnl::memory::format_tag::abcdef;
case 7:
return dnnl::memory::format_tag::abcdefg;
case 8:
return dnnl::memory::format_tag::abcdefgh;
case 9:
return dnnl::memory::format_tag::abcdefghi;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Paddle support tensors with rank in range <1, 9>, but received "
"tensor with rank: %d",
tensor_rank));
}
}
inline MKLDNNMemoryFormat MKLDNNFormatForSize(size_t dims_size,
MKLDNNMemoryFormat data_format) {
if (dims_size == 1) {
return MKLDNNMemoryFormat::x;
} else if (dims_size == 2) {
return MKLDNNMemoryFormat::nc;
} else if (dims_size == 3) {
if (data_format == MKLDNNMemoryFormat::nchw) {
return MKLDNNMemoryFormat::ncw;
} else if (data_format == MKLDNNMemoryFormat::nhwc) {
return MKLDNNMemoryFormat::nwc;
}
} else if (dims_size == 4) {
if (data_format == MKLDNNMemoryFormat::goihw) {
return MKLDNNMemoryFormat::oihw;
}
} else if (dims_size == 5) {
if (data_format == MKLDNNMemoryFormat::goidhw) {
return MKLDNNMemoryFormat::oidhw;
}
if (data_format == MKLDNNMemoryFormat::nchw) {
return MKLDNNMemoryFormat::ncdhw;
} else if (data_format == MKLDNNMemoryFormat::nhwc) {
return MKLDNNMemoryFormat::ndhwc;
}
} else if (dims_size == 6) {
if (data_format == MKLDNNMemoryFormat::nchw) {
return MKLDNNMemoryFormat::abcdef;
}
}
return data_format;
}
inline MKLDNNMemoryFormat data_format_to_memory_format(
const std::string& data_format) {
switch (phi::StringToDataLayout(data_format)) {
case phi::DataLayout::kNHWC:
return MKLDNNMemoryFormat::nhwc;
case phi::DataLayout::kNCHW:
return MKLDNNMemoryFormat::nchw;
default:
return MKLDNNMemoryFormat::any;
}
}
inline MKLDNNMemoryFormat StringToMKLDNNFormat(std::string* format) {
std::transform(format->begin(), format->end(), format->begin(), ::tolower);
if (!format->compare("nchw")) {
return MKLDNNMemoryFormat::nchw;
} else if (!format->compare("nchw16c")) {
return MKLDNNMemoryFormat::nChw16c;
} else if (!format->compare("nchw8c")) {
return MKLDNNMemoryFormat::nChw8c;
} else if (!format->compare("nhwc")) {
return MKLDNNMemoryFormat::nhwc;
} else {
return MKLDNNMemoryFormat::any;
}
}
inline std::string ThreadIDasStr(void) {
return std::to_string(
std::hash<std::thread::id>()(std::this_thread::get_id()));
......@@ -382,41 +164,6 @@ inline std::string ExtendKeyWithThreadInfoIfNeeded(
: key;
}
inline std::vector<std::vector<int64_t>> ToMkldnnPadding(
const std::vector<int64_t>& paddings) {
if (paddings.size() == 6) {
int padding_front = paddings[0];
int padding_back = paddings[1];
int padding_top = paddings[2];
int padding_bottom = paddings[3];
int padding_left = paddings[4];
int padding_right = paddings[5];
return {{padding_front, padding_top, padding_left},
{padding_back, padding_bottom, padding_right}};
} else {
int padding_top = paddings[0];
int padding_bottom = paddings[1];
int padding_left = paddings[2];
int padding_right = paddings[3];
return {{padding_top, padding_left}, {padding_bottom, padding_right}};
}
}
// The function adjusts the vector of weight dimensions for group convolutions
inline void GetGroupConvWeightsTz(std::vector<int64_t>& weights_tz, // NOLINT
const int groups) {
if (groups > 1) {
// if (is_conv3d) [o, i, d, h, w]->[g, o/g, i, d, h, w]
// else [o, i, h, w] -> [g, o/g, i, h, w]
weights_tz.push_back(0);
std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end());
weights_tz[0] = groups;
weights_tz[1] = weights_tz[1] / groups;
}
}
inline void RegisterModelLayout(
std::vector<std::unique_ptr<framework::OperatorBase>>& ops, // NOLINT
const platform::Place& place) {
......@@ -461,17 +208,8 @@ inline bool HasOpBFLOAT16DataType(const paddle::framework::OpDesc* op) {
return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "bfloat16";
}
inline bool HasOpFLOAT32DataType(const paddle::framework::OpDesc* op) {
return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "float32";
}
enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP };
template <typename T>
bool constexpr is_int8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
} // namespace platform
inline std::string FindInputNameByVarName(framework::OpDesc* op,
......
......@@ -30,32 +30,8 @@ limitations under the License. */
namespace paddle {
namespace platform {
using user_function = std::function<std::shared_ptr<float>(const float*)>;
using memory = dnnl::memory;
template <typename T,
typename TForward,
typename TBackward = mkldnn_dummy_primitive,
typename TBackward_params = mkldnn_dummy_primitive>
using MKLDNNHandlerT =
phi::funcs::OneDNNHandlerT<T, TForward, TBackward, TBackward_params>;
template <typename T,
typename TForward,
typename TBackward = mkldnn_dummy_primitive,
typename TBackward_params = mkldnn_dummy_primitive>
using MKLDNNHandlerNoCachingT = phi::funcs::
OneDNNHandlerNoCachingT<T, TForward, TBackward, TBackward_params>;
template <typename T>
using ReductionMKLDNNHandler = phi::funcs::ReductionOneDNNHandler<T>;
template <typename T>
using BroadcastDataMKLDNNHandler = phi::funcs::BroadcastDataOneDNNHandler<T>;
template <typename T>
using BinaryMKLDNNHandler = phi::funcs::BinaryOneDNNHandler<T>;
static void AppendActivation(const framework::ExecutionContext& ctx,
dnnl::post_ops& post_ops, // NOLINT
float activation_scale = 1.0f) {
......@@ -219,19 +195,9 @@ static void SetInMemDescWithLogicalLayoutFusesSupport(
}
}
template <typename T>
constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
template <typename T>
constexpr bool IsBfloat16() {
return std::is_same<T, paddle::platform::bfloat16>::value;
}
template <typename XT, typename YT, typename OT>
class MatMulV2MKLDNNHandler
: public paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul> {
: public phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
public:
MatMulV2MKLDNNHandler(const framework::ExecutionContext& ctx,
const dnnl::engine engine,
......@@ -243,8 +209,8 @@ class MatMulV2MKLDNNHandler
bool is_output_fused,
const std::vector<int64_t>& x_strides_override,
const std::vector<int64_t>& y_strides_override)
: paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul>(engine,
cpu_place) {
: phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul>(engine,
cpu_place) {
// M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);
......@@ -305,13 +271,16 @@ class MatMulV2MKLDNNHandler
}
// TODO(jczaja): Why not for int8??
if (!IsInt8<OT>() && is_output_fused) {
if (!phi::funcs::is_int8<OT>() && is_output_fused) {
out_strides = FakeTransposeStrides(out_ddims);
}
auto x_md = memory::desc(x_dims, MKLDNNGetDataType<XT>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<YT>(), y_strides);
auto out_md = memory::desc(out_ddims, MKLDNNGetDataType<OT>(), out_strides);
auto x_md =
memory::desc(x_dims, phi::funcs::OneDNNGetDataType<XT>(), x_strides);
auto y_md =
memory::desc(y_dims, phi::funcs::OneDNNGetDataType<YT>(), y_strides);
auto out_md = memory::desc(
out_ddims, phi::funcs::OneDNNGetDataType<OT>(), out_strides);
const dnnl::primitive_attr matmul_attrs = CreateMatmulAttrs(ctx);
......@@ -347,7 +316,7 @@ class MatMulV2MKLDNNHandler
auto* residual_data = ctx.Input<phi::DenseTensor>("ResidualData");
auto residual_data_tz = phi::vectorize(residual_data->dims());
auto residual_data_md = memory::desc(residual_data_tz,
MKLDNNGetDataType<OT>(),
phi::funcs::OneDNNGetDataType<OT>(),
dnnl::memory::format_tag::any);
post_operations.append_binary(dnnl::algorithm::binary_add,
residual_data_md);
......@@ -389,8 +358,9 @@ class MatMulV2MKLDNNHandler
std::shared_ptr<memory> AcquireWeightsMemory(const phi::DenseTensor* input) {
const YT* input_data = input->data<YT>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<YT>(input_data));
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->weights_desc(),
phi::funcs::to_void_cast<YT>(input_data));
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(phi::DenseTensor* output) {
......@@ -406,145 +376,5 @@ class MatMulV2MKLDNNHandler
}
};
static std::unordered_map<std::string, std::string> GetAttributeMap(
std::string act_type) {
std::unordered_map<std::string, std::string> attr_map;
if (act_type == "swish") {
attr_map.emplace("beta", "fuse_alpha");
} else if (act_type == "relu6") {
attr_map.emplace("threshold", "fuse_alpha");
} else if (act_type == "hard_sigmoid") {
attr_map.emplace("slope", "fuse_alpha");
attr_map.emplace("offset", "fuse_beta");
} else if (act_type == "clip") {
attr_map.emplace("min", "fuse_alpha");
attr_map.emplace("max", "fuse_beta");
} else {
attr_map.emplace("alpha", "fuse_alpha");
attr_map.emplace("beta", "fuse_beta");
}
return attr_map;
}
static std::vector<std::string> GetSupportedActivations() {
return std::vector<std::string>{"abs",
"clip",
"gelu",
"hard_sigmoid",
"hard_swish",
"leaky_relu",
"mish",
"relu",
"relu6",
"sigmoid",
"sqrt",
"swish",
"tanh"};
}
class ReorderMKLDNNHandler {
public:
ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
framework::proto::VarType::Type vtype,
dnnl::memory::data_type dtype,
dnnl::engine engine)
: dims_(dims),
vtype_(vtype),
vtype_dst_(vtype),
dtype_(dtype),
dtype_dst_(dtype),
engine_(engine) {}
ReorderMKLDNNHandler(std::vector<int64_t>& dims, // NOLINT
framework::proto::VarType::Type vtype,
dnnl::memory::data_type dtype,
framework::proto::VarType::Type vtype_dst,
dnnl::memory::data_type dtype_dst,
dnnl::engine engine)
: dims_(dims),
vtype_(vtype),
vtype_dst_(vtype_dst),
dtype_(dtype),
dtype_dst_(dtype_dst),
engine_(engine) {}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const dnnl::memory::desc& md,
void* ptr) {
return std::make_shared<dnnl::memory>(md, engine_, ptr);
}
std::shared_ptr<dnnl::memory> AcquireSrcMemory(const MKLDNNMemoryFormat& fmt,
void* ptr) {
auto md = dnnl::memory::desc(dims_, dtype_, fmt);
return std::make_shared<dnnl::memory>(md, engine_, ptr);
}
std::shared_ptr<dnnl::memory> AcquireSubmemory(
const std::vector<int64_t>& dims,
const std::vector<int64_t>& offset,
const std::shared_ptr<dnnl::memory>& mem_p) {
auto sub_md = mem_p->get_desc().submemory_desc(dims, {offset});
auto sub_mem_p = std::make_shared<dnnl::memory>(
sub_md, engine_, mem_p->get_data_handle());
return sub_mem_p;
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(phi::DenseTensor* output,
const MKLDNNMemoryFormat& fmt,
platform::Place place) {
auto dst_md = platform::MKLDNNMemDesc(dims_, dtype_dst_, fmt);
auto dst_data = output->mutable_data(
place, framework::TransToPhiDataType(vtype_dst_), dst_md.get_size());
return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(
phi::DenseTensor* output,
const dnnl::memory::desc& src_md,
platform::Place place) {
if (vtype_dst_ == vtype_) {
auto dst_data = output->mutable_data(
place, framework::TransToPhiDataType(vtype_dst_), src_md.get_size());
return std::make_shared<dnnl::memory>(src_md, engine_, dst_data);
} else {
auto dst_md = src_md;
dst_md.data.data_type = static_cast<dnnl_data_type_t>(dtype_dst_);
auto dst_data = output->mutable_data(
place, framework::TransToPhiDataType(vtype_dst_), dst_md.get_size());
return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
}
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(
phi::DenseTensor* output,
const std::vector<int64_t>& dims,
const MKLDNNMemoryFormat& fmt,
platform::Place place) {
auto dst_md = platform::MKLDNNMemDesc(dims, dtype_dst_, fmt);
auto dst_data = output->mutable_data(
place, framework::TransToPhiDataType(vtype_dst_), dst_md.get_size());
return std::make_shared<dnnl::memory>(dst_md, engine_, dst_data);
}
std::shared_ptr<dnnl::reorder> AcquireReorder(
std::shared_ptr<dnnl::memory> dst_memory_p,
std::shared_ptr<dnnl::memory> src_memory_p) {
return std::make_shared<dnnl::reorder>(*(src_memory_p), *(dst_memory_p));
}
std::shared_ptr<dnnl::reorder> AcquireReorder(
std::shared_ptr<dnnl::memory> dst_memory_p,
std::shared_ptr<dnnl::memory> src_memory_p,
const dnnl::primitive_attr& attrs) {
return std::make_shared<dnnl::reorder>(
*(src_memory_p), *(dst_memory_p), attrs);
}
private:
std::vector<int64_t> dims_;
framework::proto::VarType::Type vtype_, vtype_dst_;
dnnl::memory::data_type dtype_, dtype_dst_;
dnnl::engine engine_;
};
} // namespace platform
} // namespace paddle
......@@ -195,28 +195,6 @@ inline std::string CreateKey(const OneDNNContext& dev_ctx, ArgTypes&&... args) {
return key;
}
inline std::vector<std::vector<int64_t>> ToOnednnPadding(
const std::vector<int64_t>& paddings) {
if (paddings.size() == 6) {
int padding_front = paddings[0];
int padding_back = paddings[1];
int padding_top = paddings[2];
int padding_bottom = paddings[3];
int padding_left = paddings[4];
int padding_right = paddings[5];
return {{padding_front, padding_top, padding_left},
{padding_back, padding_bottom, padding_right}};
} else {
int padding_top = paddings[0];
int padding_bottom = paddings[1];
int padding_left = paddings[2];
int padding_right = paddings[3];
return {{padding_top, padding_left}, {padding_bottom, padding_right}};
}
}
// The function adjusts the vector of weight dimensions for group convolutions
inline void GetGroupConvWeightsTz(std::vector<int64_t>& weights_tz, // NOLINT
const int groups) {
......@@ -306,10 +284,5 @@ inline std::string ExtendKeyWithThreadInfoIfNeeded(const OneDNNContext& dev_ctx,
: key;
}
template <typename T>
bool constexpr is_int8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
} // namespace funcs
} // namespace phi
......@@ -35,11 +35,20 @@ limitations under the License. */
namespace phi {
namespace funcs {
using user_function = std::function<std::shared_ptr<float>(const float*)>;
using memory = dnnl::memory;
using OneDNNMemoryFormat = dnnl::memory::format_tag;
template <typename T>
bool constexpr is_int8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
template <typename T>
constexpr bool is_bfloat16() {
return std::is_same<T, phi::dtype::bfloat16>::value;
}
static void AppendActivation(const OneDNNContext& dev_ctx,
dnnl::post_ops& post_ops, // NOLINT
float activation_scale = 1.0f) {
......@@ -101,6 +110,42 @@ static void AppendActivation(const OneDNNContext& dev_ctx,
}
}
static std::unordered_map<std::string, std::string> GetAttributeMap(
std::string act_type) {
std::unordered_map<std::string, std::string> attr_map;
if (act_type == "swish") {
attr_map.emplace("beta", "fuse_alpha");
} else if (act_type == "relu6") {
attr_map.emplace("threshold", "fuse_alpha");
} else if (act_type == "hard_sigmoid") {
attr_map.emplace("slope", "fuse_alpha");
attr_map.emplace("offset", "fuse_beta");
} else if (act_type == "clip") {
attr_map.emplace("min", "fuse_alpha");
attr_map.emplace("max", "fuse_beta");
} else {
attr_map.emplace("alpha", "fuse_alpha");
attr_map.emplace("beta", "fuse_beta");
}
return attr_map;
}
static std::vector<std::string> GetSupportedActivations() {
return std::vector<std::string>{"abs",
"clip",
"gelu",
"hard_sigmoid",
"hard_swish",
"leaky_relu",
"mish",
"relu",
"relu6",
"sigmoid",
"sqrt",
"swish",
"tanh"};
}
template <typename T,
typename TForward,
typename TBackward = onednn_dummy_primitive,
......
......@@ -23,14 +23,13 @@ limitations under the License. */
public:
/* @jim19930609: Remove dependency on protobuf after Tensor Unification.
*/
*/
explicit DenseTensor(paddle::experimental::DataType dtype);
inline bool IsInitialized() const { return holder_ != nullptr; }
template <typename T>
T* mutable_data(const phi::Place& place,
size_t requested_size = 0);
T* mutable_data(const phi::Place& place, size_t requested_size = 0);
template <typename T>
T* mutable_data(const DDim& dims,
......@@ -41,15 +40,14 @@ void* mutable_data(const phi::Place& place,
paddle::experimental::DataType type,
size_t requested_size = 0);
void* mutable_data(const phi::Place& place,
size_t requested_size = 0);
void* mutable_data(const phi::Place& place, size_t requested_size = 0);
void* mutable_data(const phi::Place& place,
paddle::experimental::DataType type,
const phi::Stream& stream);
/* @jim19930609: Remove dependency on protobuf after Tensor Unification.
*/
*/
paddle::experimental::DataType type() const;
// memory size returns the holding memory size in byte.
......@@ -86,13 +84,11 @@ std::shared_ptr<phi::Allocation> MoveMemoryHolder() {
void ResetHolder(const std::shared_ptr<phi::Allocation>& holder);
void ResetHolderWithType(const std::shared_ptr<phi::Allocation>& holder,
paddle::experimental::DataType type);
paddle::experimental::DataType type);
void set_type(paddle::experimental::DataType type);
InplaceVersion& InplaceVersionCounter() {
return *inplace_version_counter_;
}
InplaceVersion& InplaceVersionCounter() { return *inplace_version_counter_; }
/*! The internal of two tensors share the same memory block. */
DenseTensor& ShareDataWith(const DenseTensor& src);
......@@ -116,11 +112,11 @@ following codes there.
#ifdef PADDLE_WITH_MKLDNN
public:
const dnnl::memory::desc& mem_desc() const;
const dnnl::memory::desc& mem_desc() const;
inline void set_mem_desc(const dnnl::memory::desc& mem_desc) {
mem_desc_ = mem_desc;
meta_.layout = DataLayout::kMKLDNN;
meta_.layout = DataLayout::ONEDNN;
}
#endif
......@@ -141,8 +137,8 @@ void set_lod(const LoD& lod);
LoD* mutable_lod();
/*
* Get the start offset and end offset of an element from LoD.
*/
* Get the start offset and end offset of an element from LoD.
*/
std::pair<size_t, size_t> lod_element(size_t level, size_t elem) const;
size_t NumLevels() const;
......
......@@ -40,14 +40,8 @@ else()
endif()
if(WITH_MKLDNN)
math_library(
selected_rows_functor
DEPS
selected_rows_utils
math_function
blas
mkldnn_axpy_handler
mixed_vector)
math_library(selected_rows_functor DEPS selected_rows_utils math_function
blas mixed_vector)
else()
math_library(selected_rows_functor DEPS selected_rows_utils math_function
blas mixed_vector)
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/fluid/platform/device/device_wrapper.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/operators/mkldnn/axpy_handler.h"
#include "paddle/phi/backends/onednn/axpy_handler.h"
#endif
namespace phi {
......@@ -371,7 +371,9 @@ add_sparse_inputs(const std::vector<const phi::SelectedRows*>& inputs,
auto& input_rows = input->rows();
#ifdef PADDLE_WITH_MKLDNN
paddle::operators::OneDNNAXPYHandler<T> axpy_handler(input_width, T(1.f));
OneDNNContext onednn_context(context.GetPlace());
funcs::OneDNNAXPYHandler<T> axpy_handler(
input_width, T(1.f), onednn_context.GetEngine());
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = rows_to_id.at(input_rows[i]);
axpy_handler(&input_data[i * input_width],
......@@ -869,11 +871,11 @@ struct UpdateToTensor<phi::CPUContext, T> {
PADDLE_ENFORCE_EQ(
in1_row_numel,
input2->numel() / in1_height,
phi::errors::InvalidArgument(
"The two inputs width must be equal."
"But received first input width = [%d], second input width = [%d]",
in1_row_numel,
input2->numel() / in1_height));
phi::errors::InvalidArgument("The two inputs width must be equal."
"But received first input width = [%d], "
"second input width = [%d]",
in1_row_numel,
input2->numel() / in1_height));
auto* in1_data = in1_value.data<T>();
auto* input2_data = input2->data<T>();
......
......@@ -154,7 +154,7 @@ class ConvOneDNNHandlerT
const auto dst_tz = phi::vectorize(output->dims());
const dnnl::memory::dims stride_dims = strides;
const auto onednn_paddings = funcs::ToOnednnPadding(paddings);
const auto onednn_paddings = funcs::ToOneDNNPadding(paddings);
const dnnl::memory::dims dilations_dims = dilations;
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
......@@ -326,7 +326,7 @@ class ConvOneDNNHandlerT
auto diff_dst_md = funcs::OneDNNMemDesc(
dst_tz, funcs::OneDNNGetDataType<T>(), chosen_memory_format);
auto onednn_paddings = funcs::ToOnednnPadding(paddings);
auto onednn_paddings = funcs::ToOneDNNPadding(paddings);
std::transform(
dilations.begin(), dilations.end(), dilations.begin(), [](int64_t i) {
return i - 1;
......
......@@ -291,8 +291,7 @@ void ConvKernel(const Context& dev_ctx,
dev_ctx.GetPlace().GetType(),
AllocationType::CPU,
phi::errors::PreconditionNotMet("Operator DNNL Conv must use CPUPlace"));
bool is_INT8 =
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
bool is_INT8 = funcs::is_int8<T>();
bool is_test = dev_ctx.HasDnnAttr("is_test")
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("is_test"))
......
......@@ -138,7 +138,7 @@ void TransferLayoutMKLDNN(const Context& dev_ctx,
src_layout,
dst_layout,
errors::PreconditionNotMet(
"No layout transform needed between two MKLDNN OPKernels."));
"No layout transform needed between two oneDNN OPKernels."));
} else {
TransferLayoutGeneral<Context>(dev_ctx, x, dst_layout, out);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册