From b2b9a1bb83a4912ae77fe92a9a1e5f5d61abe728 Mon Sep 17 00:00:00 2001 From: jiahongyu Date: Sat, 17 Sep 2022 08:30:44 +0000 Subject: [PATCH] refine mkldnn code --- paddle/fluid/operators/dequantize_op.cc | 15 ++++---- paddle/fluid/operators/fc_op.cc | 25 ++++++-------- paddle/fluid/operators/gaussian_random_op.cc | 14 ++++---- .../fluid/operators/gaussian_random_op_npu.cc | 3 -- paddle/fluid/operators/gelu_op.cc | 26 +++++++------- paddle/fluid/operators/interpolate_op.cc | 10 +++--- paddle/fluid/operators/interpolate_v2_op.cc | 10 +++--- paddle/fluid/operators/layer_norm_op.cc | 14 ++++---- paddle/fluid/operators/lrn_op.cc | 34 +++++++++---------- 9 files changed, 67 insertions(+), 84 deletions(-) diff --git a/paddle/fluid/operators/dequantize_op.cc b/paddle/fluid/operators/dequantize_op.cc index 0fb13e24b80..fc6dbbade03 100644 --- a/paddle/fluid/operators/dequantize_op.cc +++ b/paddle/fluid/operators/dequantize_op.cc @@ -24,14 +24,13 @@ namespace operators { framework::OpKernelType DeQuantOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library_ = framework::LibraryType::kMKLDNN; - framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; - - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.GetPlace(), - layout_, - library_); + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "Input"); + + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } void DeQuantOpMaker::Make() { diff --git a/paddle/fluid/operators/fc_op.cc b/paddle/fluid/operators/fc_op.cc index 43bb6089a87..1f993b58aa7 100644 --- a/paddle/fluid/operators/fc_op.cc +++ b/paddle/fluid/operators/fc_op.cc @@ -126,26 +126,21 @@ class FCOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); if (ctx.Attr("use_mkldnn")) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; using framework::proto::VarType; - customized_type_value = (input_data_type == VarType::INT8 || - input_data_type == VarType::UINT8) - ? kFCMKLDNNINT8 - : kFCMKLDNNFP32; + int customized_type_value = (input_data_type == VarType::INT8 || + input_data_type == VarType::UINT8) + ? kFCMKLDNNINT8 + : kFCMKLDNNFP32; + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN, + customized_type_value); } - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - layout, - library, - customized_type_value); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/gaussian_random_op.cc b/paddle/fluid/operators/gaussian_random_op.cc index 8d92305eb6f..b80bc7320c1 100644 --- a/paddle/fluid/operators/gaussian_random_op.cc +++ b/paddle/fluid/operators/gaussian_random_op.cc @@ -58,21 +58,19 @@ class GaussianRandomOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library{framework::LibraryType::kPlain}; - framework::DataLayout layout{framework::DataLayout::kAnyLayout}; auto data_type = static_cast(ctx.Attr("dtype")); #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.device_context(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType( - data_type, ctx.device_context(), layout, library); + return framework::OpKernelType(data_type, ctx.device_context()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/gaussian_random_op_npu.cc b/paddle/fluid/operators/gaussian_random_op_npu.cc index 8b3af57d923..bdeb106f816 100644 --- a/paddle/fluid/operators/gaussian_random_op_npu.cc +++ b/paddle/fluid/operators/gaussian_random_op_npu.cc @@ -18,9 +18,6 @@ limitations under the License. */ #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" -#ifdef PADDLE_WITH_MKLDNN -#include "paddle/fluid/platform/mkldnn_helper.h" -#endif namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/gelu_op.cc b/paddle/fluid/operators/gelu_op.cc index a16544b8ba3..c5ec8d1b21a 100644 --- a/paddle/fluid/operators/gelu_op.cc +++ b/paddle/fluid/operators/gelu_op.cc @@ -35,17 +35,16 @@ class GeluOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::LibraryType library{framework::LibraryType::kPlain}; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } }; @@ -76,18 +75,17 @@ class GeluGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - framework::LibraryType library{framework::LibraryType::kPlain}; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN auto it = this->Attrs().find("use_mkldnn"); - if (library == framework::LibraryType::kPlain && - it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + if (it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index 4c77e8b5b56..213d14ec48f 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -340,8 +340,6 @@ class InterpolateOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - framework::LibraryType library = framework::LibraryType::kPlain; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN @@ -349,12 +347,14 @@ class InterpolateOp : public framework::OperatorWithKernel { // TODO(danqing): support other interp_method if (this->CanMKLDNNBeUsed(ctx, data_type) && (interp_method == "nearest" || interp_method == "bilinear")) { - layout = framework::DataLayout::kMKLDNN; - library = framework::LibraryType::kMKLDNN; + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/interpolate_v2_op.cc b/paddle/fluid/operators/interpolate_v2_op.cc index 62d9c547fa3..1bb68699a85 100644 --- a/paddle/fluid/operators/interpolate_v2_op.cc +++ b/paddle/fluid/operators/interpolate_v2_op.cc @@ -444,8 +444,6 @@ class InterpolateV2Op : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - framework::LibraryType library = framework::LibraryType::kPlain; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN @@ -453,12 +451,14 @@ class InterpolateV2Op : public framework::OperatorWithKernel { // TODO(danqing): support other interp_method if (this->CanMKLDNNBeUsed(ctx, data_type) && (interp_method == "nearest" || interp_method == "bilinear")) { - layout = framework::DataLayout::kMKLDNN; - library = framework::LibraryType::kMKLDNN; + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc index 0346e9b8286..13cd4434730 100644 --- a/paddle/fluid/operators/layer_norm_op.cc +++ b/paddle/fluid/operators/layer_norm_op.cc @@ -110,21 +110,19 @@ class LayerNormOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN int begin_norm_axis = ctx.Attr("begin_norm_axis"); - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_data_type) && + if (this->CanMKLDNNBeUsed(ctx, input_data_type) && begin_norm_axis == ctx.Input("X")->dims().size() - 1) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index bd495664de6..ca2fba56697 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -225,19 +225,18 @@ class LRNOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType( - data_type, ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( @@ -360,19 +359,18 @@ class LRNOpGrad : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library_ = framework::LibraryType::kMKLDNN; - layout_ = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType( - data_type, ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( -- GitLab