diff --git a/paddle/fluid/operators/addmm_op.cc b/paddle/fluid/operators/addmm_op.cc index c7d6201ed2b7c96194fdb198fa0707e0289214f3..833285615f1699d84df2a7309554237494c5db72 100644 --- a/paddle/fluid/operators/addmm_op.cc +++ b/paddle/fluid/operators/addmm_op.cc @@ -39,29 +39,23 @@ class AddMMOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; - + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; if (input_data_type == framework::DataTypeTrait::DataType() || input_data_type == framework::DataTypeTrait::DataType()) { customized_type_value = kMULMKLDNNINT8; } + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN, + customized_type_value); } #endif - - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - layout, - library, - customized_type_value); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/angle_op.cc b/paddle/fluid/operators/angle_op.cc index 5c18f4c6fcc7df1753dc7020a91e71e1151e565c..ccd5584e8dedfd0ceaf4f125a752920c5f171723 100644 --- a/paddle/fluid/operators/angle_op.cc +++ b/paddle/fluid/operators/angle_op.cc @@ -16,9 +16,6 @@ #include #include #include -#ifdef PADDLE_WITH_MKLDNN -#include "paddle/fluid/platform/mkldnn_helper.h" -#endif #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index a4a3f3cd2b054d09c56e5c0b50e47c5c589c90ed..84f22ebff408429fda812441242cba57d8d2ac51 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -195,18 +195,16 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType( "Variance input should be of float type")); // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } framework::OpKernelType BatchNormOp::GetKernelTypeForVar( @@ -396,19 +394,18 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( } // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar( diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index 5cc991d8f1c751bedaaab820f2611e91736a988c..d883d2da291b234cc35e0c7459aef696e1baefec 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -80,9 +80,7 @@ framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar( // op. Treat this as NCHW (default data_format value) if (dl != framework::DataLayout::kAnyLayout) { return framework::OpKernelType( - expected_kernel_type.data_type_, - tensor.place(), - framework::StringToDataLayout(data_format)); + expected_kernel_type.data_type_, tensor.place(), dl); } } #endif diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc index a4cfb82bf8aaa390ed51455d41227e7086cbf502..4fc279e03a36f0eb447d4128c54854fb231bc015 100644 --- a/paddle/fluid/operators/data_norm_op.cc +++ b/paddle/fluid/operators/data_norm_op.cc @@ -200,18 +200,16 @@ class DataNormOp : public framework::OperatorWithKernel { "bias input should be of float type")); } // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -511,19 +509,18 @@ class DataNormGradOp : public framework::OperatorWithKernel { } // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } };