Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
f99c34c8
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f99c34c8
编写于
12月 11, 2019
作者:
T
TianXiaogang
提交者:
yiicy
12月 11, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add winograd f23 implement (#2584)
上级
fbb0d3b5
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
1461 addition
and
181 deletion
+1461
-181
lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc
lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc
+817
-77
lite/backends/arm/math/conv_impl.h
lite/backends/arm/math/conv_impl.h
+29
-1
lite/backends/arm/math/packed_sgemm_c4.cc
lite/backends/arm/math/packed_sgemm_c4.cc
+534
-1
lite/backends/arm/math/packed_sgemm_c4.h
lite/backends/arm/math/packed_sgemm_c4.h
+7
-0
lite/kernels/arm/conv_compute.cc
lite/kernels/arm/conv_compute.cc
+3
-13
lite/kernels/arm/conv_winograd.cc
lite/kernels/arm/conv_winograd.cc
+68
-88
lite/kernels/arm/conv_winograd.h
lite/kernels/arm/conv_winograd.h
+3
-1
未找到文件。
lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc
浏览文件 @
f99c34c8
...
...
@@ -24,29 +24,48 @@ namespace paddle {
namespace
lite
{
namespace
arm
{
namespace
math
{
void
input_trans_c4
(
const
float
*
src
,
int
src_stride
,
float
*
dest
,
int
dest_stride
);
void
output_trans_c4
(
const
float
*
src
,
int
src_stride
,
float
*
dest
,
int
dest_stride
);
void
output_trans_c4_post
(
const
float
*
src
,
int
src_stride
,
float
*
dest
,
int
dest_stride
,
float
*
bias_value
,
bool
has_relu
);
void
weight_trans_c4
(
void
input_trans_c4_8x8
(
const
float
*
src
,
int
src_stride
,
float
*
dest
,
int
dest_stride
);
void
output_trans_c4_6x8
(
const
float
*
src
,
int
src_stride
,
float
*
dest
,
int
dest_stride
);
void
output_trans_c4_post_6x8
(
const
float
*
src
,
int
src_stride
,
float
*
dest
,
int
dest_stride
,
float
*
bias_value
,
bool
has_relu
);
void
input_trans_c4_4x4
(
const
float
*
src
,
int
src_stride
,
int
src_h_stride
,
float
*
dest
,
int
dest_stride
,
int
dest_h_stride
);
void
output_trans_c4_post_2x4
(
const
float
*
src
,
int
src_stride
,
int
src_h_stride
,
float
*
dest
,
int
dest_stride
,
int
dest_h_stride
,
float
*
bias_value
,
bool
has_relu
);
void
weight_trans_c4_8x8
(
float
*
dest
,
const
float
*
src
,
int
ic
,
int
oc
,
void
*
workspace
);
void
weight_trans_c4_4x4
(
float
*
dest
,
const
float
*
src
,
int
ic
,
int
oc
,
void
*
workspace
);
/*
*The following function conv_compute_6x6_3x3 is base on
*The following function conv_compute_6x6_3x3 and conv_compute_2x2_3x3[_small] is
*base on
*MNN[https://github.com/alibaba/MNN]
*
*Copyright © 2018, Alibaba Group Holding Limited
*/
// F(6,3)
void
conv_compute_6x6_3x3
(
const
float
*
input
,
float
*
output
,
int
num
,
...
...
@@ -75,11 +94,14 @@ void conv_compute_6x6_3x3(const float* input,
int
tile_w
=
(
wout
+
5
)
/
6
;
int
tile_h
=
(
hout
+
5
)
/
6
;
int
size_tile
=
tile_h
*
tile_w
;
float
zero_ptr
[
8
];
memset
(
zero_ptr
,
0
,
8
*
sizeof
(
float
));
int
w_pad
=
win
+
pad_w
*
2
;
int
h_pad
=
hin
+
pad_h
*
2
;
const
int
zero_len
=
w_pad
;
float
zero_ptr
[
zero_len
];
// NOLINT
memset
(
zero_ptr
,
0
,
zero_len
*
sizeof
(
float
));
float
*
input_c4
=
tmp_work_space
;
int
new_h_stride
=
w_pad
*
4
;
int
new_c_stride
=
new_h_stride
*
h_pad
;
...
...
@@ -88,9 +110,6 @@ void conv_compute_6x6_3x3(const float* input,
int
oc_4_stride
=
wout
*
hout
*
4
;
int
tile_block
=
8
;
#ifdef __aarch64__
tile_block
=
16
;
#endif
int
block_count
=
(
size_tile
+
tile_block
-
1
)
/
tile_block
;
int
threads
=
ctx
->
threads
();
...
...
@@ -102,7 +121,8 @@ void conv_compute_6x6_3x3(const float* input,
// begin compute
for
(
int
ni
=
0
;
ni
<
num
;
++
ni
)
{
// trans input to c4
// trans input to c4
#pragma omp parallel for num_threads(threads)
for
(
int
i
=
0
;
i
<
ic_4
;
++
i
)
{
prepack_input_nxwc4_dw
(
input
+
ni
*
in_n_stride
,
input_c4
+
i
*
new_c_stride
,
...
...
@@ -161,14 +181,14 @@ void conv_compute_6x6_3x3(const float* input,
const
float
*
src_ci
=
src_ptr
+
ci
*
ic_4_stride
;
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
const
float
*
ci_ptr
=
src_ci
+
i
*
w_pad
*
4
;
input_trans_c4
(
ci_ptr
,
4
,
trans_tmp_data
+
i
*
4
,
32
);
input_trans_c4
_8x8
(
ci_ptr
,
4
,
trans_tmp_data
+
i
*
4
,
32
);
}
float
*
dst_ci
=
dst_ptr
+
ci
*
tile_count
*
4
;
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
input_trans_c4
(
trans_tmp_data
+
i
*
32
,
4
,
dst_ci
+
i
*
b_gi_stride
*
8
,
b_gi_stride
);
input_trans_c4
_8x8
(
trans_tmp_data
+
i
*
32
,
4
,
dst_ci
+
i
*
b_gi_stride
*
8
,
b_gi_stride
);
}
}
}
else
{
...
...
@@ -189,14 +209,14 @@ void conv_compute_6x6_3x3(const float* input,
// trans
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
float
*
ci_ptr
=
trans_remain_tmp_data
+
i
*
32
;
input_trans_c4
(
ci_ptr
,
4
,
trans_tmp_data
+
i
*
4
,
32
);
input_trans_c4
_8x8
(
ci_ptr
,
4
,
trans_tmp_data
+
i
*
4
,
32
);
}
float
*
dst_ci
=
dst_ptr
+
ci
*
tile_count
*
4
;
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
input_trans_c4
(
trans_tmp_data
+
i
*
32
,
4
,
dst_ci
+
i
*
b_gi_stride
*
8
,
b_gi_stride
);
input_trans_c4
_8x8
(
trans_tmp_data
+
i
*
32
,
4
,
dst_ci
+
i
*
b_gi_stride
*
8
,
b_gi_stride
);
}
}
// for ci_4
}
...
...
@@ -213,16 +233,8 @@ void conv_compute_6x6_3x3(const float* input,
float
*
origin_C
=
dst_temp_data
+
gi
*
c_gi_stride
;
float
*
origin_B
=
b_ptr
+
gi
*
b_gi_stride
;
const
float
*
origin_A
=
weight
+
gi
*
w_gi_stride
;
sgemm_prepack_c4_small
(
oc_4
*
4
,
tile_count
,
ic_4
*
4
,
origin_A
,
origin_B
,
origin_C
,
nullptr
,
false
,
false
,
ctx
);
sgemm_prepack_c4_small
(
oc_4
*
4
,
tile_count
,
ic_4
*
4
,
origin_A
,
origin_B
,
origin_C
,
ctx
);
}
//*/
//*
...
...
@@ -258,18 +270,18 @@ void conv_compute_6x6_3x3(const float* input,
float
*
dst_ci
=
dst_ptr
+
ci
*
oc_4_stride
;
float
*
src_ci
=
src_ptr
+
ci
*
tile_count
*
4
;
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
output_trans_c4
(
src_ci
+
i
*
c_gi_stride
*
8
,
c_gi_stride
,
trans_tmp_data
+
i
*
4
,
32
);
output_trans_c4
_6x8
(
src_ci
+
i
*
c_gi_stride
*
8
,
c_gi_stride
,
trans_tmp_data
+
i
*
4
,
32
);
}
for
(
int
i
=
0
;
i
<
ey
;
++
i
)
{
output_trans_c4_post
(
trans_tmp_data
+
i
*
32
,
4
,
trans_remain_tmp_data
+
i
*
24
,
4
,
bias_value
,
param
.
fuse_relu
);
output_trans_c4_post
_6x8
(
trans_tmp_data
+
i
*
32
,
4
,
trans_remain_tmp_data
+
i
*
24
,
4
,
bias_value
,
param
.
fuse_relu
);
}
write_to_output_c4_fp32
(
trans_remain_tmp_data
,
output_ptr
,
...
...
@@ -297,18 +309,18 @@ void conv_compute_6x6_3x3(const float* input,
float
*
dst_ci
=
dst_ptr
+
ci
*
oc_4_stride
;
float
*
src_ci
=
src_ptr
+
ci
*
tile_count
*
4
;
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
output_trans_c4
(
src_ci
+
i
*
c_gi_stride
*
8
,
c_gi_stride
,
trans_tmp_data
+
i
*
4
,
32
);
output_trans_c4
_6x8
(
src_ci
+
i
*
c_gi_stride
*
8
,
c_gi_stride
,
trans_tmp_data
+
i
*
4
,
32
);
}
for
(
int
i
=
0
;
i
<
ey
;
++
i
)
{
output_trans_c4_post
(
trans_tmp_data
+
i
*
32
,
4
,
trans_remain_tmp_data
+
i
*
24
,
4
,
bias_value
,
param
.
fuse_relu
);
output_trans_c4_post
_6x8
(
trans_tmp_data
+
i
*
32
,
4
,
trans_remain_tmp_data
+
i
*
24
,
4
,
bias_value
,
param
.
fuse_relu
);
}
// copy to dest
memset
(
trans_tmp_data
,
0
,
144
*
sizeof
(
float
));
...
...
@@ -338,10 +350,522 @@ void conv_compute_6x6_3x3(const float* input,
}
// for num
}
// conv_compute
void
output_trans_c4
(
const
float
*
src
,
int
src_stride
,
float
*
dest
,
int
dest_stride
)
{
// F(2,3)
void
conv_compute_2x2_3x3
(
const
float
*
input
,
float
*
output
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
const
float
*
weight
,
const
float
*
bias
,
const
operators
::
ConvParam
&
param
,
ARMContext
*
ctx
)
{
const
int
pad_h
=
(
*
param
.
paddings
)[
0
];
const
int
pad_w
=
(
*
param
.
paddings
)[
2
];
float
*
tmp_work_space
=
ctx
->
workspace_data
<
float
>
()
+
ctx
->
llc_size
()
/
sizeof
(
float
);
int
in_n_stride
=
chin
*
hin
*
win
;
int
out_n_stride
=
chout
*
hout
*
wout
;
int
ic_stride
=
win
*
hin
;
int
oc_stride
=
wout
*
hout
;
int
ic_4
=
(
chin
+
3
)
/
4
;
int
oc_4
=
(
chout
+
3
)
/
4
;
int
tile_w
=
(
wout
+
1
)
/
2
;
int
tile_h
=
(
hout
+
1
)
/
2
;
int
size_tile
=
tile_h
*
tile_w
;
int
w_pad
=
win
+
pad_w
*
2
;
int
h_pad
=
hin
+
pad_h
*
2
;
const
int
zero_len
=
w_pad
;
float
zero_ptr
[
zero_len
];
// NOLINT
memset
(
zero_ptr
,
0
,
zero_len
*
sizeof
(
float
));
float
*
input_c4
=
tmp_work_space
;
int
new_h_stride
=
w_pad
*
4
;
int
new_c_stride
=
new_h_stride
*
h_pad
;
int
ic_4_stride
=
w_pad
*
h_pad
*
4
;
int
oc_4_stride
=
wout
*
hout
*
4
;
int
tile_block
=
8
;
int
block_count
=
(
size_tile
+
tile_block
-
1
)
/
tile_block
;
int
threads
=
ctx
->
threads
();
float
*
g_tmp_data
=
tmp_work_space
+
ic_4
*
new_c_stride
;
int
tmp_data_thread_stride
=
tile_block
*
(
oc_4
+
ic_4
)
*
64
;
memset
(
g_tmp_data
,
0
,
threads
*
tmp_data_thread_stride
*
sizeof
(
float
));
float
*
g_trans_tmp_data
=
g_tmp_data
+
threads
*
tmp_data_thread_stride
;
float
*
g_trans_remain_tmp_data
=
g_trans_tmp_data
+
threads
*
64
;
// begin compute
for
(
int
ni
=
0
;
ni
<
num
;
++
ni
)
{
// trans input to c4
#pragma omp parallel for num_threads(threads)
for
(
int
i
=
0
;
i
<
ic_4
;
++
i
)
{
prepack_input_nxwc4_dw
(
input
+
ni
*
in_n_stride
,
input_c4
+
i
*
new_c_stride
,
i
*
4
,
-
pad_h
,
hin
+
pad_h
,
-
pad_w
,
win
+
pad_w
,
chin
,
win
,
hin
,
zero_ptr
);
}
float
*
output_ptr
=
output
+
ni
*
out_n_stride
;
const
float
*
weight_ptr
=
weight
;
const
float
*
bias_ptr
=
bias
;
#pragma omp parallel for num_threads(threads)
for
(
int
tbi
=
0
;
tbi
<
block_count
;
++
tbi
)
{
#ifdef ARM_WITH_OMP
float
*
tmp_data
=
g_tmp_data
+
omp_get_thread_num
()
*
tmp_data_thread_stride
;
float
*
trans_tmp_data
=
g_trans_tmp_data
+
omp_get_thread_num
()
*
64
;
float
*
trans_remain_tmp_data
=
g_trans_remain_tmp_data
+
omp_get_thread_num
()
*
64
;
#else
float
*
tmp_data
=
g_tmp_data
;
float
*
trans_tmp_data
=
g_trans_tmp_data
;
float
*
trans_remain_tmp_data
=
g_trans_remain_tmp_data
;
#endif
int
tile_index
=
tbi
*
tile_block
;
int
tile_remain
=
size_tile
-
tile_index
;
int
tile_count
=
tile_remain
>
tile_block
?
tile_block
:
tile_remain
;
// input trans
int
c_gi_stride
=
tile_count
*
oc_4
*
4
;
int
b_gi_stride
=
tile_count
*
ic_4
*
4
;
//*
for
(
int
ti
=
0
;
ti
<
tile_count
;
++
ti
)
{
int
index
=
tile_index
+
ti
;
int
tw_index
=
index
%
tile_w
;
int
th_index
=
index
/
tile_w
;
int
src_x
=
tw_index
+
tw_index
;
int
src_y
=
th_index
+
th_index
;
int
ex
=
src_x
+
4
>
w_pad
?
w_pad
-
src_x
:
4
;
int
ey
=
src_y
+
4
>
h_pad
?
h_pad
-
src_y
:
4
;
float
*
dst_ptr
=
tmp_data
+
ti
*
4
;
const
float
*
src_ptr
=
input_c4
+
(
src_y
*
w_pad
+
src_x
)
*
4
;
if
(
ex
==
4
&&
ey
==
4
)
{
// trans input
for
(
int
ci
=
0
;
ci
<
ic_4
;
++
ci
)
{
const
float
*
src_ci
=
src_ptr
+
ci
*
ic_4_stride
;
float
*
dst_ci
=
dst_ptr
+
ci
*
tile_count
*
4
;
input_trans_c4_4x4
(
src_ci
,
4
,
w_pad
*
4
,
dst_ci
,
b_gi_stride
,
b_gi_stride
*
4
);
}
}
else
{
// trans remain input
int
x_size
=
ex
;
for
(
int
ci
=
0
;
ci
<
ic_4
;
++
ci
)
{
const
float
*
src_ci
=
src_ptr
+
ci
*
ic_4_stride
;
// pad
memset
(
trans_remain_tmp_data
,
0
,
64
*
sizeof
(
float
));
if
(
x_size
>
0
)
{
for
(
int
yi
=
0
;
yi
<
ey
;
++
yi
)
{
float
*
dst_yi
=
trans_remain_tmp_data
+
yi
*
16
;
const
float
*
src_yi
=
src_ci
+
w_pad
*
yi
*
4
;
memcpy
(
dst_yi
,
src_yi
,
x_size
*
sizeof
(
float
)
*
4
);
}
}
// trans
float
*
dst_ci
=
dst_ptr
+
ci
*
tile_count
*
4
;
input_trans_c4_4x4
(
trans_remain_tmp_data
,
4
,
16
,
dst_ci
,
b_gi_stride
,
b_gi_stride
*
4
);
}
// for ci_4
}
}
//*/
// input trans end
// *begin compute dot
// *
//*
float
*
dst_temp_data
=
tmp_data
+
tile_block
*
ic_4
*
64
;
float
*
b_ptr
=
tmp_data
;
int
w_gi_stride
=
ic_4
*
oc_4
*
16
;
for
(
int
gi
=
0
;
gi
<
16
;
++
gi
)
{
float
*
origin_C
=
dst_temp_data
+
gi
*
c_gi_stride
;
float
*
origin_B
=
b_ptr
+
gi
*
b_gi_stride
;
const
float
*
origin_A
=
weight
+
gi
*
w_gi_stride
;
sgemm_prepack_c4_small
(
oc_4
*
4
,
tile_count
,
ic_4
*
4
,
origin_A
,
origin_B
,
origin_C
,
ctx
);
}
//*/
//*
// output trans
float
bias_value
[
4
];
memset
(
bias_value
,
0
,
4
*
sizeof
(
float
));
for
(
int
ti
=
0
;
ti
<
tile_count
;
++
ti
)
{
int
index
=
tile_index
+
ti
;
int
tw_index
=
index
%
tile_w
;
int
th_index
=
index
/
tile_w
;
int
dst_x
=
tw_index
*
2
;
int
dst_y
=
th_index
*
2
;
int
ex
=
dst_x
+
2
>
wout
?
wout
-
dst_x
:
2
;
int
ey
=
dst_y
+
2
>
hout
?
hout
-
dst_y
:
2
;
float
*
dst_ptr
=
output
+
(
dst_y
*
wout
+
dst_x
)
*
4
;
float
*
src_ptr
=
dst_temp_data
+
ti
*
4
;
if
(
ex
==
2
)
{
// trans output
for
(
int
ci
=
0
;
ci
<
oc_4
;
++
ci
)
{
if
(
param
.
bias
)
{
bias_value
[
0
]
=
bias
[
ci
*
4
];
bias_value
[
1
]
=
bias
[
ci
*
4
+
1
];
bias_value
[
2
]
=
bias
[
ci
*
4
+
2
];
bias_value
[
3
]
=
bias
[
ci
*
4
+
3
];
}
float
*
dst_ci
=
dst_ptr
+
ci
*
oc_4_stride
;
float
*
src_ci
=
src_ptr
+
ci
*
tile_count
*
4
;
output_trans_c4_post_2x4
(
src_ci
,
c_gi_stride
,
c_gi_stride
*
4
,
trans_remain_tmp_data
,
4
,
8
,
bias_value
,
param
.
fuse_relu
);
write_to_output_c4_fp32
(
trans_remain_tmp_data
,
output_ptr
,
ci
*
4
,
ci
*
4
+
4
,
dst_y
,
dst_y
+
ey
,
dst_x
,
dst_x
+
ex
,
chout
,
hout
,
wout
,
false
,
zero_ptr
);
}
}
else
{
for
(
int
ci
=
0
;
ci
<
oc_4
;
++
ci
)
{
if
(
param
.
bias
)
{
bias_value
[
0
]
=
bias
[
ci
*
4
];
bias_value
[
1
]
=
bias
[
ci
*
4
+
1
];
bias_value
[
2
]
=
bias
[
ci
*
4
+
2
];
bias_value
[
3
]
=
bias
[
ci
*
4
+
3
];
}
// trans output
float
*
dst_ci
=
dst_ptr
+
ci
*
oc_4_stride
;
float
*
src_ci
=
src_ptr
+
ci
*
tile_count
*
4
;
output_trans_c4_post_2x4
(
src_ci
,
c_gi_stride
,
c_gi_stride
*
4
,
trans_remain_tmp_data
,
4
,
8
,
bias_value
,
param
.
fuse_relu
);
// copy to dest
memset
(
trans_tmp_data
,
0
,
16
*
sizeof
(
float
));
for
(
int
i
=
0
;
i
<
ey
;
++
i
)
{
memcpy
(
trans_tmp_data
+
i
*
ex
*
4
,
trans_remain_tmp_data
+
i
*
8
,
ex
*
sizeof
(
float
)
*
4
);
}
write_to_output_c4_fp32
(
trans_tmp_data
,
output_ptr
,
ci
*
4
,
ci
*
4
+
4
,
dst_y
,
dst_y
+
ey
,
dst_x
,
dst_x
+
ex
,
chout
,
hout
,
wout
,
false
,
zero_ptr
);
}
}
}
//*/
}
// for block_count
}
// for num
}
// conv_compute
void
conv_compute_2x2_3x3_small
(
const
float
*
input
,
float
*
output
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
const
float
*
weight
,
const
float
*
bias
,
const
operators
::
ConvParam
&
param
,
ARMContext
*
ctx
)
{
const
int
pad_h
=
(
*
param
.
paddings
)[
0
];
const
int
pad_w
=
(
*
param
.
paddings
)[
2
];
float
*
tmp_work_space
=
ctx
->
workspace_data
<
float
>
()
+
ctx
->
llc_size
()
/
sizeof
(
float
);
int
in_n_stride
=
chin
*
hin
*
win
;
int
out_n_stride
=
chout
*
hout
*
wout
;
int
ic_stride
=
win
*
hin
;
int
oc_stride
=
wout
*
hout
;
int
ic_4
=
(
chin
+
3
)
/
4
;
int
oc_4
=
(
chout
+
3
)
/
4
;
int
tile_w
=
(
wout
+
1
)
/
2
;
int
tile_h
=
(
hout
+
1
)
/
2
;
int
size_tile
=
tile_h
*
tile_w
;
int
w_pad
=
win
+
pad_w
*
2
;
int
h_pad
=
hin
+
pad_h
*
2
;
const
int
zero_len
=
w_pad
;
float
zero_ptr
[
zero_len
];
// NOLINT
memset
(
zero_ptr
,
0
,
zero_len
*
sizeof
(
float
));
float
*
input_c4
=
tmp_work_space
;
int
new_h_stride
=
w_pad
*
4
;
int
new_c_stride
=
new_h_stride
*
h_pad
;
int
ic_4_stride
=
w_pad
*
h_pad
*
4
;
int
oc_4_stride
=
wout
*
hout
*
4
;
int
tile_block
=
8
;
int
block_count
=
(
size_tile
+
tile_block
-
1
)
/
tile_block
;
int
threads
=
ctx
->
threads
();
float
*
g_tmp_data
=
tmp_work_space
+
ic_4
*
new_c_stride
;
int
tmp_data_thread_stride
=
tile_block
*
(
oc_4
+
ic_4
)
*
64
;
memset
(
g_tmp_data
,
0
,
tmp_data_thread_stride
*
sizeof
(
float
));
float
*
g_trans_tmp_data
=
g_tmp_data
+
tmp_data_thread_stride
;
float
*
g_trans_remain_tmp_data
=
g_trans_tmp_data
+
64
;
// begin compute
for
(
int
ni
=
0
;
ni
<
num
;
++
ni
)
{
// trans input to c4
#pragma omp parallel for num_threads(threads)
for
(
int
i
=
0
;
i
<
ic_4
;
++
i
)
{
prepack_input_nxwc4_dw
(
input
+
ni
*
in_n_stride
,
input_c4
+
i
*
new_c_stride
,
i
*
4
,
-
pad_h
,
hin
+
pad_h
,
-
pad_w
,
win
+
pad_w
,
chin
,
win
,
hin
,
zero_ptr
);
}
float
*
output_ptr
=
output
+
ni
*
out_n_stride
;
const
float
*
weight_ptr
=
weight
;
const
float
*
bias_ptr
=
bias
;
for
(
int
tbi
=
0
;
tbi
<
block_count
;
++
tbi
)
{
float
*
tmp_data
=
g_tmp_data
;
float
*
trans_tmp_data
=
g_trans_tmp_data
;
float
*
trans_remain_tmp_data
=
g_trans_remain_tmp_data
;
int
tile_index
=
tbi
*
tile_block
;
int
tile_remain
=
size_tile
-
tile_index
;
int
tile_count
=
tile_remain
>
tile_block
?
tile_block
:
tile_remain
;
// input trans
int
c_gi_stride
=
tile_count
*
oc_4
*
4
;
int
b_gi_stride
=
tile_count
*
ic_4
*
4
;
//*
for
(
int
ti
=
0
;
ti
<
tile_count
;
++
ti
)
{
int
index
=
tile_index
+
ti
;
int
tw_index
=
index
%
tile_w
;
int
th_index
=
index
/
tile_w
;
int
src_x
=
tw_index
+
tw_index
;
int
src_y
=
th_index
+
th_index
;
int
ex
=
src_x
+
4
>
w_pad
?
w_pad
-
src_x
:
4
;
int
ey
=
src_y
+
4
>
h_pad
?
h_pad
-
src_y
:
4
;
float
*
dst_ptr
=
tmp_data
+
ti
*
4
;
const
float
*
src_ptr
=
input_c4
+
(
src_y
*
w_pad
+
src_x
)
*
4
;
if
(
ex
==
4
&&
ey
==
4
)
{
// trans input
for
(
int
ci
=
0
;
ci
<
ic_4
;
++
ci
)
{
const
float
*
src_ci
=
src_ptr
+
ci
*
ic_4_stride
;
float
*
dst_ci
=
dst_ptr
+
ci
*
tile_count
*
4
;
input_trans_c4_4x4
(
src_ci
,
4
,
w_pad
*
4
,
dst_ci
,
b_gi_stride
,
b_gi_stride
*
4
);
}
}
else
{
// trans remain input
int
x_size
=
ex
;
for
(
int
ci
=
0
;
ci
<
ic_4
;
++
ci
)
{
const
float
*
src_ci
=
src_ptr
+
ci
*
ic_4_stride
;
// pad
memset
(
trans_remain_tmp_data
,
0
,
64
*
sizeof
(
float
));
if
(
x_size
>
0
)
{
for
(
int
yi
=
0
;
yi
<
ey
;
++
yi
)
{
float
*
dst_yi
=
trans_remain_tmp_data
+
yi
*
16
;
const
float
*
src_yi
=
src_ci
+
w_pad
*
yi
*
4
;
memcpy
(
dst_yi
,
src_yi
,
x_size
*
sizeof
(
float
)
*
4
);
}
}
float
*
dst_ci
=
dst_ptr
+
ci
*
tile_count
*
4
;
input_trans_c4_4x4
(
trans_remain_tmp_data
,
4
,
16
,
dst_ci
,
b_gi_stride
,
b_gi_stride
*
4
);
}
// for ci_4
}
}
//*/
// input trans end
// *begin compute dot
// *
//*
float
*
dst_temp_data
=
tmp_data
+
tile_block
*
ic_4
*
64
;
float
*
b_ptr
=
tmp_data
;
int
w_gi_stride
=
ic_4
*
oc_4
*
16
;
#pragma omp parallel for num_threads(threads)
for
(
int
gi
=
0
;
gi
<
16
;
++
gi
)
{
float
*
origin_C
=
dst_temp_data
+
gi
*
c_gi_stride
;
float
*
origin_B
=
b_ptr
+
gi
*
b_gi_stride
;
const
float
*
origin_A
=
weight
+
gi
*
w_gi_stride
;
sgemm_prepack_c4_small
(
oc_4
*
4
,
tile_count
,
ic_4
*
4
,
origin_A
,
origin_B
,
origin_C
,
ctx
);
}
//*/
//*
// output trans
float
bias_value
[
4
];
memset
(
bias_value
,
0
,
4
*
sizeof
(
float
));
for
(
int
ti
=
0
;
ti
<
tile_count
;
++
ti
)
{
int
index
=
tile_index
+
ti
;
int
tw_index
=
index
%
tile_w
;
int
th_index
=
index
/
tile_w
;
int
dst_x
=
tw_index
*
2
;
int
dst_y
=
th_index
*
2
;
int
ex
=
dst_x
+
2
>
wout
?
wout
-
dst_x
:
2
;
int
ey
=
dst_y
+
2
>
hout
?
hout
-
dst_y
:
2
;
float
*
dst_ptr
=
output
+
(
dst_y
*
wout
+
dst_x
)
*
4
;
float
*
src_ptr
=
dst_temp_data
+
ti
*
4
;
if
(
ex
==
2
)
{
// trans output
for
(
int
ci
=
0
;
ci
<
oc_4
;
++
ci
)
{
if
(
param
.
bias
)
{
bias_value
[
0
]
=
bias
[
ci
*
4
];
bias_value
[
1
]
=
bias
[
ci
*
4
+
1
];
bias_value
[
2
]
=
bias
[
ci
*
4
+
2
];
bias_value
[
3
]
=
bias
[
ci
*
4
+
3
];
}
float
*
dst_ci
=
dst_ptr
+
ci
*
oc_4_stride
;
float
*
src_ci
=
src_ptr
+
ci
*
tile_count
*
4
;
output_trans_c4_post_2x4
(
src_ci
,
c_gi_stride
,
c_gi_stride
*
4
,
trans_remain_tmp_data
,
4
,
8
,
bias_value
,
param
.
fuse_relu
);
write_to_output_c4_fp32
(
trans_remain_tmp_data
,
output_ptr
,
ci
*
4
,
ci
*
4
+
4
,
dst_y
,
dst_y
+
ey
,
dst_x
,
dst_x
+
ex
,
chout
,
hout
,
wout
,
false
,
zero_ptr
);
}
}
else
{
for
(
int
ci
=
0
;
ci
<
oc_4
;
++
ci
)
{
if
(
param
.
bias
)
{
bias_value
[
0
]
=
bias
[
ci
*
4
];
bias_value
[
1
]
=
bias
[
ci
*
4
+
1
];
bias_value
[
2
]
=
bias
[
ci
*
4
+
2
];
bias_value
[
3
]
=
bias
[
ci
*
4
+
3
];
}
// trans output
float
*
dst_ci
=
dst_ptr
+
ci
*
oc_4_stride
;
float
*
src_ci
=
src_ptr
+
ci
*
tile_count
*
4
;
output_trans_c4_post_2x4
(
src_ci
,
c_gi_stride
,
c_gi_stride
*
4
,
trans_remain_tmp_data
,
4
,
8
,
bias_value
,
param
.
fuse_relu
);
// copy to dest
memset
(
trans_tmp_data
,
0
,
16
*
sizeof
(
float
));
for
(
int
i
=
0
;
i
<
ey
;
++
i
)
{
memcpy
(
trans_tmp_data
+
i
*
ex
*
4
,
trans_remain_tmp_data
+
i
*
8
,
ex
*
sizeof
(
float
)
*
4
);
}
write_to_output_c4_fp32
(
trans_tmp_data
,
output_ptr
,
ci
*
4
,
ci
*
4
+
4
,
dst_y
,
dst_y
+
ey
,
dst_x
,
dst_x
+
ex
,
chout
,
hout
,
wout
,
false
,
zero_ptr
);
}
}
}
//*/
}
// for block_count
}
// for num
}
// conv_compute
void
output_trans_c4_6x8
(
const
float
*
src
,
int
src_stride
,
float
*
dest
,
int
dest_stride
)
{
const
float32x4_t
src0
=
vld1q_f32
(
src
);
const
float32x4_t
src1
=
vld1q_f32
(
src
+
src_stride
);
const
float32x4_t
src2
=
vld1q_f32
(
src
+
src_stride
*
2
);
...
...
@@ -381,12 +905,13 @@ void output_trans_c4(const float* src,
vst1q_f32
(
dest
+
dest_stride
*
4
,
dest4
);
vst1q_f32
(
dest
+
dest_stride
*
5
,
dest5
);
}
void
output_trans_c4_post
(
const
float
*
src
,
int
src_stride
,
float
*
dest
,
int
dest_stride
,
float
*
bias_value
,
bool
has_relu
=
false
)
{
void
output_trans_c4_post_6x8
(
const
float
*
src
,
int
src_stride
,
float
*
dest
,
int
dest_stride
,
float
*
bias_value
,
bool
has_relu
=
false
)
{
const
float32x4_t
src0
=
vld1q_f32
(
src
);
const
float32x4_t
src1
=
vld1q_f32
(
src
+
src_stride
);
const
float32x4_t
src2
=
vld1q_f32
(
src
+
src_stride
*
2
);
...
...
@@ -447,10 +972,10 @@ void output_trans_c4_post(const float* src,
vst1q_f32
(
dest
+
dest_stride
*
5
,
dest5
);
}
void
input_trans_c4
(
const
float
*
src
,
int
src_stride
,
float
*
dest
,
int
dest_stride
)
{
void
input_trans_c4
_8x8
(
const
float
*
src
,
int
src_stride
,
float
*
dest
,
int
dest_stride
)
{
float32x4_t
src0
=
vld1q_f32
(
src
);
float32x4_t
src1
=
vld1q_f32
(
src
+
src_stride
);
float32x4_t
src2
=
vld1q_f32
(
src
+
src_stride
*
2
);
...
...
@@ -497,7 +1022,165 @@ void input_trans_c4(const float* src,
vst1q_f32
(
dest
+
dest_stride
*
6
,
dst6
);
vst1q_f32
(
dest
+
dest_stride
*
7
,
dst7
);
}
void
weight_trans_c4
(
// BT=[1, 0, -1, 0,
// 0, 1, 1, 0,
// 0, -1, 1, 0,
// 0, 1, 0, -1]
void
input_trans_c4_4x4
(
const
float
*
src
,
int
src_stride
,
int
src_h_stride
,
float
*
dest
,
int
dest_stride
,
int
dest_h_stride
)
{
float32x4_t
src00
=
vld1q_f32
(
src
);
float32x4_t
src01
=
vld1q_f32
(
src
+
src_stride
);
float32x4_t
src02
=
vld1q_f32
(
src
+
src_stride
+
src_stride
);
float32x4_t
src03
=
vld1q_f32
(
src
+
src_stride
+
src_stride
+
src_stride
);
src
+=
src_h_stride
;
float32x4_t
src10
=
vld1q_f32
(
src
);
float32x4_t
src11
=
vld1q_f32
(
src
+
src_stride
);
float32x4_t
src12
=
vld1q_f32
(
src
+
src_stride
+
src_stride
);
float32x4_t
src13
=
vld1q_f32
(
src
+
src_stride
+
src_stride
+
src_stride
);
src
+=
src_h_stride
;
float32x4_t
src20
=
vld1q_f32
(
src
);
float32x4_t
src21
=
vld1q_f32
(
src
+
src_stride
);
float32x4_t
src22
=
vld1q_f32
(
src
+
src_stride
+
src_stride
);
float32x4_t
src23
=
vld1q_f32
(
src
+
src_stride
+
src_stride
+
src_stride
);
src
+=
src_h_stride
;
float32x4_t
src30
=
vld1q_f32
(
src
);
float32x4_t
src31
=
vld1q_f32
(
src
+
src_stride
);
float32x4_t
src32
=
vld1q_f32
(
src
+
src_stride
+
src_stride
);
float32x4_t
src33
=
vld1q_f32
(
src
+
src_stride
+
src_stride
+
src_stride
);
float32x4_t
dst00
=
vsubq_f32
(
src00
,
src02
);
float32x4_t
dst10
=
vaddq_f32
(
src01
,
src02
);
float32x4_t
dst20
=
vsubq_f32
(
src02
,
src01
);
float32x4_t
dst30
=
vsubq_f32
(
src01
,
src03
);
float32x4_t
dst01
=
vsubq_f32
(
src10
,
src12
);
float32x4_t
dst11
=
vaddq_f32
(
src11
,
src12
);
float32x4_t
dst21
=
vsubq_f32
(
src12
,
src11
);
float32x4_t
dst31
=
vsubq_f32
(
src11
,
src13
);
float32x4_t
dst02
=
vsubq_f32
(
src20
,
src22
);
float32x4_t
dst12
=
vaddq_f32
(
src21
,
src22
);
float32x4_t
dst22
=
vsubq_f32
(
src22
,
src21
);
float32x4_t
dst32
=
vsubq_f32
(
src21
,
src23
);
float32x4_t
dst03
=
vsubq_f32
(
src30
,
src32
);
float32x4_t
dst13
=
vaddq_f32
(
src31
,
src32
);
float32x4_t
dst23
=
vsubq_f32
(
src32
,
src31
);
float32x4_t
dst33
=
vsubq_f32
(
src31
,
src33
);
float32x4_t
dest00
=
vsubq_f32
(
dst00
,
dst02
);
float32x4_t
dest10
=
vaddq_f32
(
dst01
,
dst02
);
float32x4_t
dest20
=
vsubq_f32
(
dst02
,
dst01
);
float32x4_t
dest30
=
vsubq_f32
(
dst01
,
dst03
);
float32x4_t
dest01
=
vsubq_f32
(
dst10
,
dst12
);
float32x4_t
dest11
=
vaddq_f32
(
dst11
,
dst12
);
float32x4_t
dest21
=
vsubq_f32
(
dst12
,
dst11
);
float32x4_t
dest31
=
vsubq_f32
(
dst11
,
dst13
);
float32x4_t
dest02
=
vsubq_f32
(
dst20
,
dst22
);
float32x4_t
dest12
=
vaddq_f32
(
dst21
,
dst22
);
float32x4_t
dest22
=
vsubq_f32
(
dst22
,
dst21
);
float32x4_t
dest32
=
vsubq_f32
(
dst21
,
dst23
);
float32x4_t
dest03
=
vsubq_f32
(
dst30
,
dst32
);
float32x4_t
dest13
=
vaddq_f32
(
dst31
,
dst32
);
float32x4_t
dest23
=
vsubq_f32
(
dst32
,
dst31
);
float32x4_t
dest33
=
vsubq_f32
(
dst31
,
dst33
);
vst1q_f32
(
dest
,
dest00
);
vst1q_f32
(
dest
+
dest_stride
,
dest10
);
vst1q_f32
(
dest
+
dest_stride
+
dest_stride
,
dest20
);
vst1q_f32
(
dest
+
dest_stride
+
dest_stride
+
dest_stride
,
dest30
);
dest
+=
dest_h_stride
;
vst1q_f32
(
dest
,
dest01
);
vst1q_f32
(
dest
+
dest_stride
,
dest11
);
vst1q_f32
(
dest
+
dest_stride
+
dest_stride
,
dest21
);
vst1q_f32
(
dest
+
dest_stride
+
dest_stride
+
dest_stride
,
dest31
);
dest
+=
dest_h_stride
;
vst1q_f32
(
dest
,
dest02
);
vst1q_f32
(
dest
+
dest_stride
,
dest12
);
vst1q_f32
(
dest
+
dest_stride
+
dest_stride
,
dest22
);
vst1q_f32
(
dest
+
dest_stride
+
dest_stride
+
dest_stride
,
dest32
);
dest
+=
dest_h_stride
;
vst1q_f32
(
dest
,
dest03
);
vst1q_f32
(
dest
+
dest_stride
,
dest13
);
vst1q_f32
(
dest
+
dest_stride
+
dest_stride
,
dest23
);
vst1q_f32
(
dest
+
dest_stride
+
dest_stride
+
dest_stride
,
dest33
);
}
// AT=[1, 1, 1, 0,
// 0, 1, -1, -1]
void
output_trans_c4_post_2x4
(
const
float
*
src
,
int
src_stride
,
int
src_h_stride
,
float
*
dest
,
int
dest_stride
,
int
dest_h_stride
,
float
*
bias_value
,
bool
has_relu
)
{
float32x4_t
src00
=
vld1q_f32
(
src
);
float32x4_t
src01
=
vld1q_f32
(
src
+
src_stride
);
float32x4_t
src02
=
vld1q_f32
(
src
+
src_stride
+
src_stride
);
float32x4_t
src03
=
vld1q_f32
(
src
+
src_stride
+
src_stride
+
src_stride
);
src
+=
src_h_stride
;
float32x4_t
src10
=
vld1q_f32
(
src
);
float32x4_t
src11
=
vld1q_f32
(
src
+
src_stride
);
float32x4_t
src12
=
vld1q_f32
(
src
+
src_stride
+
src_stride
);
float32x4_t
src13
=
vld1q_f32
(
src
+
src_stride
+
src_stride
+
src_stride
);
src
+=
src_h_stride
;
float32x4_t
src20
=
vld1q_f32
(
src
);
float32x4_t
src21
=
vld1q_f32
(
src
+
src_stride
);
float32x4_t
src22
=
vld1q_f32
(
src
+
src_stride
+
src_stride
);
float32x4_t
src23
=
vld1q_f32
(
src
+
src_stride
+
src_stride
+
src_stride
);
src
+=
src_h_stride
;
float32x4_t
src30
=
vld1q_f32
(
src
);
float32x4_t
src31
=
vld1q_f32
(
src
+
src_stride
);
float32x4_t
src32
=
vld1q_f32
(
src
+
src_stride
+
src_stride
);
float32x4_t
src33
=
vld1q_f32
(
src
+
src_stride
+
src_stride
+
src_stride
);
float32x4_t
dst00
=
vaddq_f32
(
vaddq_f32
(
src00
,
src01
),
src02
);
float32x4_t
dst10
=
vsubq_f32
(
vsubq_f32
(
src01
,
src02
),
src03
);
float32x4_t
dst01
=
vaddq_f32
(
vaddq_f32
(
src10
,
src11
),
src12
);
float32x4_t
dst11
=
vsubq_f32
(
vsubq_f32
(
src11
,
src12
),
src13
);
float32x4_t
dst02
=
vaddq_f32
(
vaddq_f32
(
src20
,
src21
),
src22
);
float32x4_t
dst12
=
vsubq_f32
(
vsubq_f32
(
src21
,
src22
),
src23
);
float32x4_t
dst03
=
vaddq_f32
(
vaddq_f32
(
src30
,
src31
),
src32
);
float32x4_t
dst13
=
vsubq_f32
(
vsubq_f32
(
src31
,
src32
),
src33
);
float32x4_t
dest00
=
vaddq_f32
(
vaddq_f32
(
dst00
,
dst01
),
dst02
);
float32x4_t
dest10
=
vsubq_f32
(
vsubq_f32
(
dst01
,
dst02
),
dst03
);
float32x4_t
dest01
=
vaddq_f32
(
vaddq_f32
(
dst10
,
dst11
),
dst12
);
float32x4_t
dest11
=
vsubq_f32
(
vsubq_f32
(
dst11
,
dst12
),
dst13
);
if
(
bias_value
)
{
float32x4_t
bias
=
vld1q_f32
(
bias_value
);
dest00
=
vaddq_f32
(
dest00
,
bias
);
dest10
=
vaddq_f32
(
dest10
,
bias
);
dest01
=
vaddq_f32
(
dest01
,
bias
);
dest11
=
vaddq_f32
(
dest11
,
bias
);
}
if
(
has_relu
)
{
float32x4_t
zeros
=
vdupq_n_f32
(
0
);
dest00
=
vmaxq_f32
(
dest00
,
zeros
);
dest10
=
vmaxq_f32
(
dest10
,
zeros
);
dest01
=
vmaxq_f32
(
dest01
,
zeros
);
dest11
=
vmaxq_f32
(
dest11
,
zeros
);
}
vst1q_f32
(
dest
,
dest00
);
vst1q_f32
(
dest
+
dest_stride
,
dest10
);
dest
+=
dest_h_stride
;
vst1q_f32
(
dest
,
dest01
);
vst1q_f32
(
dest
+
dest_stride
,
dest11
);
}
void
weight_trans_c4_8x8
(
float
*
dest
,
const
float
*
din
,
int
ch_in
,
int
ch_out
,
void
*
workspace
)
{
const
float
coeff
[
8
][
3
]
=
{{
1.0
f
,
0.0
f
,
0.0
f
},
{
-
2.0
f
/
9
,
-
2.0
f
/
9
,
-
2.0
f
/
9
},
...
...
@@ -558,6 +1241,63 @@ void weight_trans_c4(
}
}
void
weight_trans_c4_4x4
(
float
*
dest
,
const
float
*
din
,
int
ch_in
,
int
ch_out
,
void
*
workspace
)
{
const
float
coeff
[
4
][
3
]
=
{{
1.0
f
,
0.0
f
,
0.0
f
},
{
0.5
f
,
0.5
f
,
0.5
f
},
{
0.5
f
,
-
0.5
f
,
0.5
f
},
{
0.0
f
,
0.0
f
,
1.0
f
}};
float
*
ptr_out
=
static_cast
<
float
*>
(
workspace
);
for
(
int
i
=
0
;
i
<
ch_out
;
i
++
)
{
for
(
int
j
=
0
;
j
<
ch_in
;
j
++
)
{
const
float
*
kernel0
=
static_cast
<
const
float
*>
(
din
)
+
(
i
*
ch_in
+
j
)
*
9
;
float
*
ptr_channel
=
ptr_out
+
(
i
*
ch_in
+
j
)
*
16
;
//! transform kernel, transposed
const
float
*
k0
=
kernel0
;
const
float
*
k1
=
kernel0
+
3
;
const
float
*
k2
=
kernel0
+
6
;
//! h
float
tmp
[
4
][
3
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
tmp
[
i
][
0
]
=
k0
[
0
]
*
coeff
[
i
][
0
]
+
k0
[
1
]
*
coeff
[
i
][
1
]
+
k0
[
2
]
*
coeff
[
i
][
2
];
tmp
[
i
][
1
]
=
k1
[
0
]
*
coeff
[
i
][
0
]
+
k1
[
1
]
*
coeff
[
i
][
1
]
+
k1
[
2
]
*
coeff
[
i
][
2
];
tmp
[
i
][
2
]
=
k2
[
0
]
*
coeff
[
i
][
0
]
+
k2
[
1
]
*
coeff
[
i
][
1
]
+
k2
[
2
]
*
coeff
[
i
][
2
];
}
//! v
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
float
*
tmpp
=
&
tmp
[
j
][
0
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
ptr_channel
[
j
*
4
+
i
]
=
tmpp
[
0
]
*
coeff
[
i
][
0
]
+
tmpp
[
1
]
*
coeff
[
i
][
1
]
+
tmpp
[
2
]
*
coeff
[
i
][
2
];
}
}
}
}
int
oc_pad
=
(
ch_out
+
3
)
/
4
*
4
;
int
ic_pad
=
(
ch_in
+
3
)
/
4
*
4
;
int
c_stride
=
ic_pad
*
oc_pad
;
for
(
int
i
=
0
;
i
<
ch_out
*
ch_in
*
16
;
++
i
)
{
int
new_c
=
i
%
16
;
int
new_oc
=
i
/
ch_in
/
16
/
4
;
int
new_ic
=
i
/
16
%
(
ch_in
*
4
)
%
ch_in
;
int
new_inner
=
i
/
ch_in
/
16
%
4
;
int
dest_ind
=
new_c
*
c_stride
+
new_oc
*
ic_pad
*
4
+
new_ic
*
4
+
new_inner
;
dest
[
dest_ind
]
=
ptr_out
[
i
];
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
...
...
lite/backends/arm/math/conv_impl.h
浏览文件 @
f99c34c8
...
...
@@ -316,7 +316,9 @@ void fill_bias_int8(int* tensor,
int
channel_size
);
// new winograd
void
weight_trans_c4
(
void
weight_trans_c4_8x8
(
float
*
dest
,
const
float
*
src
,
int
ic
,
int
oc
,
void
*
workspace
);
void
weight_trans_c4_4x4
(
float
*
dest
,
const
float
*
src
,
int
ic
,
int
oc
,
void
*
workspace
);
void
conv_compute_6x6_3x3
(
const
float
*
input
,
float
*
output
,
...
...
@@ -331,6 +333,32 @@ void conv_compute_6x6_3x3(const float* input,
const
float
*
bias
,
const
operators
::
ConvParam
&
param
,
ARMContext
*
ctx
);
void
conv_compute_2x2_3x3
(
const
float
*
input
,
float
*
output
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
const
float
*
weight
,
const
float
*
bias
,
const
operators
::
ConvParam
&
param
,
ARMContext
*
ctx
);
void
conv_compute_2x2_3x3_small
(
const
float
*
input
,
float
*
output
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
const
float
*
weight
,
const
float
*
bias
,
const
operators
::
ConvParam
&
param
,
ARMContext
*
ctx
);
}
// namespace math
}
// namespace arm
}
// namespace lite
...
...
lite/backends/arm/math/packed_sgemm_c4.cc
浏览文件 @
f99c34c8
...
...
@@ -695,7 +695,6 @@ void sgemm_prepack_c4_common(int M,
}
}
}
void
sgemm_prepack_c4_small
(
int
M
,
int
N
,
int
K
,
...
...
@@ -1146,6 +1145,540 @@ void sgemm_prepack_c4_small(int M,
}
}
void
sgemm_prepack_c4_small
(
int
M
,
int
N
,
int
K
,
const
float
*
A_packed
,
const
float
*
B
,
float
*
C
,
ARMContext
*
ctx
)
{
const
int
m_round
=
(
M
+
3
)
/
4
*
4
;
const
int
k_round
=
(
K
+
3
)
/
4
*
4
;
const
int
mloop
=
m_round
>>
2
;
const
int
lda
=
4
*
k_round
;
const
int
ldb_byte
=
4
*
N
*
sizeof
(
float
);
const
int
kcnt
=
k_round
>>
2
;
#ifdef __aarch64__
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
#endif
for
(
int
m
=
0
;
m
<
mloop
;
++
m
)
{
const
float
*
b
=
B
;
int
n
=
N
;
#ifdef __aarch64__
for
(;
n
>
7
;
n
-=
8
)
{
int
cnt
=
kcnt
;
const
float
*
a_ptr
=
A_packed
;
const
float
*
b_ptr
=
b
;
// clang-format off
asm
volatile
(
"0:
\n
"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32
\n
"
/* load b0, b1 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32
\n
"
/* load b2, b3 */
"ld1 {v2.4s, v3.4s}, [%[b]], #32
\n
"
/* load a2, a3 */
"fmul v8.4s, v16.4s, v0.s[0]
\n
"
"fmul v9.4s, v16.4s, v1.s[0]
\n
"
"fmul v10.4s, v16.4s, v2.s[0]
\n
"
"fmul v11.4s, v16.4s, v3.s[0]
\n
"
"ld1 {v18.4s, v19.4s}, [%[a]], #32
\n
"
"prfm pldl1keep, [%[b]]
\n
"
"fmla v8.4s, v17.4s, v0.s[1]
\n
"
"fmla v9.4s, v17.4s, v1.s[1]
\n
"
"fmla v10.4s, v17.4s, v2.s[1]
\n
"
"fmla v11.4s, v17.4s, v3.s[1]
\n
"
/* load b4, b5 */
"ld1 {v4.4s, v5.4s}, [%[b]], #32
\n
"
"fmla v8.4s, v18.4s, v0.s[2]
\n
"
"fmla v9.4s, v18.4s, v1.s[2]
\n
"
"fmla v10.4s, v18.4s, v2.s[2]
\n
"
"fmla v11.4s, v18.4s, v3.s[2]
\n
"
/* load b6, b7 */
"ld1 {v6.4s, v7.4s}, [%[b]], #32
\n
"
"fmla v8.4s, v19.4s, v0.s[3]
\n
"
"fmla v9.4s, v19.4s, v1.s[3]
\n
"
"fmla v10.4s, v19.4s, v2.s[3]
\n
"
"fmla v11.4s, v19.4s, v3.s[3]
\n
"
"sub %[b], %[b], #128
\n
"
"fmul v12.4s, v16.4s, v4.s[0]
\n
"
"fmul v13.4s, v16.4s, v5.s[0]
\n
"
"fmul v14.4s, v16.4s, v6.s[0]
\n
"
"fmul v15.4s, v16.4s, v7.s[0]
\n
"
"add %[b], %[b], %[ldb]
\n
"
"fmla v12.4s, v17.4s, v4.s[1]
\n
"
"fmla v13.4s, v17.4s, v5.s[1]
\n
"
"fmla v14.4s, v17.4s, v6.s[1]
\n
"
"fmla v15.4s, v17.4s, v7.s[1]
\n
"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32
\n
"
"fmla v12.4s, v18.4s, v4.s[2]
\n
"
"fmla v13.4s, v18.4s, v5.s[2]
\n
"
"fmla v14.4s, v18.4s, v6.s[2]
\n
"
"fmla v15.4s, v18.4s, v7.s[2]
\n
"
/* load b0, b1 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32
\n
"
"fmla v12.4s, v19.4s, v4.s[3]
\n
"
"fmla v13.4s, v19.4s, v5.s[3]
\n
"
"fmla v14.4s, v19.4s, v6.s[3]
\n
"
"fmla v15.4s, v19.4s, v7.s[3]
\n
"
"subs %w[cnt], %w[cnt], #1
\n
"
"beq 2f
\n
"
"1:
\n
"
/* load b2, b3 */
"ld1 {v2.4s, v3.4s}, [%[b]], #32
\n
"
"fmla v8.4s, v16.4s, v0.s[0]
\n
"
"fmla v9.4s, v16.4s, v1.s[0]
\n
"
"fmla v10.4s, v16.4s, v2.s[0]
\n
"
"fmla v11.4s, v16.4s, v3.s[0]
\n
"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32
\n
"
"prfm pldl1keep, [%[b]]
\n
"
"fmla v8.4s, v17.4s, v0.s[1]
\n
"
"fmla v9.4s, v17.4s, v1.s[1]
\n
"
"fmla v10.4s, v17.4s, v2.s[1]
\n
"
"fmla v11.4s, v17.4s, v3.s[1]
\n
"
/* load b4, b5 */
"ld1 {v4.4s, v5.4s}, [%[b]], #32
\n
"
"fmla v8.4s, v18.4s, v0.s[2]
\n
"
"fmla v9.4s, v18.4s, v1.s[2]
\n
"
"fmla v10.4s, v18.4s, v2.s[2]
\n
"
"fmla v11.4s, v18.4s, v3.s[2]
\n
"
/* load b6, b7 */
"ld1 {v6.4s, v7.4s}, [%[b]], #32
\n
"
"fmla v8.4s, v19.4s, v0.s[3]
\n
"
"fmla v9.4s, v19.4s, v1.s[3]
\n
"
"fmla v10.4s, v19.4s, v2.s[3]
\n
"
"fmla v11.4s, v19.4s, v3.s[3]
\n
"
"sub %[b], %[b], #128
\n
"
"fmla v12.4s, v16.4s, v4.s[0]
\n
"
"fmla v13.4s, v16.4s, v5.s[0]
\n
"
"fmla v14.4s, v16.4s, v6.s[0]
\n
"
"fmla v15.4s, v16.4s, v7.s[0]
\n
"
"add %[b], %[b], %[ldb]
\n
"
"fmla v12.4s, v17.4s, v4.s[1]
\n
"
"fmla v13.4s, v17.4s, v5.s[1]
\n
"
"fmla v14.4s, v17.4s, v6.s[1]
\n
"
"fmla v15.4s, v17.4s, v7.s[1]
\n
"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32
\n
"
"fmla v12.4s, v18.4s, v4.s[2]
\n
"
"fmla v13.4s, v18.4s, v5.s[2]
\n
"
"fmla v14.4s, v18.4s, v6.s[2]
\n
"
"fmla v15.4s, v18.4s, v7.s[2]
\n
"
/* load b0, b1 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32
\n
"
"fmla v12.4s, v19.4s, v4.s[3]
\n
"
"fmla v13.4s, v19.4s, v5.s[3]
\n
"
"fmla v14.4s, v19.4s, v6.s[3]
\n
"
"fmla v15.4s, v19.4s, v7.s[3]
\n
"
"subs %w[cnt], %w[cnt], #1
\n
"
"bne 1b
\n
"
"2:
\n
"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64
\n
"
"st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[c]], #64
\n
"
:
[
a
]
"+r"
(
a_ptr
),
[
b
]
"+r"
(
b_ptr
),
[
c
]
"+r"
(
C
),
[
cnt
]
"+r"
(
cnt
)
:
[
ldb
]
"r"
(
ldb_byte
),
[
vzero
]
"w"
(
vzero
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"cc"
,
"memory"
);
b
+=
4
*
8
;
}
for
(;
n
>
3
;
n
-=
4
)
{
int
cnt
=
kcnt
;
const
float
*
a_ptr
=
A_packed
;
const
float
*
b_ptr
=
b
;
asm
volatile
(
"0:
\n
"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32
\n
"
/* load b0-b3 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32
\n
"
"ld1 {v2.4s, v3.4s}, [%[b]], #32
\n
"
"fmul v8.4s, v16.4s, v0.s[0]
\n
"
"fmul v9.4s, v16.4s, v1.s[0]
\n
"
"fmul v10.4s, v16.4s, v2.s[0]
\n
"
"fmul v11.4s, v16.4s, v3.s[0]
\n
"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32
\n
"
"sub %[b], %[b], #64
\n
"
"fmla v8.4s, v17.4s, v0.s[1]
\n
"
"fmla v9.4s, v17.4s, v1.s[1]
\n
"
"fmla v10.4s, v17.4s, v2.s[1]
\n
"
"fmla v11.4s, v17.4s, v3.s[1]
\n
"
"add %[b], %[b], %[ldb]
\n
"
"fmla v8.4s, v18.4s, v0.s[2]
\n
"
"fmla v9.4s, v18.4s, v1.s[2]
\n
"
"fmla v10.4s, v18.4s, v2.s[2]
\n
"
"fmla v11.4s, v18.4s, v3.s[2]
\n
"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32
\n
"
"fmla v8.4s, v19.4s, v0.s[3]
\n
"
"fmla v9.4s, v19.4s, v1.s[3]
\n
"
"fmla v10.4s, v19.4s, v2.s[3]
\n
"
"fmla v11.4s, v19.4s, v3.s[3]
\n
"
"subs %w[cnt], %w[cnt], #1
\n
"
"beq 2f
\n
"
"1:
\n
"
/* load b0-b3 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32
\n
"
"ld1 {v2.4s, v3.4s}, [%[b]], #32
\n
"
"fmla v8.4s, v16.4s, v0.s[0]
\n
"
"fmla v9.4s, v16.4s, v1.s[0]
\n
"
"fmla v10.4s, v16.4s, v2.s[0]
\n
"
"fmla v11.4s, v16.4s, v3.s[0]
\n
"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32
\n
"
"sub %[b], %[b], #64
\n
"
"fmla v8.4s, v17.4s, v0.s[1]
\n
"
"fmla v9.4s, v17.4s, v1.s[1]
\n
"
"fmla v10.4s, v17.4s, v2.s[1]
\n
"
"fmla v11.4s, v17.4s, v3.s[1]
\n
"
"add %[b], %[b], %[ldb]
\n
"
"fmla v8.4s, v18.4s, v0.s[2]
\n
"
"fmla v9.4s, v18.4s, v1.s[2]
\n
"
"fmla v10.4s, v18.4s, v2.s[2]
\n
"
"fmla v11.4s, v18.4s, v3.s[2]
\n
"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32
\n
"
"fmla v8.4s, v19.4s, v0.s[3]
\n
"
"fmla v9.4s, v19.4s, v1.s[3]
\n
"
"fmla v10.4s, v19.4s, v2.s[3]
\n
"
"fmla v11.4s, v19.4s, v3.s[3]
\n
"
"subs %w[cnt], %w[cnt], #1
\n
"
"bne 1b
\n
"
"2:
\n
"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64
\n
"
:
[
a
]
"+r"
(
a_ptr
),
[
b
]
"+r"
(
b_ptr
),
[
c
]
"+r"
(
C
),
[
cnt
]
"+r"
(
cnt
)
:
[
ldb
]
"r"
(
ldb_byte
),
[
vzero
]
"w"
(
vzero
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"cc"
,
"memory"
);
b
+=
4
*
4
;
}
for
(;
n
>
0
;
n
--
)
{
int
cnt
=
kcnt
;
const
float
*
a_ptr
=
A_packed
;
const
float
*
b_ptr
=
b
;
asm
volatile
(
"0:
\n
"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32
\n
"
/* load b0 */
"ld1 {v0.4s}, [%[b]], #16
\n
"
"fmul v8.4s, v16.4s, v0.s[0]
\n
"
"fmul v9.4s, v17.4s, v0.s[1]
\n
"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32
\n
"
"sub %[b], %[b], #16
\n
"
"subs %w[cnt], %w[cnt], #1
\n
"
"add %[b], %[b], %[ldb]
\n
"
"fmla v8.4s, v18.4s, v0.s[2]
\n
"
"fmla v9.4s, v19.4s, v0.s[3]
\n
"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32
\n
"
"beq 2f
\n
"
"1:
\n
"
/* load b0 */
"ld1 {v0.4s}, [%[b]], #16
\n
"
"fmla v8.4s, v16.4s, v0.s[0]
\n
"
"fmla v9.4s, v17.4s, v0.s[1]
\n
"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32
\n
"
"sub %[b], %[b], #16
\n
"
"subs %w[cnt], %w[cnt], #1
\n
"
"add %[b], %[b], %[ldb]
\n
"
"fmla v8.4s, v18.4s, v0.s[2]
\n
"
"fmla v9.4s, v19.4s, v0.s[3]
\n
"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32
\n
"
"bne 1b
\n
"
"fadd v8.4s, v8.4s, v9.4s
\n
"
"2:
\n
"
"st1 {v8.4s}, [%[c]], #16
\n
"
:
[
a
]
"+r"
(
a_ptr
),
[
b
]
"+r"
(
b_ptr
),
[
c
]
"+r"
(
C
),
[
cnt
]
"+r"
(
cnt
)
:
[
ldb
]
"r"
(
ldb_byte
),
[
vzero
]
"w"
(
vzero
)
:
"v0"
,
"v8"
,
"v9"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"cc"
,
"memory"
);
b
+=
4
;
}
#else
for
(;
n
>
7
;
n
-=
8
)
{
int
cnt
=
kcnt
;
const
float
*
a_ptr
=
A_packed
;
const
float
*
b_ptr
=
b
;
// clang-format off
asm
volatile
(
"0:
\n
"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]!
\n
"
"vld1.32 {d0-d3}, [%[b]]!
\n
"
/* load b2, b3 */
"vld1.32 {d4-d7}, [%[b]]!
\n
"
"vmul.f32 q8, q4, d0[0]
\n
"
"vmul.f32 q9, q4, d2[0]
\n
"
"vmul.f32 q10, q4, d4[0]
\n
"
"vmul.f32 q11, q4, d6[0]
\n
"
/* load a2, a3 */
"vld1.32 {d12-d15}, [%[a]]!
\n
"
"pld [%[b]]
\n
"
"vmla.f32 q8, q5, d0[1]
\n
"
"vmla.f32 q9, q5, d2[1]
\n
"
"vmla.f32 q10, q5, d4[1]
\n
"
"vmla.f32 q11, q5, d6[1]
\n
"
"subs %[cnt], %[cnt], #1
\n
"
"vmla.f32 q8, q6, d1[0]
\n
"
"vmla.f32 q9, q6, d3[0]
\n
"
"vmla.f32 q10, q6, d5[0]
\n
"
"vmla.f32 q11, q6, d7[0]
\n
"
"pld [%[b], #64]
\n
"
"vmla.f32 q8, q7, d1[1]
\n
"
"vmla.f32 q9, q7, d3[1]
\n
"
/* load b4, b5 */
"vld1.32 {d0-d3}, [%[b]]!
\n
"
"vmla.f32 q10, q7, d5[1]
\n
"
"vmla.f32 q11, q7, d7[1]
\n
"
/* load b6, b7 */
"vld1.32 {d4-d7}, [%[b]]!
\n
"
"vmul.f32 q12, q4, d0[0]
\n
"
"vmul.f32 q13, q4, d2[0]
\n
"
"vmul.f32 q14, q4, d4[0]
\n
"
"vmul.f32 q15, q4, d6[0]
\n
"
"sub %[b], %[b], #128
\n
"
"vmla.f32 q12, q5, d0[1]
\n
"
"vmla.f32 q13, q5, d2[1]
\n
"
"vmla.f32 q14, q5, d4[1]
\n
"
"vmla.f32 q15, q5, d6[1]
\n
"
"add %[b], %[b], %[ldb]
\n
"
"vmla.f32 q12, q6, d1[0]
\n
"
"vmla.f32 q13, q6, d3[0]
\n
"
"vmla.f32 q14, q6, d5[0]
\n
"
"vmla.f32 q15, q6, d7[0]
\n
"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]!
\n
"
"vmla.f32 q12, q7, d1[1]
\n
"
"vmla.f32 q13, q7, d3[1]
\n
"
/* load b0, b1 */
"vld1.32 {d0-d3}, [%[b]]!
\n
"
"vmla.f32 q14, q7, d5[1]
\n
"
"vmla.f32 q15, q7, d7[1]
\n
"
"beq 2f
\n
"
"1:
\n
"
/* load b2, b3 */
"vld1.32 {d4-d7}, [%[b]]!
\n
"
"vmla.f32 q8, q4, d0[0]
\n
"
"vmla.f32 q9, q4, d2[0]
\n
"
"vmla.f32 q10, q4, d4[0]
\n
"
"vmla.f32 q11, q4, d6[0]
\n
"
/* load a2, a3 */
"vld1.32 {d12-d15}, [%[a]]!
\n
"
"pld [%[b]]
\n
"
"vmla.f32 q8, q5, d0[1]
\n
"
"vmla.f32 q9, q5, d2[1]
\n
"
"vmla.f32 q10, q5, d4[1]
\n
"
"vmla.f32 q11, q5, d6[1]
\n
"
"subs %[cnt], %[cnt], #1
\n
"
"vmla.f32 q8, q6, d1[0]
\n
"
"vmla.f32 q9, q6, d3[0]
\n
"
"vmla.f32 q10, q6, d5[0]
\n
"
"vmla.f32 q11, q6, d7[0]
\n
"
"pld [%[b], #64]
\n
"
"vmla.f32 q8, q7, d1[1]
\n
"
"vmla.f32 q9, q7, d3[1]
\n
"
/* load b4, b5 */
"vld1.32 {d0-d3}, [%[b]]!
\n
"
"vmla.f32 q10, q7, d5[1]
\n
"
"vmla.f32 q11, q7, d7[1]
\n
"
/* load b6, b7 */
"vld1.32 {d4-d7}, [%[b]]!
\n
"
"vmla.f32 q12, q4, d0[0]
\n
"
"vmla.f32 q13, q4, d2[0]
\n
"
"vmla.f32 q14, q4, d4[0]
\n
"
"vmla.f32 q15, q4, d6[0]
\n
"
"sub %[b], %[b], #128
\n
"
"vmla.f32 q12, q5, d0[1]
\n
"
"vmla.f32 q13, q5, d2[1]
\n
"
"vmla.f32 q14, q5, d4[1]
\n
"
"vmla.f32 q15, q5, d6[1]
\n
"
"add %[b], %[b], %[ldb]
\n
"
"vmla.f32 q12, q6, d1[0]
\n
"
"vmla.f32 q13, q6, d3[0]
\n
"
"vmla.f32 q14, q6, d5[0]
\n
"
"vmla.f32 q15, q6, d7[0]
\n
"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]!
\n
"
"vmla.f32 q12, q7, d1[1]
\n
"
"vmla.f32 q13, q7, d3[1]
\n
"
/* load b0, b1 */
"vld1.32 {d0-d3}, [%[b]]!
\n
"
"vmla.f32 q14, q7, d5[1]
\n
"
"vmla.f32 q15, q7, d7[1]
\n
"
"bne 1b
\n
"
"2:
\n
"
"vst1.32 {d16-d19}, [%[c]]!
\n
"
"vst1.32 {d20-d23}, [%[c]]!
\n
"
"vst1.32 {d24-d27}, [%[c]]!
\n
"
"vst1.32 {d28-d31}, [%[c]]!
\n
"
:
[
a
]
"+r"
(
a_ptr
),
[
b
]
"+r"
(
b_ptr
),
[
c
]
"+r"
(
C
),
[
cnt
]
"+r"
(
cnt
)
:
[
ldb
]
"r"
(
ldb_byte
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"cc"
,
"memory"
);
b
+=
4
*
8
;
}
for
(;
n
>
3
;
n
-=
4
)
{
int
cnt
=
kcnt
;
const
float
*
a_ptr
=
A_packed
;
const
float
*
b_ptr
=
b
;
asm
volatile
(
"0:
\n
"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]!
\n
"
/* load b0-b3 */
"vld1.32 {d0-d3}, [%[b]]!
\n
"
"vld1.32 {d4-d7}, [%[b]]!
\n
"
"vmul.f32 q8, q4, d0[0]
\n
"
"vmul.f32 q9, q4, d2[0]
\n
"
"vmul.f32 q10, q4, d4[0]
\n
"
"vmul.f32 q11, q4, d6[0]
\n
"
/* load a2, a3 */
"vld1.32 {d12-d15}, [%[a]]!
\n
"
"sub %[b], %[b], #64
\n
"
"vmla.f32 q8, q5, d0[1]
\n
"
"vmla.f32 q9, q5, d2[1]
\n
"
"vmla.f32 q10, q5, d4[1]
\n
"
"vmla.f32 q11, q5, d6[1]
\n
"
"add %[b], %[b], %[ldb]
\n
"
"vmla.f32 q8, q6, d1[0]
\n
"
"vmla.f32 q9, q6, d3[0]
\n
"
"vmla.f32 q10, q6, d5[0]
\n
"
"vmla.f32 q11, q6, d7[0]
\n
"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]!
\n
"
"vmla.f32 q8, q7, d1[1]
\n
"
"vmla.f32 q9, q7, d3[1]
\n
"
"vmla.f32 q10, q7, d5[1]
\n
"
"vmla.f32 q11, q7, d7[1]
\n
"
"subs %[cnt], %[cnt], #1
\n
"
"beq 2f
\n
"
"1:
\n
"
/* load b0-b3 */
"vld1.32 {d0-d3}, [%[b]]!
\n
"
"vld1.32 {d4-d7}, [%[b]]!
\n
"
"vmla.f32 q8, q4, d0[0]
\n
"
"vmla.f32 q9, q4, d2[0]
\n
"
"vmla.f32 q10, q4, d4[0]
\n
"
"vmla.f32 q11, q4, d6[0]
\n
"
/* load a2, a3 */
"vld1.32 {d12-d15}, [%[a]]!
\n
"
"sub %[b], %[b], #64
\n
"
"vmla.f32 q8, q5, d0[1]
\n
"
"vmla.f32 q9, q5, d2[1]
\n
"
"vmla.f32 q10, q5, d4[1]
\n
"
"vmla.f32 q11, q5, d6[1]
\n
"
"add %[b], %[b], %[ldb]
\n
"
"vmla.f32 q8, q6, d1[0]
\n
"
"vmla.f32 q9, q6, d3[0]
\n
"
"vmla.f32 q10, q6, d5[0]
\n
"
"vmla.f32 q11, q6, d7[0]
\n
"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]!
\n
"
"vmla.f32 q8, q7, d1[1]
\n
"
"vmla.f32 q9, q7, d3[1]
\n
"
"vmla.f32 q10, q7, d5[1]
\n
"
"vmla.f32 q11, q7, d7[1]
\n
"
"subs %[cnt], %[cnt], #1
\n
"
"bne 1b
\n
"
"2:
\n
"
"vst1.32 {d16-d19}, [%[c]]!
\n
"
"vst1.32 {d20-d23}, [%[c]]!
\n
"
:
[
a
]
"+r"
(
a_ptr
),
[
b
]
"+r"
(
b_ptr
),
[
c
]
"+r"
(
C
),
[
cnt
]
"+r"
(
cnt
)
:
[
ldb
]
"r"
(
ldb_byte
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"cc"
,
"memory"
);
b
+=
4
*
4
;
}
for
(;
n
>
0
;
n
--
)
{
int
cnt
=
kcnt
;
const
float
*
a_ptr
=
A_packed
;
const
float
*
b_ptr
=
b
;
asm
volatile
(
"0:
\n
"
/* load a0, a1 */
"vld1.32 {d2-d5}, [%[a]]!
\n
"
/* load b0 */
"vld1.32 {d0-d1}, [%[b]]!
\n
"
"vmul.f32 q5, q1, d0[0]
\n
"
"vmul.f32 q6, q2, d0[1]
\n
"
/* load a2, a3 */
"vld1.32 {d6-d9}, [%[a]]!
\n
"
"sub %[b], %[b], #16
\n
"
"subs %[cnt], %[cnt], #1
\n
"
"add %[b], %[b], %[ldb]
\n
"
"vmla.f32 q5, q3, d1[0]
\n
"
"vmla.f32 q6, q4, d1[1]
\n
"
/* load a0, a1 */
"vld1.32 {d2-d5}, [%[a]]!
\n
"
"beq 2f
\n
"
"1:
\n
"
/* load b0 */
"vld1.32 {d0-d1}, [%[b]]!
\n
"
"vmla.f32 q5, q1, d0[0]
\n
"
"vmla.f32 q6, q2, d0[1]
\n
"
/* load a2, a3 */
"vld1.32 {d6-d9}, [%[a]]!
\n
"
"sub %[b], %[b], #16
\n
"
"subs %[cnt], %[cnt], #1
\n
"
"add %[b], %[b], %[ldb]
\n
"
"vmla.f32 q5, q3, d1[0]
\n
"
"vmla.f32 q6, q4, d1[1]
\n
"
/* load a0, a1 */
"vld1.32 {d2-d5}, [%[a]]!
\n
"
"bne 1b
\n
"
"vadd.f32 q5, q5, q6
\n
"
"2:
\n
"
"vst1.32 {d10-d11}, [%[c]]!
\n
"
:
[
a
]
"+r"
(
a_ptr
),
[
b
]
"+r"
(
b_ptr
),
[
c
]
"+r"
(
C
),
[
cnt
]
"+r"
(
cnt
)
:
[
ldb
]
"r"
(
ldb_byte
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"cc"
,
"memory"
);
// clang-format on
b
+=
4
;
}
#endif
A_packed
+=
lda
;
}
}
void
sgemm_prepack_c4
(
int
M
,
int
N
,
int
K
,
...
...
lite/backends/arm/math/packed_sgemm_c4.h
浏览文件 @
f99c34c8
...
...
@@ -47,6 +47,13 @@ void sgemm_prepack_c4_small(int M,
bool
has_bias
,
bool
has_relu
,
ARMContext
*
ctx
);
void
sgemm_prepack_c4_small
(
int
M
,
int
N
,
int
K
,
const
float
*
A_packed
,
const
float
*
B
,
float
*
C
,
ARMContext
*
ctx
);
}
// namespace math
}
// namespace arm
}
// namespace lite
...
...
lite/kernels/arm/conv_compute.cc
浏览文件 @
f99c34c8
...
...
@@ -68,19 +68,9 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
VLOG
(
3
)
<<
"invoking dw conv"
;
}
else
if
(
param
.
groups
==
1
&&
kw
==
3
&&
stride
==
1
&&
kps_equal
&&
no_dilation
)
{
bool
use_winograd
=
(
threads
==
1
&&
oc
>=
4
&&
ic
>=
4
&&
hout
>=
6
&&
wout
>=
6
&&
pads_equal
)
||
(
oc
>=
32
&&
ic
>=
32
&&
hout
>=
16
&&
wout
>=
16
&&
pads_equal
);
if
(
use_winograd
)
{
/// winograd conv impl
impl_
=
new
WinogradConv
<
PRECISION
(
kFloat
),
PRECISION
(
kFloat
)
>
;
VLOG
(
3
)
<<
"invoking winograd conv"
;
}
else
{
/// direct conv impl
impl_
=
new
DirectConv
<
PRECISION
(
kFloat
),
PRECISION
(
kFloat
)
>
;
VLOG
(
3
)
<<
"invoking direct conv"
;
}
/// winograd conv impl
impl_
=
new
WinogradConv
<
PRECISION
(
kFloat
),
PRECISION
(
kFloat
)
>
;
VLOG
(
3
)
<<
"invoking winograd conv"
;
}
else
if
(
param
.
groups
==
1
&&
kw
==
3
&&
stride
==
2
&&
chin
*
chout
<
4
*
hin
*
win
&&
kps_equal
&&
no_dilation
)
{
/// direct conv impl
...
...
lite/kernels/arm/conv_winograd.cc
浏览文件 @
f99c34c8
...
...
@@ -43,79 +43,47 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
int
oh
=
o_dims
[
2
];
int
ow
=
o_dims
[
3
];
int
tile_block
=
8
;
#ifdef __aarch64__
tile_block
=
16
;
#endif
int
parallel_threads
=
(((
ow
+
5
)
/
6
)
*
((
oh
+
5
)
/
6
)
+
tile_block
-
1
)
/
tile_block
;
if
(
threads
<=
2
&&
parallel_threads
>=
threads
)
{
if
(
last_kernel_is_c4_
==
1
)
{
choose_small_
=
ow
*
oh
/
(
tile_block
*
threads
)
<
36
?
true
:
false
;
if
(
choose_small_
)
{
wino_iw
=
4
;
if
(
last_function_
==
0
)
{
return
;
}
last_kernel_is_c4_
=
1
;
auto
pad
=
*
(
param
.
paddings
);
int
pad_h
=
pad
[
0
];
int
pad_w
=
pad
[
2
];
int
oc_pad
=
(
oc
+
3
)
/
4
*
4
;
int
ic_pad
=
(
ic
+
3
)
/
4
*
4
;
const
int
new_input_size
=
(
ic
+
3
)
/
4
*
4
*
(
ih
+
pad_h
*
2
)
*
(
iw
+
pad_w
*
2
);
const
int
temp_size
=
(
tile_block
*
((
ic
+
3
)
/
4
+
(
oc
+
3
)
/
4
)
*
256
+
512
)
*
threads
;
ctx
.
ExtendWorkspace
((
temp_size
+
new_input_size
)
*
sizeof
(
float
));
weights_
.
Resize
({
1
,
1
,
1
,
64
*
oc_pad
*
ic_pad
});
ctx
.
ExtendWorkspace
((
temp_size
+
new_input_size
)
*
sizeof
(
float
));
void
*
trans_tmp_ptr
=
malloc
(
sizeof
(
float
)
*
8
*
8
*
oc
*
ic
);
auto
weights_data_
=
weights_
.
mutable_data
<
float
>
();
lite
::
arm
::
math
::
weight_trans_c4
(
weights_data_
,
param
.
filter
->
data
<
float
>
(),
ic
,
oc
,
trans_tmp_ptr
);
free
(
trans_tmp_ptr
);
last_function_
=
0
;
}
else
{
if
(
last_kernel_is_c4_
==
0
)
{
wino_iw
=
8
;
if
(
last_function_
==
1
)
{
return
;
}
last_kernel_is_c4_
=
0
;
int
tile_w
=
(
ow
+
5
)
/
6
;
int
tile_h
=
(
oh
+
5
)
/
6
;
int
size_tile
=
tile_h
*
tile_w
;
int
size_trans_channel
=
8
*
8
*
size_tile
;
int
max_ch
=
ic
>
oc
?
ic
:
oc
;
const
int
n_wino
=
size_tile
;
ctx
.
ExtendWorkspace
((
size_trans_channel
*
max_ch
*
2
+
n_wino
)
*
sizeof
(
float
));
const
int
m_wino
=
oc
;
int
hblock
=
lite
::
arm
::
math
::
get_hblock
(
&
ctx
);
int
m_round
=
hblock
*
((
m_wino
+
hblock
-
1
)
/
hblock
);
weights_
.
Resize
({
1
,
1
,
1
,
8
*
8
*
m_round
*
ic
});
ctx
.
ExtendWorkspace
((
size_trans_channel
*
max_ch
*
2
+
n_wino
)
*
sizeof
(
float
));
auto
weights_wino
=
static_cast
<
float
*>
(
malloc
(
sizeof
(
float
)
*
8
*
8
*
oc
*
ic
));
void
*
trans_tmp_ptr
=
malloc
(
sizeof
(
float
)
*
8
*
8
*
oc
*
ic
);
lite
::
arm
::
math
::
winograd_transform_weights
(
weights_wino
,
param
.
filter
->
data
<
float
>
(),
oc
,
ic
,
trans_tmp_ptr
);
auto
weights_trans
=
weights_
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
64
;
++
i
)
{
float
*
packed_weights
=
weights_trans
+
i
*
m_round
*
ic
;
const
float
*
weights_wino_ptr
=
weights_wino
+
i
*
oc
*
ic
;
lite
::
arm
::
math
::
prepackA
(
packed_weights
,
weights_wino_ptr
,
1.
f
,
ic
,
0
,
m_wino
,
0
,
ic
,
false
,
&
ctx
);
}
free
(
trans_tmp_ptr
);
free
(
weights_wino
);
last_function_
=
1
;
}
auto
pad
=
*
(
param
.
paddings
);
int
pad_h
=
pad
[
0
];
int
pad_w
=
pad
[
2
];
int
oc_pad
=
(
oc
+
3
)
/
4
*
4
;
int
ic_pad
=
(
ic
+
3
)
/
4
*
4
;
const
int
new_input_size
=
(
ic
+
3
)
/
4
*
4
*
(
ih
+
pad_h
*
2
)
*
(
iw
+
pad_w
*
2
);
const
int
temp_size
=
(
tile_block
*
((
ic
+
3
)
/
4
+
(
oc
+
3
)
/
4
)
*
4
*
wino_iw
*
wino_iw
+
8
*
wino_iw
*
wino_iw
)
*
threads
;
ctx
.
ExtendWorkspace
((
temp_size
+
new_input_size
)
*
sizeof
(
float
));
weights_
.
Resize
({
1
,
1
,
1
,
wino_iw
*
wino_iw
*
oc_pad
*
ic_pad
});
ctx
.
ExtendWorkspace
((
temp_size
+
new_input_size
)
*
sizeof
(
float
));
void
*
trans_tmp_ptr
=
malloc
(
sizeof
(
float
)
*
wino_iw
*
wino_iw
*
oc
*
ic
);
auto
weights_data_
=
weights_
.
mutable_data
<
float
>
();
if
(
!
choose_small_
)
{
lite
::
arm
::
math
::
weight_trans_c4_8x8
(
weights_data_
,
param
.
filter
->
data
<
float
>
(),
ic
,
oc
,
trans_tmp_ptr
);
}
else
{
lite
::
arm
::
math
::
weight_trans_c4_4x4
(
weights_data_
,
param
.
filter
->
data
<
float
>
(),
ic
,
oc
,
trans_tmp_ptr
);
}
free
(
trans_tmp_ptr
);
last_shape_
=
x_dims
;
}
...
...
@@ -145,14 +113,7 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
int
ow
=
o_dims
[
3
];
int
oc
=
o_dims
[
1
];
int
tile_block
=
8
;
#ifdef __aarch64__
tile_block
=
16
;
#endif
int
threads
=
ctx
.
threads
();
int
parallel_threads
=
(((
ow
+
5
)
/
6
)
*
((
oh
+
5
)
/
6
)
+
tile_block
-
1
)
/
tile_block
;
if
(
threads
<=
2
&&
parallel_threads
>=
threads
)
{
if
(
!
choose_small_
)
{
lite
::
arm
::
math
::
conv_compute_6x6_3x3
(
i_data
,
o_data
,
bs
,
...
...
@@ -167,19 +128,38 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
param
,
&
ctx
);
}
else
{
lite
::
arm
::
math
::
conv_winograd3x3
(
i_data
,
o_data
,
bs
,
oc
,
oh
,
ow
,
ic
,
ih
,
iw
,
w_data
,
b_data
,
param
,
&
ctx
);
int
tile_block
=
8
;
int
block_count
=
(((
ow
+
1
)
/
2
)
*
((
oh
+
1
)
/
2
)
+
tile_block
-
1
)
/
tile_block
;
if
(
block_count
!=
1
)
{
lite
::
arm
::
math
::
conv_compute_2x2_3x3
(
i_data
,
o_data
,
bs
,
oc
,
oh
,
ow
,
ic
,
ih
,
iw
,
w_data
,
b_data
,
param
,
&
ctx
);
}
else
{
lite
::
arm
::
math
::
conv_compute_2x2_3x3_small
(
i_data
,
o_data
,
bs
,
oc
,
oh
,
ow
,
ic
,
ih
,
iw
,
w_data
,
b_data
,
param
,
&
ctx
);
}
}
}
...
...
lite/kernels/arm/conv_winograd.h
浏览文件 @
f99c34c8
...
...
@@ -40,7 +40,9 @@ class WinogradConv : public KernelLite<TARGET(kARM), Ptype> {
Tensor
weights_
;
DDim
last_shape_
;
int
workspace_size_
{
0
};
int
last_kernel_is_c4_
{
-
1
};
int
last_function_
{
-
1
};
bool
choose_small_
{
false
};
int
wino_iw
{
8
};
};
}
// namespace arm
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录