Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
57a3298d
Mace
项目概览
Xiaomi
/
Mace
通知
107
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
57a3298d
编写于
9月 11, 2018
作者:
李
李滨
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'pack' into 'master'
Pack matmul to improve performance See merge request !789
上级
b50a6635
26592a86
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
1586 addition
and
139 deletion
+1586
-139
.gitlab-ci.yml
.gitlab-ci.yml
+1
-1
mace/core/tensor.h
mace/core/tensor.h
+5
-1
mace/core/testing/test_benchmark_main.cc
mace/core/testing/test_benchmark_main.cc
+2
-1
mace/kernels/gemm_test.cc
mace/kernels/gemm_test.cc
+91
-0
mace/kernels/matmul.h
mace/kernels/matmul.h
+22
-25
mace/kernels/matmul_benchmark.cc
mace/kernels/matmul_benchmark.cc
+56
-4
mace/kernels/sgemm.cc
mace/kernels/sgemm.cc
+1126
-48
mace/kernels/sgemm.h
mace/kernels/sgemm.h
+96
-52
mace/kernels/sgemm_pack_test.cc
mace/kernels/sgemm_pack_test.cc
+167
-0
mace/ops/matmul.h
mace/ops/matmul.h
+5
-1
mace/ops/matmul_benchmark.cc
mace/ops/matmul_benchmark.cc
+13
-4
mace/ops/winograd_transform_benchmark.cc
mace/ops/winograd_transform_benchmark.cc
+2
-2
未找到文件。
.gitlab-ci.yml
浏览文件 @
57a3298d
...
@@ -85,7 +85,7 @@ ndk_versions_compatible_tests:
...
@@ -85,7 +85,7 @@ ndk_versions_compatible_tests:
-
DEFAULT_NDK_PATH=$ANDROID_NDK_HOME
-
DEFAULT_NDK_PATH=$ANDROID_NDK_HOME
-
prefix_path=${DEFAULT_NDK_PATH%android-ndk-*}
-
prefix_path=${DEFAULT_NDK_PATH%android-ndk-*}
-
>
-
>
for ndk in android-ndk-r1
2b android-ndk-r1
5c android-ndk-r16 android-ndk-r17b;
for ndk in android-ndk-r15c android-ndk-r16 android-ndk-r17b;
do
do
new_ndk_path=${prefix_path}${ndk};
new_ndk_path=${prefix_path}${ndk};
if [ "$new_ndk_path" != "$DEFAULT_NDK_PATH" ]; then
if [ "$new_ndk_path" != "$DEFAULT_NDK_PATH" ]; then
...
...
mace/core/tensor.h
浏览文件 @
57a3298d
...
@@ -399,6 +399,10 @@ class Tensor {
...
@@ -399,6 +399,10 @@ class Tensor {
zero_point_
=
zero_point
;
zero_point_
=
zero_point
;
}
}
inline
void
SetIsWeight
(
bool
is_weight
)
{
is_weight_
=
is_weight
;
}
private:
private:
Allocator
*
allocator_
;
Allocator
*
allocator_
;
DataType
dtype_
;
DataType
dtype_
;
...
@@ -409,7 +413,7 @@ class Tensor {
...
@@ -409,7 +413,7 @@ class Tensor {
bool
is_buffer_owner_
;
bool
is_buffer_owner_
;
bool
unused_
;
bool
unused_
;
std
::
string
name_
;
std
::
string
name_
;
const
bool
is_weight_
;
bool
is_weight_
;
float
scale_
;
float
scale_
;
int32_t
zero_point_
;
int32_t
zero_point_
;
...
...
mace/core/testing/test_benchmark_main.cc
浏览文件 @
57a3298d
...
@@ -33,7 +33,8 @@ int main(int argc, char **argv) {
...
@@ -33,7 +33,8 @@ int main(int argc, char **argv) {
// config runtime
// config runtime
mace
::
MaceStatus
status
=
mace
::
SetOpenMPThreadsAndAffinityPolicy
(
mace
::
MaceStatus
status
=
mace
::
SetOpenMPThreadsAndAffinityPolicy
(
FLAGS_omp_num_threads
,
FLAGS_omp_num_threads
,
static_cast
<
mace
::
CPUAffinityPolicy
>
(
FLAGS_cpu_affinity_policy
));
static_cast
<
mace
::
CPUAffinityPolicy
>
(
FLAGS_cpu_affinity_policy
),
true
);
if
(
status
!=
mace
::
MACE_SUCCESS
)
{
if
(
status
!=
mace
::
MACE_SUCCESS
)
{
LOG
(
WARNING
)
<<
"Set openmp or cpu affinity failed."
;
LOG
(
WARNING
)
<<
"Set openmp or cpu affinity failed."
;
}
}
...
...
mace/kernels/gemm_test.cc
浏览文件 @
57a3298d
...
@@ -13,11 +13,13 @@
...
@@ -13,11 +13,13 @@
// limitations under the License.
// limitations under the License.
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <vector>
#include <memory>
#include <memory>
#include <random>
#include <random>
#include "mace/core/types.h"
#include "mace/core/types.h"
#include "mace/kernels/gemm.h"
#include "mace/kernels/gemm.h"
#include "mace/kernels/sgemm.h"
namespace
mace
{
namespace
mace
{
...
@@ -72,6 +74,74 @@ void GemvTest(index_t batch, index_t N, index_t M) {
...
@@ -72,6 +74,74 @@ void GemvTest(index_t batch, index_t N, index_t M) {
}
}
}
}
void
SGemmTest
(
index_t
batch
,
index_t
N
,
index_t
K
,
index_t
M
,
bool
transpose_a
,
bool
transpose_b
)
{
std
::
unique_ptr
<
float
[]
>
A
(
new
float
[
batch
*
N
*
K
]);
std
::
unique_ptr
<
float
[]
>
B
(
new
float
[
batch
*
K
*
M
]);
std
::
unique_ptr
<
float
[]
>
C
(
new
float
[
batch
*
N
*
M
]);
std
::
unique_ptr
<
float
[]
>
C_ref
(
new
float
[
batch
*
N
*
M
]);
std
::
random_device
rd
;
std
::
mt19937
gen
(
rd
());
std
::
normal_distribution
<
float
>
nd
(
0
,
1
);
std
::
generate
(
A
.
get
(),
A
.
get
()
+
batch
*
N
*
K
,
[
&
gen
,
&
nd
]
{
return
nd
(
gen
);
});
std
::
generate
(
B
.
get
(),
B
.
get
()
+
batch
*
K
*
M
,
[
&
gen
,
&
nd
]
{
return
nd
(
gen
);
});
kernels
::
GemmRef
(
A
.
get
(),
B
.
get
(),
batch
,
N
,
K
,
M
,
C_ref
.
get
(),
transpose_a
,
transpose_b
);
kernels
::
MatrixMap
<
const
float
>
matrix_a
;
kernels
::
MatrixMap
<
const
float
>
matrix_b
;
if
(
!
transpose_a
)
{
matrix_a
=
kernels
::
MatrixMap
<
const
float
>
(
batch
,
N
,
K
,
kernels
::
RowMajor
,
A
.
get
());
}
else
{
matrix_a
=
kernels
::
MatrixMap
<
const
float
>
(
batch
,
K
,
N
,
kernels
::
RowMajor
,
A
.
get
());
matrix_a
=
matrix_a
.
transpose
();
}
if
(
!
transpose_b
)
{
matrix_b
=
kernels
::
MatrixMap
<
const
float
>
(
batch
,
K
,
M
,
kernels
::
RowMajor
,
B
.
get
());
}
else
{
matrix_b
=
kernels
::
MatrixMap
<
const
float
>
(
batch
,
M
,
K
,
kernels
::
RowMajor
,
B
.
get
());
matrix_b
=
matrix_b
.
transpose
();
}
kernels
::
MatrixMap
<
float
>
matrix_c
(
batch
,
N
,
M
,
kernels
::
RowMajor
,
C
.
get
());
kernels
::
SGemm
sgemm
;
sgemm
(
matrix_a
,
matrix_b
,
&
matrix_c
);
for
(
int
i
=
0
;
i
<
N
*
M
;
++
i
)
{
EXPECT_NEAR
(
C_ref
[
i
],
C
[
i
],
0.1
);
}
}
}
// namespace
}
// namespace
TEST
(
GEMMTest
,
AlignedWithoutBatch
)
{
TEST
(
GEMMTest
,
AlignedWithoutBatch
)
{
...
@@ -114,4 +184,25 @@ TEST(GEMMTest, gemv) {
...
@@ -114,4 +184,25 @@ TEST(GEMMTest, gemv) {
GemvTest
(
3
,
17
,
63
);
GemvTest
(
3
,
17
,
63
);
}
}
namespace
{
void
TestSGemmTranspose
(
index_t
batch
,
index_t
N
,
index_t
K
,
index_t
M
)
{
SGemmTest
(
batch
,
N
,
K
,
M
,
false
,
false
);
SGemmTest
(
batch
,
N
,
K
,
M
,
true
,
false
);
SGemmTest
(
batch
,
N
,
K
,
M
,
false
,
true
);
SGemmTest
(
batch
,
N
,
K
,
M
,
true
,
true
);
}
}
TEST
(
SGEMMTest
,
UnalignedWithoutBatch
)
{
std
::
vector
<
index_t
>
tests
{
1
,
5
,
14
,
31
,
47
};
for
(
index_t
N
:
tests
)
{
for
(
index_t
K
:
tests
)
{
for
(
index_t
M
:
tests
)
{
TestSGemmTranspose
(
1
,
N
,
K
,
M
);
TestSGemmTranspose
(
16
,
N
,
K
,
M
);
}
}
}
}
}
// namespace mace
}
// namespace mace
mace/kernels/matmul.h
浏览文件 @
57a3298d
...
@@ -32,6 +32,7 @@
...
@@ -32,6 +32,7 @@
#include "mace/kernels/kernel.h"
#include "mace/kernels/kernel.h"
#include "mace/utils/utils.h"
#include "mace/utils/utils.h"
#include "mace/kernels/gemmlowp_util.h"
#include "mace/kernels/gemmlowp_util.h"
#include "mace/kernels/sgemm.h"
#ifdef MACE_ENABLE_OPENCL
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/cl2_header.h"
...
@@ -83,39 +84,34 @@ struct MatMulFunctor : OpKernel {
...
@@ -83,39 +84,34 @@ struct MatMulFunctor : OpKernel {
const
T
*
b_ptr_base
=
B
->
data
<
T
>
();
const
T
*
b_ptr_base
=
B
->
data
<
T
>
();
T
*
c_ptr_base
=
C
->
mutable_data
<
T
>
();
T
*
c_ptr_base
=
C
->
mutable_data
<
T
>
();
memset
(
c_ptr_base
,
0
,
batch
*
height
*
width
*
sizeof
(
T
));
const
index_t
height_a
=
A
->
dim
(
rank
-
2
);
const
index_t
width_a
=
A
->
dim
(
rank
-
1
);
if
(
height
==
1
&&
width
>
1
&&
B
->
is_weight
())
{
const
index_t
height_b
=
B
->
dim
(
rank
-
2
);
// A * B = (B^T * A^T)^T
const
index_t
width_b
=
B
->
dim
(
rank
-
1
);
if
(
!
transpose_b
)
{
if
(
B_transpose_
.
get
()
==
nullptr
)
{
sgemm_
.
Run
(
a_ptr_base
,
B_transpose_
.
reset
(
new
Tensor
(
context_
->
device
()
->
allocator
(),
b_ptr_base
,
DataTypeToEnum
<
T
>::
v
()));
batch
,
B_transpose_
->
Resize
({
batch
,
width
,
K
});
height_a
,
Tensor
::
MappingGuard
guardbt
(
B_transpose_
.
get
());
width_a
,
T
*
bt_ptr_base
=
B_transpose_
->
mutable_data
<
T
>
();
height_b
,
Transpose
(
b_ptr_base
,
K
,
width
,
width
,
bt_ptr_base
);
width_b
,
}
transpose_a
,
Tensor
::
MappingGuard
guardbt
(
B_transpose_
.
get
());
transpose_b
,
T
*
bt_ptr_base
=
B_transpose_
->
mutable_data
<
T
>
();
A
->
is_weight
(),
Gemv
(
bt_ptr_base
,
a_ptr_base
,
batch
,
K
,
width
,
c_ptr_base
);
B
->
is_weight
(),
}
else
{
c_ptr_base
,
Gemv
(
b_ptr_base
,
a_ptr_base
,
batch
,
K
,
width
,
c_ptr_base
);
context_
->
workspace
()
->
GetScratchBuffer
(
D
));
}
}
else
{
Gemm
(
a_ptr_base
,
b_ptr_base
,
batch
,
height
,
K
,
width
,
c_ptr_base
,
transpose_a
,
transpose_b
);
}
return
MACE_SUCCESS
;
return
MACE_SUCCESS
;
}
}
std
::
unique_ptr
<
Tensor
>
B_transpose
_
;
SGemm
sgemm
_
;
};
};
template
<
>
template
<
>
struct
MatMulFunctor
<
CPU
,
uint8_t
>
:
OpKernel
{
struct
MatMulFunctor
<
CPU
,
uint8_t
>
:
OpKernel
{
explicit
MatMulFunctor
(
OpKernelContext
*
context
)
:
OpKernel
(
context
)
{}
explicit
MatMulFunctor
(
OpKernelContext
*
context
)
:
OpKernel
(
context
)
{}
template
<
gemmlowp
::
MapOrder
AOrder
,
gemmlowp
::
MapOrder
BOrder
>
template
<
gemmlowp
::
MapOrder
AOrder
,
gemmlowp
::
MapOrder
BOrder
>
void
MatMulImpl
(
const
Tensor
*
A
,
void
MatMulImpl
(
const
Tensor
*
A
,
const
Tensor
*
B
,
const
Tensor
*
B
,
...
@@ -213,6 +209,7 @@ struct MatMulFunctor<CPU, uint8_t> : OpKernel {
...
@@ -213,6 +209,7 @@ struct MatMulFunctor<CPU, uint8_t> : OpKernel {
template
<
typename
T
>
template
<
typename
T
>
struct
MatMulFunctor
<
DeviceType
::
GPU
,
T
>
:
OpKernel
{
struct
MatMulFunctor
<
DeviceType
::
GPU
,
T
>
:
OpKernel
{
explicit
MatMulFunctor
(
OpKernelContext
*
context
)
:
OpKernel
(
context
)
{}
explicit
MatMulFunctor
(
OpKernelContext
*
context
)
:
OpKernel
(
context
)
{}
MaceStatus
operator
()(
const
Tensor
*
A
,
MaceStatus
operator
()(
const
Tensor
*
A
,
const
Tensor
*
B
,
const
Tensor
*
B
,
Tensor
*
C
,
Tensor
*
C
,
...
...
mace/kernels/matmul_benchmark.cc
浏览文件 @
57a3298d
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include "mace/core/testing/test_benchmark.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/kernels/gemm.h"
#include "mace/kernels/gemm.h"
#include "mace/kernels/gemmlowp_util.h"
#include "mace/kernels/gemmlowp_util.h"
#include "mace/kernels/sgemm.h"
namespace
gemmlowp
{
namespace
gemmlowp
{
...
@@ -107,6 +108,28 @@ void MatmulBenchmark_Mace(int iters, int m, int k, int n) {
...
@@ -107,6 +108,28 @@ void MatmulBenchmark_Mace(int iters, int m, int k, int n) {
}
}
}
}
void
MatmulBenchmark_Mace_SGemm
(
int
iters
,
int
m
,
int
k
,
int
n
)
{
mace
::
testing
::
StopTiming
();
std
::
vector
<
float
>
lhs
(
m
*
k
);
std
::
vector
<
float
>
rhs
(
k
*
n
);
std
::
vector
<
float
>
result
(
m
*
n
);
kernels
::
MatrixMap
<
const
float
>
matrix_lhs
(
1
,
m
,
k
,
RowMajor
,
lhs
.
data
(),
true
);
kernels
::
MatrixMap
<
const
float
>
matrix_rhs
(
1
,
k
,
n
,
RowMajor
,
rhs
.
data
(),
true
);
kernels
::
MatrixMap
<
float
>
matrix_result
(
1
,
m
,
n
,
RowMajor
,
result
.
data
());
kernels
::
SGemm
sgemm
;
sgemm
(
matrix_lhs
,
matrix_rhs
,
&
matrix_result
);
mace
::
testing
::
StartTiming
();
while
(
iters
--
)
{
sgemm
(
matrix_lhs
,
matrix_rhs
,
&
matrix_result
);
}
}
void
MatmulBenchmark_Eigen
(
int
iters
,
int
m
,
int
k
,
int
n
)
{
void
MatmulBenchmark_Eigen
(
int
iters
,
int
m
,
int
k
,
int
n
)
{
mace
::
testing
::
StopTiming
();
mace
::
testing
::
StopTiming
();
Eigen
::
MatrixXf
lhs
=
Eigen
::
MatrixXf
::
Random
(
m
,
k
);
Eigen
::
MatrixXf
lhs
=
Eigen
::
MatrixXf
::
Random
(
m
,
k
);
...
@@ -202,6 +225,7 @@ void MatmulBenchmark_gemmlowp_int32(int iters, int rows, int depth, int cols) {
...
@@ -202,6 +225,7 @@ void MatmulBenchmark_gemmlowp_int32(int iters, int rows, int depth, int cols) {
#define MACE_BM_MATMUL(M, K, N) \
#define MACE_BM_MATMUL(M, K, N) \
MACE_BM_MATMUL_FUNC(M, K, N, Mace, float); \
MACE_BM_MATMUL_FUNC(M, K, N, Mace, float); \
MACE_BM_MATMUL_FUNC(M, K, N, Mace_SGemm, float); \
MACE_BM_MATMUL_FUNC(M, K, N, Eigen, float); \
MACE_BM_MATMUL_FUNC(M, K, N, Eigen, float); \
MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_uint8, uint8_t); \
MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_uint8, uint8_t); \
MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_int32, uint8_t);
MACE_BM_MATMUL_FUNC(M, K, N, gemmlowp_int32, uint8_t);
...
@@ -215,15 +239,43 @@ MACE_BM_MATMUL(15, 384, 384);
...
@@ -215,15 +239,43 @@ MACE_BM_MATMUL(15, 384, 384);
MACE_BM_MATMUL
(
15
,
384
,
1536
);
MACE_BM_MATMUL
(
15
,
384
,
1536
);
MACE_BM_MATMUL
(
15
,
1536
,
384
);
MACE_BM_MATMUL
(
15
,
1536
,
384
);
MACE_BM_MATMUL
(
1
,
384
,
384
);
MACE_BM_MATMUL
(
1
,
256
,
256
);
MACE_BM_MATMUL
(
1
,
384
,
1536
);
MACE_BM_MATMUL
(
1
,
256
,
1536
);
MACE_BM_MATMUL
(
1
,
1536
,
384
);
MACE_BM_MATMUL
(
1
,
1536
,
256
);
MACE_BM_MATMUL
(
1
,
384
,
44678
);
MACE_BM_MATMUL
(
256
,
256
,
1
);
MACE_BM_MATMUL
(
1536
,
256
,
1
);
MACE_BM_MATMUL
(
256
,
1536
,
1
);
MACE_BM_MATMUL
(
29792
,
256
,
1
);
MACE_BM_MATMUL
(
1
,
256
,
29792
);
MACE_BM_MATMUL
(
2
,
256
,
256
);
MACE_BM_MATMUL
(
2
,
256
,
1536
);
MACE_BM_MATMUL
(
2
,
1536
,
256
);
MACE_BM_MATMUL
(
3
,
256
,
256
);
MACE_BM_MATMUL
(
3
,
256
,
1536
);
MACE_BM_MATMUL
(
3
,
1536
,
256
);
MACE_BM_MATMUL
(
4
,
256
,
256
);
MACE_BM_MATMUL
(
4
,
256
,
1536
);
MACE_BM_MATMUL
(
4
,
1536
,
256
);
MACE_BM_MATMUL
(
8
,
256
,
256
);
MACE_BM_MATMUL
(
8
,
256
,
1536
);
MACE_BM_MATMUL
(
8
,
1536
,
256
);
MACE_BM_MATMUL
(
10
,
256
,
256
);
MACE_BM_MATMUL
(
10
,
256
,
1536
);
MACE_BM_MATMUL
(
10
,
1536
,
256
);
MACE_BM_MATMUL
(
15
,
256
,
256
);
MACE_BM_MATMUL
(
15
,
256
,
1536
);
MACE_BM_MATMUL
(
15
,
1536
,
256
);
// Embedding size 128
// Embedding size 128
MACE_BM_MATMUL
(
1
,
128
,
1536
);
MACE_BM_MATMUL
(
1
,
128
,
1536
);
MACE_BM_MATMUL
(
1
,
128
,
44678
);
MACE_BM_MATMUL
(
1
,
128
,
44678
);
// MobileNet
MACE_BM_MATMUL
(
128
,
128
,
3136
);
MACE_BM_MATMUL
(
256
,
256
,
784
);
MACE_BM_MATMUL
(
512
,
512
,
196
);
MACE_BM_MATMUL
(
1024
,
1024
,
49
);
}
// namespace test
}
// namespace test
}
// namespace kernels
}
// namespace kernels
}
// namespace mace
}
// namespace mace
mace/kernels/sgemm.cc
浏览文件 @
57a3298d
...
@@ -12,80 +12,1158 @@
...
@@ -12,80 +12,1158 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <algorithm>
#include <memory>
#include <cstring>
#include <vector>
#include "mace/kernels/sgemm.h"
#include "mace/kernels/sgemm.h"
#include "mace/core/runtime/cpu/cpu_runtime.h"
#if defined(MACE_ENABLE_NEON)
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#include <arm_neon.h>
#endif
#endif
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
#define vaddvq_f32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif
namespace
mace
{
namespace
mace
{
namespace
kernels
{
namespace
kernels
{
void
SGemm
::
operator
()(
const
MatrixMap
<
float
>
&
lhs
,
void
SGemm
::
operator
()(
const
MatrixMap
<
const
float
>
&
lhs
,
const
MatrixMap
<
float
>
&
rhs
,
const
MatrixMap
<
const
float
>
&
rhs
,
MatrixMap
<
float
>
*
result
)
{
MatrixMap
<
float
>
*
result
,
PackedBlock
<
float
>
packed_lhs
;
ScratchBuffer
*
scratch_buffer
)
{
PackLhs
(
lhs
,
&
packed_lhs
);
if
(
rhs
.
col
()
<
lhs
.
row
())
{
MatrixMap
<
const
float
>
lhs_transpose
=
lhs
.
transpose
();
PackedBlock
<
float
>
packed_rhs
;
MatrixMap
<
const
float
>
rhs_transpose
=
rhs
.
transpose
();
PackRhs
(
rhs
,
&
packed_rhs
);
MatrixMap
<
float
>
result_transpose
=
result
->
transpose
();
return
operator
()(
rhs_transpose
,
PackedBlock
<
float
>
packed_result
;
lhs_transpose
,
operator
()(
packed_lhs
,
&
result_transpose
,
packed_rhs
,
scratch_buffer
);
lhs
.
row
(),
}
lhs
.
col
(),
rhs
.
col
(),
if
(
scratch_buffer
!=
nullptr
)
{
&
packed_result
);
scratch_buffer
->
Rewind
();
UnPack
(
packed_result
,
result
);
index_t
total_size
=
result
->
size
();
if
(
!
lhs
.
is_const
())
{
total_size
+=
lhs
.
size
();
}
if
(
!
rhs
.
is_const
())
{
total_size
+=
rhs
.
size
();
}
scratch_buffer
->
GrowSize
(
total_size
*
sizeof
(
float
));
scratch_buffer
->
Rewind
();
if
(
!
lhs
.
is_const
())
{
packed_lhs_
.
reset
(
new
Tensor
(
scratch_buffer
->
Scratch
(
lhs
.
size
()
*
sizeof
(
float
)),
DT_FLOAT
));
}
if
(
!
rhs
.
is_const
())
{
packed_lhs_
.
reset
(
new
Tensor
(
scratch_buffer
->
Scratch
(
rhs
.
size
()
*
sizeof
(
float
)),
DT_FLOAT
));
}
packed_result_
.
reset
(
new
Tensor
(
scratch_buffer
->
Scratch
(
result
->
size
()
*
sizeof
(
float
)),
DT_FLOAT
));
}
if
(
packed_lhs_
.
get
()
==
nullptr
)
{
packed_lhs_
.
reset
(
new
Tensor
(
GetCPUAllocator
(),
DT_FLOAT
));
packed_lhs_
->
Resize
({
lhs
.
size
()});
}
if
(
packed_rhs_
.
get
()
==
nullptr
)
{
packed_rhs_
.
reset
(
new
Tensor
(
GetCPUAllocator
(),
DT_FLOAT
));
packed_rhs_
->
Resize
({
rhs
.
size
()});
}
if
(
packed_result_
.
get
()
==
nullptr
)
{
packed_result_
.
reset
(
new
Tensor
(
GetCPUAllocator
(),
DT_FLOAT
));
packed_result_
->
Resize
({
result
->
size
()});
}
if
(
!
lhs
.
is_const
()
||
!
packed_
)
{
PackLhs
(
lhs
,
packed_lhs_
.
get
());
}
if
(
!
rhs
.
is_const
()
||
!
packed_
)
{
PackRhs
(
rhs
,
packed_rhs_
.
get
());
}
packed_
=
true
;
RunInternal
(
*
packed_lhs_
,
*
packed_rhs_
,
lhs
.
batch
(),
lhs
.
row
(),
lhs
.
col
(),
rhs
.
col
(),
packed_result_
.
get
());
UnPack
(
*
packed_result_
,
result
);
}
void
SGemm
::
Run
(
const
float
*
A
,
const
float
*
B
,
const
index_t
batch
,
const
index_t
height_a
,
const
index_t
width_a
,
const
index_t
height_b
,
const
index_t
width_b
,
const
bool
transpose_a
,
const
bool
transpose_b
,
const
bool
is_a_weight
,
const
bool
is_b_weight
,
float
*
C
,
ScratchBuffer
*
scratch_buffer
)
{
index_t
height_c
=
height_a
;
index_t
width_c
=
width_b
;
if
(
transpose_a
)
{
height_c
=
width_a
;
}
if
(
transpose_b
)
{
width_c
=
height_b
;
}
MatrixMap
<
const
float
>
matrix_a
=
MatrixMap
<
const
float
>
(
batch
,
height_a
,
width_a
,
kernels
::
RowMajor
,
A
,
is_a_weight
);
MatrixMap
<
const
float
>
matrix_b
=
kernels
::
MatrixMap
<
const
float
>
(
batch
,
height_b
,
width_b
,
kernels
::
RowMajor
,
B
,
is_b_weight
);
if
(
transpose_a
)
{
matrix_a
=
matrix_a
.
transpose
();
}
if
(
transpose_b
)
{
matrix_b
=
matrix_b
.
transpose
();
}
MatrixMap
<
float
>
matrix_c
(
batch
,
height_c
,
width_c
,
kernels
::
RowMajor
,
C
);
operator
()(
matrix_a
,
matrix_b
,
&
matrix_c
,
scratch_buffer
);
}
}
void
SGemm
::
operator
()(
const
PackedBlock
<
float
>
&
lhs
,
#if defined(MACE_ENABLE_NEON)
const
PackedBlock
<
float
>
&
rhs
,
#if defined(__aarch64__)
const
index_t
height
,
const
index_t
depth
,
// calculate 8 rows, 4 cols for each depth
const
index_t
width
,
#define MACE_SGEMM_PART_CAL_R8_C4_D1(D, VD, VDN) \
PackedBlock
<
float
>
*
result
)
{
c0 = vfmaq_laneq_f32(c0, b##D, a##VD, 0); \
(
void
)
lhs
;
c1 = vfmaq_laneq_f32(c1, b##D, a##VD, 1); \
(
void
)
rhs
;
c2 = vfmaq_laneq_f32(c2, b##D, a##VD, 2); \
(
void
)
result
;
c3 = vfmaq_laneq_f32(c3, b##D, a##VD, 3); \
(
void
)
height
;
c4 = vfmaq_laneq_f32(c4, b##D, a##VDN, 0); \
(
void
)
depth
;
c5 = vfmaq_laneq_f32(c5, b##D, a##VDN, 1); \
(
void
)
width
;
c6 = vfmaq_laneq_f32(c6, b##D, a##VDN, 2); \
c7 = vfmaq_laneq_f32(c7, b##D, a##VDN, 3);
// calculate 4 rows, 4 cols for each depth
#define MACE_SGEMM_PART_CAL_R4_C4_D1(D) \
c0 = vfmaq_laneq_f32(c0, b##D, a##D, 0); \
c1 = vfmaq_laneq_f32(c1, b##D, a##D, 1); \
c2 = vfmaq_laneq_f32(c2, b##D, a##D, 2); \
c3 = vfmaq_laneq_f32(c3, b##D, a##D, 3);
// calculate 4 cols for 8 depths for each row
#define MACE_SGEMM_PART_CAL_R1_C4_D8(R, VR, VRN) \
c##R = vfmaq_laneq_f32(c##R, b0, a##VR, 0); \
c##R = vfmaq_laneq_f32(c##R, b1, a##VR, 1); \
c##R = vfmaq_laneq_f32(c##R, b2, a##VR, 2); \
c##R = vfmaq_laneq_f32(c##R, b3, a##VR, 3); \
c##R = vfmaq_laneq_f32(c##R, b4, a##VRN, 0); \
c##R = vfmaq_laneq_f32(c##R, b5, a##VRN, 1); \
c##R = vfmaq_laneq_f32(c##R, b6, a##VRN, 2); \
c##R = vfmaq_laneq_f32(c##R, b7, a##VRN, 3);
// calculate 4 cols for 4 depths for each row
#define MACE_SGEMM_PART_CAL_R1_C4_D4(R) \
c##R = vfmaq_laneq_f32(c##R, b0, a##R, 0); \
c##R = vfmaq_laneq_f32(c##R, b1, a##R, 1); \
c##R = vfmaq_laneq_f32(c##R, b2, a##R, 2); \
c##R = vfmaq_laneq_f32(c##R, b3, a##R, 3);
// calculate 8 cols for 4 depths for each row
#define MACE_SGEMM_PART_CAL_R1_C8_D4(VR, VRN, R) \
c##VR = vfmaq_laneq_f32(c##VR, b0, a##R, 0); \
c##VR = vfmaq_laneq_f32(c##VR, b2, a##R, 1); \
c##VR = vfmaq_laneq_f32(c##VR, b4, a##R, 2); \
c##VR = vfmaq_laneq_f32(c##VR, b6, a##R, 3); \
c##VRN = vfmaq_laneq_f32(c##VRN, b1, a##R, 0); \
c##VRN = vfmaq_laneq_f32(c##VRN, b3, a##R, 1); \
c##VRN = vfmaq_laneq_f32(c##VRN, b5, a##R, 2); \
c##VRN = vfmaq_laneq_f32(c##VRN, b7, a##R, 3);
#else
#define MACE_SGEMM_PART_CAL_R8_C4_D1(D, VD, VDN) \
c0 = vmlaq_lane_f32(c0, b##D, vget_low_f32(a##VD), 0); \
c1 = vmlaq_lane_f32(c1, b##D, vget_low_f32(a##VD), 1); \
c2 = vmlaq_lane_f32(c2, b##D, vget_high_f32(a##VD), 0); \
c3 = vmlaq_lane_f32(c3, b##D, vget_high_f32(a##VD), 1); \
c4 = vmlaq_lane_f32(c4, b##D, vget_low_f32(a##VDN), 0); \
c5 = vmlaq_lane_f32(c5, b##D, vget_low_f32(a##VDN), 1); \
c6 = vmlaq_lane_f32(c6, b##D, vget_high_f32(a##VDN), 0); \
c7 = vmlaq_lane_f32(c7, b##D, vget_high_f32(a##VDN), 1);
#define MACE_SGEMM_PART_CAL_R4_C4_D1(D) \
c0 = vmlaq_lane_f32(c0, b##D, vget_low_f32(a##D), 0); \
c1 = vmlaq_lane_f32(c1, b##D, vget_low_f32(a##D), 1); \
c2 = vmlaq_lane_f32(c2, b##D, vget_high_f32(a##D), 0); \
c3 = vmlaq_lane_f32(c3, b##D, vget_high_f32(a##D), 1);
#define MACE_SGEMM_PART_CAL_R1_C4_D8(R, VR, VRN) \
c##R = vmlaq_lane_f32(c##R, b0, vget_low_f32(a##VR), 0); \
c##R = vmlaq_lane_f32(c##R, b1, vget_low_f32(a##VR), 1); \
c##R = vmlaq_lane_f32(c##R, b2, vget_high_f32(a##VR), 0); \
c##R = vmlaq_lane_f32(c##R, b3, vget_high_f32(a##VR), 1); \
c##R = vmlaq_lane_f32(c##R, b4, vget_low_f32(a##VRN), 0); \
c##R = vmlaq_lane_f32(c##R, b5, vget_low_f32(a##VRN), 1); \
c##R = vmlaq_lane_f32(c##R, b6, vget_high_f32(a##VRN), 0); \
c##R = vmlaq_lane_f32(c##R, b7, vget_high_f32(a##VRN), 1);
#define MACE_SGEMM_PART_CAL_R1_C4_D4(R) \
c##R = vmlaq_lane_f32(c##R, b0, vget_low_f32(a##R), 0); \
c##R = vmlaq_lane_f32(c##R, b1, vget_low_f32(a##R), 1); \
c##R = vmlaq_lane_f32(c##R, b2, vget_high_f32(a##R), 0); \
c##R = vmlaq_lane_f32(c##R, b3, vget_high_f32(a##R), 1);
#endif // __aarch64__
#endif // MACE_ENABLE_NEON
void
SGemm
::
RunInternal
(
const
PackedBlock
&
lhs
,
const
PackedBlock
&
rhs
,
const
index_t
batch
,
const
index_t
height
,
const
index_t
depth
,
const
index_t
width
,
PackedBlock
*
result
)
{
const
float
*
lhs_data
=
lhs
.
data
<
float
>
();
const
float
*
rhs_data
=
rhs
.
data
<
float
>
();
float
*
result_data
=
result
->
mutable_data
<
float
>
();
// (8, 8) * (8, 4)
#define MACE_SGEMM_RUN_PER_BATCH \
for (index_t b = 0; b < batch; ++b) { \
RunPerBatch(lhs_data + b * height * depth, \
rhs_data + b * depth * width, \
height, \
depth, \
width, \
result_data + b * height * width); \
}
// (4, 4) * (4, 4)
if
(
batch
>=
MaceOpenMPThreadCount
)
{
#pragma omp parallel for
MACE_SGEMM_RUN_PER_BATCH
}
else
{
MACE_SGEMM_RUN_PER_BATCH
}
// remain
#undef MACE_SGEMM_RUN_PER_BATCH
}
}
void
SGemm
::
PackLhs
(
const
MatrixMap
<
float
>
&
lhs
,
void
SGemm
::
RunPerBatch
(
const
float
*
lhs_data
,
PackedBlock
<
float
>
*
packed_block
)
{
const
float
*
rhs_data
,
const
index_t
height
,
const
index_t
depth
,
const
index_t
width
,
float
*
result_data
)
{
#if defined(MACE_ENABLE_NEON)
const
index_t
block_w
=
width
>>
2
;
const
index_t
remain_w
=
width
-
(
block_w
<<
2
);
#else
const
index_t
remain_w
=
width
;
#endif
#if defined(MACE_ENABLE_NEON)
// TODO(liyin): make better use l2(l1) cache, try to fit as much lhs data as
// as possible to cache, by tiling lhs by height and rhs by width.
// w: 4
#pragma omp parallel for
for
(
index_t
bw
=
0
;
bw
<
block_w
;
++
bw
)
{
index_t
remain_h
=
height
;
index_t
block_h
=
0
;
const
float
*
lhs_ptr
=
lhs_data
;
float
*
res_ptr
=
result_data
+
height
*
(
bw
<<
2
);
#if defined(__aarch64__)
block_h
=
remain_h
>>
3
;
remain_h
-=
(
block_h
<<
3
);
// h: 8
for
(
index_t
bh
=
0
;
bh
<
block_h
;
++
bh
)
{
const
float
*
rhs_ptr
=
rhs_data
+
depth
*
(
bw
<<
2
);
index_t
remain_d
=
depth
;
index_t
block_d
=
remain_d
>>
3
;
remain_d
-=
(
block_d
<<
3
);
float32x4_t
c0
,
c1
,
c2
,
c3
,
c4
,
c5
,
c6
,
c7
;
c0
=
vdupq_n_f32
(
0.
f
);
c1
=
vdupq_n_f32
(
0.
f
);
c2
=
vdupq_n_f32
(
0.
f
);
c3
=
vdupq_n_f32
(
0.
f
);
c4
=
vdupq_n_f32
(
0.
f
);
c5
=
vdupq_n_f32
(
0.
f
);
c6
=
vdupq_n_f32
(
0.
f
);
c7
=
vdupq_n_f32
(
0.
f
);
// d: 8
for
(
index_t
bd
=
0
;
bd
<
block_d
;
++
bd
)
{
// 8.8.4
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
,
a8
,
a9
,
a10
,
a11
,
a12
,
a13
,
a14
,
a15
;
float32x4_t
b0
,
b1
,
b2
,
b3
,
b4
,
b5
,
b6
,
b7
;
a0
=
vld1q_f32
(
lhs_ptr
);
a1
=
vld1q_f32
(
lhs_ptr
+
4
);
a2
=
vld1q_f32
(
lhs_ptr
+
8
);
a3
=
vld1q_f32
(
lhs_ptr
+
12
);
a4
=
vld1q_f32
(
lhs_ptr
+
16
);
a5
=
vld1q_f32
(
lhs_ptr
+
20
);
a6
=
vld1q_f32
(
lhs_ptr
+
24
);
a7
=
vld1q_f32
(
lhs_ptr
+
28
);
a8
=
vld1q_f32
(
lhs_ptr
+
32
);
a9
=
vld1q_f32
(
lhs_ptr
+
36
);
a10
=
vld1q_f32
(
lhs_ptr
+
40
);
a11
=
vld1q_f32
(
lhs_ptr
+
44
);
a12
=
vld1q_f32
(
lhs_ptr
+
48
);
a13
=
vld1q_f32
(
lhs_ptr
+
52
);
a14
=
vld1q_f32
(
lhs_ptr
+
56
);
a15
=
vld1q_f32
(
lhs_ptr
+
60
);
b0
=
vld1q_f32
(
rhs_ptr
);
b1
=
vld1q_f32
(
rhs_ptr
+
4
);
b2
=
vld1q_f32
(
rhs_ptr
+
8
);
b3
=
vld1q_f32
(
rhs_ptr
+
12
);
b4
=
vld1q_f32
(
rhs_ptr
+
16
);
b5
=
vld1q_f32
(
rhs_ptr
+
20
);
b6
=
vld1q_f32
(
rhs_ptr
+
24
);
b7
=
vld1q_f32
(
rhs_ptr
+
28
);
MACE_SGEMM_PART_CAL_R8_C4_D1
(
0
,
0
,
1
);
// d = 1
MACE_SGEMM_PART_CAL_R8_C4_D1
(
1
,
2
,
3
);
// d = 2
MACE_SGEMM_PART_CAL_R8_C4_D1
(
2
,
4
,
5
);
MACE_SGEMM_PART_CAL_R8_C4_D1
(
3
,
6
,
7
);
MACE_SGEMM_PART_CAL_R8_C4_D1
(
4
,
8
,
9
);
MACE_SGEMM_PART_CAL_R8_C4_D1
(
5
,
10
,
11
);
MACE_SGEMM_PART_CAL_R8_C4_D1
(
6
,
12
,
13
);
MACE_SGEMM_PART_CAL_R8_C4_D1
(
7
,
14
,
15
);
lhs_ptr
+=
64
;
rhs_ptr
+=
32
;
}
block_d
=
remain_d
>>
2
;
remain_d
-=
(
block_d
<<
2
);
// d: 4
for
(
index_t
bd
=
0
;
bd
<
block_d
;
++
bd
)
{
// 8.4.4
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
;
float32x4_t
b0
,
b1
,
b2
,
b3
;
a0
=
vld1q_f32
(
lhs_ptr
);
a1
=
vld1q_f32
(
lhs_ptr
+
4
);
a2
=
vld1q_f32
(
lhs_ptr
+
8
);
a3
=
vld1q_f32
(
lhs_ptr
+
12
);
a4
=
vld1q_f32
(
lhs_ptr
+
16
);
a5
=
vld1q_f32
(
lhs_ptr
+
20
);
a6
=
vld1q_f32
(
lhs_ptr
+
24
);
a7
=
vld1q_f32
(
lhs_ptr
+
28
);
b0
=
vld1q_f32
(
rhs_ptr
);
b1
=
vld1q_f32
(
rhs_ptr
+
4
);
b2
=
vld1q_f32
(
rhs_ptr
+
8
);
b3
=
vld1q_f32
(
rhs_ptr
+
12
);
MACE_SGEMM_PART_CAL_R8_C4_D1
(
0
,
0
,
1
);
// d = 1
MACE_SGEMM_PART_CAL_R8_C4_D1
(
1
,
2
,
3
);
// d = 2
MACE_SGEMM_PART_CAL_R8_C4_D1
(
2
,
4
,
5
);
MACE_SGEMM_PART_CAL_R8_C4_D1
(
3
,
6
,
7
);
lhs_ptr
+=
32
;
rhs_ptr
+=
16
;
}
// TODO(liyin): handle remain by each case
// d: remain
for
(
index_t
d
=
0
;
d
<
remain_d
;
++
d
)
{
// 8.1.4
float32x4_t
a0
,
a1
;
float32x4_t
b0
;
a0
=
vld1q_f32
(
lhs_ptr
);
a1
=
vld1q_f32
(
lhs_ptr
+
4
);
b0
=
vld1q_f32
(
rhs_ptr
);
MACE_SGEMM_PART_CAL_R8_C4_D1
(
0
,
0
,
1
);
// d = 1
lhs_ptr
+=
8
;
rhs_ptr
+=
4
;
}
vst1q_f32
(
res_ptr
,
c0
);
vst1q_f32
(
res_ptr
+
4
,
c1
);
vst1q_f32
(
res_ptr
+
8
,
c2
);
vst1q_f32
(
res_ptr
+
12
,
c3
);
vst1q_f32
(
res_ptr
+
16
,
c4
);
vst1q_f32
(
res_ptr
+
20
,
c5
);
vst1q_f32
(
res_ptr
+
24
,
c6
);
vst1q_f32
(
res_ptr
+
28
,
c7
);
res_ptr
+=
32
;
}
// bh: 8
#endif // __aarch64__
// h: 4
block_h
=
remain_h
>>
2
;
remain_h
-=
(
block_h
<<
2
);
for
(
index_t
bh
=
0
;
bh
<
block_h
;
++
bh
)
{
const
float
*
rhs_ptr
=
rhs_data
+
depth
*
(
bw
<<
2
);
index_t
remain_d
=
depth
;
index_t
block_d
=
0
;
float32x4_t
c0
,
c1
,
c2
,
c3
;
c0
=
vdupq_n_f32
(
0.
f
);
c1
=
vdupq_n_f32
(
0.
f
);
c2
=
vdupq_n_f32
(
0.
f
);
c3
=
vdupq_n_f32
(
0.
f
);
// d: 8
block_d
=
remain_d
>>
3
;
remain_d
-=
(
block_d
<<
3
);
#if defined(__aarch64__)
for
(
index_t
bd
=
0
;
bd
<
block_d
;
++
bd
)
{
// 4.8.4
float32x4_t
a0
,
a1
,
a2
,
a3
,
a4
,
a5
,
a6
,
a7
;
float32x4_t
b0
,
b1
,
b2
,
b3
,
b4
,
b5
,
b6
,
b7
;
a0
=
vld1q_f32
(
lhs_ptr
);
a1
=
vld1q_f32
(
lhs_ptr
+
4
);
a2
=
vld1q_f32
(
lhs_ptr
+
8
);
a3
=
vld1q_f32
(
lhs_ptr
+
12
);
a4
=
vld1q_f32
(
lhs_ptr
+
16
);
a5
=
vld1q_f32
(
lhs_ptr
+
20
);
a6
=
vld1q_f32
(
lhs_ptr
+
24
);
a7
=
vld1q_f32
(
lhs_ptr
+
28
);
b0
=
vld1q_f32
(
rhs_ptr
);
b1
=
vld1q_f32
(
rhs_ptr
+
4
);
b2
=
vld1q_f32
(
rhs_ptr
+
8
);
b3
=
vld1q_f32
(
rhs_ptr
+
12
);
b4
=
vld1q_f32
(
rhs_ptr
+
16
);
b5
=
vld1q_f32
(
rhs_ptr
+
20
);
b6
=
vld1q_f32
(
rhs_ptr
+
24
);
b7
=
vld1q_f32
(
rhs_ptr
+
28
);
MACE_SGEMM_PART_CAL_R4_C4_D1
(
0
);
// d = 1
MACE_SGEMM_PART_CAL_R4_C4_D1
(
1
);
// d = 2
MACE_SGEMM_PART_CAL_R4_C4_D1
(
2
);
MACE_SGEMM_PART_CAL_R4_C4_D1
(
3
);
MACE_SGEMM_PART_CAL_R4_C4_D1
(
4
);
MACE_SGEMM_PART_CAL_R4_C4_D1
(
5
);
MACE_SGEMM_PART_CAL_R4_C4_D1
(
6
);
MACE_SGEMM_PART_CAL_R4_C4_D1
(
7
);
lhs_ptr
+=
32
;
rhs_ptr
+=
32
;
}
#else // arm v7
// 4.8.4
if
(
block_d
>
0
)
{
asm
volatile
(
"0:
\n
"
"vld1.f32 {d0-d1}, [%[lhs_ptr]]!
\n
"
"vld1.f32 {d2-d3}, [%[lhs_ptr]]!
\n
"
"vld1.f32 {d4-d5}, [%[lhs_ptr]]!
\n
"
"vld1.f32 {d20-d21}, [%[rhs_ptr]]!
\n
"
"vld1.f32 {d22-d23}, [%[rhs_ptr]]!
\n
"
"vld1.f32 {d24-d25}, [%[rhs_ptr]]!
\n
"
"vmla.f32 %[c0], q10, d0[0]
\n
"
"vmla.f32 %[c1], q10, d0[1]
\n
"
"vmla.f32 %[c2], q10, d1[0]
\n
"
"vmla.f32 %[c3], q10, d1[1]
\n
"
"vld1.f32 {d6-d7}, [%[lhs_ptr]]!
\n
"
"vld1.f32 {d26-d27}, [%[rhs_ptr]]!
\n
"
"vmla.f32 %[c0], q11, d2[0]
\n
"
"vmla.f32 %[c1], q11, d2[1]
\n
"
"vmla.f32 %[c2], q11, d3[0]
\n
"
"vmla.f32 %[c3], q11, d3[1]
\n
"
"vld1.f32 {d8-d9}, [%[lhs_ptr]]!
\n
"
"vld1.f32 {d28-d29}, [%[rhs_ptr]]!
\n
"
"vmla.f32 %[c0], q12, d4[0]
\n
"
"vmla.f32 %[c1], q12, d4[1]
\n
"
"vmla.f32 %[c2], q12, d5[0]
\n
"
"vmla.f32 %[c3], q12, d5[1]
\n
"
"vld1.f32 {d10-d11}, [%[lhs_ptr]]!
\n
"
"vld1.f32 {d30-d31}, [%[rhs_ptr]]!
\n
"
"vmla.f32 %[c0], q13, d6[0]
\n
"
"vmla.f32 %[c1], q13, d6[1]
\n
"
"vmla.f32 %[c2], q13, d7[0]
\n
"
"vmla.f32 %[c3], q13, d7[1]
\n
"
"vld1.f32 {d0-d1}, [%[lhs_ptr]]!
\n
"
"vld1.f32 {d2-d3}, [%[lhs_ptr]]!
\n
"
"vld1.f32 {d20-d21}, [%[rhs_ptr]]!
\n
"
"vld1.f32 {d22-d23}, [%[rhs_ptr]]!
\n
"
"vmla.f32 %[c0], q14, d8[0]
\n
"
"vmla.f32 %[c1], q14, d8[1]
\n
"
"vmla.f32 %[c2], q14, d9[0]
\n
"
"vmla.f32 %[c3], q14, d9[1]
\n
"
"vmla.f32 %[c0], q15, d10[0]
\n
"
"vmla.f32 %[c1], q15, d10[1]
\n
"
"vmla.f32 %[c2], q15, d11[0]
\n
"
"vmla.f32 %[c3], q15, d11[1]
\n
"
"vmla.f32 %[c0], q10, d0[0]
\n
"
"vmla.f32 %[c1], q10, d0[1]
\n
"
"vmla.f32 %[c2], q10, d1[0]
\n
"
"vmla.f32 %[c3], q10, d1[1]
\n
"
"subs %[block_d], %[block_d], #1
\n
"
"vmla.f32 %[c0], q11, d2[0]
\n
"
"vmla.f32 %[c1], q11, d2[1]
\n
"
"vmla.f32 %[c2], q11, d3[0]
\n
"
"vmla.f32 %[c3], q11, d3[1]
\n
"
"bne 0b
\n
"
:
// outputs
[
lhs_ptr
]
"+r"
(
lhs_ptr
),
[
rhs_ptr
]
"+r"
(
rhs_ptr
),
[
res_ptr
]
"+r"
(
res_ptr
),
[
block_d
]
"+r"
(
block_d
),
[
c0
]
"+w"
(
c0
),
[
c1
]
"+w"
(
c1
),
[
c2
]
"+w"
(
c2
),
[
c3
]
"+w"
(
c3
)
:
// inputs
:
// clabbers
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
);
}
#endif // __aarch64__
// d: 4
block_d
=
remain_d
>>
2
;
remain_d
-=
(
block_d
<<
2
);
for
(
index_t
bd
=
0
;
bd
<
block_d
;
++
bd
)
{
// 4.4.4
float32x4_t
a0
,
a1
,
a2
,
a3
;
float32x4_t
b0
,
b1
,
b2
,
b3
;
a0
=
vld1q_f32
(
lhs_ptr
);
a1
=
vld1q_f32
(
lhs_ptr
+
4
);
a2
=
vld1q_f32
(
lhs_ptr
+
8
);
a3
=
vld1q_f32
(
lhs_ptr
+
12
);
b0
=
vld1q_f32
(
rhs_ptr
);
b1
=
vld1q_f32
(
rhs_ptr
+
4
);
b2
=
vld1q_f32
(
rhs_ptr
+
8
);
b3
=
vld1q_f32
(
rhs_ptr
+
12
);
MACE_SGEMM_PART_CAL_R4_C4_D1
(
0
);
// d = 1
MACE_SGEMM_PART_CAL_R4_C4_D1
(
1
);
// d = 2
MACE_SGEMM_PART_CAL_R4_C4_D1
(
2
);
MACE_SGEMM_PART_CAL_R4_C4_D1
(
3
);
lhs_ptr
+=
16
;
rhs_ptr
+=
16
;
}
// d: remain
for
(
index_t
d
=
0
;
d
<
remain_d
;
++
d
)
{
// 4.1.4
float32x4_t
a0
;
float32x4_t
b0
;
a0
=
vld1q_f32
(
lhs_ptr
);
b0
=
vld1q_f32
(
rhs_ptr
);
MACE_SGEMM_PART_CAL_R4_C4_D1
(
0
);
// d = 1
lhs_ptr
+=
4
;
rhs_ptr
+=
4
;
}
vst1q_f32
(
res_ptr
,
c0
);
vst1q_f32
(
res_ptr
+
4
,
c1
);
vst1q_f32
(
res_ptr
+
8
,
c2
);
vst1q_f32
(
res_ptr
+
12
,
c3
);
res_ptr
+=
16
;
}
// bh: 4
// h: 1
for
(
index_t
h
=
0
;
h
<
remain_h
;
++
h
)
{
const
float
*
rhs_ptr
=
rhs_data
+
depth
*
(
bw
<<
2
);
index_t
remain_d
=
depth
;
index_t
block_d
=
0
;
float32x4_t
c0
=
vdupq_n_f32
(
0.
f
);
// d: 8
block_d
=
remain_d
>>
3
;
remain_d
-=
(
block_d
<<
3
);
for
(
index_t
bd
=
0
;
bd
<
block_d
;
++
bd
)
{
// 1.8.4
float32x4_t
a0
,
a1
;
float32x4_t
b0
,
b1
,
b2
,
b3
,
b4
,
b5
,
b6
,
b7
;
a0
=
vld1q_f32
(
lhs_ptr
);
a1
=
vld1q_f32
(
lhs_ptr
+
4
);
b0
=
vld1q_f32
(
rhs_ptr
);
b1
=
vld1q_f32
(
rhs_ptr
+
4
);
b2
=
vld1q_f32
(
rhs_ptr
+
8
);
b3
=
vld1q_f32
(
rhs_ptr
+
12
);
b4
=
vld1q_f32
(
rhs_ptr
+
16
);
b5
=
vld1q_f32
(
rhs_ptr
+
20
);
b6
=
vld1q_f32
(
rhs_ptr
+
24
);
b7
=
vld1q_f32
(
rhs_ptr
+
28
);
MACE_SGEMM_PART_CAL_R1_C4_D8
(
0
,
0
,
1
);
lhs_ptr
+=
8
;
rhs_ptr
+=
32
;
}
block_d
=
remain_d
>>
2
;
remain_d
-=
(
block_d
<<
2
);
// d: 4
for
(
index_t
bd
=
0
;
bd
<
block_d
;
++
bd
)
{
// 1.4.4
float32x4_t
a0
;
float32x4_t
b0
,
b1
,
b2
,
b3
;
a0
=
vld1q_f32
(
lhs_ptr
);
b0
=
vld1q_f32
(
rhs_ptr
);
b1
=
vld1q_f32
(
rhs_ptr
+
4
);
b2
=
vld1q_f32
(
rhs_ptr
+
8
);
b3
=
vld1q_f32
(
rhs_ptr
+
12
);
MACE_SGEMM_PART_CAL_R1_C4_D4
(
0
);
lhs_ptr
+=
4
;
rhs_ptr
+=
16
;
}
// d: remain
float
s0
=
0
;
float
s1
=
0
;
float
s2
=
0
;
float
s3
=
0
;
for
(
index_t
d
=
0
;
d
<
remain_d
;
++
d
)
{
// 1.1.4
s0
+=
lhs_ptr
[
0
]
*
rhs_ptr
[
0
];
s1
+=
lhs_ptr
[
0
]
*
rhs_ptr
[
1
];
s2
+=
lhs_ptr
[
0
]
*
rhs_ptr
[
2
];
s3
+=
lhs_ptr
[
0
]
*
rhs_ptr
[
3
];
lhs_ptr
+=
1
;
rhs_ptr
+=
4
;
}
float32x4_t
c0_remain
=
{
s0
,
s1
,
s2
,
s3
};
c0
+=
c0_remain
;
vst1q_f32
(
res_ptr
,
c0
);
res_ptr
+=
4
;
}
// bh: remain
}
// bw
#endif // MACE_ENABLE_NEON
// ========================== remain width ===========================
result_data
+=
(
width
-
remain_w
)
*
height
;
rhs_data
+=
(
width
-
remain_w
)
*
depth
;
// w: 1
#pragma omp parallel for
for
(
index_t
bw
=
0
;
bw
<
remain_w
;
++
bw
)
{
index_t
remain_h
=
height
;
const
float
*
lhs_ptr
=
lhs_data
;
float
*
res_ptr
=
result_data
+
height
*
bw
;
#if defined(MACE_ENABLE_NEON)
index_t
block_h
=
0
;
#if defined(__aarch64__)
block_h
=
remain_h
>>
3
;
remain_h
-=
(
block_h
<<
3
);
// h: 8
for
(
index_t
bh
=
0
;
bh
<
block_h
;
++
bh
)
{
const
float
*
rhs_ptr
=
rhs_data
+
depth
*
bw
;
index_t
remain_d
=
depth
;
float32x4_t
c0
,
c1
;
c0
=
vdupq_n_f32
(
0.
f
);
c1
=
vdupq_n_f32
(
0.
f
);
index_t
block_d
=
remain_d
>>
2
;
remain_d
-=
(
block_d
<<
2
);
// d: 4
for
(
index_t
bd
=
0
;
bd
<
block_d
;
++
bd
)
{
// 8.4.1
float32x4_t
b0
,
b1
,
b2
,
b3
,
b4
,
b5
,
b6
,
b7
;
float32x4_t
a0
;
b0
=
vld1q_f32
(
lhs_ptr
);
b1
=
vld1q_f32
(
lhs_ptr
+
4
);
b2
=
vld1q_f32
(
lhs_ptr
+
8
);
b3
=
vld1q_f32
(
lhs_ptr
+
12
);
b4
=
vld1q_f32
(
lhs_ptr
+
16
);
b5
=
vld1q_f32
(
lhs_ptr
+
20
);
b6
=
vld1q_f32
(
lhs_ptr
+
24
);
b7
=
vld1q_f32
(
lhs_ptr
+
28
);
a0
=
vld1q_f32
(
rhs_ptr
);
MACE_SGEMM_PART_CAL_R1_C8_D4
(
0
,
1
,
0
);
lhs_ptr
+=
32
;
rhs_ptr
+=
4
;
}
// d: remain
for
(
index_t
d
=
0
;
d
<
remain_d
;
++
d
)
{
// 8.1.1
float32x4_t
b0
,
b1
;
float32x4_t
a0
=
vdupq_n_f32
(
rhs_ptr
[
0
]);
b0
=
vld1q_f32
(
lhs_ptr
);
b1
=
vld1q_f32
(
lhs_ptr
+
4
);
c0
=
vfmaq_laneq_f32
(
c0
,
b0
,
a0
,
0
);
c1
=
vfmaq_laneq_f32
(
c1
,
b1
,
a0
,
0
);
lhs_ptr
+=
8
;
rhs_ptr
+=
1
;
}
vst1q_f32
(
res_ptr
,
c0
);
vst1q_f32
(
res_ptr
+
4
,
c1
);
res_ptr
+=
8
;
}
// bh: 8
#endif
// h: 4
block_h
=
remain_h
>>
2
;
remain_h
-=
(
block_h
<<
2
);
for
(
index_t
bh
=
0
;
bh
<
block_h
;
++
bh
)
{
const
float
*
rhs_ptr
=
rhs_data
+
depth
*
bw
;
index_t
remain_d
=
depth
;
index_t
block_d
=
0
;
float32x4_t
c0
=
vdupq_n_f32
(
0.
f
);
block_d
=
remain_d
>>
2
;
remain_d
-=
(
block_d
<<
2
);
// d: 4
for
(
index_t
bd
=
0
;
bd
<
block_d
;
++
bd
)
{
// 4.4.1
float32x4_t
b0
,
b1
,
b2
,
b3
;
float32x4_t
a0
;
b0
=
vld1q_f32
(
lhs_ptr
);
b1
=
vld1q_f32
(
lhs_ptr
+
4
);
b2
=
vld1q_f32
(
lhs_ptr
+
8
);
b3
=
vld1q_f32
(
lhs_ptr
+
12
);
a0
=
vld1q_f32
(
rhs_ptr
);
MACE_SGEMM_PART_CAL_R1_C4_D4
(
0
);
lhs_ptr
+=
16
;
rhs_ptr
+=
4
;
}
// d: remain
for
(
index_t
d
=
0
;
d
<
remain_d
;
++
d
)
{
// 4.1.1
float32x4_t
b0
,
b1
;
float32x2_t
a0
=
vdup_n_f32
(
rhs_ptr
[
0
]);
b0
=
vld1q_f32
(
lhs_ptr
);
c0
=
vmlaq_lane_f32
(
c0
,
b0
,
a0
,
0
);
lhs_ptr
+=
4
;
rhs_ptr
+=
1
;
}
vst1q_f32
(
res_ptr
,
c0
);
res_ptr
+=
4
;
}
// bh: 4
#endif // MACE_ENABLE_NEON
// h: 1
for
(
index_t
h
=
0
;
h
<
remain_h
;
++
h
)
{
const
float
*
rhs_ptr
=
rhs_data
+
depth
*
bw
;
index_t
remain_d
=
depth
;
float
sum
=
0.
f
;
#if defined(MACE_ENABLE_NEON)
index_t
block_d
=
0
;
float32x4_t
c0
,
c1
;
c0
=
vdupq_n_f32
(
0.
f
);
c1
=
vdupq_n_f32
(
0.
f
);
block_d
=
remain_d
>>
3
;
remain_d
-=
(
block_d
<<
3
);
// d: 8
for
(
index_t
bd
=
0
;
bd
<
block_d
;
++
bd
)
{
// 1.8.1
float32x4_t
a0
,
a1
;
float32x4_t
b0
,
b1
;
a0
=
vld1q_f32
(
lhs_ptr
);
a1
=
vld1q_f32
(
lhs_ptr
+
4
);
b0
=
vld1q_f32
(
rhs_ptr
);
b1
=
vld1q_f32
(
rhs_ptr
+
4
);
c0
=
vmlaq_f32
(
c0
,
a0
,
b0
);
c1
=
vmlaq_f32
(
c1
,
a1
,
b1
);
lhs_ptr
+=
8
;
rhs_ptr
+=
8
;
}
block_d
=
remain_d
>>
2
;
remain_d
-=
(
block_d
<<
2
);
// d: 4
for
(
index_t
bd
=
0
;
bd
<
block_d
;
++
bd
)
{
// 1.4.1
float32x4_t
a0
;
float32x4_t
b0
;
a0
=
vld1q_f32
(
lhs_ptr
);
b0
=
vld1q_f32
(
rhs_ptr
);
c0
=
vmlaq_f32
(
c0
,
a0
,
b0
);
lhs_ptr
+=
4
;
rhs_ptr
+=
4
;
}
sum
+=
vaddvq_f32
(
c0
);
sum
+=
vaddvq_f32
(
c1
);
#endif // MACE_ENABLE_NEON
// d: remain
for
(
index_t
d
=
0
;
d
<
remain_d
;
++
d
)
{
// 1.1.1
sum
+=
lhs_ptr
[
0
]
*
rhs_ptr
[
0
];
lhs_ptr
+=
1
;
rhs_ptr
+=
1
;
}
*
res_ptr
=
sum
;
++
res_ptr
;
}
// bh: remain
}
// bw
}
void
SGemm
::
PackLhs
(
const
MatrixMap
<
const
float
>
&
lhs
,
PackedBlock
*
packed_block
)
{
Pack
(
lhs
,
PackOrder
::
ColMajor
,
packed_block
);
Pack
(
lhs
,
PackOrder
::
ColMajor
,
packed_block
);
}
}
void
SGemm
::
PackRhs
(
const
MatrixMap
<
float
>
&
rhs
,
void
SGemm
::
PackRhs
(
const
MatrixMap
<
const
float
>
&
rhs
,
PackedBlock
<
float
>
*
packed_block
)
{
PackedBlock
*
packed_block
)
{
Pack
(
rhs
,
PackOrder
::
RowMajor
,
packed_block
);
Pack
(
rhs
,
PackOrder
::
RowMajor
,
packed_block
);
}
}
void
SGemm
::
UnPack
(
const
PackedBlock
<
float
>
&
packed_result
,
void
SGemm
::
Pack
(
const
MatrixMap
<
const
float
>
&
src
,
const
PackOrder
order
,
PackedBlock
*
packed_block
)
{
MACE_CHECK_NOTNULL
(
packed_block
);
const
index_t
height
=
src
.
row
();
const
index_t
width
=
src
.
col
();
auto
packed_data
=
packed_block
->
mutable_data
<
float
>
();
#define MACE_SGEMM_PACK_PER_BATCH \
for (index_t b = 0; b < src.batch(); ++b) { \
PackPerBatch(src, order, b, packed_data + b * height * width); \
}
if
(
src
.
batch
()
>=
MaceOpenMPThreadCount
)
{
#pragma omp parallel for
MACE_SGEMM_PACK_PER_BATCH
}
else
{
MACE_SGEMM_PACK_PER_BATCH
}
#undef MACE_SGEMM_PACK_PER_BATCH
}
void
SGemm
::
UnPack
(
const
PackedBlock
&
packed_result
,
MatrixMap
<
float
>
*
matrix_map
)
{
MatrixMap
<
float
>
*
matrix_map
)
{
(
void
)
packed_result
;
MACE_CHECK_NOTNULL
(
matrix_map
);
(
void
)
matrix_map
;
const
index_t
height
=
matrix_map
->
row
();
const
index_t
width
=
matrix_map
->
col
();
auto
packed_data
=
packed_result
.
data
<
float
>
();
#define MACE_SGEMM_UNPACK_PER_BATCH \
for (index_t b = 0; b < matrix_map->batch(); ++b) { \
UnPackPerBatch(packed_data + b * height * width, b, matrix_map); \
}
if
(
matrix_map
->
batch
()
>=
MaceOpenMPThreadCount
)
{
#pragma omp parallel for
MACE_SGEMM_UNPACK_PER_BATCH
}
else
{
MACE_SGEMM_UNPACK_PER_BATCH
}
#undef MACE_SGEMM_UNPACK_PER_BATCH
}
}
void
SGemm
::
Pack
(
const
MatrixMap
<
float
>
&
src
,
void
SGemm
::
PackPerBatch
(
const
MatrixMap
<
const
float
>
&
src
,
const
PackOrder
order
,
const
PackOrder
order
,
PackedBlock
<
float
>
*
packed_block
)
{
const
index_t
batch_index
,
(
void
)
src
;
float
*
packed_data
)
{
(
void
)
order
;
MACE_CHECK_NOTNULL
(
packed_data
);
(
void
)
packed_block
;
const
index_t
height
=
src
.
row
();
const
index_t
width
=
src
.
col
();
auto
src_data
=
src
.
batch_data
(
batch_index
);
if
(
src
.
major
()
==
Major
::
RowMajor
&&
order
==
PackOrder
::
ColMajor
)
{
// This is for packing no-transpose lhs.
index_t
h
=
0
;
#if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__)
#pragma omp parallel for
for
(
index_t
ih
=
h
;
ih
<=
height
-
8
;
ih
+=
8
)
{
const
float
*
src_data_ptr
=
src_data
+
ih
*
width
;
float
*
packed_data_ptr
=
packed_data
+
ih
*
width
;
for
(
index_t
w
=
0
;
w
<
width
;
++
w
)
{
const
index_t
src_offset
=
w
;
const
index_t
packed_offset
=
w
*
8
;
float32x4_t
vs0
=
{
src_data_ptr
[
src_offset
],
src_data_ptr
[
src_offset
+
width
],
src_data_ptr
[
src_offset
+
2
*
width
],
src_data_ptr
[
src_offset
+
3
*
width
]};
float32x4_t
vs1
=
{
src_data_ptr
[
src_offset
+
4
*
width
],
src_data_ptr
[
src_offset
+
5
*
width
],
src_data_ptr
[
src_offset
+
6
*
width
],
src_data_ptr
[
src_offset
+
7
*
width
]};
vst1q_f32
(
packed_data_ptr
+
packed_offset
,
vs0
);
vst1q_f32
(
packed_data_ptr
+
packed_offset
+
4
,
vs1
);
}
}
h
+=
(
height
-
h
)
/
8
*
8
;
#endif
#pragma omp parallel for
for
(
index_t
ih
=
h
;
ih
<=
height
-
4
;
ih
+=
4
)
{
const
float
*
src_data_ptr
=
src_data
+
ih
*
width
;
float
*
packed_data_ptr
=
packed_data
+
ih
*
width
;
for
(
index_t
w
=
0
;
w
<
width
;
++
w
)
{
const
index_t
src_offset
=
w
;
const
index_t
packed_offset
=
w
*
4
;
float32x4_t
vs
=
{
src_data_ptr
[
src_offset
],
src_data_ptr
[
src_offset
+
width
],
src_data_ptr
[
src_offset
+
2
*
width
],
src_data_ptr
[
src_offset
+
3
*
width
]};
vst1q_f32
(
packed_data_ptr
+
packed_offset
,
vs
);
}
}
h
+=
(
height
-
h
)
/
4
*
4
;
#endif
#pragma omp parallel for
for
(
index_t
ih
=
h
;
ih
<
height
;
++
ih
)
{
std
::
copy_n
(
src_data
+
ih
*
width
,
width
,
packed_data
+
ih
*
width
);
}
}
else
if
(
src
.
major
()
==
Major
::
ColMajor
&&
order
==
PackOrder
::
ColMajor
)
{
// This is for packing transpose-needed lhs.
index_t
h
=
0
;
#if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__)
#pragma omp parallel for
for
(
index_t
ih
=
h
;
ih
<=
height
-
8
;
ih
+=
8
)
{
const
float
*
src_data_ptr
=
src_data
+
ih
;
float
*
packed_data_ptr
=
packed_data
+
ih
*
width
;
for
(
index_t
w
=
0
;
w
<
width
;
++
w
)
{
const
index_t
src_offset
=
w
*
height
;
const
index_t
packed_offset
=
w
*
8
;
float32x4_t
vs0
=
vld1q_f32
(
src_data_ptr
+
src_offset
);
float32x4_t
vs1
=
vld1q_f32
(
src_data_ptr
+
src_offset
+
4
);
vst1q_f32
(
packed_data_ptr
+
packed_offset
,
vs0
);
vst1q_f32
(
packed_data_ptr
+
packed_offset
+
4
,
vs1
);
}
}
h
+=
(
height
-
h
)
/
8
*
8
;
#endif
#pragma omp parallel for
for
(
index_t
ih
=
h
;
ih
<=
height
-
4
;
ih
+=
4
)
{
const
float
*
src_data_ptr
=
src_data
+
ih
;
float
*
packed_data_ptr
=
packed_data
+
ih
*
width
;
for
(
index_t
w
=
0
;
w
<
width
;
++
w
)
{
const
index_t
src_offset
=
w
*
height
;
const
index_t
packed_offset
=
w
*
4
;
float32x4_t
vs
=
vld1q_f32
(
src_data_ptr
+
src_offset
);
vst1q_f32
(
packed_data_ptr
+
packed_offset
,
vs
);
}
}
h
+=
(
height
-
h
)
/
4
*
4
;
#endif
#pragma omp parallel for
for
(
index_t
ih
=
h
;
ih
<
height
;
++
ih
)
{
const
float
*
src_data_ptr
=
src_data
+
ih
;
float
*
packed_data_ptr
=
packed_data
+
ih
*
width
;
for
(
index_t
w
=
0
;
w
<
width
;
++
w
)
{
packed_data_ptr
[
w
]
=
src_data_ptr
[
w
*
height
];
}
}
}
else
if
(
src
.
major
()
==
Major
::
RowMajor
&&
order
==
PackOrder
::
RowMajor
)
{
// This is for packing no-transpose rhs.
index_t
w
=
0
;
#if defined(MACE_ENABLE_NEON)
#pragma omp parallel for
for
(
index_t
iw
=
w
;
iw
<=
width
-
4
;
iw
+=
4
)
{
const
float
*
src_data_ptr
=
src_data
+
iw
;
float
*
packed_data_ptr
=
packed_data
+
iw
*
height
;
for
(
index_t
h
=
0
;
h
<
height
;
++
h
)
{
const
index_t
src_offset
=
h
*
width
;
const
index_t
packed_offset
=
h
*
4
;
float32x4_t
vs
=
vld1q_f32
(
src_data_ptr
+
src_offset
);
vst1q_f32
(
packed_data_ptr
+
packed_offset
,
vs
);
}
}
w
+=
(
width
-
w
)
/
4
*
4
;
#endif
#pragma omp parallel for
for
(
index_t
iw
=
w
;
iw
<
width
;
++
iw
)
{
const
float
*
src_data_ptr
=
src_data
+
iw
;
float
*
packed_data_ptr
=
packed_data
+
iw
*
height
;
for
(
index_t
h
=
0
;
h
<
height
;
++
h
)
{
packed_data_ptr
[
h
]
=
src_data_ptr
[
h
*
width
];
}
}
}
else
if
(
src
.
major
()
==
Major
::
ColMajor
&&
order
==
PackOrder
::
RowMajor
)
{
// This is for packing transpose-needed rhs.
index_t
w
=
0
;
#if defined(MACE_ENABLE_NEON)
#pragma omp parallel for
for
(
index_t
iw
=
w
;
iw
<=
width
-
4
;
iw
+=
4
)
{
const
float
*
src_data_ptr
=
src_data
+
iw
*
height
;
float
*
packed_data_ptr
=
packed_data
+
iw
*
height
;
for
(
index_t
h
=
0
;
h
<
height
;
++
h
)
{
const
index_t
src_offset
=
h
;
const
index_t
packed_offset
=
h
*
4
;
float32x4_t
vs
=
{
src_data_ptr
[
src_offset
],
src_data_ptr
[
src_offset
+
height
],
src_data_ptr
[
src_offset
+
2
*
height
],
src_data_ptr
[
src_offset
+
3
*
height
]};
vst1q_f32
(
packed_data_ptr
+
packed_offset
,
vs
);
}
}
w
+=
(
width
-
w
)
/
4
*
4
;
#endif
#pragma omp parallel for
for
(
index_t
iw
=
w
;
iw
<
width
;
++
iw
)
{
std
::
copy_n
(
src_data
+
iw
*
height
,
height
,
packed_data
+
iw
*
height
);
}
}
}
void
SGemm
::
UnPackPerBatch
(
const
float
*
packed_data
,
const
index_t
batch_index
,
MatrixMap
<
float
>
*
matrix_map
)
{
MACE_CHECK_NOTNULL
(
matrix_map
);
const
index_t
height
=
matrix_map
->
row
();
const
index_t
width
=
matrix_map
->
col
();
auto
unpacked_data
=
matrix_map
->
batch_data
(
batch_index
);
if
(
matrix_map
->
major
()
==
Major
::
RowMajor
)
{
// This is for non-transposed result
index_t
w
=
0
;
#if defined(MACE_ENABLE_NEON)
#pragma omp parallel for
for
(
index_t
iw
=
w
;
iw
<=
width
-
4
;
iw
+=
4
)
{
const
float
*
packed_data_ptr
=
packed_data
+
iw
*
height
;
float
*
unpacked_data_ptr
=
unpacked_data
+
iw
;
for
(
index_t
h
=
0
;
h
<
height
;
++
h
)
{
const
index_t
packed_offset
=
h
*
4
;
const
index_t
unpacked_offset
=
h
*
width
;
float32x4_t
vs
=
vld1q_f32
(
packed_data_ptr
+
packed_offset
);
vst1q_f32
(
unpacked_data_ptr
+
unpacked_offset
,
vs
);
}
}
w
+=
(
width
-
w
)
/
4
*
4
;
#endif
#pragma omp parallel for
for
(
index_t
iw
=
w
;
iw
<
width
;
++
iw
)
{
const
float
*
packed_data_ptr
=
packed_data
+
iw
*
height
;
float
*
unpacked_data_ptr
=
unpacked_data
+
iw
;
for
(
index_t
h
=
0
;
h
<
height
;
++
h
)
{
unpacked_data_ptr
[
h
*
width
]
=
packed_data_ptr
[
h
];
}
}
}
else
{
// This is for transposed result
index_t
w
=
0
;
#if defined(MACE_ENABLE_NEON)
#pragma omp parallel for
for
(
index_t
iw
=
w
;
iw
<=
width
-
4
;
iw
+=
4
)
{
const
float
*
packed_data_ptr
=
packed_data
+
iw
*
height
;
float
*
unpacked_data_ptr
=
unpacked_data
+
iw
*
height
;
for
(
index_t
h
=
0
;
h
<
height
;
++
h
)
{
const
index_t
packed_offset
=
h
*
4
;
const
index_t
unpacked_offset
=
h
;
float32x4_t
vs
=
vld1q_f32
(
packed_data_ptr
+
packed_offset
);
unpacked_data_ptr
[
unpacked_offset
]
=
vs
[
0
];
unpacked_data_ptr
[
unpacked_offset
+
height
]
=
vs
[
1
];
unpacked_data_ptr
[
unpacked_offset
+
2
*
height
]
=
vs
[
2
];
unpacked_data_ptr
[
unpacked_offset
+
3
*
height
]
=
vs
[
3
];
}
}
w
+=
(
width
-
w
)
/
4
*
4
;
#endif
#pragma omp parallel for
for
(
index_t
iw
=
w
;
iw
<
width
;
++
iw
)
{
std
::
copy_n
(
packed_data
+
iw
*
height
,
height
,
unpacked_data
+
iw
*
height
);
}
}
}
}
}
// namespace kernels
}
// namespace kernels
...
...
mace/kernels/sgemm.h
浏览文件 @
57a3298d
...
@@ -15,6 +15,9 @@
...
@@ -15,6 +15,9 @@
#ifndef MACE_KERNELS_SGEMM_H_
#ifndef MACE_KERNELS_SGEMM_H_
#define MACE_KERNELS_SGEMM_H_
#define MACE_KERNELS_SGEMM_H_
#include <memory>
#include <utility>
#if defined(MACE_ENABLE_NEON)
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#include <arm_neon.h>
#endif
#endif
...
@@ -34,22 +37,29 @@ enum Major {
...
@@ -34,22 +37,29 @@ enum Major {
template
<
typename
T
>
template
<
typename
T
>
class
MatrixMap
{
class
MatrixMap
{
public:
public:
MatrixMap
(
const
index_t
row
,
MatrixMap
()
{}
MatrixMap
(
const
index_t
batch
,
const
index_t
row
,
const
index_t
col
,
const
index_t
col
,
const
Major
major
,
const
Major
major
,
T
*
data
)
:
T
*
data
,
const
bool
is_const
=
false
)
:
batch_
(
batch
),
row_
(
row
),
row_
(
row
),
col_
(
col
),
col_
(
col
),
stride_
(
major
==
RowMajor
?
col
:
row
),
stride_
(
major
==
RowMajor
?
col
:
row
),
major_
(
major
),
major_
(
major
),
data_
(
data
)
{}
data_
(
data
),
is_const_
(
is_const
)
{}
MatrixMap
<
T
>
transpose
(
const
MatrixMap
<
T
>
&
matrix_map
)
{
Major
transpose_major
=
matrix_map
.
major_
==
RowMajor
?
ColMajor
:
RowMajor
;
MatrixMap
transpose
()
const
{
return
MatrixMap
<
T
>
(
matrix_map
.
col_
,
Major
transpose_major
=
major_
==
RowMajor
?
ColMajor
:
RowMajor
;
matrix_map
.
row_
,
return
MatrixMap
(
batch_
,
col_
,
row_
,
transpose_major
,
data_
,
is_const_
);
transpose_major
,
}
matrix_map
.
data_
);
index_t
batch
()
const
{
return
batch_
;
}
}
index_t
row
()
const
{
index_t
row
()
const
{
...
@@ -72,66 +82,100 @@ class MatrixMap {
...
@@ -72,66 +82,100 @@ class MatrixMap {
return
data_
;
return
data_
;
}
}
T
*
data
(
int
row
,
int
col
)
const
{
T
*
batch_data
(
index_t
batch
)
const
{
return
data_
+
row
*
stride_
+
col
;
return
data_
+
batch
*
row_
*
col_
;
}
index_t
size
()
const
{
return
batch_
*
row_
*
col_
;
}
bool
is_const
()
const
{
return
is_const_
;
}
}
private:
private:
index_t
batch_
;
index_t
row_
;
index_t
row_
;
index_t
col_
;
index_t
col_
;
index_t
stride_
;
index_t
stride_
;
Major
major_
;
Major
major_
;
T
*
data_
;
T
*
data_
;
bool
is_const_
;
};
};
typedef
Major
PackOrder
;
typedef
Major
PackOrder
;
typedef
Tensor
PackedBlock
;
template
<
typename
T
>
class
PackedBlock
{
public:
PackedBlock
()
:
data_tensor_
(
GetCPUAllocator
(),
DataTypeToEnum
<
T
>::
v
())
{}
const
T
*
data
()
{
return
data_tensor_
.
data
<
T
>
();
}
T
*
mutable_data
()
{
return
data_tensor_
.
mutable_data
<
T
>
();
}
Tensor
*
tensor
()
{
return
&
data_tensor_
;
}
private:
Tensor
data_tensor_
;
};
class
SGemm
{
class
SGemm
{
public:
public:
void
operator
()(
const
MatrixMap
<
float
>
&
lhs
,
SGemm
()
const
MatrixMap
<
float
>
&
rhs
,
:
packed_lhs_
(
nullptr
),
MatrixMap
<
float
>
*
result
);
packed_rhs_
(
nullptr
),
packed_
(
false
)
{}
void
operator
()(
const
PackedBlock
<
float
>
&
lhs
,
const
PackedBlock
<
float
>
&
rhs
,
void
operator
()(
const
MatrixMap
<
const
float
>
&
lhs
,
const
index_t
height
,
const
MatrixMap
<
const
float
>
&
rhs
,
const
index_t
depth
,
MatrixMap
<
float
>
*
result
,
const
index_t
width
,
ScratchBuffer
*
scratch_buffer
=
nullptr
);
PackedBlock
<
float
>
*
result
);
void
Run
(
const
float
*
A
,
void
PackLhs
(
const
MatrixMap
<
float
>
&
lhs
,
PackedBlock
<
float
>
*
packed_block
);
const
float
*
B
,
const
index_t
batch
,
void
PackRhs
(
const
MatrixMap
<
float
>
&
rhs
,
PackedBlock
<
float
>
*
packed_block
);
const
index_t
height_a
,
const
index_t
width_a
,
void
UnPack
(
const
PackedBlock
<
float
>
&
packed_result
,
const
index_t
height_b
,
const
index_t
width_b
,
const
bool
transpose_a
,
const
bool
transpose_b
,
const
bool
is_a_weight
,
const
bool
is_b_weight
,
float
*
C
,
ScratchBuffer
*
scratch_buffer
=
nullptr
);
void
PackLhs
(
const
MatrixMap
<
const
float
>
&
lhs
,
PackedBlock
*
packed_block
);
void
PackRhs
(
const
MatrixMap
<
const
float
>
&
rhs
,
PackedBlock
*
packed_block
);
void
UnPack
(
const
PackedBlock
&
packed_result
,
MatrixMap
<
float
>
*
matrix_map
);
MatrixMap
<
float
>
*
matrix_map
);
private:
private:
void
Pack
(
const
MatrixMap
<
float
>
&
src
,
void
Pack
(
const
MatrixMap
<
const
float
>
&
src
,
const
PackOrder
order
,
const
PackOrder
order
,
PackedBlock
<
float
>
*
packed_block
);
PackedBlock
*
packed_block
);
void
PackPerBatch
(
const
MatrixMap
<
const
float
>
&
src
,
const
PackOrder
order
,
const
index_t
batch_index
,
float
*
packed_data
);
void
UnPackPerBatch
(
const
float
*
packed_data
,
const
index_t
batch_index
,
MatrixMap
<
float
>
*
matrix_map
);
void
RunInternal
(
const
PackedBlock
&
lhs
,
const
PackedBlock
&
rhs
,
const
index_t
batch
,
const
index_t
height
,
const
index_t
depth
,
const
index_t
width
,
PackedBlock
*
result
);
void
RunPerBatch
(
const
float
*
lhs
,
const
float
*
rhs
,
const
index_t
height
,
const
index_t
depth
,
const
index_t
width
,
float
*
result
);
std
::
unique_ptr
<
Tensor
>
packed_lhs_
;
std
::
unique_ptr
<
Tensor
>
packed_rhs_
;
std
::
unique_ptr
<
Tensor
>
packed_result_
;
bool
packed_
;
};
};
}
// namespace kernels
}
// namespace kernels
...
...
mace/kernels/sgemm_pack_test.cc
0 → 100644
浏览文件 @
57a3298d
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <algorithm>
#include <random>
#include <vector>
#include "mace/kernels/sgemm.h"
namespace
mace
{
namespace
kernels
{
namespace
test
{
namespace
{
void
TestPack
(
const
std
::
vector
<
float
>
&
data
,
const
std
::
vector
<
float
>
&
expected_data
,
const
index_t
height
,
const
index_t
width
,
Major
src_order
,
PackOrder
pack_order
)
{
SGemm
sg
;
MatrixMap
<
const
float
>
src_matrix
(
1
,
height
,
width
,
src_order
,
data
.
data
());
PackedBlock
packed
;
packed
.
Resize
({
height
,
width
});
if
(
pack_order
==
PackOrder
::
ColMajor
)
{
sg
.
PackLhs
(
src_matrix
,
&
packed
);
}
else
{
sg
.
PackRhs
(
src_matrix
,
&
packed
);
}
auto
packed_data
=
packed
.
data
<
float
>
();
for
(
index_t
i
=
0
;
i
<
packed
.
size
();
++
i
)
{
EXPECT_EQ
(
expected_data
[
i
],
packed_data
[
i
]);
}
}
void
TestUnPack
(
const
index_t
height
,
const
index_t
width
,
Major
src_order
,
PackOrder
pack_order
)
{
static
auto
seed
=
static_cast
<
unsigned
int
>
(
time
(
nullptr
));
const
index_t
matrix_size
=
height
*
width
;
std
::
vector
<
float
>
data
(
matrix_size
);
for
(
int
i
=
0
;
i
<
matrix_size
;
++
i
)
{
data
[
i
]
=
rand_r
(
&
seed
);
}
MatrixMap
<
const
float
>
src_matrix
(
1
,
height
,
width
,
src_order
,
data
.
data
());
PackedBlock
packed
;
packed
.
Resize
({
height
,
width
});
SGemm
sg
;
if
(
pack_order
==
PackOrder
::
ColMajor
)
{
sg
.
PackLhs
(
src_matrix
,
&
packed
);
}
else
{
sg
.
PackRhs
(
src_matrix
,
&
packed
);
}
std
::
vector
<
float
>
unpacked
(
matrix_size
);
MatrixMap
<
float
>
unpacked_matrix
(
1
,
height
,
width
,
src_order
,
unpacked
.
data
());
sg
.
UnPack
(
packed
,
&
unpacked_matrix
);
auto
unpacked_data
=
unpacked
.
data
();
for
(
index_t
i
=
0
;
i
<
packed
.
size
();
++
i
)
{
EXPECT_EQ
(
data
[
i
],
unpacked_data
[
i
]);
}
}
}
// namespace
TEST
(
SGemmPackTest
,
Pack
)
{
std
::
vector
<
float
>
data
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
18
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
27
,
28
,
29
,
30
,
31
,
32
,
33
,
34
,
35
,
36
};
// For no-transpose lhs
TestPack
(
data
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
},
3
,
4
,
Major
::
RowMajor
,
PackOrder
::
ColMajor
);
#if defined(MACE_ENABLE_NEON)
TestPack
(
data
,
{
1
,
5
,
9
,
13
,
2
,
6
,
10
,
14
,
3
,
7
,
11
,
15
,
4
,
8
,
12
,
16
},
4
,
4
,
Major
::
RowMajor
,
PackOrder
::
ColMajor
);
TestPack
(
data
,
{
1
,
5
,
9
,
13
,
2
,
6
,
10
,
14
,
3
,
7
,
11
,
15
,
4
,
8
,
12
,
16
,
17
,
18
,
19
,
20
},
5
,
4
,
Major
::
RowMajor
,
PackOrder
::
ColMajor
);
#if defined(__aarch64__)
TestPack
(
data
,
{
1
,
5
,
9
,
13
,
17
,
21
,
25
,
29
,
2
,
6
,
10
,
14
,
18
,
22
,
26
,
30
,
3
,
7
,
11
,
15
,
19
,
23
,
27
,
31
,
4
,
8
,
12
,
16
,
20
,
24
,
28
,
32
,
33
,
34
,
35
,
36
},
9
,
4
,
Major
::
RowMajor
,
PackOrder
::
ColMajor
);
#endif
#endif
// For transpose-needed lhs
TestPack
(
data
,
{
1
,
4
,
7
,
10
,
2
,
5
,
8
,
11
,
3
,
6
,
9
,
12
},
3
,
4
,
Major
::
ColMajor
,
PackOrder
::
ColMajor
);
#if defined(MACE_ENABLE_NEON)
TestPack
(
data
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
},
4
,
4
,
Major
::
ColMajor
,
PackOrder
::
ColMajor
);
TestPack
(
data
,
{
1
,
2
,
3
,
4
,
6
,
7
,
8
,
9
,
11
,
12
,
13
,
14
,
16
,
17
,
18
,
19
,
5
,
10
,
15
,
20
},
5
,
4
,
Major
::
ColMajor
,
PackOrder
::
ColMajor
);
#if defined(__aarch64__)
TestPack
(
data
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
10
,
11
,
12
,
13
,
14
,
15
,
16
,
17
,
19
,
20
,
21
,
22
,
23
,
24
,
25
,
26
,
28
,
29
,
30
,
31
,
32
,
33
,
34
,
35
,
9
,
18
,
27
,
36
},
9
,
4
,
Major
::
ColMajor
,
PackOrder
::
ColMajor
);
#endif
#endif
// For no-transpose rhs
TestPack
(
data
,
{
1
,
4
,
7
,
10
,
2
,
5
,
8
,
11
,
3
,
6
,
9
,
12
},
4
,
3
,
Major
::
RowMajor
,
PackOrder
::
RowMajor
);
#if defined(MACE_ENABLE_NEON)
TestPack
(
data
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
,
15
,
16
},
4
,
4
,
Major
::
RowMajor
,
PackOrder
::
RowMajor
);
TestPack
(
data
,
{
1
,
2
,
3
,
4
,
6
,
7
,
8
,
9
,
11
,
12
,
13
,
14
,
16
,
17
,
18
,
19
,
5
,
10
,
15
,
20
},
4
,
5
,
Major
::
RowMajor
,
PackOrder
::
RowMajor
);
#endif
// For transpose-needed rhs
TestPack
(
data
,
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
},
4
,
3
,
Major
::
ColMajor
,
PackOrder
::
RowMajor
);
#if defined(MACE_ENABLE_NEON)
TestPack
(
data
,
{
1
,
5
,
9
,
13
,
2
,
6
,
10
,
14
,
3
,
7
,
11
,
15
,
4
,
8
,
12
,
16
},
4
,
4
,
Major
::
ColMajor
,
PackOrder
::
RowMajor
);
TestPack
(
data
,
{
1
,
5
,
9
,
13
,
2
,
6
,
10
,
14
,
3
,
7
,
11
,
15
,
4
,
8
,
12
,
16
,
17
,
18
,
19
,
20
},
4
,
5
,
Major
::
ColMajor
,
PackOrder
::
RowMajor
);
#endif
}
TEST
(
SGemmPackTest
,
UnPack
)
{
TestUnPack
(
4
,
3
,
Major
::
RowMajor
,
PackOrder
::
RowMajor
);
TestUnPack
(
4
,
4
,
Major
::
RowMajor
,
PackOrder
::
RowMajor
);
TestUnPack
(
4
,
5
,
Major
::
RowMajor
,
PackOrder
::
RowMajor
);
TestUnPack
(
4
,
100
,
Major
::
RowMajor
,
PackOrder
::
RowMajor
);
TestUnPack
(
4
,
3
,
Major
::
ColMajor
,
PackOrder
::
RowMajor
);
TestUnPack
(
4
,
4
,
Major
::
ColMajor
,
PackOrder
::
RowMajor
);
TestUnPack
(
4
,
5
,
Major
::
ColMajor
,
PackOrder
::
RowMajor
);
TestUnPack
(
4
,
100
,
Major
::
ColMajor
,
PackOrder
::
RowMajor
);
}
}
// namespace test
}
// namespace kernels
}
// namespace mace
mace/ops/matmul.h
浏览文件 @
57a3298d
...
@@ -40,7 +40,11 @@ class MatMulOp : public Operator<D, T> {
...
@@ -40,7 +40,11 @@ class MatMulOp : public Operator<D, T> {
"than or equal to 2"
);
"than or equal to 2"
);
index_t
rank
=
A
->
dim_size
();
index_t
rank
=
A
->
dim_size
();
for
(
index_t
i
=
0
;
i
<
rank
-
2
;
++
i
)
{
for
(
index_t
i
=
0
;
i
<
rank
-
2
;
++
i
)
{
MACE_CHECK
(
A
->
dim
(
i
)
==
B
->
dim
(
i
),
"batch dimensions are not equal"
);
MACE_CHECK
(
A
->
dim
(
i
)
==
B
->
dim
(
i
),
"batch dimensions are not equal: "
,
A
->
dim
(
i
),
" vs. "
,
B
->
dim
(
i
));
}
}
index_t
ak
=
transpose_a_
?
A
->
dim
(
rank
-
2
)
:
A
->
dim
(
rank
-
1
);
index_t
ak
=
transpose_a_
?
A
->
dim
(
rank
-
2
)
:
A
->
dim
(
rank
-
1
);
index_t
bk
=
transpose_b_
?
B
->
dim
(
rank
-
1
)
:
B
->
dim
(
rank
-
2
);
index_t
bk
=
transpose_b_
?
B
->
dim
(
rank
-
1
)
:
B
->
dim
(
rank
-
2
);
...
...
mace/ops/matmul_benchmark.cc
浏览文件 @
57a3298d
...
@@ -33,13 +33,15 @@ void MatMulBenchmark(
...
@@ -33,13 +33,15 @@ void MatMulBenchmark(
// Add input data
// Add input data
net
.
AddRandomInput
<
D
,
T
>
(
"A"
,
{
batch
,
height
,
channels
});
net
.
AddRandomInput
<
D
,
T
>
(
"A"
,
{
batch
,
height
,
channels
});
net
.
AddRandomInput
<
D
,
T
>
(
"B"
,
{
batch
,
channels
,
out_width
});
net
.
AddRandomInput
<
D
,
T
>
(
"B"
,
{
batch
,
channels
,
out_width
});
net
.
GetTensor
(
"A"
)
->
SetIsWeight
(
true
);
net
.
GetTensor
(
"B"
)
->
SetIsWeight
(
true
);
if
(
DataTypeToEnum
<
T
>::
value
==
DT_UINT8
)
{
if
(
DataTypeToEnum
<
T
>::
value
==
DT_UINT8
)
{
net
.
GetTensor
(
"A"
)
->
SetScale
(
0.1
);
net
.
GetTensor
(
"A"
)
->
SetScale
(
0.1
);
net
.
GetTensor
(
"B"
)
->
SetScale
(
0.1
);
net
.
GetTensor
(
"B"
)
->
SetScale
(
0.1
);
}
}
if
(
D
==
DeviceType
::
GPU
)
{
if
(
D
==
DeviceType
::
GPU
)
{
BufferToImage
<
D
,
T
>
(
&
net
,
"A"
,
"AImage"
,
kernels
::
BufferType
::
IN_OUT_WIDTH
);
BufferToImage
<
D
,
T
>
(
&
net
,
"A"
,
"AImage"
,
kernels
::
BufferType
::
IN_OUT_WIDTH
);
BufferToImage
<
D
,
T
>
(
&
net
,
"B"
,
"BImage"
,
BufferToImage
<
D
,
T
>
(
&
net
,
"B"
,
"BImage"
,
kernels
::
BufferType
::
IN_OUT_HEIGHT
);
kernels
::
BufferType
::
IN_OUT_HEIGHT
);
...
@@ -71,7 +73,7 @@ void MatMulBenchmark(
...
@@ -71,7 +73,7 @@ void MatMulBenchmark(
mace
::
testing
::
StartTiming
();
mace
::
testing
::
StartTiming
();
while
(
iters
--
)
{
while
(
iters
--
)
{
net
.
Run
Op
(
D
);
net
.
Run
(
);
}
}
net
.
Sync
();
net
.
Sync
();
}
}
...
@@ -86,6 +88,8 @@ void MatMulTransposeBenchmark(
...
@@ -86,6 +88,8 @@ void MatMulTransposeBenchmark(
// Add input data
// Add input data
net
.
AddRandomInput
<
D
,
T
>
(
"A"
,
{
batch
,
height
,
channels
});
net
.
AddRandomInput
<
D
,
T
>
(
"A"
,
{
batch
,
height
,
channels
});
net
.
AddRandomInput
<
D
,
T
>
(
"B"
,
{
batch
,
out_width
,
channels
});
net
.
AddRandomInput
<
D
,
T
>
(
"B"
,
{
batch
,
out_width
,
channels
});
net
.
GetTensor
(
"A"
)
->
SetIsWeight
(
true
);
net
.
GetTensor
(
"B"
)
->
SetIsWeight
(
true
);
if
(
DataTypeToEnum
<
T
>::
value
==
DT_UINT8
)
{
if
(
DataTypeToEnum
<
T
>::
value
==
DT_UINT8
)
{
net
.
GetTensor
(
"A"
)
->
SetScale
(
0.1
);
net
.
GetTensor
(
"A"
)
->
SetScale
(
0.1
);
net
.
GetTensor
(
"B"
)
->
SetScale
(
0.1
);
net
.
GetTensor
(
"B"
)
->
SetScale
(
0.1
);
...
@@ -116,7 +120,7 @@ void MatMulTransposeBenchmark(
...
@@ -116,7 +120,7 @@ void MatMulTransposeBenchmark(
mace
::
testing
::
StartTiming
();
mace
::
testing
::
StartTiming
();
while
(
iters
--
)
{
while
(
iters
--
)
{
net
.
Run
Op
(
D
);
net
.
Run
(
);
}
}
net
.
Sync
();
net
.
Sync
();
}
}
...
@@ -154,10 +158,15 @@ void MatMulTransposeBenchmark(
...
@@ -154,10 +158,15 @@ void MatMulTransposeBenchmark(
MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, float, CPU); \
MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, float, CPU); \
MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, uint8_t, CPU);
MACE_BM_MATMUL_TRANSPOSE_MACRO(N, H, C, W, uint8_t, CPU);
MACE_BM_MATMUL
(
1
,
128
,
128
,
49
);
MACE_BM_MATMUL
(
2
,
128
,
128
,
49
);
MACE_BM_MATMUL
(
3
,
128
,
128
,
49
);
MACE_BM_MATMUL
(
4
,
128
,
128
,
49
);
MACE_BM_MATMUL
(
16
,
32
,
128
,
49
);
MACE_BM_MATMUL
(
16
,
32
,
128
,
49
);
MACE_BM_MATMUL
(
16
,
32
,
128
,
961
);
MACE_BM_MATMUL
(
16
,
32
,
128
,
961
);
MACE_BM_MATMUL
(
16
,
32
,
128
,
3969
);
MACE_BM_MATMUL
(
16
,
32
,
128
,
3969
);
MACE_BM_MATMUL
(
16
,
128
,
128
,
49
);
MACE_BM_MATMUL
(
16
,
128
,
128
,
49
);
MACE_BM_MATMUL
(
16
,
49
,
128
,
128
);
MACE_BM_MATMUL
(
16
,
128
,
128
,
961
);
MACE_BM_MATMUL
(
16
,
128
,
128
,
961
);
MACE_BM_MATMUL
(
16
,
128
,
128
,
3969
);
MACE_BM_MATMUL
(
16
,
128
,
128
,
3969
);
...
...
mace/ops/winograd_transform_benchmark.cc
浏览文件 @
57a3298d
...
@@ -211,8 +211,8 @@ void WinoMatMulBenchmark(
...
@@ -211,8 +211,8 @@ void WinoMatMulBenchmark(
const
index_t
round_w
=
(
width
+
block_size
-
1
)
/
block_size
;
const
index_t
round_w
=
(
width
+
block_size
-
1
)
/
block_size
;
const
index_t
out_width
=
round_h
*
round_w
;
const
index_t
out_width
=
round_h
*
round_w
;
// Add input data
// Add input data
net
.
AddRandomInput
<
D
,
float
>
(
"A"
,
{
batch
,
out_channels
,
in_channels
,
1
});
net
.
AddRandomInput
<
D
,
float
>
(
"A"
,
{
batch
,
out_channels
,
in_channels
});
net
.
AddRandomInput
<
D
,
float
>
(
"B"
,
{
batch
,
in_channels
,
out_width
,
1
});
net
.
AddRandomInput
<
D
,
float
>
(
"B"
,
{
batch
,
in_channels
,
out_width
});
if
(
D
==
DeviceType
::
GPU
)
{
if
(
D
==
DeviceType
::
GPU
)
{
BufferToImage
<
D
,
T
>
(
&
net
,
"A"
,
"AImage"
,
kernels
::
BufferType
::
IN_OUT_WIDTH
);
BufferToImage
<
D
,
T
>
(
&
net
,
"A"
,
"AImage"
,
kernels
::
BufferType
::
IN_OUT_WIDTH
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录