未验证 提交 94255f6c 编写于 作者: Y yiicy 提交者: GitHub

fix lrn param, align to fluid, test=develop (#2452)

上级 8373aec5
...@@ -31,16 +31,16 @@ void LrnCompute::Run() { ...@@ -31,16 +31,16 @@ void LrnCompute::Run() {
int channel = x_dims[1]; int channel = x_dims[1];
int h = x_dims[2]; int h = x_dims[2];
int w = x_dims[3]; int w = x_dims[3];
const int local_size = param.local_size; const int n = param.n;
const float alpha = param.alpha; const float alpha = param.alpha;
const float beta = param.beta; const float beta = param.beta;
const float k = param.k; const float k = param.k;
if (param.norm_region == "AcrossChannels") { if (param.norm_region == "AcrossChannels") {
lite::arm::math::compute_across_channels( lite::arm::math::compute_across_channels(
x_data, out_data, num, channel, h, w, local_size, alpha, beta, k); x_data, out_data, num, channel, h, w, n, alpha, beta, k);
} else { } else {
lite::arm::math::compute_within_channels( lite::arm::math::compute_within_channels(
x_data, out_data, num, channel, h, w, local_size, alpha, beta, k); x_data, out_data, num, channel, h, w, n, alpha, beta, k);
} }
} }
...@@ -53,4 +53,5 @@ REGISTER_LITE_KERNEL( ...@@ -53,4 +53,5 @@ REGISTER_LITE_KERNEL(
lrn, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::LrnCompute, def) lrn, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::LrnCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("MidOut", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -91,7 +91,7 @@ void lrn_compute_ref(const operators::LrnParam& param) { ...@@ -91,7 +91,7 @@ void lrn_compute_ref(const operators::LrnParam& param) {
const dtype* x_data = param.X->data<const dtype>(); const dtype* x_data = param.X->data<const dtype>();
dtype* out_data = param.Out->mutable_data<dtype>(); dtype* out_data = param.Out->mutable_data<dtype>();
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
int local_size = param.local_size; int local_size = param.n;
float alpha = param.alpha; float alpha = param.alpha;
float beta = param.beta; float beta = param.beta;
float k = param.k; float k = param.k;
...@@ -171,7 +171,7 @@ TEST(lrn_arm, compute) { ...@@ -171,7 +171,7 @@ TEST(lrn_arm, compute) {
} }
param.X = &x; param.X = &x;
param.Out = &output; param.Out = &output;
param.local_size = local_size; param.n = local_size;
param.alpha = alpha; param.alpha = alpha;
param.beta = beta; param.beta = beta;
param.k = k; param.k = k;
......
...@@ -37,11 +37,13 @@ bool LrnOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { ...@@ -37,11 +37,13 @@ bool LrnOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
auto Out_name = opdesc.Output("Out").front(); auto Out_name = opdesc.Output("Out").front();
param_.X = GetVar<lite::Tensor>(scope, X_name); param_.X = GetVar<lite::Tensor>(scope, X_name);
param_.Out = GetMutableVar<lite::Tensor>(scope, Out_name); param_.Out = GetMutableVar<lite::Tensor>(scope, Out_name);
param_.local_size = opdesc.GetAttr<int>("local_size"); param_.n = opdesc.GetAttr<int>("n");
param_.alpha = opdesc.GetAttr<float>("alpha"); param_.alpha = opdesc.GetAttr<float>("alpha");
param_.beta = opdesc.GetAttr<float>("beta"); param_.beta = opdesc.GetAttr<float>("beta");
param_.k = opdesc.GetAttr<float>("k"); param_.k = opdesc.GetAttr<float>("k");
param_.norm_region = opdesc.GetAttr<std::string>("norm_region"); if (opdesc.HasAttr("norm_region")) {
param_.norm_region = opdesc.GetAttr<std::string>("norm_region");
}
return true; return true;
} }
......
...@@ -522,8 +522,8 @@ struct GRUUnitParam { ...@@ -522,8 +522,8 @@ struct GRUUnitParam {
struct LrnParam { struct LrnParam {
const lite::Tensor* X{}; const lite::Tensor* X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
int local_size{5}; int n{5};
float alpha{1.}; float alpha{1e-4};
float beta{0.75}; float beta{0.75};
float k{1.}; float k{1.};
std::string norm_region{"AcrossChannels"}; std::string norm_region{"AcrossChannels"};
......
...@@ -158,7 +158,7 @@ class LrnComputeTester : public arena::TestCase { ...@@ -158,7 +158,7 @@ class LrnComputeTester : public arena::TestCase {
op_desc->SetOutput("Out", {output_}); op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("alpha", alpha_); op_desc->SetAttr("alpha", alpha_);
op_desc->SetAttr("beta", beta_); op_desc->SetAttr("beta", beta_);
op_desc->SetAttr("local_size", local_size_); op_desc->SetAttr("n", local_size_);
op_desc->SetAttr("k", k_); op_desc->SetAttr("k", k_);
op_desc->SetAttr("norm_region", norm_region_); op_desc->SetAttr("norm_region", norm_region_);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册