Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
ce0f2bfd
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
338
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ce0f2bfd
编写于
2月 27, 2019
作者:
H
hjchen2
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize gemm data package, it will bring 22% speedup for ocr detection model
上级
c070770c
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
448 addition
and
449 deletion
+448
-449
src/operators/math/gemm.cpp
src/operators/math/gemm.cpp
+441
-430
src/operators/math/gemm.h
src/operators/math/gemm.h
+7
-19
未找到文件。
src/operators/math/gemm.cpp
浏览文件 @
ce0f2bfd
...
@@ -27,390 +27,418 @@ namespace paddle_mobile {
...
@@ -27,390 +27,418 @@ namespace paddle_mobile {
namespace
operators
{
namespace
operators
{
namespace
math
{
namespace
math
{
// 将A矩阵分块复制到连续内存(RowMajor)
#if __ARM_NEON
void
Gemm
::
PackMatrixA_4r
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
inline
float32x4_t
vandq_f32
(
float32x4_t
x
,
uint32x4_t
mask
)
{
float
*
buffer
)
{
return
vreinterpretq_f32_u32
(
vandq_u32
(
vreinterpretq_u32_f32
(
x
),
mask
));
const
float
*
a0
,
*
a1
,
*
a2
,
*
a3
;
for
(
int
i
=
0
;
i
<
m
-
m_tail
;
i
+=
MR
)
{
a0
=
A
+
i
*
lda
;
a1
=
A
+
(
i
+
1
)
*
lda
;
a2
=
A
+
(
i
+
2
)
*
lda
;
a3
=
A
+
(
i
+
3
)
*
lda
;
for
(
int
j
=
0
;
j
<
k
;
++
j
)
{
*
buffer
++
=
*
a0
++
;
*
buffer
++
=
*
a1
++
;
*
buffer
++
=
*
a2
++
;
*
buffer
++
=
*
a3
++
;
}
}
if
(
m_tail
!=
0
)
{
a0
=
&
A
(
m
-
m_tail
,
0
);
a1
=
a0
+
lda
;
a2
=
a0
+
2
*
lda
;
a3
=
a0
+
3
*
lda
;
switch
(
m_tail
)
{
case
1
:
a1
=
zero
;
case
2
:
a2
=
zero
;
case
3
:
a3
=
zero
;
break
;
default:
break
;
}
for
(
int
j
=
0
;
j
<
k
;
++
j
)
{
*
buffer
++
=
*
a0
++
;
*
buffer
++
=
*
a1
++
;
*
buffer
++
=
*
a2
++
;
*
buffer
++
=
*
a3
++
;
}
}
}
}
#endif
void
Gemm
::
PackMatrixA_6r
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
void
Gemm
::
PackMatrixA_6r
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
float
*
buffer
)
{
float
*
buffer
,
const
bool
parallel
)
{
const
int
i_length
=
m
-
m_tail
;
uint32_t
mask
[
8
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
4
,
5
};
for
(
int
i
=
0
;
i
<
i_length
;
i
+=
MR
)
{
int
remain_k
=
k
&
0x3
;
uint32x4_t
vzero
=
vdupq_n_u32
(
0
);
uint32x4_t
vmask1
=
vcltq_u32
(
vld1q_u32
(
mask
),
vdupq_n_u32
(
remain_k
));
#pragma omp parallel for if (parallel)
for
(
int
i
=
0
;
i
<
m
-
5
;
i
+=
6
)
{
const
float
*
a0
=
A
+
i
*
lda
;
const
float
*
a0
=
A
+
i
*
lda
;
const
float
*
a1
=
A
+
(
i
+
1
)
*
lda
;
const
float
*
a1
=
A
+
(
i
+
1
)
*
lda
;
const
float
*
a2
=
A
+
(
i
+
2
)
*
lda
;
const
float
*
a2
=
A
+
(
i
+
2
)
*
lda
;
const
float
*
a3
=
A
+
(
i
+
3
)
*
lda
;
const
float
*
a3
=
A
+
(
i
+
3
)
*
lda
;
const
float
*
a4
=
A
+
(
i
+
4
)
*
lda
;
const
float
*
a4
=
A
+
(
i
+
4
)
*
lda
;
const
float
*
a5
=
A
+
(
i
+
5
)
*
lda
;
const
float
*
a5
=
A
+
(
i
+
5
)
*
lda
;
float
*
local_buffer
=
buffer
+
i
*
k
;
float
*
out_ptr
=
buffer
+
i
*
k
;
for
(
int
j
=
0
;
j
<
k
;
++
j
)
{
*
local_buffer
++
=
*
a0
++
;
*
local_buffer
++
=
*
a1
++
;
*
local_buffer
++
=
*
a2
++
;
*
local_buffer
++
=
*
a3
++
;
*
local_buffer
++
=
*
a4
++
;
*
local_buffer
++
=
*
a5
++
;
}
}
if
(
m_tail
!=
0
)
{
const
float
*
a0
=
&
A
(
i_length
,
0
);
const
float
*
a1
=
a0
+
lda
;
const
float
*
a2
=
a0
+
2
*
lda
;
const
float
*
a3
=
a0
+
3
*
lda
;
const
float
*
a4
=
a0
+
4
*
lda
;
const
float
*
a5
=
a0
+
5
*
lda
;
float
*
local_buffer
=
buffer
+
i_length
*
k
;
switch
(
m_tail
)
{
case
1
:
a1
=
zero
;
case
2
:
a2
=
zero
;
case
3
:
a3
=
zero
;
case
4
:
a4
=
zero
;
case
5
:
a5
=
zero
;
break
;
default:
break
;
}
for
(
int
j
=
0
;
j
<
k
;
++
j
)
{
*
local_buffer
++
=
*
a0
++
;
*
local_buffer
++
=
*
a1
++
;
*
local_buffer
++
=
*
a2
++
;
*
local_buffer
++
=
*
a3
++
;
*
local_buffer
++
=
*
a4
++
;
*
local_buffer
++
=
*
a5
++
;
}
}
}
void
Gemm
::
PackMatrixA_omp_6r
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
int
loops
=
k
>>
2
;
float
*
buffer
)
{
if
(
loops
>
0
)
{
const
int
i_length
=
m
-
m_tail
;
#if __aarch64__
#pragma omp parallel for
for
(
int
l
=
0
;
l
<
loops
;
++
l
)
{
for
(
int
i
=
0
;
i
<
i_length
;
i
+=
MR
)
{
float32x4_t
_d0
=
vld1q_f32
(
a0
);
const
float
*
a0
=
A
+
i
*
lda
;
float32x4_t
_d1
=
vld1q_f32
(
a1
);
const
float
*
a1
=
A
+
(
i
+
1
)
*
lda
;
float32x4_t
_d2
=
vld1q_f32
(
a2
);
const
float
*
a2
=
A
+
(
i
+
2
)
*
lda
;
float32x4_t
_d3
=
vld1q_f32
(
a3
);
const
float
*
a3
=
A
+
(
i
+
3
)
*
lda
;
float32x4_t
_d4
=
vld1q_f32
(
a4
);
const
float
*
a4
=
A
+
(
i
+
4
)
*
lda
;
float32x4_t
_d5
=
vld1q_f32
(
a5
);
const
float
*
a5
=
A
+
(
i
+
5
)
*
lda
;
float
*
local_buffer
=
buffer
+
i
*
k
;
float32x4x2_t
_q0
=
vtrnq_f32
(
_d0
,
_d1
);
for
(
int
j
=
0
;
j
<
k
;
++
j
)
{
float32x4x2_t
_q1
=
vtrnq_f32
(
_d2
,
_d3
);
*
local_buffer
++
=
*
a0
++
;
float32x4x2_t
_q3
=
vtrnq_f32
(
_d4
,
_d5
);
*
local_buffer
++
=
*
a1
++
;
_d0
=
vcombine_f32
(
vget_low_f32
(
_q0
.
val
[
0
]),
vget_low_f32
(
_q1
.
val
[
0
]));
*
local_buffer
++
=
*
a2
++
;
_d1
=
vcombine_f32
(
vget_low_f32
(
_q0
.
val
[
1
]),
vget_low_f32
(
_q1
.
val
[
1
]));
*
local_buffer
++
=
*
a3
++
;
_d2
=
*
local_buffer
++
=
*
a4
++
;
vcombine_f32
(
vget_high_f32
(
_q0
.
val
[
0
]),
vget_high_f32
(
_q1
.
val
[
0
]));
*
local_buffer
++
=
*
a5
++
;
_d3
=
}
vcombine_f32
(
vget_high_f32
(
_q0
.
val
[
1
]),
vget_high_f32
(
_q1
.
val
[
1
]));
}
if
(
m_tail
!=
0
)
{
vst1q_f32
(
out_ptr
,
_d0
);
const
float
*
a0
=
&
A
(
i_length
,
0
);
vst1_f32
(
out_ptr
+
4
,
vget_low_f32
(
_q3
.
val
[
0
]));
const
float
*
a1
=
a0
+
lda
;
vst1q_f32
(
out_ptr
+
6
,
_d1
);
const
float
*
a2
=
a0
+
2
*
lda
;
vst1_f32
(
out_ptr
+
10
,
vget_low_f32
(
_q3
.
val
[
1
]));
const
float
*
a3
=
a0
+
3
*
lda
;
vst1q_f32
(
out_ptr
+
12
,
_d2
);
const
float
*
a4
=
a0
+
4
*
lda
;
vst1_f32
(
out_ptr
+
16
,
vget_high_f32
(
_q3
.
val
[
0
]));
const
float
*
a5
=
a0
+
5
*
lda
;
vst1q_f32
(
out_ptr
+
18
,
_d3
);
float
*
local_buffer
=
buffer
+
i_length
*
k
;
vst1_f32
(
out_ptr
+
22
,
vget_high_f32
(
_q3
.
val
[
1
]));
switch
(
m_tail
)
{
case
1
:
a0
+=
4
;
a1
=
zero
;
a1
+=
4
;
case
2
:
a2
+=
4
;
a2
=
zero
;
a3
+=
4
;
case
3
:
a4
+=
4
;
a3
=
zero
;
a5
+=
4
;
case
4
:
out_ptr
+=
24
;
a4
=
zero
;
}
case
5
:
#else
a5
=
zero
;
asm
volatile
(
break
;
"loop_4k_%=:
\n
"
default:
"vld1.32 {d0-d1}, [%[a0]]!
\n
"
break
;
"vld1.32 {d2-d3}, [%[a1]]!
\n
"
}
"vld1.32 {d4-d5}, [%[a2]]!
\n
"
for
(
int
j
=
0
;
j
<
k
;
++
j
)
{
"vld1.32 {d6-d7}, [%[a3]]!
\n
"
*
local_buffer
++
=
*
a0
++
;
"vld1.32 {d8-d9}, [%[a4]]!
\n
"
*
local_buffer
++
=
*
a1
++
;
"vld1.32 {d10-d11}, [%[a5]]!
\n
"
*
local_buffer
++
=
*
a2
++
;
"vtrn.32 q0, q1
\n
"
*
local_buffer
++
=
*
a3
++
;
"vtrn.32 q2, q3
\n
"
*
local_buffer
++
=
*
a4
++
;
"vtrn.32 q4, q5
\n
"
*
local_buffer
++
=
*
a5
++
;
"vswp.32 d1, d4
\n
"
"vswp.32 d3, d6
\n
"
"vst1.32 {q0}, [%[out]]!
\n
"
"vst1.32 {d8}, [%[out]]!
\n
"
"vst1.32 {q1}, [%[out]]!
\n
"
"vst1.32 {d10}, [%[out]]!
\n
"
"vst1.32 {q2}, [%[out]]!
\n
"
"vst1.32 {d9}, [%[out]]!
\n
"
"vst1.32 {q3}, [%[out]]!
\n
"
"vst1.32 {d11}, [%[out]]!
\n
"
"subs %[loops], #1
\n
"
"bne loop_4k_%=
\n
"
:
[
out
]
"+r"
(
out_ptr
),
[
a0
]
"+r"
(
a0
),
[
a1
]
"+r"
(
a1
),
[
a2
]
"+r"
(
a2
),
[
a3
]
"+r"
(
a3
),
[
a4
]
"+r"
(
a4
),
[
a5
]
"+r"
(
a5
),
[
loops
]
"+r"
(
loops
)
:
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
);
#endif
}
}
}
}
void
Gemm
::
PackMatrixA_8r
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
if
(
remain_k
>
0
)
{
float
*
buffer
)
{
float32x4_t
_d0
=
vld1q_f32
(
a0
);
const
int
i_length
=
m
-
m_tail
;
float32x4_t
_d1
=
vld1q_f32
(
a1
);
for
(
int
i
=
0
;
i
<
i_length
;
i
+=
MR
)
{
float32x4_t
_d2
=
vld1q_f32
(
a2
);
const
float
*
a0
=
A
+
i
*
lda
;
float32x4_t
_d3
=
vld1q_f32
(
a3
);
const
float
*
a1
=
A
+
(
i
+
1
)
*
lda
;
float32x4_t
_d4
=
vld1q_f32
(
a4
);
const
float
*
a2
=
A
+
(
i
+
2
)
*
lda
;
float32x4_t
_d5
=
vld1q_f32
(
a5
);
const
float
*
a3
=
A
+
(
i
+
3
)
*
lda
;
const
float
*
a4
=
A
+
(
i
+
4
)
*
lda
;
_d0
=
vandq_f32
(
_d0
,
vmask1
);
const
float
*
a5
=
A
+
(
i
+
5
)
*
lda
;
_d1
=
vandq_f32
(
_d1
,
vmask1
);
const
float
*
a6
=
A
+
(
i
+
6
)
*
lda
;
_d2
=
vandq_f32
(
_d2
,
vmask1
);
const
float
*
a7
=
A
+
(
i
+
7
)
*
lda
;
_d3
=
vandq_f32
(
_d3
,
vmask1
);
float
*
local_buffer
=
buffer
+
i
*
k
;
_d4
=
vandq_f32
(
_d4
,
vmask1
);
for
(
int
j
=
0
;
j
<
k
;
++
j
)
{
_d5
=
vandq_f32
(
_d5
,
vmask1
);
*
local_buffer
++
=
*
a0
++
;
*
local_buffer
++
=
*
a1
++
;
float32x4x2_t
_q0
=
vtrnq_f32
(
_d0
,
_d1
);
*
local_buffer
++
=
*
a2
++
;
float32x4x2_t
_q1
=
vtrnq_f32
(
_d2
,
_d3
);
*
local_buffer
++
=
*
a3
++
;
float32x4x2_t
_q3
=
vtrnq_f32
(
_d4
,
_d5
);
*
local_buffer
++
=
*
a4
++
;
_d0
=
vcombine_f32
(
vget_low_f32
(
_q0
.
val
[
0
]),
vget_low_f32
(
_q1
.
val
[
0
]));
*
local_buffer
++
=
*
a5
++
;
_d1
=
vcombine_f32
(
vget_low_f32
(
_q0
.
val
[
1
]),
vget_low_f32
(
_q1
.
val
[
1
]));
*
local_buffer
++
=
*
a6
++
;
_d2
=
vcombine_f32
(
vget_high_f32
(
_q0
.
val
[
0
]),
vget_high_f32
(
_q1
.
val
[
0
]));
*
local_buffer
++
=
*
a7
++
;
}
switch
(
remain_k
)
{
}
case
3
:
if
(
m_tail
!=
0
)
{
vst1q_f32
(
out_ptr
+
12
,
_d2
);
const
float
*
a0
=
&
A
(
i_length
,
0
);
vst1_f32
(
out_ptr
+
16
,
vget_high_f32
(
_q3
.
val
[
0
]));
case
2
:
vst1q_f32
(
out_ptr
+
6
,
_d1
);
vst1_f32
(
out_ptr
+
10
,
vget_low_f32
(
_q3
.
val
[
1
]));
case
1
:
vst1q_f32
(
out_ptr
,
_d0
);
vst1_f32
(
out_ptr
+
4
,
vget_low_f32
(
_q3
.
val
[
0
]));
default:
break
;
}
}
}
int
remain_m
=
m
%
6
;
if
(
remain_m
)
{
int
remain_m_start
=
m
-
remain_m
;
const
float
*
a0
=
A
+
remain_m_start
*
lda
;
const
float
*
a1
=
a0
+
lda
;
const
float
*
a1
=
a0
+
lda
;
const
float
*
a2
=
a0
+
2
*
lda
;
const
float
*
a2
=
a0
+
2
*
lda
;
const
float
*
a3
=
a0
+
3
*
lda
;
const
float
*
a3
=
a0
+
3
*
lda
;
const
float
*
a4
=
a0
+
4
*
lda
;
const
float
*
a4
=
a0
+
4
*
lda
;
const
float
*
a5
=
a0
+
5
*
lda
;
const
float
*
a5
=
a0
+
5
*
lda
;
const
float
*
a6
=
a0
+
6
*
lda
;
float
*
out_ptr
=
buffer
+
remain_m_start
*
k
;
const
float
*
a7
=
a0
+
7
*
lda
;
float
*
local_buffer
=
buffer
+
i_length
*
k
;
uint32x4_t
vmask2
=
vcltq_u32
(
vld1q_u32
(
mask
),
vdupq_n_u32
(
remain_m
));
switch
(
m_tail
)
{
uint32x4_t
vmask3
=
vcltq_u32
(
vld1q_u32
(
mask
+
4
),
vdupq_n_u32
(
remain_m
));
case
1
:
a1
=
zero
;
int
loops
=
k
>>
2
;
case
2
:
if
(
loops
>
0
)
{
a2
=
zero
;
#if __aarch64__
case
3
:
for
(
int
l
=
0
;
l
<
loops
;
++
l
)
{
a3
=
zero
;
float32x4_t
_d0
=
vld1q_f32
(
a0
);
case
4
:
float32x4_t
_d1
=
vld1q_f32
(
a1
);
a4
=
zero
;
float32x4_t
_d2
=
vld1q_f32
(
a2
);
case
5
:
float32x4_t
_d3
=
vld1q_f32
(
a3
);
a5
=
zero
;
float32x4_t
_d4
=
vld1q_f32
(
a4
);
case
6
:
float32x4_t
_d5
=
vld1q_f32
(
a5
);
a6
=
zero
;
case
7
:
float32x4x2_t
_q0
=
vtrnq_f32
(
_d0
,
_d1
);
a7
=
zero
;
float32x4x2_t
_q1
=
vtrnq_f32
(
_d2
,
_d3
);
break
;
float32x4x2_t
_q3
=
vtrnq_f32
(
_d4
,
_d5
);
default:
_d0
=
vcombine_f32
(
vget_low_f32
(
_q0
.
val
[
0
]),
vget_low_f32
(
_q1
.
val
[
0
]));
break
;
_d1
=
vcombine_f32
(
vget_low_f32
(
_q0
.
val
[
1
]),
vget_low_f32
(
_q1
.
val
[
1
]));
}
_d2
=
for
(
int
j
=
0
;
j
<
k
;
++
j
)
{
vcombine_f32
(
vget_high_f32
(
_q0
.
val
[
0
]),
vget_high_f32
(
_q1
.
val
[
0
]));
*
local_buffer
++
=
*
a0
++
;
_d3
=
*
local_buffer
++
=
*
a1
++
;
vcombine_f32
(
vget_high_f32
(
_q0
.
val
[
1
]),
vget_high_f32
(
_q1
.
val
[
1
]));
*
local_buffer
++
=
*
a2
++
;
*
local_buffer
++
=
*
a3
++
;
_d0
=
vandq_f32
(
_d0
,
vmask2
);
*
local_buffer
++
=
*
a4
++
;
_d1
=
vandq_f32
(
_d1
,
vmask2
);
*
local_buffer
++
=
*
a5
++
;
_d2
=
vandq_f32
(
_d2
,
vmask2
);
*
local_buffer
++
=
*
a6
++
;
_d3
=
vandq_f32
(
_d3
,
vmask2
);
*
local_buffer
++
=
*
a7
++
;
_d4
=
vandq_f32
(
_q3
.
val
[
0
],
vmask3
);
_d5
=
vandq_f32
(
_q3
.
val
[
1
],
vmask3
);
vst1q_f32
(
out_ptr
,
_d0
);
vst1_f32
(
out_ptr
+
4
,
vget_low_f32
(
_d4
));
vst1q_f32
(
out_ptr
+
6
,
_d1
);
vst1_f32
(
out_ptr
+
10
,
vget_low_f32
(
_d5
));
vst1q_f32
(
out_ptr
+
12
,
_d2
);
vst1_f32
(
out_ptr
+
16
,
vget_high_f32
(
_d4
));
vst1q_f32
(
out_ptr
+
18
,
_d3
);
vst1_f32
(
out_ptr
+
22
,
vget_high_f32
(
_d5
));
a0
+=
4
;
a1
+=
4
;
a2
+=
4
;
a3
+=
4
;
a4
+=
4
;
a5
+=
4
;
out_ptr
+=
24
;
}
#else
asm
volatile
(
"loop_4k_%=:
\n
"
"vld1.32 {d0-d1}, [%[a0]]!
\n
"
"vld1.32 {d2-d3}, [%[a1]]!
\n
"
"vld1.32 {d4-d5}, [%[a2]]!
\n
"
"vld1.32 {d6-d7}, [%[a3]]!
\n
"
"vld1.32 {d8-d9}, [%[a4]]!
\n
"
"vld1.32 {d10-d11}, [%[a5]]!
\n
"
"vtrn.32 q0, q1
\n
"
"vtrn.32 q2, q3
\n
"
"vtrn.32 q4, q5
\n
"
"vswp.32 d1, d4
\n
"
"vswp.32 d3, d6
\n
"
"vbif q0, %q[vzero], %q[vmask2]
\n
"
"vbif q1, %q[vzero], %q[vmask2]
\n
"
"vbif q2, %q[vzero], %q[vmask2]
\n
"
"vbif q3, %q[vzero], %q[vmask2]
\n
"
"vbif q4, %q[vzero], %q[vmask3]
\n
"
"vbif q5, %q[vzero], %q[vmask3]
\n
"
"vst1.32 {q0}, [%[out]]!
\n
"
"vst1.32 {d8}, [%[out]]!
\n
"
"vst1.32 {q1}, [%[out]]!
\n
"
"vst1.32 {d10}, [%[out]]!
\n
"
"vst1.32 {q2}, [%[out]]!
\n
"
"vst1.32 {d9}, [%[out]]!
\n
"
"vst1.32 {q3}, [%[out]]!
\n
"
"vst1.32 {d11}, [%[out]]!
\n
"
"subs %[loops], #1
\n
"
"bne loop_4k_%=
\n
"
:
[
out
]
"+r"
(
out_ptr
),
[
a0
]
"+r"
(
a0
),
[
a1
]
"+r"
(
a1
),
[
a2
]
"+r"
(
a2
),
[
a3
]
"+r"
(
a3
),
[
a4
]
"+r"
(
a4
),
[
a5
]
"+r"
(
a5
),
[
loops
]
"+r"
(
loops
)
:
[
vmask2
]
"w"
(
vmask2
),
[
vmask3
]
"w"
(
vmask3
),
[
vzero
]
"w"
(
vzero
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
);
#endif
}
}
}
}
void
Gemm
::
PackMatrixA_omp_8r
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
if
(
remain_k
>
0
)
{
float
*
buffer
)
{
float32x4_t
_d0
=
vld1q_f32
(
a0
);
const
int
i_length
=
m
-
m_tail
;
float32x4_t
_d1
=
vld1q_f32
(
a1
);
#pragma omp parallel for
float32x4_t
_d2
=
vld1q_f32
(
a2
);
for
(
int
i
=
0
;
i
<
i_length
;
i
+=
MR
)
{
float32x4_t
_d3
=
vld1q_f32
(
a3
);
const
float
*
a0
=
A
+
i
*
lda
;
float32x4_t
_d4
=
vld1q_f32
(
a4
);
const
float
*
a1
=
A
+
(
i
+
1
)
*
lda
;
float32x4_t
_d5
=
vld1q_f32
(
a5
);
const
float
*
a2
=
A
+
(
i
+
2
)
*
lda
;
const
float
*
a3
=
A
+
(
i
+
3
)
*
lda
;
_d0
=
vandq_f32
(
_d0
,
vmask1
);
const
float
*
a4
=
A
+
(
i
+
4
)
*
lda
;
_d1
=
vandq_f32
(
_d1
,
vmask1
);
const
float
*
a5
=
A
+
(
i
+
5
)
*
lda
;
_d2
=
vandq_f32
(
_d2
,
vmask1
);
const
float
*
a6
=
A
+
(
i
+
6
)
*
lda
;
_d3
=
vandq_f32
(
_d3
,
vmask1
);
const
float
*
a7
=
A
+
(
i
+
7
)
*
lda
;
_d4
=
vandq_f32
(
_d4
,
vmask1
);
float
*
local_buffer
=
buffer
+
i
*
k
;
_d5
=
vandq_f32
(
_d5
,
vmask1
);
for
(
int
j
=
0
;
j
<
k
;
++
j
)
{
*
local_buffer
++
=
*
a0
++
;
float32x4x2_t
_q0
=
vtrnq_f32
(
_d0
,
_d1
);
*
local_buffer
++
=
*
a1
++
;
float32x4x2_t
_q1
=
vtrnq_f32
(
_d2
,
_d3
);
*
local_buffer
++
=
*
a2
++
;
float32x4x2_t
_q3
=
vtrnq_f32
(
_d4
,
_d5
);
*
local_buffer
++
=
*
a3
++
;
_d0
=
vcombine_f32
(
vget_low_f32
(
_q0
.
val
[
0
]),
vget_low_f32
(
_q1
.
val
[
0
]));
*
local_buffer
++
=
*
a4
++
;
_d1
=
vcombine_f32
(
vget_low_f32
(
_q0
.
val
[
1
]),
vget_low_f32
(
_q1
.
val
[
1
]));
*
local_buffer
++
=
*
a5
++
;
_d2
=
vcombine_f32
(
vget_high_f32
(
_q0
.
val
[
0
]),
vget_high_f32
(
_q1
.
val
[
0
]));
*
local_buffer
++
=
*
a6
++
;
// _d3 = vcombine_f32(vget_high_f32(_q0.val[1]),
*
local_buffer
++
=
*
a7
++
;
// vget_high_f32(_q1.val[1]));
}
}
_d0
=
vandq_f32
(
_d0
,
vmask2
);
if
(
m_tail
!=
0
)
{
_d1
=
vandq_f32
(
_d1
,
vmask2
);
const
float
*
a0
=
&
A
(
i_length
,
0
);
_d2
=
vandq_f32
(
_d2
,
vmask2
);
const
float
*
a1
=
a0
+
lda
;
// _d3 = vandq_f32(_d3, vmask2);
const
float
*
a2
=
a0
+
2
*
lda
;
_d4
=
vandq_f32
(
_q3
.
val
[
0
],
vmask3
);
const
float
*
a3
=
a0
+
3
*
lda
;
_d5
=
vandq_f32
(
_q3
.
val
[
1
],
vmask3
);
const
float
*
a4
=
a0
+
4
*
lda
;
const
float
*
a5
=
a0
+
5
*
lda
;
switch
(
remain_k
)
{
const
float
*
a6
=
a0
+
6
*
lda
;
case
3
:
const
float
*
a7
=
a0
+
7
*
lda
;
vst1q_f32
(
out_ptr
+
12
,
_d2
);
float
*
local_buffer
=
buffer
+
i_length
*
k
;
vst1_f32
(
out_ptr
+
16
,
vget_high_f32
(
_d4
));
switch
(
m_tail
)
{
case
2
:
case
1
:
vst1q_f32
(
out_ptr
+
6
,
_d1
);
a1
=
zero
;
vst1_f32
(
out_ptr
+
10
,
vget_low_f32
(
_d5
));
case
2
:
case
1
:
a2
=
zero
;
vst1q_f32
(
out_ptr
,
_d0
);
case
3
:
vst1_f32
(
out_ptr
+
4
,
vget_low_f32
(
_d4
));
a3
=
zero
;
default:
case
4
:
break
;
a4
=
zero
;
}
case
5
:
a5
=
zero
;
case
6
:
a6
=
zero
;
case
7
:
a7
=
zero
;
break
;
default:
break
;
}
for
(
int
j
=
0
;
j
<
k
;
++
j
)
{
*
local_buffer
++
=
*
a0
++
;
*
local_buffer
++
=
*
a1
++
;
*
local_buffer
++
=
*
a2
++
;
*
local_buffer
++
=
*
a3
++
;
*
local_buffer
++
=
*
a4
++
;
*
local_buffer
++
=
*
a5
++
;
*
local_buffer
++
=
*
a6
++
;
*
local_buffer
++
=
*
a7
++
;
}
}
}
}
}
}
// 将B矩阵分块复制到连续内存(RowMajor)
// 将B矩阵分块复制到连续内存(RowMajor)
void
Gemm
::
PackMatrixB_8c
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
void
Gemm
::
PackMatrixB_8c
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
float
*
buffer
)
{
float
*
buffer
,
const
bool
parallel
)
{
const
int
j_length
=
n
-
n_tail
;
const
int
j_length
=
n
-
n_tail
;
for
(
int
j
=
0
;
j
<
j_length
;
j
+=
NR
)
{
float
*
local_buffer
=
buffer
+
j
*
k
;
#pragma omp parallel for if (parallel)
for
(
int
i
=
0
;
i
<
k
;
++
i
)
{
for
(
int
i
=
0
;
i
<
k
;
++
i
)
{
int
j
=
0
;
for
(;
j
<
j_length
-
31
;
j
+=
32
)
{
float
*
local_buffer0
=
buffer
+
j
*
k
+
i
*
NR
;
float
*
local_buffer1
=
buffer
+
(
j
+
8
)
*
k
+
i
*
NR
;
float
*
local_buffer2
=
buffer
+
(
j
+
16
)
*
k
+
i
*
NR
;
float
*
local_buffer3
=
buffer
+
(
j
+
24
)
*
k
+
i
*
NR
;
const
float
*
b0
=
B
+
i
*
ldb
+
j
;
#if __aarch64__
asm
volatile
(
"prfm pldl1keep, [%[b0]]
\n
"
"ld1 {v0.4s, v1.4s}, [%[b0]], #32
\n
"
"ld1 {v2.4s, v3.4s}, [%[b0]], #32
\n
"
"ld1 {v4.4s, v5.4s}, [%[b0]], #32
\n
"
"ld1 {v6.4s, v7.4s}, [%[b0]]
\n
"
"st1 {v0.4s, v1.4s}, [%[local_buffer0]], #32
\n
"
"st1 {v2.4s, v3.4s}, [%[local_buffer1]], #32
\n
"
"st1 {v4.4s, v5.4s}, [%[local_buffer2]], #32
\n
"
"st1 {v6.4s, v7.4s}, [%[local_buffer3]], #32
\n
"
:
[
local_buffer0
]
"+r"
(
local_buffer0
),
[
local_buffer1
]
"+r"
(
local_buffer1
),
[
local_buffer2
]
"+r"
(
local_buffer2
),
[
local_buffer3
]
"+r"
(
local_buffer3
),
[
b0
]
"+r"
(
b0
)
:
:
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
);
#else
asm
volatile
(
// "pld [%[b]] \n"
"vld1.32 {q0, q1}, [%[b0]]!
\n
"
"vld1.32 {q2, q3}, [%[b0]]!
\n
"
"vld1.32 {q4, q5}, [%[b0]]!
\n
"
"vld1.32 {q6, q7}, [%[b0]]!
\n
"
"vst1.32 {q0, q1}, [%[local_buffer0]]!
\n
"
"vst1.32 {q2, q3}, [%[local_buffer1]]!
\n
"
"vst1.32 {q4, q5}, [%[local_buffer2]]!
\n
"
"vst1.32 {q6, q7}, [%[local_buffer3]]!
\n
"
:
[
local_buffer0
]
"+r"
(
local_buffer0
),
[
local_buffer1
]
"+r"
(
local_buffer1
),
[
local_buffer2
]
"+r"
(
local_buffer2
),
[
local_buffer3
]
"+r"
(
local_buffer3
),
[
b0
]
"+r"
(
b0
)
:
:
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
);
#endif // __aarch64__
}
for
(;
j
<
j_length
-
15
;
j
+=
16
)
{
float
*
local_buffer0
=
buffer
+
j
*
k
+
i
*
NR
;
float
*
local_buffer1
=
buffer
+
(
j
+
8
)
*
k
+
i
*
NR
;
const
float
*
b0
=
&
B
(
i
,
j
);
const
float
*
b0
=
&
B
(
i
,
j
);
#if __ARM_NEON
#if __ARM_NEON
#if __aarch64__
#if __aarch64__
asm
volatile
(
asm
volatile
(
"prfm pldl1keep, [%[b0]]
\n\t
"
"prfm pldl1keep, [%[b0]]
\n
"
"ld1 {v0.4s, v1.4s}, [%[b0]]
\n\t
"
"ld1 {v0.4s, v1.4s}, [%[b0]], #32
\n
"
"st1 {v0.4s, v1.4s}, [%[local_buffer]], #32
\n\t
"
"ld1 {v2.4s, v3.4s}, [%[b0]]
\n
"
:
[
local_buffer
]
"+r"
(
local_buffer
)
"st1 {v0.4s, v1.4s}, [%[local_buffer0]], #32
\n
"
:
[
b0
]
"r"
(
b0
)
"st1 {v2.4s, v3.4s}, [%[local_buffer1]], #32
\n
"
:
"memory"
,
"v0"
,
"v1"
);
:
[
local_buffer0
]
"+r"
(
local_buffer0
),
[
local_buffer1
]
"+r"
(
local_buffer1
),
[
b0
]
"+r"
(
b0
)
:
:
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
);
#else
#else
asm
volatile
(
asm
volatile
(
// "pld [%[b0]] \n\t"
// "pld [%[b0]] \n"
"vld1.32 {q0, q1}, [%[b0]]
\n\t
"
"vld1.32 {q0, q1}, [%[b0]]!
\n
"
"vst1.32 {q0, q1}, [%[local_buffer]]!
\n\t
"
"vld1.32 {q2, q3}, [%[b0]]
\n
"
:
[
local_buffer
]
"+r"
(
local_buffer
)
"vst1.32 {q0, q1}, [%[local_buffer0]]!
\n
"
:
[
b0
]
"r"
(
b0
)
"vst1.32 {q2, q3}, [%[local_buffer1]]!
\n
"
:
"memory"
,
"q0"
,
"q1"
);
:
[
local_buffer0
]
"+r"
(
local_buffer0
),
[
local_buffer1
]
"+r"
(
local_buffer1
),
[
b0
]
"+r"
(
b0
)
:
:
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
);
#endif // __aarch64__
#endif // __aarch64__
#else
*
local_buffer
++
=
*
b0
++
;
*
local_buffer
++
=
*
b0
++
;
*
local_buffer
++
=
*
b0
++
;
*
local_buffer
++
=
*
b0
++
;
*
local_buffer
++
=
*
b0
++
;
*
local_buffer
++
=
*
b0
++
;
*
local_buffer
++
=
*
b0
++
;
*
local_buffer
++
=
*
b0
++
;
#endif // __ARM_NEON
#endif // __ARM_NEON
}
}
}
for
(;
j
<
j_length
;
j
+=
NR
)
{
if
(
n_tail
!=
0
)
{
float
*
local_buffer
=
buffer
+
j
*
k
+
i
*
NR
;
float
*
local_buffer
=
buffer
+
j_length
*
k
;
for
(
int
i
=
0
;
i
<
k
;
++
i
)
{
const
float
*
b0
=
&
B
(
i
,
j_length
);
for
(
int
j
=
j_length
;
j
<
n
;
++
j
)
{
*
local_buffer
++
=
*
b0
++
;
}
for
(
int
j
=
n
;
j
<
j_length
+
NR
;
++
j
)
{
*
local_buffer
++
=
0
;
}
}
}
}
void
Gemm
::
PackMatrixB_omp_8c
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
float
*
buffer
)
{
const
int
j_length
=
n
-
n_tail
;
#pragma omp parallel for
for
(
int
j
=
0
;
j
<
j_length
;
j
+=
NR
)
{
float
*
local_buffer
=
buffer
+
j
*
k
;
for
(
int
i
=
0
;
i
<
k
;
++
i
)
{
const
float
*
b0
=
&
B
(
i
,
j
);
const
float
*
b0
=
&
B
(
i
,
j
);
#if __ARM_NEON
#if __aarch64__
#if __aarch64__
asm
volatile
(
asm
volatile
(
"prfm
pldl1keep, [%[b0]]
\n\t
"
"prfm
pldl1keep, [%[b0]]
\n
"
"ld1
{v0.4s, v1.4s}, [%[b0]]
\n\t
"
"ld1
{v0.4s, v1.4s}, [%[b0]]
\n
"
"st1
{v0.4s, v1.4s}, [%[local_buffer]], #32
\n\t
"
"st1
{v0.4s, v1.4s}, [%[local_buffer]], #32
\n
"
:
[
local_buffer
]
"+r"
(
local_buffer
)
:
[
local_buffer
]
"+r"
(
local_buffer
)
:
[
b0
]
"r"
(
b0
)
:
[
b0
]
"r"
(
b0
)
:
"memory"
,
"v0"
,
"v1"
);
:
"memory"
,
"v0"
,
"v1"
);
#else
#else
asm
volatile
(
asm
volatile
(
//
"pld [%[b0]] \n\t
"
//
"pld [%[b]] \n
"
"vld1.32
{q0, q1}, [%[b0]]
\n\t
"
"vld1.32
{q0, q1}, [%[b0]]
\n
"
"vst1.32
{q0, q1}, [%[local_buffer]]!
\n\t
"
"vst1.32
{q0, q1}, [%[local_buffer]]
\n
"
:
[
local_buffer
]
"+r"
(
local_buffer
)
:
[
local_buffer
]
"+r"
(
local_buffer
)
:
[
b0
]
"r"
(
b0
)
:
[
b0
]
"r"
(
b0
)
:
"memory"
,
"q0"
,
"q1"
);
:
"memory"
,
"q0"
,
"q1"
);
#endif // __aarch64__
#endif // __aarch64__
#else
*
local_buffer
++
=
*
b0
++
;
*
local_buffer
++
=
*
b0
++
;
*
local_buffer
++
=
*
b0
++
;
*
local_buffer
++
=
*
b0
++
;
*
local_buffer
++
=
*
b0
++
;
*
local_buffer
++
=
*
b0
++
;
*
local_buffer
++
=
*
b0
++
;
*
local_buffer
++
=
*
b0
++
;
#endif // __ARM_NEON
}
}
}
}
if
(
n_tail
!=
0
)
{
if
(
n_tail
!=
0
)
{
uint32_t
mask
[
8
]
=
{
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
};
uint32x4_t
vzero
=
vdupq_n_u32
(
0
);
uint32x4_t
vmask1
=
vcltq_u32
(
vld1q_u32
(
mask
),
vdupq_n_u32
(
n_tail
));
uint32x4_t
vmask2
=
vcltq_u32
(
vld1q_u32
(
mask
+
4
),
vdupq_n_u32
(
n_tail
));
float
*
local_buffer
=
buffer
+
j_length
*
k
;
float
*
local_buffer
=
buffer
+
j_length
*
k
;
for
(
int
i
=
0
;
i
<
k
;
++
i
)
{
for
(
int
i
=
0
;
i
<
k
;
++
i
)
{
const
float
*
b0
=
&
B
(
i
,
j_length
);
const
float
*
b0
=
&
B
(
i
,
j_length
);
for
(
int
j
=
j_length
;
j
<
n
;
++
j
)
{
#if __aarch64__
*
local_buffer
++
=
*
b0
++
;
asm
volatile
(
}
"prfm pldl1keep, [%[b0]]
\n
"
for
(
int
j
=
n
;
j
<
j_length
+
NR
;
++
j
)
{
"ld1 {v0.4s, v1.4s}, [%[b0]]
\n
"
*
local_buffer
++
=
0
;
"BIF v0.8b, %[vzero].8b, %[vmask1].8b
\n
"
}
"BIF v1.8b, %[vzero].8b, %[vmask2].8b
\n
"
"st1 {v0.4s, v1.4s}, [%[local_buffer]], #32
\n
"
:
[
local_buffer
]
"+r"
(
local_buffer
)
:
[
vmask1
]
"w"
(
vmask1
),
[
vmask2
]
"w"
(
vmask2
),
[
vzero
]
"w"
(
vzero
),
[
b0
]
"r"
(
b0
)
:
"memory"
,
"v0"
,
"v1"
);
#else
asm
volatile
(
"vld1.32 {q0, q1}, [%[b0]]
\n
"
"vbif q0, %q[vzero], %q[vmask1]
\n
"
"vbif q1, %q[vzero], %q[vmask2]
\n
"
"vst1.32 {q0, q1}, [%[local_buffer]]!
\n
"
:
[
local_buffer
]
"+r"
(
local_buffer
)
:
[
vmask1
]
"w"
(
vmask1
),
[
vmask2
]
"w"
(
vmask2
),
[
vzero
]
"w"
(
vzero
),
[
b0
]
"r"
(
b0
)
:
"memory"
,
"q0"
,
"q1"
);
#endif
}
}
}
}
}
}
...
@@ -418,39 +446,10 @@ void Gemm::PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb,
...
@@ -418,39 +446,10 @@ void Gemm::PackMatrixB_omp_8c(int k, int n, int n_tail, const float *B, int ldb,
#if __ARM_NEON
#if __ARM_NEON
#if __aarch64__
#if __aarch64__
void
Gemm
::
PackMatrixB_12c
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
void
Gemm
::
PackMatrixB_12c
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
float
*
buffer
)
{
float
*
buffer
,
const
bool
parallel
)
{
const
int
j_length
=
n
-
n_tail
;
const
int
j_length
=
n
-
n_tail
;
for
(
int
j
=
0
;
j
<
j_length
;
j
+=
NR
)
{
float
*
local_buffer
=
buffer
+
j
*
k
;
for
(
int
i
=
0
;
i
<
k
;
++
i
)
{
const
float
*
b0
=
&
B
(
i
,
j
);
asm
volatile
(
"prfm pldl2keep, [%[b0], #64]
\n\t
"
"ld1 {v0.4s, v1.4s, v2.4s}, [%[b0]]
\n\t
"
"st1 {v0.4s, v1.4s, v2.4s}, [%[local_buffer]], #48
\n\t
"
:
[
local_buffer
]
"+r"
(
local_buffer
)
:
[
b0
]
"r"
(
b0
)
:
"memory"
,
"v0"
,
"v1"
,
"v2"
);
}
}
if
(
n_tail
!=
0
)
{
float
*
local_buffer
=
buffer
+
j_length
*
k
;
for
(
int
i
=
0
;
i
<
k
;
++
i
)
{
const
float
*
b0
=
&
B
(
i
,
j_length
);
for
(
int
j
=
j_length
;
j
<
n
;
++
j
)
{
*
local_buffer
++
=
*
b0
++
;
}
for
(
int
j
=
n
;
j
<
j_length
+
NR
;
++
j
)
{
*
local_buffer
++
=
0
;
}
}
}
}
void
Gemm
::
PackMatrixB_omp_12c
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
#pragma omp parallel for if (parallel)
int
ldb
,
float
*
buffer
)
{
const
int
j_length
=
n
-
n_tail
;
#pragma omp parallel for
for
(
int
j
=
0
;
j
<
j_length
;
j
+=
NR
)
{
for
(
int
j
=
0
;
j
<
j_length
;
j
+=
NR
)
{
float
*
local_buffer
=
buffer
+
j
*
k
;
float
*
local_buffer
=
buffer
+
j
*
k
;
for
(
int
i
=
0
;
i
<
k
;
++
i
)
{
for
(
int
i
=
0
;
i
<
k
;
++
i
)
{
...
@@ -479,39 +478,10 @@ void Gemm::PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B,
...
@@ -479,39 +478,10 @@ void Gemm::PackMatrixB_omp_12c(int k, int n, int n_tail, const float *B,
}
}
void
Gemm
::
PackMatrixB_16c
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
void
Gemm
::
PackMatrixB_16c
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
float
*
buffer
)
{
float
*
buffer
,
const
bool
parallel
)
{
const
int
j_length
=
n
-
n_tail
;
const
int
j_length
=
n
-
n_tail
;
for
(
int
j
=
0
;
j
<
n
-
n_tail
;
j
+=
NR
)
{
float
*
local_buffer
=
buffer
+
j
*
k
;
for
(
int
i
=
0
;
i
<
k
;
++
i
)
{
const
float
*
b0
=
&
B
(
i
,
j
);
asm
volatile
(
"prfm pldl2keep, [%[b0], #64]
\n\t
"
"ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[b0]]
\n\t
"
"st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [%[local_buffer]], #64
\n\t
"
:
[
local_buffer
]
"+r"
(
local_buffer
)
:
[
b0
]
"r"
(
b0
)
:
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
);
}
}
if
(
n_tail
!=
0
)
{
float
*
local_buffer
=
buffer
+
j_length
*
k
;
for
(
int
i
=
0
;
i
<
k
;
++
i
)
{
const
float
*
b0
=
&
B
(
i
,
j_length
);
for
(
int
j
=
j_length
;
j
<
n
;
++
j
)
{
*
local_buffer
++
=
*
b0
++
;
}
for
(
int
j
=
n
;
j
<
j_length
+
NR
;
++
j
)
{
*
local_buffer
++
=
0
;
}
}
}
}
void
Gemm
::
PackMatrixB_omp_16c
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
#pragma omp parallel for if (parallel)
int
ldb
,
float
*
buffer
)
{
const
int
j_length
=
n
-
n_tail
;
#pragma omp parallel for
for
(
int
j
=
0
;
j
<
n
-
n_tail
;
j
+=
NR
)
{
for
(
int
j
=
0
;
j
<
n
-
n_tail
;
j
+=
NR
)
{
float
*
local_buffer
=
buffer
+
j
*
k
;
float
*
local_buffer
=
buffer
+
j
*
k
;
for
(
int
i
=
0
;
i
<
k
;
++
i
)
{
for
(
int
i
=
0
;
i
<
k
;
++
i
)
{
...
@@ -2971,7 +2941,48 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
...
@@ -2971,7 +2941,48 @@ void Gemm::WriteWithBnAddRelu(int mc, int nc, float *c, float *C, int ldc,
// C = A * B
// C = A * B
void
Gemm
::
VecWriteBasic
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
)
{
void
Gemm
::
VecWriteBasic
(
int
n
,
float
*
c
,
float
*
C
,
int
ldc
)
{
memcpy
(
C
,
c
,
n
*
sizeof
(
float
));
int
nc1
=
n
/
16
;
int
_nc1
=
n
%
16
;
int
nc2
=
_nc1
/
4
;
int
nc3
=
16
-
4
*
(
_nc1
%
4
);
asm
volatile
(
"subs %[nc1], %[nc1], #1
\n\t
"
"blt end_nc1_%=
\n\t
"
"loop_nc1_%=:
\n\t
"
"vld1.32 {q0, q1}, [%[c]]!
\n\t
"
"vst1.32 {q0, q1}, [%[C]]!
\n\t
"
"vld1.32 {q2, q3}, [%[c]]!
\n\t
"
"vst1.32 {q2, q3}, [%[C]]!
\n\t
"
"subs %[nc1], %[nc1], #1
\n\t
"
"bge loop_nc1_%=
\n\t
"
"end_nc1_%=:
\n\t
"
"subs %[nc2], %[nc2], #1
\n\t
"
"blt end_nc2_%=
\n\t
"
"loop_nc2_%=:
\n\t
"
"vld1.32 {q4}, [%[c]]!
\n\t
"
"vst1.32 {q4}, [%[C]]!
\n\t
"
"subs %[nc2], %[nc2], #1
\n\t
"
"bge loop_nc2_%=
\n\t
"
"end_nc2_%=:
\n\t
"
"cmp %[nc3], #16
\n\t
"
"beq end_nc3_%=
\n\t
"
"sub %[c], %[c], %[nc3]
\n\t
"
"sub %[C], %[C], %[nc3]
\n\t
"
"vld1.32 {q5}, [%[c]]!
\n\t
"
"vst1.32 {q5}, [%[C]]!
\n\t
"
"end_nc3_%=:
\n\t
"
:
:
[
C
]
"r"
(
C
),
[
c
]
"r"
(
c
),
[
nc1
]
"r"
(
nc1
),
[
nc2
]
"r"
(
nc2
),
[
nc3
]
"r"
(
nc3
)
:
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
);
}
}
// C = alpha * A * B + beta * C
// C = alpha * A * B + beta * C
...
@@ -3252,17 +3263,17 @@ void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
...
@@ -3252,17 +3263,17 @@ void Gemm::Sgemm(int m, int n, int k, float alpha, const float *A, int lda,
nc
=
s_min
(
n
-
j
,
NC
);
nc
=
s_min
(
n
-
j
,
NC
);
#if __aarch64__
#if __aarch64__
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c
(
KC
,
nc
,
nc
%
NR
,
&
B
(
0
,
j
),
ldb
,
packedB
);
PackMatrixB_16c
(
KC
,
nc
,
nc
%
NR
,
&
B
(
0
,
j
),
ldb
,
packedB
,
false
);
#else
#else
PackMatrixB_8c
(
KC
,
nc
,
nc
%
NR
,
&
B
(
0
,
j
),
ldb
,
packedB
);
PackMatrixB_8c
(
KC
,
nc
,
nc
%
NR
,
&
B
(
0
,
j
),
ldb
,
packedB
,
false
);
#endif
#endif
for
(
int
i
=
0
;
i
<
m
;
i
+=
MC
)
{
for
(
int
i
=
0
;
i
<
m
;
i
+=
MC
)
{
mc
=
s_min
(
m
-
i
,
MC
);
mc
=
s_min
(
m
-
i
,
MC
);
#if __aarch64__
#if __aarch64__
PackMatrixA_6r
(
mc
,
KC
,
mc
%
MR
,
&
A
(
i
,
0
),
lda
,
packedA
);
PackMatrixA_6r
(
mc
,
KC
,
mc
%
MR
,
&
A
(
i
,
0
),
lda
,
packedA
,
false
);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#else
#else
PackMatrixA_6r
(
mc
,
KC
,
mc
%
MR
,
&
A
(
i
,
0
),
lda
,
packedA
);
PackMatrixA_6r
(
mc
,
KC
,
mc
%
MR
,
&
A
(
i
,
0
),
lda
,
packedA
,
false
);
#endif
#endif
if
(
bias
==
nullptr
)
{
if
(
bias
==
nullptr
)
{
InnerKernelWithBias
(
mc
,
nc
,
alpha
,
packedA
,
packedB
,
beta
,
packedC
,
InnerKernelWithBias
(
mc
,
nc
,
alpha
,
packedA
,
packedB
,
beta
,
packedC
,
...
@@ -3325,17 +3336,17 @@ void Gemm::SgemmWithBn(int m, int n, int k, float alpha, const float *A,
...
@@ -3325,17 +3336,17 @@ void Gemm::SgemmWithBn(int m, int n, int k, float alpha, const float *A,
nc
=
s_min
(
n
-
j
,
NC
);
nc
=
s_min
(
n
-
j
,
NC
);
#if __aarch64__
#if __aarch64__
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c
(
KC
,
nc
,
nc
%
NR
,
&
B
(
0
,
j
),
ldb
,
packedB
);
PackMatrixB_16c
(
KC
,
nc
,
nc
%
NR
,
&
B
(
0
,
j
),
ldb
,
packedB
,
false
);
#else
#else
PackMatrixB_8c
(
KC
,
nc
,
nc
%
NR
,
&
B
(
0
,
j
),
ldb
,
packedB
);
PackMatrixB_8c
(
KC
,
nc
,
nc
%
NR
,
&
B
(
0
,
j
),
ldb
,
packedB
,
false
);
#endif
#endif
for
(
int
i
=
0
;
i
<
m
;
i
+=
MC
)
{
for
(
int
i
=
0
;
i
<
m
;
i
+=
MC
)
{
mc
=
s_min
(
m
-
i
,
MC
);
mc
=
s_min
(
m
-
i
,
MC
);
#if __aarch64__
#if __aarch64__
PackMatrixA_6r
(
mc
,
KC
,
mc
%
MR
,
&
A
(
i
,
0
),
lda
,
packedA
);
PackMatrixA_6r
(
mc
,
KC
,
mc
%
MR
,
&
A
(
i
,
0
),
lda
,
packedA
,
false
);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#else
#else
PackMatrixA_6r
(
mc
,
KC
,
mc
%
MR
,
&
A
(
i
,
0
),
lda
,
packedA
);
PackMatrixA_6r
(
mc
,
KC
,
mc
%
MR
,
&
A
(
i
,
0
),
lda
,
packedA
,
false
);
#endif
#endif
if
(
bias
==
nullptr
)
{
if
(
bias
==
nullptr
)
{
InnerKernelWithBn
(
mc
,
nc
,
alpha
,
packedA
,
packedB
,
beta
,
packedC
,
InnerKernelWithBn
(
mc
,
nc
,
alpha
,
packedA
,
packedB
,
beta
,
packedC
,
...
@@ -3401,17 +3412,17 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda,
...
@@ -3401,17 +3412,17 @@ void Gemm::SgemmWithPRelu(int m, int n, int k, const float *A, int lda,
nc
=
s_min
(
n
-
j
,
NC
);
nc
=
s_min
(
n
-
j
,
NC
);
#if __aarch64__
#if __aarch64__
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
// PackMatrixB_12c(KC, nc, nc % NR, &B(0, j), ldb, packedB);
PackMatrixB_16c
(
KC
,
nc
,
nc
%
NR
,
&
B
(
0
,
j
),
ldb
,
packedB
);
PackMatrixB_16c
(
KC
,
nc
,
nc
%
NR
,
&
B
(
0
,
j
),
ldb
,
packedB
,
false
);
#else
#else
PackMatrixB_8c
(
KC
,
nc
,
nc
%
NR
,
&
B
(
0
,
j
),
ldb
,
packedB
);
PackMatrixB_8c
(
KC
,
nc
,
nc
%
NR
,
&
B
(
0
,
j
),
ldb
,
packedB
,
false
);
#endif
#endif
for
(
int
i
=
0
;
i
<
m
;
i
+=
MC
)
{
for
(
int
i
=
0
;
i
<
m
;
i
+=
MC
)
{
mc
=
s_min
(
m
-
i
,
MC
);
mc
=
s_min
(
m
-
i
,
MC
);
#if __aarch64__
#if __aarch64__
PackMatrixA_6r
(
mc
,
KC
,
mc
%
MR
,
&
A
(
i
,
0
),
lda
,
packedA
);
PackMatrixA_6r
(
mc
,
KC
,
mc
%
MR
,
&
A
(
i
,
0
),
lda
,
packedA
,
false
);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
// PackMatrixA_8r(mc, KC, mc % MR, &A(i, 0), lda, packedA);
#else
#else
PackMatrixA_6r
(
mc
,
KC
,
mc
%
MR
,
&
A
(
i
,
0
),
lda
,
packedA
);
PackMatrixA_6r
(
mc
,
KC
,
mc
%
MR
,
&
A
(
i
,
0
),
lda
,
packedA
,
false
);
#endif
#endif
if
(
bias1
==
nullptr
)
{
if
(
bias1
==
nullptr
)
{
InnerKernelWithPRelu
(
mc
,
nc
,
packedA
,
packedB
,
packedC
,
&
C
(
i
,
j
),
ldc
,
InnerKernelWithPRelu
(
mc
,
nc
,
packedA
,
packedB
,
packedC
,
&
C
(
i
,
j
),
ldc
,
...
@@ -3465,17 +3476,17 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
...
@@ -3465,17 +3476,17 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
#if __aarch64__
#if __aarch64__
procPackA
=
&
Gemm
::
PackMatrixA_6r
;
procPackA
=
&
Gemm
::
PackMatrixA_6r
;
procPackB
=
&
Gemm
::
PackMatrixB_
omp_
16c
;
procPackB
=
&
Gemm
::
PackMatrixB_16c
;
procAddDot
=
&
Gemm
::
AddDot6x16
;
procAddDot
=
&
Gemm
::
AddDot6x16
;
#else
#else
procPackA
=
&
Gemm
::
PackMatrixA_6r
;
procPackA
=
&
Gemm
::
PackMatrixA_6r
;
procPackB
=
&
Gemm
::
PackMatrixB_
omp_
8c
;
procPackB
=
&
Gemm
::
PackMatrixB_8c
;
procAddDot
=
&
Gemm
::
AddDot6x8
;
procAddDot
=
&
Gemm
::
AddDot6x8
;
#endif
#endif
packedB
=
static_cast
<
float
*>
(
packedB
=
static_cast
<
float
*>
(
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
KC
*
NC
));
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
KC
*
NC
));
(
*
this
.
*
procPackB
)(
KC
,
n
,
n
%
NR
,
B
,
ldb
,
packedB
);
(
*
this
.
*
procPackB
)(
KC
,
n
,
n
%
NR
,
B
,
ldb
,
packedB
,
true
);
packedA
=
static_cast
<
float
*>
(
packedA
=
static_cast
<
float
*>
(
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
MC
*
KC
*
max_threads
));
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
MC
*
KC
*
max_threads
));
}
else
{
}
else
{
...
@@ -3492,19 +3503,19 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
...
@@ -3492,19 +3503,19 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
MC
=
(
m
+
MR
-
1
)
/
MR
*
MR
;
MC
=
(
m
+
MR
-
1
)
/
MR
*
MR
;
#if __aarch64__
#if __aarch64__
procPackA
=
&
Gemm
::
PackMatrixA_
omp_
6r
;
procPackA
=
&
Gemm
::
PackMatrixA_6r
;
procPackB
=
&
Gemm
::
PackMatrixB_16c
;
procPackB
=
&
Gemm
::
PackMatrixB_16c
;
procAddDot
=
&
Gemm
::
AddDot6x16
;
procAddDot
=
&
Gemm
::
AddDot6x16
;
#else
#else
procPackA
=
&
Gemm
::
PackMatrixA_
omp_
6r
;
procPackA
=
&
Gemm
::
PackMatrixA_6r
;
procPackB
=
&
Gemm
::
PackMatrixB_8c
;
procPackB
=
&
Gemm
::
PackMatrixB_8c
;
procAddDot
=
&
Gemm
::
AddDot6x8
;
procAddDot
=
&
Gemm
::
AddDot6x8
;
#endif
#endif
packedA
=
static_cast
<
float
*>
(
packedA
=
static_cast
<
float
*>
(
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
MC
*
KC
));
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
MC
*
KC
));
(
*
this
.
*
procPackA
)(
m
,
KC
,
m
%
MR
,
A
,
lda
,
packedA
);
(
*
this
.
*
procPackA
)(
m
,
KC
,
m
%
MR
,
A
,
lda
,
packedA
,
true
);
packedB
=
static_cast
<
float
*>
(
packedB
=
static_cast
<
float
*>
(
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
KC
*
NC
*
max_threads
));
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
KC
*
NC
*
max_threads
));
}
}
...
@@ -3524,7 +3535,7 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
...
@@ -3524,7 +3535,7 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
mc
=
s_min
(
m
-
i
,
MC
);
mc
=
s_min
(
m
-
i
,
MC
);
float
*
local_A
=
packedA
+
MC
*
KC
*
local_threads
;
float
*
local_A
=
packedA
+
MC
*
KC
*
local_threads
;
float
*
local_C
=
packedC
+
MC
*
NC
*
local_threads
;
float
*
local_C
=
packedC
+
MC
*
NC
*
local_threads
;
(
*
this
.
*
procPackA
)(
mc
,
KC
,
mc
%
MR
,
&
A
(
i
,
0
),
lda
,
local_A
);
(
*
this
.
*
procPackA
)(
mc
,
KC
,
mc
%
MR
,
&
A
(
i
,
0
),
lda
,
local_A
,
false
);
if
(
bias
==
nullptr
)
{
if
(
bias
==
nullptr
)
{
InnerKernelWithBias
(
mc
,
n
,
alpha
,
local_A
,
packedB
,
beta
,
local_C
,
InnerKernelWithBias
(
mc
,
n
,
alpha
,
local_A
,
packedB
,
beta
,
local_C
,
&
C
(
i
,
0
),
ldc
,
relu
,
nullptr
);
&
C
(
i
,
0
),
ldc
,
relu
,
nullptr
);
...
@@ -3546,7 +3557,7 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
...
@@ -3546,7 +3557,7 @@ void Gemm::Sgemm_omp(int m, int n, int k, float alpha, const float *A, int lda,
nc
=
s_min
(
n
-
j
,
NC
);
nc
=
s_min
(
n
-
j
,
NC
);
float
*
local_B
=
packedB
+
KC
*
NC
*
local_threads
;
float
*
local_B
=
packedB
+
KC
*
NC
*
local_threads
;
float
*
local_C
=
packedC
+
MC
*
NC
*
local_threads
;
float
*
local_C
=
packedC
+
MC
*
NC
*
local_threads
;
(
*
this
.
*
procPackB
)(
KC
,
nc
,
nc
%
NR
,
&
B
(
0
,
j
),
ldb
,
local_B
);
(
*
this
.
*
procPackB
)(
KC
,
nc
,
nc
%
NR
,
&
B
(
0
,
j
),
ldb
,
local_B
,
false
);
InnerKernelWithBias
(
m
,
nc
,
alpha
,
packedA
,
local_B
,
beta
,
local_C
,
InnerKernelWithBias
(
m
,
nc
,
alpha
,
packedA
,
local_B
,
beta
,
local_C
,
&
C
(
0
,
j
),
ldc
,
relu
,
bias
);
&
C
(
0
,
j
),
ldc
,
relu
,
bias
);
}
}
...
@@ -3587,17 +3598,17 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
...
@@ -3587,17 +3598,17 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
#if __aarch64__
#if __aarch64__
procPackA
=
&
Gemm
::
PackMatrixA_6r
;
procPackA
=
&
Gemm
::
PackMatrixA_6r
;
procPackB
=
&
Gemm
::
PackMatrixB_
omp_
16c
;
procPackB
=
&
Gemm
::
PackMatrixB_16c
;
procAddDot
=
&
Gemm
::
AddDot6x16
;
procAddDot
=
&
Gemm
::
AddDot6x16
;
#else
#else
procPackA
=
&
Gemm
::
PackMatrixA_6r
;
procPackA
=
&
Gemm
::
PackMatrixA_6r
;
procPackB
=
&
Gemm
::
PackMatrixB_
omp_
8c
;
procPackB
=
&
Gemm
::
PackMatrixB_8c
;
procAddDot
=
&
Gemm
::
AddDot6x8
;
procAddDot
=
&
Gemm
::
AddDot6x8
;
#endif
#endif
packedB
=
static_cast
<
float
*>
(
packedB
=
static_cast
<
float
*>
(
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
KC
*
NC
));
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
KC
*
NC
));
(
*
this
.
*
procPackB
)(
KC
,
n
,
n
%
NR
,
B
,
ldb
,
packedB
);
(
*
this
.
*
procPackB
)(
KC
,
n
,
n
%
NR
,
B
,
ldb
,
packedB
,
true
);
packedA
=
static_cast
<
float
*>
(
packedA
=
static_cast
<
float
*>
(
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
MC
*
KC
*
max_threads
));
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
MC
*
KC
*
max_threads
));
}
else
{
}
else
{
...
@@ -3614,18 +3625,18 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
...
@@ -3614,18 +3625,18 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
MC
=
(
m
+
MR
-
1
)
/
MR
*
MR
;
MC
=
(
m
+
MR
-
1
)
/
MR
*
MR
;
#if __aarch64__
#if __aarch64__
procPackA
=
&
Gemm
::
PackMatrixA_
omp_
6r
;
procPackA
=
&
Gemm
::
PackMatrixA_6r
;
procPackB
=
&
Gemm
::
PackMatrixB_16c
;
procPackB
=
&
Gemm
::
PackMatrixB_16c
;
procAddDot
=
&
Gemm
::
AddDot6x16
;
procAddDot
=
&
Gemm
::
AddDot6x16
;
#else
#else
procPackA
=
&
Gemm
::
PackMatrixA_
omp_
6r
;
procPackA
=
&
Gemm
::
PackMatrixA_6r
;
procPackB
=
&
Gemm
::
PackMatrixB_8c
;
procPackB
=
&
Gemm
::
PackMatrixB_8c
;
procAddDot
=
&
Gemm
::
AddDot6x8
;
procAddDot
=
&
Gemm
::
AddDot6x8
;
#endif
#endif
packedA
=
static_cast
<
float
*>
(
packedA
=
static_cast
<
float
*>
(
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
MC
*
KC
));
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
MC
*
KC
));
(
*
this
.
*
procPackA
)(
m
,
KC
,
m
%
MR
,
A
,
lda
,
packedA
);
(
*
this
.
*
procPackA
)(
m
,
KC
,
m
%
MR
,
A
,
lda
,
packedA
,
true
);
packedB
=
static_cast
<
float
*>
(
packedB
=
static_cast
<
float
*>
(
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
KC
*
NC
*
max_threads
));
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
KC
*
NC
*
max_threads
));
}
}
...
@@ -3645,7 +3656,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
...
@@ -3645,7 +3656,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
mc
=
s_min
(
m
-
i
,
MC
);
mc
=
s_min
(
m
-
i
,
MC
);
float
*
local_A
=
packedA
+
MC
*
KC
*
local_threads
;
float
*
local_A
=
packedA
+
MC
*
KC
*
local_threads
;
float
*
local_C
=
packedC
+
MC
*
NC
*
local_threads
;
float
*
local_C
=
packedC
+
MC
*
NC
*
local_threads
;
(
*
this
.
*
procPackA
)(
mc
,
KC
,
mc
%
MR
,
&
A
(
i
,
0
),
lda
,
local_A
);
(
*
this
.
*
procPackA
)(
mc
,
KC
,
mc
%
MR
,
&
A
(
i
,
0
),
lda
,
local_A
,
false
);
if
(
bias
==
nullptr
)
{
if
(
bias
==
nullptr
)
{
InnerKernelWithBn
(
mc
,
n
,
alpha
,
local_A
,
packedB
,
beta
,
local_C
,
InnerKernelWithBn
(
mc
,
n
,
alpha
,
local_A
,
packedB
,
beta
,
local_C
,
&
C
(
i
,
0
),
ldc
,
relu
,
new_scale
+
i
,
new_bias
+
i
);
&
C
(
i
,
0
),
ldc
,
relu
,
new_scale
+
i
,
new_bias
+
i
);
...
@@ -3668,7 +3679,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
...
@@ -3668,7 +3679,7 @@ void Gemm::SgemmWithBn_omp(int m, int n, int k, float alpha, const float *A,
nc
=
s_min
(
n
-
j
,
NC
);
nc
=
s_min
(
n
-
j
,
NC
);
float
*
local_B
=
packedB
+
KC
*
NC
*
local_threads
;
float
*
local_B
=
packedB
+
KC
*
NC
*
local_threads
;
float
*
local_C
=
packedC
+
MC
*
NC
*
local_threads
;
float
*
local_C
=
packedC
+
MC
*
NC
*
local_threads
;
(
*
this
.
*
procPackB
)(
KC
,
nc
,
nc
%
NR
,
&
B
(
0
,
j
),
ldb
,
local_B
);
(
*
this
.
*
procPackB
)(
KC
,
nc
,
nc
%
NR
,
&
B
(
0
,
j
),
ldb
,
local_B
,
false
);
if
(
bias
==
nullptr
)
{
if
(
bias
==
nullptr
)
{
InnerKernelWithBn
(
m
,
nc
,
alpha
,
packedA
,
local_B
,
beta
,
local_C
,
InnerKernelWithBn
(
m
,
nc
,
alpha
,
packedA
,
local_B
,
beta
,
local_C
,
&
C
(
0
,
j
),
ldc
,
relu
,
new_scale
,
new_bias
);
&
C
(
0
,
j
),
ldc
,
relu
,
new_scale
,
new_bias
);
...
@@ -3715,17 +3726,17 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
...
@@ -3715,17 +3726,17 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
#if __aarch64__
#if __aarch64__
procPackA
=
&
Gemm
::
PackMatrixA_6r
;
procPackA
=
&
Gemm
::
PackMatrixA_6r
;
procPackB
=
&
Gemm
::
PackMatrixB_
omp_
16c
;
procPackB
=
&
Gemm
::
PackMatrixB_16c
;
procAddDot
=
&
Gemm
::
AddDot6x16
;
procAddDot
=
&
Gemm
::
AddDot6x16
;
#else
#else
procPackA
=
&
Gemm
::
PackMatrixA_6r
;
procPackA
=
&
Gemm
::
PackMatrixA_6r
;
procPackB
=
&
Gemm
::
PackMatrixB_
omp_
8c
;
procPackB
=
&
Gemm
::
PackMatrixB_8c
;
procAddDot
=
&
Gemm
::
AddDot6x8
;
procAddDot
=
&
Gemm
::
AddDot6x8
;
#endif
#endif
packedB
=
static_cast
<
float
*>
(
packedB
=
static_cast
<
float
*>
(
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
KC
*
NC
));
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
KC
*
NC
));
(
*
this
.
*
procPackB
)(
KC
,
n
,
n
%
NR
,
B
,
ldb
,
packedB
);
(
*
this
.
*
procPackB
)(
KC
,
n
,
n
%
NR
,
B
,
ldb
,
packedB
,
true
);
packedA
=
static_cast
<
float
*>
(
packedA
=
static_cast
<
float
*>
(
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
MC
*
KC
*
max_threads
));
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
MC
*
KC
*
max_threads
));
}
else
{
}
else
{
...
@@ -3742,18 +3753,18 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
...
@@ -3742,18 +3753,18 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
MC
=
(
m
+
MR
-
1
)
/
MR
*
MR
;
MC
=
(
m
+
MR
-
1
)
/
MR
*
MR
;
#if __aarch64__
#if __aarch64__
procPackA
=
&
Gemm
::
PackMatrixA_
omp_
6r
;
procPackA
=
&
Gemm
::
PackMatrixA_6r
;
procPackB
=
&
Gemm
::
PackMatrixB_16c
;
procPackB
=
&
Gemm
::
PackMatrixB_16c
;
procAddDot
=
&
Gemm
::
AddDot6x16
;
procAddDot
=
&
Gemm
::
AddDot6x16
;
#else
#else
procPackA
=
&
Gemm
::
PackMatrixA_
omp_
6r
;
procPackA
=
&
Gemm
::
PackMatrixA_6r
;
procPackB
=
&
Gemm
::
PackMatrixB_8c
;
procPackB
=
&
Gemm
::
PackMatrixB_8c
;
procAddDot
=
&
Gemm
::
AddDot6x8
;
procAddDot
=
&
Gemm
::
AddDot6x8
;
#endif
#endif
packedA
=
static_cast
<
float
*>
(
packedA
=
static_cast
<
float
*>
(
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
MC
*
KC
));
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
MC
*
KC
));
(
*
this
.
*
procPackA
)(
m
,
KC
,
m
%
MR
,
A
,
lda
,
packedA
);
(
*
this
.
*
procPackA
)(
m
,
KC
,
m
%
MR
,
A
,
lda
,
packedA
,
true
);
packedB
=
static_cast
<
float
*>
(
packedB
=
static_cast
<
float
*>
(
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
KC
*
NC
*
max_threads
));
paddle_mobile
::
memory
::
Alloc
(
sizeof
(
float
)
*
KC
*
NC
*
max_threads
));
}
}
...
@@ -3773,7 +3784,7 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
...
@@ -3773,7 +3784,7 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
mc
=
s_min
(
m
-
i
,
MC
);
mc
=
s_min
(
m
-
i
,
MC
);
float
*
local_A
=
packedA
+
MC
*
KC
*
local_threads
;
float
*
local_A
=
packedA
+
MC
*
KC
*
local_threads
;
float
*
local_C
=
packedC
+
MC
*
NC
*
local_threads
;
float
*
local_C
=
packedC
+
MC
*
NC
*
local_threads
;
(
*
this
.
*
procPackA
)(
mc
,
KC
,
mc
%
MR
,
&
A
(
i
,
0
),
lda
,
local_A
);
(
*
this
.
*
procPackA
)(
mc
,
KC
,
mc
%
MR
,
&
A
(
i
,
0
),
lda
,
local_A
,
false
);
if
(
bias1
==
nullptr
)
{
if
(
bias1
==
nullptr
)
{
InnerKernelWithPRelu
(
mc
,
n
,
local_A
,
packedB
,
local_C
,
&
C
(
i
,
0
),
ldc
,
InnerKernelWithPRelu
(
mc
,
n
,
local_A
,
packedB
,
local_C
,
&
C
(
i
,
0
),
ldc
,
p
+
i
,
mode
,
bias
+
i
,
nullptr
);
p
+
i
,
mode
,
bias
+
i
,
nullptr
);
...
@@ -3795,7 +3806,7 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
...
@@ -3795,7 +3806,7 @@ void Gemm::SgemmWithPRelu_omp(int m, int n, int k, const float *A, int lda,
nc
=
s_min
(
n
-
j
,
NC
);
nc
=
s_min
(
n
-
j
,
NC
);
float
*
local_B
=
packedB
+
KC
*
NC
*
local_threads
;
float
*
local_B
=
packedB
+
KC
*
NC
*
local_threads
;
float
*
local_C
=
packedC
+
MC
*
NC
*
local_threads
;
float
*
local_C
=
packedC
+
MC
*
NC
*
local_threads
;
(
*
this
.
*
procPackB
)(
KC
,
nc
,
nc
%
NR
,
&
B
(
0
,
j
),
ldb
,
local_B
);
(
*
this
.
*
procPackB
)(
KC
,
nc
,
nc
%
NR
,
&
B
(
0
,
j
),
ldb
,
local_B
,
false
);
if
(
bias1
==
nullptr
)
{
if
(
bias1
==
nullptr
)
{
InnerKernelWithPRelu
(
m
,
nc
,
packedA
,
local_B
,
local_C
,
&
C
(
0
,
j
),
ldc
,
p
,
InnerKernelWithPRelu
(
m
,
nc
,
packedA
,
local_B
,
local_C
,
&
C
(
0
,
j
),
ldc
,
p
,
mode
,
bias
,
nullptr
);
mode
,
bias
,
nullptr
);
...
...
src/operators/math/gemm.h
浏览文件 @
ce0f2bfd
...
@@ -46,37 +46,25 @@ namespace math {
...
@@ -46,37 +46,25 @@ namespace math {
class
Gemm
{
class
Gemm
{
public:
public:
typedef
void
(
Gemm
::*
FnPack
)(
int
,
int
,
int
,
const
float
*
,
int
,
float
*
);
typedef
void
(
Gemm
::*
FnPack
)(
int
,
int
,
int
,
const
float
*
,
int
,
float
*
,
const
bool
);
typedef
void
(
Gemm
::*
FnAddDot
)(
int
,
const
float
*
,
const
float
*
,
float
*
,
typedef
void
(
Gemm
::*
FnAddDot
)(
int
,
const
float
*
,
const
float
*
,
float
*
,
int
);
int
);
FnPack
procPackA
;
FnPack
procPackA
;
FnPack
procPackB
;
FnPack
procPackB
;
FnAddDot
procAddDot
;
FnAddDot
procAddDot
;
// 将 A\B 矩阵分块复制到连续内存(RowMajor)
void
PackMatrixA_4r
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
float
*
buffer
);
void
PackMatrixA_6r
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
void
PackMatrixA_6r
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
float
*
buffer
);
float
*
buffer
,
const
bool
parallel
);
void
PackMatrixA_omp_6r
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
float
*
buffer
);
void
PackMatrixA_8r
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
void
PackMatrixA_8r
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
float
*
buffer
);
float
*
buffer
,
const
bool
parallel
);
void
PackMatrixA_omp_8r
(
int
m
,
int
k
,
int
m_tail
,
const
float
*
A
,
int
lda
,
float
*
buffer
);
void
PackMatrixB_8c
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
void
PackMatrixB_8c
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
float
*
buffer
);
float
*
buffer
,
const
bool
parallel
);
void
PackMatrixB_omp_8c
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
float
*
buffer
);
#if __aarch64__
#if __aarch64__
void
PackMatrixB_12c
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
void
PackMatrixB_12c
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
float
*
buffer
);
float
*
buffer
,
const
bool
parallel
);
void
PackMatrixB_omp_12c
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
float
*
buffer
);
void
PackMatrixB_16c
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
void
PackMatrixB_16c
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
float
*
buffer
);
float
*
buffer
,
const
bool
parallel
);
void
PackMatrixB_omp_16c
(
int
k
,
int
n
,
int
n_tail
,
const
float
*
B
,
int
ldb
,
float
*
buffer
);
#endif
#endif
// 分块矩阵乘法
// 分块矩阵乘法
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录