提交 ead611e1 编写于 作者: M Megvii Engine Team

perf(dnn): slightly improve arm neon transcendental function performance

GitOrigin-RevId: 210d88f81e23efd104ff32ddb57c06b39d0e3e03
上级 0d169524
......@@ -86,11 +86,11 @@ v4sf log_ps_f32(v4sf x) {
e = vaddq_f32(e, one);
/* part2:
if( x < SQRTHF ) {
e -= 1;
x = x + x - 1.0;
} else { x = x - 1.0; }
*/
* if( x < SQRTHF ) {
* e -= 1;
* x = x + x - 1.0;
* } else { x = x - 1.0; }
*/
v4su mask = vcltq_f32(x, vdupq_n_f32(c_cephes_SQRTHF));
v4sf tmp = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(x), mask));
x = vsubq_f32(x, one);
......@@ -101,38 +101,26 @@ v4sf log_ps_f32(v4sf x) {
v4sf z = vmulq_f32(x, x);
v4sf y = vdupq_n_f32(c_cephes_log_p0);
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p1));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p2));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p3));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p4));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p5));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p6));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p7));
y = vmulq_f32(y, x);
y = vaddq_f32(y, vdupq_n_f32(c_cephes_log_p8));
y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p1), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p2), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p3), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p4), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p5), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p6), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p7), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_log_p8), y, x);
y = vmulq_f32(y, x);
y = vmulq_f32(y, z);
tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q1));
y = vaddq_f32(y, tmp);
y = fma_ps_f32(y, e, vdupq_n_f32(c_cephes_log_q1));
tmp = vmulq_f32(z, vdupq_n_f32(0.5f));
y = vsubq_f32(y, tmp);
y = vmlsq_f32(y, z, vdupq_n_f32(0.5f));
tmp = vmulq_f32(e, vdupq_n_f32(c_cephes_log_q2));
x = vaddq_f32(x, y);
x = vaddq_f32(x, tmp);
x = fma_ps_f32(x, e, vdupq_n_f32(c_cephes_log_q2));
x = vreinterpretq_f32_u32(vorrq_u32(
vreinterpretq_u32_f32(x),
invalid_mask)); // negative arg will be NAN
vreinterpretq_u32_f32(x), invalid_mask)); // negative arg will be NAN
return x;
}
......@@ -159,7 +147,7 @@ v4sf exp_ps_f32(v4sf x) {
x = vmaxq_f32(x, vdupq_n_f32(c_exp_lo));
/* express exp(x) as exp(g + n*log(2)) */
fx = vmlaq_f32(vdupq_n_f32(0.5f), x, vdupq_n_f32(c_cephes_LOG2EF));
fx = fma_ps_f32(vdupq_n_f32(0.5f), x, vdupq_n_f32(c_cephes_LOG2EF));
/* perform a floorf */
tmp = vcvtq_f32_s32(vcvtq_s32_f32(fx));
......@@ -175,34 +163,20 @@ v4sf exp_ps_f32(v4sf x) {
x = vsubq_f32(x, tmp);
x = vsubq_f32(x, z);
static const float cephes_exp_p[6] = {c_cephes_exp_p0, c_cephes_exp_p1,
c_cephes_exp_p2, c_cephes_exp_p3,
c_cephes_exp_p4, c_cephes_exp_p5};
v4sf y = vld1q_dup_f32(cephes_exp_p + 0);
v4sf c1 = vld1q_dup_f32(cephes_exp_p + 1);
v4sf c2 = vld1q_dup_f32(cephes_exp_p + 2);
v4sf c3 = vld1q_dup_f32(cephes_exp_p + 3);
v4sf c4 = vld1q_dup_f32(cephes_exp_p + 4);
v4sf c5 = vld1q_dup_f32(cephes_exp_p + 5);
y = vmulq_f32(y, x);
z = vmulq_f32(x, x);
y = vaddq_f32(y, c1);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c2);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c3);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c4);
y = vmulq_f32(y, x);
y = vaddq_f32(y, c5);
y = vmulq_f32(y, z);
y = vaddq_f32(y, x);
v4sf y = vdupq_n_f32(c_cephes_exp_p0);
y = fma_ps_f32(vdupq_n_f32(c_cephes_exp_p1), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_exp_p2), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_exp_p3), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_exp_p4), y, x);
y = fma_ps_f32(vdupq_n_f32(c_cephes_exp_p5), y, x);
y = fma_ps_f32(x, y, z);
y = vaddq_f32(y, one);
/* build 2^n */
int32x4_t mm;
v4si mm;
mm = vcvtq_s32_f32(fx);
mm = vaddq_s32(mm, vdupq_n_s32(0x7f));
mm = vshlq_n_s32(mm, 23);
......@@ -249,8 +223,9 @@ float16x8_t exp_ps_f16(float16x8_t x) {
almost no extra price so both sin_ps_f32 and cos_ps_f32 make use of
sincos_ps_f32..
*/
void sincos_ps_f32(v4sf x, v4sf* ysin, v4sf* ycos) { // any x
v4sf xmm1, xmm2, xmm3, y;
void sincos_ps_f32(v4sf x, v4sf* ysin, v4sf* ycos) {
// any x
v4sf y;
v4su emm2;
......@@ -269,44 +244,36 @@ void sincos_ps_f32(v4sf x, v4sf* ysin, v4sf* ycos) { // any x
y = vcvtq_f32_u32(emm2);
/* get the polynom selection mask
there is one polynom for 0 <= x <= Pi/4
and another one for Pi/4<x<=Pi/2
Both branches will be computed.
*/
* there is one polynom for 0 <= x <= Pi/4
* and another one for Pi/4<x<=Pi/2
*
* Both branches will be computed.
*/
v4su poly_mask = vtstq_u32(emm2, vdupq_n_u32(2));
/* The magic pass: "Extended precision modular arithmetic"
x = ((x - y * DP1) - y * DP2) - y * DP3; */
xmm1 = vmulq_n_f32(y, c_minus_cephes_DP1);
xmm2 = vmulq_n_f32(y, c_minus_cephes_DP2);
xmm3 = vmulq_n_f32(y, c_minus_cephes_DP3);
x = vaddq_f32(x, xmm1);
x = vaddq_f32(x, xmm2);
x = vaddq_f32(x, xmm3);
* x = ((x - y * DP1) - y * DP2) - y * DP3; */
x = fma_ps_f32(x, y, vdupq_n_f32(c_minus_cephes_DP1));
x = fma_ps_f32(x, y, vdupq_n_f32(c_minus_cephes_DP2));
x = fma_ps_f32(x, y, vdupq_n_f32(c_minus_cephes_DP3));
sign_mask_sin = veorq_u32(sign_mask_sin, vtstq_u32(emm2, vdupq_n_u32(4)));
sign_mask_cos = vtstq_u32(vsubq_u32(emm2, vdupq_n_u32(2)), vdupq_n_u32(4));
/* Evaluate the first polynom (0 <= x <= Pi/4) in y1,
and the second polynom (Pi/4 <= x <= 0) in y2 */
* and the second polynom (Pi/4 <= x <= 0) in y2 */
v4sf z = vmulq_f32(x, x);
v4sf y1, y2;
y1 = vmulq_n_f32(z, c_coscof_p0);
y2 = vmulq_n_f32(z, c_sincof_p0);
y1 = vaddq_f32(y1, vdupq_n_f32(c_coscof_p1));
y2 = vaddq_f32(y2, vdupq_n_f32(c_sincof_p1));
y1 = vmulq_f32(y1, z);
y2 = vmulq_f32(y2, z);
y1 = vaddq_f32(y1, vdupq_n_f32(c_coscof_p2));
y2 = vaddq_f32(y2, vdupq_n_f32(c_sincof_p2));
y1 = fma_ps_f32(vdupq_n_f32(c_coscof_p1), z, vdupq_n_f32(c_coscof_p0));
y2 = fma_ps_f32(vdupq_n_f32(c_sincof_p1), z, vdupq_n_f32(c_sincof_p0));
y1 = fma_ps_f32(vdupq_n_f32(c_coscof_p2), y1, z);
y2 = fma_ps_f32(vdupq_n_f32(c_sincof_p2), y2, z);
y1 = vmulq_f32(y1, z);
y2 = vmulq_f32(y2, z);
y1 = vmulq_f32(y1, z);
y2 = vmulq_f32(y2, x);
y1 = vsubq_f32(y1, vmulq_f32(z, vdupq_n_f32(0.5f)));
y2 = vaddq_f32(y2, x);
y1 = vmlsq_f32(y1, z, vdupq_n_f32(0.5f));
y2 = fma_ps_f32(x, y2, x);
y1 = vaddq_f32(y1, vdupq_n_f32(1));
/* select the correct result from the two polynoms */
......@@ -407,20 +374,20 @@ v4sf sigmoid_ps_f32(v4sf src) {
auto val = vmaxq_f32(vdupq_n_f32(sigmoid_constants.lower_range), src);
val = vminq_f32(vdupq_n_f32(sigmoid_constants.upper_range), val);
auto squared = vmulq_f32(val, val);
auto p = vmlaq_f32(
auto p = fma_ps_f32(
vdupq_n_f32(sigmoid_constants.alpha_7), squared,
vdupq_n_f32(sigmoid_constants.alpha_9));
p = vmlaq_f32(vdupq_n_f32(sigmoid_constants.alpha_5), p, squared);
p = vmlaq_f32(vdupq_n_f32(sigmoid_constants.alpha_3), p, squared);
p = vmlaq_f32(vdupq_n_f32(sigmoid_constants.alpha_1), p, squared);
p = fma_ps_f32(vdupq_n_f32(sigmoid_constants.alpha_5), p, squared);
p = fma_ps_f32(vdupq_n_f32(sigmoid_constants.alpha_3), p, squared);
p = fma_ps_f32(vdupq_n_f32(sigmoid_constants.alpha_1), p, squared);
p = vmulq_f32(p, val);
auto q = vmlaq_f32(
auto q = fma_ps_f32(
vdupq_n_f32(sigmoid_constants.beta_8), squared,
vdupq_n_f32(sigmoid_constants.beta_10));
q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_6), q, squared);
q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_4), q, squared);
q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_2), q, squared);
q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_0), q, squared);
q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_6), q, squared);
q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_4), q, squared);
q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_2), q, squared);
q = fma_ps_f32(vdupq_n_f32(sigmoid_constants.beta_0), q, squared);
return vaddq_f32(div_ps_f32(p, q), vdupq_n_f32(sigmoid_constants.one_half));
}
......
......@@ -54,7 +54,7 @@ v4sf cos_ps_f32(v4sf x);
v4sf tan_ps_f32(v4sf x);
static inline v4sf div_ps_f32(v4sf x, v4sf y) {
static inline v4sf div_ps_f32(v4sf& x, v4sf& y) {
#if MEGDNN_AARCH64
return vdivq_f32(x, y);
#else
......@@ -65,6 +65,12 @@ static inline v4sf div_ps_f32(v4sf x, v4sf y) {
#endif
}
#if defined(__ARM_FEATURE_FMA)
#define fma_ps_f32(c, b, a) vfmaq_f32((c), (a), (b))
#else
#define fma_ps_f32(c, b, a) vmlaq_f32((c), (a), (b))
#endif
v4sf sigmoid_ps_f32(v4sf x);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
......@@ -73,7 +79,7 @@ v4sf sigmoid_ps_f32(v4sf x);
*/
float16x8_t exp_ps_f16(float16x8_t x);
static inline float16x8_t div_ps_f16(float16x8_t x, float16x8_t y) {
static inline float16x8_t div_ps_f16(float16x8_t& x, float16x8_t& y) {
#if MEGDNN_AARCH64
return vdivq_f16(x, y);
#else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册