diff --git a/src/operators/kernel/arm/lrn_kernel.cpp b/src/operators/kernel/arm/lrn_kernel.cpp index b7dc0a2646c9594d7ef3e74d2d667f9d616cb332..5ac4c67559ebe1603230e0d50895d0702c38cb77 100644 --- a/src/operators/kernel/arm/lrn_kernel.cpp +++ b/src/operators/kernel/arm/lrn_kernel.cpp @@ -23,27 +23,22 @@ namespace operators { template <> void LrnKernel::Compute(const LrnParam ¶m) const { - const float alpha = param.Alpha(); const Tensor *input_x = param.InputX(); auto x_dims = input_x->dims(); Tensor *out = param.Out(); out->mutable_data(); - if (alpha < 0.001) { - // coarse precision - out->ShareDataWith(*input_x); - } else { - /// data_format = NCHW - const int N = x_dims[0]; - const int C = x_dims[1]; - const int H = x_dims[2]; - const int W = x_dims[3]; - - const int n = param.N(); - const float beta = param.Beta(); - const float k = param.K(); - LRNFunctor lrnFunctor; - lrnFunctor(*input_x, out, N, C, H, W, n, k, alpha, beta); - } + /// data_format = NCHW + const int N = x_dims[0]; + const int C = x_dims[1]; + const int H = x_dims[2]; + const int W = x_dims[3]; + + const int n = param.N(); + const float alpha = param.Alpha(); + const float beta = param.Beta(); + const float k = param.K(); + LRNFunctor lrnFunctor; + lrnFunctor(*input_x, out, N, C, H, W, n, k, alpha, beta); } template class LrnKernel;