From ea96172efde585df86035de6c34582c0853e4655 Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Thu, 15 Sep 2022 19:49:05 +0800 Subject: [PATCH] refine PADDLE_WITH_MKLDNN code (#46053) * refine PADDLE_WITH_MKLDNN code * fix data_norm_op * polish addmm_op --- paddle/fluid/operators/addmm_op.cc | 24 +++++++----------- paddle/fluid/operators/angle_op.cc | 3 --- paddle/fluid/operators/batch_norm_op.cc | 27 +++++++++------------ paddle/fluid/operators/conv_transpose_op.cc | 4 +-- paddle/fluid/operators/data_norm_op.cc | 27 +++++++++------------ 5 files changed, 34 insertions(+), 51 deletions(-) diff --git a/paddle/fluid/operators/addmm_op.cc b/paddle/fluid/operators/addmm_op.cc index c7d6201ed2..833285615f 100644 --- a/paddle/fluid/operators/addmm_op.cc +++ b/paddle/fluid/operators/addmm_op.cc @@ -39,29 +39,23 @@ class AddMMOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const { - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; - int customized_type_value = - framework::OpKernelType::kDefaultCustomizedTypeValue; auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; - + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; if (input_data_type == framework::DataTypeTrait::DataType() || input_data_type == framework::DataTypeTrait::DataType()) { customized_type_value = kMULMKLDNNINT8; } + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN, + customized_type_value); } #endif - - return framework::OpKernelType(input_data_type, - ctx.GetPlace(), - layout, - library, - customized_type_value); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/angle_op.cc b/paddle/fluid/operators/angle_op.cc index 5c18f4c6fc..ccd5584e8d 100644 --- a/paddle/fluid/operators/angle_op.cc +++ b/paddle/fluid/operators/angle_op.cc @@ -16,9 +16,6 @@ #include #include #include -#ifdef PADDLE_WITH_MKLDNN -#include "paddle/fluid/platform/mkldnn_helper.h" -#endif #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index a4a3f3cd2b..84f22ebff4 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -195,18 +195,16 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType( "Variance input should be of float type")); // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } framework::OpKernelType BatchNormOp::GetKernelTypeForVar( @@ -396,19 +394,18 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( } // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar( diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index 5cc991d8f1..d883d2da29 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -80,9 +80,7 @@ framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar( // op. Treat this as NCHW (default data_format value) if (dl != framework::DataLayout::kAnyLayout) { return framework::OpKernelType( - expected_kernel_type.data_type_, - tensor.place(), - framework::StringToDataLayout(data_format)); + expected_kernel_type.data_type_, tensor.place(), dl); } } #endif diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc index a4cfb82bf8..4fc279e03a 100644 --- a/paddle/fluid/operators/data_norm_op.cc +++ b/paddle/fluid/operators/data_norm_op.cc @@ -200,18 +200,16 @@ class DataNormOp : public framework::OperatorWithKernel { "bias input should be of float type")); } // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, input_data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType( - input_data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -511,19 +509,18 @@ class DataNormGradOp : public framework::OperatorWithKernel { } // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::LibraryType library = framework::LibraryType::kPlain; - framework::DataLayout layout = framework::DataLayout::kAnyLayout; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx, data_type)) { - library = framework::LibraryType::kMKLDNN; - layout = framework::DataLayout::kMKLDNN; + if (this->CanMKLDNNBeUsed(ctx, data_type)) { + return framework::OpKernelType(data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif - return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace()); } }; -- GitLab