Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
ef032272
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ef032272
编写于
3月 16, 2018
作者:
L
Liangliang He
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'master' into 'master'
Refactor tiling code See merge request !303
上级
f06ddc29
a310ca31
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
117 addition
and
344 deletion
+117
-344
mace/kernels/conv_2d.h
mace/kernels/conv_2d.h
+51
-300
mace/kernels/matmul.h
mace/kernels/matmul.h
+66
-44
未找到文件。
mace/kernels/conv_2d.h
浏览文件 @
ef032272
...
...
@@ -197,6 +197,56 @@ struct Conv2dFunctorBase {
const
float
relux_max_limit_
;
};
#define MACE_DO_CONV2D(CC, CH, CW) \
Conv2dKernelFunc<T, inc_tile_size, CC, CH, CW>( \
input_ptr, filter_data, bias_data, output_ptr, \
h_offset, w_offset, c_offset, kernel_h, kernel_w, \
stride_h, stride_w, dilation_h, dilation_w, \
channels, input_channels, width, padded_width);
#define MACE_CASE_W_CONV2D(CC, CH) \
switch (w_count) { \
case 1: \
MACE_DO_CONV2D(CC, CH, 1); \
break; \
case 2: \
MACE_DO_CONV2D(CC, CH, 2); \
break; \
default: \
LOG(FATAL) << "Unsupported w tile: " << w_count; \
}
#define MACE_CASE_H_CONV2D(CC) \
switch (h_count) { \
case 1: \
MACE_CASE_W_CONV2D(CC, 1); \
break; \
case 2: \
MACE_CASE_W_CONV2D(CC, 2); \
break; \
default: \
LOG(FATAL) << "Unsupported h tile: " << h_count; \
}
#define MACE_CASE_C_CONV2D \
switch (c_count) { \
case 1: \
MACE_CASE_H_CONV2D(1); \
break; \
case 2: \
MACE_CASE_H_CONV2D(2); \
break; \
case 3: \
MACE_CASE_H_CONV2D(3); \
break; \
case 4: \
MACE_CASE_H_CONV2D(4); \
break; \
default: \
LOG(FATAL) << "Unsupported c tile: " << c_count; \
}
template
<
DeviceType
D
,
typename
T
>
struct
Conv2dFunctor
:
Conv2dFunctorBase
{
Conv2dFunctor
(
const
int
*
strides
,
...
...
@@ -312,306 +362,7 @@ struct Conv2dFunctor : Conv2dFunctorBase {
const
int
w_count
=
std
::
min
(
w_tile_size
,
width
-
w_offset
);
const
int
c_count
=
std
::
min
(
c_tile_size
,
channels
-
c_offset
);
switch
(
c_count
)
{
case
1
:
switch
(
h_count
)
{
case
1
:
switch
(
w_count
)
{
case
1
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
1
,
1
,
1
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
2
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
1
,
1
,
2
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
3
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
1
,
1
,
3
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
4
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
1
,
1
,
4
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
default:
LOG
(
FATAL
)
<<
"Unsupported width tile: "
<<
w_count
;
}
break
;
case
2
:
switch
(
w_count
)
{
case
1
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
1
,
2
,
1
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
2
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
1
,
2
,
2
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
3
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
1
,
2
,
3
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
4
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
1
,
2
,
4
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
default:
LOG
(
FATAL
)
<<
"Unsupported width tile: "
<<
w_count
;
}
break
;
default:
LOG
(
FATAL
)
<<
"Unsupported height tile: "
<<
h_count
;
}
break
;
case
2
:
switch
(
h_count
)
{
case
1
:
switch
(
w_count
)
{
case
1
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
2
,
1
,
1
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
2
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
2
,
1
,
2
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
3
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
2
,
1
,
3
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
4
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
2
,
1
,
4
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
default:
LOG
(
FATAL
)
<<
"Unsupported width tile: "
<<
w_count
;
}
break
;
case
2
:
switch
(
w_count
)
{
case
1
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
2
,
2
,
1
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
2
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
2
,
2
,
2
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
3
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
2
,
2
,
3
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
4
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
2
,
2
,
4
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
default:
LOG
(
FATAL
)
<<
"Unsupported width tile: "
<<
w_count
;
}
break
;
default:
LOG
(
FATAL
)
<<
"Unsupported height tile: "
<<
h_count
;
}
break
;
case
3
:
switch
(
h_count
)
{
case
1
:
switch
(
w_count
)
{
case
1
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
3
,
1
,
1
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
2
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
3
,
1
,
2
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
3
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
3
,
1
,
3
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
4
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
3
,
1
,
4
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
default:
LOG
(
FATAL
)
<<
"Unsupported width tile: "
<<
w_count
;
}
break
;
case
2
:
switch
(
w_count
)
{
case
1
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
3
,
2
,
1
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
2
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
3
,
2
,
2
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
3
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
3
,
2
,
3
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
4
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
3
,
2
,
4
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
default:
LOG
(
FATAL
)
<<
"Unsupported width tile: "
<<
w_count
;
}
break
;
default:
LOG
(
FATAL
)
<<
"Unsupported height tile: "
<<
h_count
;
}
break
;
case
4
:
switch
(
h_count
)
{
case
1
:
switch
(
w_count
)
{
case
1
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
4
,
1
,
1
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
2
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
4
,
1
,
2
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
3
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
4
,
1
,
3
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
4
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
4
,
1
,
4
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
default:
LOG
(
FATAL
)
<<
"Unsupported width tile: "
<<
w_count
;
}
break
;
case
2
:
switch
(
w_count
)
{
case
1
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
4
,
2
,
1
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
2
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
4
,
2
,
2
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
3
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
4
,
2
,
3
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
case
4
:
Conv2dKernelFunc
<
T
,
inc_tile_size
,
4
,
2
,
4
>
(
input_ptr
,
filter_data
,
bias_data
,
output_ptr
,
h_offset
,
w_offset
,
c_offset
,
kernel_h
,
kernel_w
,
stride_h
,
stride_w
,
dilation_h
,
dilation_w
,
channels
,
input_channels
,
width
,
padded_width
);
break
;
default:
LOG
(
FATAL
)
<<
"Unsupported width tile: "
<<
w_count
;
}
break
;
default:
LOG
(
FATAL
)
<<
"Unsupported height tile: "
<<
h_count
;
}
break
;
default:
LOG
(
FATAL
)
<<
"Unsupported channel tile: "
<<
c_count
;
}
MACE_CASE_C_CONV2D
;
}
}
}
...
...
mace/kernels/matmul.h
浏览文件 @
ef032272
...
...
@@ -89,42 +89,71 @@ inline void MatMulKernelFunc(const T *A,
}
}
// namespace
#define CASE_K_MATMUL(HC, WC, KC) \
case KC: \
MatMulKernelFunc<T, register_tile_size, HC, WC, KC>(a_ptr_batch_base, \
b_ptr_batch_base, \
c_ptr_batch_base, \
ih, \
iw, \
ik, \
height, \
width, \
K); \
break;
#define CASE_W_MATMUL(HC, WC) \
case WC: \
switch (k_count) { \
CASE_K_MATMUL(HC, WC, 1); \
CASE_K_MATMUL(HC, WC, 2); \
CASE_K_MATMUL(HC, WC, 3); \
CASE_K_MATMUL(HC, WC, 4); \
default: \
LOG(FATAL) << "Unsupported k tile: " << k_count; \
} \
break;
#define CASE_H_MATMUL(HC) \
case HC: \
switch (w_count) { \
CASE_W_MATMUL(HC, 1); \
CASE_W_MATMUL(HC, 2); \
CASE_W_MATMUL(HC, 3); \
CASE_W_MATMUL(HC, 4); \
default: \
LOG(FATAL) << "Unsupported w tile: " << k_count; \
} \
break;
#define MACE_DO_MATMUL(HC, WC, KC) \
MatMulKernelFunc<T, register_tile_size, HC, WC, KC>(a_ptr_batch_base, \
b_ptr_batch_base, \
c_ptr_batch_base, \
ih, \
iw, \
ik, \
height, \
width, \
K);
#define MACE_CASE_K_MATMUL(HC, WC) \
switch (k_count) { \
case 1: \
MACE_DO_MATMUL(HC, WC, 1); \
break; \
case 2: \
MACE_DO_MATMUL(HC, WC, 2); \
break; \
case 3: \
MACE_DO_MATMUL(HC, WC, 3); \
break; \
case 4: \
MACE_DO_MATMUL(HC, WC, 4); \
break; \
default: \
LOG(FATAL) << "Unsupported k tile: " << k_count; \
}
#define MACE_CASE_W_MATMUL(HC) \
switch (w_count) { \
case 1: \
MACE_CASE_K_MATMUL(HC, 1); \
break; \
case 2: \
MACE_CASE_K_MATMUL(HC, 2); \
break; \
case 3: \
MACE_CASE_K_MATMUL(HC, 3); \
break; \
case 4: \
MACE_CASE_K_MATMUL(HC, 4); \
break; \
default: \
LOG(FATAL) << "Unsupported w tile: " << w_count; \
}
#define MACE_CASE_H_MATMUL \
switch (h_count) { \
case 1: \
MACE_CASE_W_MATMUL(1); \
break; \
case 2: \
MACE_CASE_W_MATMUL(2); \
break; \
case 3: \
MACE_CASE_W_MATMUL(3); \
break; \
case 4: \
MACE_CASE_W_MATMUL(4); \
break; \
default: \
LOG(FATAL) << "Unsupported h tile: " << h_count; \
}
template
<
DeviceType
D
,
typename
T
>
struct
MatMulFunctor
{
...
...
@@ -196,14 +225,7 @@ struct MatMulFunctor {
const
int
w_count
=
std
::
min
(
register_tile_size
,
iw_end
-
iw
);
const
int
k_count
=
std
::
min
(
register_tile_size
,
ik_end
-
ik
);
switch
(
h_count
)
{
CASE_H_MATMUL
(
1
);
CASE_H_MATMUL
(
2
);
CASE_H_MATMUL
(
3
);
CASE_H_MATMUL
(
4
);
default:
LOG
(
FATAL
)
<<
"Unsupported height tile: "
<<
h_count
;
}
MACE_CASE_H_MATMUL
;
}
// ik
}
// iw
}
// ih
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录