Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
f99c34c8
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f99c34c8
编写于
12月 11, 2019
作者:
T
TianXiaogang
提交者:
yiicy
12月 11, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add winograd f23 implement (#2584)
上级
fbb0d3b5
变更
7
展开全部
隐藏空白更改
内联
并排
Showing
7 changed file
with
1461 addition
and
181 deletion
+1461
-181
lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc
lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc
+817
-77
lite/backends/arm/math/conv_impl.h
lite/backends/arm/math/conv_impl.h
+29
-1
lite/backends/arm/math/packed_sgemm_c4.cc
lite/backends/arm/math/packed_sgemm_c4.cc
+534
-1
lite/backends/arm/math/packed_sgemm_c4.h
lite/backends/arm/math/packed_sgemm_c4.h
+7
-0
lite/kernels/arm/conv_compute.cc
lite/kernels/arm/conv_compute.cc
+3
-13
lite/kernels/arm/conv_winograd.cc
lite/kernels/arm/conv_winograd.cc
+68
-88
lite/kernels/arm/conv_winograd.h
lite/kernels/arm/conv_winograd.h
+3
-1
未找到文件。
lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc
浏览文件 @
f99c34c8
此差异已折叠。
点击以展开。
lite/backends/arm/math/conv_impl.h
浏览文件 @
f99c34c8
...
...
@@ -316,7 +316,9 @@ void fill_bias_int8(int* tensor,
int
channel_size
);
// new winograd
void
weight_trans_c4
(
void
weight_trans_c4_8x8
(
float
*
dest
,
const
float
*
src
,
int
ic
,
int
oc
,
void
*
workspace
);
void
weight_trans_c4_4x4
(
float
*
dest
,
const
float
*
src
,
int
ic
,
int
oc
,
void
*
workspace
);
void
conv_compute_6x6_3x3
(
const
float
*
input
,
float
*
output
,
...
...
@@ -331,6 +333,32 @@ void conv_compute_6x6_3x3(const float* input,
const
float
*
bias
,
const
operators
::
ConvParam
&
param
,
ARMContext
*
ctx
);
void
conv_compute_2x2_3x3
(
const
float
*
input
,
float
*
output
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
const
float
*
weight
,
const
float
*
bias
,
const
operators
::
ConvParam
&
param
,
ARMContext
*
ctx
);
void
conv_compute_2x2_3x3_small
(
const
float
*
input
,
float
*
output
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
const
float
*
weight
,
const
float
*
bias
,
const
operators
::
ConvParam
&
param
,
ARMContext
*
ctx
);
}
// namespace math
}
// namespace arm
}
// namespace lite
...
...
lite/backends/arm/math/packed_sgemm_c4.cc
浏览文件 @
f99c34c8
...
...
@@ -695,7 +695,6 @@ void sgemm_prepack_c4_common(int M,
}
}
}
void
sgemm_prepack_c4_small
(
int
M
,
int
N
,
int
K
,
...
...
@@ -1146,6 +1145,540 @@ void sgemm_prepack_c4_small(int M,
}
}
void
sgemm_prepack_c4_small
(
int
M
,
int
N
,
int
K
,
const
float
*
A_packed
,
const
float
*
B
,
float
*
C
,
ARMContext
*
ctx
)
{
const
int
m_round
=
(
M
+
3
)
/
4
*
4
;
const
int
k_round
=
(
K
+
3
)
/
4
*
4
;
const
int
mloop
=
m_round
>>
2
;
const
int
lda
=
4
*
k_round
;
const
int
ldb_byte
=
4
*
N
*
sizeof
(
float
);
const
int
kcnt
=
k_round
>>
2
;
#ifdef __aarch64__
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
#endif
for
(
int
m
=
0
;
m
<
mloop
;
++
m
)
{
const
float
*
b
=
B
;
int
n
=
N
;
#ifdef __aarch64__
for
(;
n
>
7
;
n
-=
8
)
{
int
cnt
=
kcnt
;
const
float
*
a_ptr
=
A_packed
;
const
float
*
b_ptr
=
b
;
// clang-format off
asm
volatile
(
"0:
\n
"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32
\n
"
/* load b0, b1 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32
\n
"
/* load b2, b3 */
"ld1 {v2.4s, v3.4s}, [%[b]], #32
\n
"
/* load a2, a3 */
"fmul v8.4s, v16.4s, v0.s[0]
\n
"
"fmul v9.4s, v16.4s, v1.s[0]
\n
"
"fmul v10.4s, v16.4s, v2.s[0]
\n
"
"fmul v11.4s, v16.4s, v3.s[0]
\n
"
"ld1 {v18.4s, v19.4s}, [%[a]], #32
\n
"
"prfm pldl1keep, [%[b]]
\n
"
"fmla v8.4s, v17.4s, v0.s[1]
\n
"
"fmla v9.4s, v17.4s, v1.s[1]
\n
"
"fmla v10.4s, v17.4s, v2.s[1]
\n
"
"fmla v11.4s, v17.4s, v3.s[1]
\n
"
/* load b4, b5 */
"ld1 {v4.4s, v5.4s}, [%[b]], #32
\n
"
"fmla v8.4s, v18.4s, v0.s[2]
\n
"
"fmla v9.4s, v18.4s, v1.s[2]
\n
"
"fmla v10.4s, v18.4s, v2.s[2]
\n
"
"fmla v11.4s, v18.4s, v3.s[2]
\n
"
/* load b6, b7 */
"ld1 {v6.4s, v7.4s}, [%[b]], #32
\n
"
"fmla v8.4s, v19.4s, v0.s[3]
\n
"
"fmla v9.4s, v19.4s, v1.s[3]
\n
"
"fmla v10.4s, v19.4s, v2.s[3]
\n
"
"fmla v11.4s, v19.4s, v3.s[3]
\n
"
"sub %[b], %[b], #128
\n
"
"fmul v12.4s, v16.4s, v4.s[0]
\n
"
"fmul v13.4s, v16.4s, v5.s[0]
\n
"
"fmul v14.4s, v16.4s, v6.s[0]
\n
"
"fmul v15.4s, v16.4s, v7.s[0]
\n
"
"add %[b], %[b], %[ldb]
\n
"
"fmla v12.4s, v17.4s, v4.s[1]
\n
"
"fmla v13.4s, v17.4s, v5.s[1]
\n
"
"fmla v14.4s, v17.4s, v6.s[1]
\n
"
"fmla v15.4s, v17.4s, v7.s[1]
\n
"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32
\n
"
"fmla v12.4s, v18.4s, v4.s[2]
\n
"
"fmla v13.4s, v18.4s, v5.s[2]
\n
"
"fmla v14.4s, v18.4s, v6.s[2]
\n
"
"fmla v15.4s, v18.4s, v7.s[2]
\n
"
/* load b0, b1 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32
\n
"
"fmla v12.4s, v19.4s, v4.s[3]
\n
"
"fmla v13.4s, v19.4s, v5.s[3]
\n
"
"fmla v14.4s, v19.4s, v6.s[3]
\n
"
"fmla v15.4s, v19.4s, v7.s[3]
\n
"
"subs %w[cnt], %w[cnt], #1
\n
"
"beq 2f
\n
"
"1:
\n
"
/* load b2, b3 */
"ld1 {v2.4s, v3.4s}, [%[b]], #32
\n
"
"fmla v8.4s, v16.4s, v0.s[0]
\n
"
"fmla v9.4s, v16.4s, v1.s[0]
\n
"
"fmla v10.4s, v16.4s, v2.s[0]
\n
"
"fmla v11.4s, v16.4s, v3.s[0]
\n
"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32
\n
"
"prfm pldl1keep, [%[b]]
\n
"
"fmla v8.4s, v17.4s, v0.s[1]
\n
"
"fmla v9.4s, v17.4s, v1.s[1]
\n
"
"fmla v10.4s, v17.4s, v2.s[1]
\n
"
"fmla v11.4s, v17.4s, v3.s[1]
\n
"
/* load b4, b5 */
"ld1 {v4.4s, v5.4s}, [%[b]], #32
\n
"
"fmla v8.4s, v18.4s, v0.s[2]
\n
"
"fmla v9.4s, v18.4s, v1.s[2]
\n
"
"fmla v10.4s, v18.4s, v2.s[2]
\n
"
"fmla v11.4s, v18.4s, v3.s[2]
\n
"
/* load b6, b7 */
"ld1 {v6.4s, v7.4s}, [%[b]], #32
\n
"
"fmla v8.4s, v19.4s, v0.s[3]
\n
"
"fmla v9.4s, v19.4s, v1.s[3]
\n
"
"fmla v10.4s, v19.4s, v2.s[3]
\n
"
"fmla v11.4s, v19.4s, v3.s[3]
\n
"
"sub %[b], %[b], #128
\n
"
"fmla v12.4s, v16.4s, v4.s[0]
\n
"
"fmla v13.4s, v16.4s, v5.s[0]
\n
"
"fmla v14.4s, v16.4s, v6.s[0]
\n
"
"fmla v15.4s, v16.4s, v7.s[0]
\n
"
"add %[b], %[b], %[ldb]
\n
"
"fmla v12.4s, v17.4s, v4.s[1]
\n
"
"fmla v13.4s, v17.4s, v5.s[1]
\n
"
"fmla v14.4s, v17.4s, v6.s[1]
\n
"
"fmla v15.4s, v17.4s, v7.s[1]
\n
"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32
\n
"
"fmla v12.4s, v18.4s, v4.s[2]
\n
"
"fmla v13.4s, v18.4s, v5.s[2]
\n
"
"fmla v14.4s, v18.4s, v6.s[2]
\n
"
"fmla v15.4s, v18.4s, v7.s[2]
\n
"
/* load b0, b1 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32
\n
"
"fmla v12.4s, v19.4s, v4.s[3]
\n
"
"fmla v13.4s, v19.4s, v5.s[3]
\n
"
"fmla v14.4s, v19.4s, v6.s[3]
\n
"
"fmla v15.4s, v19.4s, v7.s[3]
\n
"
"subs %w[cnt], %w[cnt], #1
\n
"
"bne 1b
\n
"
"2:
\n
"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64
\n
"
"st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [%[c]], #64
\n
"
:
[
a
]
"+r"
(
a_ptr
),
[
b
]
"+r"
(
b_ptr
),
[
c
]
"+r"
(
C
),
[
cnt
]
"+r"
(
cnt
)
:
[
ldb
]
"r"
(
ldb_byte
),
[
vzero
]
"w"
(
vzero
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"cc"
,
"memory"
);
b
+=
4
*
8
;
}
for
(;
n
>
3
;
n
-=
4
)
{
int
cnt
=
kcnt
;
const
float
*
a_ptr
=
A_packed
;
const
float
*
b_ptr
=
b
;
asm
volatile
(
"0:
\n
"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32
\n
"
/* load b0-b3 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32
\n
"
"ld1 {v2.4s, v3.4s}, [%[b]], #32
\n
"
"fmul v8.4s, v16.4s, v0.s[0]
\n
"
"fmul v9.4s, v16.4s, v1.s[0]
\n
"
"fmul v10.4s, v16.4s, v2.s[0]
\n
"
"fmul v11.4s, v16.4s, v3.s[0]
\n
"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32
\n
"
"sub %[b], %[b], #64
\n
"
"fmla v8.4s, v17.4s, v0.s[1]
\n
"
"fmla v9.4s, v17.4s, v1.s[1]
\n
"
"fmla v10.4s, v17.4s, v2.s[1]
\n
"
"fmla v11.4s, v17.4s, v3.s[1]
\n
"
"add %[b], %[b], %[ldb]
\n
"
"fmla v8.4s, v18.4s, v0.s[2]
\n
"
"fmla v9.4s, v18.4s, v1.s[2]
\n
"
"fmla v10.4s, v18.4s, v2.s[2]
\n
"
"fmla v11.4s, v18.4s, v3.s[2]
\n
"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32
\n
"
"fmla v8.4s, v19.4s, v0.s[3]
\n
"
"fmla v9.4s, v19.4s, v1.s[3]
\n
"
"fmla v10.4s, v19.4s, v2.s[3]
\n
"
"fmla v11.4s, v19.4s, v3.s[3]
\n
"
"subs %w[cnt], %w[cnt], #1
\n
"
"beq 2f
\n
"
"1:
\n
"
/* load b0-b3 */
"ld1 {v0.4s, v1.4s}, [%[b]], #32
\n
"
"ld1 {v2.4s, v3.4s}, [%[b]], #32
\n
"
"fmla v8.4s, v16.4s, v0.s[0]
\n
"
"fmla v9.4s, v16.4s, v1.s[0]
\n
"
"fmla v10.4s, v16.4s, v2.s[0]
\n
"
"fmla v11.4s, v16.4s, v3.s[0]
\n
"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32
\n
"
"sub %[b], %[b], #64
\n
"
"fmla v8.4s, v17.4s, v0.s[1]
\n
"
"fmla v9.4s, v17.4s, v1.s[1]
\n
"
"fmla v10.4s, v17.4s, v2.s[1]
\n
"
"fmla v11.4s, v17.4s, v3.s[1]
\n
"
"add %[b], %[b], %[ldb]
\n
"
"fmla v8.4s, v18.4s, v0.s[2]
\n
"
"fmla v9.4s, v18.4s, v1.s[2]
\n
"
"fmla v10.4s, v18.4s, v2.s[2]
\n
"
"fmla v11.4s, v18.4s, v3.s[2]
\n
"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32
\n
"
"fmla v8.4s, v19.4s, v0.s[3]
\n
"
"fmla v9.4s, v19.4s, v1.s[3]
\n
"
"fmla v10.4s, v19.4s, v2.s[3]
\n
"
"fmla v11.4s, v19.4s, v3.s[3]
\n
"
"subs %w[cnt], %w[cnt], #1
\n
"
"bne 1b
\n
"
"2:
\n
"
"st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [%[c]], #64
\n
"
:
[
a
]
"+r"
(
a_ptr
),
[
b
]
"+r"
(
b_ptr
),
[
c
]
"+r"
(
C
),
[
cnt
]
"+r"
(
cnt
)
:
[
ldb
]
"r"
(
ldb_byte
),
[
vzero
]
"w"
(
vzero
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"cc"
,
"memory"
);
b
+=
4
*
4
;
}
for
(;
n
>
0
;
n
--
)
{
int
cnt
=
kcnt
;
const
float
*
a_ptr
=
A_packed
;
const
float
*
b_ptr
=
b
;
asm
volatile
(
"0:
\n
"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32
\n
"
/* load b0 */
"ld1 {v0.4s}, [%[b]], #16
\n
"
"fmul v8.4s, v16.4s, v0.s[0]
\n
"
"fmul v9.4s, v17.4s, v0.s[1]
\n
"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32
\n
"
"sub %[b], %[b], #16
\n
"
"subs %w[cnt], %w[cnt], #1
\n
"
"add %[b], %[b], %[ldb]
\n
"
"fmla v8.4s, v18.4s, v0.s[2]
\n
"
"fmla v9.4s, v19.4s, v0.s[3]
\n
"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32
\n
"
"beq 2f
\n
"
"1:
\n
"
/* load b0 */
"ld1 {v0.4s}, [%[b]], #16
\n
"
"fmla v8.4s, v16.4s, v0.s[0]
\n
"
"fmla v9.4s, v17.4s, v0.s[1]
\n
"
/* load a2, a3 */
"ld1 {v18.4s, v19.4s}, [%[a]], #32
\n
"
"sub %[b], %[b], #16
\n
"
"subs %w[cnt], %w[cnt], #1
\n
"
"add %[b], %[b], %[ldb]
\n
"
"fmla v8.4s, v18.4s, v0.s[2]
\n
"
"fmla v9.4s, v19.4s, v0.s[3]
\n
"
/* load a0, a1 */
"ld1 {v16.4s, v17.4s}, [%[a]], #32
\n
"
"bne 1b
\n
"
"fadd v8.4s, v8.4s, v9.4s
\n
"
"2:
\n
"
"st1 {v8.4s}, [%[c]], #16
\n
"
:
[
a
]
"+r"
(
a_ptr
),
[
b
]
"+r"
(
b_ptr
),
[
c
]
"+r"
(
C
),
[
cnt
]
"+r"
(
cnt
)
:
[
ldb
]
"r"
(
ldb_byte
),
[
vzero
]
"w"
(
vzero
)
:
"v0"
,
"v8"
,
"v9"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"cc"
,
"memory"
);
b
+=
4
;
}
#else
for
(;
n
>
7
;
n
-=
8
)
{
int
cnt
=
kcnt
;
const
float
*
a_ptr
=
A_packed
;
const
float
*
b_ptr
=
b
;
// clang-format off
asm
volatile
(
"0:
\n
"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]!
\n
"
"vld1.32 {d0-d3}, [%[b]]!
\n
"
/* load b2, b3 */
"vld1.32 {d4-d7}, [%[b]]!
\n
"
"vmul.f32 q8, q4, d0[0]
\n
"
"vmul.f32 q9, q4, d2[0]
\n
"
"vmul.f32 q10, q4, d4[0]
\n
"
"vmul.f32 q11, q4, d6[0]
\n
"
/* load a2, a3 */
"vld1.32 {d12-d15}, [%[a]]!
\n
"
"pld [%[b]]
\n
"
"vmla.f32 q8, q5, d0[1]
\n
"
"vmla.f32 q9, q5, d2[1]
\n
"
"vmla.f32 q10, q5, d4[1]
\n
"
"vmla.f32 q11, q5, d6[1]
\n
"
"subs %[cnt], %[cnt], #1
\n
"
"vmla.f32 q8, q6, d1[0]
\n
"
"vmla.f32 q9, q6, d3[0]
\n
"
"vmla.f32 q10, q6, d5[0]
\n
"
"vmla.f32 q11, q6, d7[0]
\n
"
"pld [%[b], #64]
\n
"
"vmla.f32 q8, q7, d1[1]
\n
"
"vmla.f32 q9, q7, d3[1]
\n
"
/* load b4, b5 */
"vld1.32 {d0-d3}, [%[b]]!
\n
"
"vmla.f32 q10, q7, d5[1]
\n
"
"vmla.f32 q11, q7, d7[1]
\n
"
/* load b6, b7 */
"vld1.32 {d4-d7}, [%[b]]!
\n
"
"vmul.f32 q12, q4, d0[0]
\n
"
"vmul.f32 q13, q4, d2[0]
\n
"
"vmul.f32 q14, q4, d4[0]
\n
"
"vmul.f32 q15, q4, d6[0]
\n
"
"sub %[b], %[b], #128
\n
"
"vmla.f32 q12, q5, d0[1]
\n
"
"vmla.f32 q13, q5, d2[1]
\n
"
"vmla.f32 q14, q5, d4[1]
\n
"
"vmla.f32 q15, q5, d6[1]
\n
"
"add %[b], %[b], %[ldb]
\n
"
"vmla.f32 q12, q6, d1[0]
\n
"
"vmla.f32 q13, q6, d3[0]
\n
"
"vmla.f32 q14, q6, d5[0]
\n
"
"vmla.f32 q15, q6, d7[0]
\n
"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]!
\n
"
"vmla.f32 q12, q7, d1[1]
\n
"
"vmla.f32 q13, q7, d3[1]
\n
"
/* load b0, b1 */
"vld1.32 {d0-d3}, [%[b]]!
\n
"
"vmla.f32 q14, q7, d5[1]
\n
"
"vmla.f32 q15, q7, d7[1]
\n
"
"beq 2f
\n
"
"1:
\n
"
/* load b2, b3 */
"vld1.32 {d4-d7}, [%[b]]!
\n
"
"vmla.f32 q8, q4, d0[0]
\n
"
"vmla.f32 q9, q4, d2[0]
\n
"
"vmla.f32 q10, q4, d4[0]
\n
"
"vmla.f32 q11, q4, d6[0]
\n
"
/* load a2, a3 */
"vld1.32 {d12-d15}, [%[a]]!
\n
"
"pld [%[b]]
\n
"
"vmla.f32 q8, q5, d0[1]
\n
"
"vmla.f32 q9, q5, d2[1]
\n
"
"vmla.f32 q10, q5, d4[1]
\n
"
"vmla.f32 q11, q5, d6[1]
\n
"
"subs %[cnt], %[cnt], #1
\n
"
"vmla.f32 q8, q6, d1[0]
\n
"
"vmla.f32 q9, q6, d3[0]
\n
"
"vmla.f32 q10, q6, d5[0]
\n
"
"vmla.f32 q11, q6, d7[0]
\n
"
"pld [%[b], #64]
\n
"
"vmla.f32 q8, q7, d1[1]
\n
"
"vmla.f32 q9, q7, d3[1]
\n
"
/* load b4, b5 */
"vld1.32 {d0-d3}, [%[b]]!
\n
"
"vmla.f32 q10, q7, d5[1]
\n
"
"vmla.f32 q11, q7, d7[1]
\n
"
/* load b6, b7 */
"vld1.32 {d4-d7}, [%[b]]!
\n
"
"vmla.f32 q12, q4, d0[0]
\n
"
"vmla.f32 q13, q4, d2[0]
\n
"
"vmla.f32 q14, q4, d4[0]
\n
"
"vmla.f32 q15, q4, d6[0]
\n
"
"sub %[b], %[b], #128
\n
"
"vmla.f32 q12, q5, d0[1]
\n
"
"vmla.f32 q13, q5, d2[1]
\n
"
"vmla.f32 q14, q5, d4[1]
\n
"
"vmla.f32 q15, q5, d6[1]
\n
"
"add %[b], %[b], %[ldb]
\n
"
"vmla.f32 q12, q6, d1[0]
\n
"
"vmla.f32 q13, q6, d3[0]
\n
"
"vmla.f32 q14, q6, d5[0]
\n
"
"vmla.f32 q15, q6, d7[0]
\n
"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]!
\n
"
"vmla.f32 q12, q7, d1[1]
\n
"
"vmla.f32 q13, q7, d3[1]
\n
"
/* load b0, b1 */
"vld1.32 {d0-d3}, [%[b]]!
\n
"
"vmla.f32 q14, q7, d5[1]
\n
"
"vmla.f32 q15, q7, d7[1]
\n
"
"bne 1b
\n
"
"2:
\n
"
"vst1.32 {d16-d19}, [%[c]]!
\n
"
"vst1.32 {d20-d23}, [%[c]]!
\n
"
"vst1.32 {d24-d27}, [%[c]]!
\n
"
"vst1.32 {d28-d31}, [%[c]]!
\n
"
:
[
a
]
"+r"
(
a_ptr
),
[
b
]
"+r"
(
b_ptr
),
[
c
]
"+r"
(
C
),
[
cnt
]
"+r"
(
cnt
)
:
[
ldb
]
"r"
(
ldb_byte
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"cc"
,
"memory"
);
b
+=
4
*
8
;
}
for
(;
n
>
3
;
n
-=
4
)
{
int
cnt
=
kcnt
;
const
float
*
a_ptr
=
A_packed
;
const
float
*
b_ptr
=
b
;
asm
volatile
(
"0:
\n
"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]!
\n
"
/* load b0-b3 */
"vld1.32 {d0-d3}, [%[b]]!
\n
"
"vld1.32 {d4-d7}, [%[b]]!
\n
"
"vmul.f32 q8, q4, d0[0]
\n
"
"vmul.f32 q9, q4, d2[0]
\n
"
"vmul.f32 q10, q4, d4[0]
\n
"
"vmul.f32 q11, q4, d6[0]
\n
"
/* load a2, a3 */
"vld1.32 {d12-d15}, [%[a]]!
\n
"
"sub %[b], %[b], #64
\n
"
"vmla.f32 q8, q5, d0[1]
\n
"
"vmla.f32 q9, q5, d2[1]
\n
"
"vmla.f32 q10, q5, d4[1]
\n
"
"vmla.f32 q11, q5, d6[1]
\n
"
"add %[b], %[b], %[ldb]
\n
"
"vmla.f32 q8, q6, d1[0]
\n
"
"vmla.f32 q9, q6, d3[0]
\n
"
"vmla.f32 q10, q6, d5[0]
\n
"
"vmla.f32 q11, q6, d7[0]
\n
"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]!
\n
"
"vmla.f32 q8, q7, d1[1]
\n
"
"vmla.f32 q9, q7, d3[1]
\n
"
"vmla.f32 q10, q7, d5[1]
\n
"
"vmla.f32 q11, q7, d7[1]
\n
"
"subs %[cnt], %[cnt], #1
\n
"
"beq 2f
\n
"
"1:
\n
"
/* load b0-b3 */
"vld1.32 {d0-d3}, [%[b]]!
\n
"
"vld1.32 {d4-d7}, [%[b]]!
\n
"
"vmla.f32 q8, q4, d0[0]
\n
"
"vmla.f32 q9, q4, d2[0]
\n
"
"vmla.f32 q10, q4, d4[0]
\n
"
"vmla.f32 q11, q4, d6[0]
\n
"
/* load a2, a3 */
"vld1.32 {d12-d15}, [%[a]]!
\n
"
"sub %[b], %[b], #64
\n
"
"vmla.f32 q8, q5, d0[1]
\n
"
"vmla.f32 q9, q5, d2[1]
\n
"
"vmla.f32 q10, q5, d4[1]
\n
"
"vmla.f32 q11, q5, d6[1]
\n
"
"add %[b], %[b], %[ldb]
\n
"
"vmla.f32 q8, q6, d1[0]
\n
"
"vmla.f32 q9, q6, d3[0]
\n
"
"vmla.f32 q10, q6, d5[0]
\n
"
"vmla.f32 q11, q6, d7[0]
\n
"
/* load a0, a1 */
"vld1.32 {d8-d11}, [%[a]]!
\n
"
"vmla.f32 q8, q7, d1[1]
\n
"
"vmla.f32 q9, q7, d3[1]
\n
"
"vmla.f32 q10, q7, d5[1]
\n
"
"vmla.f32 q11, q7, d7[1]
\n
"
"subs %[cnt], %[cnt], #1
\n
"
"bne 1b
\n
"
"2:
\n
"
"vst1.32 {d16-d19}, [%[c]]!
\n
"
"vst1.32 {d20-d23}, [%[c]]!
\n
"
:
[
a
]
"+r"
(
a_ptr
),
[
b
]
"+r"
(
b_ptr
),
[
c
]
"+r"
(
C
),
[
cnt
]
"+r"
(
cnt
)
:
[
ldb
]
"r"
(
ldb_byte
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"cc"
,
"memory"
);
b
+=
4
*
4
;
}
for
(;
n
>
0
;
n
--
)
{
int
cnt
=
kcnt
;
const
float
*
a_ptr
=
A_packed
;
const
float
*
b_ptr
=
b
;
asm
volatile
(
"0:
\n
"
/* load a0, a1 */
"vld1.32 {d2-d5}, [%[a]]!
\n
"
/* load b0 */
"vld1.32 {d0-d1}, [%[b]]!
\n
"
"vmul.f32 q5, q1, d0[0]
\n
"
"vmul.f32 q6, q2, d0[1]
\n
"
/* load a2, a3 */
"vld1.32 {d6-d9}, [%[a]]!
\n
"
"sub %[b], %[b], #16
\n
"
"subs %[cnt], %[cnt], #1
\n
"
"add %[b], %[b], %[ldb]
\n
"
"vmla.f32 q5, q3, d1[0]
\n
"
"vmla.f32 q6, q4, d1[1]
\n
"
/* load a0, a1 */
"vld1.32 {d2-d5}, [%[a]]!
\n
"
"beq 2f
\n
"
"1:
\n
"
/* load b0 */
"vld1.32 {d0-d1}, [%[b]]!
\n
"
"vmla.f32 q5, q1, d0[0]
\n
"
"vmla.f32 q6, q2, d0[1]
\n
"
/* load a2, a3 */
"vld1.32 {d6-d9}, [%[a]]!
\n
"
"sub %[b], %[b], #16
\n
"
"subs %[cnt], %[cnt], #1
\n
"
"add %[b], %[b], %[ldb]
\n
"
"vmla.f32 q5, q3, d1[0]
\n
"
"vmla.f32 q6, q4, d1[1]
\n
"
/* load a0, a1 */
"vld1.32 {d2-d5}, [%[a]]!
\n
"
"bne 1b
\n
"
"vadd.f32 q5, q5, q6
\n
"
"2:
\n
"
"vst1.32 {d10-d11}, [%[c]]!
\n
"
:
[
a
]
"+r"
(
a_ptr
),
[
b
]
"+r"
(
b_ptr
),
[
c
]
"+r"
(
C
),
[
cnt
]
"+r"
(
cnt
)
:
[
ldb
]
"r"
(
ldb_byte
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"cc"
,
"memory"
);
// clang-format on
b
+=
4
;
}
#endif
A_packed
+=
lda
;
}
}
void
sgemm_prepack_c4
(
int
M
,
int
N
,
int
K
,
...
...
lite/backends/arm/math/packed_sgemm_c4.h
浏览文件 @
f99c34c8
...
...
@@ -47,6 +47,13 @@ void sgemm_prepack_c4_small(int M,
bool
has_bias
,
bool
has_relu
,
ARMContext
*
ctx
);
void
sgemm_prepack_c4_small
(
int
M
,
int
N
,
int
K
,
const
float
*
A_packed
,
const
float
*
B
,
float
*
C
,
ARMContext
*
ctx
);
}
// namespace math
}
// namespace arm
}
// namespace lite
...
...
lite/kernels/arm/conv_compute.cc
浏览文件 @
f99c34c8
...
...
@@ -68,19 +68,9 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
VLOG
(
3
)
<<
"invoking dw conv"
;
}
else
if
(
param
.
groups
==
1
&&
kw
==
3
&&
stride
==
1
&&
kps_equal
&&
no_dilation
)
{
bool
use_winograd
=
(
threads
==
1
&&
oc
>=
4
&&
ic
>=
4
&&
hout
>=
6
&&
wout
>=
6
&&
pads_equal
)
||
(
oc
>=
32
&&
ic
>=
32
&&
hout
>=
16
&&
wout
>=
16
&&
pads_equal
);
if
(
use_winograd
)
{
/// winograd conv impl
impl_
=
new
WinogradConv
<
PRECISION
(
kFloat
),
PRECISION
(
kFloat
)
>
;
VLOG
(
3
)
<<
"invoking winograd conv"
;
}
else
{
/// direct conv impl
impl_
=
new
DirectConv
<
PRECISION
(
kFloat
),
PRECISION
(
kFloat
)
>
;
VLOG
(
3
)
<<
"invoking direct conv"
;
}
/// winograd conv impl
impl_
=
new
WinogradConv
<
PRECISION
(
kFloat
),
PRECISION
(
kFloat
)
>
;
VLOG
(
3
)
<<
"invoking winograd conv"
;
}
else
if
(
param
.
groups
==
1
&&
kw
==
3
&&
stride
==
2
&&
chin
*
chout
<
4
*
hin
*
win
&&
kps_equal
&&
no_dilation
)
{
/// direct conv impl
...
...
lite/kernels/arm/conv_winograd.cc
浏览文件 @
f99c34c8
...
...
@@ -43,79 +43,47 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::ReInitWhenNeeded() {
int
oh
=
o_dims
[
2
];
int
ow
=
o_dims
[
3
];
int
tile_block
=
8
;
#ifdef __aarch64__
tile_block
=
16
;
#endif
int
parallel_threads
=
(((
ow
+
5
)
/
6
)
*
((
oh
+
5
)
/
6
)
+
tile_block
-
1
)
/
tile_block
;
if
(
threads
<=
2
&&
parallel_threads
>=
threads
)
{
if
(
last_kernel_is_c4_
==
1
)
{
choose_small_
=
ow
*
oh
/
(
tile_block
*
threads
)
<
36
?
true
:
false
;
if
(
choose_small_
)
{
wino_iw
=
4
;
if
(
last_function_
==
0
)
{
return
;
}
last_kernel_is_c4_
=
1
;
auto
pad
=
*
(
param
.
paddings
);
int
pad_h
=
pad
[
0
];
int
pad_w
=
pad
[
2
];
int
oc_pad
=
(
oc
+
3
)
/
4
*
4
;
int
ic_pad
=
(
ic
+
3
)
/
4
*
4
;
const
int
new_input_size
=
(
ic
+
3
)
/
4
*
4
*
(
ih
+
pad_h
*
2
)
*
(
iw
+
pad_w
*
2
);
const
int
temp_size
=
(
tile_block
*
((
ic
+
3
)
/
4
+
(
oc
+
3
)
/
4
)
*
256
+
512
)
*
threads
;
ctx
.
ExtendWorkspace
((
temp_size
+
new_input_size
)
*
sizeof
(
float
));
weights_
.
Resize
({
1
,
1
,
1
,
64
*
oc_pad
*
ic_pad
});
ctx
.
ExtendWorkspace
((
temp_size
+
new_input_size
)
*
sizeof
(
float
));
void
*
trans_tmp_ptr
=
malloc
(
sizeof
(
float
)
*
8
*
8
*
oc
*
ic
);
auto
weights_data_
=
weights_
.
mutable_data
<
float
>
();
lite
::
arm
::
math
::
weight_trans_c4
(
weights_data_
,
param
.
filter
->
data
<
float
>
(),
ic
,
oc
,
trans_tmp_ptr
);
free
(
trans_tmp_ptr
);
last_function_
=
0
;
}
else
{
if
(
last_kernel_is_c4_
==
0
)
{
wino_iw
=
8
;
if
(
last_function_
==
1
)
{
return
;
}
last_kernel_is_c4_
=
0
;
int
tile_w
=
(
ow
+
5
)
/
6
;
int
tile_h
=
(
oh
+
5
)
/
6
;
int
size_tile
=
tile_h
*
tile_w
;
int
size_trans_channel
=
8
*
8
*
size_tile
;
int
max_ch
=
ic
>
oc
?
ic
:
oc
;
const
int
n_wino
=
size_tile
;
ctx
.
ExtendWorkspace
((
size_trans_channel
*
max_ch
*
2
+
n_wino
)
*
sizeof
(
float
));
const
int
m_wino
=
oc
;
int
hblock
=
lite
::
arm
::
math
::
get_hblock
(
&
ctx
);
int
m_round
=
hblock
*
((
m_wino
+
hblock
-
1
)
/
hblock
);
weights_
.
Resize
({
1
,
1
,
1
,
8
*
8
*
m_round
*
ic
});
ctx
.
ExtendWorkspace
((
size_trans_channel
*
max_ch
*
2
+
n_wino
)
*
sizeof
(
float
));
auto
weights_wino
=
static_cast
<
float
*>
(
malloc
(
sizeof
(
float
)
*
8
*
8
*
oc
*
ic
));
void
*
trans_tmp_ptr
=
malloc
(
sizeof
(
float
)
*
8
*
8
*
oc
*
ic
);
lite
::
arm
::
math
::
winograd_transform_weights
(
weights_wino
,
param
.
filter
->
data
<
float
>
(),
oc
,
ic
,
trans_tmp_ptr
);
auto
weights_trans
=
weights_
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
64
;
++
i
)
{
float
*
packed_weights
=
weights_trans
+
i
*
m_round
*
ic
;
const
float
*
weights_wino_ptr
=
weights_wino
+
i
*
oc
*
ic
;
lite
::
arm
::
math
::
prepackA
(
packed_weights
,
weights_wino_ptr
,
1.
f
,
ic
,
0
,
m_wino
,
0
,
ic
,
false
,
&
ctx
);
}
free
(
trans_tmp_ptr
);
free
(
weights_wino
);
last_function_
=
1
;
}
auto
pad
=
*
(
param
.
paddings
);
int
pad_h
=
pad
[
0
];
int
pad_w
=
pad
[
2
];
int
oc_pad
=
(
oc
+
3
)
/
4
*
4
;
int
ic_pad
=
(
ic
+
3
)
/
4
*
4
;
const
int
new_input_size
=
(
ic
+
3
)
/
4
*
4
*
(
ih
+
pad_h
*
2
)
*
(
iw
+
pad_w
*
2
);
const
int
temp_size
=
(
tile_block
*
((
ic
+
3
)
/
4
+
(
oc
+
3
)
/
4
)
*
4
*
wino_iw
*
wino_iw
+
8
*
wino_iw
*
wino_iw
)
*
threads
;
ctx
.
ExtendWorkspace
((
temp_size
+
new_input_size
)
*
sizeof
(
float
));
weights_
.
Resize
({
1
,
1
,
1
,
wino_iw
*
wino_iw
*
oc_pad
*
ic_pad
});
ctx
.
ExtendWorkspace
((
temp_size
+
new_input_size
)
*
sizeof
(
float
));
void
*
trans_tmp_ptr
=
malloc
(
sizeof
(
float
)
*
wino_iw
*
wino_iw
*
oc
*
ic
);
auto
weights_data_
=
weights_
.
mutable_data
<
float
>
();
if
(
!
choose_small_
)
{
lite
::
arm
::
math
::
weight_trans_c4_8x8
(
weights_data_
,
param
.
filter
->
data
<
float
>
(),
ic
,
oc
,
trans_tmp_ptr
);
}
else
{
lite
::
arm
::
math
::
weight_trans_c4_4x4
(
weights_data_
,
param
.
filter
->
data
<
float
>
(),
ic
,
oc
,
trans_tmp_ptr
);
}
free
(
trans_tmp_ptr
);
last_shape_
=
x_dims
;
}
...
...
@@ -145,14 +113,7 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
int
ow
=
o_dims
[
3
];
int
oc
=
o_dims
[
1
];
int
tile_block
=
8
;
#ifdef __aarch64__
tile_block
=
16
;
#endif
int
threads
=
ctx
.
threads
();
int
parallel_threads
=
(((
ow
+
5
)
/
6
)
*
((
oh
+
5
)
/
6
)
+
tile_block
-
1
)
/
tile_block
;
if
(
threads
<=
2
&&
parallel_threads
>=
threads
)
{
if
(
!
choose_small_
)
{
lite
::
arm
::
math
::
conv_compute_6x6_3x3
(
i_data
,
o_data
,
bs
,
...
...
@@ -167,19 +128,38 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
param
,
&
ctx
);
}
else
{
lite
::
arm
::
math
::
conv_winograd3x3
(
i_data
,
o_data
,
bs
,
oc
,
oh
,
ow
,
ic
,
ih
,
iw
,
w_data
,
b_data
,
param
,
&
ctx
);
int
tile_block
=
8
;
int
block_count
=
(((
ow
+
1
)
/
2
)
*
((
oh
+
1
)
/
2
)
+
tile_block
-
1
)
/
tile_block
;
if
(
block_count
!=
1
)
{
lite
::
arm
::
math
::
conv_compute_2x2_3x3
(
i_data
,
o_data
,
bs
,
oc
,
oh
,
ow
,
ic
,
ih
,
iw
,
w_data
,
b_data
,
param
,
&
ctx
);
}
else
{
lite
::
arm
::
math
::
conv_compute_2x2_3x3_small
(
i_data
,
o_data
,
bs
,
oc
,
oh
,
ow
,
ic
,
ih
,
iw
,
w_data
,
b_data
,
param
,
&
ctx
);
}
}
}
...
...
lite/kernels/arm/conv_winograd.h
浏览文件 @
f99c34c8
...
...
@@ -40,7 +40,9 @@ class WinogradConv : public KernelLite<TARGET(kARM), Ptype> {
Tensor
weights_
;
DDim
last_shape_
;
int
workspace_size_
{
0
};
int
last_kernel_is_c4_
{
-
1
};
int
last_function_
{
-
1
};
bool
choose_small_
{
false
};
int
wino_iw
{
8
};
};
}
// namespace arm
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录