Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
a2ca646b
Mace
项目概览
Xiaomi
/
Mace
通知
107
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
a2ca646b
编写于
5月 21, 2018
作者:
李
李寅
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Reorder winograd compute order
上级
f15a122d
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
323 addition
and
282 deletion
+323
-282
mace/kernels/arm/conv_winograd.cc
mace/kernels/arm/conv_winograd.cc
+319
-282
mace/ops/conv_2d_benchmark.cc
mace/ops/conv_2d_benchmark.cc
+4
-0
未找到文件。
mace/kernels/arm/conv_winograd.cc
浏览文件 @
a2ca646b
...
@@ -24,7 +24,7 @@ namespace mace {
...
@@ -24,7 +24,7 @@ namespace mace {
namespace
kernels
{
namespace
kernels
{
namespace
{
namespace
{
// NCHW =>
TN
CB (T: in tile pixels, B: tile indices)
// NCHW =>
NT
CB (T: in tile pixels, B: tile indices)
void
TransformInput4x4
(
const
float
*
input
,
void
TransformInput4x4
(
const
float
*
input
,
const
index_t
batch
,
const
index_t
batch
,
const
index_t
in_height
,
const
index_t
in_height
,
...
@@ -32,12 +32,15 @@ void TransformInput4x4(const float *input,
...
@@ -32,12 +32,15 @@ void TransformInput4x4(const float *input,
const
index_t
in_channels
,
const
index_t
in_channels
,
const
index_t
tile_count
,
const
index_t
tile_count
,
float
*
output
)
{
float
*
output
)
{
const
index_t
stride
=
batch
*
in_channels
*
tile_count
;
const
index_t
stride
=
in_channels
*
tile_count
;
const
index_t
in_height_width
=
in_height
*
in_width
;
const
index_t
in_height_width
=
in_height
*
in_width
;
const
index_t
input_batch_size
=
in_height_width
*
in_channels
;
const
index_t
output_batch_size
=
16
*
in_channels
*
tile_count
;
#pragma omp parallel for
#pragma omp parallel for collapse(2)
for
(
index_t
nc
=
0
;
nc
<
batch
*
in_channels
;
++
nc
)
{
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
index_t
tile_index
=
nc
*
tile_count
;
for
(
index_t
c
=
0
;
c
<
in_channels
;
++
c
)
{
index_t
tile_index
=
0
;
for
(
index_t
h
=
0
;
h
<
in_height
-
2
;
h
+=
2
)
{
for
(
index_t
h
=
0
;
h
<
in_height
-
2
;
h
+=
2
)
{
for
(
index_t
w
=
0
;
w
<
in_width
-
2
;
w
+=
2
)
{
for
(
index_t
w
=
0
;
w
<
in_width
-
2
;
w
+=
2
)
{
float
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
d7
,
d8
,
d9
,
d10
,
d11
,
d12
,
d13
,
d14
,
float
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
d7
,
d8
,
d9
,
d10
,
d11
,
d12
,
d13
,
d14
,
...
@@ -46,26 +49,28 @@ void TransformInput4x4(const float *input,
...
@@ -46,26 +49,28 @@ void TransformInput4x4(const float *input,
s15
;
s15
;
// load tile data
// load tile data
const
index_t
tile_offset
=
nc
*
in_height_width
+
h
*
in_width
+
w
;
const
float
*
input_ptr
=
d0
=
input
[
tile_offset
];
input
+
n
*
input_batch_size
+
c
*
in_height_width
+
h
*
in_width
d1
=
input
[
tile_offset
+
1
];
+
w
;
d2
=
input
[
tile_offset
+
2
];
d0
=
input_ptr
[
0
];
d3
=
input
[
tile_offset
+
3
];
d1
=
input_ptr
[
1
];
d2
=
input_ptr
[
2
];
d4
=
input
[
tile_offset
+
in_width
];
d3
=
input_ptr
[
3
];
d5
=
input
[
tile_offset
+
in_width
+
1
];
d6
=
input
[
tile_offset
+
in_width
+
2
];
d4
=
input_ptr
[
in_width
];
d7
=
input
[
tile_offset
+
in_width
+
3
];
d5
=
input_ptr
[
in_width
+
1
];
d6
=
input_ptr
[
in_width
+
2
];
d8
=
input
[
tile_offset
+
2
*
in_width
];
d7
=
input_ptr
[
in_width
+
3
];
d9
=
input
[
tile_offset
+
2
*
in_width
+
1
];
d10
=
input
[
tile_offset
+
2
*
in_width
+
2
];
d8
=
input_ptr
[
2
*
in_width
];
d11
=
input
[
tile_offset
+
2
*
in_width
+
3
];
d9
=
input_ptr
[
2
*
in_width
+
1
];
d10
=
input_ptr
[
2
*
in_width
+
2
];
d12
=
input
[
tile_offset
+
3
*
in_width
];
d11
=
input_ptr
[
2
*
in_width
+
3
];
d13
=
input
[
tile_offset
+
3
*
in_width
+
1
];
d14
=
input
[
tile_offset
+
3
*
in_width
+
2
];
d12
=
input_ptr
[
3
*
in_width
];
d15
=
input
[
tile_offset
+
3
*
in_width
+
3
];
d13
=
input_ptr
[
3
*
in_width
+
1
];
d14
=
input_ptr
[
3
*
in_width
+
2
];
d15
=
input_ptr
[
3
*
in_width
+
3
];
// s = BT * d * B
// s = BT * d * B
s0
=
(
d0
-
d8
)
-
(
d2
-
d10
);
s0
=
(
d0
-
d8
)
-
(
d2
-
d10
);
...
@@ -86,33 +91,36 @@ void TransformInput4x4(const float *input,
...
@@ -86,33 +91,36 @@ void TransformInput4x4(const float *input,
s15
=
(
d5
-
d13
)
-
(
d7
-
d15
);
s15
=
(
d5
-
d13
)
-
(
d7
-
d15
);
// store output
// store output
output
[
tile_index
+
0
*
stride
]
=
s0
;
float
*
output_ptr
=
output
[
tile_index
+
1
*
stride
]
=
s1
;
output
+
n
*
output_batch_size
+
c
*
tile_count
+
tile_index
;
output
[
tile_index
+
2
*
stride
]
=
s2
;
output_ptr
[
0
]
=
s0
;
output
[
tile_index
+
3
*
stride
]
=
s3
;
output_ptr
[
1
*
stride
]
=
s1
;
output_ptr
[
2
*
stride
]
=
s2
;
output
[
tile_index
+
4
*
stride
]
=
s4
;
output_ptr
[
3
*
stride
]
=
s3
;
output
[
tile_index
+
5
*
stride
]
=
s5
;
output
[
tile_index
+
6
*
stride
]
=
s6
;
output_ptr
[
4
*
stride
]
=
s4
;
output
[
tile_index
+
7
*
stride
]
=
s7
;
output_ptr
[
5
*
stride
]
=
s5
;
output_ptr
[
6
*
stride
]
=
s6
;
output
[
tile_index
+
8
*
stride
]
=
s8
;
output_ptr
[
7
*
stride
]
=
s7
;
output
[
tile_index
+
9
*
stride
]
=
s9
;
output
[
tile_index
+
10
*
stride
]
=
s10
;
output_ptr
[
8
*
stride
]
=
s8
;
output
[
tile_index
+
11
*
stride
]
=
s11
;
output_ptr
[
9
*
stride
]
=
s9
;
output_ptr
[
10
*
stride
]
=
s10
;
output
[
tile_index
+
12
*
stride
]
=
s12
;
output_ptr
[
11
*
stride
]
=
s11
;
output
[
tile_index
+
13
*
stride
]
=
s13
;
output
[
tile_index
+
14
*
stride
]
=
s14
;
output_ptr
[
12
*
stride
]
=
s12
;
output
[
tile_index
+
15
*
stride
]
=
s15
;
output_ptr
[
13
*
stride
]
=
s13
;
output_ptr
[
14
*
stride
]
=
s14
;
output_ptr
[
15
*
stride
]
=
s15
;
++
tile_index
;
++
tile_index
;
}
}
}
}
}
}
}
}
}
// NCHW =>
TN
CB (T: in tile pixels, B: tile indices)
// NCHW =>
NT
CB (T: in tile pixels, B: tile indices)
/**
/**
* BT =
* BT =
⎡1 0 -21/4 0 21/4 0 -1 0⎤
⎡1 0 -21/4 0 21/4 0 -1 0⎤
...
@@ -146,26 +154,32 @@ void TransformInput8x8(const float *input,
...
@@ -146,26 +154,32 @@ void TransformInput8x8(const float *input,
const
index_t
in_channels
,
const
index_t
in_channels
,
const
index_t
tile_count
,
const
index_t
tile_count
,
float
*
output
)
{
float
*
output
)
{
const
index_t
stride
=
batch
*
in_channels
*
tile_count
;
const
index_t
stride
=
in_channels
*
tile_count
;
const
index_t
in_height_width
=
in_height
*
in_width
;
const
index_t
in_height_width
=
in_height
*
in_width
;
const
index_t
input_batch_size
=
in_height_width
*
in_channels
;
const
index_t
output_batch_size
=
64
*
in_channels
*
tile_count
;
#pragma omp parallel for
#pragma omp parallel for collapse(2)
for
(
index_t
nc
=
0
;
nc
<
batch
*
in_channels
;
++
nc
)
{
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
index_t
tile_index
=
nc
*
tile_count
;
for
(
index_t
c
=
0
;
c
<
in_channels
;
++
c
)
{
index_t
tile_index
=
0
;
float
s
[
8
][
8
];
float
s
[
8
][
8
];
for
(
index_t
h
=
0
;
h
<
in_height
-
2
;
h
+=
6
)
{
for
(
index_t
h
=
0
;
h
<
in_height
-
2
;
h
+=
6
)
{
for
(
index_t
w
=
0
;
w
<
in_width
-
2
;
w
+=
6
)
{
for
(
index_t
w
=
0
;
w
<
in_width
-
2
;
w
+=
6
)
{
index_t
tile_offset
=
nc
*
in_height_width
+
h
*
in_width
+
w
;
const
float
*
input_ptr
=
input
+
n
*
input_batch_size
+
c
*
in_height_width
+
h
*
in_width
+
w
;
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
float
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
d7
;
float
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
d7
;
d0
=
input
[
tile_offset
];
d0
=
input_ptr
[
0
];
d1
=
input
[
tile_offset
+
1
];
d1
=
input_ptr
[
1
];
d2
=
input
[
tile_offset
+
2
];
d2
=
input_ptr
[
2
];
d3
=
input
[
tile_offset
+
3
];
d3
=
input_ptr
[
3
];
d4
=
input
[
tile_offset
+
4
];
d4
=
input_ptr
[
4
];
d5
=
input
[
tile_offset
+
5
];
d5
=
input_ptr
[
5
];
d6
=
input
[
tile_offset
+
6
];
d6
=
input_ptr
[
6
];
d7
=
input
[
tile_offset
+
7
];
d7
=
input_ptr
[
7
];
s
[
i
][
0
]
=
d0
-
d6
+
(
d4
-
d2
)
*
5.25
;
s
[
i
][
0
]
=
d0
-
d6
+
(
d4
-
d2
)
*
5.25
;
s
[
i
][
7
]
=
d7
-
d1
+
(
d3
-
d5
)
*
5.25
;
s
[
i
][
7
]
=
d7
-
d1
+
(
d3
-
d5
)
*
5.25
;
...
@@ -185,9 +199,11 @@ void TransformInput8x8(const float *input,
...
@@ -185,9 +199,11 @@ void TransformInput8x8(const float *input,
s
[
i
][
5
]
=
u
+
v
;
s
[
i
][
5
]
=
u
+
v
;
s
[
i
][
6
]
=
u
-
v
;
s
[
i
][
6
]
=
u
-
v
;
tile_offset
+=
in_width
;
input_ptr
+=
in_width
;
}
}
float
*
output_ptr
=
output
+
n
*
output_batch_size
+
c
*
tile_count
+
tile_index
;
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
float
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
d7
;
float
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
d7
;
d0
=
s
[
0
][
i
];
d0
=
s
[
0
][
i
];
...
@@ -199,32 +215,33 @@ void TransformInput8x8(const float *input,
...
@@ -199,32 +215,33 @@ void TransformInput8x8(const float *input,
d6
=
s
[
6
][
i
];
d6
=
s
[
6
][
i
];
d7
=
s
[
7
][
i
];
d7
=
s
[
7
][
i
];
output
[
tile_index
+
i
*
stride
]
=
d0
-
d6
+
(
d4
-
d2
)
*
5.25
;
output_ptr
[
i
*
stride
]
=
d0
-
d6
+
(
d4
-
d2
)
*
5.25
;
output
[
tile_index
+
(
56
+
i
)
*
stride
]
=
d7
-
d1
+
(
d3
-
d5
)
*
5.25
;
output_ptr
[
(
56
+
i
)
*
stride
]
=
d7
-
d1
+
(
d3
-
d5
)
*
5.25
;
float
u
=
d2
+
d6
-
d4
*
4.25
;
float
u
=
d2
+
d6
-
d4
*
4.25
;
float
v
=
d1
+
d5
-
d3
*
4.25
;
float
v
=
d1
+
d5
-
d3
*
4.25
;
output
[
tile_index
+
(
8
+
i
)
*
stride
]
=
u
+
v
;
output_ptr
[
(
8
+
i
)
*
stride
]
=
u
+
v
;
output
[
tile_index
+
(
16
+
i
)
*
stride
]
=
u
-
v
;
output_ptr
[
(
16
+
i
)
*
stride
]
=
u
-
v
;
u
=
d6
+
d2
*
0.25
-
d4
*
1.25
;
u
=
d6
+
d2
*
0.25
-
d4
*
1.25
;
v
=
d1
*
0.5
-
d3
*
2.5
+
d5
*
2
;
v
=
d1
*
0.5
-
d3
*
2.5
+
d5
*
2
;
output
[
tile_index
+
(
24
+
i
)
*
stride
]
=
u
+
v
;
output_ptr
[
(
24
+
i
)
*
stride
]
=
u
+
v
;
output
[
tile_index
+
(
32
+
i
)
*
stride
]
=
u
-
v
;
output_ptr
[
(
32
+
i
)
*
stride
]
=
u
-
v
;
u
=
d6
+
(
d2
-
d4
*
1.25
)
*
4
;
u
=
d6
+
(
d2
-
d4
*
1.25
)
*
4
;
v
=
d1
*
2
-
d3
*
2.5
+
d5
*
0.5
;
v
=
d1
*
2
-
d3
*
2.5
+
d5
*
0.5
;
output
[
tile_index
+
(
40
+
i
)
*
stride
]
=
u
+
v
;
output_ptr
[
(
40
+
i
)
*
stride
]
=
u
+
v
;
output
[
tile_index
+
(
48
+
i
)
*
stride
]
=
u
-
v
;
output_ptr
[
(
48
+
i
)
*
stride
]
=
u
-
v
;
}
}
++
tile_index
;
++
tile_index
;
}
}
}
}
}
}
}
}
}
// TOC *
TNCB => TN
OB
// TOC *
NTCB => NT
OB
void
BatchGemm
(
const
float
*
input
,
void
BatchGemm
(
const
float
*
input
,
const
float
*
filter
,
const
float
*
filter
,
index_t
batch
,
index_t
batch
,
...
@@ -233,12 +250,13 @@ void BatchGemm(const float *input,
...
@@ -233,12 +250,13 @@ void BatchGemm(const float *input,
index_t
tile_count
,
index_t
tile_count
,
int
out_tile_size
,
int
out_tile_size
,
float
*
output
)
{
float
*
output
)
{
const
index_t
in_stride
=
batch
*
in_channels
*
tile_count
;
const
index_t
in_channels_tile_count
=
in_channels
*
tile_count
;
const
index_t
filter_stride
=
out_channels
*
in_channels
;
const
index_t
filter_stride
=
out_channels
*
in_channels
;
const
index_t
out_stride
=
batch
*
out_channels
*
tile_count
;
const
index_t
out_channels_tile_count
=
out_channels
*
tile_count
;
const
int
in_tile_area
=
(
out_tile_size
+
2
)
*
(
out_tile_size
+
2
);
const
int
in_tile_area
=
(
out_tile_size
+
2
)
*
(
out_tile_size
+
2
);
const
index_t
in_batch_size
=
in_tile_area
*
in_channels
*
tile_count
;
const
index_t
in_stride
=
in_channels
*
tile_count
;
const
index_t
out_batch_size
=
in_tile_area
*
out_channels
*
tile_count
;
const
index_t
out_stride
=
out_channels
*
tile_count
;
if
(
batch
==
1
)
{
if
(
batch
==
1
)
{
Gemm
(
filter
,
Gemm
(
filter
,
input
,
input
,
...
@@ -248,12 +266,13 @@ void BatchGemm(const float *input,
...
@@ -248,12 +266,13 @@ void BatchGemm(const float *input,
tile_count
,
tile_count
,
output
);
output
);
}
else
{
}
else
{
for
(
int
i
=
0
;
i
<
in_tile_area
;
++
i
)
{
#pragma omp parallel for collapse(2)
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
int
i
=
0
;
i
<
in_tile_area
;
++
i
)
{
const
float
const
float
*
in_ptr
=
input
+
i
*
in_stride
+
b
*
in_channels_tile_count
;
*
in_ptr
=
input
+
b
*
in_batch_size
+
i
*
in_stride
;
const
float
*
filter_ptr
=
filter
+
i
*
filter_stride
;
const
float
*
filter_ptr
=
filter
+
i
*
filter_stride
;
float
*
out_ptr
=
output
+
i
*
out_stride
+
b
*
out_channels_tile_count
;
float
*
out_ptr
=
output
+
b
*
out_batch_size
+
i
*
out_stride
;
Gemm
(
filter_ptr
,
Gemm
(
filter_ptr
,
in_ptr
,
in_ptr
,
1
,
1
,
...
@@ -266,7 +285,7 @@ void BatchGemm(const float *input,
...
@@ -266,7 +285,7 @@ void BatchGemm(const float *input,
}
}
}
}
//
TNOB => ToN
OB => NOHoWo
//
NTOB => NTo
OB => NOHoWo
void
TransformOutput4x4
(
const
float
*
input
,
void
TransformOutput4x4
(
const
float
*
input
,
index_t
batch
,
index_t
batch
,
index_t
out_height
,
index_t
out_height
,
...
@@ -274,11 +293,15 @@ void TransformOutput4x4(const float *input,
...
@@ -274,11 +293,15 @@ void TransformOutput4x4(const float *input,
index_t
out_channels
,
index_t
out_channels
,
index_t
tile_count
,
index_t
tile_count
,
float
*
output
)
{
float
*
output
)
{
const
index_t
in_stride
=
batch
*
out_channels
*
tile_count
;
const
index_t
stride
=
out_channels
*
tile_count
;
const
index_t
input_batch_size
=
16
*
stride
;
const
index_t
out_image_size
=
out_height
*
out_width
;
const
index_t
output_batch_size
=
out_channels
*
out_image_size
;
#pragma omp parallel for
#pragma omp parallel for collapse(2)
for
(
index_t
nm
=
0
;
nm
<
batch
*
out_channels
;
++
nm
)
{
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
index_t
tile_offset
=
nm
*
tile_count
;
for
(
index_t
m
=
0
;
m
<
out_channels
;
++
m
)
{
index_t
tile_offset
=
0
;
for
(
index_t
h
=
0
;
h
<
out_height
;
h
+=
2
)
{
for
(
index_t
h
=
0
;
h
<
out_height
;
h
+=
2
)
{
for
(
index_t
w
=
0
;
w
<
out_width
;
w
+=
2
)
{
for
(
index_t
w
=
0
;
w
<
out_width
;
w
+=
2
)
{
float
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
d7
,
d8
,
d9
,
d10
,
d11
,
d12
,
d13
,
d14
,
float
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
d7
,
d8
,
d9
,
d10
,
d11
,
d12
,
d13
,
d14
,
...
@@ -286,25 +309,27 @@ void TransformOutput4x4(const float *input,
...
@@ -286,25 +309,27 @@ void TransformOutput4x4(const float *input,
float
s0
,
s1
,
s2
,
s3
,
s4
,
s5
,
s6
,
s7
;
float
s0
,
s1
,
s2
,
s3
,
s4
,
s5
,
s6
,
s7
;
float
v0
,
v1
,
v2
,
v3
;
float
v0
,
v1
,
v2
,
v3
;
d0
=
input
[
tile_offset
+
0
*
in_stride
];
const
float
*
input_ptr
=
d1
=
input
[
tile_offset
+
1
*
in_stride
];
input
+
n
*
input_batch_size
+
m
*
tile_count
+
tile_offset
;
d2
=
input
[
tile_offset
+
2
*
in_stride
];
d0
=
input_ptr
[
0
];
d3
=
input
[
tile_offset
+
3
*
in_stride
];
d1
=
input_ptr
[
1
*
stride
];
d2
=
input_ptr
[
2
*
stride
];
d3
=
input_ptr
[
3
*
stride
];
d4
=
input
[
tile_offset
+
4
*
in_
stride
];
d4
=
input_ptr
[
4
*
stride
];
d5
=
input
[
tile_offset
+
5
*
in_
stride
];
d5
=
input_ptr
[
5
*
stride
];
d6
=
input
[
tile_offset
+
6
*
in_
stride
];
d6
=
input_ptr
[
6
*
stride
];
d7
=
input
[
tile_offset
+
7
*
in_
stride
];
d7
=
input_ptr
[
7
*
stride
];
d8
=
input
[
tile_offset
+
8
*
in_
stride
];
d8
=
input_ptr
[
8
*
stride
];
d9
=
input
[
tile_offset
+
9
*
in_
stride
];
d9
=
input_ptr
[
9
*
stride
];
d10
=
input
[
tile_offset
+
10
*
in_
stride
];
d10
=
input_ptr
[
10
*
stride
];
d11
=
input
[
tile_offset
+
11
*
in_
stride
];
d11
=
input_ptr
[
11
*
stride
];
d12
=
input
[
tile_offset
+
12
*
in_
stride
];
d12
=
input_ptr
[
12
*
stride
];
d13
=
input
[
tile_offset
+
13
*
in_
stride
];
d13
=
input_ptr
[
13
*
stride
];
d14
=
input
[
tile_offset
+
14
*
in_
stride
];
d14
=
input_ptr
[
14
*
stride
];
d15
=
input
[
tile_offset
+
15
*
in_
stride
];
d15
=
input_ptr
[
15
*
stride
];
s0
=
d0
+
d1
+
d2
;
s0
=
d0
+
d1
+
d2
;
s1
=
d1
-
d2
-
d3
;
s1
=
d1
-
d2
-
d3
;
...
@@ -320,19 +345,22 @@ void TransformOutput4x4(const float *input,
...
@@ -320,19 +345,22 @@ void TransformOutput4x4(const float *input,
v2
=
s2
-
s4
-
s6
;
v2
=
s2
-
s4
-
s6
;
v3
=
s3
-
s5
-
s7
;
v3
=
s3
-
s5
-
s7
;
index_t
out_offset
=
nm
*
out_height
*
out_width
+
h
*
out_width
+
w
;
float
*
output_ptr
=
output
[
out_offset
]
=
v0
;
output
+
n
*
output_batch_size
+
m
*
out_image_size
+
h
*
out_width
output
[
out_offset
+
1
]
=
v1
;
+
w
;
output
[
out_offset
+
out_width
]
=
v2
;
output_ptr
[
0
]
=
v0
;
output
[
out_offset
+
out_width
+
1
]
=
v3
;
output_ptr
[
1
]
=
v1
;
output_ptr
[
out_width
]
=
v2
;
output_ptr
[
out_width
+
1
]
=
v3
;
++
tile_offset
;
++
tile_offset
;
}
}
}
}
}
}
}
}
}
//
TNOB => ToN
OB => NOHoWo
//
NTOB => NTo
OB => NOHoWo
/**
/**
* AT =
* AT =
⎡1 1 1 1 1 32 32 0⎤
⎡1 1 1 1 1 32 32 0⎤
...
@@ -362,25 +390,31 @@ void TransformOutput8x8(const float *input,
...
@@ -362,25 +390,31 @@ void TransformOutput8x8(const float *input,
index_t
out_channels
,
index_t
out_channels
,
index_t
tile_count
,
index_t
tile_count
,
float
*
output
)
{
float
*
output
)
{
const
index_t
in_stride
=
batch
*
out_channels
*
tile_count
;
const
index_t
stride
=
out_channels
*
tile_count
;
const
index_t
input_batch_size
=
64
*
stride
;
const
index_t
out_image_size
=
out_height
*
out_width
;
const
index_t
output_batch_size
=
out_channels
*
out_image_size
;
#pragma omp parallel for
#pragma omp parallel for collapse(2)
for
(
index_t
nm
=
0
;
nm
<
batch
*
out_channels
;
++
nm
)
{
for
(
index_t
n
=
0
;
n
<
batch
;
++
n
)
{
index_t
tile_offset
=
nm
*
tile_count
;
for
(
index_t
m
=
0
;
m
<
out_channels
;
++
m
)
{
index_t
tile_offset
=
0
;
float
s
[
8
][
6
];
float
s
[
8
][
6
];
for
(
index_t
h
=
0
;
h
<
out_height
;
h
+=
6
)
{
for
(
index_t
h
=
0
;
h
<
out_height
;
h
+=
6
)
{
for
(
index_t
w
=
0
;
w
<
out_width
;
w
+=
6
)
{
for
(
index_t
w
=
0
;
w
<
out_width
;
w
+=
6
)
{
index_t
tile_offset_tmp
=
tile_offset
;
const
float
*
input_ptr
=
input
+
n
*
input_batch_size
+
m
*
tile_count
+
tile_offset
;
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
float
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
d7
;
float
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
d7
;
d0
=
input
[
tile_offset_tmp
+
0
*
in_stride
];
d1
=
input
[
tile_offset_tmp
+
1
*
in_stride
];
d0
=
input_ptr
[
0
];
d2
=
input
[
tile_offset_tmp
+
2
*
in_stride
];
d1
=
input_ptr
[
1
*
stride
];
d3
=
input
[
tile_offset_tmp
+
3
*
in_stride
];
d2
=
input_ptr
[
2
*
stride
];
d4
=
input
[
tile_offset_tmp
+
4
*
in_stride
];
d3
=
input_ptr
[
3
*
stride
];
d5
=
input
[
tile_offset_tmp
+
5
*
in_stride
];
d4
=
input_ptr
[
4
*
stride
];
d6
=
input
[
tile_offset_tmp
+
6
*
in_stride
];
d5
=
input_ptr
[
5
*
stride
];
d7
=
input
[
tile_offset_tmp
+
7
*
in_stride
];
d6
=
input_ptr
[
6
*
stride
];
d7
=
input_ptr
[
7
*
stride
];
float
u
=
d1
+
d2
;
float
u
=
d1
+
d2
;
float
v
=
d1
-
d2
;
float
v
=
d1
-
d2
;
...
@@ -396,10 +430,12 @@ void TransformOutput8x8(const float *input,
...
@@ -396,10 +430,12 @@ void TransformOutput8x8(const float *input,
s
[
i
][
4
]
=
u
+
w
*
16
+
y
+
y
;
s
[
i
][
4
]
=
u
+
w
*
16
+
y
+
y
;
s
[
i
][
5
]
=
v
+
x
*
32
+
z
+
d7
;
s
[
i
][
5
]
=
v
+
x
*
32
+
z
+
d7
;
tile_offset_tmp
+=
8
*
in_
stride
;
input_ptr
+=
8
*
stride
;
}
}
index_t
out_offset
=
nm
*
out_height
*
out_width
+
h
*
out_width
+
w
;
float
*
output_ptr
=
output
+
n
*
output_batch_size
+
m
*
out_image_size
+
h
*
out_width
+
w
;
for
(
int
i
=
0
;
i
<
6
;
++
i
)
{
for
(
int
i
=
0
;
i
<
6
;
++
i
)
{
float
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
d7
;
float
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
d7
;
...
@@ -419,18 +455,19 @@ void TransformOutput8x8(const float *input,
...
@@ -419,18 +455,19 @@ void TransformOutput8x8(const float *input,
float
y
=
d5
+
d6
;
float
y
=
d5
+
d6
;
float
z
=
d5
-
d6
;
float
z
=
d5
-
d6
;
output
[
out_offset
+
0
*
out_width
+
i
]
=
d0
+
u
+
w
+
y
*
32
;
output_ptr
[
i
]
=
d0
+
u
+
w
+
y
*
32
;
output
[
out_offset
+
1
*
out_width
+
i
]
=
v
+
x
+
x
+
z
*
16
;
output_ptr
[
1
*
out_width
+
i
]
=
v
+
x
+
x
+
z
*
16
;
output
[
out_offset
+
2
*
out_width
+
i
]
=
u
+
w
*
4
+
y
*
8
;
output_ptr
[
2
*
out_width
+
i
]
=
u
+
w
*
4
+
y
*
8
;
output
[
out_offset
+
3
*
out_width
+
i
]
=
v
+
x
*
8
+
z
*
4
;
output_ptr
[
3
*
out_width
+
i
]
=
v
+
x
*
8
+
z
*
4
;
output
[
out_offset
+
4
*
out_width
+
i
]
=
u
+
w
*
16
+
y
+
y
;
output_ptr
[
4
*
out_width
+
i
]
=
u
+
w
*
16
+
y
+
y
;
output
[
out_offset
+
5
*
out_width
+
i
]
=
v
+
x
*
32
+
z
+
d7
;
output_ptr
[
5
*
out_width
+
i
]
=
v
+
x
*
32
+
z
+
d7
;
}
}
++
tile_offset
;
++
tile_offset
;
}
}
}
}
}
}
}
}
}
}
// namespace
}
// namespace
...
...
mace/ops/conv_2d_benchmark.cc
浏览文件 @
a2ca646b
...
@@ -165,6 +165,10 @@ BM_CONV_2D(1, 32, 256, 256, 3, 3, 1, 4, VALID, 32);
...
@@ -165,6 +165,10 @@ BM_CONV_2D(1, 32, 256, 256, 3, 3, 1, 4, VALID, 32);
BM_CONV_2D
(
1
,
128
,
56
,
56
,
1
,
1
,
1
,
1
,
SAME
,
128
);
BM_CONV_2D
(
1
,
128
,
56
,
56
,
1
,
1
,
1
,
1
,
SAME
,
128
);
BM_CONV_2D
(
1
,
1024
,
7
,
7
,
1
,
1
,
1
,
1
,
SAME
,
1024
);
BM_CONV_2D
(
1
,
1024
,
7
,
7
,
1
,
1
,
1
,
1
,
SAME
,
1024
);
BM_CONV_2D
(
64
,
32
,
34
,
34
,
3
,
3
,
1
,
1
,
VALID
,
32
);
BM_CONV_2D
(
1
,
32
,
34
,
34
,
3
,
3
,
1
,
1
,
VALID
,
32
);
}
// namespace test
}
// namespace test
}
// namespace ops
}
// namespace ops
}
// namespace mace
}
// namespace mace
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录