diff --git a/lite/backends/arm/math/activation.cc b/lite/backends/arm/math/activation.cc index 805b87da09a3f07edbb0591ac9d1f9eb4488ab89..4df95c9d074f4b848357eafcb13930f9f3f7450a 100644 --- a/lite/backends/arm/math/activation.cc +++ b/lite/backends/arm/math/activation.cc @@ -771,7 +771,7 @@ void act_elu( int thread_remain = size % threads; int neon_loop_cnt_dim16 = nums_per_thread >> 4; int neon_loop_remain_dim16 = nums_per_thread & 15; - float32x4_t alpha = vdupq_n_f32(alpha); + float32x4_t valpha = vdupq_n_f32(alpha); float32x4_t vzero = vdupq_n_f32(0.f); float32x4_t vone = vdupq_n_f32(1.f); int cnt = neon_loop_remain_dim16 >> 2; @@ -797,6 +797,10 @@ void act_elu( float32x4_t vb_sub = vsubq_f32(vb_exp, vone); float32x4_t vc_sub = vsubq_f32(vc_exp, vone); float32x4_t vd_sub = vsubq_f32(vd_exp, vone); + va_sub = vmulq_f32(va_sub, valpha); + vb_sub = vmulq_f32(vb_sub, valpha); + vc_sub = vmulq_f32(vc_sub, valpha); + vd_sub = vmulq_f32(vd_sub, valpha); float32x4_t va_min = vminq_f32(va_sub, vzero); float32x4_t vb_min = vminq_f32(vb_sub, vzero); float32x4_t vc_min = vminq_f32(vc_sub, vzero); @@ -817,6 +821,7 @@ void act_elu( float32x4_t va_exp = exp_ps(va); float32x4_t va_max = vmaxq_f32(va, vzero); float32x4_t va_sub = vsubq_f32(va_exp, vone); + va_sub = vmulq_f32(va_sub, valpha); float32x4_t va_min = vminq_f32(va_sub, vzero); float32x4_t va_rst = vaddq_f32(va_max, va_min); vst1q_f32(ptr_out_thread, va_rst); @@ -825,7 +830,7 @@ void act_elu( } for (int j = 0; j < remain; j++) { float beta = alpha * (expf(ptr_in_thread[0]) - 1); - float max = ptr_in[0] >= 0.f ? ptr_in_thread[0] : 0.f; + float max = ptr_in_thread[0] >= 0.f ? ptr_in_thread[0] : 0.f; float min = beta <= 0.f ? beta : 0.f; ptr_out_thread[0] = min + max; ptr_in_thread++; diff --git a/lite/kernels/arm/activation_compute.cc b/lite/kernels/arm/activation_compute.cc index a5da9dfec8ba9d93db2c493a81ae08da536840c7..c692d354630534859b7a4aaebf59269284a8ecd4 100644 --- a/lite/kernels/arm/activation_compute.cc +++ b/lite/kernels/arm/activation_compute.cc @@ -375,4 +375,4 @@ REGISTER_LITE_KERNEL(elu, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) - .Finalize(); \ No newline at end of file + .Finalize(); diff --git a/lite/operators/activation_ops.cc b/lite/operators/activation_ops.cc index 19a2134440d4d47adc5324e56e08f08457d1acf8..c519016aa89cda0ba2a3e9aa07766f89a756d3c6 100644 --- a/lite/operators/activation_ops.cc +++ b/lite/operators/activation_ops.cc @@ -87,7 +87,7 @@ bool ActivationOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) { param_.relu_threshold = opdesc.GetAttr("threshold"); } else if (opdesc.Type() == "elu") { param_.active_type = lite_api::ActivationType::kElu; - param_.param_.Elu_alpha = opdesc.GetAttr("alpha"); + param_.Elu_alpha = opdesc.GetAttr("alpha"); } VLOG(4) << "opdesc.Type():" << opdesc.Type(); diff --git a/lite/tests/kernels/activation_compute_test.cc b/lite/tests/kernels/activation_compute_test.cc index a62c698f83fe10409af0bba8774135d3409358ea..5525dad9123f5f312fd1036efb7aa3932b2db7b6 100644 --- a/lite/tests/kernels/activation_compute_test.cc +++ b/lite/tests/kernels/activation_compute_test.cc @@ -39,7 +39,8 @@ enum activation_type_test { SQUARE, HARD_SWISH, RECIPROCAL, - THRESHOLDED_RELU + THRESHOLDED_RELU, + ELU }; class ActivationComputeTester : public arena::TestCase { @@ -56,6 +57,7 @@ class ActivationComputeTester : public arena::TestCase { float hard_swish_scale = 6.0; float hard_swish_offset = 3.0; float relu_threshold_ = 1.0; + float elu_alpha_ = 1.0; DDim dims_{{1}}; std::string type_ = ""; activation_type_test act_type_ = RELU; @@ -67,6 +69,7 @@ class ActivationComputeTester : public arena::TestCase { float relu_clipped_coef, std::string prelu_mode, float swish_beta, + float elu_alpha, DDim dims, std::string type, activation_type_test act_type) @@ -75,6 +78,7 @@ class ActivationComputeTester : public arena::TestCase { relu_clipped_coef_(relu_clipped_coef), prelu_mode_(prelu_mode), swish_beta_(swish_beta), + elu_alpha_(elu_alpha), dims_(dims), type_(type), act_type_(act_type) {} @@ -87,6 +91,7 @@ class ActivationComputeTester : public arena::TestCase { auto* x = scope->FindTensor(input_); const auto* x_data = x->data(); + LOG(INFO) << act_type_; switch (act_type_) { case RELU: { for (int i = 0; i < dims_.production(); i++) { @@ -226,8 +231,17 @@ class ActivationComputeTester : public arena::TestCase { } break; } + case ELU: { + for (int i = 0; i < dims_.production(); i++) { + float tmp = std::exp(x_data[i]) - 1; + float max = x_data[i] > 0.f ? x_data[i] : 0.f; + float min = x_data[i] < 0.f ? elu_alpha_ * tmp : 0.f; + output_data[i] = min + max; + } + } + break; default: - LOG(INFO) << "the type of activation is unknow."; + LOG(INFO) << "the type of activation " << act_type_ << " is unknow."; } } @@ -256,6 +270,9 @@ class ActivationComputeTester : public arena::TestCase { if (act_type_ == THRESHOLDED_RELU) { op_desc->SetAttr("threshold", relu_threshold_); } + if (act_type_ == ELU) { + op_desc->SetAttr("alpha", elu_alpha_); + } } void PrepareData() override { @@ -309,7 +326,7 @@ TEST(Activation_relu, precision) { for (auto dims : std::vector>{ {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { std::unique_ptr tester(new ActivationComputeTester( - place, "def", 0.01, 6., "all", 0., DDim(dims), "relu", RELU)); + place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "relu", RELU)); arena::Arena arena(std::move(tester), place, abs_error); arena.TestPrecision(); } @@ -338,6 +355,7 @@ TEST(Activation_leaky_relu, precision) { 6., "all", 0., + 1.0, DDim(dims), "leaky_relu", LEAKY_RELU)); @@ -370,6 +388,7 @@ TEST(Activation_relu_clipped, precision) { coef, "all", 0., + 1.0, DDim(dims), "relu_clipped", RELU_CLIPPED)); @@ -387,7 +406,7 @@ TEST(Activation_prelu, precision) { for (auto dims : std::vector>{{1, 3, 2, 4}}) { for (auto mode : {"all", "channel", "element"}) { std::unique_ptr tester(new ActivationComputeTester( - place, "def", 0.01, 6, mode, 0., DDim(dims), "prelu", PRELU)); + place, "def", 0.01, 6, mode, 0., 1.0, DDim(dims), "prelu", PRELU)); arena::Arena arena(std::move(tester), place, 2e-5); arena.TestPrecision(); } @@ -411,7 +430,7 @@ TEST(Activation_sigmoid, precision) { for (auto dims : std::vector>{ {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { std::unique_ptr tester(new ActivationComputeTester( - place, "def", 0.01, 6., "all", 0., DDim(dims), "sigmoid", SIGMOID)); + place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "sigmoid", SIGMOID)); arena::Arena arena(std::move(tester), place, abs_error); arena.TestPrecision(); } @@ -435,7 +454,7 @@ TEST(Activation_tanh, precision) { for (auto dims : std::vector>{ {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { std::unique_ptr tester(new ActivationComputeTester( - place, "def", 0.01, 6., "all", 0., DDim(dims), "tanh", TANH)); + place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "tanh", TANH)); arena::Arena arena(std::move(tester), place, abs_error); arena.TestPrecision(); } @@ -450,7 +469,7 @@ TEST(Activation_swish, precision) { {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { for (auto coef : {0.01, 0.1}) { std::unique_ptr tester(new ActivationComputeTester( - place, "def", 0.01, 6, "all", coef, DDim(dims), "swish", SWISH)); + place, "def", 0.01, 6, "all", coef, 1.0, DDim(dims), "swish", SWISH)); arena::Arena arena(std::move(tester), place, 2e-5); arena.TestPrecision(); } @@ -474,7 +493,7 @@ TEST(Activation_relu6, precision) { for (auto dims : std::vector>{ {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { std::unique_ptr tester(new ActivationComputeTester( - place, "def", 0.01, 6., "all", 0., DDim(dims), "relu6", RELU6)); + place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "relu6", RELU6)); arena::Arena arena(std::move(tester), place, abs_error); arena.TestPrecision(); } @@ -496,7 +515,7 @@ TEST(Activation_log, precision) { for (auto dims : std::vector>{ {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { std::unique_ptr tester(new ActivationComputeTester( - place, "def", 0.01, 6., "all", 0., DDim(dims), "log", LOG)); + place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "log", LOG)); arena::Arena arena(std::move(tester), place, abs_error); arena.TestPrecision(); } @@ -510,7 +529,7 @@ TEST(Activation_exp, precision) { for (auto dims : std::vector>{ {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { std::unique_ptr tester(new ActivationComputeTester( - place, "def", 0.01, 6., "all", 0., DDim(dims), "exp", EXP)); + place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "exp", EXP)); arena::Arena arena(std::move(tester), place, 2e-5); arena.TestPrecision(); } @@ -524,7 +543,7 @@ TEST(Activation_floor, precision) { for (auto dims : std::vector>{ {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { std::unique_ptr tester(new ActivationComputeTester( - place, "def", 0.01, 6., "all", 0., DDim(dims), "floor", FLOOR)); + place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "floor", FLOOR)); arena::Arena arena(std::move(tester), place, 2e-5); arena.TestPrecision(); } @@ -539,7 +558,7 @@ TEST(Activation_rsqrt, precision) { for (auto dims : std::vector>{ {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { std::unique_ptr tester(new ActivationComputeTester( - place, "def", 0.01, 6., "all", 0., DDim(dims), "rsqrt", RSQRT)); + place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "rsqrt", RSQRT)); arena::Arena arena(std::move(tester), place, 2e-5); arena.TestPrecision(); } @@ -562,7 +581,7 @@ TEST(Activation_square, precision) { for (auto dims : std::vector>{ {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { std::unique_ptr tester(new ActivationComputeTester( - place, "def", 0.01, 6., "all", 0., DDim(dims), "square", SQUARE)); + place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "square", SQUARE)); arena::Arena arena(std::move(tester), place, abs_error); arena.TestPrecision(); } @@ -581,7 +600,7 @@ TEST(Activation_gelu, precision) { for (auto dims : std::vector>{ {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { std::unique_ptr tester(new ActivationComputeTester( - place, "def", 0.01, 6., "all", 0., DDim(dims), "gelu", GELU)); + place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "gelu", GELU)); arena::Arena arena(std::move(tester), place, abs_error); arena.TestPrecision(); } @@ -607,6 +626,7 @@ TEST(activation_hard_swish, precision) { 6., "all", 0., + 1.0, DDim(dims), "hard_swish", HARD_SWISH)); @@ -635,6 +655,7 @@ TEST(activation_reciprocal, precision) { 6., "all", 0., + 1.0, DDim(dims), "reciprocal", RECIPROCAL)); @@ -665,6 +686,7 @@ TEST(Activation_thresholded_relu, precision) { 6., "all", 0., + 1.0, DDim(dims), "thresholded_relu", THRESHOLDED_RELU)); @@ -673,5 +695,20 @@ TEST(Activation_thresholded_relu, precision) { } } +TEST(Activation_elu, precision) { + LOG(INFO) << "test elu op"; +#ifdef LITE_WITH_ARM + Place place(TARGET(kARM)); + + for (auto dims : std::vector>{ + {1, 3, 2, 4}, {2, 3, 4}, {5, 4}, {8}}) { + std::unique_ptr tester(new ActivationComputeTester( + place, "def", 0.01, 6., "all", 0., 1.0, DDim(dims), "elu", ELU)); + arena::Arena arena(std::move(tester), place, 2e-5); + arena.TestPrecision(); + } +#endif +} + } // namespace lite } // namespace paddle