From 94255f6cf0fb884eb8e1605f233bb543259332d3 Mon Sep 17 00:00:00 2001 From: yiicy Date: Tue, 19 Nov 2019 20:23:33 +0800 Subject: [PATCH] fix lrn param, align to fluid, test=develop (#2452) --- lite/kernels/arm/lrn_compute.cc | 7 ++++--- lite/kernels/arm/lrn_compute_test.cc | 4 ++-- lite/operators/lrn_op.cc | 6 ++++-- lite/operators/op_params.h | 4 ++-- lite/tests/kernels/lrn_compute_test.cc | 2 +- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/lite/kernels/arm/lrn_compute.cc b/lite/kernels/arm/lrn_compute.cc index 18e6654282..0476b1e6bd 100644 --- a/lite/kernels/arm/lrn_compute.cc +++ b/lite/kernels/arm/lrn_compute.cc @@ -31,16 +31,16 @@ void LrnCompute::Run() { int channel = x_dims[1]; int h = x_dims[2]; int w = x_dims[3]; - const int local_size = param.local_size; + const int n = param.n; const float alpha = param.alpha; const float beta = param.beta; const float k = param.k; if (param.norm_region == "AcrossChannels") { lite::arm::math::compute_across_channels( - x_data, out_data, num, channel, h, w, local_size, alpha, beta, k); + x_data, out_data, num, channel, h, w, n, alpha, beta, k); } else { lite::arm::math::compute_within_channels( - x_data, out_data, num, channel, h, w, local_size, alpha, beta, k); + x_data, out_data, num, channel, h, w, n, alpha, beta, k); } } @@ -53,4 +53,5 @@ REGISTER_LITE_KERNEL( lrn, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::LrnCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) + .BindOutput("MidOut", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/lite/kernels/arm/lrn_compute_test.cc b/lite/kernels/arm/lrn_compute_test.cc index 8e03000615..e7030d0042 100644 --- a/lite/kernels/arm/lrn_compute_test.cc +++ b/lite/kernels/arm/lrn_compute_test.cc @@ -91,7 +91,7 @@ void lrn_compute_ref(const operators::LrnParam& param) { const dtype* x_data = param.X->data(); dtype* out_data = param.Out->mutable_data(); auto x_dims = param.X->dims(); - int local_size = param.local_size; + int local_size = param.n; float alpha = param.alpha; float beta = param.beta; float k = param.k; @@ -171,7 +171,7 @@ TEST(lrn_arm, compute) { } param.X = &x; param.Out = &output; - param.local_size = local_size; + param.n = local_size; param.alpha = alpha; param.beta = beta; param.k = k; diff --git a/lite/operators/lrn_op.cc b/lite/operators/lrn_op.cc index 34b00653f9..aff3e5af55 100644 --- a/lite/operators/lrn_op.cc +++ b/lite/operators/lrn_op.cc @@ -37,11 +37,13 @@ bool LrnOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { auto Out_name = opdesc.Output("Out").front(); param_.X = GetVar(scope, X_name); param_.Out = GetMutableVar(scope, Out_name); - param_.local_size = opdesc.GetAttr("local_size"); + param_.n = opdesc.GetAttr("n"); param_.alpha = opdesc.GetAttr("alpha"); param_.beta = opdesc.GetAttr("beta"); param_.k = opdesc.GetAttr("k"); - param_.norm_region = opdesc.GetAttr("norm_region"); + if (opdesc.HasAttr("norm_region")) { + param_.norm_region = opdesc.GetAttr("norm_region"); + } return true; } diff --git a/lite/operators/op_params.h b/lite/operators/op_params.h index e29bc59216..035b3e18e8 100644 --- a/lite/operators/op_params.h +++ b/lite/operators/op_params.h @@ -522,8 +522,8 @@ struct GRUUnitParam { struct LrnParam { const lite::Tensor* X{}; lite::Tensor* Out{}; - int local_size{5}; - float alpha{1.}; + int n{5}; + float alpha{1e-4}; float beta{0.75}; float k{1.}; std::string norm_region{"AcrossChannels"}; diff --git a/lite/tests/kernels/lrn_compute_test.cc b/lite/tests/kernels/lrn_compute_test.cc index 9ee43c5c60..e306155514 100644 --- a/lite/tests/kernels/lrn_compute_test.cc +++ b/lite/tests/kernels/lrn_compute_test.cc @@ -158,7 +158,7 @@ class LrnComputeTester : public arena::TestCase { op_desc->SetOutput("Out", {output_}); op_desc->SetAttr("alpha", alpha_); op_desc->SetAttr("beta", beta_); - op_desc->SetAttr("local_size", local_size_); + op_desc->SetAttr("n", local_size_); op_desc->SetAttr("k", k_); op_desc->SetAttr("norm_region", norm_region_); } -- GitLab