diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 21fc293e84179da72be8cc5ee50de46a00fe9a0d..026c1092eb341b6aef0d6ea2260fd79887b22b7a 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1007,6 +1007,23 @@ static void CheckTensorNANOrInf(const std::string& op_type, op_type, name)); } +bool OperatorWithKernel::SupportsMKLDNN() const { + auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); + return std::any_of(op_kernels.begin(), op_kernels.end(), + [](OpKernelMap::const_reference kern_pair) { + return platform::is_cpu_place(kern_pair.first.place_) && + kern_pair.first.library_type_ == + LibraryType::kMKLDNN; + }); +} + +bool OperatorWithKernel::CanMKLDNNBeUsed( + const framework::ExecutionContext& ctx) const { + bool use_mkldnn_ctx = + ctx.Attr("use_mkldnn") && platform::is_cpu_place(ctx.GetPlace()); + return use_mkldnn_ctx && this->SupportsMKLDNN(); +} + void OperatorWithKernel::RuntimeInferShape(const Scope& scope, const platform::Place& place, const RuntimeContext& ctx) const { diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index d493f350e69736fddc2cbda56a8e3967235bce8a..d5107ef5ca22b794f6113733b24e8efb3cd1701c 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -156,6 +156,8 @@ class OperatorBase { virtual bool SupportGPU() const { return false; } + virtual bool SupportsMKLDNN() const { return false; } + const std::string& Type() const { return type_; } bool HasAttr(const std::string& name) const { return attrs_.count(name); } @@ -490,6 +492,9 @@ class OperatorWithKernel : public OperatorBase { return platform::is_gpu_place(kern_pair.first.place_); }); } + bool SupportsMKLDNN() const override; + + bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) const; virtual void InferShape(InferShapeContext* ctx) const = 0; diff --git a/paddle/fluid/operators/activation_op.cc b/paddle/fluid/operators/activation_op.cc old mode 100755 new mode 100644 index 40951d5960352293236871c08370fc8243569e93..26b4ed71e00219fcb5f5942a69d11e983f245e89 --- a/paddle/fluid/operators/activation_op.cc +++ b/paddle/fluid/operators/activation_op.cc @@ -106,7 +106,7 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, #ifdef PADDLE_WITH_MKLDNN auto it = oper.Attrs().find("use_mkldnn"); if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() && - platform::CanMKLDNNBeUsed(ctx)) { + oper.CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/addmm_op.cc b/paddle/fluid/operators/addmm_op.cc index f6e6856c61588f08f58a4e92e14e2a78f63745e5..f5b35cbd21889b71110d2aa8fcc48eaa9eb73fa1 100644 --- a/paddle/fluid/operators/addmm_op.cc +++ b/paddle/fluid/operators/addmm_op.cc @@ -119,7 +119,7 @@ class AddMMOp : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index 370ba8619f18875f9c8ade531d7d97427a7505d6..f74aa259e893a83ebb8b776b34a2899a911089fa 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -157,8 +157,7 @@ framework::OpKernelType BatchNormOp::GetExpectedKernelType( framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } @@ -527,8 +526,7 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/concat_op.cc b/paddle/fluid/operators/concat_op.cc index 7937e432d22faa3ffd93e46a39b7b1cc5500dbf8..0b3697156d36b48e7d7cbf6c175daa160754e46c 100644 --- a/paddle/fluid/operators/concat_op.cc +++ b/paddle/fluid/operators/concat_op.cc @@ -83,7 +83,7 @@ class ConcatOp : public framework::OperatorWithKernel { "All Inputs of Concat OP are Empty!")); } #ifdef PADDLE_WITH_MKLDNN - if (platform::CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 76ff1084fa61b4cc7fec3a59f39b956ec6582998..72355c7d3a45873cd58cbfbd41e8ab2732030be8 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -155,8 +155,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( } #endif #ifdef PADDLE_WITH_MKLDNN - if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + if (library == framework::LibraryType::kPlain && this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; customized_type_value = @@ -565,7 +564,7 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { const std::string data_format = ctx.Attr("data_format"); library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index 7e0e77214c5320aa9a807fc65531f163fa7ce09e..6c4844855591911c025230822768d091826cb794 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -193,7 +193,7 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/data_norm_op.cc b/paddle/fluid/operators/data_norm_op.cc index 45e77a99e6b3eb94f8ee15c8763dc32ff389d9ea..7dc1e23207d565a7a7636887e1cba78944acc01a 100644 --- a/paddle/fluid/operators/data_norm_op.cc +++ b/paddle/fluid/operators/data_norm_op.cc @@ -184,7 +184,7 @@ class DataNormOp : public framework::OperatorWithKernel { framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } @@ -486,7 +486,7 @@ class DataNormGradOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/detection/prior_box_op.cc b/paddle/fluid/operators/detection/prior_box_op.cc index 0d293bb964b615c68891be516534a05cc2277426..ef6332b6414aa7585bc67cf773881eff7bde6738 100644 --- a/paddle/fluid/operators/detection/prior_box_op.cc +++ b/paddle/fluid/operators/detection/prior_box_op.cc @@ -98,7 +98,7 @@ class PriorBoxOp : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; auto input_image_type = ctx.Input("Image")->type(); diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.h b/paddle/fluid/operators/elementwise/elementwise_div_op.h index 5ac3ffe225dba575dd31a6c8ad4d228f98698f29..1d016fba34b46aca783e8c9b364e789f23ad77aa 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.h @@ -163,7 +163,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX"); #ifdef PADDLE_WITH_MKLDNN - if (platform::CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index e4d3ea6d7291eff8911d8419cda96f2d2738b9a1..49456149c2ca81f9c42ff9a1e487ef16655359bc 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -33,7 +33,7 @@ class ElementwiseMulOp : public ElementwiseOp { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (platform::CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index ece6af1b5a6f562bd7ff81290f98e8636feabb0c..bbb240efaea5dcfeba75af3e89bc1413720050ac 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -108,7 +108,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - if (platform::CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); @@ -265,9 +265,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { return (ctx.Input("X")->dims() == ctx.Input("Y")->dims()); }; - if (platform::CanMKLDNNBeUsed(ctx) && - (ctx.Type() != "elementwise_add_grad" || - CanMKLDNNElementwiseAddGradBeUsed())) { + if (this->CanMKLDNNBeUsed(ctx) && (ctx.Type() != "elementwise_add_grad" || + CanMKLDNNElementwiseAddGradBeUsed())) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); @@ -304,7 +303,7 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut"); #ifdef PADDLE_WITH_MKLDNN - if (platform::CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); @@ -343,7 +342,7 @@ class ElementwiseOpDoubleGradWithoutDXDY } #ifdef PADDLE_WITH_MKLDNN - if (platform::CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index e3776a80b316089891282136022a4e6656360c6e..f5904039d4b6ef9794991687c535a0989864e9f6 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -133,7 +133,7 @@ framework::OpKernelType FusionGRUOp::GetExpectedKernelType( framework::LibraryType library = framework::LibraryType::kPlain; framework::DataLayout layout = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN - if (platform::CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/gaussian_random_op.cc b/paddle/fluid/operators/gaussian_random_op.cc index fd2f48265ca6f4613d273207de79f96c1d2bcbea..840975f754f5afca3ad76251ac65cef35714a1b8 100644 --- a/paddle/fluid/operators/gaussian_random_op.cc +++ b/paddle/fluid/operators/gaussian_random_op.cc @@ -115,7 +115,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/gelu_op.cc b/paddle/fluid/operators/gelu_op.cc index 9ca0d30362c5aafbfd3b21cf5ac27853b2eb77cf..6c33b05cac955c22548846caef47b4859c880318 100644 --- a/paddle/fluid/operators/gelu_op.cc +++ b/paddle/fluid/operators/gelu_op.cc @@ -49,7 +49,7 @@ class GeluOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN auto it = this->Attrs().find("use_mkldnn"); if (library == framework::LibraryType::kPlain && - it != this->Attrs().end() && platform::CanMKLDNNBeUsed(ctx)) { + it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } @@ -89,7 +89,7 @@ class GeluGradOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN auto it = this->Attrs().find("use_mkldnn"); if (library == framework::LibraryType::kPlain && - it != this->Attrs().end() && platform::CanMKLDNNBeUsed(ctx)) { + it != this->Attrs().end() && this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/layer_norm_op.cc b/paddle/fluid/operators/layer_norm_op.cc index 79e3d3b90a93ae900961c8b36e653eae903b85a1..6f83a667a5941f881adc4920083fa2f9a2ca48aa 100644 --- a/paddle/fluid/operators/layer_norm_op.cc +++ b/paddle/fluid/operators/layer_norm_op.cc @@ -104,7 +104,7 @@ class LayerNormOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index fc9d61eb75b547ba90318180036b2c8126ebf4a3..2d4123ccbd1ccd5e2f6c7ee52cff937f7e13b17b 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -201,7 +201,7 @@ class LRNOp : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } @@ -341,7 +341,7 @@ class LRNOpGrad : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 129298edafcf9a42d5f2058786e946faffa6618b..639a6991a4ff0bb6bc1e3838f9ae818e57fb2344 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -659,7 +659,7 @@ class MatMulOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN using mkldnn::memory; - if (platform::CanMKLDNNBeUsed(ctx)) { + if (this->CanMKLDNNBeUsed(ctx)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index b3afba1e4f9791b1b9027ca038b495380f403773..9d6c52b98aad1e03647861b8d6bcb368a9c1f9d1 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -106,7 +106,7 @@ class MulOp : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library = framework::LibraryType::kMKLDNN; layout = framework::DataLayout::kMKLDNN; diff --git a/paddle/fluid/operators/pool_op.cc b/paddle/fluid/operators/pool_op.cc index 5b0980a98513bfa0bb619ac182fbfca4961dada2..b78ced8eee263575dd6a7de772d80ec67ea5ec0b 100644 --- a/paddle/fluid/operators/pool_op.cc +++ b/paddle/fluid/operators/pool_op.cc @@ -157,7 +157,7 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } @@ -213,7 +213,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index ff25f1911072c3a73e3ed365630b1ce18c78bd30..ff750ab47a963c2f1d24e0f74b616534acaa2c41 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -72,7 +72,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } @@ -196,7 +196,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { #endif #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/operators/sum_op.cc b/paddle/fluid/operators/sum_op.cc index faade79091c4afcc0d0bf9625619fca1815b6db9..57fa92b199581a0fdadd3286106caee739d3aea3 100644 --- a/paddle/fluid/operators/sum_op.cc +++ b/paddle/fluid/operators/sum_op.cc @@ -147,7 +147,7 @@ class SumOp : public framework::OperatorWithKernel { #ifdef PADDLE_WITH_MKLDNN if (library == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx) && + this->CanMKLDNNBeUsed(ctx) && (static_cast(dtype) == framework::proto::VarType::FP32 || static_cast(dtype) == diff --git a/paddle/fluid/operators/transpose_op.cc b/paddle/fluid/operators/transpose_op.cc index 0e870937ec1a51d577d92aca7da7c6853a68f786..a098327ab29af57d619dabe1c814eca6b97ee2cc 100644 --- a/paddle/fluid/operators/transpose_op.cc +++ b/paddle/fluid/operators/transpose_op.cc @@ -88,7 +88,7 @@ class TransposeOp : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } @@ -186,7 +186,7 @@ class TransposeOpGrad : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } @@ -233,7 +233,7 @@ class Transpose2Op : public TransposeOp { framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; using framework::proto::VarType; @@ -298,7 +298,7 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { framework::DataLayout layout_ = framework::StringToDataLayout(data_format); #ifdef PADDLE_WITH_MKLDNN if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { + this->CanMKLDNNBeUsed(ctx)) { library_ = framework::LibraryType::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN; } diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 34f5759e4cd01607a63946174d2726ed00b8693c..797ff42f3c201458fd02caa445a9f5336a3cdb19 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -134,11 +134,6 @@ inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector& dims, return mkldnn::memory::desc({dims}, data_type, format); } -inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) { - bool use_mkldnn = ctx.Attr("use_mkldnn"); - return use_mkldnn && platform::is_cpu_place(ctx.GetPlace()); -} - inline void ClearMKLDNNCache(const platform::Place& place) { // Clear mkl-dnn cache, if (platform::is_cpu_place(place)) {