diff --git a/src/operators/math/softmax.cpp b/src/operators/math/softmax.cpp index e905ff25643ee120e1630e02702a91286c7c2b41..6b34f522ff6caf32c20971d9cf38f93730fdb727 100644 --- a/src/operators/math/softmax.cpp +++ b/src/operators/math/softmax.cpp @@ -65,12 +65,14 @@ void SoftmaxBasic(const float *input, int num_classes, float *y) { // find max float max = find_max(input, num_classes); - // exp(x - max) + // exp(x - max) and sum(exp(x - max)) int remain = num_classes; + float sum = 0.f; #if defined(__ARM_NEON) || defined(__ARM_NEON__) int loop = num_classes >> 3; remain = num_classes & 0x7; float32x4_t __max = vdupq_n_f32(max); + float32x4_t __sum = vdupq_n_f32(0.f); for (int i = 0; i < loop; ++i, input += 8, output += 8) { float32x4_t x0 = vld1q_f32(input); float32x4_t x1 = vld1q_f32(input + 4); @@ -78,29 +80,17 @@ void SoftmaxBasic(const float *input, int num_classes, float *y) { x1 = vsubq_f32(x1, __max); x0 = exp_ps(x0); x1 = exp_ps(x1); - vst1q_f32(output, x0); - vst1q_f32(output + 4, x1); - } -#endif // __ARM_NEON__ - for (int i = 0; i < remain; ++i) { - output[i] = expf(input[i] - max); - } - - // sum(exp(x - max)) - float sum = 0.f; - output = y; -#if defined(__ARM_NEON) || defined(__ARM_NEON__) - float32x4_t __sum = vdupq_n_f32(0.f); - for (int i = 0; i < loop; ++i, output += 8) { - float32x4_t x0 = vld1q_f32(output); - float32x4_t x1 = vld1q_f32(output + 4); __sum = vaddq_f32(x0, __sum); __sum = vaddq_f32(x1, __sum); + vst1q_f32(output, x0); + vst1q_f32(output + 4, x1); } sum += vaddvq_f32(__sum); #endif // __ARM_NEON__ for (int i = 0; i < remain; ++i) { - sum += output[i]; + float out = expf(input[i] - max); + sum += out; + output[i] = out; } // exp(x - max) / sum