diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index dcaebc10a7408320831503dba4bd76ef55d2ed37..cff160b386eaac1510ba841778815000adebe25b 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1040,21 +1040,23 @@ static void CheckTensorNANOrInf(const std::string& op_type, op_type, name)); } -bool OperatorWithKernel::SupportsMKLDNN() const { +bool OperatorWithKernel::SupportsMKLDNN( + const proto::VarType::Type data_type) const { auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); return std::any_of(op_kernels.begin(), op_kernels.end(), - [](OpKernelMap::const_reference kern_pair) { + [data_type](OpKernelMap::const_reference kern_pair) { return platform::is_cpu_place(kern_pair.first.place_) && kern_pair.first.library_type_ == - LibraryType::kMKLDNN; + LibraryType::kMKLDNN && + kern_pair.first.data_type_ == data_type; }); } -bool OperatorWithKernel::CanMKLDNNBeUsed( - const framework::ExecutionContext& ctx) const { +bool OperatorWithKernel::CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, + proto::VarType::Type data_type) const { bool use_mkldnn_ctx = ctx.Attr("use_mkldnn") && platform::is_cpu_place(ctx.GetPlace()); - return use_mkldnn_ctx && this->SupportsMKLDNN(); + return use_mkldnn_ctx && this->SupportsMKLDNN(data_type); } void OperatorWithKernel::RuntimeInferShape(const Scope& scope, diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index fd1cc18b9513976c163f3b92c08dea8e0831d65e..4ad9bbd9d16cd62963ab989772216754c4ffddef 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -156,8 +156,6 @@ class OperatorBase { virtual bool SupportGPU() const { return false; } - virtual bool SupportsMKLDNN() const { return false; } - const std::string& Type() const { return type_; } bool HasAttr(const std::string& name) const { return attrs_.count(name); } @@ -492,9 +490,10 @@ class OperatorWithKernel : public OperatorBase { return platform::is_gpu_place(kern_pair.first.place_); }); } - bool SupportsMKLDNN() const override; + bool SupportsMKLDNN(proto::VarType::Type data_type) const; - bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) const; + bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx, + proto::VarType::Type data_type) const; virtual void InferShape(InferShapeContext* ctx) const = 0; diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc index 696606441642c91e5dabacaa1af7e28a575e0557..3643fd926d33adbbab60a13d2de1d9fbb851941d 100644 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -93,6 +93,7 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, const std::string& name) { framework::LibraryType library{framework::LibraryType::kPlain}; framework::DataLayout layout = framework::DataLayout::kAnyLayout; + auto data_type = oper.IndicateVarDataType(ctx, name); // FIXME(liuwei1031) temporarily disable the code to unblock users // TODO(liuwei1031) figure out the reason behind // https://github.com/PaddlePaddle/Paddle/issues/16096 @@ -106,13 +107,12 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, #ifdef PADDLE_WITH_MKLDNN auto it = oper.Attrs().find("use_mkldnn"); if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() && - oper.CanMKLDNNBeUsed(ctx)) { + oper.CanMKLDNNBeUsed(ctx, data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType(oper.IndicateVarDataType(ctx, name), - ctx.GetPlace(), layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); } class ActivationOp : public framework::OperatorWithKernel { diff --git a/paddle/fluid/operators/addmm_op.cc b/paddle/fluid/operators/addmm_op.cc index f5b35cbd21889b71110d2aa8fcc48eaa9eb73fa1..c56e3ca9a9a5366e859585c63d2214acf4dc288d 100644 --- a/paddle/fluid/operators/addmm_op.cc +++ b/paddle/fluid/operators/addmm_op.cc @@ -119,7 +119,7 @@ class AddMMOp : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, input_data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index f74aa259e893a83ebb8b776b34a2899a911089fa..fc31885824b55f22bba77559d728a1e40d47e784 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -157,7 +157,8 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType( framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) { + if (library == framework::LibraryType::kPlain && + this->CanMKLDNNBeUsed(ctx, input_data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } @@ -524,17 +525,17 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( // TODO(pzelazko-intel): enable MKLDNN layout when it's ready framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) { + if (library == framework::LibraryType::kPlain && + this->CanMKLDNNBeUsed(ctx, data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout, - library); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); } framework::OpKernelType BatchNormGradOp::GetKernelTypeForVar( diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index e84f0222142cadf66a22465e9a2a66c6a5ccb721..bbc42d97146f24e69d2f2337967e129af013fb6c 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -83,7 +83,7 @@ class ConcatOp : public framework::OperatorWithKernel { "All Inputs of Concat OP are Empty!")); } #ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 268b475f18314a2241c6b9d90c65496f229cd08d..dd7bfbdaefeb2c9f98e26eb63f9b0d2ad73a5d5d 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -155,7 +155,8 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( } #endif #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) { + if (library == framework::LibraryType::kPlain && + this->CanMKLDNNBeUsed(ctx, input_data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; customized_type_value = @@ -556,6 +557,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( // TODO(pzelazko-intel): enable MKLDNN layout when it's ready std::string data_format = "AnyLayout"; framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); #ifdef PADDLE_WITH_CUDA if (platform::CanCUDNNBeUsed(ctx)) { @@ -564,7 +566,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { const std::string data_format = ctx.Attr("data_format"); library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; @@ -572,9 +574,8 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( } #endif - auto type = framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), - layout_, library_, customized_type_value); + auto type = framework::OpKernelType(data_type, ctx.GetPlace(), layout_, + library_, customized_type_value); return type; } diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index 7ff17e68b73a8e108057231d20470cd938d8dc17..018d15e76c920bf16867d2a93d589a3d57b6f1c6 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -182,6 +182,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; bool use_cudnn = ctx.Attr("use_cudnn"); use_cudnn &= platform::is_gpu_place(ctx.GetPlace()); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input"); #ifdef PADDLE_WITH_CUDA if (platform::is_gpu_place(ctx.GetPlace())) { auto& dev_ctx = ctx.template device_context(); @@ -193,15 +194,13 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_); } framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar( diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc index 698c57482dd06e8f74db53f494412cab981bad5d..91e8c04a3d3d858efff9ee88f9b0d487095b3aac 100644 --- a/paddle/fluid/operators/data_norm_op.cc +++ b/paddle/fluid/operators/data_norm_op.cc @@ -184,7 +184,7 @@ class DataNormOp : public framework::OperatorWithKernel { framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, input_data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } @@ -483,18 +483,17 @@ class DataNormGradOp : public framework::OperatorWithKernel { // TODO(pzelazko-intel): enable MKLDNN layout when it's ready framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); } }; diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc index ef6332b6414aa7585bc67cf773881eff7bde6738..cf19e2411090a7ce8f529c7ae4018c5baeacf191 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cc +++ b/paddle/fluid/operators/detection/prior_box_op.cc @@ -98,7 +98,7 @@ class PriorBoxOp : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, input_input_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; auto input_image_type = ctx.Input("Image")->type(); diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index b6f6151e13360441f1517bc9fe75c0dbc6a22249..5f4321f7273c99aa0add3da710e8f427a2bb3f30 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -207,7 +207,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Out"); #ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index 66a9e6dd0fcf277904d8805ff66fee8c6971b1d7..3bc12fe16d979e47fd2535b56d7de55046e0d083 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -34,7 +34,7 @@ class ElementwiseMulOp : public ElementwiseOp { OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); #ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index be10376f6111579377586a04a2cd8212cdcbd2e3..a09fe4b67604130e4cc3aa1385f87586a258d886 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -110,7 +110,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { OperatorWithKernel::IndicateOrPromoteVarDataTypes(ctx, "X", "Y"); #ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); @@ -280,8 +280,9 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { return (ctx.Input("X")->dims() == ctx.Input("Y")->dims()); }; - if (this->CanMKLDNNBeUsed(ctx) && (ctx.Type() != "elementwise_add_grad" || - CanMKLDNNElementwiseAddGradBeUsed())) { + if (this->CanMKLDNNBeUsed(ctx, input_data_type) && + (ctx.Type() != "elementwise_add_grad" || + CanMKLDNNElementwiseAddGradBeUsed())) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); @@ -331,7 +332,7 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut"); #ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); @@ -384,7 +385,7 @@ class ElementwiseOpDoubleGradWithoutDXDY } #ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index 71dccad0b581b0f7f043c989ca9e7854243590f7..e0ecd2cab535adcad76c333a1d2979addda97cb3 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -133,15 +133,14 @@ framework::OpKernelType FusionGRUOp::GetExpectedKernelType( const framework::ExecutionContext& ctx) const { framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (this->CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx, data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout, - library); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); } void FusionGRUOpMaker::Make() { diff --git a/paddle/fluid/operators/gaussian_random_op.cc b/paddle/fluid/operators/gaussian_random_op.cc index 9087a9e8d5c91330c799d8d75101a93eec3cadbb..ea8930fb6f73bc057dc8df580341ecea936a0bd3 100644 --- a/paddle/fluid/operators/gaussian_random_op.cc +++ b/paddle/fluid/operators/gaussian_random_op.cc @@ -112,18 +112,19 @@ class GaussianRandomOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::LibraryType library{framework::LibraryType::kPlain}; framework::DataLayout layout{framework::DataLayout::kAnyLayout}; + auto data_type = + static_cast(ctx.Attr("dtype")); #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - static_cast(ctx.Attr("dtype")), - ctx.device_context(), layout, library); + return framework::OpKernelType(data_type, ctx.device_context(), layout, + library); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/gelu_op.cc b/paddle/fluid/operators/gelu_op.cc index 6c33b05cac955c22548846caef47b4859c880318..3293800e1c6206a7a810781d204ca5779a9ce400 100644 --- a/paddle/fluid/operators/gelu_op.cc +++ b/paddle/fluid/operators/gelu_op.cc @@ -46,17 +46,16 @@ class GeluOp : public framework::OperatorWithKernel { const framework::ExecutionContext &ctx) const override { framework::LibraryType library{framework::LibraryType::kPlain}; framework::DataLayout layout = framework::DataLayout::kAnyLayout; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN auto it = this->Attrs().find("use_mkldnn"); if (library == framework::LibraryType::kPlain && - it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx)) { + it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); } }; @@ -86,17 +85,16 @@ class GeluGradOp : public framework::OperatorWithKernel { const framework::ExecutionContext &ctx) const override { framework::LibraryType library{framework::LibraryType::kPlain}; framework::DataLayout layout = framework::DataLayout::kAnyLayout; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN auto it = this->Attrs().find("use_mkldnn"); if (library == framework::LibraryType::kPlain && - it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx)) { + it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx, data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); } }; diff --git a/paddle/fluid/operators/interpolate_op.cc b/paddle/fluid/operators/interpolate_op.cc index f3699d0d7b6ed2db290c50bf3fb3f594e8c372e1..6c488c387f81500bf12b9a7cc8102944ffb301c4 100644 --- a/paddle/fluid/operators/interpolate_op.cc +++ b/paddle/fluid/operators/interpolate_op.cc @@ -322,20 +322,19 @@ class InterpolateOp : public framework::OperatorWithKernel { const framework::ExecutionContext& ctx) const override { framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::LibraryType library = framework::LibraryType::kPlain; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN auto interp_method = ctx.Attr("interp_method"); // TODO(danqing): support other interp_method - if (this->CanMKLDNNBeUsed(ctx) && + if (this->CanMKLDNNBeUsed(ctx, data_type) && (interp_method == "nearest" || interp_method == "bilinear")) { layout = framework::DataLayout::kMKLDNN; library = framework::LibraryType::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, library); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc index 23de34bc6fa3e0a708eadf73d3c2dc1f70fa18ac..4980315d55eb4d3941ba55eb05262d3ffd83cd27 100644 --- a/paddle/fluid/operators/layer_norm_op.cc +++ b/paddle/fluid/operators/layer_norm_op.cc @@ -124,7 +124,7 @@ class LayerNormOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, input_data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 2d4123ccbd1ccd5e2f6c7ee52cff937f7e13b17b..d6fc143402464e40a9292b8c45cd937a348e5a66 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -199,16 +199,16 @@ class LRNOp : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; // TODO(pzelazko-intel): enable MKLDNN layout when it's ready framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, + library_); } framework::OpKernelType GetKernelTypeForVar( @@ -339,16 +339,16 @@ class LRNOpGrad : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; // TODO(pzelazko-intel): enable MKLDNN layout when it's ready framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, + library_); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 668445d2429e2977f26c569e01a50da66f136130..e97565a662318e0cd0f8fc53d10d6015d36ebd7f 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -661,7 +661,7 @@ class MatMulOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN using mkldnn::memory; - if (this->CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index 9d6c52b98aad1e03647861b8d6bcb368a9c1f9d1..5d1682889535f848738b2244631fe34c89f29436 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -106,7 +106,7 @@ class MulOp : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, input_data_type)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index b78ced8eee263575dd6a7de772d80ec67ea5ec0b..55651dcecf6c290bb19def834611895d30237687 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -149,6 +149,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = "AnyLayout"; framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_CUDA if (platform::CanCUDNNBeUsed(ctx)) { @@ -157,15 +158,13 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_); } framework::OpKernelType PoolOp::GetKernelTypeForVar( @@ -205,6 +204,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = "AnyLayout"; framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_CUDA if (platform::CanCUDNNBeUsed(ctx)) { @@ -213,14 +213,12 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, input_data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); - return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_, library_); } diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index ff750ab47a963c2f1d24e0f74b616534acaa2c41..64030486eb4a5d4993339910b9f35c6d21ee222c 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -64,6 +64,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_CUDA if (platform::CanCUDNNBeUsed(ctx)) { @@ -72,13 +73,12 @@ class SoftmaxOp : public framework::OperatorWithKernel { #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, input_data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, platform::errors::InvalidArgument( @@ -188,7 +188,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); - + auto input_data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); #ifdef PADDLE_WITH_CUDA if (platform::CanCUDNNBeUsed(ctx)) { library_ = framework::LibraryType::kCUDNN; @@ -196,13 +197,11 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, input_data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - auto input_data_type = OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")); if (input_data_type == framework::proto::VarType::FP16) { PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, platform::errors::InvalidArgument( diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index 57fa92b199581a0fdadd3286106caee739d3aea3..741f86f35848b2e626923e381bf007f351584789 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -145,29 +145,26 @@ class SumOp : public framework::OperatorWithKernel { platform::errors::InvalidArgument( "Sum operator should have at least one tensor")); + auto data_type = static_cast(dtype); #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx) && - (static_cast(dtype) == - framework::proto::VarType::FP32 || - static_cast(dtype) == - framework::proto::VarType::BF16) && + this->CanMKLDNNBeUsed(ctx, data_type) && + (data_type == framework::proto::VarType::FP32 || + data_type == framework::proto::VarType::BF16) && ctx.OutputVar("Out")->IsType()) { if (std::all_of(x_vars.begin(), x_vars.end(), [](const framework::Variable* v) { return v->IsType(); })) { - return framework::OpKernelType( - static_cast(dtype), - ctx.GetPlace(), framework::DataLayout::kMKLDNN, - framework::LibraryType::kMKLDNN); + return framework::OpKernelType(data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } } #endif - return framework::OpKernelType( - static_cast(dtype), ctx.GetPlace(), - layout, library); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout, + library); } else if (x_vars[0]->IsType()) { for (auto& var : x_vars) { auto& value = var->Get().value(); diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index d9940ddca3e3baa519c43cdb144e3504f2b40b75..465970451f5d105e6a33555ed241c4528e35d50a 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -86,16 +86,16 @@ class TransposeOp : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, + library_); } }; @@ -184,16 +184,17 @@ class TransposeOpGrad : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + auto data_type = OperatorWithKernel::IndicateVarDataType( + ctx, framework::GradVarName("Out")); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, + library_); } }; @@ -231,9 +232,11 @@ class Transpose2Op : public TransposeOp { int customized_type_value = framework::OpKernelType::kDefaultCustomizedTypeValue; framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + framework::proto::VarType::Type data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; using framework::proto::VarType; @@ -244,9 +247,8 @@ class Transpose2Op : public TransposeOp { : kTransposeMKLDNNFP32; } #endif - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), - layout_, library_, customized_type_value); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, library_, + customized_type_value); } }; @@ -310,16 +312,18 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { framework::LibraryType library_{framework::LibraryType::kPlain}; std::string data_format = ctx.Attr("data_format"); framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + framework::proto::VarType::Type data_type = + OperatorWithKernel::IndicateVarDataType(ctx, + framework::GradVarName("Out")); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - this->CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx, data_type)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } #endif - return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.GetPlace(), layout_, library_); + return framework::OpKernelType(data_type, ctx.GetPlace(), layout_, + library_); } };