未验证 提交 94255f6c 编写于 作者: Y yiicy 提交者: GitHub

fix lrn param, align to fluid, test=develop (#2452)

上级 8373aec5
......@@ -31,16 +31,16 @@ void LrnCompute::Run() {
int channel = x_dims[1];
int h = x_dims[2];
int w = x_dims[3];
const int local_size = param.local_size;
const int n = param.n;
const float alpha = param.alpha;
const float beta = param.beta;
const float k = param.k;
if (param.norm_region == "AcrossChannels") {
lite::arm::math::compute_across_channels(
x_data, out_data, num, channel, h, w, local_size, alpha, beta, k);
x_data, out_data, num, channel, h, w, n, alpha, beta, k);
} else {
lite::arm::math::compute_within_channels(
x_data, out_data, num, channel, h, w, local_size, alpha, beta, k);
x_data, out_data, num, channel, h, w, n, alpha, beta, k);
}
}
......@@ -53,4 +53,5 @@ REGISTER_LITE_KERNEL(
lrn, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::LrnCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("MidOut", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -91,7 +91,7 @@ void lrn_compute_ref(const operators::LrnParam& param) {
const dtype* x_data = param.X->data<const dtype>();
dtype* out_data = param.Out->mutable_data<dtype>();
auto x_dims = param.X->dims();
int local_size = param.local_size;
int local_size = param.n;
float alpha = param.alpha;
float beta = param.beta;
float k = param.k;
......@@ -171,7 +171,7 @@ TEST(lrn_arm, compute) {
}
param.X = &x;
param.Out = &output;
param.local_size = local_size;
param.n = local_size;
param.alpha = alpha;
param.beta = beta;
param.k = k;
......
......@@ -37,11 +37,13 @@ bool LrnOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
auto Out_name = opdesc.Output("Out").front();
param_.X = GetVar<lite::Tensor>(scope, X_name);
param_.Out = GetMutableVar<lite::Tensor>(scope, Out_name);
param_.local_size = opdesc.GetAttr<int>("local_size");
param_.n = opdesc.GetAttr<int>("n");
param_.alpha = opdesc.GetAttr<float>("alpha");
param_.beta = opdesc.GetAttr<float>("beta");
param_.k = opdesc.GetAttr<float>("k");
param_.norm_region = opdesc.GetAttr<std::string>("norm_region");
if (opdesc.HasAttr("norm_region")) {
param_.norm_region = opdesc.GetAttr<std::string>("norm_region");
}
return true;
}
......
......@@ -522,8 +522,8 @@ struct GRUUnitParam {
struct LrnParam {
const lite::Tensor* X{};
lite::Tensor* Out{};
int local_size{5};
float alpha{1.};
int n{5};
float alpha{1e-4};
float beta{0.75};
float k{1.};
std::string norm_region{"AcrossChannels"};
......
......@@ -158,7 +158,7 @@ class LrnComputeTester : public arena::TestCase {
op_desc->SetOutput("Out", {output_});
op_desc->SetAttr("alpha", alpha_);
op_desc->SetAttr("beta", beta_);
op_desc->SetAttr("local_size", local_size_);
op_desc->SetAttr("n", local_size_);
op_desc->SetAttr("k", k_);
op_desc->SetAttr("norm_region", norm_region_);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册