提交 3a631fbb 编写于 作者: C Chunwei

prune

上级 0f9e7057
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/activation.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void act_relu<float>(const float* din, float* dout, int size, int threads) {
int nums_per_thread = size / threads;
int remain = size - threads * nums_per_thread;
int neon_loop_cnt = nums_per_thread >> 4;
int neon_loop_remain = nums_per_thread - (neon_loop_cnt << 4);
float32x4_t vzero = vdupq_n_f32(0.f);
#pragma omp parallel for
for (int i = 0; i < threads; ++i) {
const float* ptr_in_thread = din + i * nums_per_thread;
float* ptr_out_thread = dout + i * nums_per_thread;
int cnt = neon_loop_cnt;
#ifdef __aarch64__
for (int num = 0; num < neon_loop_cnt; ++num) {
float32x4_t vr0 = vld1q_f32(ptr_in_thread);
ptr_in_thread += 4;
float32x4_t vr1 = vld1q_f32(ptr_in_thread);
ptr_in_thread += 4;
float32x4_t vr2 = vld1q_f32(ptr_in_thread);
ptr_in_thread += 4;
float32x4_t vr3 = vld1q_f32(ptr_in_thread);
ptr_in_thread += 4;
vr0 = vmaxq_f32(vr0, vzero);
vr1 = vmaxq_f32(vr1, vzero);
vr2 = vmaxq_f32(vr2, vzero);
vr3 = vmaxq_f32(vr3, vzero);
vst1q_f32(ptr_out_thread, vr0);
ptr_out_thread += 4;
vst1q_f32(ptr_out_thread, vr1);
ptr_out_thread += 4;
vst1q_f32(ptr_out_thread, vr2);
ptr_out_thread += 4;
vst1q_f32(ptr_out_thread, vr3);
ptr_out_thread += 4;
}
#else
if (cnt > 0) {
asm volatile(
"1: @ loop header\n"
"vld1.32 {d0-d3}, [%[din]]! @ load din 0\n"
"vld1.32 {d4-d7}, [%[din]]! @ load din 0\n"
"vmax.f32 q8, q0, %q[vzero] @ relu\n"
"vmax.f32 q9, q1, %q[vzero] @ relu\n"
"vmax.f32 q10, q2, %q[vzero] @ relu\n"
"vmax.f32 q11, q3, %q[vzero] @ relu\n"
"vst1.32 {d16-d19}, [%[dout]]! @ store result, add pointer\n"
"vst1.32 {d20-d23}, [%[dout]]! @ store result, add pointer\n"
"subs %[cnt], #1 @ loop count minus 1\n"
"bne 1b @ jump to main loop start "
"point\n"
: [dout] "+r"(ptr_out_thread), [din] "+r"(ptr_in_thread),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero)
: "cc", "memory", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11");
}
#endif
for (int j = 0; j < neon_loop_remain; ++j) {
ptr_out_thread[0] = ptr_in_thread[0] > 0.f ? ptr_in_thread[0] : 0.f;
ptr_in_thread++;
ptr_out_thread++;
}
}
float* out_ptr_remain = dout + threads * nums_per_thread;
const float* in_ptr_remain = din + threads * nums_per_thread;
for (int j = 0; j < remain; ++j) {
out_ptr_remain[0] = in_ptr_remain[0] > 0.f ? in_ptr_remain[0] : 0.f;
in_ptr_remain++;
out_ptr_remain++;
}
}
template <>
void act_relu_neg<float>(const float* din, float* dout, int size,
const float negative_slope, int threads) {
int nums_per_thread = size / threads;
int remain = size - threads * nums_per_thread;
int neon_loop_cnt = nums_per_thread >> 4;
int neon_loop_remain = nums_per_thread - (neon_loop_cnt << 4);
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t valpha = vdupq_n_f32(negative_slope);
#pragma omp parallel for
for (int i = 0; i < threads; ++i) {
const float* ptr_in_thread = din + i * nums_per_thread;
float* ptr_out_thread = dout + i * nums_per_thread;
int cnt = neon_loop_cnt;
#ifdef __aarch64__
for (int num = 0; num < neon_loop_cnt; ++num) {
float32x4_t vr0 = vld1q_f32(ptr_in_thread);
ptr_in_thread += 4;
float32x4_t vr1 = vld1q_f32(ptr_in_thread);
ptr_in_thread += 4;
float32x4_t vr2 = vld1q_f32(ptr_in_thread);
ptr_in_thread += 4;
float32x4_t vr3 = vld1q_f32(ptr_in_thread);
ptr_in_thread += 4;
uint32x4_t vm0 = vcgeq_f32(vr0, vzero);
uint32x4_t vm1 = vcgeq_f32(vr1, vzero);
uint32x4_t vm2 = vcgeq_f32(vr2, vzero);
uint32x4_t vm3 = vcgeq_f32(vr3, vzero);
float32x4_t vn0 = vmulq_f32(vr0, valpha);
float32x4_t vn1 = vmulq_f32(vr1, valpha);
float32x4_t vn2 = vmulq_f32(vr2, valpha);
float32x4_t vn3 = vmulq_f32(vr3, valpha);
float32x4_t vo0 = vbslq_f32(vm0, vr0, vn0);
float32x4_t vo1 = vbslq_f32(vm1, vr1, vn1);
float32x4_t vo2 = vbslq_f32(vm2, vr2, vn2);
float32x4_t vo3 = vbslq_f32(vm3, vr3, vn3);
vst1q_f32(ptr_out_thread, vo0);
ptr_out_thread += 4;
vst1q_f32(ptr_out_thread, vo1);
ptr_out_thread += 4;
vst1q_f32(ptr_out_thread, vo2);
ptr_out_thread += 4;
vst1q_f32(ptr_out_thread, vo3);
ptr_out_thread += 4;
}
#else
if (cnt > 0) {
asm volatile(
"1: @ loop header\n"
"vld1.32 {d0-d3}, [%[din]]! @ load din 0\n"
"vld1.32 {d4-d7}, [%[din]]! @ load din 0\n"
"vcge.f32 q8, q0, %q[vzero] @ get mask\n"
"vcge.f32 q9, q1, %q[vzero] @ get mask\n"
"vcge.f32 q10, q2, %q[vzero] @ get mask\n"
"vcge.f32 q11, q3, %q[vzero] @ get mask\n"
"vmul.f32 q4, q0, %q[valpha] @ get neg data\n"
"vmul.f32 q5, q1, %q[valpha] @ get neg data\n"
"vmul.f32 q6, q2, %q[valpha] @ get neg data\n"
"vmul.f32 q7, q3, %q[valpha] @ get neg data\n"
"vbit q4, q0, q8 @ bitsel, insert q0 to q4, "
"if q8 is 1\n"
"vbit q5, q1, q9 @ bitsel, insert q1 to q5, "
"if q9 is 1\n"
"vbit q6, q2, q10 @ bitsel, insert q2 to q6, "
"if q10 is 1\n"
"vbit q7, q3, q11 @ bitsel, insert q3 to q7, "
"if q11 is 1\n"
"vst1.32 {d8-d11}, [%[dout]]! @ store result, add pointer\n"
"vst1.32 {d12-d15}, [%[dout]]! @ store result, add pointer\n"
"subs %[cnt], #1 @ loop count minus 1\n"
"bne 1b @ jump to main loop start "
"point\n"
: [dout] "+r"(ptr_out_thread), [din] "+r"(ptr_in_thread),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero), [valpha] "w"(valpha)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11");
}
#endif
for (int j = 0; j < neon_loop_remain; ++j) {
ptr_out_thread[0] = ptr_in_thread[0] > 0.f
? ptr_in_thread[0]
: ptr_in_thread[0] * negative_slope;
ptr_in_thread++;
ptr_out_thread++;
}
}
float* out_ptr_remain = dout + threads * nums_per_thread;
const float* in_ptr_remain = din + threads * nums_per_thread;
for (int j = 0; j < remain; ++j) {
out_ptr_remain[0] = in_ptr_remain[0] > 0.f
? in_ptr_remain[0]
: in_ptr_remain[0] * negative_slope;
in_ptr_remain++;
out_ptr_remain++;
}
}
template <>
void act_clipped_relu<float>(const float* din, float* dout, int size,
const float coef, int threads) {
int nums_per_thread = size / threads;
int remain = size - threads * nums_per_thread;
int neon_loop_cnt = nums_per_thread >> 4;
int neon_loop_remain = nums_per_thread - (neon_loop_cnt << 4);
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t vclip = vdupq_n_f32(coef);
#pragma omp parallel for
for (int i = 0; i < threads; ++i) {
const float* ptr_in_thread = din + i * nums_per_thread;
float* ptr_out_thread = dout + i * nums_per_thread;
int cnt = neon_loop_cnt;
#ifdef __aarch64__
for (int num = 0; num < neon_loop_cnt; ++num) {
float32x4_t vr0 = vld1q_f32(ptr_in_thread);
ptr_in_thread += 4;
float32x4_t vr1 = vld1q_f32(ptr_in_thread);
ptr_in_thread += 4;
float32x4_t vr2 = vld1q_f32(ptr_in_thread);
ptr_in_thread += 4;
float32x4_t vr3 = vld1q_f32(ptr_in_thread);
ptr_in_thread += 4;
float32x4_t vt0 = vmaxq_f32(vr0, vzero);
float32x4_t vt1 = vmaxq_f32(vr1, vzero);
float32x4_t vt2 = vmaxq_f32(vr2, vzero);
float32x4_t vt3 = vmaxq_f32(vr3, vzero);
float32x4_t vo0 = vminq_f32(vt0, vclip);
float32x4_t vo1 = vminq_f32(vt1, vclip);
float32x4_t vo2 = vminq_f32(vt2, vclip);
float32x4_t vo3 = vminq_f32(vt3, vclip);
vst1q_f32(ptr_out_thread, vo0);
ptr_out_thread += 4;
vst1q_f32(ptr_out_thread, vo1);
ptr_out_thread += 4;
vst1q_f32(ptr_out_thread, vo2);
ptr_out_thread += 4;
vst1q_f32(ptr_out_thread, vo3);
ptr_out_thread += 4;
}
#else
if (cnt > 0) {
asm volatile(
"1: @ loop header\n"
"vld1.32 {d0-d3}, [%[din]]! @ load din 0\n"
"vld1.32 {d4-d7}, [%[din]]! @ load din 0\n"
"vmax.f32 q8, q0, %q[vzero] @ relu\n"
"vmax.f32 q9, q1, %q[vzero] @ relu\n"
"vmax.f32 q10, q2, %q[vzero] @ relu\n"
"vmax.f32 q11, q3, %q[vzero] @ relu\n"
"vmin.f32 q4, q8, %q[vclip] @ clip relu\n"
"vmin.f32 q5, q9, %q[vclip] @ clip relu\n"
"vmin.f32 q6, q10, %q[vclip] @ clip relu\n"
"vmin.f32 q7, q11, %q[vclip] @ clip relu\n"
"vst1.32 {d8-d11}, [%[dout]]! @ store result, add pointer\n"
"vst1.32 {d12-d15}, [%[dout]]! @ store result, add pointer\n"
"subs %[cnt], #1 @ loop count minus 1\n"
"bne 1b @ jump to main loop start "
"point\n"
: [dout] "+r"(ptr_out_thread), [din] "+r"(ptr_in_thread),
[cnt] "+r"(cnt)
: [vzero] "w"(vzero), [vclip] "w"(vclip)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11");
}
#endif
for (int j = 0; j < neon_loop_remain; ++j) {
ptr_out_thread[0] = ptr_in_thread[0] > 0.f ? ptr_in_thread[0] : 0.f;
ptr_out_thread[0] = ptr_out_thread[0] < coef ? ptr_out_thread[0] : coef;
ptr_in_thread++;
ptr_out_thread++;
}
}
float* out_ptr_remain = dout + threads * nums_per_thread;
const float* in_ptr_remain = din + threads * nums_per_thread;
for (int j = 0; j < remain; ++j) {
out_ptr_remain[0] = in_ptr_remain[0] > 0.f ? in_ptr_remain[0] : 0.f;
out_ptr_remain[0] = out_ptr_remain[0] < coef ? out_ptr_remain[0] : coef;
in_ptr_remain++;
out_ptr_remain++;
}
}
template <>
void act_prelu<float>(const float* din, float* dout, int outer_size,
int channel_size, int inner_size, bool channel_shared,
float* channel_slope, int threads) {
int stride_size = inner_size * channel_size;
int cnt = inner_size >> 4;
int remain = inner_size & 15;
float32x4_t vzero = vdupq_n_f32(0.f);
for (int n = 0; n < outer_size; n++) {
const float* data_in_batch = din + n * stride_size;
float* data_out_batch = dout + n * stride_size;
#pragma omp parallel for
for (int c = 0; c < channel_size; c++) {
const float* data_in_c = data_in_batch + c * inner_size;
float* data_out_c = data_out_batch + c * inner_size;
float slope = channel_shared ? channel_slope[0] : channel_slope[c];
float32x4_t vslope = vdupq_n_f32(slope);
#ifdef __aarch64__
for (int i = 0; i < cnt; ++i) {
float32x4_t vr0 = vld1q_f32(data_in_c);
float32x4_t vr1 = vld1q_f32(data_in_c + 4);
float32x4_t vr2 = vld1q_f32(data_in_c + 8);
float32x4_t vr3 = vld1q_f32(data_in_c + 12);
uint32x4_t vm0 = vcltq_f32(vr0, vzero); // vr0 <= vzero
uint32x4_t vm1 = vcltq_f32(vr1, vzero); // vr0 <= vzero
uint32x4_t vm2 = vcltq_f32(vr2, vzero); // vr0 <= vzero
uint32x4_t vm3 = vcltq_f32(vr3, vzero); // vr0 <= vzero
float32x4_t vo0 = vmulq_f32(vr0, vslope); // vr0 * vslope
float32x4_t vo1 = vmulq_f32(vr1, vslope); // vr0 * vslope
float32x4_t vo2 = vmulq_f32(vr2, vslope); // vr0 * vslope
float32x4_t vo3 = vmulq_f32(vr3, vslope); // vr0 * vslope
float32x4_t vos0 = vbslq_f32(vm0, vo0, vr0);
float32x4_t vos1 = vbslq_f32(vm1, vo1, vr1);
float32x4_t vos2 = vbslq_f32(vm2, vo2, vr2);
float32x4_t vos3 = vbslq_f32(vm3, vo3, vr3);
vst1q_f32(data_out_c, vos0);
vst1q_f32(data_out_c + 4, vos1);
vst1q_f32(data_out_c + 8, vos2);
vst1q_f32(data_out_c + 12, vos3);
data_in_c += 16;
data_out_c += 16;
}
#else
int cnt_loop = cnt;
if (cnt_loop > 0) {
asm volatile(
"vld1.32 {d0-d3}, [%[ptr_in]]! @ load "
"input to q0, q1\n"
"pld [%[ptr_in]] @ preload\n"
"pld [%[ptr_in], #64] @ preload\n"
"pld [%[ptr_in], #128] @ preload\n"
"pld [%[ptr_in], #192] @ preload\n"
"1: @main loop\n"
"vld1.32 {d4-d7}, [%[ptr_in]]! @ load input to "
"q2, q3\n"
"vclt.f32 q8, q0, %q[vzero] @vcle q0 <= vzero\n"
"vclt.f32 q9, q1, %q[vzero] @vcle q1 <= vzero\n"
"vmul.f32 q10, q0, %q[vslope] @vmul q0 * vslope\n"
"vmul.f32 q11, q1, %q[vslope] @vmul q1 * vslope\n"
"vclt.f32 q12, q2, %q[vzero] @vcle q2 <= vzero\n"
"vclt.f32 q13, q3, %q[vzero] @vcle q3 <= vzero\n"
"vmul.f32 q14, q2, %q[vslope] @vmul q2 * vslope\n"
"vmul.f32 q15, q3, %q[vslope] @vmul q3 * vslope\n"
"vbif.32 q10, q0, q8 @vbit q10, q0, q8\n"
"vbif.32 q11, q1, q9 @vbit q11, q1, q9\n"
"vbif.32 q14, q2, q12 @vbit q14, q2, "
"q12\n"
"vbif.32 q15, q3, q13 @vbit q15, q3, "
"q13\n"
"subs %[cnt], #1 @subs nn, 1\n"
"vld1.32 {d0-d3}, [%[ptr_in]]! @ load input to "
"q0, q1\n"
"vst1.f32 {d20-d23}, [%[dout]]! @store data\n"
"vst1.f32 {d28-d31}, [%[dout]]! @store data\n"
"bne 1b @bne nn\n"
"sub %[ptr_in], #32 @ ptr-32\n"
: [ptr_in] "+r"(data_in_c), [cnt] "+r"(cnt_loop),
[dout] "+r"(data_out_c)
: [vzero] "w"(vzero), [vslope] "w"(vslope)
: "cc", "memory", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11",
"q12", "q13", "q14", "q15");
}
#endif // __aarch64__
for (int i = remain; i > 0; i--) {
*(data_out_c++) =
data_in_c[0] > 0.f ? data_in_c[0] : data_in_c[0] * slope;
data_in_c++;
}
}
}
}
template <>
void act_sigmoid(const float* din, float* dout, int size, int threads) {
int nums_per_thread = size / threads;
int remain = size - threads * nums_per_thread;
int neon_loop_cnt_dim4 = nums_per_thread >> 2;
int neon_loop_remain_dim4 = nums_per_thread - (neon_loop_cnt_dim4 << 2);
float32x4_t vzero = vdupq_n_f32(0.f);
#pragma omp parallel for
for (int i = 0; i < threads; ++i) {
float32x4_t exp_vec = vdupq_n_f32(0.0f);
float32x4_t recip = vdupq_n_f32(0.0f);
const float* ptr_in_thread = din + i * nums_per_thread;
float* ptr_out_thread = dout + i * nums_per_thread;
for (int k = 0; k < neon_loop_cnt_dim4; ++k) {
exp_vec = exp_ps(vnegq_f32(vld1q_f32(ptr_in_thread)));
exp_vec = vaddq_f32(exp_vec, vdupq_n_f32(1.0f));
recip = vrecpeq_f32(exp_vec);
recip = vmulq_f32(vrecpsq_f32(exp_vec, recip), recip);
recip = vmulq_f32(vrecpsq_f32(exp_vec, recip), recip);
vst1q_f32(ptr_out_thread, recip);
ptr_out_thread += 4;
ptr_in_thread += 4;
}
for (int j = 0; j < neon_loop_remain_dim4; ++j) {
ptr_out_thread[0] = 1.f / (1 + expf(-ptr_in_thread[0]));
ptr_in_thread++;
ptr_out_thread++;
}
}
float* ptr_out = dout + threads * nums_per_thread;
const float* ptr_in = din + threads * nums_per_thread;
for (int j = 0; j < remain; ++j) {
ptr_out[0] = 1.f / (1 + expf(-ptr_in[0]));
ptr_in++;
ptr_out++;
}
}
// tanh : (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <>
void act_tanh<float>(const float* din, float* dout, int size, int threads) {
int nums_per_thread = size / threads;
int remain = size - threads * nums_per_thread;
int neon_loop_cnt_dim4 = nums_per_thread >> 2;
int neon_loop_remain_dim4 = nums_per_thread - (neon_loop_cnt_dim4 << 2);
#pragma omp parallel for
for (int i = 0; i < threads; ++i) {
float32x4_t exp_plus_vec = vdupq_n_f32(0.0f);
float32x4_t exp_minus_vec = vdupq_n_f32(0.0f);
float32x4_t exp_sum_vec = vdupq_n_f32(0.0f);
float32x4_t exp_diff_vec = vdupq_n_f32(0.0f);
float32x4_t recip = vdupq_n_f32(0.0f);
const float* ptr_in_thread = din + i * nums_per_thread;
float* ptr_out_thread = dout + i * nums_per_thread;
for (int k = 0; k < neon_loop_cnt_dim4; ++k) {
exp_plus_vec = exp_ps(vld1q_f32(ptr_in_thread));
exp_minus_vec = exp_ps(vnegq_f32(vld1q_f32(ptr_in_thread)));
exp_sum_vec = vaddq_f32(exp_plus_vec, exp_minus_vec);
exp_diff_vec = vsubq_f32(exp_plus_vec, exp_minus_vec);
recip = div_ps(exp_diff_vec, exp_sum_vec);
vst1q_f32(ptr_out_thread, recip);
ptr_out_thread += 4;
ptr_in_thread += 4;
}
for (int j = 0; j < neon_loop_remain_dim4; ++j) {
ptr_out_thread[0] = (expf(ptr_in_thread[0]) - expf(-ptr_in_thread[0])) /
(expf(ptr_in_thread[0]) + expf(-ptr_in_thread[0]));
ptr_in_thread++;
ptr_out_thread++;
}
}
float* ptr_out = dout + threads * nums_per_thread;
const float* ptr_in = din + threads * nums_per_thread;
for (int j = 0; j < remain; ++j) {
ptr_out[0] = (expf(ptr_in[0]) - expf(-ptr_in[0])) /
(expf(ptr_in[0]) + expf(-ptr_in[0]));
ptr_in++;
ptr_out++;
}
}
// swish: x /(1 + exp(-(b * x)))
template <>
void act_swish<float>(const float* din, float* dout, int size, const float coef,
int threads) {
int nums_per_thread = size / threads;
int remain = size - threads * nums_per_thread;
int neon_loop_cnt_dim4 = nums_per_thread >> 2;
int neon_loop_remain_dim4 = nums_per_thread - (neon_loop_cnt_dim4 << 2);
const float beta = coef;
float32x4_t vbeta = vdupq_n_f32(beta);
float32x4_t vone = vdupq_n_f32(1.f);
#pragma omp parallel for
for (int i = 0; i < threads; ++i) {
const float* ptr_in_thread = din + i * nums_per_thread;
float* ptr_out_thread = dout + i * nums_per_thread;
for (int k = 0; k < neon_loop_cnt_dim4; ++k) {
float32x4_t va = vld1q_f32(ptr_in_thread); // x
float32x4_t vb = vnegq_f32(vld1q_f32(ptr_in_thread)); // -x
float32x4_t vsum = vmulq_f32(vb, vbeta);
vsum = exp_ps(vsum);
float32x4_t vc = vaddq_f32(vone, vsum);
float32x4_t vrst = div_ps(va, vc);
vst1q_f32(ptr_out_thread, vrst);
ptr_out_thread += 4;
ptr_in_thread += 4;
}
for (int j = 0; j < neon_loop_remain_dim4; ++j) {
ptr_out_thread[0] =
ptr_in_thread[0] / (1.0 + expf(-ptr_in_thread[0] * beta));
ptr_in_thread++;
ptr_out_thread++;
}
}
float* ptr_out = dout + threads * nums_per_thread;
const float* ptr_in = din + threads * nums_per_thread;
for (int j = 0; j < remain; ++j) {
ptr_out[0] = ptr_in[0] / (1.0 + expf(-ptr_in[0] * beta));
ptr_in++;
ptr_out++;
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void act_relu(const T* din, T* dout, int size, int threads);
template <typename T>
void act_relu_neg(const T* din, T* dout, int size, const float negative_slope,
int threads);
template <typename T>
void act_clipped_relu(const T* din, T* dout, int size, const float coef,
int threads);
template <typename T>
void act_prelu(const T* din, T* dout, int outer_size, int channel_size,
int inner_size, bool channel_shared, float* channel_slope,
int threads);
template <typename T>
void act_sigmoid(const T* din, T* dout, int size, int threads);
template <typename T>
void act_tanh(const T* din, T* dout, int size, int threads);
template <typename T>
void act_swish(const T* din, T* dout, int size, const float coef, int threads);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/concat.h"
#include <algorithm>
#include <limits>
#include <memory>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void concat_func(const std::vector<lite::Tensor *> &input, const int axis,
lite::Tensor *output) {
size_t num = input.size();
int rows = 1;
auto dim_0 = input[0]->dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int out_rows = rows, out_cols = 0;
std::vector<int64_t> input_cols(input.size());
for (int i = 0; i < num; ++i) {
int t_cols = input[i]->numel() / rows;
out_cols += t_cols;
input_cols[i] = t_cols;
}
// computation
for (int k = 0; k < out_rows; ++k) {
float *dst_ptr = output->mutable_data<float>() + k * out_cols;
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
const float *src_prt = input[j]->data<float>() + k * col_len;
std::memcpy(dst_ptr + col_idx, src_prt, sizeof(float) * col_len);
col_idx += col_len;
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void concat_func(const std::vector<lite::Tensor *> &input, const int axis,
lite::Tensor *output);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/dropout.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void dropout_down<float>(const float* din, float* dout, int num, float prob) {
const float scale = 1.0f - prob;
int cnt = num >> 4;
int remain = num % 16;
float32x4_t vscale = vdupq_n_f32(scale);
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const float* din_ptr = din + (i << 4);
float* dout_ptr = dout + (i << 4);
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
float32x4_t vmul0 = vmulq_f32(din0, vscale);
float32x4_t vmul1 = vmulq_f32(din1, vscale);
float32x4_t vmul2 = vmulq_f32(din2, vscale);
float32x4_t vmul3 = vmulq_f32(din3, vscale);
vst1q_f32(dout_ptr, vmul0);
vst1q_f32(dout_ptr + 4, vmul1);
vst1q_f32(dout_ptr + 8, vmul2);
vst1q_f32(dout_ptr + 12, vmul3);
}
if (remain > 0) {
const float* din_ptr = din + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
*dout_ptr = *din_ptr * scale;
dout_ptr++;
din_ptr++;
}
}
}
template <>
void dropout_up<float>(const float* din, float* dout, int num) {
int cnt = num >> 4;
int remain = num % 16;
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const float* din_ptr = din + (i << 4);
float* dout_ptr = dout + (i << 4);
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
vst1q_f32(dout_ptr + 8, din2);
vst1q_f32(dout_ptr + 12, din3);
}
if (remain > 0) {
const float* din_ptr = din + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
*dout_ptr = *din_ptr;
dout_ptr++;
din_ptr++;
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void dropout_down(const T* din, T* dout, int num, float prob);
template <typename T>
void dropout_up(const T* din, T* dout, int num);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/elementwise.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void elementwise_add<float>(const float* dinx, const float* diny, float* dout,
int num) {
int cnt = num >> 4;
int remain = num % 16;
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const float* dinx_ptr = dinx + (i << 4);
const float* diny_ptr = diny + (i << 4);
float* dout_ptr = dout + (i << 4);
float32x4_t dinx0 = vld1q_f32(dinx_ptr);
float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4);
float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8);
float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12);
float32x4_t diny0 = vld1q_f32(diny_ptr);
float32x4_t diny1 = vld1q_f32(diny_ptr + 4);
float32x4_t diny2 = vld1q_f32(diny_ptr + 8);
float32x4_t diny3 = vld1q_f32(diny_ptr + 12);
dinx0 = vaddq_f32(dinx0, diny0);
dinx1 = vaddq_f32(dinx1, diny1);
dinx2 = vaddq_f32(dinx2, diny2);
dinx3 = vaddq_f32(dinx3, diny3);
vst1q_f32(dout_ptr, dinx0);
vst1q_f32(dout_ptr + 4, dinx1);
vst1q_f32(dout_ptr + 8, dinx2);
vst1q_f32(dout_ptr + 12, dinx3);
}
if (remain > 0) {
const float* dinx_ptr = dinx + (cnt << 4);
const float* diny_ptr = diny + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
*dout_ptr = *dinx_ptr + *diny_ptr;
dout_ptr++;
dinx_ptr++;
diny_ptr++;
}
}
}
template <>
void elementwise_add_relu<float>(const float* dinx, const float* diny,
float* dout, int num) {
int cnt = num >> 4;
int remain = num % 16;
float32x4_t vzero = vdupq_n_f32(0.f);
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const float* dinx_ptr = dinx + (i << 4);
const float* diny_ptr = diny + (i << 4);
float* dout_ptr = dout + (i << 4);
float32x4_t dinx0 = vld1q_f32(dinx_ptr);
float32x4_t dinx1 = vld1q_f32(dinx_ptr + 4);
float32x4_t dinx2 = vld1q_f32(dinx_ptr + 8);
float32x4_t dinx3 = vld1q_f32(dinx_ptr + 12);
float32x4_t diny0 = vld1q_f32(diny_ptr);
float32x4_t diny1 = vld1q_f32(diny_ptr + 4);
float32x4_t diny2 = vld1q_f32(diny_ptr + 8);
float32x4_t diny3 = vld1q_f32(diny_ptr + 12);
dinx0 = vaddq_f32(dinx0, diny0);
dinx1 = vaddq_f32(dinx1, diny1);
dinx2 = vaddq_f32(dinx2, diny2);
dinx3 = vaddq_f32(dinx3, diny3);
// relu
dinx0 = vmaxq_f32(dinx0, vzero);
dinx1 = vmaxq_f32(dinx1, vzero);
dinx2 = vmaxq_f32(dinx2, vzero);
dinx3 = vmaxq_f32(dinx3, vzero);
vst1q_f32(dout_ptr, dinx0);
vst1q_f32(dout_ptr + 4, dinx1);
vst1q_f32(dout_ptr + 8, dinx2);
vst1q_f32(dout_ptr + 12, dinx3);
}
if (remain > 0) {
const float* dinx_ptr = dinx + (cnt << 4);
const float* diny_ptr = diny + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
float tmp = *dinx_ptr + *diny_ptr;
*dout_ptr = tmp > 0.f ? tmp : 0.f;
dout_ptr++;
dinx_ptr++;
diny_ptr++;
}
}
}
template <>
void elementwise_add_broadcast<float>(const float* dinx, const float* diny,
float* dout, int batch, int channels,
int num) {
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const float* din_ptr = dinx + offset;
const float diny_data = diny[j];
float* dout_ptr = dout + offset;
int cnt = num >> 4;
int remain = num % 16;
float32x4_t rb = vdupq_n_f32(diny_data);
for (int k = 0; k < cnt; ++k) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
din0 = vaddq_f32(din0, rb);
din1 = vaddq_f32(din1, rb);
din2 = vaddq_f32(din2, rb);
din3 = vaddq_f32(din3, rb);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
vst1q_f32(dout_ptr + 8, din2);
vst1q_f32(dout_ptr + 12, din3);
din_ptr += 16;
dout_ptr += 16;
}
if (remain >= 8) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
din0 = vaddq_f32(din0, rb);
din1 = vaddq_f32(din1, rb);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
din_ptr += 8;
dout_ptr += 8;
remain -= 8;
}
if (remain >= 4) {
float32x4_t din0 = vld1q_f32(din_ptr);
din0 = vaddq_f32(din0, rb);
vst1q_f32(dout_ptr, din0);
din_ptr += 4;
dout_ptr += 4;
remain -= 4;
}
if (remain > 0) {
for (int p = 0; p < remain; p++) {
*dout_ptr = *din_ptr + diny_data;
dout_ptr++;
din_ptr++;
}
}
}
}
}
template <>
void elementwise_add_relu_broadcast<float>(const float* dinx, const float* diny,
float* dout, int batch, int channels,
int num) {
float32x4_t vzero = vdupq_n_f32(0.f);
#pragma omp parallel for collapse(2)
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const float* din_ptr = dinx + offset;
const float diny_data = diny[j];
float* dout_ptr = dout + offset;
int cnt = num >> 4;
int remain = num % 16;
float32x4_t rb = vdupq_n_f32(diny_data);
for (int k = 0; k < cnt; ++k) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
din0 = vaddq_f32(din0, rb);
din1 = vaddq_f32(din1, rb);
din2 = vaddq_f32(din2, rb);
din3 = vaddq_f32(din3, rb);
// relu
din0 = vmaxq_f32(din0, vzero);
din1 = vmaxq_f32(din1, vzero);
din2 = vmaxq_f32(din2, vzero);
din3 = vmaxq_f32(din3, vzero);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
vst1q_f32(dout_ptr + 8, din2);
vst1q_f32(dout_ptr + 12, din3);
din_ptr += 16;
dout_ptr += 16;
}
if (remain >= 8) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
din0 = vaddq_f32(din0, rb);
din1 = vaddq_f32(din1, rb);
// relu
din0 = vmaxq_f32(din0, vzero);
din1 = vmaxq_f32(din1, vzero);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
din_ptr += 8;
dout_ptr += 8;
remain -= 8;
}
if (remain >= 4) {
float32x4_t din0 = vld1q_f32(din_ptr);
din0 = vaddq_f32(din0, rb);
// relu
din0 = vmaxq_f32(din0, vzero);
vst1q_f32(dout_ptr, din0);
din_ptr += 4;
dout_ptr += 4;
remain -= 4;
}
if (remain > 0) {
for (int p = 0; p < remain; p++) {
float tmp = *din_ptr + diny_data;
*dout_ptr = tmp > 0.f ? tmp : 0.f;
dout_ptr++;
din_ptr++;
}
}
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void elementwise_add(const T* dinx, const T* diny, T* dout, int num);
template <typename T>
void elementwise_add_relu(const T* dinx, const T* diny, T* dout, int num);
template <typename T>
void elementwise_add_broadcast(const T* dinx, const T* diny, T* dout, int batch,
int channels, int num);
template <typename T>
void elementwise_add_relu_broadcast(const T* dinx, const T* diny, T* dout,
int batch, int channels, int num);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/pooling.h"
#include <algorithm>
#include <limits>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void pooling_basic(const float* din, float* dout, int num, int chout, int hout,
int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type) {
// no need to pad input tensor, border is zero pad inside this function
int kernel_h = ksize[0];
int kernel_w = ksize[1];
int stride_h = strides[0];
int stride_w = strides[1];
int pad_h = paddings[0];
int pad_w = paddings[1];
int size_channel_in = win * hin;
int size_channel_out = wout * hout;
if (global_pooling) {
if (pooling_type == "max") { // Pooling_max
for (int n = 0; n < num; ++n) {
float* dout_batch = dout + n * chout * size_channel_out;
const float* din_batch = din + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; ++c) {
const float* din_ch = din_batch + c * size_channel_in; // in address
float tmp1 = din_ch[0];
for (int i = 0; i < size_channel_in; ++i) {
float tmp2 = din_ch[i];
tmp1 = tmp1 > tmp2 ? tmp1 : tmp2;
}
dout_batch[c] = tmp1;
}
}
} else if (pooling_type == "avg") {
// Pooling_average_include_padding
// Pooling_average_exclude_padding
for (int n = 0; n < num; ++n) {
float* dout_batch = dout + n * chout * size_channel_out;
const float* din_batch = din + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; ++c) {
const float* din_ch = din_batch + c * size_channel_in; // in address
float sum = 0.f;
for (int i = 0; i < size_channel_in; ++i) {
sum += din_ch[i];
}
dout_batch[c] = sum / size_channel_in;
}
}
} else {
LOG(FATAL) << "unsupported pooling type: " << pooling_type;
}
} else {
if (pooling_type == "max") {
// Pooling_max
for (int n = 0; n < num; ++n) {
float* dout_ch = dout + n * chout * size_channel_out;
const float* din_batch = din + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* dout_row = dout_ch + c * size_channel_out;
const float* din_ch = din_batch + c * size_channel_in;
for (int i = 0; i < hout; i++) {
for (int j = 0; j < wout; j++) {
int hstart = i * stride_h - pad_h;
int wstart = j * stride_w - pad_w;
int hend = std::min(hstart + kernel_h, hin + pad_h);
int wend = std::min(wstart + kernel_w, win + pad_w);
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
hend = std::min(hend, hin);
wend = std::min(wend, win);
int pool_size = (hend - hstart) * (wend - wstart);
if (pool_size == 0) continue;
float tmp1 = din_ch[hstart * win + wstart];
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
float tmp2 = din_ch[h * win + w];
tmp1 = tmp1 > tmp2 ? tmp1 : tmp2;
}
}
dout_row[j] = tmp1;
}
dout_row += wout;
}
}
}
} else if (pooling_type == "avg") {
if (exclusive) {
// Pooling_average_exclude_padding
for (int n = 0; n < num; ++n) {
float* dout_ch = dout + n * chout * size_channel_out;
const float* din_batch = din + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* dout_row = dout_ch + c * size_channel_out;
const float* din_ch = din_batch + c * size_channel_in;
for (int i = 0; i < hout; i++) {
for (int j = 0; j < wout; j++) {
int hstart = i * stride_h - pad_h;
int wstart = j * stride_w - pad_w;
int hend = std::min(hstart + kernel_h, hin + pad_h);
int wend = std::min(wstart + kernel_w, win + pad_w);
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
hend = std::min(hend, hin);
wend = std::min(wend, win);
int pool_size = (hend - hstart) * (wend - wstart);
if (pool_size == 0) continue;
float sum = 0.f;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
sum += din_ch[h * win + w];
}
}
dout_row[j] = sum / pool_size;
}
dout_row += wout;
}
}
}
} else { // Pooling_average_include_padding
for (int n = 0; n < num; ++n) {
float* dout_ch = dout + n * chout * size_channel_out;
const float* din_batch = din + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* dout_row = dout_ch + c * size_channel_out;
const float* din_ch = din_batch + c * size_channel_in;
for (int i = 0; i < hout; i++) {
for (int j = 0; j < wout; j++) {
int hstart = i * stride_h - pad_h;
int wstart = j * stride_w - pad_w;
int hend = std::min(hstart + kernel_h, hin + pad_h);
int wend = std::min(wstart + kernel_w, win + pad_w);
hstart = std::max(hstart, 0);
wstart = std::max(wstart, 0);
hend = std::min(hend, hin);
wend = std::min(wend, win);
int pool_size = (hend - hstart) * (wend - wstart);
if (pool_size == 0) continue;
float sum = 0.f;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
sum += din_ch[h * win + w];
}
}
dout_row[j] = sum / (kernel_w * kernel_h);
}
dout_row += wout;
}
}
}
}
} else {
LOG(FATAL) << "unsupported pooling type: " << pooling_type;
}
}
}
void pooling_global_max(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win) {
int size_channel_in = win * hin;
int cnt = size_channel_in / 8;
for (int n = 0; n < num; ++n) {
float* dout_batch = dout + n * chout;
const float* din_batch = din + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; ++c) {
const float* din_ch = din_batch + c * size_channel_in;
int i = 0;
float minval = std::numeric_limits<float>::lowest();
float32x4_t vmax = vdupq_n_f32(minval);
#ifdef __aarch64__
for (; i < cnt; i++) {
float32x4_t vdin1 = vld1q_f32(din_ch);
vmax = vmaxq_f32(vdin1, vmax);
float32x4_t vdin2 = vld1q_f32(din_ch + 4);
vmax = vmaxq_f32(vmax, vdin2);
din_ch += 8;
}
#else
int cnt_num = cnt;
if (cnt_num > 0) {
asm volatile(
"max_loop: @main loop\n"
"vld1.f32 {d0-d1}, [%[din_ch]]! @load q1,din_ch\n"
"vmax.f32 %q[vmax], %q[vmax], q0 @max vmax,vmax,din_ch\n"
"vld1.f32 {d2-d3}, [%[din_ch]]! @load 2nd 4 data\n"
"vmax.f32 %q[vmax], %q[vmax], q1 @compare 2nd 4 datas\n"
"subs %[cnt_num], #1 @cnt_num--\n"
"bne max_loop @bne cnt_num\n"
: [din_ch] "+r"(din_ch), [cnt_num] "+r"(cnt_num), [vmax] "+w"(vmax)
:
: "cc", "memory", "q0", "q1");
}
#endif // __aarch64__
float32x2_t vmax_tmp = vmax_f32(vget_low_f32(vmax), vget_high_f32(vmax));
float tmp1 = vget_lane_f32(vmax_tmp, 0);
float tmp2 = vget_lane_f32(vmax_tmp, 1);
float max_tmp = tmp1 > tmp2 ? tmp1 : tmp2;
for (i = cnt * 8; i < size_channel_in; ++i) {
/* code */
max_tmp = max_tmp > din_ch[0] ? max_tmp : din_ch[0];
din_ch++;
}
dout_batch[c] = max_tmp;
}
}
}
void pooling_global_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win) {
int size_channel_in = win * hin;
int cnt = size_channel_in / 4;
for (int n = 0; n < num; ++n) {
float* dout_batch = dout + n * chout;
const float* din_batch = din + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
const float* din_ch = din_batch + c * size_channel_in; // in address
int i = 0;
float32x4_t vsum = vdupq_n_f32(0.0f);
#ifdef __aarch64__
for (; i < cnt; i++) {
vsum = vaddq_f32(vld1q_f32(din_ch), vsum);
din_ch += 4;
}
#else
int cnt_num = cnt;
if (cnt_num > 0) {
asm volatile(
"add_loop: @main loop\n"
"vld1.f32 {d0-d1}, [%[din_ch]]! @load q1,din_ch\n"
"vadd.f32 %q[vsum], %q[vsum], q0 @add vmax,vmax, din_ch\n"
"subs %[cnt_num], #1 @cnt_num--\n"
"bne add_loop @bne num\n"
: [din_ch] "+r"(din_ch), [cnt_num] "+r"(cnt_num), [vsum] "+w"(vsum)
:
: "cc", "memory", "q0");
}
#endif // __aarch64__
float32x2_t vsum_tmp = vadd_f32(vget_low_f32(vsum), vget_high_f32(vsum));
float sum = vget_lane_f32(vsum_tmp, 0) + vget_lane_f32(vsum_tmp, 1);
for (i = cnt * 4; i < size_channel_in; i++) {
sum += din_ch[0];
din_ch++;
}
dout_batch[c] = sum / size_channel_in;
}
}
}
void pooling2x2s2_max(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win) {
int kernel = 2;
int stride = 2;
int padding = 0;
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
int w_needed = (wout << 1);
int h_needed = (hout << 1);
int w_limit = w_needed > win ? win : w_needed;
int h_limit = h_needed > hin ? hin : h_needed;
int w_even = (w_limit >> 1) << 1;
int h_even = (h_limit >> 1) << 1;
int w_unroll_size = (w_even >> 3) << 3;
// int w_unroll_remain = w_even - w_unroll_size;
int w_in_2 = win << 1;
for (int n = 0; n < num; ++n) {
float* dout_batch = dout + n * chout * size_channel_out;
const float* din_batch = din + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* dout_ch = dout_batch + c * size_channel_out;
const float* din_ch = din_batch + c * size_channel_in;
const float* r0 = din_ch;
const float* r1 = r0 + win;
int h = 0;
for (; h < h_even; h += 2) {
int w = 0;
#ifdef __aarch64__
for (; w < w_unroll_size; w += 8) {
float32x4_t dr00 = vld1q_f32(&r0[w]);
float32x4_t dr01 = vld1q_f32(&r0[w + 4]);
float32x4_t dr10 = vld1q_f32(&r1[w]);
float32x4_t dr11 = vld1q_f32(&r1[w + 4]);
float32x4_t dmax1 = vmaxq_f32(dr00, dr10);
float32x4_t dmax2 = vmaxq_f32(dr01, dr11);
#ifdef __aarch64__
float32x4_t dmax = vpmaxq_f32(dmax1, dmax2);
#else
float32x2_t dmaxl =
vpmax_f32(vget_low_f32(dmax1), vget_high_f32(dmax1));
float32x2_t dmaxh =
vpmax_f32(vget_low_f32(dmax2), vget_high_f32(dmax2));
float32x4_t dmax = vcombine_f32(dmaxl, dmaxh);
#endif
vst1q_f32(&dout_ch[w >> 1], dmax);
}
#else
float* dr_out = dout_ch;
const float* dr0 = r0;
const float* dr1 = r1;
int cnt_num = w_unroll_size >> 3;
if (cnt_num > 0) {
asm volatile(
"s2_max_loop: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0\n"
"vld1.f32 {d4-d7}, [%[dr1]]! @load q1,dr1\n"
"vmax.f32 q0, q0, q2 @max q0,q0,q2\n"
"vmax.f32 q1, q1, q3 @max q1,q1,q2\n"
"vpmax.f32 d4, d0, d1 @max d4,d0,d1\n"
"vpmax.f32 d5, d2, d3 @max d5,d2,d3\n"
"vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out\n"
"subs %[cnt_num], #1 @cnt_num--\n"
"bne s2_max_loop @bne cnt_num\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
:
: "cc", "memory", "q0", "q1", "q2", "q3");
}
w = w_unroll_size;
#endif // __aarch64__
for (; w < w_even; w += 2) {
dout_ch[w >> 1] =
std::max(std::max(r0[w], r0[w + 1]), std::max(r1[w], r1[w + 1]));
}
for (; w < w_limit; ++w) { // run 0 or 1 time
dout_ch[w >> 1] = std::max(r0[w], r1[w]);
}
r0 += w_in_2; // << 1;
r1 += w_in_2; // << 1;
dout_ch += wout;
}
// process remain row (odd, last row)
for (; h < h_limit; h++) { // run 0 or 1 time
int w = 0;
#ifdef __aarch64__
for (; w < w_unroll_size; w += 8) {
float32x4_t dr00 = vld1q_f32(&r0[w]);
float32x4_t dr01 = vld1q_f32(&r0[w + 4]);
#ifdef __aarch64__
float32x4_t dmax = vpmaxq_f32(dr00, dr01);
#else
float32x2_t dmaxl =
vpmax_f32(vget_low_f32(dr00), vget_high_f32(dr00));
float32x2_t dmaxh =
vpmax_f32(vget_low_f32(dr01), vget_high_f32(dr01));
float32x4_t dmax = vcombine_f32(dmaxl, dmaxh);
#endif
vst1q_f32(&dout_ch[w >> 1], dmax);
}
#else
float* dr_out = dout_ch;
const float* dr0 = r0;
int cnt_num = w_unroll_size >> 3;
if (cnt_num > 0) {
asm volatile(
"s2_max_loop1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0\n"
"vpmax.f32 d4, d0, d1 @max d4,d0,d1\n"
"vpmax.f32 d5, d2, d3 @max d5,d2,d3\n"
"vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out\n"
"subs %[cnt_num], #1 @cnt_num--\n"
"bne s2_max_loop1 @bne cnt_num\n"
: [dr0] "+r"(dr0), [dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num)
:
: "cc", "memory", "q0", "q1", "q2");
}
w = w_unroll_size;
#endif // __aarch64__
for (; w < w_even; w += 2) {
dout_ch[w >> 1] = std::max(r0[w], r0[w + 1]);
}
for (; w < w_limit; ++w) { // run 0 or 1 time
dout_ch[w >> 1] = r0[w];
}
}
}
}
}
void pooling2x2s2_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
bool exclusive) {
int kernel = 2;
int stride = 2;
int padding = 0;
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
int w_needed = (wout << 1);
int h_needed = (hout << 1);
int w_limit = w_needed > win ? win : w_needed;
int h_limit = h_needed > hin ? hin : h_needed;
int w_even = (w_limit >> 1) << 1;
int h_even = (h_limit >> 1) << 1;
int w_unroll_size = (w_even >> 3) << 3;
// int w_unroll_remain = w_even - w_unroll_size;
int w_in_2 = win << 1;
const float coef = 1.f / 4.f;
const float coef_1 = exclusive ? 1.f : coef;
const float coef_2 = exclusive ? 1.f / 2.f : coef;
float32x4_t vcoef = vdupq_n_f32(coef);
float32x4_t vcoef_1 = vdupq_n_f32(coef_1);
float32x4_t vcoef_2 = vdupq_n_f32(coef_2);
for (int n = 0; n < num; ++n) {
float* dout_batch = dout + n * chout * size_channel_out;
const float* din_batch = din + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* dout_ch = dout_batch + c * size_channel_out;
const float* din_ch = din_batch + c * size_channel_in;
const float* r0 = din_ch;
const float* r1 = r0 + win;
int h = 0;
for (; h < h_even; h += 2) {
int w = 0;
#ifdef __aarch64__
for (; w < w_unroll_size; w += 8) {
float32x4_t dr00 = vld1q_f32(&r0[w]);
float32x4_t dr01 = vld1q_f32(&r0[w + 4]);
float32x4_t dr10 = vld1q_f32(&r1[w]);
float32x4_t dr11 = vld1q_f32(&r1[w + 4]);
float32x4_t dsum1 = vaddq_f32(dr00, dr10);
float32x4_t dsum2 = vaddq_f32(dr01, dr11);
#ifdef __aarch64__
float32x4_t dsum = vpaddq_f32(dsum1, dsum2);
#else
float32x2_t dsuml =
vpadd_f32(vget_low_f32(dsum1), vget_high_f32(dsum1));
float32x2_t dsumh =
vpadd_f32(vget_low_f32(dsum2), vget_high_f32(dsum2));
float32x4_t dsum = vcombine_f32(dsuml, dsumh);
#endif
float32x4_t res = vmulq_f32(dsum, vcoef);
vst1q_f32(&dout_ch[w >> 1], res);
}
#else
float* dr_out = dout_ch;
const float* dr0 = r0;
const float* dr1 = r1;
int cnt_num = w_unroll_size >> 3;
if (cnt_num > 0) {
asm volatile(
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0\n"
"vld1.f32 {d4-d7}, [%[dr1]]! @load q1,dr1\n"
"vadd.f32 q0, q0, q2 @add q0,q0,q2\n"
"vadd.f32 q1, q1, q3 @add q1,q1,q2\n"
"vpadd.f32 d4, d0, d1 @add d4,d0,d1\n"
"vpadd.f32 d5, d2, d3 @add d5,d2,d3\n"
"vmul.f32 q2, q2, %q[vcoef] @mul q2,q2,vcoef\n"
"vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out\n"
"subs %[cnt_num], #1 @cnt_num--\n"
"bne 1b @bne cnt_num\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[vcoef] "+w"(vcoef), [cnt_num] "+r"(cnt_num)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "w"(vcoef)
: "cc", "memory", "q0", "q1", "q2", "q3");
}
w = w_unroll_size;
#endif // __aarch64__
for (; w < w_even; w += 2) {
dout_ch[w >> 1] = (r0[w] + r0[w + 1] + r1[w] + r1[w + 1]) * coef;
}
for (; w < w_limit; ++w) { // run 0 or 1 time
dout_ch[w >> 1] = (r0[w] + r1[w]) * coef_2;
}
r0 += w_in_2; // << 1;
r1 += w_in_2; // << 1;
dout_ch += wout;
}
// process remain row (odd, last row)
for (; h < h_limit; h++) { // run 0 or 1 time
int w = 0;
#ifdef __aarch64__
for (; w < w_unroll_size; w += 8) {
float32x4_t dr00 = vld1q_f32(&r0[w]);
float32x4_t dr01 = vld1q_f32(&r0[w + 4]);
#ifdef __aarch64__
float32x4_t dsum = vpaddq_f32(dr00, dr01);
#else
float32x2_t dsuml =
vpadd_f32(vget_low_f32(dr00), vget_high_f32(dr00));
float32x2_t dsumh =
vpadd_f32(vget_low_f32(dr01), vget_high_f32(dr01));
float32x4_t dsum = vcombine_f32(dsuml, dsumh);
#endif
float32x4_t res = vmulq_f32(dsum, vcoef_2);
vst1q_f32(&dout_ch[w >> 1], res);
}
#else
float* dr_out = dout_ch;
const float* dr0 = r0;
int cnt_num = w_unroll_size >> 3;
if (cnt_num > 0) {
asm volatile(
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load q0,dr0\n"
"vpadd.f32 d4, d0, d1 @add d4,d0,d1\n"
"vpadd.f32 d5, d2, d3 @add d5,d2,d3\n"
"vmul.f32 q2, q2, %q[vcoef_2] @mul q2,q2,vcoef_2\n"
"vst1.f32 {d4-d5}, [%[dr_out]]! @vst1 q2,dr_out\n"
"subs %[cnt_num], #1 @cnt_num--\n"
"bne 1b @bne cnt_num\n"
: [dr0] "+r"(dr0), [dr_out] "+r"(dr_out), [vcoef_2] "+w"(vcoef_2),
[cnt_num] "+r"(cnt_num)
: "r"(dr0), "r"(dr_out), "r"(cnt_num), "w"(vcoef_2)
: "cc", "memory", "q0", "q1", "q2");
}
w = w_unroll_size;
#endif // __aarch64__
for (; w < w_even; w += 2) {
dout_ch[w >> 1] = (r0[w] + r0[w + 1]) * coef_2;
}
for (; w < w_limit; ++w) { // run 0 or 1 time
dout_ch[w >> 1] = r0[w] * coef_1;
}
}
}
}
}
void pooling3x3s1p1_max(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win) {
int kernel = 3;
int stride = 1;
int padding = 1;
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
int w_unroll_size = ((win - 2) >> 2) << 2;
int w_unroll_remain = win - 2 - w_unroll_size;
const float minval = std::numeric_limits<float>::lowest();
for (int n = 0; n < num; ++n) {
float* dout_batch = dout + n * chout * size_channel_out;
const float* din_batch = din + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* dout_ch = dout_batch + c * size_channel_out;
const float* din_ch = din_batch + c * size_channel_in;
const float* r0 = din_ch;
const float* r1 = r0 + win;
const float* r2 = r1 + win;
int cnt_num = w_unroll_size >> 2; // w_unroll_size / 4
float* dr_out = dout_ch;
const float* dr0 = r0;
const float* dr1 = r1;
const float* dr2 = r2;
int w = 0;
int cnt = 1;
// left
dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1]));
// first row with zero pad
#ifdef __aarch64__
for (; w < w_unroll_size; w += 4) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234);
float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678);
float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1);
float32x4_t vmax_3456 = vextq_f32(vmax_1234, vmax_5678, 2);
float32x2_t vmax_12_34 =
vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234));
float32x2_t vmax_23_45 =
vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345));
float32x2_t vmax_34_56 =
vpmax_f32(vget_low_f32(vmax_3456), vget_high_f32(vmax_3456));
float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45);
float32x2_t vmax_234_456 = vmax_f32(vmax_23_45, vmax_34_56);
float32x4_t vmax = vdupq_n_f32(vget_lane_f32(vmax_123_345, 0));
vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1);
vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2);
vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3);
vst1q_f32(&dout_ch[cnt], vmax);
cnt += 4;
}
#else
dr_out = dr_out + 1;
if (cnt_num > 0) {
asm volatile(
"1: @main loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n"
"vmax.f32 q5, q0, q2 @max r0_1234,r1_1234\n"
"vmax.f32 d12, d2, d6 @max r0_5678,r1_5678\n"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q5, q6, #1 @vext max_2345\n"
"vext.f32 q2, q5, q6, #2 @vext max_3456\n"
"vpmax.f32 d2, d10, d11 @pmax d4,max_1234,max_1234\n"
"vpmax.f32 d3, d0, d1 @pmax d4,max_2345,max_2345\n"
"vpmax.f32 d6, d4, d5 @pmax d6,max_3456,max_3456\n"
"vmax.f32 d8, d2, d3 @max d2,vmax_12_34,vmax_23_45\n"
"vmax.f32 d9, d3, d6 @max d2,vmax_23_45,vmax_34_56\n"
"sub %[dr0], #8 @sub w,8\n"
"sub %[dr1], #8 @sub w,8\n"
// swap
"vmov.f32 s0, s17 @mov\n"
"vmov.f32 s17, s18 @mov\n"
"vmov.f32 s18, s0 @mov\n"
"subs %[cnt_num], #1 @subs cnt_num,#1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n"
"bne 1b @bne s1_max_loop\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6");
}
#endif
// remain
w = w_unroll_size;
for (int j = 0; j < w_unroll_remain; j++) {
float tmp_max = std::max(r0[j + w], r1[j + w]);
tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1]));
tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2]));
dout_ch[j + w + 1] = tmp_max;
}
// right
float tmp = std::max(r0[win - 2], r1[win - 2]);
tmp = std::max(tmp, std::max(r0[win - 1], r1[win - 1]));
dout_ch[wout - 1] = tmp;
// r0 = r1;
// r1 = r0 + w_in;
// r2 = r1 + w_in;
dout_ch += wout;
int h = 0;
for (; h < hin - 2; h += 1) {
// deal with left pad
float maxr0 = std::max(r0[0], r0[1]);
float maxr1 = std::max(r1[0], r1[1]);
float maxr2 = std::max(r2[0], r2[1]);
dout_ch[0] = std::max(std::max(maxr0, maxr1), maxr2);
#ifdef __aarch64__
w = 0;
cnt = 1;
for (; w < w_unroll_size; w += 4) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr2_1234 = vld1q_f32(&r2[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]);
float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234);
vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234);
float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678);
vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678);
float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1);
float32x4_t vmax_3456 = vextq_f32(vmax_1234, vmax_5678, 2);
float32x2_t vmax_12_34 =
vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234));
float32x2_t vmax_23_45 =
vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345));
float32x2_t vmax_34_56 =
vpmax_f32(vget_low_f32(vmax_3456), vget_high_f32(vmax_3456));
float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45);
float32x2_t vmax_234_456 = vmax_f32(vmax_23_45, vmax_34_56);
float32x4_t vmax = vdupq_n_f32(vget_lane_f32(vmax_123_345, 0));
vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1);
vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2);
vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3);
vst1q_f32(&dout_ch[cnt], vmax);
cnt += 4;
}
#else
dr_out = dout_ch + 1;
dr0 = r0;
dr1 = r1;
dr2 = r2;
cnt_num = w_unroll_size >> 2;
if (cnt_num > 0) {
asm volatile(
"1: @main loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7,dr1\n"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d10}, [%[dr2]]! @load d4-d7, dr1\n"
"vmax.f32 q7, q0, q2 @max r0_1234,r1_1234\n"
"vmax.f32 d16, d2, d6 @max r0_5678,r1_5678\n"
"vmax.f32 q3, q7, q4 @max r0_1234,r1_1234\n"
"vmax.f32 d12, d16, d10 @max r0_5678,r1_5678\n"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q3, q6, #1 @vext max_2345\n"
"vext.f32 q2, q3, q6, #2 @vext max_3456\n"
"vpmax.f32 d2, d6, d7 @pmax d4,max_1234,max_1234\n"
"vpmax.f32 d3, d0, d1 @pmax d4,max_2345,max_2345\n"
"vpmax.f32 d6, d4, d5 @pmax d6,max_3456,max_3456\n"
"vmax.f32 d8, d2, d3 @max d2,vmax_12_34,vmax_23_45\n"
"vmax.f32 d9, d3, d6 @max d2,vmax_23_45,vmax_34_56\n"
"sub %[dr0], #8 @sub w,8\n"
"sub %[dr1], #8 @sub w,8\n"
"sub %[dr2], #8 @sub w,8\n"
// swap
"vmov.f32 s0, s17 @mov\n"
"vmov.f32 s17, s18 @mov\n"
"vmov.f32 s18, s0 @mov\n"
"subs %[cnt_num], #1 @subs cnt_num,#1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n"
"bne 1b @bne s1_max_loop\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2),
[dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8");
}
#endif
// remain
w = w_unroll_size;
for (int j = 0; j < w_unroll_remain; j++) {
float tmp_max = std::max(r0[j + w], r1[j + w]);
tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1]));
tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2]));
tmp_max = std::max(tmp_max, std::max(r2[j + w], r2[j + w + 1]));
tmp_max = std::max(tmp_max, r2[j + w + 2]);
dout_ch[j + w + 1] = tmp_max;
}
// right
tmp = std::max(r0[win - 2], r1[win - 2]);
tmp = std::max(tmp, std::max(r0[win - 1], r1[win - 1]));
tmp = std::max(tmp, std::max(r2[win - 2], r2[win - 1]));
dout_ch[wout - 1] = tmp;
r0 = r1;
r1 = r2;
r2 = r1 + win;
dout_ch += wout;
}
// the last two line
float maxr0 = std::max(r0[0], r0[1]);
float maxr1 = std::max(r1[0], r1[1]);
dout_ch[0] = std::max(maxr0, maxr1);
#ifdef __aarch64__
w = 0;
cnt = 1;
for (; w < w_unroll_size; w += 4) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234);
float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678);
float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1);
float32x4_t vmax_3456 = vextq_f32(vmax_1234, vmax_5678, 2);
float32x2_t vmax_12_34 =
vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234));
float32x2_t vmax_23_45 =
vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345));
float32x2_t vmax_34_56 =
vpmax_f32(vget_low_f32(vmax_3456), vget_high_f32(vmax_3456));
float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45);
float32x2_t vmax_234_456 = vmax_f32(vmax_23_45, vmax_34_56);
float32x4_t vmax = vdupq_n_f32(vget_lane_f32(vmax_123_345, 0));
vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 0), vmax, 1);
vmax = vsetq_lane_f32(vget_lane_f32(vmax_123_345, 1), vmax, 2);
vmax = vsetq_lane_f32(vget_lane_f32(vmax_234_456, 1), vmax, 3);
vst1q_f32(&dout_ch[cnt], vmax);
cnt += 4;
}
#else
dr_out = dout_ch + 1;
dr0 = r0;
dr1 = r1;
cnt_num = w_unroll_size >> 2;
if (cnt_num > 0) {
asm volatile(
"1: @main loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n"
"vmax.f32 q5, q0, q2 @max r0_1234,r1_1234\n"
"vmax.f32 d12, d2, d6 @max r0_5678,r1_5678\n"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q5, q6, #1 @vext max_2345\n"
"vext.f32 q2, q5, q6, #2 @vext max_3456\n"
"vpmax.f32 d2, d10, d11 @pmax d4,max_1234,max_1234\n"
"vpmax.f32 d3, d0, d1 @pmax d4,max_2345,max_2345\n"
"vpmax.f32 d6, d4, d5 @pmax d6,max_3456,max_3456\n"
"vmax.f32 d8, d2, d3 @max d2,vmax_12_34,vmax_23_45\n"
"vmax.f32 d9, d3, d6 @max d2,vmax_23_45,vmax_34_56\n"
"sub %[dr0], #8 @sub w,8\n"
"sub %[dr1], #8 @sub w,8\n"
// swap
"vmov.f32 s0, s17 @mov\n"
"vmov.f32 s17, s18 @mov\n"
"vmov.f32 s18, s0 @mov\n"
"subs %[cnt_num], #1 @subs cnt_num,#1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n"
"bne 1b @bne s1_max_loop\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6");
}
#endif
// remian
w = w_unroll_size;
for (int j = 0; j < w_unroll_remain; j++) {
float tmp_max = std::max(r0[j + w], r1[j + w]);
tmp_max = std::max(tmp_max, std::max(r0[j + w + 1], r1[j + w + 1]));
tmp_max = std::max(tmp_max, std::max(r0[j + w + 2], r1[j + w + 2]));
dout_ch[j + w + 1] = tmp_max;
}
tmp = std::max(r0[win - 2], r1[win - 2]);
tmp = std::max(tmp, std::max(r0[win - 1], r1[win - 1]));
dout_ch[wout - 1] = tmp;
}
}
}
void pooling3x3s1p1_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
bool exclusive) {
int kernel = 3;
int stride = 1;
int padding = 1;
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
int w_unroll_size = ((win - 2) >> 2) << 2;
int w_unroll_remain = win - 2 - w_unroll_size;
const float coef = 1.f / 9.f;
const float coef_2 = exclusive ? 1.f / 2.f : coef;
const float coef_4 = exclusive ? 1.f / 4.f : coef;
const float coef_6 = exclusive ? 1.f / 6.f : coef;
float32x4_t vcoef = vdupq_n_f32(coef);
float32x4_t vcoef_2 = vdupq_n_f32(coef_2);
float32x4_t vcoef_4 = vdupq_n_f32(coef_4);
float32x4_t vcoef_6 = vdupq_n_f32(coef_6);
for (int n = 0; n < num; ++n) {
float* dout_batch = dout + n * chout * size_channel_out;
const float* din_batch = din + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* dout_ch = dout_batch + c * size_channel_out;
const float* din_ch = din_batch + c * size_channel_in;
const float* r0 = din_ch;
const float* r1 = r0 + win;
const float* r2 = r1 + win;
int cnt_num = w_unroll_size >> 2; // w_unroll_size / 4
float* dr_out = dout_ch;
const float* dr0 = r0;
const float* dr1 = r1;
const float* dr2 = r2;
int w = 0;
int cnt = 1;
// left
dout_ch[0] = (r0[0] + r0[1] + r1[0] + r1[1]) * coef_4;
// first row with zero pad
#ifdef __aarch64__
for (; w < w_unroll_size; w += 4) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234);
float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678);
float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1);
float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2);
float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345);
vsum = vaddq_f32(vsum, vsum_3456);
vsum = vmulq_f32(vsum, vcoef_6);
vst1q_f32(&dout_ch[cnt], vsum);
cnt += 4;
}
#else
dr_out = dr_out + 1;
if (cnt_num > 0) {
asm volatile(
"1: @main loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n"
"vadd.f32 q5, q0, q2 @max r0_1234,r1_1234\n"
"vadd.f32 d12, d2, d6 @max r0_5678,r1_5678\n"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q5, q6, #1 @vext max_2345\n"
"vext.f32 q2, q5, q6, #2 @vext max_3456\n"
"vadd.f32 q1, q5, q0 @add 1234+2345\n"
"vadd.f32 q1, q1, q2 @add + 3456\n"
"vmul.f32 q4, q1, %q[vcoef_6] @mul * 1/9.f\n"
"sub %[dr0], #8 @sub w,8\n"
"sub %[dr1], #8 @sub w,8\n"
"subs %[cnt_num], #1 @subs cnt_num,#1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n"
"bne 1b @bne s1_max_loop\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num), [vcoef_6] "+w"(vcoef_6)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6");
}
#endif
// remain
w = w_unroll_size;
for (int j = 0; j < w_unroll_remain; j++) {
float tmp_sum = r0[j + w] + r1[j + w];
tmp_sum += (r0[j + w + 1] + r1[j + w + 1]);
tmp_sum += (r0[j + w + 2] + r1[j + w + 2]);
dout_ch[j + w + 1] = tmp_sum * coef_6;
}
// right
float tmp = r0[win - 2] + r1[win - 2];
tmp += (r0[win - 1] + r1[win - 1]);
dout_ch[wout - 1] = tmp * coef_4;
// r0 = r1;
// r1 = r0 + w_in;
// r2 = r1 + w_in;
dout_ch += wout;
int h = 0;
for (; h < hin - 2; h += 1) {
// deal with left pad
float maxr0 = r0[0] + r0[1];
float maxr1 = r1[0] + r1[1];
float maxr2 = r2[0] + r2[1];
dout_ch[0] = (maxr0 + maxr1 + maxr2) * coef_6;
#ifdef __aarch64__
w = 0;
cnt = 1;
for (; w < w_unroll_size; w += 4) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr2_1234 = vld1q_f32(&r2[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]);
float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234);
vsum_1234 = vaddq_f32(vsum_1234, vr2_1234);
float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678);
vsum_5678 = vaddq_f32(vsum_5678, vr2_5678);
float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1);
float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2);
float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345);
vsum = vaddq_f32(vsum, vsum_3456);
vsum = vmulq_f32(vsum, vcoef);
vst1q_f32(&dout_ch[cnt], vsum);
cnt += 4;
}
#else
dr_out = dout_ch + 1;
dr0 = r0;
dr1 = r1;
dr2 = r2;
cnt_num = w_unroll_size >> 2;
if (cnt_num > 0) {
asm volatile(
"1: @main loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d8-d9}, [%[dr2]]! @load d4-d7,dr1\n"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d10}, [%[dr2]]! @load d4-d7,dr1\n"
"vadd.f32 q7, q0, q2 @max r0_1234,r1_1234\n"
"vadd.f32 d16, d2, d6 @max r0_5678,r1_5678\n"
"vadd.f32 q3, q7, q4 @max r0_1234,r1_1234\n"
"vadd.f32 d12, d16, d10 @max r0_5678,r1_5678\n"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q3, q6, #1 @vext max_2345\n"
"vext.f32 q2, q3, q6, #2 @vext max_3456\n"
"vadd.f32 q1, q3, q0 @add 1234+2345\n"
"vadd.f32 q1, q1, q2 @add+3456\n"
"vmul.f32 q4, q1, %q[vcoef] @mul*1/9.f\n"
"sub %[dr0], #8 @sub w,8\n"
"sub %[dr1], #8 @sub w,8\n"
"sub %[dr2], #8 @sub w,8\n"
"subs %[cnt_num], #1 @subs cnt_num,#1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n"
"bne 1b @bne s1_max_loop\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2),
[dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num),
[vcoef] "+w"(vcoef)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8");
}
#endif
// remain
w = w_unroll_size;
for (int j = 0; j < w_unroll_remain; j++) {
float tmp_sum = r0[j + w] + r1[j + w];
tmp_sum += (r0[j + w + 1] + r1[j + w + 1]);
tmp_sum += (r0[j + w + 2] + r1[j + w + 2]);
tmp_sum += (r2[j + w + 1] + r2[j + w + 2]);
tmp_sum += r2[j + w];
dout_ch[j + w + 1] = tmp_sum * coef;
}
// right
tmp = r0[win - 2] + r1[win - 2];
tmp += (r0[win - 1] + r1[win - 1]);
tmp += (r2[win - 2] + r2[win - 1]);
dout_ch[wout - 1] = tmp * coef_6;
r0 = r1;
r1 = r2;
r2 = r1 + win;
dout_ch += wout;
}
// last line
float maxr0 = (r0[0] + r0[1]);
float maxr1 = (r1[0] + r1[1]);
dout_ch[0] = (maxr0 + maxr1) * coef_4;
#ifdef __aarch64__
w = 0;
cnt = 1;
for (; w < w_unroll_size; w += 4) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234);
float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678);
float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1);
float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2);
float32x4_t vsum = vaddq_f32(vsum_1234, vsum_2345);
vsum = vaddq_f32(vsum, vsum_3456);
vsum = vmulq_f32(vsum, vcoef_6);
vst1q_f32(&dout_ch[cnt], vsum);
cnt += 4;
}
#else
dr_out = dout_ch + 1;
dr0 = r0;
dr1 = r1;
cnt_num = w_unroll_size >> 2;
if (cnt_num > 0) {
asm volatile(
"1: @main loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d4-d5}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d2}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d6}, [%[dr1]]! @load d4-d7,dr1\n"
"vadd.f32 q5, q0, q2 @max r0_1234,r1_1234\n"
"vadd.f32 d12, d2, d6 @max r0_5678,r1_5678\n"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q5, q6, #1 @vext max_2345\n"
"vext.f32 q2, q5, q6, #2 @vext max_3456\n"
"vadd.f32 q1, q5, q0 @add 1234+2345\n"
"vadd.f32 q1, q1, q2 @add + 3456\n"
"vmul.f32 q4, q1, %q[vcoef_6] @mul * 1/9.f\n"
"sub %[dr0], #8 @sub w,8\n"
"sub %[dr1], #8 @sub w,8\n"
"subs %[cnt_num], #1 @subs cnt_num,#1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n"
"bne 1b @bne s1_max_loop\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num), [vcoef_6] "+w"(vcoef_6)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6");
}
#endif
// remain
w = w_unroll_size;
for (int j = 0; j < w_unroll_remain; j++) {
float tmp_sum = r0[j + w] + r1[j + w];
tmp_sum += (r0[j + w + 1] + r1[j + w + 1]);
tmp_sum += (r0[j + w + 2] + r1[j + w + 2]);
dout_ch[j + w + 1] = tmp_sum * coef_6;
}
// right
tmp = r0[win - 2] + r1[win - 2];
tmp += (r0[win - 1] + r1[win - 1]);
dout_ch[wout - 1] = tmp * coef_4;
}
}
}
void pooling3x3s2p1_max(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win) {
int kernel = 3;
int stride = 2;
int padding = 1;
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
int w_needed = (wout << 1) + 1;
int h_needed = (hout << 1) + 1;
int w_limit = w_needed > win ? win : w_needed;
int h_limit = h_needed > hin ? hin : h_needed;
int w_even = (w_limit >> 1) << 1;
int h_even = (h_limit >> 1) << 1;
int w_unroll_size = ((w_even - 1) >> 3) << 3;
int w_unroll_remain = w_even - 1 - w_unroll_size;
int w_remain = w_needed - w_limit - padding;
int h_remain = h_needed - h_limit - padding;
int w_in_2 = win << 1;
float minval = std::numeric_limits<float>::lowest();
for (int n = 0; n < num; ++n) {
float* dout_batch = dout + n * chout * size_channel_out;
const float* din_batch = din + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* dout_ch = dout_batch + c * size_channel_out;
const float* din_ch = din_batch + c * size_channel_in;
const float* r0 = din_ch;
const float* r1 = r0 + win;
const float* r2 = r1 + win;
int cnt_num = w_unroll_size >> 3;
int cnt_num_remain = w_unroll_remain >> 1;
float* dr_out = dout_ch;
const float* dr0 = r0;
const float* dr1 = r1;
const float* dr2 = r2;
int w = 1;
int cnt = 1;
dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1]));
// first row with zero pad
#if __aarch64__
for (; w < w_unroll_size; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234);
float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678);
float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112);
float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1);
float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1);
float32x2_t vmax_12_34 =
vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234));
float32x2_t vmax_23_45 =
vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345));
float32x2_t vmax_56_78 =
vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678));
float32x2_t vmax_67_89 =
vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789));
float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45);
float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89);
vst1_f32(&dout_ch[cnt], vmax_123_345);
vst1_f32(&dout_ch[cnt + 2], vmax_567_789);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
vmax2 = vpmax_f32(vmax2, vmax2);
dout_ch[cnt] = vget_lane_f32(vmax2, 0);
cnt++;
}
#else
dr0 = dr0 + 1;
dr1 = dr1 + 1;
dr_out = dr_out + 1;
// LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " <<
// cnt_num_remain;
if (cnt_num > 0 || cnt_num_remain > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num,0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n"
"vmax.f32 q6, q0, q3 @max r0_1234,r1_1234\n"
"vmax.f32 q7, q1, q4 @max r0_5678,r1_5678\n"
"vmax.f32 q8, q2, q5 @max r0_9101112,r1_9101112\n"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345\n"
"vext.f32 q1, q7, q8, #1 @vext max_6789\n"
"vpmax.f32 d4, d12, d13 @pmax d4,vmax_1234,vmax_1234\n"
"vpmax.f32 d6, d14, d15 @pmax d6,vmax_5678,vmax_5678\n"
"vpmax.f32 d5, d0, d1 @pmax d5,vmax_2345,vmax_2345\n"
"vpmax.f32 d7, d2, d3 @pmax d7,vmax_6789,vmax_6789\n"
"vmax.f32 d8, d4, d5 @max d2,vmax_12_34,vmax_23_45\n"
"vmax.f32 d9, d6, d7 @max d2,vmax_56_78,vmax_67_89\n"
"sub %[dr0], #16 @add w,8\n"
"sub %[dr1], #16 @add w, 8\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n"
"subs %[cnt_num], #1 @subs cnt_num, #1\n"
"bne 1b @bne s3_max_loop\n"
"3: @loop \n"
"cmp %[cnt_num_remain], #0 @cmp cnt_num,0\n"
"ble 4f @ble exit\n"
"2: @main loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n"
"vmov.f32 s3,s2 @movs3,s2\n"
"vmov.f32 s7,s6 @movs7,s6\n"
"vmax.f32 q0, q0, q1 @max q0,q0,q1\n"
"vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n"
"vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n"
"sub %[dr0], #8 @add w,6\n"
"sub %[dr1], #8 @add w,6\n"
"subs %[cnt_num_remain], #1 @subs cnt_num,#1\n"
"bne 2b @bne s3_max_loop_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num), [cnt_num_remain] "+r"(cnt_num_remain)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9");
}
#endif
// int w = w_even - 1;
if (w_remain > 0) {
// deal with right pad
int wstart = (w_even >> 1) * stride - padding;
int wend = std::min(std::min(wstart + kernel, win + padding), win);
float tmp = r0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) { // only run 1 or 2 times
tmp = std::max(tmp, std::max(r0[i], r1[i]));
}
dout_ch[w_even >> 1] = tmp;
// cnt ++;
}
r0 = r1;
r1 = r0 + win;
r2 = r1 + win;
dout_ch += wout;
int h = 2;
for (; h < h_even; h += 2) {
// deal with left pad
float maxr0 = std::max(r0[0], r0[1]);
float maxr1 = std::max(r1[0], r1[1]);
float maxr2 = std::max(r2[0], r2[1]);
dout_ch[0] = std::max(std::max(maxr0, maxr1), maxr2);
#if __aarch64__
w = 1;
cnt = 1;
for (; w < w_unroll_size; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vr2_1234 = vld1q_f32(&r2[w]);
float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]);
float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]);
float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234);
vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234);
float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678);
vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678);
float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112);
vmax_9101112 = vmaxq_f32(vmax_9101112, vr2_9101112);
float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1);
float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1);
float32x2_t vmax_12_34 =
vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234));
float32x2_t vmax_23_45 =
vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345));
float32x2_t vmax_56_78 =
vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678));
float32x2_t vmax_67_89 =
vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789));
float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45);
float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89);
vst1_f32(&dout_ch[cnt], vmax_123_345);
vst1_f32(&dout_ch[cnt + 2], vmax_567_789);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
float32x4_t vr2 = vld1q_f32(&r2[w]);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
vr2 = vsetq_lane_f32(minval, vr2, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
vmax1 = vmaxq_f32(vmax1, vr2);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
float32x2_t vmax = vpmax_f32(vmax2, vmax2);
dout_ch[cnt] = vget_lane_f32(vmax, 0);
cnt++;
}
#else
dr_out = dout_ch + 1;
dr0 = (r0 + 1);
dr1 = (r1 + 1);
dr2 = (r2 + 1);
cnt_num = w_unroll_size >> 3;
cnt_num_remain = w_unroll_remain >> 1;
if (cnt_num > 0 || cnt_num_remain > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num,0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7,dr1\n"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7,dr1\n"
"vmax.f32 q9, q0, q3 @max q0,q0,q2\n"
"vmax.f32 q10, q1, q4 @max q1,q1,q3\n"
"vmax.f32 q11, q2, q5 @max q1,q1,q3\n"
"vmax.f32 q0, q9, q6 @max q0,q0,q2 1234\n"
"vmax.f32 q3, q10, q7 @max q1,q1,q3 5678\n"
"vmax.f32 q1, q11, q8 @max q1,q1,q3 9101112\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q4, q0, q3, #1 @vext 2345\n"
"vext.f32 q2, q3, q1, #1 @vext 6789\n"
"vpmax.f32 d10, d0, d1 @pmax d10,vmax_1234,vmax_1234\n"
"vpmax.f32 d12, d6, d7 @pmax d12,vmax_5678,vmax_5678\n"
"vpmax.f32 d11, d8, d9 @pmax d11,vmax_2345,vmax_2345\n"
"vpmax.f32 d13, d4, d5 @pmax d13,vmax_6789,vmax_6789\n"
"vmax.f32 d0, d10, d11 @pmax d0,vmax_12_34,vmax_23_45\n"
"vmax.f32 d1, d12, d13 @pmax d1,vmax_56_78,vmax_67_89\n"
"sub %[dr0], #16 @add w,8\n"
"sub %[dr1], #16 @add w,8\n"
"sub %[dr2], #16 @add w,8\n"
"vst1.f32 d0, [%[dr_out]]! @vst1 d0,dr_out\n"
"vst1.f32 d1, [%[dr_out]]! @vst1 d0,dr_out\n"
"subs %[cnt_num], #1 @subs cnt_num,#1\n"
"bne 1b @bne s3_max_loop_mid\n"
"3: @loop \n"
"cmp %[cnt_num_remain], #0 @cmp cnt_num,0\n"
"ble 4f @ble exit1\n"
"2: @mid loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n"
"vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1\n"
"vmov.f32 s3,s2 @movs3,s2\n"
"vmov.f32 s7,s6 @movs7,s6\n"
"vmov.f32 s11,s10 @movs11,s10\n"
"vmax.f32 q0, q0, q1 @max q0,q0,q1\n"
"vmax.f32 q0, q0, q2 @max q0,q0,q2\n"
"vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n"
"vpmax.f32 d0, d0, d0 @pmax d0, d0,d0\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n"
"sub %[dr0], #8 @add w,6\n"
"sub %[dr1], #8 @add w,6\n"
"sub %[dr2], #8 @add w,6\n"
"subs %[cnt_num_remain], #1 @subs cnt_num,#1\n"
"bne 2b @bne s3_max_loop_mid_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2),
[dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num),
[cnt_num_remain] "+r"(cnt_num_remain)
: "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num),
"r"(cnt_num_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12");
}
#endif
if (w_remain > 0) {
// deal with right pad
int wstart = (w_even >> 1) * stride - padding;
int wend = std::min(std::min(wstart + kernel, win + padding), win);
float tmp = r0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) {
tmp = std::max(tmp, std::max(r0[i], r1[i]));
tmp = std::max(tmp, r2[i]);
}
dout_ch[w_even >> 1] = tmp;
// cnt ++;
}
r0 = r2;
r1 = r0 + win;
r2 = r1 + win;
dout_ch += wout;
}
if (h_remain > 0) {
// deal with bottom pad
// first row with zero pad
int hstart = (h >> 1) * stride - padding;
int hend = std::min(std::min(hstart + kernel, hin + padding), hin);
if (hstart == hend - 1) { // only one lline
dout_ch[0] = std::max(r0[0], r0[1]);
#if __aarch64__
w = 1;
cnt = 1;
for (; w < w_unroll_size; w += 8) {
float32x4_t vmax_1234 = vld1q_f32(&r0[w]);
float32x4_t vmax_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vmax_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1);
float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1);
float32x2_t vmax_12_34 =
vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234));
float32x2_t vmax_23_45 =
vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345));
float32x2_t vmax_56_78 =
vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678));
float32x2_t vmax_67_89 =
vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789));
float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45);
float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89);
vst1_f32(&dout_ch[cnt], vmax_123_345);
vst1_f32(&dout_ch[cnt + 2], vmax_567_789);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
vr0 = vsetq_lane_f32(minval, vr0, 3);
float32x2_t vmax = vpmax_f32(vget_low_f32(vr0), vget_high_f32(vr0));
vmax = vpmax_f32(vmax, vmax);
dout_ch[cnt] = vget_lane_f32(vmax, 0);
cnt++;
}
#else
dr_out = dout_ch + 1;
dr0 = (r0 + 1);
cnt_num = w_unroll_size >> 3;
cnt_num_remain = w_unroll_remain >> 1;
// LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " <<
// cnt_num_remain;
if (cnt_num > 0 || cnt_num_remain > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num,0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d3,dr0\n"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3,dr0\n"
"vext.f32 q4, q0, q1, #1 @vmax_2345\n"
"vext.f32 q5, q1, q2, #1 @vmax_6789\n"
"vpmax.f32 d12, d0, d1 @vmax_12_34\n"
"vpmax.f32 d14, d2, d3 @vmax_56_78\n"
"vpmax.f32 d13, d8, d9 @vmax_23_45\n"
"vpmax.f32 d15, d10, d11 @vmax_67_89\n"
"vmax.f32 d0, d12, d13 @12_34,23_45\n"
"vmax.f32 d1, d14, d15 @56_78,67_89\n"
"sub %[dr0], #16 @add w,6\n"
"vst1.f32 d0, [%[dr_out]]! @vst1 d0,dr_out\n"
"vst1.f32 d1, [%[dr_out]]! @vst1 d0,dr_out\n"
"subs %[cnt_num], #1 @subs cnt_num,#1\n"
"bne 1b @bne s3_max_loop_bot\n"
"3: @loop \n"
"cmp %[cnt_num_remain], #0 @cmp cnt_num,0\n"
"ble 4f @ble exit\n"
"2: @bot loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n"
"vmov.f32 s3,s2 @movs3, s2\n"
"vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n"
"vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n"
"sub %[dr0], #8 @add w,2\n"
"subs %[cnt_num_remain], #1 @subs cnt_num,#1\n"
"bne 2b @bne s3_max_loop_bot_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num), [cnt_num_remain] "+r"(cnt_num_remain)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num),
"r"(cnt_num_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
"q7", "q8");
}
#endif
if (w_remain > 0) {
// deal with right pad
int wstart = (w_even >> 1) * stride - padding;
int wend = std::min(std::min(wstart + kernel, win + padding), win);
float tmp = r0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) {
tmp = std::max(tmp, r0[i]);
}
dout_ch[w_even >> 1] = tmp;
}
} else { // two lines
dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0], r1[1]));
#ifdef __aarch64__
w = 1;
cnt = 1;
for (; w < w_unroll_size; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234);
float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678);
float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112);
float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1);
float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1);
float32x2_t vmax_12_34 =
vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234));
float32x2_t vmax_23_45 =
vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345));
float32x2_t vmax_56_78 =
vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678));
float32x2_t vmax_67_89 =
vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789));
float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45);
float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89);
vst1_f32(&dout_ch[cnt], vmax_123_345);
vst1_f32(&dout_ch[cnt + 2], vmax_567_789);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
vmax2 = vpmax_f32(vmax2, vmax2);
dout_ch[cnt] = vget_lane_f32(vmax2, 0);
cnt++;
}
#else
dr_out = dout_ch + 1;
dr0 = (r0 + 1);
dr1 = (r1 + 1);
cnt_num = w_unroll_size >> 3;
cnt_num_remain = w_unroll_remain >> 1;
// LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " <<
// cnt_num_remain;
if (cnt_num > 0 || cnt_num_remain > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num,0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3,dr0\n"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n"
"vmax.f32 q6, q0, q3 @max q0,q0,q2 1234\n"
"vmax.f32 q7, q1, q4 @max q1,q1,q3 5678\n"
"vmax.f32 q8, q2, q5 @max q1,q1,q3 9101112\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext q0,2345\n"
"vext.f32 q1, q7, q8, #1 @vext q1,6789\n"
"vpmax.f32 d4, d12, d13 @pmax "
"d4,vmax_1234,vmax_1234\n"
"vpmax.f32 d6, d14, d15 @pmax "
"d6,vmax_5678,vmax_5678\n"
"vpmax.f32 d5, d0, d1 @pmax "
"d5,vmax_2345,vmax_2345\n"
"vpmax.f32 d7, d2, d3 @pmax "
"d7,vmax_6789,vmax_6789\n"
"vmax.f32 d8, d4, d5 @max "
"d2,vmax_12_34,vmax_23_45\n"
"vmax.f32 d9, d6, d7 @max "
"d2,vmax_56_78,vmax_67_89\n"
"sub %[dr0], #16 @add w,8\n"
"sub %[dr1], #16 @add w,8\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n"
"subs %[cnt_num], #1 @subs cnt_num,#1\n"
"bne 1b @bne s3_max_loop_bot\n"
"3: @loop \n"
"cmp %[cnt_num_remain], #0 @cmp cnt_num,0\n"
"ble 4f @ble exit\n"
"2: @bot loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n"
"vmov.f32 s3,s2 @movs3, s2\n"
"vmov.f32 s7,s6 @movs7, s6\n"
"vmax.f32 q0, q0, q1 @max q0,q0,q1\n"
"vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n"
"vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n"
"sub %[dr0], #8 @add w,6\n"
"sub %[dr1], #8 @add w,6\n"
"subs %[cnt_num_remain], #1 @subs cnt_num,#1\n"
"bne 2b @bne s3_max_loop_bot_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num), [cnt_num_remain] "+r"(cnt_num_remain)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num),
"r"(cnt_num_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
"q7", "q8", "q9");
}
#endif
if (w_remain > 0) {
// deal with right pad
int wstart = (w_even >> 1) * stride - padding;
int wend = std::min(std::min(wstart + kernel, win + padding), win);
float tmp = r0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) { // only run 1 or 2 times
tmp = std::max(tmp, std::max(r0[i], r1[i]));
}
dout_ch[w_even >> 1] = tmp;
}
}
}
}
}
}
void pooling3x3s2p1_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
bool exclusive) {
int kernel = 3;
int stride = 2;
int padding = 1;
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
int w_needed = (wout << 1) + 1;
int h_needed = (hout << 1) + 1;
int w_limit = w_needed > win ? win : w_needed;
int h_limit = h_needed > hin ? hin : h_needed;
int w_even = (w_limit >> 1) << 1;
int h_even = (h_limit >> 1) << 1;
int w_unroll_size = ((w_even - 1) >> 3) << 3;
int w_unroll_remain = w_even - 1 - w_unroll_size;
int w_remain = w_needed - w_limit - padding;
int h_remain = h_needed - h_limit - padding;
int w_in_2 = win << 1;
const float coef = 1.f / 9.f;
const float coef_1 = exclusive ? 1.f : coef;
const float coef_2 = exclusive ? 1.f / 2.f : coef;
const float coef_3 = exclusive ? 1.f / 3.f : coef;
const float coef_4 = exclusive ? 1.f / 4.f : coef;
const float coef_6 = exclusive ? 1.f / 6.f : coef;
float32x4_t vcoef = vdupq_n_f32(coef);
float32x4_t vcoef_1 = vdupq_n_f32(coef_1);
float32x4_t vcoef_2 = vdupq_n_f32(coef_2);
float32x4_t vcoef_3 = vdupq_n_f32(coef_3);
float32x4_t vcoef_4 = vdupq_n_f32(coef_4);
float32x4_t vcoef_6 = vdupq_n_f32(coef_6);
for (int n = 0; n < num; ++n) {
float* dout_batch = dout + n * chout * size_channel_out;
const float* din_batch = din + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* dout_ch = dout_batch + c * size_channel_out;
const float* din_ch = din_batch + c * size_channel_in;
const float* r0 = din_ch;
const float* r1 = r0 + win;
const float* r2 = r1 + win;
int cnt_num = w_unroll_size >> 3;
int cnt_num_remain = w_unroll_remain >> 1;
float* dr_out = dout_ch;
const float* dr0 = r0;
const float* dr1 = r1;
const float* dr2 = r2;
int w = 1;
int cnt = 1;
float32x4_t vzero = vdupq_n_f32(0.f);
dout_ch[0] = (r0[0] + r0[1] + r1[0] + r1[1]) * coef_4;
// first row with zero pad
#ifdef __aarch64__
for (; w < w_unroll_size; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234);
float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678);
float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112);
float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1);
float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2);
float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3);
float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1);
float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345);
vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456);
float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678);
vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3);
float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef_6);
vst1q_f32(&dout_ch[cnt], vrst);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
vr0 = vsetq_lane_f32(0.f, vr0, 3);
vr1 = vsetq_lane_f32(0.f, vr1, 3);
float32x4_t vsum1 = vaddq_f32(vr0, vr1);
float32x2_t vsum2 =
vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1));
vsum2 = vpadd_f32(vsum2, vsum2);
float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef_6));
dout_ch[cnt] = vget_lane_f32(vrst, 0);
cnt++;
}
#else
dr0 = dr0 + 1;
dr1 = dr1 + 1;
dr_out = dr_out + 1;
// LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " <<
// cnt_num_remain;
if (cnt_num > 0 || cnt_num_remain > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num,0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n"
"vadd.f32 q6, q0, q3 @max r0_1234,r1_1234\n"
"vadd.f32 q7, q1, q4 @max r0_5678,r1_5678\n"
"vadd.f32 q8, q2, q5 @max r0_9101112,r1_9101112\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345\n"
"vext.f32 q1, q6, q7, #3 @vext max_4567\n"
"vext.f32 q2, q6, q7, #2 @vext max_3456\n"
"vext.f32 q3, q7, q8, #1 @vext max_6789\n"
"vadd.f32 q4, q6, q0 @add 1234, 2345\n"
"vadd.f32 q5, q7, q1 @add 5678, 4567\n"
"vadd.f32 q4, q4, q2 @add 3456, sum1\n"
"vadd.f32 q5, q5, q3 @add 6789, sum2\n"
"vmov.f32 s17, s18 @mov\n"
"vmov.f32 s18, s21 @mov\n"
"vmov.f32 s19, s23 @mov\n"
"vmul.f32 q4, q4, %q[vcoef_6] @mul\n"
"sub %[dr0], #16 @add w,8\n"
"sub %[dr1], #16 @add w,8\n"
"subs %[cnt_num], #1 @subs cnt_num,#1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n"
"bne 1b @bne s3_max_loop\n"
"3: @loop\n"
"cmp %[cnt_num_remain], #0 @cnt_num_remain<=0\n"
"ble 4f @ble exit\n"
"2: @main loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n"
"vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n"
"vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n"
"vadd.f32 q0, q0, q1 @add q0,q0,q1\n"
"vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n"
"vpadd.f32 d0, d0, d0 @padd d0, d0,d0\n"
"vmul.f32 d0, d0, %e[vcoef_6] @mul\n"
"sub %[dr0], #8 @add w,6\n"
"sub %[dr1], #8 @add w,6\n"
"subs %[cnt_num_remain], #1 @subs cnt_num,#1\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n"
"bne 2b @bne s3_max_loop_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num), [cnt_num_remain] "+r"(cnt_num_remain),
[vcoef_6] "+w"(vcoef_6), [vzero] "+w"(vzero)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num), "r"(cnt_num_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9");
}
#endif
// int w = w_even - 1;
if (w_remain > 0) {
// deal with right pad
int wstart = (w_even >> 1) * stride - padding;
int wend = std::min(std::min(wstart + kernel, win + padding), win);
float tmp1 = 0.f; // std::numeric_limits<float>::min();
float tmp2 = exclusive ? 1.0f / (2.f * (wend - wstart)) : coef;
for (int i = wstart; i < wend; i++) { // only run 1 or 2 times
tmp1 += (r0[i] + r1[i]);
}
dout_ch[w_even >> 1] = tmp1 * tmp2;
// cnt ++;
}
r0 = r1;
r1 = r0 + win;
r2 = r1 + win;
dout_ch += wout;
int h = 2;
for (; h < h_even; h += 2) {
// deal with left pad
float sum0 = r0[0] + r0[1];
float sum1 = r1[0] + r1[1];
float sum2 = r2[0] + r2[1];
dout_ch[0] = (sum0 + sum1 + sum2) * coef_6;
#ifdef __aarch64__
w = 1;
cnt = 1;
for (; w < w_unroll_size; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vr2_1234 = vld1q_f32(&r2[w]);
float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]);
float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]);
float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234);
float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678);
float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112);
vsum_1234 = vaddq_f32(vsum_1234, vr2_1234);
vsum_5678 = vaddq_f32(vsum_5678, vr2_5678);
vsum_9101112 = vaddq_f32(vsum_9101112, vr2_9101112);
float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1);
float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2);
float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3);
float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1);
float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345);
vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456);
float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678);
vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3);
float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef);
vst1q_f32(&dout_ch[cnt], vrst);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
float32x4_t vr2 = vld1q_f32(&r2[w]);
vr0 = vsetq_lane_f32(0.f, vr0, 3);
vr1 = vsetq_lane_f32(0.f, vr1, 3);
vr2 = vsetq_lane_f32(0.f, vr2, 3);
float32x4_t vsum1 = vaddq_f32(vr0, vr1);
vsum1 = vaddq_f32(vsum1, vr2);
float32x2_t vsum2 =
vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1));
float32x2_t vsum = vpadd_f32(vsum2, vsum2);
dout_ch[cnt] = vget_lane_f32(vsum, 0) * coef;
cnt++;
}
#else
dr_out = dout_ch + 1;
dr0 = (r0 + 1);
dr1 = (r1 + 1);
dr2 = (r2 + 1);
cnt_num = w_unroll_size >> 3;
cnt_num_remain = w_unroll_remain >> 1;
if (cnt_num > 0 || cnt_num_remain > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num,0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, "
"dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7,dr1\n"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d16-d17}, [%[dr2]]! @load d4-d7,dr1\n"
"vadd.f32 q9, q0, q3 @max q0,q0,q2\n"
"vadd.f32 q10, q1, q4 @max q1,q1,q3\n"
"vadd.f32 q11, q2, q5 @max q1,q1,q3\n"
"vadd.f32 q6, q9, q6 @max q0,q0,q2 1234\n"
"vadd.f32 q7, q10, q7 @max q1,q1,q3 5678\n"
"vadd.f32 q8, q11, q8 @max q1,q1,q3 9101112\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345\n"
"vext.f32 q1, q6, q7, #3 @vext max_4567\n"
"vext.f32 q2, q6, q7, #2 @vext max_3456\n"
"vext.f32 q3, q7, q8, #1 @vext max_6789\n"
"vadd.f32 q4, q6, q0 @add 1234,2345\n"
"vadd.f32 q5, q7, q1 @add 5678,4567\n"
"vadd.f32 q4, q4, q2 @add 3456,sum1\n"
"vadd.f32 q5, q5, q3 @add 6789,sum2\n"
"vmov.f32 s17, s18 @mov\n"
"vmov.f32 s18, s21 @mov\n"
"vmov.f32 s19, s23 @mov\n"
"vmul.f32 q4, q4, %q[vcoef] @mul\n"
"sub %[dr0], #16 @add w,8\n"
"sub %[dr1], #16 @add w,8\n"
"sub %[dr2], #16 @add w, 8\n"
"subs %[cnt_num], #1 @subs cnt_num,#1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n"
"bne 1b @bne s3_max_loop_mid\n"
"3: @loop\n"
"cmp %[cnt_num_remain], #0 @cnt_num_remain<=0\n"
"ble 4f @ble exit1\n"
"2: @mid loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n"
"vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1\n"
"vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n"
"vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n"
"vext.f32 q2, %q[vzero], q2, #3 @ext v1_0123\n"
"vadd.f32 q0, q0, q1 @add q0,q0,q1\n"
"vadd.f32 q0, q0, q2 @add q0,q0,q1\n"
"vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n"
"vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n"
"vmul.f32 d0, d0, %e[vcoef] @mul\n"
"sub %[dr0], #8 @add w,6\n"
"sub %[dr1], #8 @add w,6\n"
"sub %[dr2], #8 @add w,6\n"
"subs %[cnt_num_remain], #1 @cnt_num_remain--\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n"
"bne 2b @bne s3_max_loop_mid_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2),
[dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num),
[cnt_num_remain] "+r"(cnt_num_remain), [vcoef] "+w"(vcoef),
[vzero] "+w"(vzero)
: "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num),
"r"(cnt_num_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12");
}
#endif
if (w_remain > 0) {
// deal with right pad
int wstart = (w_even >> 1) * stride - padding;
int wend = std::min(std::min(wstart + kernel, win + padding), win);
float tmp1 = 0.f;
float tmp2 = exclusive ? 1.0f / (3.f * (wend - wstart)) : coef;
for (int i = wstart; i < wend; i++) {
tmp1 += (r0[i] + r1[i] + r2[i]);
}
dout_ch[w_even >> 1] = tmp1 * tmp2;
// cnt ++;
}
r0 = r2;
r1 = r0 + win;
r2 = r1 + win;
dout_ch += wout;
}
if (h_remain > 0) {
// deal with bottom pad
// first row with zero pad
int hstart = (h >> 1) * stride - padding;
int hend = std::min(std::min(hstart + kernel, hin + padding), hin);
if (hstart == hend - 1) { // only one line
dout_ch[0] = (r0[0] + r0[1]) * coef_2;
#ifdef __aarch64__
w = 1;
cnt = 1;
for (; w < w_unroll_size; w += 8) {
float32x4_t vsum_1234 = vld1q_f32(&r0[w]);
float32x4_t vsum_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vsum_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1);
float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2);
float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3);
float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1);
float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345);
vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456);
float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678);
vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789);
vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2),
vsum_123_345, 1);
vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1),
vsum_123_345, 2);
vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3),
vsum_123_345, 3);
float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef_3);
vst1q_f32(&dout_ch[cnt], vrst);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
vr0 = vsetq_lane_f32(0.f, vr0, 3);
float32x2_t vsum = vpadd_f32(vget_low_f32(vr0), vget_high_f32(vr0));
vsum = vpadd_f32(vsum, vsum);
dout_ch[cnt] = vget_lane_f32(vsum, 0) * coef_3;
cnt++;
}
#else
dr_out = dout_ch + 1;
dr0 = (r0 + 1);
cnt_num = w_unroll_size >> 3;
cnt_num_remain = w_unroll_remain >> 1;
if (cnt_num > 0 || cnt_num_remain > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num,0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d12-d15}, [%[dr0]]! @load d0-d3,dr0\n"
"vld1.f32 {d16-d17}, [%[dr0]]! @load d0-d3,dr0\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345\n"
"vext.f32 q1, q6, q7, #3 @vext max_4567\n"
"vext.f32 q2, q6, q7, #2 @vext max_3456\n"
"vext.f32 q3, q7, q8, #1 @vext max_6789\n"
"vadd.f32 q4, q6, q0 @add 1234,2345\n"
"vadd.f32 q5, q7, q1 @add 5678,4567\n"
"vadd.f32 q4, q4, q2 @add 3456,sum1\n"
"vadd.f32 q5, q5, q3 @add 6789,sum2\n"
"vmov.f32 s17, s18 @mov\n"
"vmov.f32 s18, s21 @mov\n"
"vmov.f32 s19, s23 @mov\n"
"vmul.f32 q4, q4, %q[vcoef_3] @mul\n"
"sub %[dr0], #16 @add w,6\n"
"subs %[cnt_num], #1 @subs cnt_num,#1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n"
"bne 1b @bne s3_max_loop_bot\n"
"3: @loop\n"
"cmp %[cnt_num_remain], #0 @cnt_num_remain<=0\n"
"ble 4f @ble exit\n"
"2: @bot loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n"
"vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n"
"vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n"
"vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n"
"vmul.f32 d0, d0, %e[vcoef_3] @mul\n"
"sub %[dr0], #8 @add w,2\n"
"subs %[cnt_num_remain], #1 @cnt_num_remain--\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n"
"bne 2b @bne s3_max_loop_bot_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num),
[cnt_num_remain] "+r"(cnt_num_remain),
[vcoef_3] "+w"(vcoef_3), [vzero] "+w"(vzero)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num),
"r"(cnt_num_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
"q7", "q8");
}
#endif
if (w_remain > 0) {
// deal with right pad
int wstart = (w_even >> 1) * stride - padding;
int wend = std::min(std::min(wstart + kernel, win + padding), win);
float tmp1 = 0.f;
float tmp2 = exclusive ? 1.0f / (1.f * (wend - wstart)) : coef;
for (int i = wstart; i < wend; i++) {
tmp1 += r0[i];
}
dout_ch[w_even >> 1] = tmp1 * tmp2;
}
} else { // two lines
dout_ch[0] = (r0[0] + r0[1] + r1[0] + r1[1]) * coef_4;
#ifdef __aarch64__
w = 1;
cnt = 1;
for (; w < w_unroll_size; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234);
float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678);
float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112);
float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1);
float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2);
float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3);
float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1);
float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345);
vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456);
float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678);
vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789);
vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2),
vsum_123_345, 1);
vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1),
vsum_123_345, 2);
vsum_123_345 = vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3),
vsum_123_345, 3);
float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef_6);
vst1q_f32(&dout_ch[cnt], vrst);
cnt += 4;
}
for (; w < w_even - 1; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
vr0 = vsetq_lane_f32(0.f, vr0, 3);
vr1 = vsetq_lane_f32(0.f, vr1, 3);
float32x4_t vsum1 = vaddq_f32(vr0, vr1);
float32x2_t vsum2 =
vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1));
vsum2 = vpadd_f32(vsum2, vsum2);
float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef_6));
dout_ch[cnt] = vget_lane_f32(vrst, 0);
cnt++;
}
#else
dr_out = dout_ch + 1;
dr0 = (r0 + 1);
dr1 = (r1 + 1);
cnt_num = w_unroll_size >> 3;
cnt_num_remain = w_unroll_remain >> 1;
if (cnt_num > 0 || cnt_num_remain > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num,0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d4-d5}, [%[dr0]]! @load d0-d3,dr0\n"
"vld1.f32 {d10-d11}, [%[dr1]]! @load d4-d7,dr1\n"
"vadd.f32 q6, q0, q3 @add q0,q0,q2 1234\n"
"vadd.f32 q7, q1, q4 @add q1,q1,q3 5678\n"
"vadd.f32 q8, q2, q5 @add q1,q1,q3 9101112\n"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345\n"
"vext.f32 q1, q6, q7, #3 @vext max_4567\n"
"vext.f32 q2, q6, q7, #2 @vext max_3456\n"
"vext.f32 q3, q7, q8, #1 @vext max_6789\n"
"vadd.f32 q4, q6, q0 @add 1234,2345\n"
"vadd.f32 q5, q7, q1 @add 5678,4567\n"
"vadd.f32 q4, q4, q2 @add 3456,sum1\n"
"vadd.f32 q5, q5, q3 @add 6789,sum2\n"
"vmov.f32 s17, s18 @mov\n"
"vmov.f32 s18, s21 @mov\n"
"vmov.f32 s19, s23 @mov\n"
"vmul.f32 q4, q4, %q[vcoef_6] @mul\n"
"sub %[dr0], #16 @add w,8\n"
"sub %[dr1], #16 @add w,8\n"
"subs %[cnt_num], #1 @subs cnt_num,#1\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0, dr_out\n"
"bne 1b @bne s3_max_loop_bot\n"
"3: @loop\n"
"cmp %[cnt_num_remain], #0 @cnt_num_remain<=0\n"
"ble 4f @ble exit\n"
"2: @bot loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n"
"vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n"
"vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n"
"vadd.f32 q0, q0, q1 @add q0,q0,q1\n"
"vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n"
"vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n"
"vmul.f32 d0, d0, %e[vcoef_6] @mul\n"
"sub %[dr0], #8 @add w,6\n"
"sub %[dr1], #8 @add w,6\n"
"subs %[cnt_num_remain], #1 @cnt_num_remain--\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n"
"bne 2b @bne s3_max_loop_bot_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num),
[cnt_num_remain] "+r"(cnt_num_remain),
[vcoef_6] "+w"(vcoef_6), [vzero] "+w"(vzero)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num),
"r"(cnt_num_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6",
"q7", "q8", "q9");
}
#endif
if (w_remain > 0) {
// deal with right pad
int wstart = (w_even >> 1) * stride - padding;
int wend = std::min(std::min(wstart + kernel, win + padding), win);
float tmp1 = 0.f;
float tmp2 = exclusive ? 1.0f / (2.f * (wend - wstart)) : coef;
for (int i = wstart; i < wend; i++) { // only run 1 or 2 times
tmp1 += (r0[i] + r1[i]);
}
dout_ch[w_even >> 1] = tmp1 * tmp2;
}
}
}
}
}
}
void pooling3x3s2p0_max(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win) {
int kernel = 3;
int stride = 2;
int padding = 0;
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
int w_needed = (wout << 1) + 1;
int h_needed = (hout << 1) + 1;
int w_limit = w_needed > win ? win : w_needed;
int h_limit = h_needed > hin ? hin : h_needed;
int w_even = ((w_limit - 1) >> 1) << 1;
int h_even = ((h_limit - 1) >> 1) << 1;
int w_unroll_size = (w_even >> 3) << 3;
int w_unroll_remain = w_even - w_unroll_size;
int w_remain = w_needed - w_limit;
int h_remain = h_needed - h_limit;
int w_in_2 = win << 1;
float minval = std::numeric_limits<float>::lowest();
for (int n = 0; n < num; ++n) {
float* dout_batch = dout + n * chout * size_channel_out;
const float* din_batch = din + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* dout_ch = dout_batch + c * size_channel_out;
const float* din_ch = din_batch + c * size_channel_in;
const float* r0 = din_ch;
const float* r1 = r0 + win;
const float* r2 = r1 + win;
// w = w_in - 8;
float* dr_out = dout_ch;
const float* dr0 = r0;
const float* dr1 = r1;
const float* dr2 = r2;
int w = 0;
int cnt = 0;
// dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0],
// r1[1]));
// first row with zero pad
// r0 = r1;
// r1 = r0 + w_in;
// r2 = r1 + w_in;
// dout_channel += w_out;
int h = 0;
for (; h < h_even; h += 2) {
// deal with left pad
float maxr0 = std::max(r0[0], r0[1]);
float maxr1 = std::max(r1[0], r1[1]);
float maxr2 = std::max(r2[0], r2[1]);
// dout_ch[0] = std::max(std::max(maxr0, maxr1), maxr2);
#ifdef __aarch64__
w = 0;
cnt = 0;
for (; w < w_unroll_size; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vr2_1234 = vld1q_f32(&r2[w]);
float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]);
float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]);
float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234);
vmax_1234 = vmaxq_f32(vmax_1234, vr2_1234);
float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678);
vmax_5678 = vmaxq_f32(vmax_5678, vr2_5678);
float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112);
vmax_9101112 = vmaxq_f32(vmax_9101112, vr2_9101112);
float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1);
float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1);
float32x2_t vmax_12_34 =
vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234));
float32x2_t vmax_23_45 =
vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345));
float32x2_t vmax_56_78 =
vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678));
float32x2_t vmax_67_89 =
vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789));
float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45);
float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89);
vst1_f32(&dout_ch[cnt], vmax_123_345);
vst1_f32(&dout_ch[cnt + 2], vmax_567_789);
cnt += 4;
}
for (; w < w_even; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
float32x4_t vr2 = vld1q_f32(&r2[w]);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
vr2 = vsetq_lane_f32(minval, vr2, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
vmax1 = vmaxq_f32(vmax1, vr2);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
float32x2_t vmax = vpmax_f32(vmax2, vmax2);
dout_ch[cnt] = vget_lane_f32(vmax, 0);
cnt++;
}
#else
dr_out = dout_ch; // + 1;
dr0 = r0; // (r0 + 1);
dr1 = r1; // (r1 + 1);
dr2 = r2; // (r2 + 1);
int cnt_num = w_unroll_size >> 3;
int cnt_num_remain = w_unroll_remain >> 1;
if (cnt_num > 0 || cnt_num_remain > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num,0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7,dr1\n"
"vld1.f32 {d4}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d10}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d16}, [%[dr2]]! @load d4-d7,dr1\n"
"vmax.f32 q9, q0, q3 @max q0,q0,q2\n"
"vmax.f32 q10, q1, q4 @max q1,q1,q3\n"
"vmax.f32 d22, d4, d10 @max q1,q1,q3\n"
"vmax.f32 q0, q9, q6 @max q0,q0,q2 1234\n"
"vmax.f32 q3, q10, q7 @max q1,q1,q3 5678\n"
"vmax.f32 d2, d22, d16 @max q1,q1,q3 9101112\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q4, q0, q3, #1 @vext 2345\n"
"vext.f32 q2, q3, q1, #1 @vext 6789\n"
"vpmax.f32 d10, d0, d1 @pmax "
"d10,vmax_1234,vmax_1234\n"
"vpmax.f32 d12, d6, d7 @pmax "
"d12,vmax_5678,vmax_5678\n"
"vpmax.f32 d11, d8, d9 @pmax "
"d11,vmax_2345,vmax_2345\n"
"vpmax.f32 d13, d4, d5 @pmax "
"d13,vmax_6789,vmax_6789\n"
"vmax.f32 d0, d10, d11 @pmax "
"d0,vmax_12_34,vmax_23_45\n"
"vmax.f32 d1, d12, d13 @pmax "
"d1,vmax_56_78,vmax_67_89\n"
"sub %[dr0], #8 @add w,8\n"
"sub %[dr1], #8 @add w,8\n"
"sub %[dr2], #8 @add w,8\n"
"vst1.f32 d0, [%[dr_out]]! @vst1 d0,dr_out\n"
"vst1.f32 d1, [%[dr_out]]! @vst1 d0,dr_out\n"
"subs %[cnt_num], #1 @cnt_num--\n"
"bne 1b @bne s3_max_loop_mid\n"
"3: @loop\n"
"cmp %[cnt_num_remain], #0 @cmp cnt_num_remain,0\n"
"ble 4f @ble exit1\n"
"2: @mid loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n"
"vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1\n"
"vmov.f32 s3,s2 @movs3,s2\n"
"vmov.f32 s7,s6 @movs7,s6\n"
"vmov.f32 s11,s10 @movs11,s10\n"
"vmax.f32 q0, q0, q1 @max q0,q0,q1\n"
"vmax.f32 q0, q0, q2 @max q0,q0,q2\n"
"vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n"
"vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n"
"sub %[dr0], #8 @add w,6\n"
"sub %[dr1], #8 @add w,6\n"
"sub %[dr2], #8 @add w,6\n"
"subs %[cnt_num_remain], #1 @cnt_num_remain--\n"
"bne 2b @bne s3_max_loop_mid_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2),
[dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num),
[cnt_num_remain] "+r"(cnt_num_remain)
: "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num),
"r"(cnt_num_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9", "q10", "q11", "q12");
}
#endif
if (w_remain > 0) {
// deal with right pad
int wstart = (w_even >> 1) * stride - padding;
int wend = std::min(std::min(wstart + kernel, win + padding), win);
float tmp = r0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) {
tmp = std::max(tmp, std::max(r0[i], r1[i]));
tmp = std::max(tmp, r2[i]);
}
dout_ch[w_even >> 1] = tmp;
// cnt ++;
}
r0 = r2;
r1 = r0 + win;
r2 = r1 + win;
dout_ch += wout;
}
if (h_remain > 0) {
// deal with bottom pad
// first row with zero pad
// int hstart = (h >> 1) * stride_h - pad_h;
// int hend = std::min(std::min(hstart + kernel_h, hin + pad_h), hin);
// dout_ch[0] = std::max(std::max(r0[0], r0[1]), std::max(r1[0],
// r1[1]));
#ifdef __aarch64__
w = 0;
cnt = 0;
for (; w < w_unroll_size; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vmax_1234 = vmaxq_f32(vr0_1234, vr1_1234);
float32x4_t vmax_5678 = vmaxq_f32(vr0_5678, vr1_5678);
float32x4_t vmax_9101112 = vmaxq_f32(vr0_9101112, vr1_9101112);
float32x4_t vmax_2345 = vextq_f32(vmax_1234, vmax_5678, 1);
float32x4_t vmax_6789 = vextq_f32(vmax_5678, vmax_9101112, 1);
float32x2_t vmax_12_34 =
vpmax_f32(vget_low_f32(vmax_1234), vget_high_f32(vmax_1234));
float32x2_t vmax_23_45 =
vpmax_f32(vget_low_f32(vmax_2345), vget_high_f32(vmax_2345));
float32x2_t vmax_56_78 =
vpmax_f32(vget_low_f32(vmax_5678), vget_high_f32(vmax_5678));
float32x2_t vmax_67_89 =
vpmax_f32(vget_low_f32(vmax_6789), vget_high_f32(vmax_6789));
float32x2_t vmax_123_345 = vmax_f32(vmax_12_34, vmax_23_45);
float32x2_t vmax_567_789 = vmax_f32(vmax_56_78, vmax_67_89);
vst1_f32(&dout_ch[cnt], vmax_123_345);
vst1_f32(&dout_ch[cnt + 2], vmax_567_789);
cnt += 4;
}
for (; w < w_even; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
vr0 = vsetq_lane_f32(minval, vr0, 3);
vr1 = vsetq_lane_f32(minval, vr1, 3);
float32x4_t vmax1 = vmaxq_f32(vr0, vr1);
float32x2_t vmax2 =
vpmax_f32(vget_low_f32(vmax1), vget_high_f32(vmax1));
vmax2 = vpmax_f32(vmax2, vmax2);
dout_ch[cnt] = vget_lane_f32(vmax2, 0);
cnt++;
}
#else
dr_out = dout_ch; // + 1;
dr0 = r0; // (r0 + 1);
dr1 = r1; // (r1 + 1);
int cnt_num = w_unroll_size >> 3;
int cnt_num_remain = w_unroll_remain >> 1;
if (cnt_num > 0 || cnt_num_remain > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num,0\n"
"ble 3f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d4}, [%[dr0]]! @load d0-d3,dr0\n"
"vld1.f32 {d10}, [%[dr1]]! @load d4-d7,dr1\n"
"vmax.f32 q6, q0, q3 @max q0,q0,q2 1234\n"
"vmax.f32 q7, q1, q4 @max q1,q1,q3 5678\n"
"vmax.f32 d16, d4, d10 @max q1,q1,q3 9101112\n"
//"vmov.f32 s7,s6 @mov s7,s6\n"
"vext.f32 q0, q6, q7, #1 @vext q0,2345\n"
"vext.f32 q1, q7, q8, #1 @vext q1,6789\n"
"vpmax.f32 d4, d12, d13 @pmax "
"d4,vmax_1234,vmax_1234\n"
"vpmax.f32 d6, d14, d15 @pmax "
"d6,vmax_5678,vmax_5678\n"
"vpmax.f32 d5, d0, d1 @pmax "
"d5,vmax_2345,vmax_2345\n"
"vpmax.f32 d7, d2, d3 @pmax "
"d7,vmax_6789,vmax_6789\n"
"vmax.f32 d8, d4, d5 @max "
"d2,vmax_12_34,vmax_23_45\n"
"vmax.f32 d9, d6, d7 @max "
"d2,vmax_56_78,vmax_67_89\n"
"sub %[dr0], #8 @add w,8\n"
"sub %[dr1], #8 @add w,8\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n"
"subs %[cnt_num], #1 @subs cnt_num,#1\n"
"bne 1b @bne s3_max_loop_bot\n"
"3: @loop \n"
"cmp %[cnt_num_remain], #0 @cmp cnt_num_remain,0\n"
"ble 4f @ble exit\n"
"2: @bot loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n"
"vmov.f32 s3,s2 @movs3,s2\n"
"vmov.f32 s7,s6 @movs7,s6\n"
"vmax.f32 q0, q0, q1 @max q0,q0,q1\n"
"vpmax.f32 d0, d0, d1 @pmax d0,d0,d1\n"
"vpmax.f32 d0, d0, d0 @pmax d0,d0,d0\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n"
"sub %[dr0], #8 @add w,6\n"
"sub %[dr1], #8 @add w,6\n"
"subs %[cnt_num_remain], #1 @cnt_num_remain--\n"
"bne 2b @bne s3_max_loop_bot_1\n"
"4: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num), [cnt_num_remain] "+r"(cnt_num_remain)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num),
"r"(cnt_num_remain)
: "cc", "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7",
"q8", "q9");
}
#endif
if (w_remain > 0) {
// deal with right pad
int wstart = (w_even >> 1) * stride - padding;
int wend = std::min(std::min(wstart + kernel, win + padding), win);
float tmp = r0[wstart]; // std::numeric_limits<float>::min();
for (int i = wstart; i < wend; i++) { // only run 1 or 2 times
tmp = std::max(tmp, std::max(r0[i], r1[i]));
}
dout_ch[w_even >> 1] = tmp;
}
}
}
}
}
void pooling3x3s2p0_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
bool exclusive) {
int kernel = 3;
int stride = 2;
int padding = 0;
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
int w_needed = (wout << 1) + 1;
int h_needed = (hout << 1) + 1;
int w_limit = w_needed > win ? win : w_needed;
int h_limit = h_needed > hin ? hin : h_needed;
int w_even = ((w_limit - 1) >> 1) << 1;
int h_even = ((h_limit - 1) >> 1) << 1;
int w_unroll_size = (w_even >> 3) << 3;
int w_unroll_remain = w_even - w_unroll_size;
int w_remain = w_needed - w_limit;
int h_remain = h_needed - h_limit;
int w_in_2 = win << 1;
const float coef = 1.f / 9.f;
const float coef_6 = exclusive ? 1.f / 6.f : coef;
float32x4_t vcoef = vdupq_n_f32(coef);
float32x4_t vcoef_6 = vdupq_n_f32(coef_6);
for (int n = 0; n < num; ++n) {
float* dout_batch = dout + n * chout * size_channel_out;
const float* din_batch = din + n * chin * size_channel_in;
#pragma omp parallel for
for (int c = 0; c < chout; c++) {
float* dout_ch = dout_batch + c * size_channel_out;
const float* din_ch = din_batch + c * size_channel_in;
const float* r0 = din_ch;
const float* r1 = r0 + win;
const float* r2 = r1 + win;
// w = w_in - 8;
float* dr_out = dout_ch;
const float* dr0 = r0;
const float* dr1 = r1;
const float* dr2 = r2;
float32x4_t vzero = vdupq_n_f32(0.f);
int h = 0;
for (; h < h_even; h += 2) {
// LOG(INFO) << "h: " << h <<", dr0:" << r0 << ", dr1: " << r1 <<
// ",dr2: " <<r2; deal with left pad float sum0 = r0[0] + r0[1]; float
// sum1 = r1[0] + r1[1]; float sum2 = r2[0] + r2[1]; dout_channel[0] =
// (sum0 + sum1 + sum2) / 9.f;
#ifdef __aarch64__
int w = 0;
int cnt = 0;
for (; w < w_unroll_size; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vr2_1234 = vld1q_f32(&r2[w]);
float32x4_t vr2_5678 = vld1q_f32(&r2[w + 4]);
float32x4_t vr2_9101112 = vld1q_f32(&r2[w + 8]);
float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234);
float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678);
float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112);
vsum_1234 = vaddq_f32(vsum_1234, vr2_1234);
vsum_5678 = vaddq_f32(vsum_5678, vr2_5678);
vsum_9101112 = vaddq_f32(vsum_9101112, vr2_9101112);
float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1);
float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2);
float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3);
float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1);
float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345);
vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456);
float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678);
vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3);
float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef);
vst1q_f32(&dout_ch[cnt], vrst);
cnt += 4;
}
for (; w < w_even; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
float32x4_t vr2 = vld1q_f32(&r2[w]);
vr0 = vsetq_lane_f32(0.f, vr0, 3);
vr1 = vsetq_lane_f32(0.f, vr1, 3);
vr2 = vsetq_lane_f32(0.f, vr2, 3);
float32x4_t vsum1 = vaddq_f32(vr0, vr1);
vsum1 = vaddq_f32(vsum1, vr2);
float32x2_t vsum2 =
vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1));
float32x2_t vsum = vpadd_f32(vsum2, vsum2);
dout_ch[cnt] = vget_lane_f32(vsum, 0) * coef;
cnt++;
}
#else
dr_out = dout_ch; // + 1;
dr0 = r0; // (r0 + 1);
dr1 = r1; // (r1 + 1);
dr2 = r2; // (r2 + 1);
int cnt_num = w_unroll_size >> 3;
int cnt_num_remain = w_unroll_remain >> 1;
// LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " <<
// cnt_num_remain;
if (cnt_num > 0 || cnt_num_remain > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num, 0\n"
"ble loop3_ave_p0 @ble exit\n"
"s3_ave_loop_mid_p0: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5, dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7, dr1\n"
"vld1.f32 {d12-d15}, [%[dr2]]! @load d4-d7, dr2\n"
"vld1.f32 {d4}, [%[dr0]]! @load d0-d5, dr0\n"
"vld1.f32 {d10}, [%[dr1]]! @load d4-d7, dr1\n"
"vld1.f32 {d16}, [%[dr2]]! @load d4-d7, dr2\n"
"vadd.f32 q9, q0, q3 @max q0,q0,q2\n"
"vadd.f32 q10, q1, q4 @max q1,q1,q3\n"
"vadd.f32 d22, d4, d10 @max q1,q1,q3\n"
"vadd.f32 q6, q9, q6 @max q0,q0,q2 1234\n"
"vadd.f32 q7, q10, q7 @max q1,q1,q3 5678\n"
"vadd.f32 d16, d22, d16 @max q1,q1,q3 9101112\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345\n"
"vext.f32 q1, q6, q7, #3 @vext max_4567\n"
"vext.f32 q2, q6, q7, #2 @vext max_3456\n"
"vext.f32 q3, q7, q8, #1 @vext max_6789\n"
"vadd.f32 q4, q6, q0 @add 1234, 2345\n"
"vadd.f32 q5, q7, q1 @add 5678, 4567\n"
"vadd.f32 q4, q4, q2 @add 3456, sum1\n"
"vadd.f32 q5, q5, q3 @add 6789, sum2\n"
"vmov.f32 s17, s18 @mov\n"
"vmov.f32 s18, s21 @mov\n"
"vmov.f32 s19, s23 @mov\n"
"vmul.f32 q4, q4, %q[vcoef] @mul\n"
"sub %[dr0], #8 @add w,8\n"
"sub %[dr1], #8 @add w,8\n"
"sub %[dr2], #8 @add w,8\n"
"subs %[cnt_num], #1 @cnt_num--\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n"
"bne s3_ave_loop_mid_p0 @bne s3_max_loop_mid\n"
"loop3_ave_p0: @loop\n"
"cmp %[cnt_num_remain], #0 @cmp cnt_num_remain,0\n"
"ble exit1_ave_p0 @ble exit1\n"
"s3_ave_loop_mid_1_p0: @mid loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n"
"vld1.f32 {d4-d5}, [%[dr2]]! @load d2-d3,dr1\n"
"vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n"
"vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n"
"vext.f32 q2, %q[vzero], q2, #3 @ext v1_0123\n"
"vadd.f32 q0, q0, q1 @add q0,q0,q1\n"
"vadd.f32 q0, q0, q2 @add q0,q0,q1\n"
"vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n"
"vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n"
"vmul.f32 d0, d0, %e[vcoef] @mul\n"
"sub %[dr0], #8 @add w,6\n"
"sub %[dr1], #8 @add w,6\n"
"sub %[dr2], #8 @add w,6\n"
"subs %[cnt_num_remain], #1 @cnt_num_remain--\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n"
"bne s3_ave_loop_mid_1_p0 @bne s3_max_loop_mid_1\n"
"exit1_ave_p0: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr2] "+r"(dr2),
[dr_out] "+r"(dr_out), [cnt_num] "+r"(cnt_num),
[cnt_num_remain] "+r"(cnt_num_remain), [vcoef] "+w"(vcoef),
[vzero] "+w"(vzero)
: "r"(dr0), "r"(dr1), "r"(dr2), "r"(dr_out), "r"(cnt_num),
"r"(cnt_num_remain)
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9",
"q10", "q11", "q12");
}
#endif
if (w_remain > 0) {
// deal with right pad
int wstart = (w_even >> 1) * stride - padding;
int wend = std::min(std::min(wstart + kernel, win + padding), win);
float tmp1 = 0.f;
float tmp2 = exclusive ? 1.0f / (3.f * (wend - wstart)) : coef;
for (int i = wstart; i < wend; i++) {
tmp1 += (r0[i] + r1[i] + r2[i]);
}
dout_ch[w_even >> 1] = tmp1 * tmp2;
// cnt ++;
}
r0 = r2;
r1 = r0 + win;
r2 = r1 + win;
dout_ch += wout;
}
if (h_remain > 0) {
// deal with bottom pad
// first row with zero pad
// int hstart = (h >> 1) * stride_h - pad_h;
// int hend = std::min(std::min(hstart + kernel_h, hin + padding_h),
// hin); data_out_channel[0] =(r0[0] + r0[1] + r0[2] + r1[0] + r1[1] +
// r1[2]) / 9.f;
#ifdef __aarch64__
int w = 0;
int cnt = 0;
for (; w < w_unroll_size; w += 8) {
float32x4_t vr0_1234 = vld1q_f32(&r0[w]);
float32x4_t vr0_5678 = vld1q_f32(&r0[w + 4]);
float32x4_t vr0_9101112 = vld1q_f32(&r0[w + 8]);
float32x4_t vr1_1234 = vld1q_f32(&r1[w]);
float32x4_t vr1_5678 = vld1q_f32(&r1[w + 4]);
float32x4_t vr1_9101112 = vld1q_f32(&r1[w + 8]);
float32x4_t vsum_1234 = vaddq_f32(vr0_1234, vr1_1234);
float32x4_t vsum_5678 = vaddq_f32(vr0_5678, vr1_5678);
float32x4_t vsum_9101112 = vaddq_f32(vr0_9101112, vr1_9101112);
float32x4_t vsum_2345 = vextq_f32(vsum_1234, vsum_5678, 1);
float32x4_t vsum_3456 = vextq_f32(vsum_1234, vsum_5678, 2);
float32x4_t vsum_4567 = vextq_f32(vsum_1234, vsum_5678, 3);
float32x4_t vsum_6789 = vextq_f32(vsum_5678, vsum_9101112, 1);
float32x4_t vsum_123_345 = vaddq_f32(vsum_1234, vsum_2345);
vsum_123_345 = vaddq_f32(vsum_123_345, vsum_3456);
float32x4_t vsum_567_789 = vaddq_f32(vsum_4567, vsum_5678);
vsum_567_789 = vaddq_f32(vsum_567_789, vsum_6789);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_123_345, 2), vsum_123_345, 1);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 1), vsum_123_345, 2);
vsum_123_345 =
vsetq_lane_f32(vgetq_lane_f32(vsum_567_789, 3), vsum_123_345, 3);
float32x4_t vrst = vmulq_f32(vsum_123_345, vcoef_6);
vst1q_f32(&dout_ch[cnt], vrst);
cnt += 4;
}
for (; w < w_even; w += 2) {
float32x4_t vr0 = vld1q_f32(&r0[w]);
float32x4_t vr1 = vld1q_f32(&r1[w]);
vr0 = vsetq_lane_f32(0.f, vr0, 3);
vr1 = vsetq_lane_f32(0.f, vr1, 3);
float32x4_t vsum1 = vaddq_f32(vr0, vr1);
float32x2_t vsum2 =
vpadd_f32(vget_low_f32(vsum1), vget_high_f32(vsum1));
vsum2 = vpadd_f32(vsum2, vsum2);
float32x2_t vrst = vmul_f32(vsum2, vget_low_f32(vcoef_6));
dout_ch[cnt] = vget_lane_f32(vrst, 0);
cnt++;
}
#else
dr_out = dout_ch; // + 1;
dr0 = r0; // (r0 + 1);
dr1 = r1; // (r1 + 1);
int cnt_num = w_unroll_size >> 3;
int cnt_num_remain = w_unroll_remain >> 1;
// LOG(INFO) << "cnt_num: " << cnt_num << " cnt_num_remain: " <<
// cnt_num_remain;
if (cnt_num > 0 || cnt_num_remain > 0) {
asm volatile(
"cmp %[cnt_num], #0 @cmp cnt_num,0\n"
"ble 2f @ble exit\n"
"1: @main loop\n"
"vld1.f32 {d0-d3}, [%[dr0]]! @load d0-d5,dr0\n"
"vld1.f32 {d6-d9}, [%[dr1]]! @load d4-d7,dr1\n"
"vld1.f32 {d4}, [%[dr0]]! @load d0-d3,dr0\n"
"vld1.f32 {d10}, [%[dr1]]! @load d4-d7,dr1\n"
"vadd.f32 q6, q0, q3 @max q0,q0,q2 1234\n"
"vadd.f32 q7, q1, q4 @max q1,q1,q3 5678\n"
"vadd.f32 d16, d4, d10 @max q1,q1,q3 9101112\n"
//"vmov.f32 s7,s6 @mov s7, s6\n"
"vext.f32 q0, q6, q7, #1 @vext max_2345\n"
"vext.f32 q1, q6, q7, #3 @vext max_4567\n"
"vext.f32 q2, q6, q7, #2 @vext max_3456\n"
"vext.f32 q3, q7, q8, #1 @vext max_6789\n"
"vadd.f32 q4, q6, q0 @add 1234,2345\n"
"vadd.f32 q5, q7, q1 @add 5678,4567\n"
"vadd.f32 q4, q4, q2 @add 3456,sum1\n"
"vadd.f32 q5, q5, q3 @add 6789,sum2\n"
"vmov.f32 s17, s18 @mov\n"
"vmov.f32 s18, s21 @mov\n"
"vmov.f32 s19, s23 @mov\n"
"vmul.f32 q4, q4, %q[vcoef_6] @mul\n"
"sub %[dr0], #8 @add w,8\n"
"sub %[dr1], #8 @add w,8\n"
"subs %[cnt_num], #1 @cnt_num--\n"
"vst1.f32 d8, [%[dr_out]]! @vst1 d0,dr_out\n"
"vst1.f32 d9, [%[dr_out]]! @vst1 d0,dr_out\n"
"bne 1b @bne s3_max_loop_bot\n"
"2: @loop\n"
"cmp %[cnt_num_remain], #0 @cmp cnt_num_remain, 0\n"
"ble 3f @ble exit\n"
"4: @bot loop\n"
"vld1.f32 {d0-d1}, [%[dr0]]! @load d0-d1,dr0\n"
"vld1.f32 {d2-d3}, [%[dr1]]! @load d2-d3,dr1\n"
"vext.f32 q0, %q[vzero], q0, #3 @ext v0_0123\n"
"vext.f32 q1, %q[vzero], q1, #3 @ext v1_0123\n"
"vadd.f32 q0, q0, q1 @add q0,q0,q1\n"
"vpadd.f32 d0, d0, d1 @padd d0,d0,d1\n"
"vpadd.f32 d0, d0, d0 @padd d0,d0,d0\n"
"vmul.f32 d0, d0, %e[vcoef_6] @mul\n"
"sub %[dr0], #8 @add w,6\n"
"sub %[dr1], #8 @add w,6\n"
"subs %[cnt_num_remain], #1 @cnt_num_remain--\n"
"vst1.f32 d0[0], [%[dr_out]]! @vst d0[0],dr_out\n"
"bne 4b @bne s3_max_loop_bot_1\n"
"3: @exit\n"
: [dr0] "+r"(dr0), [dr1] "+r"(dr1), [dr_out] "+r"(dr_out),
[cnt_num] "+r"(cnt_num), [cnt_num_remain] "+r"(cnt_num_remain),
[vcoef_6] "+w"(vcoef_6), [vzero] "+w"(vzero)
: "r"(dr0), "r"(dr1), "r"(dr_out), "r"(cnt_num),
"r"(cnt_num_remain)
: "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9");
}
#endif
if (w_remain > 0) {
// deal with right pad
int wstart = (w_even >> 1) * stride - padding;
int wend = std::min(std::min(wstart + kernel, win + padding), win);
float tmp1 = 0.f;
float tmp2 = exclusive ? 1.0f / (2.f * (wend - wstart)) : coef;
for (int i = wstart; i < wend; i++) { // only run 1 or 2 times
tmp1 += (r0[i] + r1[i]);
}
dout_ch[w_even >> 1] = tmp1 * tmp2;
}
}
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/lite/utils/cp_logging.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
// !pooling fp32 Op
void pooling_basic(const float* din, float* dout, int num, int chout, int hout,
int wout, int chin, int hin, int win,
const std::vector<int>& ksize,
const std::vector<int>& strides,
const std::vector<int>& paddings, bool global_pooling,
bool exclusive, bool adaptive, bool ceil_mode,
bool use_quantizer, const std::string& pooling_type);
void pooling_global_max(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win);
void pooling_global_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win);
void pooling2x2s2_max(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win);
void pooling2x2s2_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
bool exclusive);
void pooling3x3s1p1_max(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win);
void pooling3x3s1p1_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
bool exclusive);
void pooling3x3s2p1_max(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win);
void pooling3x3s2p1_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
bool exclusive);
void pooling3x3s2p0_max(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win);
void pooling3x3s2p0_avg(const float* din, float* dout, int num, int chout,
int hout, int wout, int chin, int hin, int win,
bool exclusive);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/scale.h"
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void scale<float>(const float* din, float* dout, int num, float scale,
float bias) {
int cnt = num >> 4;
int remain = num % 16;
float32x4_t vscale = vdupq_n_f32(scale);
float32x4_t vbias = vdupq_n_f32(bias);
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const float* din_ptr = din + (i << 4);
float* dout_ptr = dout + (i << 4);
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
float32x4_t vsum1 = vmlaq_f32(vbias, din0, vscale);
float32x4_t vsum2 = vmlaq_f32(vbias, din1, vscale);
float32x4_t vsum3 = vmlaq_f32(vbias, din2, vscale);
float32x4_t vsum4 = vmlaq_f32(vbias, din3, vscale);
vst1q_f32(dout_ptr, vsum1);
vst1q_f32(dout_ptr + 4, vsum2);
vst1q_f32(dout_ptr + 8, vsum3);
vst1q_f32(dout_ptr + 12, vsum4);
}
if (remain > 0) {
const float* din_ptr = din + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
*dout_ptr = *din_ptr * scale + bias;
dout_ptr++;
din_ptr++;
}
}
}
template <>
void scale<float>(const float* din, float* dout, int outer_dim, int scale_dim,
int inner_dim, const float* scale_data,
const float* bias_data) {
int cnt = inner_dim >> 4;
int remain = inner_dim % 16;
int size = inner_dim * scale_dim;
for (int n = 0; n < outer_dim; n++) {
const float* din_ptr_n = din + n * size;
float* dout_ptr_n = dout + n * size;
#pragma omp parallel for
for (int i = 0; i < scale_dim; i++) {
const float* din_ptr = din_ptr_n + i * inner_dim;
float* dout_ptr = dout_ptr_n + i * inner_dim;
float scale = scale_data[i];
float32x4_t vscale = vdupq_n_f32(scale);
float bias = bias_data[i];
float32x4_t vbias = vdupq_n_f32(bias);
for (int j = 0; j < cnt; j++) {
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
float32x4_t vsum1 = vmlaq_f32(vbias, din0, vscale);
float32x4_t vsum2 = vmlaq_f32(vbias, din1, vscale);
float32x4_t vsum3 = vmlaq_f32(vbias, din2, vscale);
float32x4_t vsum4 = vmlaq_f32(vbias, din3, vscale);
din_ptr += 16;
vst1q_f32(dout_ptr, vsum1);
vst1q_f32(dout_ptr + 4, vsum2);
vst1q_f32(dout_ptr + 8, vsum3);
vst1q_f32(dout_ptr + 12, vsum4);
dout_ptr += 16;
}
for (int j = 0; j < remain; j++) {
*dout_ptr = *din_ptr * scale + bias;
dout_ptr++;
din_ptr++;
}
}
}
}
template <>
void scale<float>(const float* din, float* dout, int outer_dim, int scale_dim,
const float* scale_data, const float* bias_data) {
int cnt = scale_dim >> 4;
int remain = scale_dim % 16;
for (int n = 0; n < outer_dim; n++) {
const float* din_ptr_n = din + n * scale_dim;
float* dout_ptr_n = dout + n * scale_dim;
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
int idx = i << 4;
const float* din_ptr = din_ptr_n + idx;
const float* scale_ptr = scale_data + idx;
const float* bias_ptr = bias_data + idx;
float* dout_ptr = dout_ptr_n + idx;
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t vscale0 = vld1q_f32(scale_ptr);
float32x4_t vbias0 = vld1q_f32(bias_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t vscale1 = vld1q_f32(scale_ptr + 4);
float32x4_t vbias1 = vld1q_f32(bias_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t vscale2 = vld1q_f32(scale_ptr + 8);
float32x4_t vbias2 = vld1q_f32(bias_ptr + 8);
float32x4_t vsum1 = vmlaq_f32(vbias0, din0, vscale0);
float32x4_t vsum2 = vmlaq_f32(vbias1, din1, vscale1);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
float32x4_t vscale3 = vld1q_f32(scale_ptr + 12);
float32x4_t vbias3 = vld1q_f32(bias_ptr + 12);
vst1q_f32(dout_ptr, vsum1);
vst1q_f32(dout_ptr + 4, vsum2);
float32x4_t vsum3 = vmlaq_f32(vbias2, din2, vscale2);
float32x4_t vsum4 = vmlaq_f32(vbias3, din3, vscale3);
vst1q_f32(dout_ptr + 8, vsum3);
vst1q_f32(dout_ptr + 12, vsum4);
}
int idx = cnt << 4;
const float* din_ptr = din_ptr_n + idx;
float* dout_ptr = dout_ptr_n + idx;
const float* scale_ptr = scale_data + idx;
const float* bias_ptr = bias_data + idx;
for (int j = 0; j < remain; j++) {
*dout_ptr = *din_ptr * (*scale_ptr) + (*bias_ptr);
dout_ptr++;
din_ptr++;
scale_ptr++;
bias_ptr++;
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void scale(const T* din, T* dout, int num, float scale, float bias);
template <typename T>
void scale(const T* din, T* dout, int outer_dim, int scale_dim, int inner_dim,
const float* scale_data, const float* bias_data);
template <typename T>
void scale(const T* din, T* dout, int outer_dim, int scale_dim,
const float* scale_data, const float* bias_data);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/softmax.h"
#include <algorithm>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void softmax_basic<float>(const float* din, float* dout, const int axis_size,
const int inner_num, const int outer_num) {
int compute_size = inner_num * outer_num;
#pragma omp parallel for
for (int i = 0; i < compute_size; ++i) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
float max_data = din[real_index];
// get max
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
max_data = din[real_index] > max_data ? din[real_index] : max_data;
}
real_index = idx_outer * inner_num + idx_inner;
// sub, exp and sum
dout[real_index] = expf(din[real_index] - max_data);
float sum_data = dout[real_index];
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
dout[real_index] = expf(din[real_index] - max_data);
sum_data += dout[real_index];
}
float sum_inv = 1.f / sum_data;
real_index = idx_outer * inner_num + idx_inner;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
dout[real_index] *= sum_inv;
real_index += inner_num;
}
}
}
template <>
void softmax_inner8_axis4<float>(const float* din, float* dout,
const int axis_size, const int inner_num,
const int outer_num) {
int compute_size = inner_num * outer_num;
int cmp_cnt = compute_size >> 3;
int remain = compute_size % 8;
float32x4_t vone = vdupq_n_f32(1.0f);
#pragma omp parallel for
for (int c = 0; c < cmp_cnt; ++c) {
int i = c * 8;
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
// get max axis_size == 4
const float* din_ptr = din + real_index;
const float* din_ptr1 = din_ptr + inner_num;
const float* din_ptr2 = din_ptr1 + inner_num;
const float* din_ptr3 = din_ptr2 + inner_num;
float32x4_t vdata0 = vld1q_f32(din_ptr);
float32x4_t vdata1 = vld1q_f32(din_ptr1);
float32x4_t vdata2 = vld1q_f32(din_ptr2);
float32x4_t vdata3 = vld1q_f32(din_ptr3);
float32x4_t vdata01 = vld1q_f32(din_ptr + 4);
float32x4_t vdata11 = vld1q_f32(din_ptr1 + 4);
float32x4_t vdata21 = vld1q_f32(din_ptr2 + 4);
float32x4_t vdata31 = vld1q_f32(din_ptr3 + 4);
float* dout_ptr0 = dout + real_index;
float* dout_ptr1 = dout_ptr0 + inner_num;
float32x4_t vmax1 = vmaxq_f32(vdata0, vdata1);
float32x4_t vmax2 = vmaxq_f32(vdata2, vdata3);
float32x4_t vmax11 = vmaxq_f32(vdata01, vdata11);
float32x4_t vmax21 = vmaxq_f32(vdata21, vdata31);
float* dout_ptr2 = dout_ptr1 + inner_num;
float* dout_ptr3 = dout_ptr2 + inner_num;
float32x4_t vmax = vmaxq_f32(vmax1, vmax2);
float32x4_t vmax_1 = vmaxq_f32(vmax11, vmax21);
// sub, exp and sum
float32x4_t vsum0 = exp_ps(vsubq_f32(vdata0, vmax));
float32x4_t vsum1 = exp_ps(vsubq_f32(vdata1, vmax));
float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax));
float32x4_t vsum3 = exp_ps(vsubq_f32(vdata3, vmax));
float32x4_t vsum01 = exp_ps(vsubq_f32(vdata01, vmax_1));
float32x4_t vsum11 = exp_ps(vsubq_f32(vdata11, vmax_1));
float32x4_t vsum21 = exp_ps(vsubq_f32(vdata21, vmax_1));
float32x4_t vsum31 = exp_ps(vsubq_f32(vdata31, vmax_1));
float32x4_t vsum_1 = vaddq_f32(vsum0, vsum1);
float32x4_t vsum_2 = vaddq_f32(vsum2, vsum3);
float32x4_t vsum_11 = vaddq_f32(vsum01, vsum11);
float32x4_t vsum_21 = vaddq_f32(vsum21, vsum31);
float32x4_t vsum = vaddq_f32(vsum_1, vsum_2);
float32x4_t vsum111 = vaddq_f32(vsum_11, vsum_21);
float32x4_t vinf = div_ps(vone, vsum);
float32x4_t vinf1 = div_ps(vone, vsum111);
vsum0 = vmulq_f32(vsum0, vinf);
vsum1 = vmulq_f32(vsum1, vinf);
vsum2 = vmulq_f32(vsum2, vinf);
vsum3 = vmulq_f32(vsum3, vinf);
vsum01 = vmulq_f32(vsum01, vinf1);
vsum11 = vmulq_f32(vsum11, vinf1);
vsum21 = vmulq_f32(vsum21, vinf1);
vsum31 = vmulq_f32(vsum31, vinf1);
vst1q_f32(dout_ptr0, vsum0);
vst1q_f32(dout_ptr1, vsum1);
vst1q_f32(dout_ptr2, vsum2);
vst1q_f32(dout_ptr3, vsum3);
vst1q_f32(dout_ptr0 + 4, vsum01);
vst1q_f32(dout_ptr1 + 4, vsum11);
vst1q_f32(dout_ptr2 + 4, vsum21);
vst1q_f32(dout_ptr3 + 4, vsum31);
}
int i = cmp_cnt * 8;
if (remain > 4) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
// get max axis_size == 4
const float* din_ptr = din + real_index;
const float* din_ptr1 = din_ptr + inner_num;
const float* din_ptr2 = din_ptr1 + inner_num;
const float* din_ptr3 = din_ptr2 + inner_num;
float32x4_t vdata0 = vld1q_f32(din_ptr);
float32x4_t vdata1 = vld1q_f32(din_ptr1);
float32x4_t vdata2 = vld1q_f32(din_ptr2);
float32x4_t vdata3 = vld1q_f32(din_ptr3);
float* dout_ptr0 = dout + real_index;
float* dout_ptr1 = dout_ptr0 + inner_num;
float32x4_t vmax1 = vmaxq_f32(vdata0, vdata1);
float32x4_t vmax2 = vmaxq_f32(vdata2, vdata3);
float* dout_ptr2 = dout_ptr1 + inner_num;
float* dout_ptr3 = dout_ptr2 + inner_num;
float32x4_t vmax = vmaxq_f32(vmax1, vmax2);
// sub, exp and sum
float32x4_t vsum0 = exp_ps(vsubq_f32(vdata0, vmax));
float32x4_t vsum1 = exp_ps(vsubq_f32(vdata1, vmax));
float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax));
float32x4_t vsum3 = exp_ps(vsubq_f32(vdata3, vmax));
float32x4_t vsum_1 = vaddq_f32(vsum0, vsum1);
float32x4_t vsum_2 = vaddq_f32(vsum2, vsum3);
float32x4_t vsum = vaddq_f32(vsum_1, vsum_2);
float32x4_t vone = vdupq_n_f32(1.0f);
float32x4_t vinf = div_ps(vone, vsum);
vsum0 = vmulq_f32(vsum0, vinf);
vsum1 = vmulq_f32(vsum1, vinf);
vsum2 = vmulq_f32(vsum2, vinf);
vsum3 = vmulq_f32(vsum3, vinf);
vst1q_f32(dout_ptr0, vsum0);
vst1q_f32(dout_ptr1, vsum1);
vst1q_f32(dout_ptr2, vsum2);
vst1q_f32(dout_ptr3, vsum3);
i += 4;
}
for (; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
float max_data = din[real_index];
// get max
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
max_data = din[real_index] > max_data ? din[real_index] : max_data;
}
real_index = idx_outer * inner_num + idx_inner;
// sub, exp and sum
dout[real_index] = expf(din[real_index] - max_data);
float sum_data = dout[real_index];
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
dout[real_index] = expf(din[real_index] - max_data);
sum_data += dout[real_index];
}
float sum_inv = 1.f / sum_data;
real_index = idx_outer * inner_num + idx_inner;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
dout[real_index] *= sum_inv;
real_index += inner_num;
}
}
}
template <>
void softmax_inner4_axis4<float>(const float* din, float* dout,
const int axis_size, const int inner_num,
const int outer_num) {
int compute_size = inner_num * outer_num;
int cmp_cnt = compute_size >> 2;
int remain = compute_size % 4;
float32x4_t vone = vdupq_n_f32(1.0f);
#pragma omp parallel for
for (int c = 0; c < cmp_cnt; ++c) {
int i = c * 4;
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
// get max axis_size == 4
const float* din_ptr = din + real_index;
const float* din_ptr1 = din_ptr + inner_num;
const float* din_ptr2 = din_ptr1 + inner_num;
const float* din_ptr3 = din_ptr2 + inner_num;
float32x4_t vdata0 = vld1q_f32(din_ptr);
float32x4_t vdata1 = vld1q_f32(din_ptr1);
float32x4_t vdata2 = vld1q_f32(din_ptr2);
float32x4_t vdata3 = vld1q_f32(din_ptr3);
float* dout_ptr0 = dout + real_index;
float* dout_ptr1 = dout_ptr0 + inner_num;
float32x4_t vmax1 = vmaxq_f32(vdata0, vdata1);
float32x4_t vmax2 = vmaxq_f32(vdata2, vdata3);
float* dout_ptr2 = dout_ptr1 + inner_num;
float* dout_ptr3 = dout_ptr2 + inner_num;
float32x4_t vmax = vmaxq_f32(vmax1, vmax2);
// sub, exp and sum
float32x4_t vsum0 = exp_ps(vsubq_f32(vdata0, vmax));
float32x4_t vsum1 = exp_ps(vsubq_f32(vdata1, vmax));
float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax));
float32x4_t vsum3 = exp_ps(vsubq_f32(vdata3, vmax));
float32x4_t vsum_1 = vaddq_f32(vsum0, vsum1);
float32x4_t vsum_2 = vaddq_f32(vsum2, vsum3);
float32x4_t vsum = vaddq_f32(vsum_1, vsum_2);
float32x4_t vinf = div_ps(vone, vsum);
vsum0 = vmulq_f32(vsum0, vinf);
vsum1 = vmulq_f32(vsum1, vinf);
vsum2 = vmulq_f32(vsum2, vinf);
vsum3 = vmulq_f32(vsum3, vinf);
vst1q_f32(dout_ptr0, vsum0);
vst1q_f32(dout_ptr1, vsum1);
vst1q_f32(dout_ptr2, vsum2);
vst1q_f32(dout_ptr3, vsum3);
}
int i = cmp_cnt * 8;
for (; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
float max_data = din[real_index];
// get max
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
max_data = din[real_index] > max_data ? din[real_index] : max_data;
}
real_index = idx_outer * inner_num + idx_inner;
// sub, exp and sum
dout[real_index] = expf(din[real_index] - max_data);
float sum_data = dout[real_index];
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
dout[real_index] = expf(din[real_index] - max_data);
sum_data += dout[real_index];
}
float sum_inv = 1.f / sum_data;
real_index = idx_outer * inner_num + idx_inner;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
dout[real_index] *= sum_inv;
real_index += inner_num;
}
}
}
template <>
void softmax_inner8<float>(const float* din, float* dout, const int axis_size,
const int inner_num, const int outer_num) {
int compute_size = inner_num * outer_num;
int cmp_cnt = compute_size >> 3;
#pragma omp parallel for
for (int c = 0; c < cmp_cnt; ++c) {
int i = c * 8;
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
const float* din_ptr = din + real_index;
float32x4_t vmax = vld1q_f32(din_ptr);
float32x4_t vmax2 = vld1q_f32(din_ptr + 4);
// get max
for (int j = 1; j < axis_size; ++j) {
din_ptr += inner_num;
float32x4_t vdata = vld1q_f32(din_ptr);
float32x4_t vdata2 = vld1q_f32(din_ptr + 4);
vmax = vmaxq_f32(vmax, vdata);
vmax2 = vmaxq_f32(vmax2, vdata2);
}
// sub, exp and sum
din_ptr = din + real_index;
float* dout_ptr = dout + real_index;
float32x4_t vdata = vld1q_f32(din_ptr);
float32x4_t vdata2 = vld1q_f32(din_ptr + 4);
float32x4_t vsum = exp_ps(vsubq_f32(vdata, vmax));
float32x4_t vsum2 = exp_ps(vsubq_f32(vdata2, vmax2));
din_ptr += inner_num;
vst1q_f32(dout_ptr, vsum);
vst1q_f32(dout_ptr + 4, vsum2);
dout_ptr += inner_num;
for (int j = 1; j < axis_size; ++j) {
float32x4_t vdata0 = vld1q_f32(din_ptr);
float32x4_t vdata1 = vld1q_f32(din_ptr + 4);
vdata0 = exp_ps(vsubq_f32(vdata0, vmax));
vdata1 = exp_ps(vsubq_f32(vdata1, vmax2));
din_ptr += inner_num;
vsum = vaddq_f32(vsum, vdata0);
vsum2 = vaddq_f32(vsum2, vdata1);
vst1q_f32(dout_ptr, vdata0);
vst1q_f32(dout_ptr + 4, vdata1);
dout_ptr += inner_num;
}
float32x4_t vone = vdupq_n_f32(1.0f);
float32x4_t vinf = div_ps(vone, vsum);
float32x4_t vinf2 = div_ps(vone, vsum2);
dout_ptr = dout + real_index;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
float32x4_t vdata0 = vld1q_f32(dout_ptr);
float32x4_t vdata1 = vld1q_f32(dout_ptr + 4);
vdata0 = vmulq_f32(vdata0, vinf);
vdata1 = vmulq_f32(vdata1, vinf2);
vst1q_f32(dout_ptr, vdata0);
vst1q_f32(dout_ptr + 4, vdata1);
dout_ptr += inner_num;
}
}
for (int i = cmp_cnt * 8; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
float max_data = din[real_index];
// get max
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
max_data = din[real_index] > max_data ? din[real_index] : max_data;
}
real_index = idx_outer * inner_num + idx_inner;
// sub, exp and sum
dout[real_index] = expf(din[real_index] - max_data);
float sum_data = dout[real_index];
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
dout[real_index] = expf(din[real_index] - max_data);
sum_data += dout[real_index];
}
float sum_inv = 1.f / sum_data;
real_index = idx_outer * inner_num + idx_inner;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
dout[real_index] *= sum_inv;
real_index += inner_num;
}
}
}
template <>
void softmax_inner4<float>(const float* din, float* dout, const int axis_size,
const int inner_num, const int outer_num) {
int compute_size = inner_num * outer_num;
int cmp_cnt = compute_size >> 2;
#pragma omp parallel for
for (int c = 0; c < cmp_cnt; ++c) {
int i = c * 4;
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
// float max_data = din[real_index];
const float* din_ptr = din + real_index;
float32x4_t vmax = vld1q_f32(din_ptr);
// get max
for (int j = 1; j < axis_size; ++j) {
din_ptr += inner_num;
float32x4_t vdata = vld1q_f32(din_ptr);
vmax = vmaxq_f32(vmax, vdata);
}
// sub, exp and sum
din_ptr = din + real_index;
float* dout_ptr = dout + real_index;
float32x4_t vdata = vld1q_f32(din_ptr);
float32x4_t vsum = exp_ps(vsubq_f32(vdata, vmax));
din_ptr += inner_num;
vst1q_f32(dout_ptr, vsum);
dout_ptr += inner_num;
for (int j = 1; j < axis_size; ++j) {
// real_index += inner_num;
float32x4_t vdata0 = vld1q_f32(din_ptr);
vdata0 = exp_ps(vsubq_f32(vdata0, vmax));
din_ptr += inner_num;
vsum = vaddq_f32(vsum, vdata0);
vst1q_f32(dout_ptr, vdata0);
dout_ptr += inner_num;
}
float32x4_t vone = vdupq_n_f32(1.0f);
float32x4_t vinf = div_ps(vone, vsum);
dout_ptr = dout + real_index;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
float32x4_t vdata0 = vld1q_f32(dout_ptr);
vdata0 = vmulq_f32(vdata0, vinf);
vst1q_f32(dout_ptr, vdata0);
dout_ptr += inner_num;
}
}
for (int i = cmp_cnt * 4; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int real_index = idx_outer * inner_num + idx_inner;
float max_data = din[real_index];
// get max
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
max_data = din[real_index] > max_data ? din[real_index] : max_data;
}
real_index = idx_outer * inner_num + idx_inner;
// sub, exp and sum
dout[real_index] = expf(din[real_index] - max_data);
float sum_data = dout[real_index];
for (int j = 1; j < axis_size; ++j) {
real_index += inner_num;
dout[real_index] = expf(din[real_index] - max_data);
sum_data += dout[real_index];
}
float sum_inv = 1.f / sum_data;
real_index = idx_outer * inner_num + idx_inner;
// get softmax result
for (int j = 0; j < axis_size; ++j) {
dout[real_index] *= sum_inv;
real_index += inner_num;
}
}
}
template <>
void softmax_inner1_large_axis<float>(const float* din, float* dout,
const int outer_size,
const int axis_size) {
#pragma omp parallel for
for (int i = 0; i < outer_size; ++i) {
const float* din_ptr = din + i * axis_size;
float* dout_ptr = dout + i * axis_size;
const float* din_max_ptr = din_ptr;
int nn = axis_size >> 2;
// get max
float32x4_t vmax = vld1q_f32(din_max_ptr);
din_max_ptr += 4;
int j = 1;
for (; j < nn; ++j) {
vmax = vmaxq_f32(vmax, vld1q_f32(din_max_ptr));
din_max_ptr += 4;
}
float32x2_t vhmax = vmax_f32(vget_high_f32(vmax), vget_low_f32(vmax));
float max_data = std::max(vget_lane_f32(vhmax, 0), vget_lane_f32(vhmax, 1));
for (j = 4 * j; j < axis_size; ++j) {
max_data = std::max(max_data, din_max_ptr[0]);
din_max_ptr++;
}
// sub, exp and sum
const float* din_sum_ptr = din_ptr;
float* dout_sum_ptr = dout_ptr;
vmax = vdupq_n_f32(max_data);
float32x4_t vsub_exp = exp_ps(vsubq_f32(vld1q_f32(din_sum_ptr), vmax));
float32x4_t vsum = vsub_exp;
vst1q_f32(dout_sum_ptr, vsub_exp);
din_sum_ptr += 4;
dout_sum_ptr += 4;
j = 1;
for (; j < nn; ++j) {
vsub_exp = exp_ps(vsubq_f32(vld1q_f32(din_sum_ptr), vmax));
vst1q_f32(dout_sum_ptr, vsub_exp);
vsum = vaddq_f32(vsum, vsub_exp);
din_sum_ptr += 4;
dout_sum_ptr += 4;
}
float32x2_t vhsum = vadd_f32(vget_high_f32(vsum), vget_low_f32(vsum));
float sum_data = vget_lane_f32(vhsum, 0) + vget_lane_f32(vhsum, 1);
for (j = 4 * j; j < axis_size; ++j) {
dout_sum_ptr[0] = expf(din_sum_ptr[0] - max_data);
sum_data += dout_sum_ptr[0];
din_sum_ptr++;
dout_sum_ptr++;
}
float sum_inv = 1.f / sum_data;
float* dout_res_ptr = dout_ptr;
float32x4_t vinv = vdupq_n_f32(sum_inv);
// get softmax result
j = 0;
for (; j < nn; ++j) {
float32x4_t vout = vld1q_f32(dout_res_ptr);
float32x4_t vres = vmulq_f32(vout, vinv);
vst1q_f32(dout_res_ptr, vres);
dout_res_ptr += 4;
}
for (j = nn * 4; j < axis_size; ++j) {
dout_ptr[j] *= sum_inv;
}
}
}
template <>
void softmax_inner1_small_axis<float>(const float* din, float* dout,
const int outer_size,
const int axis_size) {
#pragma omp parallel for
for (int i = 0; i < outer_size; ++i) {
const float* din_ptr = din + i * axis_size;
float* dout_ptr = dout + i * axis_size;
// get max
float max_data = din_ptr[0];
for (int j = 1; j < axis_size; ++j) {
max_data = std::max(max_data, din_ptr[j]);
}
// sub, exp and sum
float sum_data = 0.f;
for (int j = 0; j < axis_size; ++j) {
dout_ptr[j] = expf(din_ptr[j] - max_data);
sum_data += dout_ptr[j];
}
float sum_inv = 1.f / sum_data;
for (int j = 0; j < axis_size; ++j) {
dout_ptr[j] *= sum_inv;
}
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void softmax_basic(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner8_axis4(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner4_axis4(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner8(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner4(const T* din, T* dout, const int axis_size,
const int inner_num, const int outer_num);
template <typename T>
void softmax_inner1_large_axis(const T* din, T* dout, const int outer_size,
const int axis_size);
template <typename T>
void softmax_inner1_small_axis(const T* din, T* dout, const int outer_size,
const int axis_size);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/arm/math/split.h"
#include <algorithm>
#include "paddle/fluid/lite/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <>
void split_cpy<float>(const float* din, float* dout, int num) {
int cnt = num >> 4;
int remain = num % 16;
#pragma omp parallel for
for (int i = 0; i < cnt; i++) {
const float* din_ptr = din + (i << 4);
float* dout_ptr = dout + (i << 4);
float32x4_t din0 = vld1q_f32(din_ptr);
float32x4_t din1 = vld1q_f32(din_ptr + 4);
float32x4_t din2 = vld1q_f32(din_ptr + 8);
float32x4_t din3 = vld1q_f32(din_ptr + 12);
vst1q_f32(dout_ptr, din0);
vst1q_f32(dout_ptr + 4, din1);
vst1q_f32(dout_ptr + 8, din2);
vst1q_f32(dout_ptr + 12, din3);
}
if (remain > 0) {
const float* din_ptr = din + (cnt << 4);
float* dout_ptr = dout + (cnt << 4);
for (int i = 0; i < remain; i++) {
*dout_ptr = *din_ptr;
dout_ptr++;
din_ptr++;
}
}
}
template <>
void split<float>(const float* din, const std::vector<lite::Tensor*>& dout,
const int axis, const std::vector<int>& in_strides) {
int input_offset = 0;
for (auto out : dout) {
auto out_dim = out->dims();
std::vector<int> out_strides(out_dim.size());
out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1];
for (int i = out_dim.size() - 2; i >= 0; --i) {
out_strides[i] = out_strides[i + 1] * out_dim[i];
}
float* out_data = out->mutable_data<float>();
int before = out_strides[0] / out_strides[axis];
int in_after = in_strides[axis];
int out_after = out_strides[axis];
for (int i = 0; i < before; ++i) {
split_cpy(din + input_offset + i * in_after, out_data + i * out_after,
out_after);
}
input_offset += out_strides[axis];
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
template <typename T>
void split_cpy(const T* din, T* dout, int num);
template <typename T>
void split(const T* din, const std::vector<lite::Tensor*>& dout, const int axis,
const std::vector<int>& in_strides);
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册