diff --git a/paddle/fluid/operators/dequantize_op.cc b/paddle/fluid/operators/dequantize_op.cc index 0fb13e24b80e0f8f526b54c2093d6ca27422892b..fc6dbbade032297c521eb3dc4fc1e339ad0eb6e4 100644 --- a/paddle/fluid/operators/dequantize_op.cc +++ b/paddle/fluid/operators/dequantize_op.cc @@ -24,14 +24,13 @@ namespace operators { framework::OpKernelType DeQuantOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library_ = framework::LibraryType::kMKLDNN; - framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; - - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.GetPlace(), - layout_, - library_); + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "Input"); + + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } void DeQuantOpMaker::Make() { diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index 43bb6089a87dd2d56105e24732193bbf10289a48..1f993b58aa76f8b6caa149ec12e4fe0b9daba82f 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -126,26 +126,21 @@ class FCOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); if (ctx.Attr("use_mkldnn")) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; using framework::proto::VarType; - customized_type_value = (input_data_type == VarType::INT8 || - input_data_type == VarType::UINT8) - ? kFCMKLDNNINT8 - : kFCMKLDNNFP32; + int customized_type_value = (input_data_type == VarType::INT8 || + input_data_type == VarType::UINT8) + ? kFCMKLDNNINT8 + : kFCMKLDNNFP32; + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN, + customized_type_value); } - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - layout, - library, - customized_type_value); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/gaussian_random_op.cc b/paddle/fluid/operators/gaussian_random_op.cc index 8d92305eb6f15840f3901638bd9176861b132c6d..b80bc7320c1fd630899bd4be8a373a3d16be64bf 100644 --- a/paddle/fluid/operators/gaussian_random_op.cc +++ b/paddle/fluid/operators/gaussian_random_op.cc @@ -58,21 +58,19 @@ class GaussianRandomOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library{framework::LibraryType::kPlain}; - framework::DataLayout layout{framework::DataLayout::kAnyLayout}; auto data_type = static_cast(ctx.Attr("dtype")); #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.device_context(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType( - data_type, ctx.device_context(), layout, library); + return framework::OpKernelType(data_type, ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/gaussian_random_op_npu.cc b/paddle/fluid/operators/gaussian_random_op_npu.cc index 8b3af57d923fe5c6c80de321901f46a49c266100..bdeb106f81665f9c1daa839f52e0e21abbef7157 100644 --- a/paddle/fluid/operators/gaussian_random_op_npu.cc +++ b/paddle/fluid/operators/gaussian_random_op_npu.cc @@ -18,9 +18,6 @@ limitations under the License. */ #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" -#ifdef PADDLE_WITH_MKLDNN -#include "paddle/fluid/platform/mkldnn_helper.h" -#endif namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/gelu_op.cc b/paddle/fluid/operators/gelu_op.cc index a16544b8ba3dee70df5bd387598e25856df42e68..c5ec8d1b21ab9dfc5687c76cc40107ead7b2f462 100644 --- a/paddle/fluid/operators/gelu_op.cc +++ b/paddle/fluid/operators/gelu_op.cc @@ -35,17 +35,16 @@ class GeluOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::LibraryType library{framework::LibraryType::kPlain}; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } }; @@ -76,18 +75,17 @@ class GeluGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::LibraryType library{framework::LibraryType::kPlain}; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN auto it = this->Attrs().find("use_mkldnn"); - if (library == framework::LibraryType::kPlain && - it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + if (it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 4c77e8b5b56c6593d9bcbd3154f861dd0f030a9a..213d14ec48f66b5632bf15ba0e523b701bac0b68 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -340,8 +340,6 @@ class InterpolateOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - framework::LibraryType library = framework::LibraryType::kPlain; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN @@ -349,12 +347,14 @@ class InterpolateOp : public framework::OperatorWithKernel { // TODO(danqing): support other interp_method if (this->CanMKLDNNBeUsed(ctx, data_type) && (interp_method == "nearest" || interp_method == "bilinear")) { - layout = framework::DataLayout::kMKLDNN; - library = framework::LibraryType::kMKLDNN; + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/interpolate_v2_op.cc b/paddle/fluid/operators/interpolate_v2_op.cc index 62d9c547fa39791d35dd8de4f180649ee7a187af..1bb68699a855337e37c4c401140e6d4778a7a1c6 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cc +++ b/paddle/fluid/operators/interpolate_v2_op.cc @@ -444,8 +444,6 @@ class InterpolateV2Op : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - framework::LibraryType library = framework::LibraryType::kPlain; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN @@ -453,12 +451,14 @@ class InterpolateV2Op : public framework::OperatorWithKernel { // TODO(danqing): support other interp_method if (this->CanMKLDNNBeUsed(ctx, data_type) && (interp_method == "nearest" || interp_method == "bilinear")) { - layout = framework::DataLayout::kMKLDNN; - library = framework::LibraryType::kMKLDNN; + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc index 0346e9b82868a0740779c1524270171581f7a101..13cd443473040071da5785fed327de943f4a3969 100644 --- a/paddle/fluid/operators/layer_norm_op.cc +++ b/paddle/fluid/operators/layer_norm_op.cc @@ -110,21 +110,19 @@ class LayerNormOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN int begin_norm_axis = ctx.Attr("begin_norm_axis"); - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_data_type) && + if (this->CanMKLDNNBeUsed(ctx, input_data_type) && begin_norm_axis == ctx.Input("X")->dims().size() - 1) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index bd495664de601150cc94c475e1a709af2eadac5f..ca2fba56697fcce816f0d345707a69c693ded3ff 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -225,19 +225,18 @@ class LRNOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType( - data_type, ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( @@ -360,19 +359,18 @@ class LRNOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType( - data_type, ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar(