From 4b8d4ade5e06e7f989d15785c4a78792ce774a9b Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Tue, 20 Sep 2022 10:58:19 +0000 Subject: [PATCH] refine mkldnn code --- .../fluid/operators/detection/prior_box_op.cc | 14 +--- paddle/fluid/operators/fused/fusion_gru_op.cc | 10 +-- .../fluid/operators/fused/fusion_lstm_op.cc | 10 +-- paddle/fluid/operators/fused/multi_gru_op.cc | 7 +- paddle/fluid/operators/matmul_op.cc | 1 - paddle/fluid/operators/matrix_rank_op.cc | 4 - paddle/fluid/operators/mul_op.cc | 48 +++++------ paddle/fluid/operators/pool_op.cc | 6 +- paddle/fluid/operators/prelu_op.cc | 48 +++++------ paddle/fluid/operators/qr_op.cc | 5 +- paddle/fluid/operators/quantize_op.cc | 7 +- paddle/fluid/operators/requantize_op.cc | 7 +- paddle/fluid/operators/svd_op.cc | 3 - paddle/fluid/operators/transpose_op.cc | 79 +++++++++---------- 14 files changed, 106 insertions(+), 143 deletions(-) diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc index 03733e34ec6..37111738e06 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cc +++ b/paddle/fluid/operators/detection/prior_box_op.cc @@ -35,13 +35,8 @@ class PriorBoxOp : public framework::OperatorWithKernel { auto input_input_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - framework::LibraryType library_{framework::LibraryType::kPlain}; - framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_input_type)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, input_input_type)) { auto input_image_type = framework::TransToProtoVarType( ctx.Input("Image")->dtype()); int customized_type_value = @@ -54,13 +49,12 @@ class PriorBoxOp : public framework::OperatorWithKernel { } return framework::OpKernelType(input_input_type, ctx.GetPlace(), - layout_, - library_, + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN, customized_type_value); } #endif - return framework::OpKernelType( - input_input_type, ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(input_input_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index e2d2cf071ca..ced0355b1a6 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -152,16 +152,16 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType FusionGRUOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (this->CanMKLDNNBeUsed(ctx, data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } void FusionGRUOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index 5454c90b3c5..e45fd196437 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -175,16 +175,16 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (this->CanMKLDNNBeUsed(ctx, data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } void FusionLSTMOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/multi_gru_op.cc b/paddle/fluid/operators/fused/multi_gru_op.cc index 2a8917f1c00..2e40afa9c3a 100644 --- a/paddle/fluid/operators/fused/multi_gru_op.cc +++ b/paddle/fluid/operators/fused/multi_gru_op.cc @@ -143,14 +143,11 @@ void MultiGRUOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType MultiGRUOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kMKLDNN; - framework::DataLayout layout = framework::DataLayout::kMKLDNN; - return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout, - library); + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } void MultiGRUOpMaker::Make() { diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index a49ceb42559..ae23ab91dcd 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -700,7 +700,6 @@ class MatMulOp : public framework::OperatorWithKernel { OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); #ifdef PADDLE_WITH_MKLDNN - using dnnl::memory; if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), diff --git a/paddle/fluid/operators/matrix_rank_op.cc b/paddle/fluid/operators/matrix_rank_op.cc index 4c9a7f56366..15904ae63b2 100644 --- a/paddle/fluid/operators/matrix_rank_op.cc +++ b/paddle/fluid/operators/matrix_rank_op.cc @@ -19,10 +19,6 @@ #include "paddle/fluid/operators/svd_helper.h" #include "paddle/phi/kernels/funcs/compare_functors.h" -#ifdef PADDLE_WITH_MKLDNN -#include "paddle/fluid/platform/mkldnn_helper.h" -#endif - namespace paddle { namespace operators { using DDim = framework::DDim; diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index 2d4ca62955e..9e28d1d57be 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -41,17 +41,12 @@ class MulOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); -#ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; if (input_data_type == framework::DataTypeTrait::DataType() || input_data_type == framework::DataTypeTrait::DataType()) { customized_type_value = kMULMKLDNNINT8; @@ -62,14 +57,15 @@ class MulOp : public framework::OperatorWithKernel { framework::DataTypeTrait::DataType()) { customized_type_value = kMULMKLDNNFP32; } + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN, + customized_type_value); } #endif - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - layout, - library, - customized_type_value); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -140,17 +136,12 @@ class MulGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); -#ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; if (input_data_type == framework::DataTypeTrait::DataType() || input_data_type == framework::DataTypeTrait::DataType()) { customized_type_value = kMULMKLDNNINT8; @@ -161,14 +152,15 @@ class MulGradOp : public framework::OperatorWithKernel { framework::DataTypeTrait::DataType()) { customized_type_value = kMULMKLDNNFP32; } + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN, + customized_type_value); } #endif - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - layout, - library, - customized_type_value); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index e8b35b89157..2e2fd00c647 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -42,8 +42,7 @@ bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) { framework::OpKernelType PoolOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = "AnyLayout"; - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -88,8 +87,7 @@ framework::OpKernelType PoolOp::GetKernelTypeForVar( framework::OpKernelType PoolOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = "AnyLayout"; - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index f7abaf648eb..4c9a7a388ff 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -23,26 +23,6 @@ namespace operators { using Tensor = framework::Tensor; -framework::OpKernelType innerGetKernelTypeForVar( - const Tensor &tensor, const framework::OpKernelType &expected_kernel_type) { -#ifdef PADDLE_WITH_MKLDNN - auto isOneDNNKernelChosen = - (expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN); - auto isNotOneDNNTensor = (tensor.layout() != framework::DataLayout::kMKLDNN); - auto isModelNHWC = - (paddle::platform::MKLDNNDeviceContext::tls() - .get_cur_paddle_data_layout() == framework::DataLayout::kNHWC); - // All inputs (including alpha) need shape rotating - if (isOneDNNKernelChosen && isNotOneDNNTensor && isModelNHWC) { - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), - framework::DataLayout::kNHWC); - } -#endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); -} - class PReluOp : public framework::OperatorWithKernel { public: PReluOp(const std::string &type, @@ -72,7 +52,19 @@ class PReluOp : public framework::OperatorWithKernel { const std::string &var_name, const Tensor &tensor, const framework::OpKernelType &expected_kernel_type) const override { - return innerGetKernelTypeForVar(tensor, expected_kernel_type); +#ifdef PADDLE_WITH_MKLDNN + // All inputs (including alpha) need shape rotating + if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) && + (tensor.layout() != framework::DataLayout::kMKLDNN) && + paddle::platform::MKLDNNDeviceContext::tls() + .get_cur_paddle_data_layout() == framework::DataLayout::kNHWC) { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), + framework::DataLayout::kNHWC); + } +#endif + return framework::OpKernelType( + expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } }; @@ -151,7 +143,19 @@ class PReluGradOp : public framework::OperatorWithKernel { const std::string &var_name, const Tensor &tensor, const framework::OpKernelType &expected_kernel_type) const override { - return innerGetKernelTypeForVar(tensor, expected_kernel_type); +#ifdef PADDLE_WITH_MKLDNN + // All inputs (including alpha) need shape rotating + if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) && + (tensor.layout() != framework::DataLayout::kMKLDNN) && + paddle::platform::MKLDNNDeviceContext::tls() + .get_cur_paddle_data_layout() == framework::DataLayout::kNHWC) { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), + framework::DataLayout::kNHWC); + } +#endif + return framework::OpKernelType( + expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } }; diff --git a/paddle/fluid/operators/qr_op.cc b/paddle/fluid/operators/qr_op.cc index e939ec7be2e..3eac56d1604 100644 --- a/paddle/fluid/operators/qr_op.cc +++ b/paddle/fluid/operators/qr_op.cc @@ -17,12 +17,9 @@ #include #include -#include "paddle/phi/core/ddim.h" -#ifdef PADDLE_WITH_MKLDNN -#include "paddle/fluid/platform/mkldnn_helper.h" -#endif #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/unary.h" diff --git a/paddle/fluid/operators/quantize_op.cc b/paddle/fluid/operators/quantize_op.cc index 45b4e9bac7f..9d956cde9fd 100644 --- a/paddle/fluid/operators/quantize_op.cc +++ b/paddle/fluid/operators/quantize_op.cc @@ -24,14 +24,11 @@ namespace operators { framework::OpKernelType QuantOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library_ = framework::LibraryType::kMKLDNN; - framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; - return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), - layout_, - library_); + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } void QuantOpMaker::Make() { diff --git a/paddle/fluid/operators/requantize_op.cc b/paddle/fluid/operators/requantize_op.cc index 5c48a22666e..dfb32a8e8f9 100644 --- a/paddle/fluid/operators/requantize_op.cc +++ b/paddle/fluid/operators/requantize_op.cc @@ -24,14 +24,11 @@ namespace operators { framework::OpKernelType ReQuantOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library_ = framework::LibraryType::kMKLDNN; - framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; - return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), - layout_, - library_); + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } void ReQuantOpMaker::Make() { diff --git a/paddle/fluid/operators/svd_op.cc b/paddle/fluid/operators/svd_op.cc index 7f9fccddf72..afbfd80b8d5 100644 --- a/paddle/fluid/operators/svd_op.cc +++ b/paddle/fluid/operators/svd_op.cc @@ -21,9 +21,6 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/infermeta/unary.h" -#ifdef PADDLE_WITH_MKLDNN -#include "paddle/fluid/platform/mkldnn_helper.h" -#endif namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index b342f01e46f..8ac9704df3a 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -99,19 +99,18 @@ class TransposeOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - auto &data_format = ctx.Attr("data_format"); - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType( - data_type, ctx.GetPlace(), layout_, library_); + auto &data_format = ctx.Attr("data_format"); + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); } }; @@ -203,20 +202,19 @@ class TransposeOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = ctx.Attr("data_format"); - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType( - data_type, ctx.GetPlace(), layout_, library_); + std::string data_format = ctx.Attr("data_format"); + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); } }; @@ -249,29 +247,27 @@ class Transpose2Op : public TransposeOp { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = ctx.Attr("data_format"); - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::proto::VarType::Type data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { using framework::proto::VarType; auto input_data_type = framework::TransToProtoVarType(ctx.Input("X")->dtype()); - customized_type_value = (input_data_type == VarType::INT8 || - input_data_type == VarType::UINT8) - ? kTransposeMKLDNNINT8 - : kTransposeMKLDNNFP32; + int customized_type_value = (input_data_type == VarType::INT8 || + input_data_type == VarType::UINT8) + ? kTransposeMKLDNNINT8 + : kTransposeMKLDNNFP32; + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN, + customized_type_value); } #endif - return framework::OpKernelType( - data_type, ctx.GetPlace(), layout_, library_, customized_type_value); + std::string data_format = ctx.Attr("data_format"); + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); } }; @@ -371,21 +367,20 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = ctx.Attr("data_format"); - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::proto::VarType::Type data_type = OperatorWithKernel::IndicateVarDataType(ctx, framework::GradVarName("Out")); #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType( - data_type, ctx.GetPlace(), layout_, library_); + std::string data_format = ctx.Attr("data_format"); + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); } }; -- GitLab