提交 66745883 编写于 作者: H hjchen2

Further optimize softmax, reduce memory access

上级 8daf85c8
......@@ -65,12 +65,14 @@ void SoftmaxBasic(const float *input, int num_classes, float *y) {
// find max
float max = find_max(input, num_classes);
// exp(x - max)
// exp(x - max) and sum(exp(x - max))
int remain = num_classes;
float sum = 0.f;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
int loop = num_classes >> 3;
remain = num_classes & 0x7;
float32x4_t __max = vdupq_n_f32(max);
float32x4_t __sum = vdupq_n_f32(0.f);
for (int i = 0; i < loop; ++i, input += 8, output += 8) {
float32x4_t x0 = vld1q_f32(input);
float32x4_t x1 = vld1q_f32(input + 4);
......@@ -78,29 +80,17 @@ void SoftmaxBasic(const float *input, int num_classes, float *y) {
x1 = vsubq_f32(x1, __max);
x0 = exp_ps(x0);
x1 = exp_ps(x1);
vst1q_f32(output, x0);
vst1q_f32(output + 4, x1);
}
#endif // __ARM_NEON__
for (int i = 0; i < remain; ++i) {
output[i] = expf(input[i] - max);
}
// sum(exp(x - max))
float sum = 0.f;
output = y;
#if defined(__ARM_NEON) || defined(__ARM_NEON__)
float32x4_t __sum = vdupq_n_f32(0.f);
for (int i = 0; i < loop; ++i, output += 8) {
float32x4_t x0 = vld1q_f32(output);
float32x4_t x1 = vld1q_f32(output + 4);
__sum = vaddq_f32(x0, __sum);
__sum = vaddq_f32(x1, __sum);
vst1q_f32(output, x0);
vst1q_f32(output + 4, x1);
}
sum += vaddvq_f32(__sum);
#endif // __ARM_NEON__
for (int i = 0; i < remain; ++i) {
sum += output[i];
float out = expf(input[i] - max);
sum += out;
output[i] = out;
}
// exp(x - max) / sum
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册