未验证 提交 27d5295c 编写于 作者: B BUG1989 提交者: GitHub

fix bugs of avx2 (#360)

上级 342f780e
......@@ -40,7 +40,7 @@ int conv_hcl_run(struct ir_tensor* input_tensor, struct ir_tensor* filter_tensor
int conv_hcl_get_shared_mem_size(struct ir_tensor* input_tensor, struct ir_tensor* output_tensor,
struct conv_param* param) __attribute__((weak));
int conv_hcl_get_shared_pack4_mem_size(struct ir_tensor* input_tensor, struct ir_tensor* output_tensor,
struct conv_param* param) __attribute__((weak));
struct conv_param* param) __attribute__((weak));
int conv_hcl_set_shared_mem(struct conv_priv_info* priv_info, void* mem, int mem_size) __attribute__((weak));
int conv_hcl_set_shared_pack4_mem(struct conv_priv_info* priv_info, void* mem, int mem_size) __attribute__((weak));
......
......@@ -115,6 +115,692 @@ static void im2col_ir(struct ir_tensor* input, struct ir_tensor* output, struct
param->pad_h0, param->pad_w0, param->dilation_h, param->dilation_w);
}
#if __AVX__
void input_pack4(int K, int N, float* pB, float* pB_t, int num_thread)
{
int nn_size = N >> 3;
int remian_size_start = nn_size << 3;
// [ch00, ch10, ch20, ch30, ch01, ch11, ch21, ch31, ch02, ch12, ch22, ch32, ch03, ch13, ch23, ch33 ....]
#pragma omp parallel for num_threads(num_thread)
for (int ii = 0; ii < nn_size; ii++)
{
int i = ii * 8;
const float* img = pB + i;
float* tmp = pB_t + (i / 8) * 8 * K;
for (int j = 0; j < K; j++)
{
#if __AVX__
_mm256_storeu_ps(tmp, _mm256_loadu_ps(img));
#else
tmp[0] = img[0];
tmp[1] = img[1];
tmp[2] = img[2];
tmp[3] = img[3];
tmp[4] = img[4];
tmp[5] = img[5];
tmp[6] = img[6];
tmp[7] = img[7];
#endif // __SSE__
tmp += 8;
img += N;
}
}
// [ch00, ch01, ch02, ch03 ....]
#pragma omp parallel for num_threads(num_thread)
for (int i = remian_size_start; i < N; i++)
{
const float* img = pB + i;
float* tmp = pB_t + (i / 8 + i % 8) * 8 * K;
for (int j = 0; j < K; j++)
{
tmp[0] = img[0];
tmp += 1;
img += N;
}
}
}
static void sgemm(int M, int N, int K, float* pA_t, float* pB_t, float* pC, int num_thread)
{
int nn_outch = 0;
int remain_outch_start = 0;
nn_outch = M >> 3;
remain_outch_start = nn_outch << 3;
#pragma omp parallel for num_threads(num_thread)
for (int pp = 0; pp < nn_outch; pp++)
{
int i = pp * 8;
float* output0 = pC + ( i )*N;
float* output1 = pC + (i + 1) * N;
float* output2 = pC + (i + 2) * N;
float* output3 = pC + (i + 3) * N;
float* output4 = pC + (i + 4) * N;
float* output5 = pC + (i + 5) * N;
float* output6 = pC + (i + 6) * N;
float* output7 = pC + (i + 7) * N;
int j = 0;
for (; j + 7 < N; j += 8)
{
float* va = pA_t + (i / 8) * 8 * K;
float* vb = pB_t + (j / 8) * 8 * K;
#if __AVX__
__m256 _sum0 = _mm256_set1_ps(0.0);
__m256 _sum1 = _mm256_set1_ps(0.0);
__m256 _sum2 = _mm256_set1_ps(0.0);
__m256 _sum3 = _mm256_set1_ps(0.0);
__m256 _sum4 = _mm256_set1_ps(0.0);
__m256 _sum5 = _mm256_set1_ps(0.0);
__m256 _sum6 = _mm256_set1_ps(0.0);
__m256 _sum7 = _mm256_set1_ps(0.0);
int k = 0;
for (; k + 3 < K; k = k + 4)
{
// k0
__m256 _va0 = _mm256_broadcast_ss(va);
__m256 _va1 = _mm256_broadcast_ss(va + 1);
__m256 _va2 = _mm256_broadcast_ss(va + 2);
__m256 _va3 = _mm256_broadcast_ss(va + 3);
__m256 _vb0 = _mm256_loadu_ps(vb);
__m256 _vb1 = _mm256_loadu_ps(vb + 8);
__m256 _vb2 = _mm256_loadu_ps(vb + 16);
__m256 _vb3 = _mm256_loadu_ps(vb + 24);
_sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
_sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10
_sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20
_sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30
_va0 = _mm256_broadcast_ss(va + 4);
_va1 = _mm256_broadcast_ss(va + 5);
_va2 = _mm256_broadcast_ss(va + 6);
_va3 = _mm256_broadcast_ss(va + 7);
_sum4 = _mm256_fmadd_ps(_vb0, _va0, _sum4); // sum4 = (a00-a07) * k40
_sum5 = _mm256_fmadd_ps(_vb0, _va1, _sum5); // sum5 = (a00-a07) * k50
_sum6 = _mm256_fmadd_ps(_vb0, _va2, _sum6); // sum6 = (a00-a07) * k60
_sum7 = _mm256_fmadd_ps(_vb0, _va3, _sum7); // sum7 = (a00-a07) * k70
va += 8;
// k1
_va0 = _mm256_broadcast_ss(va);
_va1 = _mm256_broadcast_ss(va + 1);
_va2 = _mm256_broadcast_ss(va + 2);
_va3 = _mm256_broadcast_ss(va + 3);
_sum0 = _mm256_fmadd_ps(_vb1, _va0, _sum0); // sum0 += (a10-a17) * k01
_sum1 = _mm256_fmadd_ps(_vb1, _va1, _sum1); // sum1 += (a10-a17) * k11
_sum2 = _mm256_fmadd_ps(_vb1, _va2, _sum2); // sum2 += (a10-a17) * k21
_sum3 = _mm256_fmadd_ps(_vb1, _va3, _sum3); // sum3 += (a10-a17) * k31
_va0 = _mm256_broadcast_ss(va + 4);
_va1 = _mm256_broadcast_ss(va + 5);
_va2 = _mm256_broadcast_ss(va + 6);
_va3 = _mm256_broadcast_ss(va + 7);
_sum4 = _mm256_fmadd_ps(_vb1, _va0, _sum4); // sum4 += (a10-a17) * k41
_sum5 = _mm256_fmadd_ps(_vb1, _va1, _sum5); // sum5 += (a10-a17) * k51
_sum6 = _mm256_fmadd_ps(_vb1, _va2, _sum6); // sum6 += (a10-a17) * k61
_sum7 = _mm256_fmadd_ps(_vb1, _va3, _sum7); // sum7 += (a10-a17) * k71
va += 8;
// k2
_va0 = _mm256_broadcast_ss(va);
_va1 = _mm256_broadcast_ss(va + 1);
_va2 = _mm256_broadcast_ss(va + 2);
_va3 = _mm256_broadcast_ss(va + 3);
_sum0 = _mm256_fmadd_ps(_vb2, _va0, _sum0); // sum0 += (a20-a27) * k02
_sum1 = _mm256_fmadd_ps(_vb2, _va1, _sum1); // sum1 += (a20-a27) * k12
_sum2 = _mm256_fmadd_ps(_vb2, _va2, _sum2); // sum2 += (a20-a27) * k22
_sum3 = _mm256_fmadd_ps(_vb2, _va3, _sum3); // sum3 += (a20-a27) * k32
_va0 = _mm256_broadcast_ss(va + 4);
_va1 = _mm256_broadcast_ss(va + 5);
_va2 = _mm256_broadcast_ss(va + 6);
_va3 = _mm256_broadcast_ss(va + 7);
_sum4 = _mm256_fmadd_ps(_vb2, _va0, _sum4); // sum4 += (a20-a27) * k42
_sum5 = _mm256_fmadd_ps(_vb2, _va1, _sum5); // sum5 += (a20-a27) * k52
_sum6 = _mm256_fmadd_ps(_vb2, _va2, _sum6); // sum6 += (a20-a27) * k62
_sum7 = _mm256_fmadd_ps(_vb2, _va3, _sum7); // sum7 += (a20-a27) * k72
va += 8;
// k3
_va0 = _mm256_broadcast_ss(va);
_va1 = _mm256_broadcast_ss(va + 1);
_va2 = _mm256_broadcast_ss(va + 2);
_va3 = _mm256_broadcast_ss(va + 3);
_sum0 = _mm256_fmadd_ps(_vb3, _va0, _sum0); // sum0 += (a30-a37) * k03
_sum1 = _mm256_fmadd_ps(_vb3, _va1, _sum1); // sum1 += (a30-a37) * k13
_sum2 = _mm256_fmadd_ps(_vb3, _va2, _sum2); // sum2 += (a30-a37) * k23
_sum3 = _mm256_fmadd_ps(_vb3, _va3, _sum3); // sum3 += (a30-a37) * k33
_va0 = _mm256_broadcast_ss(va + 4);
_va1 = _mm256_broadcast_ss(va + 5);
_va2 = _mm256_broadcast_ss(va + 6);
_va3 = _mm256_broadcast_ss(va + 7);
_sum4 = _mm256_fmadd_ps(_vb3, _va0, _sum4); // sum4 += (a30-a37) * k43
_sum5 = _mm256_fmadd_ps(_vb3, _va1, _sum5); // sum5 += (a30-a37) * k53
_sum6 = _mm256_fmadd_ps(_vb3, _va2, _sum6); // sum6 += (a30-a37) * k63
_sum7 = _mm256_fmadd_ps(_vb3, _va3, _sum7); // sum7 += (a30-a37) * k73
va += 8;
vb += 32;
}
for (; k < K; k++)
{
// k0
__m256 _va0 = _mm256_broadcast_ss(va);
__m256 _va1 = _mm256_broadcast_ss(va + 1);
__m256 _va2 = _mm256_broadcast_ss(va + 2);
__m256 _va3 = _mm256_broadcast_ss(va + 3);
__m256 _va4 = _mm256_broadcast_ss(va + 4);
__m256 _va5 = _mm256_broadcast_ss(va + 5);
__m256 _va6 = _mm256_broadcast_ss(va + 6);
__m256 _va7 = _mm256_broadcast_ss(va + 7);
__m256 _vb0 = _mm256_loadu_ps(vb);
_sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
_sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10
_sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20
_sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30
_sum4 = _mm256_fmadd_ps(_vb0, _va4, _sum4); // sum4 = (a00-a07) * k40
_sum5 = _mm256_fmadd_ps(_vb0, _va5, _sum5); // sum5 = (a00-a07) * k50
_sum6 = _mm256_fmadd_ps(_vb0, _va6, _sum6); // sum6 = (a00-a07) * k60
_sum7 = _mm256_fmadd_ps(_vb0, _va7, _sum7); // sum7 = (a00-a07) * k70
va += 8;
vb += 8;
}
_mm256_storeu_ps(output0, _sum0);
_mm256_storeu_ps(output1, _sum1);
_mm256_storeu_ps(output2, _sum2);
_mm256_storeu_ps(output3, _sum3);
_mm256_storeu_ps(output4, _sum4);
_mm256_storeu_ps(output5, _sum5);
_mm256_storeu_ps(output6, _sum6);
_mm256_storeu_ps(output7, _sum7);
#else
float sum0[8] = {0};
float sum1[8] = {0};
float sum2[8] = {0};
float sum3[8] = {0};
float sum4[8] = {0};
float sum5[8] = {0};
float sum6[8] = {0};
float sum7[8] = {0};
for (int k = 0; k < K; k++)
{
for (int n = 0; n < 8; n++)
{
sum0[n] += va[0] * vb[n];
sum1[n] += va[1] * vb[n];
sum2[n] += va[2] * vb[n];
sum3[n] += va[3] * vb[n];
sum4[n] += va[4] * vb[n];
sum5[n] += va[5] * vb[n];
sum6[n] += va[6] * vb[n];
sum7[n] += va[7] * vb[n];
}
va += 8;
vb += 8;
}
for (int n = 0; n < 8; n++)
{
output0[n] = sum0[n];
output1[n] = sum1[n];
output2[n] = sum2[n];
output3[n] = sum3[n];
output4[n] = sum4[n];
output5[n] = sum5[n];
output6[n] = sum6[n];
output7[n] = sum7[n];
}
#endif // __AVX__
output0 += 8;
output1 += 8;
output2 += 8;
output3 += 8;
output4 += 8;
output5 += 8;
output6 += 8;
output7 += 8;
}
for (; j < N; j++)
{
float* va = pA_t + (i / 8) * 8 * K;
float* vb = pB_t + (j / 8 + j % 8) * 8 * K;
#if __AVX__
__m256 _sum0_7 = _mm256_set1_ps(0.0);
__m256 _sum0 = _mm256_set1_ps(0.0);
__m256 _sum1 = _mm256_set1_ps(0.0);
__m256 _sum2 = _mm256_set1_ps(0.0);
__m256 _sum3 = _mm256_set1_ps(0.0);
int k = 0;
for (; k + 3 < K; k = k + 4)
{
__m256 _vb0 = _mm256_broadcast_ss(vb);
__m256 _vb1 = _mm256_broadcast_ss(vb + 1);
__m256 _vb2 = _mm256_broadcast_ss(vb + 2);
__m256 _vb3 = _mm256_broadcast_ss(vb + 3);
__m256 _va0 = _mm256_loadu_ps(va);
__m256 _va1 = _mm256_loadu_ps(va + 8);
__m256 _va2 = _mm256_loadu_ps(va + 16);
__m256 _va3 = _mm256_loadu_ps(va + 24);
_sum0 = _mm256_fmadd_ps(_va0, _vb0, _sum0); // sum0 += (k00-k70) * a00
_sum1 = _mm256_fmadd_ps(_va1, _vb1, _sum1); // sum1 += (k01-k71) * a10
_sum2 = _mm256_fmadd_ps(_va2, _vb2, _sum2); // sum2 += (k02-k72) * a20
_sum3 = _mm256_fmadd_ps(_va3, _vb3, _sum3); // sum3 += (k03-k73) * a30
va += 32;
vb += 4;
}
_sum0 = _mm256_add_ps(_sum0, _sum1);
_sum2 = _mm256_add_ps(_sum2, _sum3);
_sum0_7 = _mm256_add_ps(_sum0_7, _sum0);
_sum0_7 = _mm256_add_ps(_sum0_7, _sum2);
for (; k < K; k++)
{
__m256 _vb0 = _mm256_broadcast_ss(vb);
__m256 _va = _mm256_loadu_ps(va);
_sum0_7 = _mm256_fmadd_ps(_va, _vb0, _sum0_7); // sum0 += (k00-k70) * a00
va += 8;
vb += 1;
}
float output_sum0_7[8] = {0.f};
_mm256_storeu_ps(output_sum0_7, _sum0_7);
output0[0] = output_sum0_7[0];
output1[0] = output_sum0_7[1];
output2[0] = output_sum0_7[2];
output3[0] = output_sum0_7[3];
output4[0] = output_sum0_7[4];
output5[0] = output_sum0_7[5];
output6[0] = output_sum0_7[6];
output7[0] = output_sum0_7[7];
#else
float sum0 = 0;
float sum1 = 0;
float sum2 = 0;
float sum3 = 0;
float sum4 = 0;
float sum5 = 0;
float sum6 = 0;
float sum7 = 0;
for (int k = 0; k < K; k++)
{
sum0 += va[0] * vb[0];
sum1 += va[1] * vb[0];
sum2 += va[2] * vb[0];
sum3 += va[3] * vb[0];
sum4 += va[4] * vb[0];
sum5 += va[5] * vb[0];
sum6 += va[6] * vb[0];
sum7 += va[7] * vb[0];
va += 8;
vb += 1;
}
output0[0] = sum0;
output1[0] = sum1;
output2[0] = sum2;
output3[0] = sum3;
output4[0] = sum4;
output5[0] = sum5;
output6[0] = sum6;
output7[0] = sum7;
#endif // __AVX__
output0++;
output1++;
output2++;
output3++;
output4++;
output5++;
output6++;
output7++;
}
}
nn_outch = (M - remain_outch_start) >> 2;
for (int pp = 0; pp < nn_outch; pp++)
{
int i = remain_outch_start + pp * 4;
float* output0 = pC + ( i )*N;
float* output1 = pC + (i + 1) * N;
float* output2 = pC + (i + 2) * N;
float* output3 = pC + (i + 3) * N;
int j = 0;
for (; j + 7 < N; j += 8)
{
float* va = pA_t + (i / 8 + (i % 8) / 4) * 8 * K;
float* vb = pB_t + (j / 8) * 8 * K;
#if __AVX__
__m256 _sum0 = _mm256_set1_ps(0.0);
__m256 _sum1 = _mm256_set1_ps(0.0);
__m256 _sum2 = _mm256_set1_ps(0.0);
__m256 _sum3 = _mm256_set1_ps(0.0);
int k = 0;
for (; k + 3 < K; k = k + 4)
{
// k0
__m256 _va0 = _mm256_broadcast_ss(va);
__m256 _va1 = _mm256_broadcast_ss(va + 1);
__m256 _va2 = _mm256_broadcast_ss(va + 2);
__m256 _va3 = _mm256_broadcast_ss(va + 3);
__m256 _vb0 = _mm256_loadu_ps(vb);
__m256 _vb1 = _mm256_loadu_ps(vb + 8);
__m256 _vb2 = _mm256_loadu_ps(vb + 16);
__m256 _vb3 = _mm256_loadu_ps(vb + 24);
_sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
_sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10
_sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20
_sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30
va += 4;
// k1
_va0 = _mm256_broadcast_ss(va);
_va1 = _mm256_broadcast_ss(va + 1);
_va2 = _mm256_broadcast_ss(va + 2);
_va3 = _mm256_broadcast_ss(va + 3);
_sum0 = _mm256_fmadd_ps(_vb1, _va0, _sum0); // sum0 += (a10-a17) * k01
_sum1 = _mm256_fmadd_ps(_vb1, _va1, _sum1); // sum1 += (a10-a17) * k11
_sum2 = _mm256_fmadd_ps(_vb1, _va2, _sum2); // sum2 += (a10-a17) * k21
_sum3 = _mm256_fmadd_ps(_vb1, _va3, _sum3); // sum3 += (a10-a17) * k31
va += 4;
// k2
_va0 = _mm256_broadcast_ss(va);
_va1 = _mm256_broadcast_ss(va + 1);
_va2 = _mm256_broadcast_ss(va + 2);
_va3 = _mm256_broadcast_ss(va + 3);
_sum0 = _mm256_fmadd_ps(_vb2, _va0, _sum0); // sum0 += (a20-a27) * k02
_sum1 = _mm256_fmadd_ps(_vb2, _va1, _sum1); // sum1 += (a20-a27) * k12
_sum2 = _mm256_fmadd_ps(_vb2, _va2, _sum2); // sum2 += (a20-a27) * k22
_sum3 = _mm256_fmadd_ps(_vb2, _va3, _sum3); // sum3 += (a20-a27) * k32
va += 4;
// k3
_va0 = _mm256_broadcast_ss(va);
_va1 = _mm256_broadcast_ss(va + 1);
_va2 = _mm256_broadcast_ss(va + 2);
_va3 = _mm256_broadcast_ss(va + 3);
_sum0 = _mm256_fmadd_ps(_vb3, _va0, _sum0); // sum0 += (a30-a37) * k03
_sum1 = _mm256_fmadd_ps(_vb3, _va1, _sum1); // sum1 += (a30-a37) * k13
_sum2 = _mm256_fmadd_ps(_vb3, _va2, _sum2); // sum2 += (a30-a37) * k23
_sum3 = _mm256_fmadd_ps(_vb3, _va3, _sum3); // sum3 += (a30-a37) * k33
va += 4;
vb += 32;
}
for (; k < K; k++)
{
// k0
__m256 _va0 = _mm256_broadcast_ss(va);
__m256 _va1 = _mm256_broadcast_ss(va + 1);
__m256 _va2 = _mm256_broadcast_ss(va + 2);
__m256 _va3 = _mm256_broadcast_ss(va + 3);
__m256 _vb0 = _mm256_loadu_ps(vb);
_sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
_sum1 = _mm256_fmadd_ps(_vb0, _va1, _sum1); // sum1 = (a00-a07) * k10
_sum2 = _mm256_fmadd_ps(_vb0, _va2, _sum2); // sum2 = (a00-a07) * k20
_sum3 = _mm256_fmadd_ps(_vb0, _va3, _sum3); // sum3 = (a00-a07) * k30
va += 4;
vb += 8;
}
_mm256_storeu_ps(output0, _sum0);
_mm256_storeu_ps(output1, _sum1);
_mm256_storeu_ps(output2, _sum2);
_mm256_storeu_ps(output3, _sum3);
#else
float sum0[8] = {0};
float sum1[8] = {0};
float sum2[8] = {0};
float sum3[8] = {0};
for (int k = 0; k < K; k++)
{
for (int n = 0; n < 8; n++)
{
sum0[n] += va[0] * vb[n];
sum1[n] += va[1] * vb[n];
sum2[n] += va[2] * vb[n];
sum3[n] += va[3] * vb[n];
}
va += 4;
vb += 8;
}
for (int n = 0; n < 8; n++)
{
output0[n] = sum0[n];
output1[n] = sum1[n];
output2[n] = sum2[n];
output3[n] = sum3[n];
}
#endif // __AVX__
output0 += 8;
output1 += 8;
output2 += 8;
output3 += 8;
}
for (; j < N; j++)
{
float* va = pA_t + (i / 8 + (i % 8) / 4) * 8 * K;
float* vb = pB_t + (j / 8 + j % 8) * 8 * K;
#if __AVX__
__m128 _sum0_3 = _mm_set1_ps(0.0);
__m128 _sum0 = _mm_set1_ps(0.0);
__m128 _sum1 = _mm_set1_ps(0.0);
__m128 _sum2 = _mm_set1_ps(0.0);
__m128 _sum3 = _mm_set1_ps(0.0);
int k = 0;
for (; k + 3 < K; k = k + 4)
{
__m128 _vb0 = _mm_set1_ps(vb[0]);
__m128 _vb1 = _mm_set1_ps(vb[1]);
__m128 _vb2 = _mm_set1_ps(vb[2]);
__m128 _vb3 = _mm_set1_ps(vb[3]);
__m128 _va0 = _mm_loadu_ps(va);
__m128 _va1 = _mm_loadu_ps(va + 4);
__m128 _va2 = _mm_loadu_ps(va + 8);
__m128 _va3 = _mm_loadu_ps(va + 12);
_sum0 = _mm_fmadd_ps(_va0, _vb0, _sum0); // sum0 += (k00-k30) * a00
_sum1 = _mm_fmadd_ps(_va1, _vb1, _sum1); // sum1 += (k01-k31) * a10
_sum2 = _mm_fmadd_ps(_va2, _vb2, _sum2); // sum2 += (k02-k32) * a20
_sum3 = _mm_fmadd_ps(_va3, _vb3, _sum3); // sum3 += (k03-k33) * a30
va += 16;
vb += 4;
}
_sum0 = _mm_add_ps(_sum0, _sum1);
_sum2 = _mm_add_ps(_sum2, _sum3);
_sum0_3 = _mm_add_ps(_sum0_3, _sum0);
_sum0_3 = _mm_add_ps(_sum0_3, _sum2);
for (; k < K; k++)
{
__m128 _vb0 = _mm_set1_ps(vb[0]);
__m128 _va = _mm_loadu_ps(va);
_sum0_3 = _mm_fmadd_ps(_va, _vb0, _sum0_3); // sum0 += (k00-k30) * a00
va += 4;
vb += 1;
}
float output_sum0_3[4] = {0.f};
_mm_storeu_ps(output_sum0_3, _sum0_3);
output0[0] = output_sum0_3[0];
output1[0] = output_sum0_3[1];
output2[0] = output_sum0_3[2];
output3[0] = output_sum0_3[3];
#else
float sum0 = 0;
float sum1 = 0;
float sum2 = 0;
float sum3 = 0;
for (int k = 0; k < K; k++)
{
sum0 += va[0] * vb[0];
sum1 += va[1] * vb[0];
sum2 += va[2] * vb[0];
sum3 += va[3] * vb[0];
va += 4;
vb += 1;
}
output0[0] = sum0;
output1[0] = sum1;
output2[0] = sum2;
output3[0] = sum3;
#endif // __AVX__
output0++;
output1++;
output2++;
output3++;
}
}
remain_outch_start += nn_outch << 2;
// output ch0
for (int i = remain_outch_start; i < M; i++)
{
float* output = pC + i * N;
int j = 0;
for (; j + 7 < N; j += 8)
{
float* va = pA_t + (i / 8 + (i % 8) / 4 + i % 4) * 8 * K;
float* vb = pB_t + (j / 8) * 8 * K;
#if __AVX__
__m256 _sum0 = _mm256_set1_ps(0.0);
int k = 0;
for (; k + 3 < K; k = k + 4)
{
// k0
__m256 _va0 = _mm256_broadcast_ss(va);
__m256 _va1 = _mm256_broadcast_ss(va + 1);
__m256 _va2 = _mm256_broadcast_ss(va + 2);
__m256 _va3 = _mm256_broadcast_ss(va + 3);
__m256 _vb0 = _mm256_loadu_ps(vb);
__m256 _vb1 = _mm256_loadu_ps(vb + 8);
__m256 _vb2 = _mm256_loadu_ps(vb + 16);
__m256 _vb3 = _mm256_loadu_ps(vb + 24);
_sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
_sum0 = _mm256_fmadd_ps(_vb1, _va1, _sum0); // sum0 += (a10-a17) * k01
_sum0 = _mm256_fmadd_ps(_vb2, _va2, _sum0); // sum0 += (a20-a27) * k02
_sum0 = _mm256_fmadd_ps(_vb3, _va3, _sum0); // sum0 += (a30-a37) * k03
va += 4;
vb += 32;
}
for (; k < K; k++)
{
// k0
__m256 _va0 = _mm256_broadcast_ss(va);
__m256 _vb0 = _mm256_loadu_ps(vb);
_sum0 = _mm256_fmadd_ps(_vb0, _va0, _sum0); // sum0 = (a00-a07) * k00
va += 1;
vb += 8;
}
_mm256_storeu_ps(output, _sum0);
#else
float sum[8] = {0};
for (int k = 0; k < K; k++)
{
for (int n = 0; n < 8; n++)
{
sum[n] += va[0] * vb[n];
}
va += 1;
vb += 8;
}
for (int n = 0; n < 8; n++)
{
output[n] = sum[n];
}
#endif // __AVX__
output += 8;
}
for (; j < N; j++)
{
float* va = pA_t + (i / 8 + (i % 8) / 4 + i % 4) * 8 * K;
float* vb = pB_t + (j / 8 + j % 8) * 8 * K;
int k = 0;
#if __AVX__
__m128 _sum0 = _mm_set1_ps(0.f);
for (; k + 3 < K; k += 4)
{
__m128 _p0 = _mm_loadu_ps(vb);
__m128 _k0 = _mm_loadu_ps(va);
_sum0 = _mm_add_ps(_sum0, _mm_mul_ps(_p0, _k0));
va += 4;
vb += 4;
}
float sum0 = _sum0[0] + _sum0[1] + _sum0[2] + _sum0[3];
#else
float sum0 = 0.f;
#endif // __AVX__
for (; k < K; k++)
{
sum0 += va[0] * vb[0];
va += 1;
vb += 1;
}
output[0] = sum0;
output++;
}
}
}
#else // SSE2
void input_pack4(int K, int N, float* pB, float* pB_t, int num_thread)
{
int nn_size = N >> 2;
......@@ -159,7 +845,6 @@ void input_pack4(int K, int N, float* pB, float* pB_t, int num_thread)
}
}
}
// unloop output M, unloop N, packet 4x4, using intrinsic
static void sgemm(int M, int N, int K, float* pA_t, float* pB_t, float* pC, int num_thread)
{
......@@ -481,7 +1166,7 @@ static void sgemm(int M, int N, int K, float* pA_t, float* pB_t, float* pC, int
}
}
}
#endif // __AVX2__
static void sgemm_fp32(struct ir_tensor* input, struct ir_tensor* filter, struct ir_tensor* bias,
struct ir_tensor* output, struct conv_priv_info* priv_info, struct conv_param* param, int n,
int group, int num_thread)
......@@ -587,21 +1272,123 @@ int conv_hcl_get_shared_mem_size(struct ir_tensor* input, struct ir_tensor* outp
return elem_size * output_xy * kernel_size;
}
#if __AVX__
int conv_hcl_get_shared_pack4_mem_size(struct ir_tensor* filter, struct ir_tensor* output, struct conv_param* param)
{
int K = filter->elem_num / filter->dims[0];
int N = output->dims[2] * output->dims[3];
int elem_size = filter->elem_size;
return (4 * K * (N / 4 + N % 4)) * elem_size;
return (8 * K * (N / 8 + N % 8)) * elem_size;
}
int conv_hcl_get_interleave_pack4_size(int M, int K, struct ir_tensor* filter)
{
int size = 8 * K * (M / 8 + (M % 8) / 4 + M % 4) * filter->elem_size;
return size;
}
void conv_hcl_interleave_pack4(int M, int K, struct conv_priv_info* priv_info)
{
float* pA = ( float* )priv_info->interleave_buffer;
float* pA_t = ( float* )priv_info->interleave_buffer_pack4;
int nn_outch = M >> 3;
int remain_outch_start = nn_outch << 3;
for (int pp = 0; pp < nn_outch; pp++)
{
int p = pp * 8;
const float* k0 = pA + (p + 0) * K;
const float* k1 = pA + (p + 1) * K;
const float* k2 = pA + (p + 2) * K;
const float* k3 = pA + (p + 3) * K;
const float* k4 = pA + (p + 4) * K;
const float* k5 = pA + (p + 5) * K;
const float* k6 = pA + (p + 6) * K;
const float* k7 = pA + (p + 7) * K;
float* ktmp = pA_t + (p / 8) * 8 * K;
for (int q = 0; q < K; q++)
{
ktmp[0] = k0[0];
ktmp[1] = k1[0];
ktmp[2] = k2[0];
ktmp[3] = k3[0];
ktmp[4] = k4[0];
ktmp[5] = k5[0];
ktmp[6] = k6[0];
ktmp[7] = k7[0];
ktmp += 8;
k0 += 1;
k1 += 1;
k2 += 1;
k3 += 1;
k4 += 1;
k5 += 1;
k6 += 1;
k7 += 1;
}
}
nn_outch = (M - remain_outch_start) >> 2;
for (int pp = 0; pp < nn_outch; pp++)
{
int p = remain_outch_start + pp * 4;
const float* k0 = pA + (p + 0) * K;
const float* k1 = pA + (p + 1) * K;
const float* k2 = pA + (p + 2) * K;
const float* k3 = pA + (p + 3) * K;
float* ktmp = pA_t + (p / 8 + (p % 8) / 4) * 8 * K;
for (int q = 0; q < K; q++)
{
ktmp[0] = k0[0];
ktmp[1] = k1[0];
ktmp[2] = k2[0];
ktmp[3] = k3[0];
ktmp += 4;
k0 += 1;
k1 += 1;
k2 += 1;
k3 += 1;
}
}
remain_outch_start += nn_outch << 2;
for (int p = remain_outch_start; p < M; p++)
{
const float* k0 = pA + (p + 0) * K;
float* ktmp = pA_t + (p / 8 + (p % 8) / 4 + p % 4) * 8 * K;
for (int q = 0; q < K; q++)
{
ktmp[0] = k0[0];
ktmp++;
k0++;
}
}
}
#else
int conv_hcl_get_shared_pack4_mem_size(struct ir_tensor* filter, struct ir_tensor* output, struct conv_param* param)
{
int K = filter->elem_num / filter->dims[0];
int N = output->dims[2] * output->dims[3];
int elem_size = filter->elem_size;
return (4 * K * (N / 4 + N % 4)) * elem_size;
}
int conv_hcl_get_interleave_pack4_size(int M, int K, struct ir_tensor* filter)
{
int size = 4 * K * (M / 4 + M % 4) * filter->elem_size;
return size;
}
void conv_hcl_interleave_pack4(int M, int K, struct conv_priv_info* priv_info)
{
float* pA = ( float* )priv_info->interleave_buffer;
......@@ -650,7 +1437,7 @@ void conv_hcl_interleave_pack4(int M, int K, struct conv_priv_info* priv_info)
}
}
}
#endif
int conv_hcl_prerun(struct ir_tensor* input_tensor, struct ir_tensor* filter_tensor, struct ir_tensor* output_tensor,
struct conv_priv_info* priv_info, struct conv_param* param)
{
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册