Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
f249d387
MegEngine
项目概览
MegEngine 天元
/
MegEngine
接近 2 年 前同步成功
通知
414
Star
4708
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
f249d387
编写于
5月 17, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(fallback): imp gi matmul FB_GI_F32_4x12 algo
GitOrigin-RevId: 16255e7a728bf8cffbdc57094534600683af587a
上级
03f78547
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
1253 addition
and
4 deletion
+1253
-4
dnn/src/fallback/conv_bias/opr_impl.cpp
dnn/src/fallback/conv_bias/opr_impl.cpp
+0
-2
dnn/src/fallback/matrix_mul/algos.cpp
dnn/src/fallback/matrix_mul/algos.cpp
+58
-0
dnn/src/fallback/matrix_mul/algos.h
dnn/src/fallback/matrix_mul/algos.h
+11
-0
dnn/src/fallback/matrix_mul/generic_strategy.h
dnn/src/fallback/matrix_mul/generic_strategy.h
+1
-0
dnn/src/fallback/matrix_mul/gi/fp32/common.h
dnn/src/fallback/matrix_mul/gi/fp32/common.h
+221
-0
dnn/src/fallback/matrix_mul/gi/fp32/strategy_4x12.cpp
dnn/src/fallback/matrix_mul/gi/fp32/strategy_4x12.cpp
+950
-0
dnn/src/fallback/matrix_mul/opr_impl.cpp
dnn/src/fallback/matrix_mul/opr_impl.cpp
+4
-2
dnn/src/fallback/matrix_mul/opr_impl.h
dnn/src/fallback/matrix_mul/opr_impl.h
+2
-0
dnn/test/fallback/matrix_mul.cpp
dnn/test/fallback/matrix_mul.cpp
+6
-0
未找到文件。
dnn/src/fallback/conv_bias/opr_impl.cpp
浏览文件 @
f249d387
...
...
@@ -138,8 +138,6 @@ public:
}
}
//! TODO: move arm_v7 MatrixMulImpl::AlgoF32 matmul to gi fallback, for nchw
//! prefetch algo, also need update dnn/test/common/conv_bias.cpp:check_winograd
matmul_algos
=
static_cast
<
fallback
::
MatrixMulImpl
*>
(
matmul_opr
)
->
select_algo_type
(
{
AlgoDataType
::
FLOAT32
,
MatmulFormat
::
DEFAULT
});
...
...
dnn/src/fallback/matrix_mul/algos.cpp
浏览文件 @
f249d387
...
...
@@ -15,6 +15,7 @@ MIDOUT_DECL(megdnn_fb_matmul_f32_gemm_gemv_like)
MIDOUT_DECL
(
megdnn_fb_matmul_naive
)
MIDOUT_DECL
(
megdnn_fb_gi_exec_fp32
)
MIDOUT_DECL
(
megdnn_fb_gi_matmul_kern
)
MIDOUT_DECL
(
megdnn_fb_gi_f32_4x12
)
using
namespace
megdnn
;
using
namespace
fallback
;
...
...
@@ -293,4 +294,61 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32GiMK4_4x8::get_kern(
const
KernSizeParam
&
)
const
{
return
gi_f32_mk4_4x8_kern
;
}
/* ===================== F32 algo ===================== */
namespace
{
void
f32_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
MIDOUT_BEGIN
(
megdnn_fb_gi_f32_4x12
,
midout_iv
(
"f32_kern"
_hash
))
{
auto
M
=
kern_param
.
M
,
N
=
kern_param
.
N
,
K
=
kern_param
.
K
;
auto
trA
=
kern_param
.
trA
,
trB
=
kern_param
.
trB
;
auto
LDA
=
kern_param
.
LDA
,
LDB
=
kern_param
.
LDB
,
LDC
=
kern_param
.
LDC
;
auto
A_type
=
kern_param
.
A_type
,
B_type
=
kern_param
.
B_type
,
C_type
=
kern_param
.
C_type
;
const
auto
Aptr
=
kern_param
.
A
<
float
>
(),
Bptr
=
kern_param
.
B
<
float
>
();
auto
Cptr
=
kern_param
.
C
<
float
>
();
matmul
::
fallback
::
gi_sgemm_4x12
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
megdnn
::
matmul
::
GemmInterleaved
<
matmul
::
fallback
::
gi_sgemm_4x12
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
execute
(
Aptr
,
LDA
,
Bptr
,
LDB
,
Cptr
,
LDC
,
kern_param
.
workspace_ptr
);
}
MIDOUT_END
();
}
}
// anonymous namespace
bool
MatrixMulImpl
::
AlgoF32Gi4x12
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
return
kern_size_param
.
compute_mode
==
Param
::
ComputeMode
::
DEFAULT
&&
kern_size_param
.
format
==
param
::
MatrixMul
::
Format
::
DEFAULT
&&
kern_size_param
.
B_type
==
kern_size_param
.
A_type
&&
kern_size_param
.
C_type
==
kern_size_param
.
A_type
&&
kern_size_param
.
A_type
==
dtype
::
Float32
();
}
size_t
MatrixMulImpl
::
AlgoF32Gi4x12
::
get_workspace
(
const
KernSizeParam
&
kern_size_param
)
const
{
MIDOUT_BEGIN
(
megdnn_fb_gi_f32_4x12
,
midout_iv
(
"AlgoF32Gi4x12::get_workspace"
_hash
))
{
auto
M
=
kern_size_param
.
M
,
N
=
kern_size_param
.
N
,
K
=
kern_size_param
.
K
;
auto
trA
=
kern_size_param
.
trA
,
trB
=
kern_size_param
.
trB
;
auto
A_type
=
kern_size_param
.
A_type
,
B_type
=
kern_size_param
.
B_type
,
C_type
=
kern_size_param
.
C_type
;
matmul
::
fallback
::
gi_sgemm_4x12
strategy
(
M
,
N
,
K
,
A_type
,
B_type
,
C_type
);
return
megdnn
::
matmul
::
GemmInterleaved
<
matmul
::
fallback
::
gi_sgemm_4x12
>
(
M
,
N
,
K
,
trA
,
trB
,
strategy
)
.
get_workspace_size
();
}
MIDOUT_END
();
return
0
;
}
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoF32Gi4x12
::
get_kern
(
const
KernSizeParam
&
)
const
{
return
f32_kern
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL
(
AlgoF32Gi4x12
,
megdnn_fb_gi_f32_4x12
,
"AlgoF32Gi4x12Impl"
_hash
,
matmul
::
fallback
::
gi_sgemm_4x12
,
float
,
float
,
AlgoDataType
::
FLOAT32
,
DEFAULT
);
// vim: syntax=cpp.doxygen
dnn/src/fallback/matrix_mul/algos.h
浏览文件 @
f249d387
...
...
@@ -97,6 +97,17 @@ public:
MEGDNN_DECL_ALGO_TYPE
(
FB_GI_F32_MK4_4x8
)
};
class
MatrixMulImpl
::
AlgoF32Gi4x12
final
:
public
AlgoBase
{
public:
AlgoAttribute
attribute
()
const
override
{
return
AlgoAttribute
::
REPRODUCIBLE
;
}
const
char
*
name
()
const
override
{
return
"FB_GI_F32_4x12"
;
}
bool
usable
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
;
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
MEGDNN_DECL_ALGO_TYPE
(
FB_GI_F32_4x12
)
};
}
// namespace fallback
}
// namespace megdnn
...
...
dnn/src/fallback/matrix_mul/generic_strategy.h
浏览文件 @
f249d387
...
...
@@ -8,6 +8,7 @@ namespace fallback {
MEGDNN_REG_GEMM_STRATEGY
(
float
,
float
,
float
,
8
,
12
,
1
,
false
,
true
,
sgemm_8x12
);
MEGDNN_REG_GEMM_STRATEGY_NOPACK
(
float
,
float
,
float
,
4
,
8
,
1
,
false
,
true
,
gi_sgemm_nopack_4x8
);
MEGDNN_REG_GEMM_STRATEGY
(
float
,
float
,
float
,
4
,
12
,
1
,
false
,
true
,
gi_sgemm_4x12
);
}
// namespace fallback
}
// namespace matmul
...
...
dnn/src/fallback/matrix_mul/gi/fp32/common.h
0 → 100644
浏览文件 @
f249d387
#pragma once
#include "src/fallback/general_intrinsic/gi_float.h"
namespace
megdnn
{
namespace
matmul
{
namespace
fallback
{
/* ======================== transform ======================== */
/**
* interleave_INTERLEAVE_UNROLLK_BATCH_type
*
* BATCH means process BATCH * UNROLL_K cols once, BATCH * sizeof(TYPE) *
* UNROLL_K = 16bytes(128bits, a vector size).
*
* the elements traverse order:
* rep(j, 0, INTERLEAVE) rep(i, 0, UNROLL_K) *ouptr++ = inptr[j, i]
*/
template
<
typename
T
>
static
GI_FORCEINLINE
void
interleave_4x4_1_s
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
const
T
*&
inptr2
,
const
T
*&
inptr3
,
T
*&
outptr
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"interleave_4x4_1_s only support sizeof(T) == 4"
);
GI_FLOAT32_t
d0d1
=
GiLoadFloat32
(
inptr0
);
GI_FLOAT32_t
d2d3
=
GiLoadFloat32
(
inptr1
);
GI_FLOAT32_t
d4d5
=
GiLoadFloat32
(
inptr2
);
GI_FLOAT32_t
d6d7
=
GiLoadFloat32
(
inptr3
);
inptr0
+=
4
;
inptr1
+=
4
;
inptr2
+=
4
;
inptr3
+=
4
;
GiStoreFloat32
(
outptr
,
d0d1
);
outptr
+=
4
;
GiStoreFloat32
(
outptr
,
d2d3
);
outptr
+=
4
;
GiStoreFloat32
(
outptr
,
d4d5
);
outptr
+=
4
;
GiStoreFloat32
(
outptr
,
d6d7
);
outptr
+=
4
;
}
template
<
typename
T
>
static
GI_FORCEINLINE
void
interleave_4x12_1_s
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
const
T
*&
inptr2
,
const
T
*&
inptr3
,
T
*&
outptr
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"interleave_4x12_1_s only support sizeof(T) == 4"
);
GI_FLOAT32_t
d0d1
=
GiLoadFloat32
(
inptr0
);
inptr0
+=
4
;
GI_FLOAT32_t
d2d3
=
GiLoadFloat32
(
inptr0
);
inptr0
+=
4
;
GI_FLOAT32_t
d4d5
=
GiLoadFloat32
(
inptr0
);
inptr0
+=
4
;
GI_FLOAT32_t
d6d7
=
GiLoadFloat32
(
inptr1
);
inptr1
+=
4
;
GI_FLOAT32_t
d8d9
=
GiLoadFloat32
(
inptr1
);
inptr1
+=
4
;
GI_FLOAT32_t
d10d11
=
GiLoadFloat32
(
inptr1
);
inptr1
+=
4
;
GI_FLOAT32_t
d12d13
=
GiLoadFloat32
(
inptr2
);
inptr2
+=
4
;
GI_FLOAT32_t
d14d15
=
GiLoadFloat32
(
inptr2
);
inptr2
+=
4
;
GI_FLOAT32_t
d16d17
=
GiLoadFloat32
(
inptr2
);
inptr2
+=
4
;
GI_FLOAT32_t
d18d19
=
GiLoadFloat32
(
inptr3
);
inptr3
+=
4
;
GI_FLOAT32_t
d20d21
=
GiLoadFloat32
(
inptr3
);
inptr3
+=
4
;
GI_FLOAT32_t
d22d23
=
GiLoadFloat32
(
inptr3
);
inptr3
+=
4
;
GiStoreFloat32
(
outptr
,
d0d1
);
outptr
+=
4
;
GiStoreFloat32
(
outptr
,
d2d3
);
outptr
+=
4
;
GiStoreFloat32
(
outptr
,
d4d5
);
outptr
+=
4
;
GiStoreFloat32
(
outptr
,
d6d7
);
outptr
+=
4
;
GiStoreFloat32
(
outptr
,
d8d9
);
outptr
+=
4
;
GiStoreFloat32
(
outptr
,
d10d11
);
outptr
+=
4
;
GiStoreFloat32
(
outptr
,
d12d13
);
outptr
+=
4
;
GiStoreFloat32
(
outptr
,
d14d15
);
outptr
+=
4
;
GiStoreFloat32
(
outptr
,
d16d17
);
outptr
+=
4
;
GiStoreFloat32
(
outptr
,
d18d19
);
outptr
+=
4
;
GiStoreFloat32
(
outptr
,
d20d21
);
outptr
+=
4
;
GiStoreFloat32
(
outptr
,
d22d23
);
outptr
+=
4
;
}
template
<
typename
T
>
static
GI_FORCEINLINE
void
interleave_1x12_1_s
(
const
T
*&
inptr0
,
T
*&
outptr
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"interleave_1x12_1_s only support sizeof(T) == 4"
);
GI_FLOAT32_t
d0d1
=
GiLoadFloat32
(
inptr0
);
inptr0
+=
4
;
GI_FLOAT32_t
d2d3
=
GiLoadFloat32
(
inptr0
);
inptr0
+=
4
;
GI_FLOAT32_t
d4d5
=
GiLoadFloat32
(
inptr0
);
inptr0
+=
4
;
GiStoreFloat32
(
outptr
,
d0d1
);
outptr
+=
4
;
GiStoreFloat32
(
outptr
,
d2d3
);
outptr
+=
4
;
GiStoreFloat32
(
outptr
,
d4d5
);
outptr
+=
4
;
}
template
<
typename
T
>
static
GI_FORCEINLINE
void
interleave_1x4_1_s
(
const
T
*&
inptr0
,
T
*&
outptr
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"interleave_1x4_1_s only support sizeof(T) == 4"
);
GI_FLOAT32_t
d0d1
=
GiLoadFloat32
(
inptr0
);
inptr0
+=
4
;
GiStoreFloat32
(
outptr
,
d0d1
);
outptr
+=
4
;
}
template
<
typename
T
>
static
GI_FORCEINLINE
void
interleave_helper
(
const
T
*&
inptr
,
T
*&
outptr
,
int
unroll_k
,
int
ksize
,
T
val
=
0
)
{
int
k
=
0
;
for
(;
k
<
ksize
;
k
++
)
{
*
outptr
++
=
*
inptr
++
;
}
for
(;
k
<
unroll_k
;
k
++
)
{
*
outptr
++
=
val
;
}
}
template
<
typename
T
>
static
GI_FORCEINLINE
void
interleave_1
(
const
T
*&
inptr0
,
T
*&
outptr
,
int
unroll_k
,
int
ksize
,
T
val
=
0
)
{
for
(
int
k
=
0
;
k
<
ksize
;
k
+=
unroll_k
)
{
int
size
=
std
::
min
(
unroll_k
,
ksize
-
k
);
interleave_helper
(
inptr0
,
outptr
,
unroll_k
,
size
,
val
);
}
}
template
<
typename
T
>
static
GI_FORCEINLINE
void
interleave_4
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
const
T
*&
inptr2
,
const
T
*&
inptr3
,
T
*&
outptr
,
int
unroll_k
,
int
ksize
,
T
val
=
0
)
{
for
(
int
k
=
0
;
k
<
ksize
;
k
+=
unroll_k
)
{
int
size
=
std
::
min
(
unroll_k
,
ksize
-
k
);
interleave_helper
(
inptr0
,
outptr
,
unroll_k
,
size
,
val
);
interleave_helper
(
inptr1
,
outptr
,
unroll_k
,
size
,
val
);
interleave_helper
(
inptr2
,
outptr
,
unroll_k
,
size
,
val
);
interleave_helper
(
inptr3
,
outptr
,
unroll_k
,
size
,
val
);
}
}
/* ======================== transpose pack B ======================== */
/**
* transpose_INTERLEAVE_UNROLLK_BATCH_type
*
* BATCH means process BATCH * INTERLEAVE cols once, BATCH * sizeof(TYPE) *
* INTERLEAVE = 16bytes(128bits, a vector size).
*
* the elements traverse order:
* rep(j, 0, INTERLEAVE) rep(i, 0, UNROLL_K) *ouptr++ = inptr[i, j]
*/
template
<
typename
T
>
static
GI_FORCEINLINE
void
transpose_4x4_1_s
(
const
T
*&
inptr0
,
const
T
*&
inptr1
,
const
T
*&
inptr2
,
const
T
*&
inptr3
,
T
*&
outptr
,
int
stride
=
16
)
{
static_assert
(
sizeof
(
T
)
==
4
,
"transpose_4x4_1_s only support sizeof(T) == 4"
);
stride
=
stride
/
sizeof
(
float
);
stride
-=
2
;
GI_FLOAT32_t
d0d1
=
GiLoadFloat32
(
inptr0
);
GI_FLOAT32_t
d2d3
=
GiLoadFloat32
(
inptr1
);
GI_FLOAT32_t
d4d5
=
GiLoadFloat32
(
inptr2
);
GI_FLOAT32_t
d6d7
=
GiLoadFloat32
(
inptr3
);
inptr0
+=
4
;
inptr1
+=
4
;
inptr2
+=
4
;
inptr3
+=
4
;
GI_FLOAT32_V2_t
q0q1
=
GiZipqFloat32
(
d0d1
,
d2d3
);
GI_FLOAT32_V2_t
q2q3
=
GiZipqFloat32
(
d4d5
,
d6d7
);
GiSt1Float32
(
outptr
,
GiGetLowFloat32
(
q0q1
.
val
[
0
]));
outptr
+=
2
;
GiSt1Float32
(
outptr
,
GiGetLowFloat32
(
q2q3
.
val
[
0
]));
outptr
+=
stride
;
GiSt1Float32
(
outptr
,
GiGetHighFloat32
(
q0q1
.
val
[
0
]));
outptr
+=
2
;
GiSt1Float32
(
outptr
,
GiGetHighFloat32
(
q2q3
.
val
[
0
]));
outptr
+=
stride
;
GiSt1Float32
(
outptr
,
GiGetLowFloat32
(
q0q1
.
val
[
1
]));
outptr
+=
2
;
GiSt1Float32
(
outptr
,
GiGetLowFloat32
(
q2q3
.
val
[
1
]));
outptr
+=
stride
;
GiSt1Float32
(
outptr
,
GiGetHighFloat32
(
q0q1
.
val
[
1
]));
outptr
+=
2
;
GiSt1Float32
(
outptr
,
GiGetHighFloat32
(
q2q3
.
val
[
1
]));
outptr
+=
stride
;
}
}
// namespace fallback
}
// namespace matmul
}
// namespace megdnn
// vim: syntax=cpp.doxygen
dnn/src/fallback/matrix_mul/gi/fp32/strategy_4x12.cpp
0 → 100644
浏览文件 @
f249d387
#include "src/fallback/matrix_mul/generic_strategy.h"
#include "src/fallback/matrix_mul/gi/fp32/common.h"
using
namespace
megdnn
;
using
namespace
matmul
::
fallback
;
namespace
{
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wuninitialized"
#ifdef __GNUC__
#ifndef __has_warning
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#else
#if __has_warning("-Wmaybe-uninitialized")
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#endif
#endif
#endif
void
kern_4x12
(
const
float
*
packA
,
const
float
*
packB
,
int
K
,
float
*
output
,
int
LDC
,
bool
is_first_k
,
int
m_remain
)
{
const
float
*
a_ptr
=
packA
;
const
float
*
b_ptr
=
packB
;
int
oddk
=
(
K
&
1
);
K
=
((
K
+
1
)
/
2
)
-
1
;
float
*
r0
=
output
;
float
*
r1
=
r0
+
LDC
;
float
*
r2
=
r1
+
LDC
;
float
*
r3
=
r2
+
LDC
;
GI_FLOAT32_t
d0d1
,
d2d3
,
d4d5
,
d6d7
,
d8d9
,
d10d11
,
d12d13
,
d14d15
,
d16d17
,
d18d19
,
d20d21
,
d22d23
,
d24d25
,
d26d27
,
d28d29
,
d30d31
;
if
(
is_first_k
)
{
d8d9
=
GiBroadcastFloat32
(
0.0
f
);
d10d11
=
GiBroadcastFloat32
(
0.0
f
);
d12d13
=
GiBroadcastFloat32
(
0.0
f
);
d14d15
=
GiBroadcastFloat32
(
0.0
f
);
d16d17
=
GiBroadcastFloat32
(
0.0
f
);
d18d19
=
GiBroadcastFloat32
(
0.0
f
);
d20d21
=
GiBroadcastFloat32
(
0.0
f
);
d22d23
=
GiBroadcastFloat32
(
0.0
f
);
d24d25
=
GiBroadcastFloat32
(
0.0
f
);
d26d27
=
GiBroadcastFloat32
(
0.0
f
);
d28d29
=
GiBroadcastFloat32
(
0.0
f
);
d30d31
=
GiBroadcastFloat32
(
0.0
f
);
}
else
{
if
(
m_remain
==
4
)
{
d8d9
=
GiLoadFloat32
(
r0
);
d10d11
=
GiLoadFloat32
(
r0
+
4
);
d12d13
=
GiLoadFloat32
(
r0
+
8
);
d14d15
=
GiLoadFloat32
(
r1
);
d16d17
=
GiLoadFloat32
(
r1
+
4
);
d18d19
=
GiLoadFloat32
(
r1
+
8
);
d20d21
=
GiLoadFloat32
(
r2
);
d22d23
=
GiLoadFloat32
(
r2
+
4
);
d24d25
=
GiLoadFloat32
(
r2
+
8
);
d26d27
=
GiLoadFloat32
(
r3
);
d28d29
=
GiLoadFloat32
(
r3
+
4
);
d30d31
=
GiLoadFloat32
(
r3
+
8
);
}
else
if
(
m_remain
==
3
)
{
d8d9
=
GiLoadFloat32
(
r0
);
d10d11
=
GiLoadFloat32
(
r0
+
4
);
d12d13
=
GiLoadFloat32
(
r0
+
8
);
d14d15
=
GiLoadFloat32
(
r1
);
d16d17
=
GiLoadFloat32
(
r1
+
4
);
d18d19
=
GiLoadFloat32
(
r1
+
8
);
d20d21
=
GiLoadFloat32
(
r2
);
d22d23
=
GiLoadFloat32
(
r2
+
4
);
d24d25
=
GiLoadFloat32
(
r2
+
8
);
}
else
if
(
m_remain
==
2
)
{
d8d9
=
GiLoadFloat32
(
r0
);
d10d11
=
GiLoadFloat32
(
r0
+
4
);
d12d13
=
GiLoadFloat32
(
r0
+
8
);
d14d15
=
GiLoadFloat32
(
r1
);
d16d17
=
GiLoadFloat32
(
r1
+
4
);
d18d19
=
GiLoadFloat32
(
r1
+
8
);
}
else
if
(
m_remain
==
1
)
{
d8d9
=
GiLoadFloat32
(
r0
);
d10d11
=
GiLoadFloat32
(
r0
+
4
);
d12d13
=
GiLoadFloat32
(
r0
+
8
);
}
}
d2d3
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
d4d5
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
d6d7
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
for
(;
K
>
0
;
K
--
)
{
d0d1
=
GiLoadFloat32
(
a_ptr
);
a_ptr
=
a_ptr
+
4
;
d8d9
=
GiSimdFmaLane
(
d8d9
,
d2d3
,
d0d1
,
0
);
d10d11
=
GiSimdFmaLane
(
d10d11
,
d4d5
,
d0d1
,
0
);
d12d13
=
GiSimdFmaLane
(
d12d13
,
d6d7
,
d0d1
,
0
);
d14d15
=
GiSimdFmaLane
(
d14d15
,
d2d3
,
d0d1
,
1
);
d16d17
=
GiSimdFmaLane
(
d16d17
,
d4d5
,
d0d1
,
1
);
d18d19
=
GiSimdFmaLane
(
d18d19
,
d6d7
,
d0d1
,
1
);
d20d21
=
GiSimdFmaLane
(
d20d21
,
d2d3
,
d0d1
,
2
);
d22d23
=
GiSimdFmaLane
(
d22d23
,
d4d5
,
d0d1
,
2
);
d24d25
=
GiSimdFmaLane
(
d24d25
,
d6d7
,
d0d1
,
2
);
d26d27
=
GiSimdFmaLane
(
d26d27
,
d2d3
,
d0d1
,
3
);
d28d29
=
GiSimdFmaLane
(
d28d29
,
d4d5
,
d0d1
,
3
);
d30d31
=
GiSimdFmaLane
(
d30d31
,
d6d7
,
d0d1
,
3
);
d0d1
=
GiLoadFloat32
(
a_ptr
);
a_ptr
=
a_ptr
+
4
;
d2d3
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
d4d5
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
d6d7
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
d8d9
=
GiSimdFmaLane
(
d8d9
,
d2d3
,
d0d1
,
0
);
d10d11
=
GiSimdFmaLane
(
d10d11
,
d4d5
,
d0d1
,
0
);
d12d13
=
GiSimdFmaLane
(
d12d13
,
d6d7
,
d0d1
,
0
);
d14d15
=
GiSimdFmaLane
(
d14d15
,
d2d3
,
d0d1
,
1
);
d16d17
=
GiSimdFmaLane
(
d16d17
,
d4d5
,
d0d1
,
1
);
d18d19
=
GiSimdFmaLane
(
d18d19
,
d6d7
,
d0d1
,
1
);
d20d21
=
GiSimdFmaLane
(
d20d21
,
d2d3
,
d0d1
,
2
);
d22d23
=
GiSimdFmaLane
(
d22d23
,
d4d5
,
d0d1
,
2
);
d24d25
=
GiSimdFmaLane
(
d24d25
,
d6d7
,
d0d1
,
2
);
d26d27
=
GiSimdFmaLane
(
d26d27
,
d2d3
,
d0d1
,
3
);
d28d29
=
GiSimdFmaLane
(
d28d29
,
d4d5
,
d0d1
,
3
);
d30d31
=
GiSimdFmaLane
(
d30d31
,
d6d7
,
d0d1
,
3
);
d2d3
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
d4d5
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
d6d7
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
}
if
(
1
==
oddk
)
{
d0d1
=
GiLoadFloat32
(
a_ptr
);
a_ptr
=
a_ptr
+
4
;
d8d9
=
GiSimdFmaLane
(
d8d9
,
d2d3
,
d0d1
,
0
);
d10d11
=
GiSimdFmaLane
(
d10d11
,
d4d5
,
d0d1
,
0
);
d12d13
=
GiSimdFmaLane
(
d12d13
,
d6d7
,
d0d1
,
0
);
d14d15
=
GiSimdFmaLane
(
d14d15
,
d2d3
,
d0d1
,
1
);
d16d17
=
GiSimdFmaLane
(
d16d17
,
d4d5
,
d0d1
,
1
);
d18d19
=
GiSimdFmaLane
(
d18d19
,
d6d7
,
d0d1
,
1
);
d20d21
=
GiSimdFmaLane
(
d20d21
,
d2d3
,
d0d1
,
2
);
d22d23
=
GiSimdFmaLane
(
d22d23
,
d4d5
,
d0d1
,
2
);
d24d25
=
GiSimdFmaLane
(
d24d25
,
d6d7
,
d0d1
,
2
);
d26d27
=
GiSimdFmaLane
(
d26d27
,
d2d3
,
d0d1
,
3
);
d28d29
=
GiSimdFmaLane
(
d28d29
,
d4d5
,
d0d1
,
3
);
d30d31
=
GiSimdFmaLane
(
d30d31
,
d6d7
,
d0d1
,
3
);
}
else
{
d0d1
=
GiLoadFloat32
(
a_ptr
);
a_ptr
=
a_ptr
+
4
;
d8d9
=
GiSimdFmaLane
(
d8d9
,
d2d3
,
d0d1
,
0
);
d10d11
=
GiSimdFmaLane
(
d10d11
,
d4d5
,
d0d1
,
0
);
d12d13
=
GiSimdFmaLane
(
d12d13
,
d6d7
,
d0d1
,
0
);
d14d15
=
GiSimdFmaLane
(
d14d15
,
d2d3
,
d0d1
,
1
);
d16d17
=
GiSimdFmaLane
(
d16d17
,
d4d5
,
d0d1
,
1
);
d18d19
=
GiSimdFmaLane
(
d18d19
,
d6d7
,
d0d1
,
1
);
d20d21
=
GiSimdFmaLane
(
d20d21
,
d2d3
,
d0d1
,
2
);
d22d23
=
GiSimdFmaLane
(
d22d23
,
d4d5
,
d0d1
,
2
);
d24d25
=
GiSimdFmaLane
(
d24d25
,
d6d7
,
d0d1
,
2
);
d26d27
=
GiSimdFmaLane
(
d26d27
,
d2d3
,
d0d1
,
3
);
d28d29
=
GiSimdFmaLane
(
d28d29
,
d4d5
,
d0d1
,
3
);
d30d31
=
GiSimdFmaLane
(
d30d31
,
d6d7
,
d0d1
,
3
);
d0d1
=
GiLoadFloat32
(
a_ptr
);
a_ptr
=
a_ptr
+
4
;
d2d3
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
d4d5
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
d6d7
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
d8d9
=
GiSimdFmaLane
(
d8d9
,
d2d3
,
d0d1
,
0
);
d10d11
=
GiSimdFmaLane
(
d10d11
,
d4d5
,
d0d1
,
0
);
d12d13
=
GiSimdFmaLane
(
d12d13
,
d6d7
,
d0d1
,
0
);
d14d15
=
GiSimdFmaLane
(
d14d15
,
d2d3
,
d0d1
,
1
);
d16d17
=
GiSimdFmaLane
(
d16d17
,
d4d5
,
d0d1
,
1
);
d18d19
=
GiSimdFmaLane
(
d18d19
,
d6d7
,
d0d1
,
1
);
d20d21
=
GiSimdFmaLane
(
d20d21
,
d2d3
,
d0d1
,
2
);
d22d23
=
GiSimdFmaLane
(
d22d23
,
d4d5
,
d0d1
,
2
);
d24d25
=
GiSimdFmaLane
(
d24d25
,
d6d7
,
d0d1
,
2
);
d26d27
=
GiSimdFmaLane
(
d26d27
,
d2d3
,
d0d1
,
3
);
d28d29
=
GiSimdFmaLane
(
d28d29
,
d4d5
,
d0d1
,
3
);
d30d31
=
GiSimdFmaLane
(
d30d31
,
d6d7
,
d0d1
,
3
);
}
if
(
m_remain
==
4
)
{
GiStoreFloat32
(
r0
,
d8d9
);
GiStoreFloat32
(
r0
+
4
,
d10d11
);
GiStoreFloat32
(
r0
+
8
,
d12d13
);
GiStoreFloat32
(
r1
,
d14d15
);
GiStoreFloat32
(
r1
+
4
,
d16d17
);
GiStoreFloat32
(
r1
+
8
,
d18d19
);
GiStoreFloat32
(
r2
,
d20d21
);
GiStoreFloat32
(
r2
+
4
,
d22d23
);
GiStoreFloat32
(
r2
+
8
,
d24d25
);
GiStoreFloat32
(
r3
,
d26d27
);
GiStoreFloat32
(
r3
+
4
,
d28d29
);
GiStoreFloat32
(
r3
+
8
,
d30d31
);
}
else
if
(
m_remain
==
3
)
{
GiStoreFloat32
(
r0
,
d8d9
);
GiStoreFloat32
(
r0
+
4
,
d10d11
);
GiStoreFloat32
(
r0
+
8
,
d12d13
);
GiStoreFloat32
(
r1
,
d14d15
);
GiStoreFloat32
(
r1
+
4
,
d16d17
);
GiStoreFloat32
(
r1
+
8
,
d18d19
);
GiStoreFloat32
(
r2
,
d20d21
);
GiStoreFloat32
(
r2
+
4
,
d22d23
);
GiStoreFloat32
(
r2
+
8
,
d24d25
);
}
else
if
(
m_remain
==
2
)
{
GiStoreFloat32
(
r0
,
d8d9
);
GiStoreFloat32
(
r0
+
4
,
d10d11
);
GiStoreFloat32
(
r0
+
8
,
d12d13
);
GiStoreFloat32
(
r1
,
d14d15
);
GiStoreFloat32
(
r1
+
4
,
d16d17
);
GiStoreFloat32
(
r1
+
8
,
d18d19
);
}
else
if
(
m_remain
==
1
)
{
GiStoreFloat32
(
r0
,
d8d9
);
GiStoreFloat32
(
r0
+
4
,
d10d11
);
GiStoreFloat32
(
r0
+
8
,
d12d13
);
}
}
void
kern_4x4
(
const
float
*
packA
,
const
float
*
packB
,
int
K
,
float
*
output
,
int
LDC
,
bool
is_first_k
,
int
m_remain
,
int
n_remain
)
{
const
float
*
a_ptr
=
packA
;
const
float
*
b_ptr
=
packB
;
int
oddk
=
(
K
&
1
);
K
=
((
K
+
1
)
/
2
)
-
1
;
float
*
r0
=
output
;
float
*
r1
=
r0
+
LDC
;
float
*
r2
=
r1
+
LDC
;
float
*
r3
=
r2
+
LDC
;
size_t
d_size
=
sizeof
(
float
);
GI_FLOAT32_t
d0d1
,
d2d3
,
d4d5
,
d6d7
,
d8d9
,
d10d11
,
d12d13
,
d14d15
;
float
tmp
[
4
];
if
(
is_first_k
)
{
d8d9
=
GiBroadcastFloat32
(
0.0
f
);
d10d11
=
GiBroadcastFloat32
(
0.0
f
);
d12d13
=
GiBroadcastFloat32
(
0.0
f
);
d14d15
=
GiBroadcastFloat32
(
0.0
f
);
}
else
{
if
(
m_remain
==
4
)
{
if
(
n_remain
==
4
)
{
d8d9
=
GiLoadFloat32
(
r0
);
d10d11
=
GiLoadFloat32
(
r1
);
d12d13
=
GiLoadFloat32
(
r2
);
d14d15
=
GiLoadFloat32
(
r3
);
}
else
if
(
n_remain
==
3
)
{
memcpy
(
tmp
,
r0
,
d_size
*
3
);
r0
+=
3
;
d8d9
=
GiLoadFloat32
(
tmp
);
memcpy
(
tmp
,
r1
,
d_size
*
3
);
r1
+=
3
;
d10d11
=
GiLoadFloat32
(
tmp
);
memcpy
(
tmp
,
r2
,
d_size
*
3
);
r2
+=
3
;
d12d13
=
GiLoadFloat32
(
tmp
);
memcpy
(
tmp
,
r3
,
d_size
*
3
);
r3
+=
3
;
d14d15
=
GiLoadFloat32
(
tmp
);
}
else
if
(
n_remain
==
2
)
{
memcpy
(
tmp
,
r0
,
d_size
*
2
);
r0
+=
2
;
d8d9
=
GiLoadFloat32
(
tmp
);
memcpy
(
tmp
,
r1
,
d_size
*
2
);
r1
+=
2
;
d10d11
=
GiLoadFloat32
(
tmp
);
memcpy
(
tmp
,
r2
,
d_size
*
2
);
r2
+=
2
;
d12d13
=
GiLoadFloat32
(
tmp
);
memcpy
(
tmp
,
r3
,
d_size
*
2
);
r3
+=
2
;
d14d15
=
GiLoadFloat32
(
tmp
);
}
else
if
(
n_remain
==
1
)
{
tmp
[
0
]
=
*
r0
;
r0
++
;
d8d9
=
GiLoadFloat32
(
tmp
);
tmp
[
0
]
=
*
r1
;
r1
++
;
d10d11
=
GiLoadFloat32
(
tmp
);
tmp
[
0
]
=
*
r2
;
r2
++
;
d12d13
=
GiLoadFloat32
(
tmp
);
tmp
[
0
]
=
*
r3
;
r3
++
;
d14d15
=
GiLoadFloat32
(
tmp
);
}
}
else
if
(
m_remain
==
3
)
{
if
(
n_remain
==
4
)
{
d8d9
=
GiLoadFloat32
(
r0
);
d10d11
=
GiLoadFloat32
(
r1
);
d12d13
=
GiLoadFloat32
(
r2
);
}
else
if
(
n_remain
==
3
)
{
memcpy
(
tmp
,
r0
,
d_size
*
3
);
r0
+=
3
;
d8d9
=
GiLoadFloat32
(
tmp
);
memcpy
(
tmp
,
r1
,
d_size
*
3
);
r1
+=
3
;
d10d11
=
GiLoadFloat32
(
tmp
);
memcpy
(
tmp
,
r2
,
d_size
*
3
);
r2
+=
3
;
d12d13
=
GiLoadFloat32
(
tmp
);
}
else
if
(
n_remain
==
2
)
{
memcpy
(
tmp
,
r0
,
d_size
*
2
);
r0
+=
2
;
d8d9
=
GiLoadFloat32
(
tmp
);
memcpy
(
tmp
,
r1
,
d_size
*
2
);
r1
+=
2
;
d10d11
=
GiLoadFloat32
(
tmp
);
memcpy
(
tmp
,
r2
,
d_size
*
2
);
r2
+=
2
;
d12d13
=
GiLoadFloat32
(
tmp
);
}
else
if
(
n_remain
==
1
)
{
tmp
[
0
]
=
*
r0
;
r0
++
;
d8d9
=
GiLoadFloat32
(
tmp
);
tmp
[
0
]
=
*
r1
;
r1
++
;
d10d11
=
GiLoadFloat32
(
tmp
);
tmp
[
0
]
=
*
r2
;
r2
++
;
d12d13
=
GiLoadFloat32
(
tmp
);
}
}
else
if
(
m_remain
==
2
)
{
if
(
n_remain
==
4
)
{
d8d9
=
GiLoadFloat32
(
r0
);
d10d11
=
GiLoadFloat32
(
r1
);
}
else
if
(
n_remain
==
3
)
{
memcpy
(
tmp
,
r0
,
d_size
*
3
);
r0
+=
3
;
d8d9
=
GiLoadFloat32
(
tmp
);
memcpy
(
tmp
,
r1
,
d_size
*
3
);
r1
+=
3
;
d10d11
=
GiLoadFloat32
(
tmp
);
}
else
if
(
n_remain
==
2
)
{
memcpy
(
tmp
,
r0
,
d_size
*
2
);
r0
+=
2
;
d8d9
=
GiLoadFloat32
(
tmp
);
memcpy
(
tmp
,
r1
,
d_size
*
2
);
r1
+=
2
;
d10d11
=
GiLoadFloat32
(
tmp
);
}
else
if
(
n_remain
==
1
)
{
tmp
[
0
]
=
*
r0
;
r0
++
;
d8d9
=
GiLoadFloat32
(
tmp
);
tmp
[
0
]
=
*
r1
;
r1
++
;
d10d11
=
GiLoadFloat32
(
tmp
);
}
}
else
if
(
m_remain
==
1
)
{
if
(
n_remain
==
4
)
{
d8d9
=
GiLoadFloat32
(
r0
);
}
else
if
(
n_remain
==
3
)
{
memcpy
(
tmp
,
r0
,
d_size
*
3
);
r0
+=
3
;
d8d9
=
GiLoadFloat32
(
tmp
);
}
else
if
(
n_remain
==
2
)
{
memcpy
(
tmp
,
r0
,
d_size
*
2
);
r0
+=
2
;
d8d9
=
GiLoadFloat32
(
tmp
);
}
else
if
(
n_remain
==
1
)
{
tmp
[
0
]
=
*
r0
;
r0
++
;
d8d9
=
GiLoadFloat32
(
tmp
);
}
}
}
d0d1
=
GiLoadFloat32
(
a_ptr
);
a_ptr
=
a_ptr
+
4
;
d4d5
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
for
(;
K
>
0
;
K
--
)
{
d2d3
=
GiLoadFloat32
(
a_ptr
);
a_ptr
=
a_ptr
+
4
;
d6d7
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
d8d9
=
GiSimdFmaLane
(
d8d9
,
d4d5
,
d0d1
,
0
);
d10d11
=
GiSimdFmaLane
(
d10d11
,
d4d5
,
d0d1
,
1
);
d12d13
=
GiSimdFmaLane
(
d12d13
,
d4d5
,
d0d1
,
2
);
d14d15
=
GiSimdFmaLane
(
d14d15
,
d4d5
,
d0d1
,
3
);
d0d1
=
GiLoadFloat32
(
a_ptr
);
a_ptr
=
a_ptr
+
4
;
d4d5
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
d8d9
=
GiSimdFmaLane
(
d8d9
,
d6d7
,
d2d3
,
0
);
d10d11
=
GiSimdFmaLane
(
d10d11
,
d6d7
,
d2d3
,
1
);
d12d13
=
GiSimdFmaLane
(
d12d13
,
d6d7
,
d2d3
,
2
);
d14d15
=
GiSimdFmaLane
(
d14d15
,
d6d7
,
d2d3
,
3
);
}
if
(
1
==
oddk
)
{
d8d9
=
GiSimdFmaLane
(
d8d9
,
d4d5
,
d0d1
,
0
);
d10d11
=
GiSimdFmaLane
(
d10d11
,
d4d5
,
d0d1
,
1
);
d12d13
=
GiSimdFmaLane
(
d12d13
,
d4d5
,
d0d1
,
2
);
d14d15
=
GiSimdFmaLane
(
d14d15
,
d4d5
,
d0d1
,
3
);
}
else
{
d2d3
=
GiLoadFloat32
(
a_ptr
);
a_ptr
=
a_ptr
+
4
;
d6d7
=
GiLoadFloat32
(
b_ptr
);
b_ptr
=
b_ptr
+
4
;
d8d9
=
GiSimdFmaLane
(
d8d9
,
d4d5
,
d0d1
,
0
);
d10d11
=
GiSimdFmaLane
(
d10d11
,
d4d5
,
d0d1
,
1
);
d12d13
=
GiSimdFmaLane
(
d12d13
,
d4d5
,
d0d1
,
2
);
d14d15
=
GiSimdFmaLane
(
d14d15
,
d4d5
,
d0d1
,
3
);
d8d9
=
GiSimdFmaLane
(
d8d9
,
d6d7
,
d2d3
,
0
);
d10d11
=
GiSimdFmaLane
(
d10d11
,
d6d7
,
d2d3
,
1
);
d12d13
=
GiSimdFmaLane
(
d12d13
,
d6d7
,
d2d3
,
2
);
d14d15
=
GiSimdFmaLane
(
d14d15
,
d6d7
,
d2d3
,
3
);
}
if
(
m_remain
==
4
)
{
if
(
n_remain
==
4
)
{
GiStoreFloat32
(
r0
,
d8d9
);
r0
=
r0
+
4
;
GiStoreFloat32
(
r1
,
d10d11
);
r1
=
r1
+
4
;
GiStoreFloat32
(
r2
,
d12d13
);
r2
=
r2
+
4
;
GiStoreFloat32
(
r3
,
d14d15
);
r3
=
r3
+
4
;
}
else
if
(
n_remain
==
3
)
{
GiStoreFloat32
(
tmp
,
d8d9
);
memcpy
(
r0
,
tmp
,
d_size
*
3
);
r0
+=
3
;
GiStoreFloat32
(
tmp
,
d10d11
);
memcpy
(
r1
,
tmp
,
d_size
*
3
);
r1
+=
3
;
GiStoreFloat32
(
tmp
,
d12d13
);
memcpy
(
r2
,
tmp
,
d_size
*
3
);
r2
+=
3
;
GiStoreFloat32
(
tmp
,
d14d15
);
memcpy
(
r3
,
tmp
,
d_size
*
3
);
r3
+=
3
;
}
else
if
(
n_remain
==
2
)
{
GiStoreFloat32
(
tmp
,
d8d9
);
memcpy
(
r0
,
tmp
,
d_size
*
2
);
r0
+=
2
;
GiStoreFloat32
(
tmp
,
d10d11
);
memcpy
(
r1
,
tmp
,
d_size
*
2
);
r1
+=
2
;
GiStoreFloat32
(
tmp
,
d12d13
);
memcpy
(
r2
,
tmp
,
d_size
*
2
);
r2
+=
2
;
GiStoreFloat32
(
tmp
,
d14d15
);
memcpy
(
r3
,
tmp
,
d_size
*
2
);
r3
+=
2
;
}
else
if
(
n_remain
==
1
)
{
GiStoreFloat32
(
tmp
,
d8d9
);
*
r0
=
tmp
[
0
];
r0
++
;
GiStoreFloat32
(
tmp
,
d10d11
);
*
r1
=
tmp
[
0
];
r1
++
;
GiStoreFloat32
(
tmp
,
d12d13
);
*
r2
=
tmp
[
0
];
r2
++
;
GiStoreFloat32
(
tmp
,
d14d15
);
*
r3
=
tmp
[
0
];
r3
++
;
}
}
else
if
(
m_remain
==
3
)
{
if
(
n_remain
==
4
)
{
GiStoreFloat32
(
r0
,
d8d9
);
r0
=
r0
+
4
;
GiStoreFloat32
(
r1
,
d10d11
);
r1
=
r1
+
4
;
GiStoreFloat32
(
r2
,
d12d13
);
r2
=
r2
+
4
;
}
else
if
(
n_remain
==
3
)
{
GiStoreFloat32
(
tmp
,
d8d9
);
memcpy
(
r0
,
tmp
,
d_size
*
3
);
r0
+=
3
;
GiStoreFloat32
(
tmp
,
d10d11
);
memcpy
(
r1
,
tmp
,
d_size
*
3
);
r1
+=
3
;
GiStoreFloat32
(
tmp
,
d12d13
);
memcpy
(
r2
,
tmp
,
d_size
*
3
);
r2
+=
3
;
}
else
if
(
n_remain
==
2
)
{
GiStoreFloat32
(
tmp
,
d8d9
);
memcpy
(
r0
,
tmp
,
d_size
*
2
);
r0
+=
2
;
GiStoreFloat32
(
tmp
,
d10d11
);
memcpy
(
r1
,
tmp
,
d_size
*
2
);
r1
+=
2
;
GiStoreFloat32
(
tmp
,
d12d13
);
memcpy
(
r2
,
tmp
,
d_size
*
2
);
r2
+=
2
;
}
else
if
(
n_remain
==
1
)
{
GiStoreFloat32
(
tmp
,
d8d9
);
*
r0
=
tmp
[
0
];
r0
++
;
GiStoreFloat32
(
tmp
,
d10d11
);
*
r1
=
tmp
[
0
];
r1
++
;
GiStoreFloat32
(
tmp
,
d12d13
);
*
r2
=
tmp
[
0
];
r2
++
;
}
}
else
if
(
m_remain
==
2
)
{
if
(
n_remain
==
4
)
{
GiStoreFloat32
(
r0
,
d8d9
);
r0
=
r0
+
4
;
GiStoreFloat32
(
r1
,
d10d11
);
r1
=
r1
+
4
;
}
else
if
(
n_remain
==
3
)
{
GiStoreFloat32
(
tmp
,
d8d9
);
memcpy
(
r0
,
tmp
,
d_size
*
3
);
r0
+=
3
;
GiStoreFloat32
(
tmp
,
d10d11
);
memcpy
(
r1
,
tmp
,
d_size
*
3
);
r1
+=
3
;
}
else
if
(
n_remain
==
2
)
{
GiStoreFloat32
(
tmp
,
d8d9
);
memcpy
(
r0
,
tmp
,
d_size
*
2
);
r0
+=
2
;
GiStoreFloat32
(
tmp
,
d10d11
);
memcpy
(
r1
,
tmp
,
d_size
*
2
);
r1
+=
2
;
}
else
if
(
n_remain
==
1
)
{
GiStoreFloat32
(
tmp
,
d8d9
);
*
r0
=
tmp
[
0
];
r0
++
;
GiStoreFloat32
(
tmp
,
d10d11
);
*
r1
=
tmp
[
0
];
r1
++
;
}
}
else
if
(
m_remain
==
1
)
{
if
(
n_remain
==
4
)
{
GiStoreFloat32
(
r0
,
d8d9
);
r0
=
r0
+
4
;
}
else
if
(
n_remain
==
3
)
{
GiStoreFloat32
(
tmp
,
d8d9
);
memcpy
(
r0
,
tmp
,
d_size
*
3
);
r0
+=
3
;
}
else
if
(
n_remain
==
2
)
{
GiStoreFloat32
(
tmp
,
d8d9
);
memcpy
(
r0
,
tmp
,
d_size
*
2
);
r0
+=
2
;
}
else
if
(
n_remain
==
1
)
{
GiStoreFloat32
(
tmp
,
d8d9
);
*
r0
=
tmp
[
0
];
r0
++
;
}
}
}
#pragma GCC diagnostic pop
void
gi_sgemm_4x12_pack_A_n
(
float
*
outptr
,
const
float
*
inptr
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
)
{
float
zerobuff
[
4
];
std
::
memset
(
zerobuff
,
0
,
sizeof
(
float
)
*
4
);
int
y
=
y0
;
for
(;
y
+
3
<
ymax
;
y
+=
4
)
{
const
float
*
inptr0
=
inptr
+
y
*
ldin
+
k0
;
const
float
*
inptr1
=
inptr0
+
ldin
;
const
float
*
inptr2
=
inptr1
+
ldin
;
const
float
*
inptr3
=
inptr2
+
ldin
;
int
K
=
(
kmax
-
k0
);
for
(;
K
>
3
;
K
-=
4
)
{
transpose_4x4_1_s
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
outptr
);
}
interleave_4
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
outptr
,
1
,
K
);
}
for
(;
y
<
ymax
;
y
+=
4
)
{
const
float
*
inptr0
=
inptr
+
y
*
ldin
+
k0
;
const
float
*
inptr1
=
inptr0
+
ldin
;
const
float
*
inptr2
=
inptr1
+
ldin
;
const
float
*
inptr3
=
inptr2
+
ldin
;
int
K
=
(
kmax
-
k0
);
for
(;
K
>
3
;
K
-=
4
)
{
if
((
y
+
3
)
>=
ymax
)
{
switch
((
y
+
3
)
-
ymax
)
{
/* Everything falls through in here */
case
2
:
inptr1
=
zerobuff
;
MEGDNN_FALLTHRU
case
1
:
inptr2
=
zerobuff
;
MEGDNN_FALLTHRU
case
0
:
inptr3
=
zerobuff
;
break
;
default:
megdnn_assert
(
0
);
}
}
transpose_4x4_1_s
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
outptr
);
}
if
(
K
>
0
)
{
if
((
y
+
3
)
>=
ymax
)
{
switch
((
y
+
3
)
-
ymax
)
{
/* Everything falls through in here */
case
2
:
inptr1
=
zerobuff
;
MEGDNN_FALLTHRU
case
1
:
inptr2
=
zerobuff
;
MEGDNN_FALLTHRU
case
0
:
inptr3
=
zerobuff
;
break
;
default:
megdnn_assert
(
0
);
}
}
interleave_4
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
outptr
,
1
,
K
);
}
}
}
void
gi_sgemm_4x12_pack_A_t
(
float
*
out
,
const
float
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
)
{
int
ksize
=
kmax
-
k0
;
int
ksize4
=
(
ksize
<<
2
);
float
*
outptr_base
=
out
;
int
k
=
k0
;
for
(;
k
+
3
<
kmax
;
k
+=
4
)
{
const
float
*
inptr
=
in
+
k
*
ldin
+
x0
;
const
float
*
inptr1
=
inptr
+
ldin
;
const
float
*
inptr2
=
inptr1
+
ldin
;
const
float
*
inptr3
=
inptr2
+
ldin
;
int
x
=
x0
;
auto
outptr
=
outptr_base
;
for
(;
x
+
4
<=
xmax
;
x
+=
4
)
{
auto
outptr_interleave
=
outptr
;
interleave_4x4_1_s
(
inptr
,
inptr1
,
inptr2
,
inptr3
,
outptr_interleave
);
outptr
+=
ksize4
;
}
if
(
x
<
xmax
)
{
interleave_4
(
inptr
,
inptr1
,
inptr2
,
inptr3
,
outptr
,
4
,
xmax
-
x
);
}
outptr_base
+=
4
*
4
;
}
for
(;
k
<
kmax
;
k
++
)
{
const
float
*
inptr
=
in
+
k
*
ldin
+
x0
;
int
x
=
x0
;
auto
outptr
=
outptr_base
;
for
(;
x
+
4
<=
xmax
;
x
+=
4
)
{
auto
outptr_interleave
=
outptr
;
interleave_1x4_1_s
(
inptr
,
outptr_interleave
);
outptr
+=
ksize4
;
}
if
(
x
<
xmax
)
{
interleave_1
(
inptr
,
outptr
,
4
,
xmax
-
x
);
}
outptr_base
+=
4
;
}
}
void
gi_sgemm_4x12_pack_B_n
(
float
*
out
,
const
float
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
)
{
int
ksize
=
kmax
-
k0
;
int
ksize12
=
ksize
*
12
;
int
ksize4
=
(
ksize
<<
2
);
float
*
outptr_base
=
out
;
float
*
outptr_base4
=
outptr_base
+
(
xmax
-
x0
)
/
12
*
ksize12
;
int
k
=
k0
;
for
(;
k
+
3
<
kmax
;
k
+=
4
)
{
const
float
*
inptr
=
in
+
k
*
ldin
+
x0
;
const
float
*
inptr1
=
inptr
+
ldin
;
const
float
*
inptr2
=
inptr1
+
ldin
;
const
float
*
inptr3
=
inptr2
+
ldin
;
int
x
=
x0
;
auto
outptr
=
outptr_base
;
for
(;
x
+
12
<=
xmax
;
x
+=
12
)
{
auto
outptr_interleave
=
outptr
;
interleave_4x12_1_s
(
inptr
,
inptr1
,
inptr2
,
inptr3
,
outptr_interleave
);
outptr
+=
ksize12
;
}
outptr
=
outptr_base4
;
for
(;
x
+
4
<=
xmax
;
x
+=
4
)
{
auto
outptr_interleave
=
outptr
;
interleave_4x4_1_s
(
inptr
,
inptr1
,
inptr2
,
inptr3
,
outptr_interleave
);
outptr
+=
ksize4
;
}
if
(
x
<
xmax
)
{
interleave_4
(
inptr
,
inptr1
,
inptr2
,
inptr3
,
outptr
,
4
,
xmax
-
x
);
}
outptr_base
+=
12
*
4
;
outptr_base4
+=
4
*
4
;
}
for
(;
k
<
kmax
;
k
++
)
{
const
float
*
inptr
=
in
+
k
*
ldin
+
x0
;
int
x
=
x0
;
auto
outptr
=
outptr_base
;
for
(;
x
+
12
<=
xmax
;
x
+=
12
)
{
auto
outptr_interleave
=
outptr
;
interleave_1x12_1_s
(
inptr
,
outptr_interleave
);
outptr
+=
ksize12
;
}
outptr
=
outptr_base4
;
for
(;
x
+
4
<=
xmax
;
x
+=
4
)
{
auto
outptr_interleave
=
outptr
;
interleave_1x4_1_s
(
inptr
,
outptr_interleave
);
outptr
+=
ksize4
;
}
if
(
x
<
xmax
)
{
interleave_1
(
inptr
,
outptr
,
4
,
xmax
-
x
);
}
outptr_base
+=
12
;
outptr_base4
+=
4
;
}
}
void
gi_sgemm_4x12_pack_B_t
(
float
*
out
,
const
float
*
in
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
)
{
float
*
outptr
=
out
;
const
float
*
inptr
=
in
;
float
zerobuff
[
4
];
std
::
memset
(
zerobuff
,
0
,
sizeof
(
float
)
*
4
);
int
K12
=
12
*
(
kmax
-
k0
);
int
y
=
y0
;
for
(;
y
+
12
<=
ymax
;
y
+=
12
)
{
int
yi
=
y
;
for
(;
yi
<
y
+
12
;
yi
+=
4
)
{
const
float
*
inptr0
=
inptr
+
yi
*
ldin
+
k0
;
const
float
*
inptr1
=
inptr0
+
ldin
;
const
float
*
inptr2
=
inptr1
+
ldin
;
const
float
*
inptr3
=
inptr2
+
ldin
;
float
*
outptr_inner
=
outptr
+
yi
-
y
;
int
x
=
(
kmax
-
k0
);
for
(;
x
>
3
;
x
-=
4
)
{
transpose_4x4_1_s
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
outptr_inner
,
48
);
}
for
(;
x
>
0
;
x
--
)
{
*
outptr_inner
++
=
*
inptr0
++
;
*
outptr_inner
++
=
*
inptr1
++
;
*
outptr_inner
++
=
*
inptr2
++
;
*
outptr_inner
++
=
*
inptr3
++
;
outptr_inner
+=
8
;
}
}
outptr
+=
K12
;
}
for
(;
y
<
ymax
;
y
+=
4
)
{
const
float
*
inptr0
=
inptr
+
y
*
ldin
+
k0
;
const
float
*
inptr1
=
inptr0
+
ldin
;
const
float
*
inptr2
=
inptr1
+
ldin
;
const
float
*
inptr3
=
inptr2
+
ldin
;
/* Cope with ragged cases by copying from a buffer of zeroes instead
*/
int
x
=
(
kmax
-
k0
);
for
(;
x
>
3
;
x
-=
4
)
{
if
((
y
+
3
)
>=
ymax
)
{
switch
((
y
+
3
)
-
ymax
)
{
/* Everything falls through in here */
case
2
:
inptr1
=
zerobuff
;
MEGDNN_FALLTHRU
case
1
:
inptr2
=
zerobuff
;
MEGDNN_FALLTHRU
case
0
:
inptr3
=
zerobuff
;
break
;
default:
megdnn_assert
(
0
);
}
}
transpose_4x4_1_s
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
outptr
);
}
if
(
x
>
0
)
{
if
((
y
+
3
)
>=
ymax
)
{
switch
((
y
+
3
)
-
ymax
)
{
/* Everything falls through in here */
case
2
:
inptr1
=
zerobuff
;
MEGDNN_FALLTHRU
case
1
:
inptr2
=
zerobuff
;
MEGDNN_FALLTHRU
case
0
:
inptr3
=
zerobuff
;
break
;
default:
megdnn_assert
(
0
);
}
}
interleave_4
(
inptr0
,
inptr1
,
inptr2
,
inptr3
,
outptr
,
1
,
x
);
}
}
}
}
// namespace
MEGDNN_REG_GEMM_STRATEGY_IMPL
(
gi_sgemm_4x12
);
void
gi_sgemm_4x12
::
pack_A
(
float
*
out
,
const
float
*
in
,
int
ldin
,
int
y0
,
int
ymax
,
int
k0
,
int
kmax
,
bool
transpose_A
)
const
{
if
(
transpose_A
)
{
gi_sgemm_4x12_pack_A_t
(
out
,
in
,
ldin
,
y0
,
ymax
,
k0
,
kmax
);
}
else
{
gi_sgemm_4x12_pack_A_n
(
out
,
in
,
ldin
,
y0
,
ymax
,
k0
,
kmax
);
}
}
void
gi_sgemm_4x12
::
pack_B
(
float
*
out
,
const
float
*
in
,
int
ldin
,
int
x0
,
int
xmax
,
int
k0
,
int
kmax
,
bool
transpose_B
)
const
{
if
(
transpose_B
)
{
gi_sgemm_4x12_pack_B_t
(
out
,
in
,
ldin
,
x0
,
xmax
,
k0
,
kmax
);
}
else
{
gi_sgemm_4x12_pack_B_n
(
out
,
in
,
ldin
,
x0
,
xmax
,
k0
,
kmax
);
}
}
void
gi_sgemm_4x12
::
kern
(
const
float
*
packA
,
const
float
*
packB
,
size_t
M
,
size_t
N
,
size_t
K
,
float
*
C
,
size_t
LDC
,
bool
is_first_k
,
const
float
*
,
float
*
)
const
{
megdnn_assert
(
A_dtype
.
enumv
()
==
B_dtype
.
enumv
()
&&
A_dtype
.
enumv
()
==
C_dtype
.
enumv
()
&&
A_dtype
.
enumv
()
==
DTypeEnum
::
Float32
);
MEGDNN_MARK_USED_VAR
(
A_dtype
);
MEGDNN_MARK_USED_VAR
(
B_dtype
);
MEGDNN_MARK_USED_VAR
(
C_dtype
);
constexpr
size_t
A_INTERLEAVE
=
4
;
constexpr
size_t
B_INTERLEAVE
=
12
;
const
int
K12
=
K
*
12
;
const
int
K4
=
K
*
4
;
size_t
m
=
0
;
for
(;
m
<
M
;
m
+=
A_INTERLEAVE
)
{
float
*
output
=
C
+
(
m
*
LDC
);
size_t
n
=
0
;
const
float
*
cur_packB
=
packB
;
for
(;
n
+
B_INTERLEAVE
-
1
<
N
;
n
+=
B_INTERLEAVE
)
{
kern_4x12
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
std
::
min
<
size_t
>
(
M
-
m
,
4
));
output
+=
B_INTERLEAVE
;
cur_packB
+=
K12
;
}
for
(;
n
<
N
;
n
+=
4
)
{
kern_4x4
(
packA
,
cur_packB
,
K
,
output
,
LDC
,
is_first_k
,
std
::
min
<
size_t
>
(
M
-
m
,
4
),
std
::
min
<
size_t
>
(
N
-
n
,
4
));
output
+=
4
;
cur_packB
+=
K4
;
}
packA
+=
K4
;
}
}
// vim: syntax=cpp.doxygen
dnn/src/fallback/matrix_mul/opr_impl.cpp
浏览文件 @
f249d387
...
...
@@ -28,16 +28,18 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoNaive
naive
;
AlgoF32GiGemvMK4
f32_gemv_mk4
;
AlgoF32GiMK4_4x8
f32_mk4_4x8
;
AlgoF32Gi4x12
f32_4x8
;
SmallVector
<
AlgoBase
*>
m_all_algos
;
AlgoBase
::
Mapper
m_all_algos_map
;
public:
AlgoPack
()
{
m_all_algos
.
emplace_back
(
&
f32_gemv_mk4
);
m_all_algos
.
emplace_back
(
&
f32_mk4_4x8
);
m_all_algos
.
emplace_back
(
&
f32_4x8
);
m_all_algos
.
emplace_back
(
&
gemv
);
m_all_algos
.
emplace_back
(
&
f32_k8x12x1
);
m_all_algos
.
emplace_back
(
&
naive
);
m_all_algos
.
emplace_back
(
&
f32_gemv_mk4
);
m_all_algos
.
emplace_back
(
&
f32_mk4_4x8
);
for
(
auto
&&
algo
:
m_all_algos
)
{
m_all_algos_map
.
emplace
(
algo
->
info
().
desc
,
algo
);
}
...
...
dnn/src/fallback/matrix_mul/opr_impl.h
浏览文件 @
f249d387
...
...
@@ -103,6 +103,7 @@ public:
FB_NAIVE
,
FB_GI_F32_GEMV_MK4
,
FB_GI_F32_MK4_4x8
,
FB_GI_F32_4x12
,
#if MEGDNN_X86
//! x86
...
...
@@ -232,6 +233,7 @@ private:
class
AlgoF32K8x12x1
;
// Fallback F32 Kernel 8x12x1
class
AlgoF32GiGemvMK4
;
// fallback F32 gi Gemv NCHW44
class
AlgoF32GiMK4_4x8
;
// fallback F32 gi Gemm NCHW44
class
AlgoF32Gi4x12
;
// fallback F32 gi Gemm
class
AlgoGemv
;
class
AlgoNaive
;
class
AlgoPack
;
...
...
dnn/test/fallback/matrix_mul.cpp
浏览文件 @
f249d387
...
...
@@ -42,6 +42,12 @@ TEST_F(FALLBACK, MATRIX_MUL_MK4_GI) {
"FB_GI_F32_MK4_4x8"
,
param
::
MatrixMul
::
Format
::
MK4
,
1
);
}
TEST_F
(
FALLBACK
,
MATRIX_MULF_GI_F32_4x12
)
{
matrix_mul
::
check_matrix_mul
(
dtype
::
Float32
{},
dtype
::
Float32
{},
dtype
::
Float32
{},
handle
(),
"FB_GI_F32_4x12"
);
}
TEST_F
(
FALLBACK
,
MATRIX_MUL_RECORD
)
{
TaskRecordChecker
<
MatrixMul
>
checker
(
1
);
using
Param
=
MatrixMul
::
Param
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录