From c51c446221ce63890a0c099da7f26b9bfa41cb48 Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Fri, 16 Mar 2018 10:05:54 -0400 Subject: [PATCH] Content of GetExpectedKernelType moved to standalone function --- paddle/fluid/operators/lrn_op.cc | 54 ++++++++++++++------------------ 1 file changed, 24 insertions(+), 30 deletions(-) diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 6bd451a118..00db09ece3 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -119,6 +119,26 @@ struct LRNGradFunctor { template struct LRNGradFunctor; template struct LRNGradFunctor; +namespace { + framework::OpKernelType GetExpectedLRNKernel( + const framework::ExecutionContext& ctx) { + framework::LibraryType library_{framework::LibraryType::kPlain}; +#ifdef PADDLE_WITH_MKLDNN + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + } +#endif + + std::string data_format = ctx.Attr("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), + layout_, library_); + } +} + class LRNOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -140,21 +160,8 @@ class LRNOp : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - framework::LibraryType library_{framework::LibraryType::kPlain}; -#ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kMKLDNN; - } -#endif - - std::string data_format = ctx.Attr("data_format"); - // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), - layout_, library_); + const framework::ExecutionContext& ctx) const override { + return GetExpectedLRNKernel(ctx); } }; @@ -261,21 +268,8 @@ class LRNOpGrad : public framework::OperatorWithKernel { } framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const { - framework::LibraryType library_{framework::LibraryType::kPlain}; -#ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kMKLDNN; - } -#endif - - std::string data_format = ctx.Attr("data_format"); - // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), - layout_, library_); + const framework::ExecutionContext& ctx) const override { + return GetExpectedLRNKernel(ctx); } }; } // namespace operators -- GitLab