diff --git a/lite/kernels/arm/lrn_compute.cc b/lite/kernels/arm/lrn_compute.cc index 18e6654282c8810a8310e540c2851fecb116f2d8..0476b1e6bde99e7993d1b0feb53ab10ba1b8f9b5 100644 --- a/lite/kernels/arm/lrn_compute.cc +++ b/lite/kernels/arm/lrn_compute.cc @@ -31,16 +31,16 @@ void LrnCompute::Run() { int channel = x_dims[1]; int h = x_dims[2]; int w = x_dims[3]; - const int local_size = param.local_size; + const int n = param.n; const float alpha = param.alpha; const float beta = param.beta; const float k = param.k; if (param.norm_region == "AcrossChannels") { lite::arm::math::compute_across_channels( - x_data, out_data, num, channel, h, w, local_size, alpha, beta, k); + x_data, out_data, num, channel, h, w, n, alpha, beta, k); } else { lite::arm::math::compute_within_channels( - x_data, out_data, num, channel, h, w, local_size, alpha, beta, k); + x_data, out_data, num, channel, h, w, n, alpha, beta, k); } } @@ -53,4 +53,5 @@ REGISTER_LITE_KERNEL( lrn, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::LrnCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("MidOut", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/kernels/arm/lrn_compute_test.cc b/lite/kernels/arm/lrn_compute_test.cc index 8e030006151c5834a68037800192ec7d9bc5d94d..e7030d00427e55c7faf333997cd90cba46260cd4 100644 --- a/lite/kernels/arm/lrn_compute_test.cc +++ b/lite/kernels/arm/lrn_compute_test.cc @@ -91,7 +91,7 @@ void lrn_compute_ref(const operators::LrnParam& param) { const dtype* x_data = param.X->data(); dtype* out_data = param.Out->mutable_data(); auto x_dims = param.X->dims(); - int local_size = param.local_size; + int local_size = param.n; float alpha = param.alpha; float beta = param.beta; float k = param.k; @@ -171,7 +171,7 @@ TEST(lrn_arm, compute) { } param.X = &x; param.Out = &output; - param.local_size = local_size; + param.n = local_size; param.alpha = alpha; param.beta = beta; param.k = k; diff --git a/lite/operators/lrn_op.cc b/lite/operators/lrn_op.cc index 34b00653f91d03f8e661fac56b5931d928be15b2..aff3e5af5566771411acf20736fdbec703f5def9 100644 --- a/lite/operators/lrn_op.cc +++ b/lite/operators/lrn_op.cc @@ -37,11 +37,13 @@ bool LrnOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { auto Out_name = opdesc.Output("Out").front(); param_.X = GetVar(scope, X_name); param_.Out = GetMutableVar(scope, Out_name); - param_.local_size = opdesc.GetAttr("local_size"); + param_.n = opdesc.GetAttr("n"); param_.alpha = opdesc.GetAttr("alpha"); param_.beta = opdesc.GetAttr("beta"); param_.k = opdesc.GetAttr("k"); - param_.norm_region = opdesc.GetAttr("norm_region"); + if (opdesc.HasAttr("norm_region")) { + param_.norm_region = opdesc.GetAttr("norm_region"); + } return true; } diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index e29bc5921697e9af7c9c495d471231c4e9aee0c6..035b3e18e81bb979dded329361799de2d99aaedb 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -522,8 +522,8 @@ struct GRUUnitParam { struct LrnParam { const lite::Tensor* X{}; lite::Tensor* Out{}; - int local_size{5}; - float alpha{1.}; + int n{5}; + float alpha{1e-4}; float beta{0.75}; float k{1.}; std::string norm_region{"AcrossChannels"}; diff --git a/lite/tests/kernels/lrn_compute_test.cc b/lite/tests/kernels/lrn_compute_test.cc index 9ee43c5c60b4703f64e7a2575ec15ba59b618052..e306155514e7423dfcfccb3d7103050b50f9fdbe 100644 --- a/lite/tests/kernels/lrn_compute_test.cc +++ b/lite/tests/kernels/lrn_compute_test.cc @@ -158,7 +158,7 @@ class LrnComputeTester : public arena::TestCase { op_desc->SetOutput("Out", {output_}); op_desc->SetAttr("alpha", alpha_); op_desc->SetAttr("beta", beta_); - op_desc->SetAttr("local_size", local_size_); + op_desc->SetAttr("n", local_size_); op_desc->SetAttr("k", k_); op_desc->SetAttr("norm_region", norm_region_); }