From 1f680e4c3a93f34fa9cdeb732087459cac005210 Mon Sep 17 00:00:00 2001 From: eclipsess Date: Mon, 11 Jun 2018 20:51:17 +0800 Subject: [PATCH] restore lrn fix error --- src/operators/kernel/arm/lrn_kernel.cpp | 29 ++++++++++--------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/src/operators/kernel/arm/lrn_kernel.cpp b/src/operators/kernel/arm/lrn_kernel.cpp index b7dc0a2646..5ac4c67559 100644 --- a/src/operators/kernel/arm/lrn_kernel.cpp +++ b/src/operators/kernel/arm/lrn_kernel.cpp @@ -23,27 +23,22 @@ namespace operators { template <> void LrnKernel::Compute(const LrnParam ¶m) const { - const float alpha = param.Alpha(); const Tensor *input_x = param.InputX(); auto x_dims = input_x->dims(); Tensor *out = param.Out(); out->mutable_data(); - if (alpha < 0.001) { - // coarse precision - out->ShareDataWith(*input_x); - } else { - /// data_format = NCHW - const int N = x_dims[0]; - const int C = x_dims[1]; - const int H = x_dims[2]; - const int W = x_dims[3]; - - const int n = param.N(); - const float beta = param.Beta(); - const float k = param.K(); - LRNFunctor lrnFunctor; - lrnFunctor(*input_x, out, N, C, H, W, n, k, alpha, beta); - } + /// data_format = NCHW + const int N = x_dims[0]; + const int C = x_dims[1]; + const int H = x_dims[2]; + const int W = x_dims[3]; + + const int n = param.N(); + const float alpha = param.Alpha(); + const float beta = param.Beta(); + const float k = param.K(); + LRNFunctor lrnFunctor; + lrnFunctor(*input_x, out, N, C, H, W, n, k, alpha, beta); } template class LrnKernel; -- GitLab