From 04a7abda0a9dd576e1be7682f73120e3ddba2b55 Mon Sep 17 00:00:00 2001 From: hjchen2 Date: Thu, 10 Jan 2019 23:06:42 +0800 Subject: [PATCH] Further optimize softmax, reduce memory access --- src/operators/math/softmax.cpp | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/src/operators/math/softmax.cpp b/src/operators/math/softmax.cpp index e905ff2564..6b34f522ff 100644 --- a/src/operators/math/softmax.cpp +++ b/src/operators/math/softmax.cpp @@ -65,12 +65,14 @@ void SoftmaxBasic(const float *input, int num_classes, float *y) { // find max float max = find_max(input, num_classes); - // exp(x - max) + // exp(x - max) and sum(exp(x - max)) int remain = num_classes; + float sum = 0.f; #if defined(__ARM_NEON) || defined(__ARM_NEON__) int loop = num_classes >> 3; remain = num_classes & 0x7; float32x4_t __max = vdupq_n_f32(max); + float32x4_t __sum = vdupq_n_f32(0.f); for (int i = 0; i < loop; ++i, input += 8, output += 8) { float32x4_t x0 = vld1q_f32(input); float32x4_t x1 = vld1q_f32(input + 4); @@ -78,29 +80,17 @@ void SoftmaxBasic(const float *input, int num_classes, float *y) { x1 = vsubq_f32(x1, __max); x0 = exp_ps(x0); x1 = exp_ps(x1); - vst1q_f32(output, x0); - vst1q_f32(output + 4, x1); - } -#endif // __ARM_NEON__ - for (int i = 0; i < remain; ++i) { - output[i] = expf(input[i] - max); - } - - // sum(exp(x - max)) - float sum = 0.f; - output = y; -#if defined(__ARM_NEON) || defined(__ARM_NEON__) - float32x4_t __sum = vdupq_n_f32(0.f); - for (int i = 0; i < loop; ++i, output += 8) { - float32x4_t x0 = vld1q_f32(output); - float32x4_t x1 = vld1q_f32(output + 4); __sum = vaddq_f32(x0, __sum); __sum = vaddq_f32(x1, __sum); + vst1q_f32(output, x0); + vst1q_f32(output + 4, x1); } sum += vaddvq_f32(__sum); #endif // __ARM_NEON__ for (int i = 0; i < remain; ++i) { - sum += output[i]; + float out = expf(input[i] - max); + sum += out; + output[i] = out; } // exp(x - max) / sum -- GitLab