From 5bf25d1e8b6eef2eea8aa24f5dbacea0b832aae2 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Mon, 25 Jan 2021 15:58:00 +0100 Subject: [PATCH] More precise mkldnn kernel rules in GetExpectedKernelType (#29840) * More precise mkldnn kernel choice in GetExpectedKernelType * Fixes after review * Refresh develop for CI * CI experiment * get back from CI exper --- paddle/fluid/framework/operator.cc | 14 ++++---- paddle/fluid/framework/operator.h | 7 ++-- paddle/fluid/operators/activation_op.cc | 6 ++-- paddle/fluid/operators/addmm_op.cc | 2 +- paddle/fluid/operators/batch_norm_op.cc | 11 +++--- paddle/fluid/operators/concat_op.cc | 2 +- paddle/fluid/operators/conv_op.cc | 11 +++--- paddle/fluid/operators/conv_transpose_op.cc | 7 ++-- paddle/fluid/operators/data_norm_op.cc | 9 +++-- .../fluid/operators/detection/prior_box_op.cc | 2 +- .../elementwise/elementwise_div_op.h | 2 +- .../elementwise/elementwise_mul_op.h | 2 +- .../operators/elementwise/elementwise_op.h | 11 +++--- paddle/fluid/operators/fused/fusion_gru_op.cc | 7 ++-- paddle/fluid/operators/gaussian_random_op.cc | 9 ++--- paddle/fluid/operators/gelu_op.cc | 14 ++++---- paddle/fluid/operators/interpolate_op.cc | 7 ++-- paddle/fluid/operators/layer_norm_op.cc | 2 +- paddle/fluid/operators/lrn_op.cc | 16 ++++----- paddle/fluid/operators/matmul_op.cc | 2 +- paddle/fluid/operators/mul_op.cc | 2 +- paddle/fluid/operators/pool_op.cc | 12 +++---- paddle/fluid/operators/softmax_op.cc | 11 +++--- paddle/fluid/operators/sum_op.cc | 21 +++++------ paddle/fluid/operators/transpose_op.cc | 36 ++++++++++--------- 25 files changed, 111 insertions(+), 114 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index dcaebc10a74..cff160b386e 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1040,21 +1040,23 @@ static void CheckTensorNANOrInf(const std::string& op_type, op_type, name)); } -bool OperatorWithKernel::SupportsMKLDNN() const { +bool OperatorWithKernel::SupportsMKLDNN( + const proto::VarType::Type data_type) const { auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); return std::any_of(op_kernels.begin(), op_kernels.end(), - [](OpKernelMap::const_reference kern_pair) { + [data_type](OpKernelMap::const_reference kern_pair) { return platform::is_cpu_place(kern_pair.first.place_) && kern_pair.first.library_type_ == - LibraryType::kMKLDNN; + LibraryType::kMKLDNN && + kern_pair.first.data_type_ == data_type; }); } -bool OperatorWithKernel::CanMKLDNNBeUsed( - const framework::ExecutionContext& ctx) const { +bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, + proto::VarType::Type data_type) const { bool use_mkldnn_ctx = ctx.Attr("use_mkldnn") && platform::is_cpu_place(ctx.GetPlace()); - return use_mkldnn_ctx && this->SupportsMKLDNN(); + return use_mkldnn_ctx && this->SupportsMKLDNN(data_type); } void OperatorWithKernel::RuntimeInferShape(const Scope& scope, diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index fd1cc18b951..4ad9bbd9d16 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -156,8 +156,6 @@ class OperatorBase { virtual bool SupportGPU() const { return false; } - virtual bool SupportsMKLDNN() const { return false; } - const std::string& Type() const { return type_; } bool HasAttr(const std::string& name) const { return attrs_.count(name); } @@ -492,9 +490,10 @@ class OperatorWithKernel : public OperatorBase { return platform::is_gpu_place(kern_pair.first.place_); }); } - bool SupportsMKLDNN() const override; + bool SupportsMKLDNN(proto::VarType::Type data_type) const; - bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) const; + bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, + proto::VarType::Type data_type) const; virtual void InferShape(InferShapeContext* ctx) const = 0; diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 69660644164..3643fd926d3 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -93,6 +93,7 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, const std::string& name) { framework::LibraryType library{framework::LibraryType::kPlain}; framework::DataLayout layout = framework::DataLayout::kAnyLayout; + auto data_type = oper.IndicateVarDataType(ctx, name); // FIXME(liuwei1031) temporarily disable the code to unblock users // TODO(liuwei1031) figure out the reason behind // https://github.com/PaddlePaddle/Paddle/issues/16096 @@ -106,13 +107,12 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, #ifdef PADDLE_WITH_MKLDNN auto it = oper.Attrs().find("use_mkldnn"); if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() && - oper.CanMKLDNNBeUsed(ctx)) { + oper.CanMKLDNNBeUsed(ctx, data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType(oper.IndicateVarDataType(ctx, name), - ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); } class ActivationOp : public framework::OperatorWithKernel { diff --git a/paddle/fluid/operators/addmm_op.cc b/paddle/fluid/operators/addmm_op.cc index f5b35cbd218..c56e3ca9a9a 100644 --- a/paddle/fluid/operators/addmm_op.cc +++ b/paddle/fluid/operators/addmm_op.cc @@ -119,7 +119,7 @@ class AddMMOp : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, input_data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index f74aa259e89..fc31885824b 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -157,7 +157,8 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType( framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) { + if (library == framework::LibraryType::kPlain && + this->CanMKLDNNBeUsed(ctx, input_data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } @@ -524,17 +525,17 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( // TODO(pzelazko-intel): enable MKLDNN layout when it's ready framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) { + if (library == framework::LibraryType::kPlain && + this->CanMKLDNNBeUsed(ctx, data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout, - library); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); } framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar( diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index e84f0222142..bbc42d97146 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -83,7 +83,7 @@ class ConcatOp : public framework::OperatorWithKernel { "All Inputs of Concat OP are Empty!")); } #ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 268b475f183..dd7bfbdaefe 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -155,7 +155,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( } #endif #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) { + if (library == framework::LibraryType::kPlain && + this->CanMKLDNNBeUsed(ctx, input_data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; customized_type_value = @@ -556,6 +557,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( // TODO(pzelazko-intel): enable MKLDNN layout when it's ready std::string data_format = "AnyLayout"; framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); #ifdef PADDLE_WITH_CUDA if (platform::CanCUDNNBeUsed(ctx)) { @@ -564,7 +566,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { const std::string data_format = ctx.Attr("data_format"); library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; @@ -572,9 +574,8 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( } #endif - auto type = framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), - layout_, library_, customized_type_value); + auto type = framework::OpKernelType(data_type, ctx.GetPlace(), layout_, + library_, customized_type_value); return type; } diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index 7ff17e68b73..018d15e76c9 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -182,6 +182,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; bool use_cudnn = ctx.Attr("use_cudnn"); use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(ctx.GetPlace())) { auto& dev_ctx = ctx.template device_context(); @@ -193,15 +194,13 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_); } framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar( diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc index 698c57482dd..91e8c04a3d3 100644 --- a/paddle/fluid/operators/data_norm_op.cc +++ b/paddle/fluid/operators/data_norm_op.cc @@ -184,7 +184,7 @@ class DataNormOp : public framework::OperatorWithKernel { framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, input_data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } @@ -483,18 +483,17 @@ class DataNormGradOp : public framework::OperatorWithKernel { // TODO(pzelazko-intel): enable MKLDNN layout when it's ready framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); } }; diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc index ef6332b6414..cf19e241109 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cc +++ b/paddle/fluid/operators/detection/prior_box_op.cc @@ -98,7 +98,7 @@ class PriorBoxOp : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, input_input_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; auto input_image_type = ctx.Input("Image")->type(); diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index b6f6151e133..5f4321f7273 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -207,7 +207,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out"); #ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index 66a9e6dd0fc..3bc12fe16d9 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -34,7 +34,7 @@ class ElementwiseMulOp : public ElementwiseOp { OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); #ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index be10376f611..a09fe4b6760 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -110,7 +110,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); #ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); @@ -280,8 +280,9 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { return (ctx.Input("X")->dims() == ctx.Input("Y")->dims()); }; - if (this->CanMKLDNNBeUsed(ctx) && (ctx.Type() != "elementwise_add_grad" || - CanMKLDNNElementwiseAddGradBeUsed())) { + if (this->CanMKLDNNBeUsed(ctx, input_data_type) && + (ctx.Type() != "elementwise_add_grad" || + CanMKLDNNElementwiseAddGradBeUsed())) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); @@ -331,7 +332,7 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut"); #ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); @@ -384,7 +385,7 @@ class ElementwiseOpDoubleGradWithoutDXDY } #ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index 71dccad0b58..e0ecd2cab53 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -133,15 +133,14 @@ framework::OpKernelType FusionGRUOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx, data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout, - library); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); } void FusionGRUOpMaker::Make() { diff --git a/paddle/fluid/operators/gaussian_random_op.cc b/paddle/fluid/operators/gaussian_random_op.cc index 9087a9e8d5c..ea8930fb6f7 100644 --- a/paddle/fluid/operators/gaussian_random_op.cc +++ b/paddle/fluid/operators/gaussian_random_op.cc @@ -112,18 +112,19 @@ class GaussianRandomOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library{framework::LibraryType::kPlain}; framework::DataLayout layout{framework::DataLayout::kAnyLayout}; + auto data_type = + static_cast(ctx.Attr("dtype")); #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - static_cast(ctx.Attr("dtype")), - ctx.device_context(), layout, library); + return framework::OpKernelType(data_type, ctx.device_context(), layout, + library); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/gelu_op.cc b/paddle/fluid/operators/gelu_op.cc index 6c33b05cac9..3293800e1c6 100644 --- a/paddle/fluid/operators/gelu_op.cc +++ b/paddle/fluid/operators/gelu_op.cc @@ -46,17 +46,16 @@ class GeluOp : public framework::OperatorWithKernel { const framework::ExecutionContext &ctx) const override { framework::LibraryType library{framework::LibraryType::kPlain}; framework::DataLayout layout = framework::DataLayout::kAnyLayout; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN auto it = this->Attrs().find("use_mkldnn"); if (library == framework::LibraryType::kPlain && - it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx)) { + it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); } }; @@ -86,17 +85,16 @@ class GeluGradOp : public framework::OperatorWithKernel { const framework::ExecutionContext &ctx) const override { framework::LibraryType library{framework::LibraryType::kPlain}; framework::DataLayout layout = framework::DataLayout::kAnyLayout; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN auto it = this->Attrs().find("use_mkldnn"); if (library == framework::LibraryType::kPlain && - it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx)) { + it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); } }; diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index f3699d0d7b6..6c488c387f8 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -322,20 +322,19 @@ class InterpolateOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::LibraryType library = framework::LibraryType::kPlain; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN auto interp_method = ctx.Attr("interp_method"); // TODO(danqing): support other interp_method - if (this->CanMKLDNNBeUsed(ctx) && + if (this->CanMKLDNNBeUsed(ctx, data_type) && (interp_method == "nearest" || interp_method == "bilinear")) { layout = framework::DataLayout::kMKLDNN; library = framework::LibraryType::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc index 23de34bc6fa..4980315d55e 100644 --- a/paddle/fluid/operators/layer_norm_op.cc +++ b/paddle/fluid/operators/layer_norm_op.cc @@ -124,7 +124,7 @@ class LayerNormOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, input_data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 2d4123ccbd1..d6fc1434024 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -199,16 +199,16 @@ class LRNOp : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; // TODO(pzelazko-intel): enable MKLDNN layout when it's ready framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, + library_); } framework::OpKernelType GetKernelTypeForVar( @@ -339,16 +339,16 @@ class LRNOpGrad : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; // TODO(pzelazko-intel): enable MKLDNN layout when it's ready framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, + library_); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 668445d2429..e97565a6623 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -661,7 +661,7 @@ class MatMulOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN using mkldnn::memory; - if (this->CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index 9d6c52b98aa..5d168288953 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -106,7 +106,7 @@ class MulOp : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, input_data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index b78ced8eee2..55651dcecf6 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -149,6 +149,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = "AnyLayout"; framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_CUDA if (platform::CanCUDNNBeUsed(ctx)) { @@ -157,15 +158,13 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_); } framework::OpKernelType PoolOp::GetKernelTypeForVar( @@ -205,6 +204,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = "AnyLayout"; framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_CUDA if (platform::CanCUDNNBeUsed(ctx)) { @@ -213,14 +213,12 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, input_data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, library_); } diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index ff750ab47a9..64030486eb4 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -64,6 +64,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_CUDA if (platform::CanCUDNNBeUsed(ctx)) { @@ -72,13 +73,12 @@ class SoftmaxOp : public framework::OperatorWithKernel { #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, input_data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, platform::errors::InvalidArgument( @@ -188,7 +188,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); - + auto input_data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); #ifdef PADDLE_WITH_CUDA if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; @@ -196,13 +197,11 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, input_data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - auto input_data_type = OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")); if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, platform::errors::InvalidArgument( diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 57fa92b1995..741f86f3584 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -145,29 +145,26 @@ class SumOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "Sum operator should have at least one tensor")); + auto data_type = static_cast(dtype); #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx) && - (static_cast(dtype) == - framework::proto::VarType::FP32 || - static_cast(dtype) == - framework::proto::VarType::BF16) && + this->CanMKLDNNBeUsed(ctx, data_type) && + (data_type == framework::proto::VarType::FP32 || + data_type == framework::proto::VarType::BF16) && ctx.OutputVar("Out")->IsType()) { if (std::all_of(x_vars.begin(), x_vars.end(), [](const framework::Variable* v) { return v->IsType(); })) { - return framework::OpKernelType( - static_cast(dtype), - ctx.GetPlace(), framework::DataLayout::kMKLDNN, - framework::LibraryType::kMKLDNN); + return framework::OpKernelType(data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } } #endif - return framework::OpKernelType( - static_cast(dtype), ctx.GetPlace(), - layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, + library); } else if (x_vars[0]->IsType()) { for (auto& var : x_vars) { auto& value = var->Get().value(); diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index d9940ddca3e..465970451f5 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -86,16 +86,16 @@ class TransposeOp : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, + library_); } }; @@ -184,16 +184,17 @@ class TransposeOpGrad : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, + library_); } }; @@ -231,9 +232,11 @@ class Transpose2Op : public TransposeOp { int customized_type_value = framework::OpKernelType::kDefaultCustomizedTypeValue; framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + framework::proto::VarType::Type data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; using framework::proto::VarType; @@ -244,9 +247,8 @@ class Transpose2Op : public TransposeOp { : kTransposeMKLDNNFP32; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout_, library_, customized_type_value); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_, + customized_type_value); } }; @@ -310,16 +312,18 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + framework::proto::VarType::Type data_type = + OperatorWithKernel::IndicateVarDataType(ctx, + framework::GradVarName("Out")); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, + library_); } }; -- GitLab