Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
5d950063
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
5d950063
编写于
6月 10, 2020
作者:
M
Megvii Engine Team
提交者:
Xu Xinran
6月 19, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn): refactor dot gemv for both aarch64 and aarch32
GitOrigin-RevId: 2b98867e4563bffe69f0340676e1a6c9ca8d0a2d
上级
53c288a3
变更
17
显示空白变更内容
内联
并排
Showing
17 changed file
with
122 addition
and
239 deletion
+122
-239
dnn/src/aarch64/matrix_mul/algos.cpp
dnn/src/aarch64/matrix_mul/algos.cpp
+0
-34
dnn/src/aarch64/matrix_mul/algos.h
dnn/src/aarch64/matrix_mul/algos.h
+0
-19
dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp
dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp
+0
-116
dnn/src/aarch64/matrix_mul/int8_dot/gemv.h
dnn/src/aarch64/matrix_mul/int8_dot/gemv.h
+0
-34
dnn/src/aarch64/matrix_mul/opr_impl.cpp
dnn/src/aarch64/matrix_mul/opr_impl.cpp
+0
-4
dnn/src/aarch64/matrix_mul/opr_impl.h
dnn/src/aarch64/matrix_mul/opr_impl.h
+0
-2
dnn/src/arm_common/matrix_mul/algos.cpp
dnn/src/arm_common/matrix_mul/algos.cpp
+0
-3
dnn/src/arm_common/matrix_mul/algos.h
dnn/src/arm_common/matrix_mul/algos.h
+0
-5
dnn/src/arm_common/matrix_mul/int8/gemv.cpp
dnn/src/arm_common/matrix_mul/int8/gemv.cpp
+76
-3
dnn/src/arm_common/matrix_mul/int8/gemv.h
dnn/src/arm_common/matrix_mul/int8/gemv.h
+1
-2
dnn/src/arm_common/matrix_mul/opr_impl.cpp
dnn/src/arm_common/matrix_mul/opr_impl.cpp
+2
-1
dnn/src/arm_common/matrix_mul/opr_impl.h
dnn/src/arm_common/matrix_mul/opr_impl.h
+0
-2
dnn/src/arm_common/simd_macro/marm_neon.h
dnn/src/arm_common/simd_macro/marm_neon.h
+13
-0
dnn/src/armv7/matrix_mul/algos.h
dnn/src/armv7/matrix_mul/algos.h
+0
-5
dnn/src/armv7/matrix_mul/opr_impl.cpp
dnn/src/armv7/matrix_mul/opr_impl.cpp
+0
-6
dnn/src/armv7/matrix_mul/opr_impl.h
dnn/src/armv7/matrix_mul/opr_impl.h
+0
-3
dnn/test/arm_common/matrix_mul.cpp
dnn/test/arm_common/matrix_mul.cpp
+30
-0
未找到文件。
dnn/src/aarch64/matrix_mul/algos.cpp
浏览文件 @
5d950063
...
@@ -14,7 +14,6 @@
...
@@ -14,7 +14,6 @@
#include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/aarch64/matrix_mul/int16/strategy.h"
#include "src/aarch64/matrix_mul/int16/strategy.h"
#include "src/aarch64/matrix_mul/int8/strategy.h"
#include "src/aarch64/matrix_mul/int8/strategy.h"
#include "src/aarch64/matrix_mul/int8_dot/gemv.h"
#include "src/aarch64/matrix_mul/int8_dot/strategy.h"
#include "src/aarch64/matrix_mul/int8_dot/strategy.h"
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h"
#include "src/aarch64/matrix_mul/int8x8x16/strategy.h"
#include "src/aarch64/matrix_mul/quint8/strategy.h"
#include "src/aarch64/matrix_mul/quint8/strategy.h"
...
@@ -441,39 +440,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x12x4DotProd,
...
@@ -441,39 +440,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x12x4DotProd,
"AlgoInt8x8x32K8x12x4DotProdImpl"
_hash
,
"AlgoInt8x8x32K8x12x4DotProdImpl"
_hash
,
aarch64
::
matmul
::
gemm_s8_8x12
,
int8_t
,
aarch64
::
matmul
::
gemm_s8_8x12
,
int8_t
,
int32_t
);
int32_t
);
/* ===================== Int8x8x32 Gemv DotProd algo ===================== */
namespace
{
void
int8x8x32_gemv_dotprod_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
auto
M
=
kern_param
.
M
,
N
=
kern_param
.
N
,
K
=
kern_param
.
K
;
auto
LDA
=
kern_param
.
LDA
,
LDB
=
kern_param
.
LDB
,
LDC
=
kern_param
.
LDC
;
const
auto
Aptr
=
kern_param
.
A
<
dt_int8
>
(),
Bptr
=
kern_param
.
B
<
dt_int8
>
();
auto
Cptr
=
kern_param
.
C
<
dt_int32
>
();
aarch64
::
matmul
::
gemv_like_int8
(
Aptr
,
Bptr
,
Cptr
,
M
,
N
,
K
,
LDA
,
LDB
,
LDC
);
}
}
// anonymous namespace
bool
MatrixMulImpl
::
AlgoInt8x8x32GemvDotProd
::
usable
(
const
KernSizeParam
&
kern_size_param
)
const
{
return
can_be_treated_as_int8x8x32
(
kern_size_param
)
&&
!
kern_size_param
.
trA
&&
!
kern_size_param
.
trB
&&
kern_size_param
.
N
==
1
&&
kern_size_param
.
LDB
==
1
;
}
bool
MatrixMulImpl
::
AlgoInt8x8x32GemvDotProd
::
preferred
(
const
KernSizeParam
&
kern_size_param
)
const
{
auto
N
=
kern_size_param
.
N
,
LDB
=
kern_size_param
.
LDB
;
return
(
N
==
1
&&
LDB
==
1
);
}
MatrixMulImpl
::
kern_t
MatrixMulImpl
::
AlgoInt8x8x32GemvDotProd
::
get_kern
(
const
KernSizeParam
&
)
const
{
MIDOUT_BEGIN
(
megdnn_aarch64_matmul_kern
,
midout_iv
(
"AlgoInt8x8x32GemvDotProd::get_kern"
_hash
))
{
return
int8x8x32_gemv_dotprod_kern
;
}
MIDOUT_END
();
return
nullptr
;
}
/* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */
/* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */
namespace
{
namespace
{
...
...
dnn/src/aarch64/matrix_mul/algos.h
浏览文件 @
5d950063
...
@@ -104,21 +104,6 @@ public:
...
@@ -104,21 +104,6 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
};
class
MatrixMulImpl
::
AlgoInt8x8x32GemvDotProd
final
:
public
AlgoBase
{
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"AARCH64_INT8X8X32_GEMV_DOTPROD"
;
}
bool
usable
(
const
KernSizeParam
&
)
const
override
;
bool
preferred
(
const
KernSizeParam
&
)
const
override
;
size_t
get_workspace
(
const
KernSizeParam
&
)
const
override
{
return
0
;
}
kern_t
get_kern
(
const
KernSizeParam
&
)
const
override
;
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
};
class
MatrixMulImpl
::
AlgoInt8x8x32MK4_8x12x4DotProd
final
:
public
AlgoBase
{
class
MatrixMulImpl
::
AlgoInt8x8x32MK4_8x12x4DotProd
final
:
public
AlgoBase
{
public:
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
bool
is_reproducible
()
const
override
{
return
true
;
}
...
@@ -174,10 +159,6 @@ public:
...
@@ -174,10 +159,6 @@ public:
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
void
*
type
()
const
override
{
return
sm_arm_common_algo_type
;
}
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
};
class
MatrixMulImpl
::
AlgoInt8x8x32Gemv
final
:
public
arm_common
::
MatrixMulImpl
::
AlgoInt8x8x32Gemv
{};
#endif
#endif
class
MatrixMulImpl
::
AlgoInt8x8x16K8x8x8
final
:
public
AlgoBase
{
class
MatrixMulImpl
::
AlgoInt8x8x16K8x8x8
final
:
public
AlgoBase
{
...
...
dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp
已删除
100644 → 0
浏览文件 @
53c288a3
/**
* \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/aarch64/matrix_mul/int8_dot/gemv.h"
#include <cstddef>
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h"
#include "src/common/unroll_macro.h"
#if __ARM_FEATURE_DOTPROD
namespace
{
void
gemv_naive_n
(
const
int8_t
*
__restrict
A
,
const
int8_t
*
__restrict
B
,
int32_t
*
__restrict
C
,
size_t
M
,
size_t
N
,
size_t
K
,
size_t
Astride
,
size_t
Bstride
,
size_t
Cstride
)
{
megdnn_assert
(
N
==
1
&&
Bstride
==
1
);
size_t
m
=
0
;
for
(;
m
+
2
<=
M
;
m
+=
2
)
{
int32_t
acc
[
4
];
int32x4_t
acc_neon
=
vdupq_n_s32
(
0
);
size_t
k
=
0
;
for
(;
k
+
16
<=
K
;
k
+=
16
)
{
int64x2_t
a0
=
vreinterpretq_s64_s8
(
vld1q_s8
(
A
+
m
*
Astride
+
k
));
int64x2_t
a1
=
vreinterpretq_s64_s8
(
vld1q_s8
(
A
+
(
m
+
1
)
*
Astride
+
k
));
//! the first 8 elements is m, the last 8 elements is m + 1
int8x16_t
a2
=
vreinterpretq_s8_s64
(
vzip1q_s64
(
a0
,
a1
));
int8x16_t
a3
=
vreinterpretq_s8_s64
(
vzip2q_s64
(
a0
,
a1
));
int64x2_t
b0
=
vreinterpretq_s64_s8
(
vld1q_s8
(
B
+
k
));
int8x16_t
b2
=
vreinterpretq_s8_s64
(
vzip1q_s64
(
b0
,
b0
));
int8x16_t
b3
=
vreinterpretq_s8_s64
(
vzip2q_s64
(
b0
,
b0
));
acc_neon
=
vdotq_s32
(
acc_neon
,
a2
,
b2
);
acc_neon
=
vdotq_s32
(
acc_neon
,
a3
,
b3
);
}
vst1q_s32
(
acc
,
acc_neon
);
for
(;
k
+
8
<=
K
;
k
+=
8
)
{
int8x8_t
a0
=
vld1_s8
(
A
+
m
*
Astride
+
k
);
int8x8_t
a1
=
vld1_s8
(
A
+
(
m
+
1
)
*
Astride
+
k
);
int8x8_t
b0
=
vld1_s8
(
B
+
k
);
uint32x2_t
zero
=
vdup_n_s32
(
0
);
acc
[
0
]
+=
vaddv_s32
(
vdot_s32
(
zero
,
a0
,
b0
));
zero
=
vdup_n_s32
(
0
);
acc
[
3
]
+=
vaddv_s32
(
vdot_s32
(
zero
,
a1
,
b0
));
}
for
(;
k
<
K
;
++
k
)
{
acc
[
0
]
+=
static_cast
<
int32_t
>
(
A
[
m
*
Astride
+
k
])
*
B
[
k
];
acc
[
3
]
+=
static_cast
<
int32_t
>
(
A
[(
m
+
1
)
*
Astride
+
k
])
*
B
[
k
];
}
C
[
m
*
Cstride
]
=
acc
[
0
]
+
acc
[
1
];
C
[(
m
+
1
)
*
Cstride
]
=
acc
[
2
]
+
acc
[
3
];
}
for
(;
m
<
M
;
++
m
)
{
int32_t
acc
[
4
];
int32x4_t
acc_neon
=
vdupq_n_s32
(
0
);
size_t
k
=
0
;
for
(;
k
+
16
<=
K
;
k
+=
16
)
{
int8x16_t
a0
=
vld1q_s8
(
A
+
m
*
Astride
+
k
);
int8x16_t
b0
=
vld1q_s8
(
B
+
k
);
acc_neon
=
vdotq_s32
(
acc_neon
,
a0
,
b0
);
}
vst1q_s32
(
acc
,
acc_neon
);
for
(;
k
+
8
<=
K
;
k
+=
8
)
{
int8x8_t
a0
=
vld1_s8
(
A
+
m
*
Astride
+
k
);
int8x8_t
b0
=
vld1_s8
(
B
+
k
);
uint32x2_t
zero
=
vdup_n_s32
(
0
);
acc
[
0
]
+=
vaddv_s32
(
vdot_s32
(
zero
,
a0
,
b0
));
}
for
(;
k
<
K
;
++
k
)
{
acc
[
0
]
+=
static_cast
<
int32_t
>
(
A
[
m
*
Astride
+
k
])
*
B
[
k
];
}
C
[
m
*
Cstride
]
=
acc
[
0
]
+
acc
[
1
]
+
acc
[
2
]
+
acc
[
3
];
}
}
}
// namespace
bool
megdnn
::
aarch64
::
matmul
::
is_gemv_like_preferred_int8
(
bool
transposeA
,
bool
transposeB
,
size_t
M
,
size_t
N
,
size_t
K
,
size_t
/* LDA */
,
size_t
LDB
,
size_t
/* LDC */
)
{
if
(
transposeA
)
return
false
;
if
(
transposeB
)
return
false
;
MEGDNN_MARK_USED_VAR
(
K
);
MEGDNN_MARK_USED_VAR
(
M
);
return
(
N
==
1
&&
LDB
==
1
);
}
void
megdnn
::
aarch64
::
matmul
::
gemv_like_int8
(
const
int8_t
*
__restrict
A
,
const
int8_t
*
__restrict
B
,
int32_t
*
__restrict
C
,
size_t
M
,
size_t
N
,
size_t
K
,
size_t
Astride
,
size_t
Bstride
,
size_t
Cstride
)
{
megdnn_assert
(
N
==
1
);
return
gemv_naive_n
(
A
,
B
,
C
,
M
,
N
,
K
,
Astride
,
Bstride
,
Cstride
);
}
#endif
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/int8_dot/gemv.h
已删除
100644 → 0
浏览文件 @
53c288a3
/**
* \file dnn/src/aarch64/matrix_mul/int8_dot/gemv.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include <cstddef>
#include <cstdint>
#if __ARM_FEATURE_DOTPROD
namespace
megdnn
{
namespace
aarch64
{
namespace
matmul
{
bool
is_gemv_like_preferred_int8
(
bool
transposeA
,
bool
transposeB
,
size_t
M
,
size_t
N
,
size_t
K
,
size_t
LDA
,
size_t
LDB
,
size_t
LDC
);
void
gemv_like_int8
(
const
int8_t
*
__restrict
A
,
const
int8_t
*
__restrict
B
,
int32_t
*
__restrict
C
,
size_t
M
,
size_t
N
,
size_t
K
,
size_t
Astride
,
size_t
Bstride
,
size_t
Cstride
);
}
// namespace matmul
}
// namespace aarch64
}
// namespace megdnn
#endif
// vim: syntax=cpp.doxygen
dnn/src/aarch64/matrix_mul/opr_impl.cpp
浏览文件 @
5d950063
...
@@ -28,13 +28,11 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
...
@@ -28,13 +28,11 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#endif
#endif
#if __ARM_FEATURE_DOTPROD
#if __ARM_FEATURE_DOTPROD
AlgoInt8x8x32K8x12x4DotProd
int8x8x32_k8x12x4_dotprod
;
AlgoInt8x8x32K8x12x4DotProd
int8x8x32_k8x12x4_dotprod
;
AlgoInt8x8x32GemvDotProd
int8x8x32_gemv_dotprod
;
AlgoInt8x8x32MK4_8x12x4DotProd
int8x8x32_mk4_8x12x4_dotprod
;
AlgoInt8x8x32MK4_8x12x4DotProd
int8x8x32_mk4_8x12x4_dotprod
;
#else
#else
AlgoInt8x8x32MK4_4x4x16
int8x8x32_mk4_4x4x16
;
AlgoInt8x8x32MK4_4x4x16
int8x8x32_mk4_4x4x16
;
AlgoInt8x8x32K4x4x16
int8x8x32_k4x4x16
;
AlgoInt8x8x32K4x4x16
int8x8x32_k4x4x16
;
AlgoInt8x8x32K8x8x8
int8x8x32_k8x8x8
;
AlgoInt8x8x32K8x8x8
int8x8x32_k8x8x8
;
AlgoInt8x8x32Gemv
int8x8x32_gemv
;
#endif
#endif
AlgoInt8x8x16K8x8x8
int8x8x16_k8x8x8
;
AlgoInt8x8x16K8x8x8
int8x8x16_k8x8x8
;
AlgoInt8x8x16K4x4x16
int8x8x16_k4x4x16
;
AlgoInt8x8x16K4x4x16
int8x8x16_k4x4x16
;
...
@@ -63,11 +61,9 @@ public:
...
@@ -63,11 +61,9 @@ public:
all_algos
.
emplace_back
(
&
f16_mk8_8x8
);
all_algos
.
emplace_back
(
&
f16_mk8_8x8
);
#endif
#endif
#if __ARM_FEATURE_DOTPROD
#if __ARM_FEATURE_DOTPROD
all_algos
.
emplace_back
(
&
int8x8x32_gemv_dotprod
);
all_algos
.
emplace_back
(
&
int8x8x32_k8x12x4_dotprod
);
all_algos
.
emplace_back
(
&
int8x8x32_k8x12x4_dotprod
);
all_algos
.
emplace_back
(
&
int8x8x32_mk4_8x12x4_dotprod
);
all_algos
.
emplace_back
(
&
int8x8x32_mk4_8x12x4_dotprod
);
#else
#else
all_algos
.
emplace_back
(
&
int8x8x32_gemv
);
all_algos
.
emplace_back
(
&
int8x8x32_k4x4x16
);
all_algos
.
emplace_back
(
&
int8x8x32_k4x4x16
);
all_algos
.
emplace_back
(
&
int8x8x32_k8x8x8
);
all_algos
.
emplace_back
(
&
int8x8x32_k8x8x8
);
all_algos
.
emplace_back
(
&
int8x8x32_mk4_4x4x16
);
all_algos
.
emplace_back
(
&
int8x8x32_mk4_4x4x16
);
...
...
dnn/src/aarch64/matrix_mul/opr_impl.h
浏览文件 @
5d950063
...
@@ -34,14 +34,12 @@ private:
...
@@ -34,14 +34,12 @@ private:
#if __ARM_FEATURE_DOTPROD
#if __ARM_FEATURE_DOTPROD
class
AlgoInt8x8x32K8x12x4DotProd
;
// Aarch64 Int8x8x32 Kernel
class
AlgoInt8x8x32K8x12x4DotProd
;
// Aarch64 Int8x8x32 Kernel
// 8x12x4 DotProduct
// 8x12x4 DotProduct
class
AlgoInt8x8x32GemvDotProd
;
// Aarch64 Int8x8x32 Gemv DotProduct
class
AlgoInt8x8x32MK4_8x12x4DotProd
;
// Aarch64 nchw44 Int8x8x32 Kernel
class
AlgoInt8x8x32MK4_8x12x4DotProd
;
// Aarch64 nchw44 Int8x8x32 Kernel
// 8x12x4 DotProduct
// 8x12x4 DotProduct
#else
#else
class
AlgoInt8x8x32MK4_4x4x16
;
// Aarch64 nchw44 Int8x8x32 Kernel 4x4x16
class
AlgoInt8x8x32MK4_4x4x16
;
// Aarch64 nchw44 Int8x8x32 Kernel 4x4x16
class
AlgoInt8x8x32K4x4x16
;
// Aarch64 Int8x8x32 Kernel 4x4x16
class
AlgoInt8x8x32K4x4x16
;
// Aarch64 Int8x8x32 Kernel 4x4x16
class
AlgoInt8x8x32K8x8x8
;
// Aarch64 Int8x8x32 Kernel 8x8x8
class
AlgoInt8x8x32K8x8x8
;
// Aarch64 Int8x8x32 Kernel 8x8x8
class
AlgoInt8x8x32Gemv
;
// Aarch64 Int8x8x32 Gemv
#endif
#endif
class
AlgoInt8x8x16K8x8x8
;
// Aarch64 Int8x8x16 Kernel 8x8x8
class
AlgoInt8x8x16K8x8x8
;
// Aarch64 Int8x8x16 Kernel 8x8x8
class
AlgoInt8x8x16K4x4x16
;
// Aarch64 Int8x8x16 Kernel 4x4x16
class
AlgoInt8x8x16K4x4x16
;
// Aarch64 Int8x8x16 Kernel 4x4x16
...
...
dnn/src/arm_common/matrix_mul/algos.cpp
浏览文件 @
5d950063
...
@@ -72,7 +72,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern(
...
@@ -72,7 +72,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16::get_kern(
return
exec_int_8x8x16
;
return
exec_int_8x8x16
;
}
}
#if !__ARM_FEATURE_DOTPROD
/* ===================== Int8x8x32 Gemv algo ===================== */
/* ===================== Int8x8x32 Gemv algo ===================== */
namespace
{
namespace
{
void
int8x8x32_gemv_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
void
int8x8x32_gemv_kern
(
const
MatrixMulImpl
::
KernParam
&
kern_param
)
{
...
@@ -102,7 +101,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern(
...
@@ -102,7 +101,6 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Gemv::get_kern(
const
KernSizeParam
&
)
const
{
const
KernSizeParam
&
)
const
{
return
int8x8x32_gemv_kern
;
return
int8x8x32_gemv_kern
;
}
}
#endif
/* ===================== F32 Gemv algo ===================== */
/* ===================== F32 Gemv algo ===================== */
namespace
{
namespace
{
...
@@ -112,7 +110,6 @@ void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) {
...
@@ -112,7 +110,6 @@ void f32_gemv_kern(const MatrixMulImpl::KernParam& kern_param) {
const
auto
Aptr
=
kern_param
.
A
<
dt_float32
>
(),
const
auto
Aptr
=
kern_param
.
A
<
dt_float32
>
(),
Bptr
=
kern_param
.
B
<
dt_float32
>
();
Bptr
=
kern_param
.
B
<
dt_float32
>
();
auto
Cptr
=
kern_param
.
C
<
dt_float32
>
();
auto
Cptr
=
kern_param
.
C
<
dt_float32
>
();
arm_common
::
sgemm_sgemv_like
(
Aptr
,
Bptr
,
Cptr
,
M
,
N
,
K
,
LDA
,
LDB
,
LDC
);
arm_common
::
sgemm_sgemv_like
(
Aptr
,
Bptr
,
Cptr
,
M
,
N
,
K
,
LDA
,
LDB
,
LDC
);
}
}
}
// anonymous namespace
}
// anonymous namespace
...
...
dnn/src/arm_common/matrix_mul/algos.h
浏览文件 @
5d950063
...
@@ -27,11 +27,7 @@ public:
...
@@ -27,11 +27,7 @@ public:
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
};
};
#if !__ARM_FEATURE_DOTPROD
class
MatrixMulImpl
::
AlgoInt8x8x32Gemv
:
public
AlgoBase
{
class
MatrixMulImpl
::
AlgoInt8x8x32Gemv
:
public
AlgoBase
{
protected:
~
AlgoInt8x8x32Gemv
()
=
default
;
public:
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
bool
is_reproducible
()
const
override
{
return
true
;
}
const
char
*
name
()
const
override
{
return
"ARM_COMMON_INT8X8X32_GEMV"
;
}
const
char
*
name
()
const
override
{
return
"ARM_COMMON_INT8X8X32_GEMV"
;
}
...
@@ -43,7 +39,6 @@ public:
...
@@ -43,7 +39,6 @@ public:
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
AlgoSet
algoset
()
const
override
{
return
AlgoSet
::
ALGO_TYPE_GEMV
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
PackMode
packmode
()
const
override
{
return
PackMode
::
NO_PACK
;
}
};
};
#endif
class
MatrixMulImpl
::
AlgoF32Gemv
:
public
AlgoBase
{
class
MatrixMulImpl
::
AlgoF32Gemv
:
public
AlgoBase
{
protected:
protected:
...
...
dnn/src/arm_common/matrix_mul/int8/gemv.cpp
浏览文件 @
5d950063
...
@@ -9,8 +9,6 @@
...
@@ -9,8 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
*/
#if !__ARM_FEATURE_DOTPROD
#include <cstddef>
#include <cstddef>
#include "src/arm_common/matrix_mul/int8/gemv.h"
#include "src/arm_common/matrix_mul/int8/gemv.h"
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/simd_macro/marm_neon.h"
...
@@ -23,6 +21,8 @@ MIDOUT_DECL(megdnn_arm_common_int8_gemv)
...
@@ -23,6 +21,8 @@ MIDOUT_DECL(megdnn_arm_common_int8_gemv)
using
namespace
megdnn
;
using
namespace
megdnn
;
using
namespace
arm_common
;
using
namespace
arm_common
;
#if !__ARM_FEATURE_DOTPROD
namespace
{
namespace
{
void
gemv_naive_n
(
const
int8_t
*
__restrict
A
,
const
int8_t
*
__restrict
B
,
void
gemv_naive_n
(
const
int8_t
*
__restrict
A
,
const
int8_t
*
__restrict
B
,
...
@@ -95,8 +95,82 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B,
...
@@ -95,8 +95,82 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B,
C
[
m
*
Cstride
]
=
acc0
;
C
[
m
*
Cstride
]
=
acc0
;
}
}
}
}
}
// namespace
#endif
#if __ARM_FEATURE_DOTPROD
namespace
{
void
gemv_naive_n
(
const
int8_t
*
__restrict
A
,
const
int8_t
*
__restrict
B
,
int32_t
*
__restrict
C
,
size_t
M
,
size_t
N
,
size_t
K
,
size_t
Astride
,
size_t
Bstride
,
size_t
Cstride
)
{
megdnn_assert
(
N
==
1
&&
Bstride
==
1
);
size_t
m
=
0
;
for
(;
m
+
2
<=
M
;
m
+=
2
)
{
int32_t
acc
[
4
];
int32x4_t
acc_neon
=
vdupq_n_s32
(
0
);
size_t
k
=
0
;
for
(;
k
+
16
<=
K
;
k
+=
16
)
{
int64x2_t
a0
=
vreinterpretq_s64_s8
(
vld1q_s8
(
A
+
m
*
Astride
+
k
));
int64x2_t
a1
=
vreinterpretq_s64_s8
(
vld1q_s8
(
A
+
(
m
+
1
)
*
Astride
+
k
));
//! the first 8 elements is m, the last 8 elements is m + 1
int8x16_t
a2
=
vreinterpretq_s8_s64
(
vzip1q_s64
(
a0
,
a1
));
int8x16_t
a3
=
vreinterpretq_s8_s64
(
vzip2q_s64
(
a0
,
a1
));
int64x2_t
b0
=
vreinterpretq_s64_s8
(
vld1q_s8
(
B
+
k
));
int8x16_t
b2
=
vreinterpretq_s8_s64
(
vzip1q_s64
(
b0
,
b0
));
int8x16_t
b3
=
vreinterpretq_s8_s64
(
vzip2q_s64
(
b0
,
b0
));
acc_neon
=
vdotq_s32
(
acc_neon
,
a2
,
b2
);
acc_neon
=
vdotq_s32
(
acc_neon
,
a3
,
b3
);
}
vst1q_s32
(
acc
,
acc_neon
);
for
(;
k
+
8
<=
K
;
k
+=
8
)
{
int8x8_t
a0
=
vld1_s8
(
A
+
m
*
Astride
+
k
);
int8x8_t
a1
=
vld1_s8
(
A
+
(
m
+
1
)
*
Astride
+
k
);
int8x8_t
b0
=
vld1_s8
(
B
+
k
);
uint32x2_t
zero
=
vdup_n_s32
(
0
);
acc
[
0
]
+=
vaddv_s32
(
vdot_s32
(
zero
,
a0
,
b0
));
zero
=
vdup_n_s32
(
0
);
acc
[
3
]
+=
vaddv_s32
(
vdot_s32
(
zero
,
a1
,
b0
));
}
for
(;
k
<
K
;
++
k
)
{
acc
[
0
]
+=
static_cast
<
int32_t
>
(
A
[
m
*
Astride
+
k
])
*
B
[
k
];
acc
[
3
]
+=
static_cast
<
int32_t
>
(
A
[(
m
+
1
)
*
Astride
+
k
])
*
B
[
k
];
}
C
[
m
*
Cstride
]
=
acc
[
0
]
+
acc
[
1
];
C
[(
m
+
1
)
*
Cstride
]
=
acc
[
2
]
+
acc
[
3
];
}
for
(;
m
<
M
;
++
m
)
{
int32_t
acc
[
4
];
int32x4_t
acc_neon
=
vdupq_n_s32
(
0
);
size_t
k
=
0
;
for
(;
k
+
16
<=
K
;
k
+=
16
)
{
int8x16_t
a0
=
vld1q_s8
(
A
+
m
*
Astride
+
k
);
int8x16_t
b0
=
vld1q_s8
(
B
+
k
);
acc_neon
=
vdotq_s32
(
acc_neon
,
a0
,
b0
);
}
vst1q_s32
(
acc
,
acc_neon
);
for
(;
k
+
8
<=
K
;
k
+=
8
)
{
int8x8_t
a0
=
vld1_s8
(
A
+
m
*
Astride
+
k
);
int8x8_t
b0
=
vld1_s8
(
B
+
k
);
uint32x2_t
zero
=
vdup_n_s32
(
0
);
acc
[
0
]
+=
vaddv_s32
(
vdot_s32
(
zero
,
a0
,
b0
));
}
for
(;
k
<
K
;
++
k
)
{
acc
[
0
]
+=
static_cast
<
int32_t
>
(
A
[
m
*
Astride
+
k
])
*
B
[
k
];
}
C
[
m
*
Cstride
]
=
acc
[
0
]
+
acc
[
1
]
+
acc
[
2
]
+
acc
[
3
];
}
}
}
// namespace
}
// namespace
#endif
bool
matmul
::
is_gemv_like_preferred_int8
(
bool
transposeA
,
bool
transposeB
,
bool
matmul
::
is_gemv_like_preferred_int8
(
bool
transposeA
,
bool
transposeB
,
size_t
M
,
size_t
N
,
size_t
K
,
size_t
M
,
size_t
N
,
size_t
K
,
...
@@ -124,6 +198,5 @@ void matmul::gemv_like_int8(const int8_t* __restrict A,
...
@@ -124,6 +198,5 @@ void matmul::gemv_like_int8(const int8_t* __restrict A,
}
MIDOUT_END
();
}
MIDOUT_END
();
}
}
#endif
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/src/arm_common/matrix_mul/int8/gemv.h
浏览文件 @
5d950063
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
#include <cstddef>
#include <cstddef>
#include <cstdint>
#include <cstdint>
#if !__ARM_FEATURE_DOTPROD
namespace
megdnn
{
namespace
megdnn
{
namespace
arm_common
{
namespace
arm_common
{
namespace
matmul
{
namespace
matmul
{
...
@@ -28,6 +27,6 @@ void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B,
...
@@ -28,6 +27,6 @@ void gemv_like_int8(const int8_t* __restrict A, const int8_t* __restrict B,
}
// namespace matmul
}
// namespace matmul
}
// namespace arm_common
}
// namespace arm_common
}
// namespace megdnn
}
// namespace megdnn
#endif
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen
dnn/src/arm_common/matrix_mul/opr_impl.cpp
浏览文件 @
5d950063
...
@@ -27,13 +27,14 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
...
@@ -27,13 +27,14 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoF16Gemv
f16gemv
;
AlgoF16Gemv
f16gemv
;
#endif
#endif
AlgoInt8x8x32Gemv
int8x8x32_gemv
;
public:
public:
AlgoPack
()
{
AlgoPack
()
{
all_algos
.
emplace_back
(
&
int8x8x16
);
all_algos
.
emplace_back
(
&
int8x8x16
);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
all_algos
.
emplace_back
(
&
f16gemv
);
all_algos
.
emplace_back
(
&
f16gemv
);
#endif
#endif
all_algos
.
emplace_back
(
&
int8x8x32_gemv
);
}
}
SmallVector
<
AlgoBase
*>
all_algos
;
SmallVector
<
AlgoBase
*>
all_algos
;
};
};
...
...
dnn/src/arm_common/matrix_mul/opr_impl.h
浏览文件 @
5d950063
...
@@ -25,9 +25,7 @@ public:
...
@@ -25,9 +25,7 @@ public:
protected:
protected:
static
void
*
const
sm_arm_common_algo_type
;
static
void
*
const
sm_arm_common_algo_type
;
#if !__ARM_FEATURE_DOTPROD
class
AlgoInt8x8x32Gemv
;
// Arm_common Int 8x8x32 Gemv
class
AlgoInt8x8x32Gemv
;
// Arm_common Int 8x8x32 Gemv
#endif
class
AlgoF32Gemv
;
// Arm_common F32 Gemv
class
AlgoF32Gemv
;
// Arm_common F32 Gemv
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class
AlgoF16Gemv
;
class
AlgoF16Gemv
;
...
...
dnn/src/arm_common/simd_macro/marm_neon.h
浏览文件 @
5d950063
...
@@ -388,6 +388,19 @@ __ai int64x2_t vmovl_high_s32(int32x4_t __p0) {
...
@@ -388,6 +388,19 @@ __ai int64x2_t vmovl_high_s32(int32x4_t __p0) {
__ai
uint64x2_t
vmovl_high_u32
(
uint32x4_t
__p0
)
{
__ai
uint64x2_t
vmovl_high_u32
(
uint32x4_t
__p0
)
{
return
vmovl_u32
(
vget_high_u32
(
__p0
));
return
vmovl_u32
(
vget_high_u32
(
__p0
));
}
}
__ai
int64x2_t
vzip1q_s64
(
int64x2_t
&
a
,
int64x2_t
&
b
)
{
return
vcombine_s64
(
vget_low_s64
(
a
),
vget_low_s64
(
b
));
}
__ai
int64x2_t
vzip2q_s64
(
int64x2_t
&
a
,
int64x2_t
&
b
)
{
return
vcombine_s64
(
vget_high_s64
(
a
),
vget_high_s64
(
b
));
}
__ai
int32_t
vaddv_s32
(
int32x2_t
a
)
{
return
vget_lane_s32
(
a
,
0
)
+
vget_lane_s32
(
a
,
1
);
}
#endif // MEGDNN_ARMV7
#endif // MEGDNN_ARMV7
//! pack vmovl_low_xx() on armv7 and armv8
//! pack vmovl_low_xx() on armv7 and armv8
...
...
dnn/src/armv7/matrix_mul/algos.h
浏览文件 @
5d950063
...
@@ -134,11 +134,6 @@ public:
...
@@ -134,11 +134,6 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL
();
};
};
#if !__ARM_FEATURE_DOTPROD
class
MatrixMulImpl
::
AlgoInt8x8x32Gemv
final
:
public
arm_common
::
MatrixMulImpl
::
AlgoInt8x8x32Gemv
{};
#endif
class
MatrixMulImpl
::
AlgoQuint8K4x8x8
final
:
public
AlgoBase
{
class
MatrixMulImpl
::
AlgoQuint8K4x8x8
final
:
public
AlgoBase
{
public:
public:
bool
is_reproducible
()
const
override
{
return
true
;
}
bool
is_reproducible
()
const
override
{
return
true
;
}
...
...
dnn/src/armv7/matrix_mul/opr_impl.cpp
浏览文件 @
5d950063
...
@@ -35,9 +35,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
...
@@ -35,9 +35,6 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt8x8x32MK4_4x2x16
int8x8x32_mk4_4x2x16
;
AlgoInt8x8x32MK4_4x2x16
int8x8x32_mk4_4x2x16
;
AlgoInt8x8x32K4x2x16
int8x8x32_k4x2x16
;
AlgoInt8x8x32K4x2x16
int8x8x32_k4x2x16
;
AlgoInt8x8x32K4x8x8
int8x8x32_k4x8x8
;
AlgoInt8x8x32K4x8x8
int8x8x32_k4x8x8
;
#if !__ARM_FEATURE_DOTPROD
AlgoInt8x8x32Gemv
int8x8x32_gemv
;
#endif
AlgoQuint8K4x8x8
quint8_k4x8x8
;
AlgoQuint8K4x8x8
quint8_k4x8x8
;
AlgoInt8x8x16K4x2x16
int8x8x16_k4x2x16
;
AlgoInt8x8x16K4x2x16
int8x8x16_k4x2x16
;
AlgoInt8x8x16K4x8x8
int8x8x16_k4x8x8
;
AlgoInt8x8x16K4x8x8
int8x8x16_k4x8x8
;
...
@@ -60,9 +57,6 @@ public:
...
@@ -60,9 +57,6 @@ public:
all_algos
.
emplace_back
(
&
int8x8x32_mk4_8x4x4_dotprod
);
all_algos
.
emplace_back
(
&
int8x8x32_mk4_8x4x4_dotprod
);
all_algos
.
emplace_back
(
&
int8_k6x8x4
);
all_algos
.
emplace_back
(
&
int8_k6x8x4
);
all_algos
.
emplace_back
(
&
quint8_k4x8x4
);
all_algos
.
emplace_back
(
&
quint8_k4x8x4
);
#endif
#if !__ARM_FEATURE_DOTPROD
all_algos
.
emplace_back
(
&
int8x8x32_gemv
);
#endif
#endif
all_algos
.
emplace_back
(
&
int8x8x32_mk4_4x2x16
);
all_algos
.
emplace_back
(
&
int8x8x32_mk4_4x2x16
);
all_algos
.
emplace_back
(
&
int8x8x32_k4x2x16
);
all_algos
.
emplace_back
(
&
int8x8x32_k4x2x16
);
...
...
dnn/src/armv7/matrix_mul/opr_impl.h
浏览文件 @
5d950063
...
@@ -27,9 +27,6 @@ private:
...
@@ -27,9 +27,6 @@ private:
class
AlgoInt8x8x32K4x8x8
;
// Armv7 Int8x8x32 Kernel 4x8x8
class
AlgoInt8x8x32K4x8x8
;
// Armv7 Int8x8x32 Kernel 4x8x8
class
AlgoInt8x8x32K4x2x16
;
// Armv7 Int8x8x32 Kernel 4x2x16
class
AlgoInt8x8x32K4x2x16
;
// Armv7 Int8x8x32 Kernel 4x2x16
class
AlgoInt8x8x32MK4_4x2x16
;
// Armv7 Int8x8x32 Kernel MK4 4x2x16
class
AlgoInt8x8x32MK4_4x2x16
;
// Armv7 Int8x8x32 Kernel MK4 4x2x16
#if !__ARM_FEATURE_DOTPROD
class
AlgoInt8x8x32Gemv
;
// Armv7 Int8x8x32 Gemv
#endif
class
AlgoQuint8K4x8x8
;
// Armv7 Quint8 Kernel 4x8x8
class
AlgoQuint8K4x8x8
;
// Armv7 Quint8 Kernel 4x8x8
class
AlgoInt8x8x16K4x2x16
;
// Armv7 Int8x8x16 Kernel 4x2x16
class
AlgoInt8x8x16K4x2x16
;
// Armv7 Int8x8x16 Kernel 4x2x16
class
AlgoInt8x8x16K4x8x8
;
// Armv7 Int8x8x16 Kernel 4x8x8
class
AlgoInt8x8x16K4x8x8
;
// Armv7 Int8x8x16 Kernel 4x8x8
...
...
dnn/test/arm_common/matrix_mul.cpp
浏览文件 @
5d950063
...
@@ -133,6 +133,36 @@ TEST_F(ARM_COMMON, MATRIX_MUL_FP16_TEST) {
...
@@ -133,6 +133,36 @@ TEST_F(ARM_COMMON, MATRIX_MUL_FP16_TEST) {
}
}
#endif
#endif
TEST_F
(
ARM_COMMON
,
QINT8x8x32_GEMV
)
{
Checker
<
MatrixMul
>
checker
(
handle
());
using
Param
=
MatrixMul
::
Param
;
checker
.
set_before_exec_callback
(
AlgoChecker
<
MatrixMul
>
(
"ARM_COMMON_INT8X8X32_GEMV"
));
std
::
unique_ptr
<
RNG
>
rng
=
std
::
make_unique
<
UniformIntRNG
>
(
-
127
,
127
);
checker
.
set_rng
(
0
,
rng
.
get
()).
set_rng
(
1
,
rng
.
get
());
auto
run
=
[
&
](
size_t
M
,
size_t
K
,
size_t
N
)
{
Param
param
;
param
.
transposeA
=
false
;
param
.
transposeB
=
false
;
TensorShape
A
,
B
;
A
=
TensorShape
{
M
,
K
};
B
=
TensorShape
{
K
,
N
};
checker
.
set_param
(
param
)
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
2.5
f
))
.
set_dtype
(
1
,
dtype
::
QuantizedS8
(
2.5
f
))
.
set_dtype
(
2
,
dtype
::
QuantizedS32
(
6.25
f
))
.
execs
({
A
,
B
,
{}});
};
// N = 1
for
(
size_t
M
:
{
1
,
10
,
16
,
33
,
64
})
for
(
size_t
K
:
{
7
,
512
,
1024
})
for
(
size_t
N
:
{
1
})
run
(
M
,
K
,
N
);
}
#if MEGDNN_WITH_BENCHMARK
#if MEGDNN_WITH_BENCHMARK
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录