From 2d95527527fe3b27e06f254965c8eb4fbacb4abf Mon Sep 17 00:00:00 2001 From: Tomasz Patejko Date: Mon, 19 Mar 2018 06:10:27 -0400 Subject: [PATCH] Removing WITHIN_CHANNEL algorithm for lrn. CPU lrn operator works only with ACROSS_CHANNELS --- paddle/fluid/operators/lrn_mkldnn_op.cc | 27 ++++++-------------- paddle/fluid/operators/lrn_op.cc | 33 +++++++++++-------------- 2 files changed, 22 insertions(+), 38 deletions(-) diff --git a/paddle/fluid/operators/lrn_mkldnn_op.cc b/paddle/fluid/operators/lrn_mkldnn_op.cc index 334597ab05..a2971fcd14 100644 --- a/paddle/fluid/operators/lrn_mkldnn_op.cc +++ b/paddle/fluid/operators/lrn_mkldnn_op.cc @@ -22,18 +22,6 @@ namespace operators { using paddle::framework::Tensor; using paddle::platform::MKLDNNDeviceContext; -namespace { -mkldnn::algorithm LRNAlgorithm(const paddle::framework::ExecutionContext& ctx) { - mkldnn::algorithm algorithm = mkldnn::lrn_across_channels; - - std::string algorithm_str = ctx.Attr("algorithm"); - if (algorithm_str == "WITHIN_CHANNEL") { - algorithm = mkldnn::lrn_within_channel; - } - return algorithm; -} -} // namespace - template class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { public: @@ -64,8 +52,6 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { const float beta = ctx.Attr("beta"); const float k = ctx.Attr("k"); - auto algorithm = LRNAlgorithm(ctx); - auto e_mid = framework::EigenTensor::From(*mid); e_mid = e_mid.constant(k); @@ -77,8 +63,13 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel { auto dst_md = paddle::platform::MKLDNNMemDesc( dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw); - auto forward_desc = mkldnn::lrn_forward::desc{ - mkldnn::prop_kind::forward, algorithm, src_md, n, alpha, beta, k}; + auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward, + mkldnn::lrn_across_channels, + src_md, + n, + alpha, + beta, + k}; auto forward_pd = std::make_shared( forward_desc, mkldnn_engine); @@ -154,10 +145,8 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel { auto diff_src_memory = mkldnn::memory{{diff_src_md, mkldnn_engine}, static_cast(x_grad_data)}; - auto algorithm = LRNAlgorithm(ctx); - auto backward_desc = mkldnn::lrn_backward::desc{ - algorithm, src_md, diff_src_md, n, alpha, beta, k}; + mkldnn::lrn_across_channels, src_md, diff_src_md, n, alpha, beta, k}; auto forward_pd = dev_ctx.GetBlob(key_pd); diff --git a/paddle/fluid/operators/lrn_op.cc b/paddle/fluid/operators/lrn_op.cc index 00db09ece3..bd72f0435e 100644 --- a/paddle/fluid/operators/lrn_op.cc +++ b/paddle/fluid/operators/lrn_op.cc @@ -120,24 +120,24 @@ template struct LRNGradFunctor; template struct LRNGradFunctor; namespace { - framework::OpKernelType GetExpectedLRNKernel( - const framework::ExecutionContext& ctx) { - framework::LibraryType library_{framework::LibraryType::kPlain}; +framework::OpKernelType GetExpectedLRNKernel( + const framework::ExecutionContext& ctx) { + framework::LibraryType library_{framework::LibraryType::kPlain}; #ifdef PADDLE_WITH_MKLDNN - if (library_ == framework::LibraryType::kPlain && - platform::CanMKLDNNBeUsed(ctx)) { - library_ = framework::LibraryType::kMKLDNN; - } + if (library_ == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library_ = framework::LibraryType::kMKLDNN; + } #endif - std::string data_format = ctx.Attr("data_format"); - // TODO(pzelazko-intel): enable MKLDNN layout when it's ready - framework::DataLayout layout_ = framework::StringToDataLayout(data_format); - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), - layout_, library_); - } + std::string data_format = ctx.Attr("data_format"); + // TODO(pzelazko-intel): enable MKLDNN layout when it's ready + framework::DataLayout layout_ = framework::StringToDataLayout(data_format); + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), ctx.GetPlace(), + layout_, library_); } +} // namespace class LRNOp : public framework::OperatorWithKernel { public: @@ -214,11 +214,6 @@ class LRNOpMaker : public framework::OpProtoAndCheckerMaker { "Defaults to \"NHWC\". Specify the data format of the output data, " "the input will be transformed automatically. ") .SetDefault("AnyLayout"); - AddAttr("algorithm", - "(string default ACROSS_CHANNELS" - "An optional string: \"ACROSS_CHANNELS\", " - "\"WITHIN_CHANNEL\". Used by MKLDNN library") - .SetDefault("ACROSS_CHANNELS"); AddComment(R"DOC( Local Response Normalization Operator. -- GitLab