Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
项目经理老王
Mace
提交
379c730d
Mace
项目概览
项目经理老王
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
379c730d
编写于
9月 07, 2018
作者:
李
李滨
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'gemm_tile' into 'master'
Optimize gemm tiling See merge request !787
上级
170251df
81202cd4
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
285 addition
and
233 deletion
+285
-233
mace/core/allocator.h
mace/core/allocator.h
+2
-2
mace/core/runtime/cpu/cpu_runtime.cc
mace/core/runtime/cpu/cpu_runtime.cc
+4
-0
mace/core/runtime/cpu/cpu_runtime.h
mace/core/runtime/cpu/cpu_runtime.h
+2
-0
mace/core/tensor.h
mace/core/tensor.h
+15
-3
mace/core/workspace.cc
mace/core/workspace.cc
+2
-2
mace/kernels/gemm.cc
mace/kernels/gemm.cc
+207
-213
mace/kernels/gemm.h
mace/kernels/gemm.h
+6
-0
mace/kernels/gemm_test.cc
mace/kernels/gemm_test.cc
+2
-0
mace/kernels/matmul.h
mace/kernels/matmul.h
+23
-5
mace/kernels/transpose.h
mace/kernels/transpose.h
+15
-3
mace/ops/resize_bicubic_test.cc
mace/ops/resize_bicubic_test.cc
+4
-4
mace/ops/transpose_benchmark.cc
mace/ops/transpose_benchmark.cc
+3
-0
mace/ops/unstack_test.cc
mace/ops/unstack_test.cc
+0
-1
未找到文件。
mace/core/allocator.h
浏览文件 @
379c730d
...
...
@@ -34,8 +34,8 @@ namespace mace {
#if defined(__hexagon__)
constexpr
size_t
kMaceAlignment
=
128
;
#elif defined(__ANDROID__)
//
16 bytes = 128 bits = 32 * 4 (Neon)
constexpr
size_t
kMaceAlignment
=
16
;
//
arm cache line
constexpr
size_t
kMaceAlignment
=
64
;
#else
// 32 bytes = 256 bits (AVX512)
constexpr
size_t
kMaceAlignment
=
32
;
...
...
mace/core/runtime/cpu/cpu_runtime.cc
浏览文件 @
379c730d
...
...
@@ -35,6 +35,8 @@
namespace
mace
{
int
MaceOpenMPThreadCount
=
1
;
namespace
{
int
GetCPUCount
()
{
...
...
@@ -136,6 +138,8 @@ MaceStatus GetCPUBigLittleCoreIDs(std::vector<int> *big_core_ids,
MaceStatus
SetOpenMPThreadsAndAffinityCPUs
(
int
omp_num_threads
,
const
std
::
vector
<
int
>
&
cpu_ids
)
{
MaceOpenMPThreadCount
=
omp_num_threads
;
#ifdef MACE_ENABLE_OPENMP
VLOG
(
1
)
<<
"Set OpenMP threads number: "
<<
omp_num_threads
<<
", CPU core IDs: "
<<
MakeString
(
cpu_ids
);
...
...
mace/core/runtime/cpu/cpu_runtime.h
浏览文件 @
379c730d
...
...
@@ -22,6 +22,8 @@
namespace
mace
{
extern
int
MaceOpenMPThreadCount
;
MaceStatus
GetCPUBigLittleCoreIDs
(
std
::
vector
<
int
>
*
big_core_ids
,
std
::
vector
<
int
>
*
little_core_ids
);
...
...
mace/core/tensor.h
浏览文件 @
379c730d
...
...
@@ -100,31 +100,38 @@ enum DataFormat { NHWC = 0, NCHW = 1, HWOI = 2, OIHW = 3, HWIO = 4, OHWI = 5 };
class
Tensor
{
public:
Tensor
(
Allocator
*
alloc
,
DataType
type
)
Tensor
(
Allocator
*
alloc
,
DataType
type
,
bool
is_weight
=
false
)
:
allocator_
(
alloc
),
dtype_
(
type
),
buffer_
(
nullptr
),
is_buffer_owner_
(
true
),
unused_
(
false
),
name_
(
""
),
is_weight_
(
is_weight
),
scale_
(
0.
f
),
zero_point_
(
0
)
{}
Tensor
(
BufferBase
*
buffer
,
DataType
dtype
)
Tensor
(
BufferBase
*
buffer
,
DataType
dtype
,
bool
is_weight
=
false
)
:
dtype_
(
dtype
),
buffer_
(
buffer
),
is_buffer_owner_
(
false
),
unused_
(
false
),
name_
(
""
),
is_weight_
(
is_weight
),
scale_
(
0.
f
),
zero_point_
(
0
)
{}
Tensor
(
const
BufferSlice
&
buffer_slice
,
DataType
dtype
)
Tensor
(
const
BufferSlice
&
buffer_slice
,
DataType
dtype
,
bool
is_weight
=
false
)
:
dtype_
(
dtype
),
buffer_slice_
(
buffer_slice
),
is_buffer_owner_
(
false
),
unused_
(
false
),
name_
(
""
),
is_weight_
(
is_weight
),
scale_
(
0.
f
),
zero_point_
(
0
)
{
buffer_
=
&
buffer_slice_
;
...
...
@@ -373,6 +380,10 @@ class Tensor {
MACE_DISABLE_COPY_AND_ASSIGN
(
MappingGuard
);
};
inline
bool
is_weight
()
const
{
return
is_weight_
;
}
inline
float
scale
()
const
{
return
scale_
;
}
...
...
@@ -399,6 +410,7 @@ class Tensor {
bool
is_buffer_owner_
;
bool
unused_
;
std
::
string
name_
;
const
bool
is_weight_
;
float
scale_
;
int32_t
zero_point_
;
...
...
mace/core/workspace.cc
浏览文件 @
379c730d
...
...
@@ -105,7 +105,7 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def,
std
::
unique_ptr
<
Tensor
>
tensor
(
new
Tensor
(
GetDeviceAllocator
(
type
),
const_tensor
.
data_type
()));
const_tensor
.
data_type
()
,
true
));
tensor
->
Resize
(
dims
);
MACE_CHECK
(
tensor
->
size
()
==
const_tensor
.
data_size
(),
...
...
@@ -159,7 +159,7 @@ MaceStatus Workspace::LoadModelTensor(const NetDef &net_def,
tensor_buffer_
.
get
(),
const_tensor
.
offset
(),
const_tensor
.
data_size
()
*
GetEnumTypeSize
(
const_tensor
.
data_type
())),
const_tensor
.
data_type
()));
const_tensor
.
data_type
()
,
true
));
tensor
->
Reshape
(
dims
);
tensor
->
SetScale
(
const_tensor
.
scale
());
...
...
mace/kernels/gemm.cc
浏览文件 @
379c730d
...
...
@@ -14,8 +14,10 @@
#include <algorithm>
#include <cstring>
#include <vector>
#include "mace/core/tensor.h"
#include "mace/core/runtime/cpu/cpu_runtime.h"
#include "mace/kernels/gemm.h"
/**
...
...
@@ -329,37 +331,6 @@ inline void Gemm644(const float *a_ptr,
#endif
}
inline
void
GemmX44
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
,
int
row
)
{
switch
(
row
)
{
case
1
:
Gemm144
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
2
:
Gemm244
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
3
:
Gemm344
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
4
:
Gemm444
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
5
:
Gemm544
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
6
:
Gemm644
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
default:
MACE_NOT_IMPLEMENTED
;
}
}
inline
void
Gemm884
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
...
...
@@ -770,43 +741,6 @@ inline void Gemm784(const float *a_ptr,
#endif
}
inline
void
GemmX84
(
const
float
*
a_ptr
,
const
float
*
b_ptr
,
const
index_t
stride_a
,
const
index_t
stride_b
,
const
index_t
stride_c
,
float
*
c_ptr
,
int
row
)
{
switch
(
row
)
{
case
1
:
Gemm184
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
2
:
Gemm284
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
3
:
Gemm384
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
4
:
Gemm484
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
5
:
Gemm584
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
6
:
Gemm684
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
7
:
Gemm784
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
case
8
:
Gemm884
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
break
;
default:
MACE_NOT_IMPLEMENTED
;
}
}
inline
void
GemmTile
(
const
float
*
A
,
const
float
*
B
,
const
index_t
height
,
...
...
@@ -873,6 +807,8 @@ inline void GemmTile(const float *A,
float
*
c_ptr7
=
C
+
(
h
+
7
)
*
stride_c
;
asm
volatile
(
"0:
\n
"
"prfm pldl1keep, [%9, #128]
\n
"
"ld1 {v16.4s}, [%9], #16
\n
"
...
...
@@ -882,8 +818,6 @@ inline void GemmTile(const float *A,
"prfm pldl1keep, [%2, #128]
\n
"
"ld1 {v19.4s}, [%2]
\n
"
"0:
\n
"
"prfm pldl1keep, [%3, #128]
\n
"
"ld1 {v20.4s}, [%3]
\n
"
"prfm pldl1keep, [%4, #128]
\n
"
...
...
@@ -1002,19 +936,13 @@ inline void GemmTile(const float *A,
"fmla v24.4s, v17.4s, %48.s[3]
\n
"
"fmla v25.4s, v17.4s, %49.s[3]
\n
"
"subs %w0, %w0, #1
\n
"
"st1 {v22.4s}, [%5], #16
\n
"
"st1 {v23.4s}, [%6], #16
\n
"
"st1 {v24.4s}, [%7], #16
\n
"
"st1 {v25.4s}, [%8], #16
\n
"
"prfm pldl1keep, [%9, #128]
\n
"
"ld1 {v16.4s}, [%9], #16
\n
"
"prfm pldl1keep, [%1, #128]
\n
"
"ld1 {v18.4s}, [%1]
\n
"
"prfm pldl1keep, [%2, #128]
\n
"
"ld1 {v19.4s}, [%2]
\n
"
"subs %w0, %w0, #1
\n
"
"bne 0b
\n
"
:
"=r"
(
nw
),
// 0
"=r"
(
c_ptr0
),
// 1
...
...
@@ -1102,6 +1030,8 @@ inline void GemmTile(const float *A,
float
*
c_ptr5
=
C
+
(
h
+
5
)
*
stride_c
;
asm
volatile
(
"0:
\n
"
"pld [%7, #128]
\n
"
"vld1.f32 {d12-d13}, [%7]!
\n
"
"pld [%1, #128]
\n
"
...
...
@@ -1109,8 +1039,6 @@ inline void GemmTile(const float *A,
"pld [%2, #128]
\n
"
"vld1.f32 {d18-d19}, [%2]
\n
"
"0:
\n
"
"pld [%3, #128]
\n
"
"vld1.f32 {d20-d21}, [%3]
\n
"
"pld [%4, #128]
\n
"
...
...
@@ -1159,22 +1087,11 @@ inline void GemmTile(const float *A,
"vst1.f32 {d16-d17}, [%1]!
\n
"
"vst1.f32 {d18-d19}, [%2]!
\n
"
"pld [%7, #128]
\n
"
"vld1.f32 {d12-d13}, [%7]!
\n
"
"vst1.f32 {d20-d21}, [%3]!
\n
"
"vst1.f32 {d22-d23}, [%4]!
\n
"
"pld [%1, #128]
\n
"
"vld1.f32 {d16-d17}, [%1]
\n
"
"vst1.f32 {d24-d25}, [%5]!
\n
"
"vst1.f32 {d26-d27}, [%6]!
\n
"
"pld [%2, #128]
\n
"
"vld1.f32 {d18-d19}, [%2]
\n
"
"subs %0, #1
\n
"
"bne 0b
\n
"
:
"=r"
(
nw
),
// 0
...
...
@@ -1228,17 +1145,69 @@ inline void GemmTile(const float *A,
}
if
(
h
<
height
)
{
index_t
remain_h
=
height
-
h
;
auto
gemm_fn
=
Gemm184
;
switch
(
remain_h
)
{
case
1
:
#if defined(__aarch64__)
gemm_fn
=
Gemm184
;
#else
gemm_fn
=
Gemm144
;
#endif
break
;
case
2
:
#if defined(__aarch64__)
gemm_fn
=
Gemm284
;
#else
gemm_fn
=
Gemm244
;
#endif
break
;
case
3
:
#if defined(__aarch64__)
gemm_fn
=
Gemm384
;
#else
gemm_fn
=
Gemm344
;
#endif
break
;
case
4
:
#if defined(__aarch64__)
gemm_fn
=
Gemm484
;
#else
gemm_fn
=
Gemm444
;
#endif
break
;
case
5
:
#if defined(__aarch64__)
gemm_fn
=
Gemm584
;
#else
gemm_fn
=
Gemm544
;
#endif
break
;
case
6
:
#if defined(__aarch64__)
gemm_fn
=
Gemm684
;
#else
LOG
(
FATAL
)
<<
"remain_h should < 6"
;
#endif
break
;
case
7
:
#if defined(__aarch64__)
gemm_fn
=
Gemm784
;
#else
LOG
(
FATAL
)
<<
"remain_h should < 6"
;
#endif
break
;
default:
LOG
(
FATAL
)
<<
"remain_h should < 8"
;
}
for
(
k
=
0
;
k
<
K
-
reg_K_tile
;
k
+=
reg_K_tile
)
{
const
float
*
a_ptr
=
A
+
(
h
*
stride_a
+
k
);
index_t
w
;
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
const
float
*
b_ptr
=
B
+
(
k
*
stride_b
+
w
);
float
*
c_ptr
=
C
+
(
h
*
stride_c
+
w
);
#if defined(__aarch64__)
GemmX84
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
,
remain_h
);
#else
GemmX44
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
,
remain_h
);
#endif
gemm_fn
(
a_ptr
,
b_ptr
,
stride_a
,
stride_b
,
stride_c
,
c_ptr
);
}
if
(
w
<
width
)
{
const
float
*
b_ptr
=
B
+
(
k
*
stride_b
+
w
);
...
...
@@ -1260,20 +1229,27 @@ inline void GemmTile(const float *A,
#endif // MACE_ENABLE_NEON
}
}
// namespace
void
Transpose
(
const
float
*
src
,
index_t
height
,
index_t
width
,
index_t
stride_w
,
float
*
dst
)
{
for
(
index_t
h
=
0
;
h
<
height
;
++
h
)
{
for
(
index_t
w
=
0
;
w
<
width
;
++
w
)
{
dst
[
w
*
height
+
h
]
=
src
[
h
*
stride_w
+
w
];
index_t
tile_size
=
height
>
512
||
width
>
512
?
64
:
32
;
for
(
index_t
i
=
0
;
i
<
height
;
i
+=
tile_size
)
{
for
(
index_t
j
=
0
;
j
<
width
;
j
+=
tile_size
)
{
index_t
end_i
=
std
::
min
(
i
+
tile_size
,
height
);
index_t
end_j
=
std
::
min
(
j
+
tile_size
,
width
);
for
(
index_t
tile_i
=
i
;
tile_i
<
end_i
;
++
tile_i
)
{
for
(
index_t
tile_j
=
j
;
tile_j
<
end_j
;
++
tile_j
)
{
dst
[
tile_j
*
height
+
tile_i
]
=
src
[
tile_i
*
stride_w
+
tile_j
];
}
}
}
}
}
}
// namespace
// A: height x K, B: K x width, C: height x width
void
Gemm
(
const
float
*
A
,
const
float
*
B
,
...
...
@@ -1284,7 +1260,7 @@ void Gemm(const float *A,
float
*
C
,
const
bool
transpose_a
,
const
bool
transpose_b
)
{
if
(
width
==
1
)
{
if
(
width
==
1
&&
!
transpose_a
)
{
for
(
index_t
b
=
0
;
b
<
batch
;
++
b
)
{
Gemv
(
A
+
b
*
height
*
K
,
B
+
b
*
K
,
1
,
K
,
height
,
C
+
b
*
height
);
}
...
...
@@ -1292,45 +1268,78 @@ void Gemm(const float *A,
}
memset
(
C
,
0
,
sizeof
(
float
)
*
batch
*
height
*
width
);
// It is better to use large block size if it fits for fast cache.
// Assume l1 cache size is 32k, we load three blocks at a time (A, B, C),
// the block size should be sqrt(32k / sizeof(T) / 3).
// As number of input channels of convolution is normally power of 2, and
// we have not optimized tiling remains, we use the following magic number
const
index_t
block_size
=
64
;
const
index_t
block_tile_height
=
RoundUpDiv
(
height
,
block_size
);
const
index_t
block_tile_width
=
RoundUpDiv
(
width
,
block_size
);
const
index_t
block_tile_k
=
RoundUpDiv
(
K
,
block_size
);
const
index_t
block_tile
[
3
]
=
{
block_tile_height
,
block_tile_width
,
block_tile_k
};
const
index_t
remain_height
=
height
%
block_size
;
const
index_t
remain_width
=
width
%
block_size
;
const
index_t
remain_k
=
K
%
block_size
;
const
index_t
remain
[
3
]
=
{
remain_height
,
remain_width
,
remain_k
};
std
::
vector
<
index_t
>
block_size_dims
{
height
,
width
,
K
};
index_t
thread_count
=
MaceOpenMPThreadCount
;
MACE_CHECK
(
thread_count
>=
1
,
"thread should be ge 1"
);
// TODO(liyin): apply gcd ?
if
(
height
%
thread_count
==
0
)
{
block_size_dims
[
0
]
=
height
/
thread_count
;
}
else
if
(
thread_count
==
4
&&
(
height
&
1
)
==
0
&&
(
width
&
1
)
==
0
)
{
block_size_dims
[
0
]
=
height
>>
1
;
block_size_dims
[
1
]
=
width
>>
1
;
}
else
if
(
width
%
thread_count
==
0
)
{
block_size_dims
[
1
]
=
width
/
thread_count
;
}
else
{
if
(
height
>=
thread_count
)
{
block_size_dims
[
0
]
=
height
/
thread_count
;
}
else
{
thread_count
=
std
::
min
(
thread_count
,
height
*
width
);
index_t
thread_h
=
height
;
index_t
thread_w
=
RoundUpDiv
(
thread_count
,
thread_h
);
block_size_dims
[
0
]
=
1
;
block_size_dims
[
1
]
=
std
::
max
(
static_cast
<
index_t
>
(
1
),
width
/
thread_w
);
}
}
const
index_t
block_tile
[
3
]
=
{
height
/
block_size_dims
[
0
],
width
/
block_size_dims
[
1
],
K
/
block_size_dims
[
2
]};
block_size_dims
[
0
]
=
height
/
block_tile
[
0
];
block_size_dims
[
1
]
=
width
/
block_tile
[
1
];
block_size_dims
[
2
]
=
K
/
block_tile
[
2
];
const
index_t
remain
[
3
]
=
{
height
%
block_tile
[
0
],
width
%
block_tile
[
1
],
K
%
block_tile
[
2
]};
#pragma omp parallel for collapse(3)
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
for
(
index_t
bh
=
0
;
bh
<
block_tile
[
0
];
++
bh
)
{
for
(
index_t
bw
=
0
;
bw
<
block_tile
[
1
];
++
bw
)
{
const
index_t
remain_height
=
remain
[
0
];
const
index_t
remain_width
=
remain
[
1
];
const
index_t
remain_k
=
remain
[
2
];
const
index_t
block_size_height
=
block_size_dims
[
0
];
const
index_t
block_size_width
=
block_size_dims
[
1
];
const
index_t
block_size_k
=
block_size_dims
[
2
];
const
index_t
this_block_size_height
=
block_size_height
+
(
bh
<
remain_height
?
1
:
0
);
const
index_t
this_block_size_width
=
block_size_width
+
(
bw
<
remain_width
?
1
:
0
);
const
float
*
a_base
=
A
+
n
*
height
*
K
;
const
float
*
b_base
=
B
+
n
*
K
*
width
;
float
*
c_base
=
C
+
n
*
height
*
width
;
const
index_t
ih_begin
=
bh
*
block_size
;
const
index_t
ih_end
=
bh
*
block_size
+
(
bh
==
block_tile
[
0
]
-
1
&&
remain
[
0
]
>
0
?
remain
[
0
]
:
block_size
);
const
index_t
iw_begin
=
bw
*
block_size
;
const
index_t
iw_end
=
bw
*
block_size
+
(
bw
==
block_tile
[
1
]
-
1
&&
remain
[
1
]
>
0
?
remain
[
1
]
:
block_size
);
const
index_t
ih_begin
=
bh
*
block_size_height
+
(
bh
<
remain_height
?
bh
:
remain_height
);
const
index_t
ih_end
=
std
::
min
(
height
,
ih_begin
+
this_block_size_height
);
const
index_t
iw_begin
=
bw
*
block_size_width
+
(
bw
<
remain_width
?
bw
:
remain_width
);
const
index_t
iw_end
=
std
::
min
(
width
,
iw_begin
+
this_block_size_width
);
for
(
index_t
bk
=
0
;
bk
<
block_tile
[
2
];
++
bk
)
{
const
index_t
ik_begin
=
bk
*
block_size
;
const
index_t
ik_end
=
bk
*
block_size
+
(
bk
==
block_tile
[
2
]
-
1
&&
remain
[
2
]
>
0
?
remain
[
2
]
:
block_size
);
const
index_t
this_block_size_k
=
block_size_k
+
(
bk
<
remain_k
?
1
:
0
);
const
index_t
ik_begin
=
bk
*
block_size_k
+
(
bk
<
remain_k
?
bk
:
remain_k
);
const
index_t
ik_end
=
std
::
min
(
K
,
ik_begin
+
this_block_size_k
);
Tensor
trans_a
;
Tensor
trans_b
;
...
...
@@ -1342,7 +1351,7 @@ void Gemm(const float *A,
index_t
stride_c
=
width
;
if
(
transpose_a
)
{
trans_a
.
Resize
({
block_size
,
block_size
});
trans_a
.
Resize
({
this_block_size_height
,
this_block_size_k
});
float
*
trans_a_data
=
trans_a
.
mutable_data
<
float
>
();
// A[K, H] -> A[H, K]
Transpose
(
a_base
+
(
ik_begin
*
height
+
ih_begin
),
...
...
@@ -1356,7 +1365,7 @@ void Gemm(const float *A,
}
if
(
transpose_b
)
{
trans_b
.
Resize
({
block_size
,
block_size
});
trans_b
.
Resize
({
this_block_size_k
,
this_block_size_width
});
float
*
trans_b_data
=
trans_b
.
mutable_data
<
float
>
();
// B[W, K] -> B[K, W]
Transpose
(
b_base
+
(
iw_begin
*
K
+
ik_begin
),
iw_end
-
iw_begin
,
...
...
@@ -1449,7 +1458,6 @@ void GemvRef(const float *m_ptr,
}
}
// TODO(liyin): batched gemv can be transformed to gemm (w/ transpose)
void
Gemv
(
const
float
*
m_ptr
,
const
float
*
v_ptr
,
const
index_t
batch
,
...
...
@@ -1457,88 +1465,74 @@ void Gemv(const float *m_ptr,
const
index_t
height
,
float
*
out_ptr
)
{
#if defined(MACE_ENABLE_NEON)
// TODO(liyin/wch): try height tiling = 8
#pragma omp parallel for collapse(2)
for
(
index_t
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
index_t
h
=
0
;
h
<
height
;
h
+=
4
)
{
if
(
h
+
3
<
height
)
{
const
float
*
m_ptr0
=
m_ptr
+
h
*
width
;
const
float
*
m_ptr1
=
m_ptr0
+
width
;
const
float
*
m_ptr2
=
m_ptr1
+
width
;
const
float
*
m_ptr3
=
m_ptr2
+
width
;
const
float
*
v_ptr0
=
v_ptr
+
b
*
width
;
float
*
out_ptr0
=
out_ptr
+
b
*
height
+
h
;
float32x4_t
vm0
,
vm1
,
vm2
,
vm3
;
float32x4_t
vv
;
float32x4_t
vsum0
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vsum1
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vsum2
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vsum3
=
vdupq_n_f32
(
0.
f
);
index_t
w
;
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
vm0
=
vld1q_f32
(
m_ptr0
);
vm1
=
vld1q_f32
(
m_ptr1
);
vm2
=
vld1q_f32
(
m_ptr2
);
vm3
=
vld1q_f32
(
m_ptr3
);
vv
=
vld1q_f32
(
v_ptr0
);
vsum0
=
vmlaq_f32
(
vsum0
,
vm0
,
vv
);
vsum1
=
vmlaq_f32
(
vsum1
,
vm1
,
vv
);
vsum2
=
vmlaq_f32
(
vsum2
,
vm2
,
vv
);
vsum3
=
vmlaq_f32
(
vsum3
,
vm3
,
vv
);
m_ptr0
+=
4
;
m_ptr1
+=
4
;
m_ptr2
+=
4
;
m_ptr3
+=
4
;
v_ptr0
+=
4
;
}
float
sum0
=
vaddvq_f32
(
vsum0
);
float
sum1
=
vaddvq_f32
(
vsum1
);
float
sum2
=
vaddvq_f32
(
vsum2
);
float
sum3
=
vaddvq_f32
(
vsum3
);
// handle remaining w
for
(;
w
<
width
;
++
w
)
{
sum0
+=
m_ptr0
[
0
]
*
v_ptr0
[
0
];
sum1
+=
m_ptr1
[
0
]
*
v_ptr0
[
0
];
sum2
+=
m_ptr2
[
0
]
*
v_ptr0
[
0
];
sum3
+=
m_ptr3
[
0
]
*
v_ptr0
[
0
];
m_ptr0
++
;
m_ptr1
++
;
m_ptr2
++
;
m_ptr3
++
;
v_ptr0
++
;
}
*
out_ptr0
++
=
sum0
;
*
out_ptr0
++
=
sum1
;
*
out_ptr0
++
=
sum2
;
*
out_ptr0
++
=
sum3
;
}
else
{
for
(
index_t
hh
=
h
;
hh
<
height
;
++
hh
)
{
float32x4_t
vsum0
=
vdupq_n_f32
(
0.
f
);
const
float
*
m_ptr0
=
m_ptr
+
hh
*
width
;
const
float
*
v_ptr0
=
v_ptr
+
b
*
width
;
index_t
w
;
for
(
w
=
0
;
w
+
3
<
width
;
w
+=
4
)
{
float32x4_t
vm
=
vld1q_f32
(
m_ptr0
);
float32x4_t
vv
=
vld1q_f32
(
v_ptr0
);
vsum0
=
vmlaq_f32
(
vsum0
,
vm
,
vv
);
m_ptr0
+=
4
;
v_ptr0
+=
4
;
}
float
sum
=
vaddvq_f32
(
vsum0
);
for
(;
w
<
width
;
++
w
)
{
sum
+=
m_ptr0
[
0
]
*
v_ptr0
[
0
];
m_ptr0
++
;
v_ptr0
++
;
}
out_ptr
[
b
*
height
+
hh
]
=
sum
;
}
}
// if
for
(
index_t
h
=
0
;
h
<
height
;
++
h
)
{
const
float
*
m_ptr0
=
m_ptr
+
h
*
width
;
const
float
*
v_ptr0
=
v_ptr
+
b
*
width
;
float
*
out_ptr0
=
out_ptr
+
b
*
height
+
h
;
float32x4_t
vm0
,
vm1
,
vm2
,
vm3
;
float32x4_t
vv0
,
vv1
,
vv2
,
vv3
;
float32x4_t
vsum0
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vsum1
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vsum2
=
vdupq_n_f32
(
0.
f
);
float32x4_t
vsum3
=
vdupq_n_f32
(
0.
f
);
index_t
w
;
for
(
w
=
0
;
w
+
15
<
width
;
w
+=
16
)
{
vm0
=
vld1q_f32
(
m_ptr0
);
vv0
=
vld1q_f32
(
v_ptr0
);
vm1
=
vld1q_f32
(
m_ptr0
+
4
);
vv1
=
vld1q_f32
(
v_ptr0
+
4
);
vm2
=
vld1q_f32
(
m_ptr0
+
8
);
vv2
=
vld1q_f32
(
v_ptr0
+
8
);
vm3
=
vld1q_f32
(
m_ptr0
+
12
);
vv3
=
vld1q_f32
(
v_ptr0
+
12
);
vsum0
=
vmlaq_f32
(
vsum0
,
vm0
,
vv0
);
vsum1
=
vmlaq_f32
(
vsum1
,
vm1
,
vv1
);
vsum2
=
vmlaq_f32
(
vsum2
,
vm2
,
vv2
);
vsum3
=
vmlaq_f32
(
vsum3
,
vm3
,
vv3
);
m_ptr0
+=
16
;
v_ptr0
+=
16
;
}
for
(;
w
+
7
<
width
;
w
+=
8
)
{
vm0
=
vld1q_f32
(
m_ptr0
);
vv0
=
vld1q_f32
(
v_ptr0
);
vm1
=
vld1q_f32
(
m_ptr0
+
4
);
vv1
=
vld1q_f32
(
v_ptr0
+
4
);
vsum0
=
vmlaq_f32
(
vsum0
,
vm0
,
vv0
);
vsum1
=
vmlaq_f32
(
vsum1
,
vm1
,
vv1
);
m_ptr0
+=
8
;
v_ptr0
+=
8
;
}
for
(;
w
+
3
<
width
;
w
+=
4
)
{
vm0
=
vld1q_f32
(
m_ptr0
);
vv0
=
vld1q_f32
(
v_ptr0
);
vsum0
=
vmlaq_f32
(
vsum0
,
vm0
,
vv0
);
m_ptr0
+=
4
;
v_ptr0
+=
4
;
}
vsum0
+=
vsum1
;
vsum2
+=
vsum3
;
vsum0
+=
vsum2
;
float
sum0
=
vaddvq_f32
(
vsum0
);
// handle remaining w
for
(;
w
<
width
;
++
w
)
{
sum0
+=
m_ptr0
[
0
]
*
v_ptr0
[
0
];
m_ptr0
++
;
v_ptr0
++
;
}
*
out_ptr0
++
=
sum0
;
}
// h
}
// b
#else
...
...
mace/kernels/gemm.h
浏览文件 @
379c730d
...
...
@@ -66,6 +66,12 @@ void GemvRef(const float *m_ptr,
const
index_t
height
,
float
*
out_ptr
);
void
Transpose
(
const
float
*
src
,
index_t
height
,
index_t
width
,
index_t
stride_w
,
float
*
dst
);
}
// namespace kernels
}
// namespace mace
...
...
mace/kernels/gemm_test.cc
浏览文件 @
379c730d
...
...
@@ -83,6 +83,8 @@ TEST(GEMMTest, AlignedWithoutBatch) {
GemmTest
(
1
,
6
,
64
,
128
,
false
,
true
);
GemmTest
(
1
,
7
,
64
,
128
,
true
,
false
);
GemmTest
(
1
,
17
,
64
,
128
,
true
,
true
);
GemmTest
(
1
,
256
,
128
,
4096
,
false
,
false
);
GemmTest
(
1
,
256
,
128
,
4104
,
false
,
false
);
}
TEST
(
GEMMTest
,
UnalignedWithoutBatch
)
{
...
...
mace/kernels/matmul.h
浏览文件 @
379c730d
...
...
@@ -81,16 +81,34 @@ struct MatMulFunctor {
const
T
*
b_ptr_base
=
B
->
data
<
T
>
();
T
*
c_ptr_base
=
C
->
mutable_data
<
T
>
();
// It is better to use large block size if it fits for fast cache.
// Assume l1 cache size is 32k, we load three blocks at a time (A, B, C),
// the block size should be sqrt(32k / sizeof(T) / 3).
memset
(
c_ptr_base
,
0
,
batch
*
height
*
width
*
sizeof
(
T
));
Gemm
(
a_ptr_base
,
b_ptr_base
,
batch
,
height
,
K
,
width
,
c_ptr_base
,
transpose_a
,
transpose_b
);
if
(
height
==
1
&&
width
>
1
&&
B
->
is_weight
())
{
// A * B = (B^T * A^T)^T
if
(
!
transpose_b
)
{
if
(
B_transpose_
.
get
()
==
nullptr
)
{
B_transpose_
.
reset
(
new
Tensor
(
GetDeviceAllocator
(
D
),
DataTypeToEnum
<
T
>::
v
()));
B_transpose_
->
Resize
({
batch
,
width
,
K
});
Tensor
::
MappingGuard
guardbt
(
B_transpose_
.
get
());
T
*
bt_ptr_base
=
B_transpose_
->
mutable_data
<
T
>
();
Transpose
(
b_ptr_base
,
K
,
width
,
width
,
bt_ptr_base
);
}
Tensor
::
MappingGuard
guardbt
(
B_transpose_
.
get
());
T
*
bt_ptr_base
=
B_transpose_
->
mutable_data
<
T
>
();
Gemv
(
bt_ptr_base
,
a_ptr_base
,
batch
,
K
,
width
,
c_ptr_base
);
}
else
{
Gemv
(
b_ptr_base
,
a_ptr_base
,
batch
,
K
,
width
,
c_ptr_base
);
}
}
else
{
Gemm
(
a_ptr_base
,
b_ptr_base
,
batch
,
height
,
K
,
width
,
c_ptr_base
,
transpose_a
,
transpose_b
);
}
return
MACE_SUCCESS
;
}
std
::
unique_ptr
<
Tensor
>
B_transpose_
;
};
template
<
>
...
...
mace/kernels/transpose.h
浏览文件 @
379c730d
...
...
@@ -20,6 +20,7 @@
#endif
#include <vector>
#include <algorithm>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
...
...
@@ -122,9 +123,20 @@ struct TransposeFunctor {
MACE_CHECK
(
dims_
[
0
]
==
1
&&
dims_
[
1
]
==
0
,
"no need transform"
);
index_t
stride_i
=
input_shape
[
0
];
index_t
stride_j
=
input_shape
[
1
];
for
(
int
i
=
0
;
i
<
input_shape
[
0
];
++
i
)
{
for
(
int
j
=
0
;
j
<
input_shape
[
1
];
++
j
)
{
output_data
[
j
*
stride_i
+
i
]
=
input_data
[
i
*
stride_j
+
j
];
index_t
tile_size
=
input_shape
[
0
]
>
512
||
input_shape
[
1
]
>
512
?
64
:
32
;
#pragma omp parallel for collapse(2)
for
(
index_t
i
=
0
;
i
<
input_shape
[
0
];
i
+=
tile_size
)
{
for
(
index_t
j
=
0
;
j
<
input_shape
[
1
];
j
+=
tile_size
)
{
index_t
end_i
=
std
::
min
(
i
+
tile_size
,
input_shape
[
0
]);
index_t
end_j
=
std
::
min
(
j
+
tile_size
,
input_shape
[
1
]);
for
(
index_t
tile_i
=
i
;
tile_i
<
end_i
;
++
tile_i
)
{
for
(
index_t
tile_j
=
j
;
tile_j
<
end_j
;
++
tile_j
)
{
output_data
[
tile_j
*
stride_i
+
tile_i
]
=
input_data
[
tile_i
*
stride_j
+
tile_j
];
}
}
}
}
}
else
if
(
input
->
dim_size
()
==
4
)
{
...
...
mace/ops/resize_bicubic_test.cc
浏览文件 @
379c730d
...
...
@@ -50,7 +50,7 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCorners) {
// Check
auto
expected
=
CreateTensor
<
float
>
({
1
,
1
,
2
,
3
},
{
0
,
1
,
2
,
6
,
7
,
8
});
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
1e-
5
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
1e-
2
);
}
TEST_F
(
ResizeBicubicTest
,
CPUResizeBicubicWOAlignCornersFloat
)
{
...
...
@@ -82,7 +82,7 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCornersFloat) {
8.223037
,
9.223036
,
10.223037
,
24.
,
25.
,
26.
,
28.110298
,
29.1103
,
30.110298
,
32.223038
,
33.223038
,
34.223038
});
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
1e-
5
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
1e-
2
);
}
TEST_F
(
ResizeBicubicTest
,
ResizeBicubicWAlignCorners
)
{
...
...
@@ -112,7 +112,7 @@ TEST_F(ResizeBicubicTest, ResizeBicubicWAlignCorners) {
// Check
auto
expected
=
CreateTensor
<
float
>
({
1
,
1
,
2
,
3
},
{
0
,
1
,
2
,
9
,
10
,
11
});
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
1e-
5
);
ExpectTensorNear
<
float
>
(
*
expected
,
*
net
.
GetOutput
(
"Output"
),
1e-
2
);
}
namespace
{
...
...
@@ -168,7 +168,7 @@ void TestRandomResizeBicubic() {
kernels
::
BufferType
::
IN_OUT_CHANNEL
);
}
// Check
ExpectTensorNear
<
float
>
(
expected
,
*
net
.
GetOutput
(
"DeviceOutput"
),
1e-
5
,
ExpectTensorNear
<
float
>
(
expected
,
*
net
.
GetOutput
(
"DeviceOutput"
),
1e-
2
,
1e-4
);
}
}
...
...
mace/ops/transpose_benchmark.cc
浏览文件 @
379c730d
...
...
@@ -90,6 +90,9 @@ MACE_BM_TRANSPOSE4D(1, 64, 64, 512, 0, 3, 1, 2);
MACE_BM_TRANSPOSE4D
(
1
,
512
,
64
,
64
,
0
,
2
,
3
,
1
);
MACE_BM_TRANSPOSE2D
(
128
,
128
);
MACE_BM_TRANSPOSE2D
(
512
,
512
);
MACE_BM_TRANSPOSE2D
(
1024
,
1024
);
MACE_BM_TRANSPOSE2D
(
512
,
2048
);
MACE_BM_TRANSPOSE2D
(
2048
,
512
);
}
// namespace test
}
// namespace ops
...
...
mace/ops/unstack_test.cc
浏览文件 @
379c730d
...
...
@@ -43,7 +43,6 @@ void TestUnstack(const std::vector<index_t> &input_shape,
net
.
RunOp
();
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
LOG
(
INFO
)
<<
MakeString
(
"Output"
,
i
);
net
.
AddInputFromArray
<
CPU
,
float
>
(
"ExpectedOutput"
,
output_shape
,
outputs
[
i
]);
ExpectTensorNear
<
float
>
(
*
net
.
GetOutput
(
"ExpectedOutput"
),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录