提交 66745883 编写于 作者: H hjchen2

Further optimize softmax, reduce memory access

上级 8daf85c8
...@@ -65,12 +65,14 @@ void SoftmaxBasic(const float *input, int num_classes, float *y) { ...@@ -65,12 +65,14 @@ void SoftmaxBasic(const float *input, int num_classes, float *y) {
// find max // find max
float max = find_max(input, num_classes); float max = find_max(input, num_classes);
// exp(x - max) // exp(x - max) and sum(exp(x - max))
int remain = num_classes; int remain = num_classes;
float sum = 0.f;
#if defined(__ARM_NEON) || defined(__ARM_NEON__) #if defined(__ARM_NEON) || defined(__ARM_NEON__)
int loop = num_classes >> 3; int loop = num_classes >> 3;
remain = num_classes & 0x7; remain = num_classes & 0x7;
float32x4_t __max = vdupq_n_f32(max); float32x4_t __max = vdupq_n_f32(max);
float32x4_t __sum = vdupq_n_f32(0.f);
for (int i = 0; i < loop; ++i, input += 8, output += 8) { for (int i = 0; i < loop; ++i, input += 8, output += 8) {
float32x4_t x0 = vld1q_f32(input); float32x4_t x0 = vld1q_f32(input);
float32x4_t x1 = vld1q_f32(input + 4); float32x4_t x1 = vld1q_f32(input + 4);
...@@ -78,29 +80,17 @@ void SoftmaxBasic(const float *input, int num_classes, float *y) { ...@@ -78,29 +80,17 @@ void SoftmaxBasic(const float *input, int num_classes, float *y) {
x1 = vsubq_f32(x1, __max); x1 = vsubq_f32(x1, __max);
x0 = exp_ps(x0); x0 = exp_ps(x0);
x1 = exp_ps(x1); x1 = exp_ps(x1);
vst1q_f32(output, x0);
vst1q_f32(output + 4, x1);
}
#endif // __ARM_NEON__
for (int i = 0; i < remain; ++i) {
output[i] = expf(input[i] - max);
}
// sum(exp(x - max))
float sum = 0.f;
output = y;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
float32x4_t __sum = vdupq_n_f32(0.f);
for (int i = 0; i < loop; ++i, output += 8) {
float32x4_t x0 = vld1q_f32(output);
float32x4_t x1 = vld1q_f32(output + 4);
__sum = vaddq_f32(x0, __sum); __sum = vaddq_f32(x0, __sum);
__sum = vaddq_f32(x1, __sum); __sum = vaddq_f32(x1, __sum);
vst1q_f32(output, x0);
vst1q_f32(output + 4, x1);
} }
sum += vaddvq_f32(__sum); sum += vaddvq_f32(__sum);
#endif // __ARM_NEON__ #endif // __ARM_NEON__
for (int i = 0; i < remain; ++i) { for (int i = 0; i < remain; ++i) {
sum += output[i]; float out = expf(input[i] - max);
sum += out;
output[i] = out;
} }
// exp(x - max) / sum // exp(x - max) / sum
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册