diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 6bd451a118afeffd4f5dd80e2450c491ef6d55b0..00db09ece32156be414f98bc46eb8d47b5c21689 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -119,6 +119,26 @@ struct LRNGradFunctor { template struct LRNGradFunctor; template struct LRNGradFunctor; +namespace { + framework::OpKernelType GetExpectedLRNKernel( + const framework::ExecutionContext& ctx) { + framework::LibraryType library_{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + } +#endif + + std::string data_format = ctx.Attr("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), + layout_, library_); + } +} + class LRNOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -140,21 +160,8 @@ class LRNOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - framework::LibraryType library_{framework::LibraryType::kPlain}; -#ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kMKLDNN; - } -#endif - - std::string data_format = ctx.Attr("data_format"); - // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), - layout_, library_); + const framework::ExecutionContext& ctx) const override { + return GetExpectedLRNKernel(ctx); } }; @@ -261,21 +268,8 @@ class LRNOpGrad : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - framework::LibraryType library_{framework::LibraryType::kPlain}; -#ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kMKLDNN; - } -#endif - - std::string data_format = ctx.Attr("data_format"); - // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), - layout_, library_); + const framework::ExecutionContext& ctx) const override { + return GetExpectedLRNKernel(ctx); } }; } // namespace operators