提交 23c1fda7 编写于 作者: M Megvii Engine Team

perf(arm_common): optimize sigmoid

GitOrigin-RevId: 7cb248a15b447f4fbb0008e419cdf07cf6387309
上级 b20cda6b
...@@ -371,6 +371,69 @@ v4sf tan_ps_f32(v4sf x) { ...@@ -371,6 +371,69 @@ v4sf tan_ps_f32(v4sf x) {
#undef c_cephes_log_q1 #undef c_cephes_log_q1
#undef c_cephes_log_q2 #undef c_cephes_log_q2
static const struct {
float lower_range;
float upper_range;
float alpha_9;
float alpha_7;
float alpha_5;
float alpha_3;
float alpha_1;
float beta_10;
float beta_8;
float beta_6;
float beta_4;
float beta_2;
float beta_0;
float one_half;
} sigmoid_constants = {
-18.0f,
18.0f,
4.37031012579801e-11f,
1.15627324459942e-07f,
6.08574864600143e-05f,
8.51377133304701e-03f,
2.48287947061529e-01f,
6.10247389755681e-13f,
5.76102136993427e-09f,
6.29106785017040e-06f,
1.70198817374094e-03f,
1.16817656904453e-01f,
9.93151921023180e-01f,
0.5f,
};
v4sf sigmoid_ps_f32(v4sf src) {
auto val = vmaxq_f32(vdupq_n_f32(sigmoid_constants.lower_range), src);
val = vminq_f32(vdupq_n_f32(sigmoid_constants.upper_range), val);
auto squared = vmulq_f32(val, val);
auto p = vmlaq_f32(
vdupq_n_f32(sigmoid_constants.alpha_7), squared,
vdupq_n_f32(sigmoid_constants.alpha_9));
p = vmlaq_f32(vdupq_n_f32(sigmoid_constants.alpha_5), p, squared);
p = vmlaq_f32(vdupq_n_f32(sigmoid_constants.alpha_3), p, squared);
p = vmlaq_f32(vdupq_n_f32(sigmoid_constants.alpha_1), p, squared);
p = vmulq_f32(p, val);
auto q = vmlaq_f32(
vdupq_n_f32(sigmoid_constants.beta_8), squared,
vdupq_n_f32(sigmoid_constants.beta_10));
q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_6), q, squared);
q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_4), q, squared);
q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_2), q, squared);
q = vmlaq_f32(vdupq_n_f32(sigmoid_constants.beta_0), q, squared);
return vaddq_f32(div_ps_f32(p, q), vdupq_n_f32(sigmoid_constants.one_half));
}
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
float16x8_t sigmoid_ps_f16(float16x8_t x) {
float32x4_t low = vcvt_f32_f16(vget_low_f16(x));
float32x4_t high = vcvt_f32_f16(vget_high_f16(x));
low = sigmoid_ps_f32(low);
high = sigmoid_ps_f32(high);
return vcombine_f16(vcvt_f16_f32(low), vcvt_f16_f32(high));
}
#endif
} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn
......
...@@ -54,11 +54,38 @@ v4sf cos_ps_f32(v4sf x); ...@@ -54,11 +54,38 @@ v4sf cos_ps_f32(v4sf x);
v4sf tan_ps_f32(v4sf x); v4sf tan_ps_f32(v4sf x);
static inline v4sf div_ps_f32(v4sf x, v4sf y) {
#if MEGDNN_AARCH64
return vdivq_f32(x, y);
#else
//! armv7 not support vdiv, so compute the reciprocal and iterate again
float32x4_t recp = vrecpeq_f32(y);
recp = vmulq_f32(vrecpsq_f32(y, recp), recp);
return vmulq_f32(x, recp);
#endif
}
v4sf sigmoid_ps_f32(v4sf x);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/** /**
* \brief compute for 8 half at once, the inner just invoke exp_ps_f32 twice * \brief compute for 8 half at once, the inner just invoke exp_ps_f32 twice
*/ */
float16x8_t exp_ps_f16(float16x8_t x); float16x8_t exp_ps_f16(float16x8_t x);
static inline float16x8_t div_ps_f16(float16x8_t x, float16x8_t y) {
#if MEGDNN_AARCH64
return vdivq_f16(x, y);
#else
//! armv7 not support vdiv, so compute the reciprocal and iterate again
float16x8_t recp = vrecpeq_f16(y);
recp = vmulq_f16(vrecpsq_f16(y, recp), recp);
return vmulq_f16(x, recp);
#endif
}
float16x8_t sigmoid_ps_f16(float16x8_t x);
#endif #endif
} // namespace arm_common } // namespace arm_common
......
...@@ -47,24 +47,14 @@ struct FuseAddSigmoidOp; ...@@ -47,24 +47,14 @@ struct FuseAddSigmoidOp;
vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \ vst1q_##_func_suffix(dst + SIMD_WIDTH, vitem.val[1]); \
} \ } \
_neon_type operator()(const _neon_type& src0, const _neon_type& src1) const { \ _neon_type operator()(const _neon_type& src0, const _neon_type& src1) const { \
auto zero_val = vdupq_n_##_func_suffix(0.f); \
auto one_val = vdupq_n_##_func_suffix(1.f); \
auto val1 = src0.val[0]; \ auto val1 = src0.val[0]; \
auto val2 = src0.val[1]; \ auto val2 = src0.val[1]; \
auto val3 = src1.val[0]; \ auto val3 = src1.val[0]; \
auto val4 = src1.val[1]; \ auto val4 = src1.val[1]; \
val1 = vaddq_##_func_suffix(val1, val3); \ val1 = vaddq_##_func_suffix(val1, val3); \
val2 = vaddq_##_func_suffix(val2, val4); \ val2 = vaddq_##_func_suffix(val2, val4); \
val1 = vsubq_##_func_suffix(zero_val, val1); \ val1 = sigmoid_ps_##_func_suffix(val1); \
val2 = vsubq_##_func_suffix(zero_val, val2); \ val2 = sigmoid_ps_##_func_suffix(val2); \
val1 = exp_ps_##_func_suffix(val1); \
val2 = exp_ps_##_func_suffix(val2); \
auto recipe1 = vaddq_##_func_suffix(one_val, val1); \
auto recipe2 = vaddq_##_func_suffix(one_val, val2); \
val1 = vrecpeq_##_func_suffix(recipe1); \
val2 = vrecpeq_##_func_suffix(recipe2); \
val1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(recipe1, val1), val1); \
val2 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(recipe2, val2), val2); \
return {{val1, val2}}; \ return {{val1, val2}}; \
} \ } \
}; };
......
...@@ -52,14 +52,7 @@ struct SigmoidOp; ...@@ -52,14 +52,7 @@ struct SigmoidOp;
return {{operator()(src.val[0]), operator()(src.val[1])}}; \ return {{operator()(src.val[0]), operator()(src.val[1])}}; \
} \ } \
_neon_type operator()(const _neon_type& src) const { \ _neon_type operator()(const _neon_type& src) const { \
auto zero_val = vdupq_n_##_func_suffix(0.f); \ return sigmoid_ps_##_func_suffix(src); \
auto one_val = vdupq_n_##_func_suffix(1.f); \
auto val1 = vsubq_##_func_suffix(zero_val, src); \
val1 = exp_ps_##_func_suffix(val1); \
auto recipe1 = vaddq_##_func_suffix(one_val, val1); \
val1 = vrecpeq_##_func_suffix(recipe1); \
val1 = vmulq_##_func_suffix(vrecpsq_##_func_suffix(recipe1, val1), val1); \
return val1; \
} \ } \
}; };
OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4) OP(dt_float32, float32x4_t, float32x4x2_t, f32, 4)
......
...@@ -318,7 +318,7 @@ def test_add_remove_output(): ...@@ -318,7 +318,7 @@ def test_add_remove_output():
out = g.run(a.numpy(), b.numpy()) out = g.run(a.numpy(), b.numpy())
np.testing.assert_equal(out["new_o1"], ((a + b) * 3).numpy()) np.testing.assert_equal(out["new_o1"], ((a + b) * 3).numpy())
np.testing.assert_equal(out["new_o2"], (F.sigmoid((a + b))).numpy()) np.testing.assert_almost_equal(out["new_o2"], (F.sigmoid((a + b))).numpy())
def test_query(): def test_query():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册