Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Xiaomi
Mace
提交
e9bcf721
Mace
项目概览
Xiaomi
/
Mace
通知
106
Star
40
Fork
27
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
e9bcf721
编写于
4月 08, 2018
作者:
李
李寅
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement winograd (6x6, 3x3)
上级
5fc5fbb2
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
412 addition
and
54 deletion
+412
-54
mace/kernels/arm/conv_2d.cc
mace/kernels/arm/conv_2d.cc
+12
-7
mace/kernels/arm/conv_winograd.cc
mace/kernels/arm/conv_winograd.cc
+396
-46
mace/kernels/arm/conv_winograd.h
mace/kernels/arm/conv_winograd.h
+2
-0
mace/kernels/arm/conv_winograd_test.cc
mace/kernels/arm/conv_winograd_test.cc
+2
-1
未找到文件。
mace/kernels/arm/conv_2d.cc
浏览文件 @
e9bcf721
...
...
@@ -7,6 +7,7 @@
// winograd is always superior to neon impl during benchmark
#define USE_WINOGRAD 1
#define WINOGRAD_OUT_TILE_SIZE 6
namespace
mace
{
namespace
kernels
{
...
...
@@ -164,9 +165,9 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
&&
stride_w
==
1
&&
dilation_h
==
1
&&
dilation_w
==
1
&&
input_channels
>=
8
&&
channels
>=
8
)
{
extra_output_height
=
RoundUp
<
index_t
>
(
height
,
2
);
extra_output_height
=
RoundUp
<
index_t
>
(
height
,
WINOGRAD_OUT_TILE_SIZE
);
extra_input_height
=
std
::
max
(
padded_input_height
,
extra_output_height
+
2
);
extra_output_width
=
RoundUp
<
index_t
>
(
width
,
2
);
extra_output_width
=
RoundUp
<
index_t
>
(
width
,
WINOGRAD_OUT_TILE_SIZE
);
extra_input_width
=
std
::
max
(
padded_input_width
,
extra_output_width
+
2
);
if
(
extra_input_height
!=
padded_input_height
)
{
pad_bottom
+=
(
extra_input_height
-
padded_input_height
);
...
...
@@ -175,12 +176,15 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
pad_right
+=
(
extra_input_width
-
padded_input_width
);
}
index_t
tile_height_count
=
(
extra_output_height
+
1
)
/
2
;
index_t
tile_width_count
=
(
extra_output_width
+
1
)
/
2
;
index_t
tile_height_count
=
extra_output_height
/
WINOGRAD_OUT_TILE_SIZE
;
index_t
tile_width_count
=
extra_output_width
/
WINOGRAD_OUT_TILE_SIZE
;
index_t
tile_count
=
tile_height_count
*
tile_width_count
;
transformed_input_
.
Resize
({
16
,
batch
,
input_channels
,
tile_count
});
transformed_filter_
.
Resize
({
16
,
channels
,
input_channels
});
transformed_output_
.
Resize
({
16
,
batch
,
channels
,
tile_count
});
index_t
in_tile_area
=
(
WINOGRAD_OUT_TILE_SIZE
+
2
)
*
(
WINOGRAD_OUT_TILE_SIZE
+
2
);
transformed_input_
.
Resize
({
in_tile_area
,
batch
,
input_channels
,
tile_count
});
transformed_filter_
.
Resize
({
in_tile_area
,
channels
,
input_channels
});
transformed_output_
.
Resize
({
in_tile_area
,
batch
,
channels
,
tile_count
});
conv_func
=
[
=
](
const
float
*
pad_input
,
float
*
pad_output
)
{
WinoGradConv3x3s1
(
pad_input
,
...
...
@@ -190,6 +194,7 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
extra_input_width
,
input_channels
,
channels
,
WINOGRAD_OUT_TILE_SIZE
,
transformed_input_
.
mutable_data
<
float
>
(),
transformed_filter_
.
mutable_data
<
float
>
(),
transformed_output_
.
mutable_data
<
float
>
(),
...
...
mace/kernels/arm/conv_winograd.cc
浏览文件 @
e9bcf721
...
...
@@ -8,19 +8,20 @@
#include "mace/kernels/arm/conv_winograd.h"
#include "mace/kernels/gemm.h"
#include "mace/utils/utils.h"
#include "mace/utils/logging.h"
namespace
mace
{
namespace
kernels
{
namespace
{
// NCHW => TNCB (T: in tile pixels, B: tile indices)
void
TransformInput
(
const
float
*
input
,
const
index_t
batch
,
const
index_t
in_height
,
const
index_t
in_width
,
const
index_t
in_channels
,
const
index_t
tile_count
,
float
*
output
)
{
void
TransformInput
4x4
(
const
float
*
input
,
const
index_t
batch
,
const
index_t
in_height
,
const
index_t
in_width
,
const
index_t
in_channels
,
const
index_t
tile_count
,
float
*
output
)
{
const
index_t
stride
=
batch
*
in_channels
*
tile_count
;
const
index_t
in_height_width
=
in_height
*
in_width
;
...
...
@@ -101,12 +102,124 @@ void TransformInput(const float *input,
}
}
// NCHW => TNCB (T: in tile pixels, B: tile indices)
/**
* BT =
⎡1 0 -21/4 0 21/4 0 -1 0⎤
⎢ ⎥
⎢0 1 1 -17/4 -17/4 1 1 0⎥
⎢ ⎥
⎢0 -1 1 17/4 -17/4 -1 1 0⎥
⎢ ⎥
⎢0 1/2 1/4 -5/2 -5/4 2 1 0⎥
⎢ ⎥
⎢0 -1/2 1/4 5/2 -5/4 -2 1 0⎥
⎢ ⎥
⎢0 2 4 -5/2 -5 1/2 1 0⎥
⎢ ⎥
⎢0 -2 4 5/2 -5 -1/2 1 0⎥
⎢ ⎥
⎣0 -1 0 21/4 0 -21/4 0 1⎦
* @param input
* @param batch
* @param in_height
* @param in_width
* @param in_channels
* @param tile_count
* @param output
*/
void
TransformInput8x8
(
const
float
*
input
,
const
index_t
batch
,
const
index_t
in_height
,
const
index_t
in_width
,
const
index_t
in_channels
,
const
index_t
tile_count
,
float
*
output
)
{
const
index_t
stride
=
batch
*
in_channels
*
tile_count
;
const
index_t
in_height_width
=
in_height
*
in_width
;
#pragma omp parallel for
for
(
index_t
nc
=
0
;
nc
<
batch
*
in_channels
;
++
nc
)
{
index_t
tile_index
=
nc
*
tile_count
;
float
s
[
8
][
8
];
for
(
index_t
h
=
0
;
h
<
in_height
-
2
;
h
+=
6
)
{
for
(
index_t
w
=
0
;
w
<
in_width
-
2
;
w
+=
6
)
{
index_t
tile_offset
=
nc
*
in_height_width
+
h
*
in_width
+
w
;
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
float
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
d7
;
d0
=
input
[
tile_offset
];
d1
=
input
[
tile_offset
+
1
];
d2
=
input
[
tile_offset
+
2
];
d3
=
input
[
tile_offset
+
3
];
d4
=
input
[
tile_offset
+
4
];
d5
=
input
[
tile_offset
+
5
];
d6
=
input
[
tile_offset
+
6
];
d7
=
input
[
tile_offset
+
7
];
s
[
i
][
0
]
=
d0
-
d6
+
(
d4
-
d2
)
*
5.25
;
s
[
i
][
7
]
=
d7
-
d1
+
(
d3
-
d5
)
*
5.25
;
float
u
=
d2
+
d6
-
d4
*
4.25
;
float
v
=
d1
+
d5
-
d3
*
4.25
;
s
[
i
][
1
]
=
u
+
v
;
s
[
i
][
2
]
=
u
-
v
;
u
=
d6
+
d2
*
0.25
-
d4
*
1.25
;
v
=
d1
*
0.5
-
d3
*
2.5
+
d5
*
2
;
s
[
i
][
3
]
=
u
+
v
;
s
[
i
][
4
]
=
u
-
v
;
u
=
d6
+
(
d2
-
d4
*
1.25
)
*
4
;
v
=
d1
*
2
-
d3
*
2.5
+
d5
*
0.5
;
s
[
i
][
5
]
=
u
+
v
;
s
[
i
][
6
]
=
u
-
v
;
tile_offset
+=
in_width
;
}
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
float
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
d7
;
d0
=
s
[
0
][
i
];
d1
=
s
[
1
][
i
];
d2
=
s
[
2
][
i
];
d3
=
s
[
3
][
i
];
d4
=
s
[
4
][
i
];
d5
=
s
[
5
][
i
];
d6
=
s
[
6
][
i
];
d7
=
s
[
7
][
i
];
output
[
tile_index
+
i
*
stride
]
=
d0
-
d6
+
(
d4
-
d2
)
*
5.25
;
output
[
tile_index
+
(
56
+
i
)
*
stride
]
=
d7
-
d1
+
(
d3
-
d5
)
*
5.25
;
float
u
=
d2
+
d6
-
d4
*
4.25
;
float
v
=
d1
+
d5
-
d3
*
4.25
;
output
[
tile_index
+
(
8
+
i
)
*
stride
]
=
u
+
v
;
output
[
tile_index
+
(
16
+
i
)
*
stride
]
=
u
-
v
;
u
=
d6
+
d2
*
0.25
-
d4
*
1.25
;
v
=
d1
*
0.5
-
d3
*
2.5
+
d5
*
2
;
output
[
tile_index
+
(
24
+
i
)
*
stride
]
=
u
+
v
;
output
[
tile_index
+
(
32
+
i
)
*
stride
]
=
u
-
v
;
u
=
d6
+
(
d2
-
d4
*
1.25
)
*
4
;
v
=
d1
*
2
-
d3
*
2.5
+
d5
*
0.5
;
output
[
tile_index
+
(
40
+
i
)
*
stride
]
=
u
+
v
;
output
[
tile_index
+
(
48
+
i
)
*
stride
]
=
u
-
v
;
}
++
tile_index
;
}
}
}
}
// OCHW => TOC
// no need to optimize, it will exist in converter
void
TransformFilter
(
const
float
*
filter
,
const
index_t
in_channels
,
const
index_t
out_channels
,
float
*
output
)
{
void
TransformFilter
4x4
(
const
float
*
filter
,
const
index_t
in_channels
,
const
index_t
out_channels
,
float
*
output
)
{
const
index_t
stride
=
out_channels
*
in_channels
;
#pragma omp parallel for collapse(2)
...
...
@@ -171,6 +284,83 @@ void TransformFilter(const float *filter,
}
}
// OCHW => TOC
// no need to optimize, it will exist in converter
/**
* G =
⎡ 1 0 0 ⎤
⎢ ⎥
⎢-2/9 -2/9 -2/9 ⎥
⎢ ⎥
⎢-2/9 2/9 -2/9 ⎥
⎢ ⎥
⎢1/90 1/45 2/45 ⎥
⎢ ⎥
⎢1/90 -1/45 2/45 ⎥
⎢ ⎥
⎢1/45 1/90 1/180⎥
⎢ ⎥
⎢1/45 -1/90 1/180⎥
⎢ ⎥
⎣ 0 0 1 ⎦
*
* @param filter
* @param in_channels
* @param out_channels
* @param output
*/
void
TransformFilter8x8
(
const
float
*
filter
,
const
index_t
in_channels
,
const
index_t
out_channels
,
float
*
output
)
{
const
index_t
stride
=
out_channels
*
in_channels
;
const
float
G
[
8
][
3
]
=
{
{
1.0
f
,
0.0
f
,
0.0
f
},
{
-
2.0
f
/
9
,
-
2.0
f
/
9
,
-
2.0
f
/
9
},
{
-
2.0
f
/
9
,
2.0
f
/
9
,
-
2.0
f
/
9
},
{
1.0
f
/
90
,
1.0
f
/
45
,
2.0
f
/
45
},
{
1.0
f
/
90
,
-
1.0
f
/
45
,
2.0
f
/
45
},
{
1.0
f
/
45
,
1.0
f
/
90
,
1.0
f
/
180
},
{
1.0
f
/
45
,
-
1.0
f
/
90
,
1.0
f
/
180
},
{
0.0
f
,
0.0
f
,
1.0
f
}
};
#pragma omp parallel for collapse(2)
for
(
index_t
m
=
0
;
m
<
out_channels
;
++
m
)
{
for
(
index_t
c
=
0
;
c
<
in_channels
;
++
c
)
{
// load filter
index_t
filter_offset
=
(
m
*
in_channels
+
c
)
*
9
;
float
g0
,
g1
,
g2
,
g3
,
g4
,
g5
,
g6
,
g7
,
g8
;
g0
=
filter
[
filter_offset
];
g1
=
filter
[
filter_offset
+
1
];
g2
=
filter
[
filter_offset
+
2
];
g3
=
filter
[
filter_offset
+
3
];
g4
=
filter
[
filter_offset
+
4
];
g5
=
filter
[
filter_offset
+
5
];
g6
=
filter
[
filter_offset
+
6
];
g7
=
filter
[
filter_offset
+
7
];
g8
=
filter
[
filter_offset
+
8
];
float
s
[
3
][
8
];
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
s
[
0
][
i
]
=
g0
*
G
[
i
][
0
]
+
g1
*
G
[
i
][
1
]
+
g2
*
G
[
i
][
2
];
s
[
1
][
i
]
=
g3
*
G
[
i
][
0
]
+
g4
*
G
[
i
][
1
]
+
g5
*
G
[
i
][
2
];
s
[
2
][
i
]
=
g6
*
G
[
i
][
0
]
+
g7
*
G
[
i
][
1
]
+
g8
*
G
[
i
][
2
];
}
// store output
index_t
output_offset
=
m
*
in_channels
+
c
;
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
for
(
int
j
=
0
;
j
<
8
;
++
j
)
{
output
[
output_offset
+
(
i
*
8
+
j
)
*
stride
]
=
G
[
i
][
0
]
*
s
[
0
][
j
]
+
G
[
i
][
1
]
*
s
[
1
][
j
]
+
G
[
i
][
2
]
*
s
[
2
][
j
];
}
}
}
}
}
// TOC * TNCB => TNOB
void
BatchGemm
(
const
float
*
input
,
const
float
*
filter
,
...
...
@@ -178,17 +368,24 @@ void BatchGemm(const float *input,
index_t
in_channels
,
index_t
out_channels
,
index_t
tile_count
,
int
out_tile_size
,
float
*
output
)
{
const
index_t
in_stride
=
batch
*
in_channels
*
tile_count
;
const
index_t
in_channels_tile_count
=
in_channels
*
tile_count
;
const
index_t
filter_stride
=
out_channels
*
in_channels
;
const
index_t
out_stride
=
batch
*
out_channels
*
tile_count
;
const
index_t
out_channels_tile_count
=
out_channels
*
tile_count
;
const
int
in_tile_area
=
(
out_tile_size
+
2
)
*
(
out_tile_size
+
2
);
if
(
batch
==
1
)
{
Gemm
(
filter
,
input
,
16
,
out_channels
,
in_channels
,
tile_count
,
output
);
Gemm
(
filter
,
input
,
in_tile_area
,
out_channels
,
in_channels
,
tile_count
,
output
);
}
else
{
for
(
int
i
=
0
;
i
<
16
;
++
i
)
{
for
(
int
i
=
0
;
i
<
in_tile_area
;
++
i
)
{
for
(
int
b
=
0
;
b
<
batch
;
++
b
)
{
const
float
*
in_ptr
=
input
+
i
*
in_stride
+
b
*
in_channels_tile_count
;
...
...
@@ -207,13 +404,13 @@ void BatchGemm(const float *input,
}
// TNOB => ToNOB => NOHoWo
void
TransformOutput
(
const
float
*
input
,
index_t
batch
,
index_t
out_height
,
index_t
out_width
,
index_t
out_channels
,
index_t
tile_count
,
float
*
output
)
{
void
TransformOutput
4x4
(
const
float
*
input
,
index_t
batch
,
index_t
out_height
,
index_t
out_width
,
index_t
out_channels
,
index_t
tile_count
,
float
*
output
)
{
const
index_t
in_stride
=
batch
*
out_channels
*
tile_count
;
#pragma omp parallel for
...
...
@@ -271,6 +468,107 @@ void TransformOutput(const float *input,
}
}
}
// TNOB => ToNOB => NOHoWo
/**
* AT =
⎡1 1 1 1 1 32 32 0⎤
⎢ ⎥
⎢0 1 -1 2 -2 16 -16 0⎥
⎢ ⎥
⎢0 1 1 4 4 8 8 0⎥
⎢ ⎥
⎢0 1 -1 8 -8 4 -4 0⎥
⎢ ⎥
⎢0 1 1 16 16 2 2 0⎥
⎢ ⎥
⎣0 1 -1 32 -32 1 -1 1⎦
*
* @param input
* @param batch
* @param out_height
* @param out_width
* @param out_channels
* @param tile_count
* @param output
*/
void
TransformOutput8x8
(
const
float
*
input
,
index_t
batch
,
index_t
out_height
,
index_t
out_width
,
index_t
out_channels
,
index_t
tile_count
,
float
*
output
)
{
const
index_t
in_stride
=
batch
*
out_channels
*
tile_count
;
#pragma omp parallel for
for
(
index_t
nm
=
0
;
nm
<
batch
*
out_channels
;
++
nm
)
{
index_t
tile_offset
=
nm
*
tile_count
;
float
s
[
8
][
6
];
for
(
index_t
h
=
0
;
h
<
out_height
;
h
+=
6
)
{
for
(
index_t
w
=
0
;
w
<
out_width
;
w
+=
6
)
{
index_t
tile_offset_tmp
=
tile_offset
;
for
(
int
i
=
0
;
i
<
8
;
++
i
)
{
float
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
d7
;
d0
=
input
[
tile_offset_tmp
+
0
*
in_stride
];
d1
=
input
[
tile_offset_tmp
+
1
*
in_stride
];
d2
=
input
[
tile_offset_tmp
+
2
*
in_stride
];
d3
=
input
[
tile_offset_tmp
+
3
*
in_stride
];
d4
=
input
[
tile_offset_tmp
+
4
*
in_stride
];
d5
=
input
[
tile_offset_tmp
+
5
*
in_stride
];
d6
=
input
[
tile_offset_tmp
+
6
*
in_stride
];
d7
=
input
[
tile_offset_tmp
+
7
*
in_stride
];
float
u
=
d1
+
d2
;
float
v
=
d1
-
d2
;
float
w
=
d3
+
d4
;
float
x
=
d3
-
d4
;
float
y
=
d5
+
d6
;
float
z
=
d5
-
d6
;
s
[
i
][
0
]
=
d0
+
u
+
w
+
y
*
32
;
s
[
i
][
1
]
=
v
+
x
+
x
+
z
*
16
;
s
[
i
][
2
]
=
u
+
w
*
4
+
y
*
8
;
s
[
i
][
3
]
=
v
+
x
*
8
+
z
*
4
;
s
[
i
][
4
]
=
u
+
w
*
16
+
y
+
y
;
s
[
i
][
5
]
=
v
+
x
*
32
+
z
+
d7
;
tile_offset_tmp
+=
8
*
in_stride
;
}
index_t
out_offset
=
nm
*
out_height
*
out_width
+
h
*
out_width
+
w
;
for
(
int
i
=
0
;
i
<
6
;
++
i
)
{
float
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
d7
;
d0
=
s
[
0
][
i
];
d1
=
s
[
1
][
i
];
d2
=
s
[
2
][
i
];
d3
=
s
[
3
][
i
];
d4
=
s
[
4
][
i
];
d5
=
s
[
5
][
i
];
d6
=
s
[
6
][
i
];
d7
=
s
[
7
][
i
];
float
u
=
d1
+
d2
;
float
v
=
d1
-
d2
;
float
w
=
d3
+
d4
;
float
x
=
d3
-
d4
;
float
y
=
d5
+
d6
;
float
z
=
d5
-
d6
;
output
[
out_offset
+
0
*
out_width
+
i
]
=
d0
+
u
+
w
+
y
*
32
;
output
[
out_offset
+
1
*
out_width
+
i
]
=
v
+
x
+
x
+
z
*
16
;
output
[
out_offset
+
2
*
out_width
+
i
]
=
u
+
w
*
4
+
y
*
8
;
output
[
out_offset
+
3
*
out_width
+
i
]
=
v
+
x
*
8
+
z
*
4
;
output
[
out_offset
+
4
*
out_width
+
i
]
=
u
+
w
*
16
+
y
+
y
;
output
[
out_offset
+
5
*
out_width
+
i
]
=
v
+
x
*
32
+
z
+
d7
;
}
++
tile_offset
;
}
}
}
}
}
// namespace
void
WinoGradConv3x3s1
(
const
float
*
input
,
...
...
@@ -280,6 +578,7 @@ void WinoGradConv3x3s1(const float *input,
const
index_t
in_width
,
const
index_t
in_channels
,
const
index_t
out_channels
,
const
int
out_tile_size
,
float
*
transformed_input
,
float
*
transformed_filter
,
float
*
transformed_output
,
...
...
@@ -287,22 +586,52 @@ void WinoGradConv3x3s1(const float *input,
float
*
output
)
{
index_t
out_height
=
in_height
-
2
;
index_t
out_width
=
in_width
-
2
;
index_t
tile_height_count
=
(
out_height
+
1
)
/
2
;
index_t
tile_width_count
=
(
out_width
+
1
)
/
2
;
index_t
tile_height_count
=
RoundUpDiv
(
out_height
,
static_cast
<
index_t
>
(
out_tile_size
));
index_t
tile_width_count
=
RoundUpDiv
(
out_width
,
static_cast
<
index_t
>
(
out_tile_size
));
index_t
tile_count
=
tile_height_count
*
tile_width_count
;
TransformInput
(
input
,
batch
,
in_height
,
in_width
,
in_channels
,
tile_count
,
transformed_input
);
switch
(
out_tile_size
)
{
case
2
:
TransformInput4x4
(
input
,
batch
,
in_height
,
in_width
,
in_channels
,
tile_count
,
transformed_input
);
break
;
case
6
:
TransformInput8x8
(
input
,
batch
,
in_height
,
in_width
,
in_channels
,
tile_count
,
transformed_input
);
break
;
default:
MACE_NOT_IMPLEMENTED
;
}
// TODO(liyin): put it in model converter, but do not worry, it is fast and
// will only do once
if
(
!
is_filter_transformed
)
{
TransformFilter
(
filter
,
in_channels
,
out_channels
,
transformed_filter
);
switch
(
out_tile_size
)
{
case
2
:
TransformFilter4x4
(
filter
,
in_channels
,
out_channels
,
transformed_filter
);
break
;
case
6
:
TransformFilter8x8
(
filter
,
in_channels
,
out_channels
,
transformed_filter
);
break
;
default:
MACE_NOT_IMPLEMENTED
;
}
}
BatchGemm
(
transformed_input
,
...
...
@@ -311,15 +640,30 @@ void WinoGradConv3x3s1(const float *input,
in_channels
,
out_channels
,
tile_count
,
out_tile_size
,
transformed_output
);
TransformOutput
(
transformed_output
,
batch
,
out_height
,
out_width
,
out_channels
,
tile_count
,
output
);
switch
(
out_tile_size
)
{
case
2
:
TransformOutput4x4
(
transformed_output
,
batch
,
out_height
,
out_width
,
out_channels
,
tile_count
,
output
);
break
;
case
6
:
TransformOutput8x8
(
transformed_output
,
batch
,
out_height
,
out_width
,
out_channels
,
tile_count
,
output
);
break
;
default:
MACE_NOT_IMPLEMENTED
;
}
}
void
WinoGradConv3x3s1
(
const
float
*
input
,
...
...
@@ -329,16 +673,21 @@ void WinoGradConv3x3s1(const float *input,
const
index_t
in_width
,
const
index_t
in_channels
,
const
index_t
out_channels
,
const
int
out_tile_size
,
float
*
output
)
{
index_t
out_height
=
in_height
-
2
;
index_t
out_width
=
in_width
-
2
;
index_t
tile_height_count
=
(
out_height
+
1
)
/
2
;
index_t
tile_width_count
=
(
out_width
+
1
)
/
2
;
index_t
tile_height_count
=
RoundUpDiv
(
out_height
,
static_cast
<
index_t
>
(
out_tile_size
));
index_t
tile_width_count
=
RoundUpDiv
(
out_width
,
static_cast
<
index_t
>
(
out_tile_size
));
index_t
tile_count
=
tile_height_count
*
tile_width_count
;
index_t
transformed_input_size
=
16
*
batch
*
in_channels
*
tile_count
;
index_t
transformed_filter_size
=
16
*
out_channels
*
in_channels
;
index_t
transformed_output_size
=
16
*
batch
*
out_channels
*
tile_count
;
index_t
in_tile_area
=
(
out_tile_size
+
2
)
*
(
out_tile_size
+
2
);
index_t
transformed_input_size
=
in_tile_area
*
batch
*
in_channels
*
tile_count
;
index_t
transformed_filter_size
=
in_tile_area
*
out_channels
*
in_channels
;
index_t
transformed_output_size
=
in_tile_area
*
batch
*
out_channels
*
tile_count
;
float
*
transformed_input
=
new
float
[
transformed_input_size
];
// TNCB
float
*
transformed_filter
=
new
float
[
transformed_filter_size
];
// TOC
...
...
@@ -351,6 +700,7 @@ void WinoGradConv3x3s1(const float *input,
in_width
,
in_channels
,
out_channels
,
out_tile_size
,
transformed_input
,
transformed_filter
,
transformed_output
,
...
...
@@ -362,7 +712,6 @@ void WinoGradConv3x3s1(const float *input,
delete
[]
transformed_output
;
}
void
ConvRef3x3s1
(
const
float
*
input
,
const
float
*
filter
,
const
index_t
batch
,
...
...
@@ -391,7 +740,8 @@ void ConvRef3x3s1(const float *input,
((
b
*
in_channels
+
c
)
*
in_height
+
ih
)
*
in_width
+
iw
;
index_t
filter_offset
=
(((
m
*
in_channels
)
+
c
)
*
3
+
kh
)
*
3
+
kw
;
output
[
out_offset
]
+=
input
[
in_offset
]
*
filter
[
filter_offset
];
output
[
out_offset
]
+=
input
[
in_offset
]
*
filter
[
filter_offset
];
}
}
}
...
...
mace/kernels/arm/conv_winograd.h
浏览文件 @
e9bcf721
...
...
@@ -21,6 +21,7 @@ void WinoGradConv3x3s1(const float *input,
const
index_t
in_width
,
const
index_t
in_channels
,
const
index_t
out_channels
,
const
int
out_tile_size
,
float
*
output
);
void
WinoGradConv3x3s1
(
const
float
*
input
,
...
...
@@ -30,6 +31,7 @@ void WinoGradConv3x3s1(const float *input,
const
index_t
in_width
,
const
index_t
in_channels
,
const
index_t
out_channels
,
const
int
out_tile_size
,
float
*
transformed_input
,
float
*
transformed_filter
,
float
*
transformed_output
,
...
...
mace/kernels/arm/conv_winograd_test.cc
浏览文件 @
e9bcf721
...
...
@@ -58,11 +58,12 @@ TEST(ConvWinogradTest, winograd) {
in_width
,
in_channels
,
out_channels
,
6
,
output_data
);
// test
for
(
index_t
i
=
0
;
i
<
output_size
;
++
i
)
{
EXPECT_NEAR
(
output_data_ref
[
i
],
output_data
[
i
],
0.1
);
EXPECT_NEAR
(
output_data_ref
[
i
],
output_data
[
i
],
0.1
)
<<
" with index "
<<
i
;
}
delete
[]
input_data
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录