diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc index 03733e34ec670b3467c535eb887ea5995a630122..37111738e0633b912ce99a62bccdf1b25894235b 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cc +++ b/paddle/fluid/operators/detection/prior_box_op.cc @@ -35,13 +35,8 @@ class PriorBoxOp : public framework::OperatorWithKernel { auto input_input_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); - framework::LibraryType library_{framework::LibraryType::kPlain}; - framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_input_type)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, input_input_type)) { auto input_image_type = framework::TransToProtoVarType( ctx.Input("Image")->dtype()); int customized_type_value = @@ -54,13 +49,12 @@ class PriorBoxOp : public framework::OperatorWithKernel { } return framework::OpKernelType(input_input_type, ctx.GetPlace(), - layout_, - library_, + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN, customized_type_value); } #endif - return framework::OpKernelType( - input_input_type, ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(input_input_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index e2d2cf071caba8f7e9691a6bfdacaedeff8aee78..ced0355b1a68efaa6b3d8c2ee136ea952a574bff 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -152,16 +152,16 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType FusionGRUOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (this->CanMKLDNNBeUsed(ctx, data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } void FusionGRUOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/fusion_lstm_op.cc b/paddle/fluid/operators/fused/fusion_lstm_op.cc index 5454c90b3c59618fc8c028265adfec1ce81d4b1f..e45fd196437d7f8cd03992e4bc4878154db0d09a 100644 --- a/paddle/fluid/operators/fused/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fused/fusion_lstm_op.cc @@ -175,16 +175,16 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (this->CanMKLDNNBeUsed(ctx, data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } void FusionLSTMOpMaker::Make() { diff --git a/paddle/fluid/operators/fused/multi_gru_op.cc b/paddle/fluid/operators/fused/multi_gru_op.cc index 2a8917f1c005d2979453e0422608049718300443..2e40afa9c3a5730ad34c6d8a25cfa9c0ab2a9bc7 100644 --- a/paddle/fluid/operators/fused/multi_gru_op.cc +++ b/paddle/fluid/operators/fused/multi_gru_op.cc @@ -143,14 +143,11 @@ void MultiGRUOp::InferShape(framework::InferShapeContext* ctx) const { framework::OpKernelType MultiGRUOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kMKLDNN; - framework::DataLayout layout = framework::DataLayout::kMKLDNN; - return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout, - library); + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } void MultiGRUOpMaker::Make() { diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index a49ceb42559c5cf86ae467787525e637300b52a4..ae23ab91dcd99a09480d2d9451ff31f5ac1fd12d 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -700,7 +700,6 @@ class MatMulOp : public framework::OperatorWithKernel { OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); #ifdef PADDLE_WITH_MKLDNN - using dnnl::memory; if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), diff --git a/paddle/fluid/operators/matrix_rank_op.cc b/paddle/fluid/operators/matrix_rank_op.cc index 4c9a7f563664ce0e879c671762c0e81d2399f103..15904ae63b2dc82c58cf012ef6e26efd4258ed0e 100644 --- a/paddle/fluid/operators/matrix_rank_op.cc +++ b/paddle/fluid/operators/matrix_rank_op.cc @@ -19,10 +19,6 @@ #include "paddle/fluid/operators/svd_helper.h" #include "paddle/phi/kernels/funcs/compare_functors.h" -#ifdef PADDLE_WITH_MKLDNN -#include "paddle/fluid/platform/mkldnn_helper.h" -#endif - namespace paddle { namespace operators { using DDim = framework::DDim; diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index 2d4ca62955eb1474c72ea36695bc71847084c4d8..9e28d1d57be782ec3a4475e8d9dfd0e911938113 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -41,17 +41,12 @@ class MulOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); -#ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; if (input_data_type == framework::DataTypeTrait::DataType() || input_data_type == framework::DataTypeTrait::DataType()) { customized_type_value = kMULMKLDNNINT8; @@ -62,14 +57,15 @@ class MulOp : public framework::OperatorWithKernel { framework::DataTypeTrait::DataType()) { customized_type_value = kMULMKLDNNFP32; } + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN, + customized_type_value); } #endif - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - layout, - library, - customized_type_value); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -140,17 +136,12 @@ class MulGradOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); -#ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; if (input_data_type == framework::DataTypeTrait::DataType() || input_data_type == framework::DataTypeTrait::DataType()) { customized_type_value = kMULMKLDNNINT8; @@ -161,14 +152,15 @@ class MulGradOp : public framework::OperatorWithKernel { framework::DataTypeTrait::DataType()) { customized_type_value = kMULMKLDNNFP32; } + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN, + customized_type_value); } #endif - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - layout, - library, - customized_type_value); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index e8b35b89157a315f4b1aacce06cea0ea4993c79f..2e2fd00c647fa5aebcf060c06a0fcc4cd4186028 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -42,8 +42,7 @@ bool CanMKLDNNSupportPool(const framework::ExecutionContext& ctx) { framework::OpKernelType PoolOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = "AnyLayout"; - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -88,8 +87,7 @@ framework::OpKernelType PoolOp::GetKernelTypeForVar( framework::OpKernelType PoolOpGrad::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = "AnyLayout"; - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index f7abaf648ebcfd195921bc015a9c6b770045b166..4c9a7a388ffc8b753a2d5341e3324a5c447f9799 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -23,26 +23,6 @@ namespace operators { using Tensor = framework::Tensor; -framework::OpKernelType innerGetKernelTypeForVar( - const Tensor &tensor, const framework::OpKernelType &expected_kernel_type) { -#ifdef PADDLE_WITH_MKLDNN - auto isOneDNNKernelChosen = - (expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN); - auto isNotOneDNNTensor = (tensor.layout() != framework::DataLayout::kMKLDNN); - auto isModelNHWC = - (paddle::platform::MKLDNNDeviceContext::tls() - .get_cur_paddle_data_layout() == framework::DataLayout::kNHWC); - // All inputs (including alpha) need shape rotating - if (isOneDNNKernelChosen && isNotOneDNNTensor && isModelNHWC) { - return framework::OpKernelType(expected_kernel_type.data_type_, - tensor.place(), - framework::DataLayout::kNHWC); - } -#endif - return framework::OpKernelType( - expected_kernel_type.data_type_, tensor.place(), tensor.layout()); -} - class PReluOp : public framework::OperatorWithKernel { public: PReluOp(const std::string &type, @@ -72,7 +52,19 @@ class PReluOp : public framework::OperatorWithKernel { const std::string &var_name, const Tensor &tensor, const framework::OpKernelType &expected_kernel_type) const override { - return innerGetKernelTypeForVar(tensor, expected_kernel_type); +#ifdef PADDLE_WITH_MKLDNN + // All inputs (including alpha) need shape rotating + if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) && + (tensor.layout() != framework::DataLayout::kMKLDNN) && + paddle::platform::MKLDNNDeviceContext::tls() + .get_cur_paddle_data_layout() == framework::DataLayout::kNHWC) { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), + framework::DataLayout::kNHWC); + } +#endif + return framework::OpKernelType( + expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } }; @@ -151,7 +143,19 @@ class PReluGradOp : public framework::OperatorWithKernel { const std::string &var_name, const Tensor &tensor, const framework::OpKernelType &expected_kernel_type) const override { - return innerGetKernelTypeForVar(tensor, expected_kernel_type); +#ifdef PADDLE_WITH_MKLDNN + // All inputs (including alpha) need shape rotating + if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) && + (tensor.layout() != framework::DataLayout::kMKLDNN) && + paddle::platform::MKLDNNDeviceContext::tls() + .get_cur_paddle_data_layout() == framework::DataLayout::kNHWC) { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), + framework::DataLayout::kNHWC); + } +#endif + return framework::OpKernelType( + expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } }; diff --git a/paddle/fluid/operators/qr_op.cc b/paddle/fluid/operators/qr_op.cc index e939ec7be2ee7010d6a30b7156a779543f704490..3eac56d1604b9acab58d4ae61b587859f9756722 100644 --- a/paddle/fluid/operators/qr_op.cc +++ b/paddle/fluid/operators/qr_op.cc @@ -17,12 +17,9 @@ #include #include -#include "paddle/phi/core/ddim.h" -#ifdef PADDLE_WITH_MKLDNN -#include "paddle/fluid/platform/mkldnn_helper.h" -#endif #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/unary.h" diff --git a/paddle/fluid/operators/quantize_op.cc b/paddle/fluid/operators/quantize_op.cc index 45b4e9bac7f3ccd32a033d53b02b637ebac8da45..9d956cde9fdfa6603cb15c3e3f30cf79e40aaae6 100644 --- a/paddle/fluid/operators/quantize_op.cc +++ b/paddle/fluid/operators/quantize_op.cc @@ -24,14 +24,11 @@ namespace operators { framework::OpKernelType QuantOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library_ = framework::LibraryType::kMKLDNN; - framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; - return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), - layout_, - library_); + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } void QuantOpMaker::Make() { diff --git a/paddle/fluid/operators/requantize_op.cc b/paddle/fluid/operators/requantize_op.cc index 5c48a22666efa2b51e059a19eb3376327ff1c2b5..dfb32a8e8f9c78b9cb1c96cb6f63051e42128b8d 100644 --- a/paddle/fluid/operators/requantize_op.cc +++ b/paddle/fluid/operators/requantize_op.cc @@ -24,14 +24,11 @@ namespace operators { framework::OpKernelType ReQuantOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library_ = framework::LibraryType::kMKLDNN; - framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; - return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), - layout_, - library_); + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } void ReQuantOpMaker::Make() { diff --git a/paddle/fluid/operators/svd_op.cc b/paddle/fluid/operators/svd_op.cc index 7f9fccddf729a4e5479e477514cb6e2ad82493d2..afbfd80b8d5379ca915b104e84824aeeedefb661 100644 --- a/paddle/fluid/operators/svd_op.cc +++ b/paddle/fluid/operators/svd_op.cc @@ -21,9 +21,6 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/phi/core/ddim.h" #include "paddle/phi/infermeta/unary.h" -#ifdef PADDLE_WITH_MKLDNN -#include "paddle/fluid/platform/mkldnn_helper.h" -#endif namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index b342f01e46ff7661d4d76483b59fa0cb05d1fa58..8ac9704df3a43e1f21f21b9d07be55264843b207 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -99,19 +99,18 @@ class TransposeOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - auto &data_format = ctx.Attr("data_format"); - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType( - data_type, ctx.GetPlace(), layout_, library_); + auto &data_format = ctx.Attr("data_format"); + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); } }; @@ -203,20 +202,19 @@ class TransposeOpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = ctx.Attr("data_format"); - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); auto data_type = OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType( - data_type, ctx.GetPlace(), layout_, library_); + std::string data_format = ctx.Attr("data_format"); + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); } }; @@ -249,29 +247,27 @@ class Transpose2Op : public TransposeOp { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = ctx.Attr("data_format"); - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::proto::VarType::Type data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { using framework::proto::VarType; auto input_data_type = framework::TransToProtoVarType(ctx.Input("X")->dtype()); - customized_type_value = (input_data_type == VarType::INT8 || - input_data_type == VarType::UINT8) - ? kTransposeMKLDNNINT8 - : kTransposeMKLDNNFP32; + int customized_type_value = (input_data_type == VarType::INT8 || + input_data_type == VarType::UINT8) + ? kTransposeMKLDNNINT8 + : kTransposeMKLDNNFP32; + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN, + customized_type_value); } #endif - return framework::OpKernelType( - data_type, ctx.GetPlace(), layout_, library_, customized_type_value); + std::string data_format = ctx.Attr("data_format"); + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); } }; @@ -371,21 +367,20 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - std::string data_format = ctx.Attr("data_format"); - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); framework::proto::VarType::Type data_type = OperatorWithKernel::IndicateVarDataType(ctx, framework::GradVarName("Out")); #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType( - data_type, ctx.GetPlace(), layout_, library_); + std::string data_format = ctx.Attr("data_format"); + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_); } };