Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
OAID
Tengine
提交
27d5295c
T
Tengine
项目概览
OAID
/
Tengine
11 个月 前同步成功
通知
53
Star
4429
Fork
1032
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tengine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
27d5295c
编写于
8月 03, 2020
作者:
B
BUG1989
提交者:
GitHub
8月 03, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bugs of avx2 (#360)
上级
342f780e
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
793 addition
and
6 deletion
+793
-6
src/dev/cpu/op/conv/conv_hcl_kernel.h
src/dev/cpu/op/conv/conv_hcl_kernel.h
+1
-1
src/dev/cpu/op/conv/x86/conv_kernel_x86.c
src/dev/cpu/op/conv/x86/conv_kernel_x86.c
+792
-5
未找到文件。
src/dev/cpu/op/conv/conv_hcl_kernel.h
浏览文件 @
27d5295c
src/dev/cpu/op/conv/x86/conv_kernel_x86.c
浏览文件 @
27d5295c
...
...
@@ -115,6 +115,692 @@ static void im2col_ir(struct ir_tensor* input, struct ir_tensor* output, struct
param
->
pad_h0
,
param
->
pad_w0
,
param
->
dilation_h
,
param
->
dilation_w
);
}
#if __AVX__
void
input_pack4
(
int
K
,
int
N
,
float
*
pB
,
float
*
pB_t
,
int
num_thread
)
{
int
nn_size
=
N
>>
3
;
int
remian_size_start
=
nn_size
<<
3
;
// [ch00, ch10, ch20, ch30, ch01, ch11, ch21, ch31, ch02, ch12, ch22, ch32, ch03, ch13, ch23, ch33 ....]
#pragma omp parallel for num_threads(num_thread)
for
(
int
ii
=
0
;
ii
<
nn_size
;
ii
++
)
{
int
i
=
ii
*
8
;
const
float
*
img
=
pB
+
i
;
float
*
tmp
=
pB_t
+
(
i
/
8
)
*
8
*
K
;
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
#if __AVX__
_mm256_storeu_ps
(
tmp
,
_mm256_loadu_ps
(
img
));
#else
tmp
[
0
]
=
img
[
0
];
tmp
[
1
]
=
img
[
1
];
tmp
[
2
]
=
img
[
2
];
tmp
[
3
]
=
img
[
3
];
tmp
[
4
]
=
img
[
4
];
tmp
[
5
]
=
img
[
5
];
tmp
[
6
]
=
img
[
6
];
tmp
[
7
]
=
img
[
7
];
#endif // __SSE__
tmp
+=
8
;
img
+=
N
;
}
}
// [ch00, ch01, ch02, ch03 ....]
#pragma omp parallel for num_threads(num_thread)
for
(
int
i
=
remian_size_start
;
i
<
N
;
i
++
)
{
const
float
*
img
=
pB
+
i
;
float
*
tmp
=
pB_t
+
(
i
/
8
+
i
%
8
)
*
8
*
K
;
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
tmp
[
0
]
=
img
[
0
];
tmp
+=
1
;
img
+=
N
;
}
}
}
static
void
sgemm
(
int
M
,
int
N
,
int
K
,
float
*
pA_t
,
float
*
pB_t
,
float
*
pC
,
int
num_thread
)
{
int
nn_outch
=
0
;
int
remain_outch_start
=
0
;
nn_outch
=
M
>>
3
;
remain_outch_start
=
nn_outch
<<
3
;
#pragma omp parallel for num_threads(num_thread)
for
(
int
pp
=
0
;
pp
<
nn_outch
;
pp
++
)
{
int
i
=
pp
*
8
;
float
*
output0
=
pC
+
(
i
)
*
N
;
float
*
output1
=
pC
+
(
i
+
1
)
*
N
;
float
*
output2
=
pC
+
(
i
+
2
)
*
N
;
float
*
output3
=
pC
+
(
i
+
3
)
*
N
;
float
*
output4
=
pC
+
(
i
+
4
)
*
N
;
float
*
output5
=
pC
+
(
i
+
5
)
*
N
;
float
*
output6
=
pC
+
(
i
+
6
)
*
N
;
float
*
output7
=
pC
+
(
i
+
7
)
*
N
;
int
j
=
0
;
for
(;
j
+
7
<
N
;
j
+=
8
)
{
float
*
va
=
pA_t
+
(
i
/
8
)
*
8
*
K
;
float
*
vb
=
pB_t
+
(
j
/
8
)
*
8
*
K
;
#if __AVX__
__m256
_sum0
=
_mm256_set1_ps
(
0
.
0
);
__m256
_sum1
=
_mm256_set1_ps
(
0
.
0
);
__m256
_sum2
=
_mm256_set1_ps
(
0
.
0
);
__m256
_sum3
=
_mm256_set1_ps
(
0
.
0
);
__m256
_sum4
=
_mm256_set1_ps
(
0
.
0
);
__m256
_sum5
=
_mm256_set1_ps
(
0
.
0
);
__m256
_sum6
=
_mm256_set1_ps
(
0
.
0
);
__m256
_sum7
=
_mm256_set1_ps
(
0
.
0
);
int
k
=
0
;
for
(;
k
+
3
<
K
;
k
=
k
+
4
)
{
// k0
__m256
_va0
=
_mm256_broadcast_ss
(
va
);
__m256
_va1
=
_mm256_broadcast_ss
(
va
+
1
);
__m256
_va2
=
_mm256_broadcast_ss
(
va
+
2
);
__m256
_va3
=
_mm256_broadcast_ss
(
va
+
3
);
__m256
_vb0
=
_mm256_loadu_ps
(
vb
);
__m256
_vb1
=
_mm256_loadu_ps
(
vb
+
8
);
__m256
_vb2
=
_mm256_loadu_ps
(
vb
+
16
);
__m256
_vb3
=
_mm256_loadu_ps
(
vb
+
24
);
_sum0
=
_mm256_fmadd_ps
(
_vb0
,
_va0
,
_sum0
);
// sum0 = (a00-a07) * k00
_sum1
=
_mm256_fmadd_ps
(
_vb0
,
_va1
,
_sum1
);
// sum1 = (a00-a07) * k10
_sum2
=
_mm256_fmadd_ps
(
_vb0
,
_va2
,
_sum2
);
// sum2 = (a00-a07) * k20
_sum3
=
_mm256_fmadd_ps
(
_vb0
,
_va3
,
_sum3
);
// sum3 = (a00-a07) * k30
_va0
=
_mm256_broadcast_ss
(
va
+
4
);
_va1
=
_mm256_broadcast_ss
(
va
+
5
);
_va2
=
_mm256_broadcast_ss
(
va
+
6
);
_va3
=
_mm256_broadcast_ss
(
va
+
7
);
_sum4
=
_mm256_fmadd_ps
(
_vb0
,
_va0
,
_sum4
);
// sum4 = (a00-a07) * k40
_sum5
=
_mm256_fmadd_ps
(
_vb0
,
_va1
,
_sum5
);
// sum5 = (a00-a07) * k50
_sum6
=
_mm256_fmadd_ps
(
_vb0
,
_va2
,
_sum6
);
// sum6 = (a00-a07) * k60
_sum7
=
_mm256_fmadd_ps
(
_vb0
,
_va3
,
_sum7
);
// sum7 = (a00-a07) * k70
va
+=
8
;
// k1
_va0
=
_mm256_broadcast_ss
(
va
);
_va1
=
_mm256_broadcast_ss
(
va
+
1
);
_va2
=
_mm256_broadcast_ss
(
va
+
2
);
_va3
=
_mm256_broadcast_ss
(
va
+
3
);
_sum0
=
_mm256_fmadd_ps
(
_vb1
,
_va0
,
_sum0
);
// sum0 += (a10-a17) * k01
_sum1
=
_mm256_fmadd_ps
(
_vb1
,
_va1
,
_sum1
);
// sum1 += (a10-a17) * k11
_sum2
=
_mm256_fmadd_ps
(
_vb1
,
_va2
,
_sum2
);
// sum2 += (a10-a17) * k21
_sum3
=
_mm256_fmadd_ps
(
_vb1
,
_va3
,
_sum3
);
// sum3 += (a10-a17) * k31
_va0
=
_mm256_broadcast_ss
(
va
+
4
);
_va1
=
_mm256_broadcast_ss
(
va
+
5
);
_va2
=
_mm256_broadcast_ss
(
va
+
6
);
_va3
=
_mm256_broadcast_ss
(
va
+
7
);
_sum4
=
_mm256_fmadd_ps
(
_vb1
,
_va0
,
_sum4
);
// sum4 += (a10-a17) * k41
_sum5
=
_mm256_fmadd_ps
(
_vb1
,
_va1
,
_sum5
);
// sum5 += (a10-a17) * k51
_sum6
=
_mm256_fmadd_ps
(
_vb1
,
_va2
,
_sum6
);
// sum6 += (a10-a17) * k61
_sum7
=
_mm256_fmadd_ps
(
_vb1
,
_va3
,
_sum7
);
// sum7 += (a10-a17) * k71
va
+=
8
;
// k2
_va0
=
_mm256_broadcast_ss
(
va
);
_va1
=
_mm256_broadcast_ss
(
va
+
1
);
_va2
=
_mm256_broadcast_ss
(
va
+
2
);
_va3
=
_mm256_broadcast_ss
(
va
+
3
);
_sum0
=
_mm256_fmadd_ps
(
_vb2
,
_va0
,
_sum0
);
// sum0 += (a20-a27) * k02
_sum1
=
_mm256_fmadd_ps
(
_vb2
,
_va1
,
_sum1
);
// sum1 += (a20-a27) * k12
_sum2
=
_mm256_fmadd_ps
(
_vb2
,
_va2
,
_sum2
);
// sum2 += (a20-a27) * k22
_sum3
=
_mm256_fmadd_ps
(
_vb2
,
_va3
,
_sum3
);
// sum3 += (a20-a27) * k32
_va0
=
_mm256_broadcast_ss
(
va
+
4
);
_va1
=
_mm256_broadcast_ss
(
va
+
5
);
_va2
=
_mm256_broadcast_ss
(
va
+
6
);
_va3
=
_mm256_broadcast_ss
(
va
+
7
);
_sum4
=
_mm256_fmadd_ps
(
_vb2
,
_va0
,
_sum4
);
// sum4 += (a20-a27) * k42
_sum5
=
_mm256_fmadd_ps
(
_vb2
,
_va1
,
_sum5
);
// sum5 += (a20-a27) * k52
_sum6
=
_mm256_fmadd_ps
(
_vb2
,
_va2
,
_sum6
);
// sum6 += (a20-a27) * k62
_sum7
=
_mm256_fmadd_ps
(
_vb2
,
_va3
,
_sum7
);
// sum7 += (a20-a27) * k72
va
+=
8
;
// k3
_va0
=
_mm256_broadcast_ss
(
va
);
_va1
=
_mm256_broadcast_ss
(
va
+
1
);
_va2
=
_mm256_broadcast_ss
(
va
+
2
);
_va3
=
_mm256_broadcast_ss
(
va
+
3
);
_sum0
=
_mm256_fmadd_ps
(
_vb3
,
_va0
,
_sum0
);
// sum0 += (a30-a37) * k03
_sum1
=
_mm256_fmadd_ps
(
_vb3
,
_va1
,
_sum1
);
// sum1 += (a30-a37) * k13
_sum2
=
_mm256_fmadd_ps
(
_vb3
,
_va2
,
_sum2
);
// sum2 += (a30-a37) * k23
_sum3
=
_mm256_fmadd_ps
(
_vb3
,
_va3
,
_sum3
);
// sum3 += (a30-a37) * k33
_va0
=
_mm256_broadcast_ss
(
va
+
4
);
_va1
=
_mm256_broadcast_ss
(
va
+
5
);
_va2
=
_mm256_broadcast_ss
(
va
+
6
);
_va3
=
_mm256_broadcast_ss
(
va
+
7
);
_sum4
=
_mm256_fmadd_ps
(
_vb3
,
_va0
,
_sum4
);
// sum4 += (a30-a37) * k43
_sum5
=
_mm256_fmadd_ps
(
_vb3
,
_va1
,
_sum5
);
// sum5 += (a30-a37) * k53
_sum6
=
_mm256_fmadd_ps
(
_vb3
,
_va2
,
_sum6
);
// sum6 += (a30-a37) * k63
_sum7
=
_mm256_fmadd_ps
(
_vb3
,
_va3
,
_sum7
);
// sum7 += (a30-a37) * k73
va
+=
8
;
vb
+=
32
;
}
for
(;
k
<
K
;
k
++
)
{
// k0
__m256
_va0
=
_mm256_broadcast_ss
(
va
);
__m256
_va1
=
_mm256_broadcast_ss
(
va
+
1
);
__m256
_va2
=
_mm256_broadcast_ss
(
va
+
2
);
__m256
_va3
=
_mm256_broadcast_ss
(
va
+
3
);
__m256
_va4
=
_mm256_broadcast_ss
(
va
+
4
);
__m256
_va5
=
_mm256_broadcast_ss
(
va
+
5
);
__m256
_va6
=
_mm256_broadcast_ss
(
va
+
6
);
__m256
_va7
=
_mm256_broadcast_ss
(
va
+
7
);
__m256
_vb0
=
_mm256_loadu_ps
(
vb
);
_sum0
=
_mm256_fmadd_ps
(
_vb0
,
_va0
,
_sum0
);
// sum0 = (a00-a07) * k00
_sum1
=
_mm256_fmadd_ps
(
_vb0
,
_va1
,
_sum1
);
// sum1 = (a00-a07) * k10
_sum2
=
_mm256_fmadd_ps
(
_vb0
,
_va2
,
_sum2
);
// sum2 = (a00-a07) * k20
_sum3
=
_mm256_fmadd_ps
(
_vb0
,
_va3
,
_sum3
);
// sum3 = (a00-a07) * k30
_sum4
=
_mm256_fmadd_ps
(
_vb0
,
_va4
,
_sum4
);
// sum4 = (a00-a07) * k40
_sum5
=
_mm256_fmadd_ps
(
_vb0
,
_va5
,
_sum5
);
// sum5 = (a00-a07) * k50
_sum6
=
_mm256_fmadd_ps
(
_vb0
,
_va6
,
_sum6
);
// sum6 = (a00-a07) * k60
_sum7
=
_mm256_fmadd_ps
(
_vb0
,
_va7
,
_sum7
);
// sum7 = (a00-a07) * k70
va
+=
8
;
vb
+=
8
;
}
_mm256_storeu_ps
(
output0
,
_sum0
);
_mm256_storeu_ps
(
output1
,
_sum1
);
_mm256_storeu_ps
(
output2
,
_sum2
);
_mm256_storeu_ps
(
output3
,
_sum3
);
_mm256_storeu_ps
(
output4
,
_sum4
);
_mm256_storeu_ps
(
output5
,
_sum5
);
_mm256_storeu_ps
(
output6
,
_sum6
);
_mm256_storeu_ps
(
output7
,
_sum7
);
#else
float
sum0
[
8
]
=
{
0
};
float
sum1
[
8
]
=
{
0
};
float
sum2
[
8
]
=
{
0
};
float
sum3
[
8
]
=
{
0
};
float
sum4
[
8
]
=
{
0
};
float
sum5
[
8
]
=
{
0
};
float
sum6
[
8
]
=
{
0
};
float
sum7
[
8
]
=
{
0
};
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
for
(
int
n
=
0
;
n
<
8
;
n
++
)
{
sum0
[
n
]
+=
va
[
0
]
*
vb
[
n
];
sum1
[
n
]
+=
va
[
1
]
*
vb
[
n
];
sum2
[
n
]
+=
va
[
2
]
*
vb
[
n
];
sum3
[
n
]
+=
va
[
3
]
*
vb
[
n
];
sum4
[
n
]
+=
va
[
4
]
*
vb
[
n
];
sum5
[
n
]
+=
va
[
5
]
*
vb
[
n
];
sum6
[
n
]
+=
va
[
6
]
*
vb
[
n
];
sum7
[
n
]
+=
va
[
7
]
*
vb
[
n
];
}
va
+=
8
;
vb
+=
8
;
}
for
(
int
n
=
0
;
n
<
8
;
n
++
)
{
output0
[
n
]
=
sum0
[
n
];
output1
[
n
]
=
sum1
[
n
];
output2
[
n
]
=
sum2
[
n
];
output3
[
n
]
=
sum3
[
n
];
output4
[
n
]
=
sum4
[
n
];
output5
[
n
]
=
sum5
[
n
];
output6
[
n
]
=
sum6
[
n
];
output7
[
n
]
=
sum7
[
n
];
}
#endif // __AVX__
output0
+=
8
;
output1
+=
8
;
output2
+=
8
;
output3
+=
8
;
output4
+=
8
;
output5
+=
8
;
output6
+=
8
;
output7
+=
8
;
}
for
(;
j
<
N
;
j
++
)
{
float
*
va
=
pA_t
+
(
i
/
8
)
*
8
*
K
;
float
*
vb
=
pB_t
+
(
j
/
8
+
j
%
8
)
*
8
*
K
;
#if __AVX__
__m256
_sum0_7
=
_mm256_set1_ps
(
0
.
0
);
__m256
_sum0
=
_mm256_set1_ps
(
0
.
0
);
__m256
_sum1
=
_mm256_set1_ps
(
0
.
0
);
__m256
_sum2
=
_mm256_set1_ps
(
0
.
0
);
__m256
_sum3
=
_mm256_set1_ps
(
0
.
0
);
int
k
=
0
;
for
(;
k
+
3
<
K
;
k
=
k
+
4
)
{
__m256
_vb0
=
_mm256_broadcast_ss
(
vb
);
__m256
_vb1
=
_mm256_broadcast_ss
(
vb
+
1
);
__m256
_vb2
=
_mm256_broadcast_ss
(
vb
+
2
);
__m256
_vb3
=
_mm256_broadcast_ss
(
vb
+
3
);
__m256
_va0
=
_mm256_loadu_ps
(
va
);
__m256
_va1
=
_mm256_loadu_ps
(
va
+
8
);
__m256
_va2
=
_mm256_loadu_ps
(
va
+
16
);
__m256
_va3
=
_mm256_loadu_ps
(
va
+
24
);
_sum0
=
_mm256_fmadd_ps
(
_va0
,
_vb0
,
_sum0
);
// sum0 += (k00-k70) * a00
_sum1
=
_mm256_fmadd_ps
(
_va1
,
_vb1
,
_sum1
);
// sum1 += (k01-k71) * a10
_sum2
=
_mm256_fmadd_ps
(
_va2
,
_vb2
,
_sum2
);
// sum2 += (k02-k72) * a20
_sum3
=
_mm256_fmadd_ps
(
_va3
,
_vb3
,
_sum3
);
// sum3 += (k03-k73) * a30
va
+=
32
;
vb
+=
4
;
}
_sum0
=
_mm256_add_ps
(
_sum0
,
_sum1
);
_sum2
=
_mm256_add_ps
(
_sum2
,
_sum3
);
_sum0_7
=
_mm256_add_ps
(
_sum0_7
,
_sum0
);
_sum0_7
=
_mm256_add_ps
(
_sum0_7
,
_sum2
);
for
(;
k
<
K
;
k
++
)
{
__m256
_vb0
=
_mm256_broadcast_ss
(
vb
);
__m256
_va
=
_mm256_loadu_ps
(
va
);
_sum0_7
=
_mm256_fmadd_ps
(
_va
,
_vb0
,
_sum0_7
);
// sum0 += (k00-k70) * a00
va
+=
8
;
vb
+=
1
;
}
float
output_sum0_7
[
8
]
=
{
0
.
f
};
_mm256_storeu_ps
(
output_sum0_7
,
_sum0_7
);
output0
[
0
]
=
output_sum0_7
[
0
];
output1
[
0
]
=
output_sum0_7
[
1
];
output2
[
0
]
=
output_sum0_7
[
2
];
output3
[
0
]
=
output_sum0_7
[
3
];
output4
[
0
]
=
output_sum0_7
[
4
];
output5
[
0
]
=
output_sum0_7
[
5
];
output6
[
0
]
=
output_sum0_7
[
6
];
output7
[
0
]
=
output_sum0_7
[
7
];
#else
float
sum0
=
0
;
float
sum1
=
0
;
float
sum2
=
0
;
float
sum3
=
0
;
float
sum4
=
0
;
float
sum5
=
0
;
float
sum6
=
0
;
float
sum7
=
0
;
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
sum0
+=
va
[
0
]
*
vb
[
0
];
sum1
+=
va
[
1
]
*
vb
[
0
];
sum2
+=
va
[
2
]
*
vb
[
0
];
sum3
+=
va
[
3
]
*
vb
[
0
];
sum4
+=
va
[
4
]
*
vb
[
0
];
sum5
+=
va
[
5
]
*
vb
[
0
];
sum6
+=
va
[
6
]
*
vb
[
0
];
sum7
+=
va
[
7
]
*
vb
[
0
];
va
+=
8
;
vb
+=
1
;
}
output0
[
0
]
=
sum0
;
output1
[
0
]
=
sum1
;
output2
[
0
]
=
sum2
;
output3
[
0
]
=
sum3
;
output4
[
0
]
=
sum4
;
output5
[
0
]
=
sum5
;
output6
[
0
]
=
sum6
;
output7
[
0
]
=
sum7
;
#endif // __AVX__
output0
++
;
output1
++
;
output2
++
;
output3
++
;
output4
++
;
output5
++
;
output6
++
;
output7
++
;
}
}
nn_outch
=
(
M
-
remain_outch_start
)
>>
2
;
for
(
int
pp
=
0
;
pp
<
nn_outch
;
pp
++
)
{
int
i
=
remain_outch_start
+
pp
*
4
;
float
*
output0
=
pC
+
(
i
)
*
N
;
float
*
output1
=
pC
+
(
i
+
1
)
*
N
;
float
*
output2
=
pC
+
(
i
+
2
)
*
N
;
float
*
output3
=
pC
+
(
i
+
3
)
*
N
;
int
j
=
0
;
for
(;
j
+
7
<
N
;
j
+=
8
)
{
float
*
va
=
pA_t
+
(
i
/
8
+
(
i
%
8
)
/
4
)
*
8
*
K
;
float
*
vb
=
pB_t
+
(
j
/
8
)
*
8
*
K
;
#if __AVX__
__m256
_sum0
=
_mm256_set1_ps
(
0
.
0
);
__m256
_sum1
=
_mm256_set1_ps
(
0
.
0
);
__m256
_sum2
=
_mm256_set1_ps
(
0
.
0
);
__m256
_sum3
=
_mm256_set1_ps
(
0
.
0
);
int
k
=
0
;
for
(;
k
+
3
<
K
;
k
=
k
+
4
)
{
// k0
__m256
_va0
=
_mm256_broadcast_ss
(
va
);
__m256
_va1
=
_mm256_broadcast_ss
(
va
+
1
);
__m256
_va2
=
_mm256_broadcast_ss
(
va
+
2
);
__m256
_va3
=
_mm256_broadcast_ss
(
va
+
3
);
__m256
_vb0
=
_mm256_loadu_ps
(
vb
);
__m256
_vb1
=
_mm256_loadu_ps
(
vb
+
8
);
__m256
_vb2
=
_mm256_loadu_ps
(
vb
+
16
);
__m256
_vb3
=
_mm256_loadu_ps
(
vb
+
24
);
_sum0
=
_mm256_fmadd_ps
(
_vb0
,
_va0
,
_sum0
);
// sum0 = (a00-a07) * k00
_sum1
=
_mm256_fmadd_ps
(
_vb0
,
_va1
,
_sum1
);
// sum1 = (a00-a07) * k10
_sum2
=
_mm256_fmadd_ps
(
_vb0
,
_va2
,
_sum2
);
// sum2 = (a00-a07) * k20
_sum3
=
_mm256_fmadd_ps
(
_vb0
,
_va3
,
_sum3
);
// sum3 = (a00-a07) * k30
va
+=
4
;
// k1
_va0
=
_mm256_broadcast_ss
(
va
);
_va1
=
_mm256_broadcast_ss
(
va
+
1
);
_va2
=
_mm256_broadcast_ss
(
va
+
2
);
_va3
=
_mm256_broadcast_ss
(
va
+
3
);
_sum0
=
_mm256_fmadd_ps
(
_vb1
,
_va0
,
_sum0
);
// sum0 += (a10-a17) * k01
_sum1
=
_mm256_fmadd_ps
(
_vb1
,
_va1
,
_sum1
);
// sum1 += (a10-a17) * k11
_sum2
=
_mm256_fmadd_ps
(
_vb1
,
_va2
,
_sum2
);
// sum2 += (a10-a17) * k21
_sum3
=
_mm256_fmadd_ps
(
_vb1
,
_va3
,
_sum3
);
// sum3 += (a10-a17) * k31
va
+=
4
;
// k2
_va0
=
_mm256_broadcast_ss
(
va
);
_va1
=
_mm256_broadcast_ss
(
va
+
1
);
_va2
=
_mm256_broadcast_ss
(
va
+
2
);
_va3
=
_mm256_broadcast_ss
(
va
+
3
);
_sum0
=
_mm256_fmadd_ps
(
_vb2
,
_va0
,
_sum0
);
// sum0 += (a20-a27) * k02
_sum1
=
_mm256_fmadd_ps
(
_vb2
,
_va1
,
_sum1
);
// sum1 += (a20-a27) * k12
_sum2
=
_mm256_fmadd_ps
(
_vb2
,
_va2
,
_sum2
);
// sum2 += (a20-a27) * k22
_sum3
=
_mm256_fmadd_ps
(
_vb2
,
_va3
,
_sum3
);
// sum3 += (a20-a27) * k32
va
+=
4
;
// k3
_va0
=
_mm256_broadcast_ss
(
va
);
_va1
=
_mm256_broadcast_ss
(
va
+
1
);
_va2
=
_mm256_broadcast_ss
(
va
+
2
);
_va3
=
_mm256_broadcast_ss
(
va
+
3
);
_sum0
=
_mm256_fmadd_ps
(
_vb3
,
_va0
,
_sum0
);
// sum0 += (a30-a37) * k03
_sum1
=
_mm256_fmadd_ps
(
_vb3
,
_va1
,
_sum1
);
// sum1 += (a30-a37) * k13
_sum2
=
_mm256_fmadd_ps
(
_vb3
,
_va2
,
_sum2
);
// sum2 += (a30-a37) * k23
_sum3
=
_mm256_fmadd_ps
(
_vb3
,
_va3
,
_sum3
);
// sum3 += (a30-a37) * k33
va
+=
4
;
vb
+=
32
;
}
for
(;
k
<
K
;
k
++
)
{
// k0
__m256
_va0
=
_mm256_broadcast_ss
(
va
);
__m256
_va1
=
_mm256_broadcast_ss
(
va
+
1
);
__m256
_va2
=
_mm256_broadcast_ss
(
va
+
2
);
__m256
_va3
=
_mm256_broadcast_ss
(
va
+
3
);
__m256
_vb0
=
_mm256_loadu_ps
(
vb
);
_sum0
=
_mm256_fmadd_ps
(
_vb0
,
_va0
,
_sum0
);
// sum0 = (a00-a07) * k00
_sum1
=
_mm256_fmadd_ps
(
_vb0
,
_va1
,
_sum1
);
// sum1 = (a00-a07) * k10
_sum2
=
_mm256_fmadd_ps
(
_vb0
,
_va2
,
_sum2
);
// sum2 = (a00-a07) * k20
_sum3
=
_mm256_fmadd_ps
(
_vb0
,
_va3
,
_sum3
);
// sum3 = (a00-a07) * k30
va
+=
4
;
vb
+=
8
;
}
_mm256_storeu_ps
(
output0
,
_sum0
);
_mm256_storeu_ps
(
output1
,
_sum1
);
_mm256_storeu_ps
(
output2
,
_sum2
);
_mm256_storeu_ps
(
output3
,
_sum3
);
#else
float
sum0
[
8
]
=
{
0
};
float
sum1
[
8
]
=
{
0
};
float
sum2
[
8
]
=
{
0
};
float
sum3
[
8
]
=
{
0
};
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
for
(
int
n
=
0
;
n
<
8
;
n
++
)
{
sum0
[
n
]
+=
va
[
0
]
*
vb
[
n
];
sum1
[
n
]
+=
va
[
1
]
*
vb
[
n
];
sum2
[
n
]
+=
va
[
2
]
*
vb
[
n
];
sum3
[
n
]
+=
va
[
3
]
*
vb
[
n
];
}
va
+=
4
;
vb
+=
8
;
}
for
(
int
n
=
0
;
n
<
8
;
n
++
)
{
output0
[
n
]
=
sum0
[
n
];
output1
[
n
]
=
sum1
[
n
];
output2
[
n
]
=
sum2
[
n
];
output3
[
n
]
=
sum3
[
n
];
}
#endif // __AVX__
output0
+=
8
;
output1
+=
8
;
output2
+=
8
;
output3
+=
8
;
}
for
(;
j
<
N
;
j
++
)
{
float
*
va
=
pA_t
+
(
i
/
8
+
(
i
%
8
)
/
4
)
*
8
*
K
;
float
*
vb
=
pB_t
+
(
j
/
8
+
j
%
8
)
*
8
*
K
;
#if __AVX__
__m128
_sum0_3
=
_mm_set1_ps
(
0
.
0
);
__m128
_sum0
=
_mm_set1_ps
(
0
.
0
);
__m128
_sum1
=
_mm_set1_ps
(
0
.
0
);
__m128
_sum2
=
_mm_set1_ps
(
0
.
0
);
__m128
_sum3
=
_mm_set1_ps
(
0
.
0
);
int
k
=
0
;
for
(;
k
+
3
<
K
;
k
=
k
+
4
)
{
__m128
_vb0
=
_mm_set1_ps
(
vb
[
0
]);
__m128
_vb1
=
_mm_set1_ps
(
vb
[
1
]);
__m128
_vb2
=
_mm_set1_ps
(
vb
[
2
]);
__m128
_vb3
=
_mm_set1_ps
(
vb
[
3
]);
__m128
_va0
=
_mm_loadu_ps
(
va
);
__m128
_va1
=
_mm_loadu_ps
(
va
+
4
);
__m128
_va2
=
_mm_loadu_ps
(
va
+
8
);
__m128
_va3
=
_mm_loadu_ps
(
va
+
12
);
_sum0
=
_mm_fmadd_ps
(
_va0
,
_vb0
,
_sum0
);
// sum0 += (k00-k30) * a00
_sum1
=
_mm_fmadd_ps
(
_va1
,
_vb1
,
_sum1
);
// sum1 += (k01-k31) * a10
_sum2
=
_mm_fmadd_ps
(
_va2
,
_vb2
,
_sum2
);
// sum2 += (k02-k32) * a20
_sum3
=
_mm_fmadd_ps
(
_va3
,
_vb3
,
_sum3
);
// sum3 += (k03-k33) * a30
va
+=
16
;
vb
+=
4
;
}
_sum0
=
_mm_add_ps
(
_sum0
,
_sum1
);
_sum2
=
_mm_add_ps
(
_sum2
,
_sum3
);
_sum0_3
=
_mm_add_ps
(
_sum0_3
,
_sum0
);
_sum0_3
=
_mm_add_ps
(
_sum0_3
,
_sum2
);
for
(;
k
<
K
;
k
++
)
{
__m128
_vb0
=
_mm_set1_ps
(
vb
[
0
]);
__m128
_va
=
_mm_loadu_ps
(
va
);
_sum0_3
=
_mm_fmadd_ps
(
_va
,
_vb0
,
_sum0_3
);
// sum0 += (k00-k30) * a00
va
+=
4
;
vb
+=
1
;
}
float
output_sum0_3
[
4
]
=
{
0
.
f
};
_mm_storeu_ps
(
output_sum0_3
,
_sum0_3
);
output0
[
0
]
=
output_sum0_3
[
0
];
output1
[
0
]
=
output_sum0_3
[
1
];
output2
[
0
]
=
output_sum0_3
[
2
];
output3
[
0
]
=
output_sum0_3
[
3
];
#else
float
sum0
=
0
;
float
sum1
=
0
;
float
sum2
=
0
;
float
sum3
=
0
;
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
sum0
+=
va
[
0
]
*
vb
[
0
];
sum1
+=
va
[
1
]
*
vb
[
0
];
sum2
+=
va
[
2
]
*
vb
[
0
];
sum3
+=
va
[
3
]
*
vb
[
0
];
va
+=
4
;
vb
+=
1
;
}
output0
[
0
]
=
sum0
;
output1
[
0
]
=
sum1
;
output2
[
0
]
=
sum2
;
output3
[
0
]
=
sum3
;
#endif // __AVX__
output0
++
;
output1
++
;
output2
++
;
output3
++
;
}
}
remain_outch_start
+=
nn_outch
<<
2
;
// output ch0
for
(
int
i
=
remain_outch_start
;
i
<
M
;
i
++
)
{
float
*
output
=
pC
+
i
*
N
;
int
j
=
0
;
for
(;
j
+
7
<
N
;
j
+=
8
)
{
float
*
va
=
pA_t
+
(
i
/
8
+
(
i
%
8
)
/
4
+
i
%
4
)
*
8
*
K
;
float
*
vb
=
pB_t
+
(
j
/
8
)
*
8
*
K
;
#if __AVX__
__m256
_sum0
=
_mm256_set1_ps
(
0
.
0
);
int
k
=
0
;
for
(;
k
+
3
<
K
;
k
=
k
+
4
)
{
// k0
__m256
_va0
=
_mm256_broadcast_ss
(
va
);
__m256
_va1
=
_mm256_broadcast_ss
(
va
+
1
);
__m256
_va2
=
_mm256_broadcast_ss
(
va
+
2
);
__m256
_va3
=
_mm256_broadcast_ss
(
va
+
3
);
__m256
_vb0
=
_mm256_loadu_ps
(
vb
);
__m256
_vb1
=
_mm256_loadu_ps
(
vb
+
8
);
__m256
_vb2
=
_mm256_loadu_ps
(
vb
+
16
);
__m256
_vb3
=
_mm256_loadu_ps
(
vb
+
24
);
_sum0
=
_mm256_fmadd_ps
(
_vb0
,
_va0
,
_sum0
);
// sum0 = (a00-a07) * k00
_sum0
=
_mm256_fmadd_ps
(
_vb1
,
_va1
,
_sum0
);
// sum0 += (a10-a17) * k01
_sum0
=
_mm256_fmadd_ps
(
_vb2
,
_va2
,
_sum0
);
// sum0 += (a20-a27) * k02
_sum0
=
_mm256_fmadd_ps
(
_vb3
,
_va3
,
_sum0
);
// sum0 += (a30-a37) * k03
va
+=
4
;
vb
+=
32
;
}
for
(;
k
<
K
;
k
++
)
{
// k0
__m256
_va0
=
_mm256_broadcast_ss
(
va
);
__m256
_vb0
=
_mm256_loadu_ps
(
vb
);
_sum0
=
_mm256_fmadd_ps
(
_vb0
,
_va0
,
_sum0
);
// sum0 = (a00-a07) * k00
va
+=
1
;
vb
+=
8
;
}
_mm256_storeu_ps
(
output
,
_sum0
);
#else
float
sum
[
8
]
=
{
0
};
for
(
int
k
=
0
;
k
<
K
;
k
++
)
{
for
(
int
n
=
0
;
n
<
8
;
n
++
)
{
sum
[
n
]
+=
va
[
0
]
*
vb
[
n
];
}
va
+=
1
;
vb
+=
8
;
}
for
(
int
n
=
0
;
n
<
8
;
n
++
)
{
output
[
n
]
=
sum
[
n
];
}
#endif // __AVX__
output
+=
8
;
}
for
(;
j
<
N
;
j
++
)
{
float
*
va
=
pA_t
+
(
i
/
8
+
(
i
%
8
)
/
4
+
i
%
4
)
*
8
*
K
;
float
*
vb
=
pB_t
+
(
j
/
8
+
j
%
8
)
*
8
*
K
;
int
k
=
0
;
#if __AVX__
__m128
_sum0
=
_mm_set1_ps
(
0
.
f
);
for
(;
k
+
3
<
K
;
k
+=
4
)
{
__m128
_p0
=
_mm_loadu_ps
(
vb
);
__m128
_k0
=
_mm_loadu_ps
(
va
);
_sum0
=
_mm_add_ps
(
_sum0
,
_mm_mul_ps
(
_p0
,
_k0
));
va
+=
4
;
vb
+=
4
;
}
float
sum0
=
_sum0
[
0
]
+
_sum0
[
1
]
+
_sum0
[
2
]
+
_sum0
[
3
];
#else
float
sum0
=
0
.
f
;
#endif // __AVX__
for
(;
k
<
K
;
k
++
)
{
sum0
+=
va
[
0
]
*
vb
[
0
];
va
+=
1
;
vb
+=
1
;
}
output
[
0
]
=
sum0
;
output
++
;
}
}
}
#else // SSE2
void
input_pack4
(
int
K
,
int
N
,
float
*
pB
,
float
*
pB_t
,
int
num_thread
)
{
int
nn_size
=
N
>>
2
;
...
...
@@ -159,7 +845,6 @@ void input_pack4(int K, int N, float* pB, float* pB_t, int num_thread)
}
}
}
// unloop output M, unloop N, packet 4x4, using intrinsic
static
void
sgemm
(
int
M
,
int
N
,
int
K
,
float
*
pA_t
,
float
*
pB_t
,
float
*
pC
,
int
num_thread
)
{
...
...
@@ -481,7 +1166,7 @@ static void sgemm(int M, int N, int K, float* pA_t, float* pB_t, float* pC, int
}
}
}
#endif // __AVX2__
static
void
sgemm_fp32
(
struct
ir_tensor
*
input
,
struct
ir_tensor
*
filter
,
struct
ir_tensor
*
bias
,
struct
ir_tensor
*
output
,
struct
conv_priv_info
*
priv_info
,
struct
conv_param
*
param
,
int
n
,
int
group
,
int
num_thread
)
...
...
@@ -587,21 +1272,123 @@ int conv_hcl_get_shared_mem_size(struct ir_tensor* input, struct ir_tensor* outp
return
elem_size
*
output_xy
*
kernel_size
;
}
#if __AVX__
int
conv_hcl_get_shared_pack4_mem_size
(
struct
ir_tensor
*
filter
,
struct
ir_tensor
*
output
,
struct
conv_param
*
param
)
{
int
K
=
filter
->
elem_num
/
filter
->
dims
[
0
];
int
N
=
output
->
dims
[
2
]
*
output
->
dims
[
3
];
int
elem_size
=
filter
->
elem_size
;
return
(
4
*
K
*
(
N
/
4
+
N
%
4
))
*
elem_size
;
return
(
8
*
K
*
(
N
/
8
+
N
%
8
))
*
elem_size
;
}
int
conv_hcl_get_interleave_pack4_size
(
int
M
,
int
K
,
struct
ir_tensor
*
filter
)
{
int
size
=
8
*
K
*
(
M
/
8
+
(
M
%
8
)
/
4
+
M
%
4
)
*
filter
->
elem_size
;
return
size
;
}
void
conv_hcl_interleave_pack4
(
int
M
,
int
K
,
struct
conv_priv_info
*
priv_info
)
{
float
*
pA
=
(
float
*
)
priv_info
->
interleave_buffer
;
float
*
pA_t
=
(
float
*
)
priv_info
->
interleave_buffer_pack4
;
int
nn_outch
=
M
>>
3
;
int
remain_outch_start
=
nn_outch
<<
3
;
for
(
int
pp
=
0
;
pp
<
nn_outch
;
pp
++
)
{
int
p
=
pp
*
8
;
const
float
*
k0
=
pA
+
(
p
+
0
)
*
K
;
const
float
*
k1
=
pA
+
(
p
+
1
)
*
K
;
const
float
*
k2
=
pA
+
(
p
+
2
)
*
K
;
const
float
*
k3
=
pA
+
(
p
+
3
)
*
K
;
const
float
*
k4
=
pA
+
(
p
+
4
)
*
K
;
const
float
*
k5
=
pA
+
(
p
+
5
)
*
K
;
const
float
*
k6
=
pA
+
(
p
+
6
)
*
K
;
const
float
*
k7
=
pA
+
(
p
+
7
)
*
K
;
float
*
ktmp
=
pA_t
+
(
p
/
8
)
*
8
*
K
;
for
(
int
q
=
0
;
q
<
K
;
q
++
)
{
ktmp
[
0
]
=
k0
[
0
];
ktmp
[
1
]
=
k1
[
0
];
ktmp
[
2
]
=
k2
[
0
];
ktmp
[
3
]
=
k3
[
0
];
ktmp
[
4
]
=
k4
[
0
];
ktmp
[
5
]
=
k5
[
0
];
ktmp
[
6
]
=
k6
[
0
];
ktmp
[
7
]
=
k7
[
0
];
ktmp
+=
8
;
k0
+=
1
;
k1
+=
1
;
k2
+=
1
;
k3
+=
1
;
k4
+=
1
;
k5
+=
1
;
k6
+=
1
;
k7
+=
1
;
}
}
nn_outch
=
(
M
-
remain_outch_start
)
>>
2
;
for
(
int
pp
=
0
;
pp
<
nn_outch
;
pp
++
)
{
int
p
=
remain_outch_start
+
pp
*
4
;
const
float
*
k0
=
pA
+
(
p
+
0
)
*
K
;
const
float
*
k1
=
pA
+
(
p
+
1
)
*
K
;
const
float
*
k2
=
pA
+
(
p
+
2
)
*
K
;
const
float
*
k3
=
pA
+
(
p
+
3
)
*
K
;
float
*
ktmp
=
pA_t
+
(
p
/
8
+
(
p
%
8
)
/
4
)
*
8
*
K
;
for
(
int
q
=
0
;
q
<
K
;
q
++
)
{
ktmp
[
0
]
=
k0
[
0
];
ktmp
[
1
]
=
k1
[
0
];
ktmp
[
2
]
=
k2
[
0
];
ktmp
[
3
]
=
k3
[
0
];
ktmp
+=
4
;
k0
+=
1
;
k1
+=
1
;
k2
+=
1
;
k3
+=
1
;
}
}
remain_outch_start
+=
nn_outch
<<
2
;
for
(
int
p
=
remain_outch_start
;
p
<
M
;
p
++
)
{
const
float
*
k0
=
pA
+
(
p
+
0
)
*
K
;
float
*
ktmp
=
pA_t
+
(
p
/
8
+
(
p
%
8
)
/
4
+
p
%
4
)
*
8
*
K
;
for
(
int
q
=
0
;
q
<
K
;
q
++
)
{
ktmp
[
0
]
=
k0
[
0
];
ktmp
++
;
k0
++
;
}
}
}
#else
int
conv_hcl_get_shared_pack4_mem_size
(
struct
ir_tensor
*
filter
,
struct
ir_tensor
*
output
,
struct
conv_param
*
param
)
{
int
K
=
filter
->
elem_num
/
filter
->
dims
[
0
];
int
N
=
output
->
dims
[
2
]
*
output
->
dims
[
3
];
int
elem_size
=
filter
->
elem_size
;
return
(
4
*
K
*
(
N
/
4
+
N
%
4
))
*
elem_size
;
}
int
conv_hcl_get_interleave_pack4_size
(
int
M
,
int
K
,
struct
ir_tensor
*
filter
)
{
int
size
=
4
*
K
*
(
M
/
4
+
M
%
4
)
*
filter
->
elem_size
;
return
size
;
}
void
conv_hcl_interleave_pack4
(
int
M
,
int
K
,
struct
conv_priv_info
*
priv_info
)
{
float
*
pA
=
(
float
*
)
priv_info
->
interleave_buffer
;
...
...
@@ -650,7 +1437,7 @@ void conv_hcl_interleave_pack4(int M, int K, struct conv_priv_info* priv_info)
}
}
}
#endif
int
conv_hcl_prerun
(
struct
ir_tensor
*
input_tensor
,
struct
ir_tensor
*
filter_tensor
,
struct
ir_tensor
*
output_tensor
,
struct
conv_priv_info
*
priv_info
,
struct
conv_param
*
param
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录