From b314b3543de648bbc0cb108df6a9b381052e3a06 Mon Sep 17 00:00:00 2001 From: Kenji Mouri Date: Wed, 15 Mar 2023 19:33:37 +0800 Subject: [PATCH] Add SSE2 implementation of acos in x86 targets. (#4573) --- src/layer/x86/sse_mathfun.h | 90 +++++++++++++++++++++++++++++++++++ src/layer/x86/unaryop_x86.cpp | 9 +--- 2 files changed, 91 insertions(+), 8 deletions(-) diff --git a/src/layer/x86/sse_mathfun.h b/src/layer/x86/sse_mathfun.h index c6c5e19f..d09e62eb 100644 --- a/src/layer/x86/sse_mathfun.h +++ b/src/layer/x86/sse_mathfun.h @@ -921,4 +921,94 @@ static NCNN_FORCEINLINE __m128 asin_ps(__m128 x) return _mm_or_ps(final_approx, negative_mask); } +static NCNN_FORCEINLINE __m128 acos_ps(__m128 x) +{ + const __m128 magic_negative_zero = _mm_set_ps1(-0.0f); + const __m128 magic_zero = _mm_set_ps1(0.0f); + const __m128 magic_half_one = _mm_set_ps1(0.5f); + const __m128 magic_one = _mm_set_ps1(1.0f); + const __m128 magic_a4 = _mm_set_ps1(0.023994016f); + const __m128 magic_a5 = _mm_set_ps1(0.042417344f); + const __m128 magic_a2 = _mm_set_ps1(0.07494697f); + const __m128 magic_a3 = _mm_set_ps1(0.045520633f); + const __m128 magic_a0 = _mm_set_ps1(1.0f); + const __m128 magic_a1 = _mm_set_ps1(0.166667819f); + const __m128 magic_half_pi = _mm_set_ps1(1.5707964f); + const __m128 magic_pi = _mm_set_ps1(3.1415927f); + + // negative_mask = magic_negative_zero && x; + __m128 negative_mask = _mm_and_ps(magic_negative_zero, x); + + // absolute = abs(x); + __m128 absolute = _mm_andnot_ps(magic_negative_zero, x); + + // Reference: https://en.wikipedia.org/wiki/Small-angle_approximation + + // is_small_input = (absolute <= 0.5f); + __m128 is_small_input = _mm_cmple_ps(absolute, magic_half_one); + + // big_input_approx = sqrt(0.5f * (1 - absolute)); + __m128 big_input_approx = _mm_sqrt_ps(_mm_mul_ps( + magic_half_one, + _mm_sub_ps(magic_one, absolute))); + + // input_approx = (is_small_input ? absolute : big_input_approx); + __m128 input_approx = _mm_or_ps( + _mm_and_ps(is_small_input, absolute), + _mm_andnot_ps(is_small_input, big_input_approx)); + + // square_of_input_approx = input_approx * input_approx; + __m128 square_of_input_approx = _mm_mul_ps(input_approx, input_approx); + + // fourth_power_of_input_approx = + // square_of_input_approx * square_of_input_approx; + __m128 fourth_power_of_input_approx = _mm_mul_ps( + square_of_input_approx, square_of_input_approx); + + // TODO: Need more explanations. + // x1 = ((magic_a4 * fourth_power_of_input_approx) + magic_a2); + // x2 = ((magic_a5 * fourth_power_of_input_approx) + magic_a3); + // x3 = ((x1 * fourth_power_of_input_approx) + magic_a0); + // x4 = ((fourth_power_of_input_approx * x2) + magic_a1); + // output_approx = (x3 + (square_of_input_approx * x4)); + __m128 output_approx = _mm_add_ps( + _mm_add_ps( + _mm_mul_ps( + _mm_add_ps( + _mm_mul_ps(magic_a4, fourth_power_of_input_approx), + magic_a2), + fourth_power_of_input_approx), + magic_a0), + _mm_mul_ps( + square_of_input_approx, + _mm_add_ps( + _mm_mul_ps( + fourth_power_of_input_approx, + _mm_add_ps( + _mm_mul_ps(magic_a5, fourth_power_of_input_approx), + magic_a3)), + magic_a1))); + + // TODO: Need more explanations. + // x1 = (output_approx * input_approx); + __m128 x1 = _mm_mul_ps(output_approx, input_approx); + + // TODO: Need more explanations. + // small_final_approx = ((0.5 * PI) - (x1 | negative_mask)); + __m128 small_final_approx = _mm_sub_ps( + magic_half_pi, + _mm_or_ps(x1, negative_mask)); + + // TODO: Need more explanations. + // big_final_approx = (((x < 0.0f) & PI) + ((x1 * 2) | negative_mask)); + __m128 big_final_approx = _mm_add_ps( + _mm_and_ps(_mm_cmplt_ps(x, magic_zero), magic_pi), + _mm_or_ps(_mm_add_ps(x1, x1), negative_mask)); + + // return (is_small_input ? small_final_approx : big_final_approx); + return _mm_or_ps( + _mm_and_ps(is_small_input, small_final_approx), + _mm_andnot_ps(is_small_input, big_final_approx)); +} + #endif // SSE_MATHFUN_H diff --git a/src/layer/x86/unaryop_x86.cpp b/src/layer/x86/unaryop_x86.cpp index ecb361f4..8e8c8444 100644 --- a/src/layer/x86/unaryop_x86.cpp +++ b/src/layer/x86/unaryop_x86.cpp @@ -464,14 +464,7 @@ struct unary_op_acos #if __SSE2__ __m128 func_pack4(const __m128& x) const { - //TODO sse optimize - float tmp[4]; - _mm_storeu_ps(tmp, x); - tmp[0] = acos(tmp[0]); - tmp[1] = acos(tmp[1]); - tmp[2] = acos(tmp[2]); - tmp[3] = acos(tmp[3]); - return _mm_loadu_ps(tmp); + return acos_ps(x); } #if __AVX__ __m256 func_pack8(const __m256& x) const -- GitLab