From 98aaf7974c59b20a9f22a7d59fd590b71580f3db Mon Sep 17 00:00:00 2001 From: jakpiase Date: Mon, 28 Nov 2022 15:40:28 +0100 Subject: [PATCH] Reenabled reshape, squeeze and flatten oneDNN kernels (#48359) * re-enabled reshape, squeeze and flatten kernels * added formatting --- paddle/fluid/operators/flatten_op.cc | 40 +++++++++++++ .../operators/mkldnn/reshape_mkldnn_op.cc | 31 ++-------- paddle/fluid/operators/reshape_op.cc | 18 ++++++ paddle/fluid/operators/squeeze_op.cc | 60 ++++++++++--------- 4 files changed, 94 insertions(+), 55 deletions(-) diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 036f3b8222..65d3f809fa 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -87,6 +87,16 @@ class FlattenOp : public framework::OperatorWithKernel { const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + phi::DataLayout::ONEDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -159,6 +169,16 @@ class FlattenGradOp : public framework::OperatorWithKernel { const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + phi::DataLayout::ONEDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -223,6 +243,16 @@ class Flatten2Op : public framework::OperatorWithKernel { const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + phi::DataLayout::ONEDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -275,6 +305,16 @@ class Flatten2GradOp : public framework::OperatorWithKernel { const framework::ExecutionContext &ctx) const override { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + phi::DataLayout::ONEDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc index f1b321c5dd..902cd8509b 100644 --- a/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/reshape_mkldnn_op.cc @@ -80,7 +80,7 @@ class ReshapeMKLDNNKernel : public framework::OpKernel { out->Resize(x_dims); // to match x numel, format is changed later // reorder is done into a plain tag to allow usage with blocked formats auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( - out, getPlainFormatTag(x), ctx.GetPlace()); + out, phi::funcs::GetPlainOneDNNFormat(x_dims.size()), ctx.GetPlace()); auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p); @@ -194,31 +194,6 @@ class ReshapeMKLDNNKernel : public framework::OpKernel { } protected: - static dnnl::memory::format_tag getPlainFormatTag( - const phi::DenseTensor* tensor) { - auto tensor_dims_size = tensor->dims().size(); - PADDLE_ENFORCE_EQ( - tensor_dims_size <= 6 && tensor_dims_size >= 1, - true, - platform::errors::InvalidArgument( - "Dims for squeeze_grad oneDNN op must be in range <1, 6>")); - - switch (tensor_dims_size) { - case 1: - return dnnl::memory::format_tag::a; - case 2: - return dnnl::memory::format_tag::ab; - case 3: - return dnnl::memory::format_tag::abc; - case 4: - return dnnl::memory::format_tag::abcd; - case 5: - return dnnl::memory::format_tag::abcde; - default: - return dnnl::memory::format_tag::abcdef; - } - } - static framework::DDim ValidateShape(const std::vector& shape, const framework::DDim& in_dims) { const int64_t in_size = phi::product(in_dims); @@ -348,7 +323,9 @@ class ReshapeGradMKLDNNKernel : public ReshapeMKLDNNKernel { auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( dout->mem_desc(), phi::funcs::to_void_cast(dout->data())); auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( - dx, this->getPlainFormatTag(dout), ctx.GetPlace()); + dx, + phi::funcs::GetPlainOneDNNFormat(dout_vec_dims.size()), + ctx.GetPlace()); auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p); diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index e143d3e144..161f230bac 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -258,6 +258,15 @@ class ReshapeOp : public framework::OperatorWithKernel { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + phi::DataLayout::ONEDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } @@ -615,6 +624,15 @@ class Reshape2GradOp : public framework::OperatorWithKernel { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + phi::DataLayout::ONEDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc index 93a03c535f..1afc7ac8ec 100644 --- a/paddle/fluid/operators/squeeze_op.cc +++ b/paddle/fluid/operators/squeeze_op.cc @@ -125,13 +125,14 @@ class SqueezeOp : public framework::OperatorWithKernel { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - // #ifdef PADDLE_WITH_MKLDNN - // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { - // return framework::OpKernelType(input_data_type, ctx.GetPlace(), - // phi::DataLayout::ONEDNN, - // framework::LibraryType::kMKLDNN); - // } - // #endif +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + phi::DataLayout::ONEDNN, + framework::LibraryType::kMKLDNN); + } +#endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -152,13 +153,14 @@ class SqueezeGradOp : public framework::OperatorWithKernel { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - // #ifdef PADDLE_WITH_MKLDNN - // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { - // return framework::OpKernelType(input_data_type, ctx.GetPlace(), - // phi::DataLayout::ONEDNN, - // framework::LibraryType::kMKLDNN); - // } - // #endif +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + phi::DataLayout::ONEDNN, + framework::LibraryType::kMKLDNN); + } +#endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -219,13 +221,14 @@ class Squeeze2Op : public framework::OperatorWithKernel { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - // #ifdef PADDLE_WITH_MKLDNN - // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { - // return framework::OpKernelType(input_data_type, ctx.GetPlace(), - // phi::DataLayout::ONEDNN, - // framework::LibraryType::kMKLDNN); - // } - // #endif +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + phi::DataLayout::ONEDNN, + framework::LibraryType::kMKLDNN); + } +#endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -267,13 +270,14 @@ class Squeeze2GradOp : public framework::OperatorWithKernel { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - // #ifdef PADDLE_WITH_MKLDNN - // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { - // return framework::OpKernelType(input_data_type, ctx.GetPlace(), - // phi::DataLayout::ONEDNN, - // framework::LibraryType::kMKLDNN); - // } - // #endif +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + phi::DataLayout::ONEDNN, + framework::LibraryType::kMKLDNN); + } +#endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; -- GitLab