Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
d37229fa
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
d37229fa
编写于
7月 03, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn): optimize f23 and f63 nchw44 winograd
GitOrigin-RevId: 8569c9dfc6db1b6853d4aa35bcf5b2bc9b6f89b1
上级
d7c0dd45
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
316 addition
and
132 deletion
+316
-132
dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp
...src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp
+50
-61
dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp
...src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp
+173
-66
dnn/src/arm_common/simd_macro/marm_neon.h
dnn/src/arm_common/simd_macro/marm_neon.h
+91
-3
dnn/test/arm_common/conv_bias.cpp
dnn/test/arm_common/conv_bias.cpp
+2
-2
未找到文件。
dnn/src/arm_common/conv_bias/fp32/strategy_f23_mk4_nchw44.cpp
浏览文件 @
d37229fa
...
...
@@ -32,29 +32,37 @@ constexpr size_t pack_size = 4;
struct
InputTransformF23_NCHW44
{
template
<
bool
inner
>
static
void
prepare
(
const
float
*
input
,
float
*
patch
,
float
*
patchT
,
int
ih_start
,
int
iw_start
,
size_t
IH
,
size_t
IW
,
size_t
ic
,
size_t
IC
)
{
MEGDNN_MARK_USED_VAR
(
patch
);
static
void
transform
(
float
*
patchT
,
const
float
*
input
,
float
*
input_transform_buf
,
size_t
ih_start
,
size_t
iw_start
,
size_t
IH
,
size_t
IW
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
ic
,
size_t
IC
)
{
size_t
IW4
=
IW
*
pack_size
;
size_t
iw4_start
=
iw_start
*
pack_size
;
size_t
icb
=
ic
/
pack_size
;
size_t
iw4_start
=
iw_start
*
pack_size
;
size_t
ICB
=
IC
/
pack_size
;
#define cb(m, n) Vector<float, 4> d##m##n;
UNROLL_CALL_NOWRAPPER_D2
(
4
,
4
,
cb
);
#undef cb
if
(
!
(
inner
&&
ic
+
pack_size
<
IC
))
{
memset
(
patchT
,
0
,
sizeof
(
float
)
*
pack_size
*
alpha
*
alpha
);
}
if
(
inner
)
{
MEGDNN_MARK_USED_VAR
(
patchT
);
const
float
*
input_ptr
=
input
+
icb
*
IH
*
IW4
+
ih_start
*
IW4
+
iw4_start
;
for
(
size_t
ih
=
0
;
ih
<
alpha
;
ih
++
)
{
#define cb(i) auto v##i = vld1q_f32(input_ptr + pack_size * i);
UNROLL_CALL_NOWRAPPER
(
4
,
cb
);
#undef cb
#define cb(i) vst1q_f32(patchT + ih * alpha * pack_size + i * pack_size, v##i);
UNROLL_CALL_NOWRAPPER
(
4
,
cb
);
#define cb(n, m) d##m##n = Vector<float, 4>::load(input_ptr + pack_size * n);
UNROLL_CALL_RAW
(
4
,
cb
,
0
);
input_ptr
+=
IW4
;
UNROLL_CALL_RAW
(
4
,
cb
,
1
);
input_ptr
+=
IW4
;
UNROLL_CALL_RAW
(
4
,
cb
,
2
);
input_ptr
+=
IW4
;
UNROLL_CALL_RAW
(
4
,
cb
,
3
);
#undef cb
input_ptr
+=
IW4
;
}
}
else
{
int
ih0_act
=
std
::
max
<
int
>
(
ih_start
,
0
),
ih1_act
=
std
::
min
<
int
>
(
ih_start
+
alpha
,
IH
),
...
...
@@ -71,19 +79,12 @@ struct InputTransformF23_NCHW44 {
src
);
}
}
}
}
static
void
transform
(
const
float
*
patchT
,
float
*
input_transform_buf
,
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
ic
,
size_t
IC
)
{
// BT * d * B
#define cb(m, n) \
Vector<float, 4> d##m##n = Vector<float, 4>::load( \
patchT + m * alpha * pack_size + n * pack_size);
UNROLL_CALL_NOWRAPPER_D2
(
4
,
4
,
cb
);
#define cb(m, n) \
d##m##n = Vector<float, 4>::load(patchT + m * alpha * pack_size + \
n * pack_size);
UNROLL_CALL_NOWRAPPER_D2
(
4
,
4
,
cb
);
#undef cb
}
//! 1 0 -1 0 d00 d01 d02 d03 1 0 0 0
//! 0 1 1 0 d10 d11 d12 d13 0 1 -1 -1
//! 0 -1 1 0 d20 d21 d22 d23 -1 1 1 0
...
...
@@ -106,8 +107,6 @@ struct InputTransformF23_NCHW44 {
UNROLL_CALL_NOWRAPPER
(
4
,
cb
);
#undef cb
size_t
ICB
=
IC
/
4
;
size_t
icb
=
ic
/
4
;
#define cb(m, n) \
d##m##n.save(input_transform_buf + \
(m * alpha + n) * ICB * nr_units_in_tile * pack_size + \
...
...
@@ -273,7 +272,6 @@ void winograd_F23_mk4_f_nchw44::input(const float* input,
// OW = IW + 2 * PW - KERNEL_SIZE + 1
auto
units_w
=
div_ceil
<
size_t
>
(
IW
+
2
*
PW
-
KERNEL_SIZE
+
1
,
OUTPUT_BLOCK_SIZE
);
float
*
patch
=
transform_mid_buf
;
float
*
patchT
=
transform_mid_buf
+
4
*
alpha
*
alpha
;
for
(
size_t
ic
=
0
;
ic
<
IC
;
ic
+=
4
)
{
...
...
@@ -285,20 +283,13 @@ void winograd_F23_mk4_f_nchw44::input(const float* input,
int
iw_start
=
nw
*
OUTPUT_BLOCK_SIZE
-
PW
;
if
(
ih_start
>=
0
&&
ih_start
+
alpha
<=
static_cast
<
int
>
(
IH
)
&&
iw_start
>=
0
&&
iw_start
+
alpha
<=
static_cast
<
int
>
(
IW
))
{
InputTransformF23_NCHW44
::
prepare
<
true
>
(
input
,
patch
,
patchT
,
ih_start
,
iw_start
,
IH
,
IW
,
ic
,
IC
);
InputTransformF23_NCHW44
::
transform
(
patchT
,
input_transform_buf
,
unit_idx
,
nr_units_in_tile
,
ic
,
IC
);
InputTransformF23_NCHW44
::
transform
<
true
>
(
patchT
,
input
,
input_transform_buf
,
ih_start
,
iw_start
,
IH
,
IW
,
unit_idx
,
nr_units_in_tile
,
ic
,
IC
);
}
else
{
InputTransformF23_NCHW44
::
prepare
<
false
>
(
input
,
patch
,
patchT
,
ih_start
,
iw_start
,
IH
,
IW
,
ic
,
IC
);
InputTransformF23_NCHW44
::
transform
(
patchT
,
input_transform_buf
,
unit_idx
,
nr_units_in_tile
,
ic
,
IC
);
InputTransformF23_NCHW44
::
transform
<
false
>
(
patchT
,
input
,
input_transform_buf
,
ih_start
,
iw_start
,
IH
,
IW
,
unit_idx
,
nr_units_in_tile
,
ic
,
IC
);
}
}
}
...
...
@@ -311,9 +302,21 @@ void winograd_F23_mk4_f_nchw44::output(const float* output_transform_buf,
size_t
OW
,
size_t
oc_start
,
size_t
oc_end
,
size_t
unit_start_idx
,
size_t
nr_units_in_tile
)
{
#define cb(_bmode, _nonline_op, ...) \
OutputTransformF23_NCHW44<_bmode MEGDNN_COMMA _nonline_op>::transform( \
__VA_ARGS__);
#define cb(_bmode, _nonline_op, ...) \
for (size_t oc = oc_start; oc < oc_end; oc += 4) { \
size_t oc_index = oc - oc_start; \
rep(unit_idx, nr_units_in_tile) { \
size_t index = unit_start_idx + unit_idx; \
auto nh = index / units_w; \
auto nw = index % units_w; \
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \
OutputTransformF23_NCHW44<_bmode, _nonline_op>::transform( \
output_transform_buf, bias, output, transform_mid_buf, \
oh_start, ow_start, OH, OW, oc_start, oc_end, oc_index, \
unit_idx, nr_units_in_tile, src_dtype, dst_dtype); \
} \
}
auto
units_w
=
div_ceil
<
size_t
>
(
OW
,
OUTPUT_BLOCK_SIZE
);
constexpr
size_t
pack_size
=
4
;
...
...
@@ -323,22 +326,8 @@ void winograd_F23_mk4_f_nchw44::output(const float* output_transform_buf,
oc_end
%
pack_size
==
0
,
"NCHW44 Winograd filter transform requires OC is times of 4"
);
for
(
size_t
oc
=
oc_start
;
oc
<
oc_end
;
oc
+=
4
)
{
size_t
oc_index
=
oc
-
oc_start
;
rep
(
unit_idx
,
nr_units_in_tile
)
{
size_t
index
=
unit_start_idx
+
unit_idx
;
auto
nh
=
index
/
units_w
;
auto
nw
=
index
%
units_w
;
size_t
oh_start
=
nh
*
OUTPUT_BLOCK_SIZE
;
size_t
ow_start
=
nw
*
OUTPUT_BLOCK_SIZE
;
DISPATCH_CONV_WINOGRAD_BIAS
(
megdnn_arm_common_winograd_nchw44_fp32_F23_mk4
,
cb
,
float
,
float
,
bmode
,
nonline_mode
,
output_transform_buf
,
bias
,
output
,
transform_mid_buf
,
oh_start
,
ow_start
,
OH
,
OW
,
oc_start
,
oc_end
,
oc_index
,
unit_idx
,
nr_units_in_tile
,
src_dtype
,
dst_dtype
);
}
}
DISPATCH_CONV_WINOGRAD_BIAS
(
megdnn_arm_common_winograd_nchw44_fp32_F23_mk4
,
cb
,
float
,
float
,
bmode
,
nonline_mode
);
#undef cb
}
...
...
dnn/src/arm_common/conv_bias/fp32/strategy_f63_mk4_nchw44.cpp
浏览文件 @
d37229fa
...
...
@@ -31,6 +31,8 @@ namespace {
constexpr
size_t
alpha
=
6
+
3
-
1
;
constexpr
size_t
pack_size
=
4
;
constexpr
float
input_parameters
[
12
]
=
{
5.25
f
,
4.25
f
,
0.5
f
,
0.25
f
,
2.5
f
,
1.25
f
,
2.0
f
,
4.0
f
,
5.0
f
,
0.0
f
,
0.0
f
,
0.0
f
};
struct
InputTransformF63_NCHW44
{
template
<
bool
inner
>
...
...
@@ -80,12 +82,14 @@ struct InputTransformF63_NCHW44 {
size_t
unit_idx
,
size_t
nr_units_in_tile
,
size_t
ic
,
size_t
IC
)
{
// BT * d * B
#define cb(m, n) \
Vector<float, 4> d##m##n = Vector<float, 4>::load( \
patchT + m * alpha * pack_size + n * pack_size);
UNROLL_CALL_NOWRAPPER_D2
(
8
,
8
,
cb
);
#undef cb
size_t
ICB
=
IC
/
pack_size
;
size_t
icb
=
ic
/
pack_size
;
float32x4_t
d0
,
d1
,
d2
,
d3
,
d4
,
d5
,
d6
,
d7
;
float32x4_t
v0
=
vld1q_f32
(
input_parameters
+
0
);
float32x4_t
v1
=
vld1q_f32
(
input_parameters
+
4
);
float32x4_t
v2
=
vld1q_f32
(
input_parameters
+
8
);
//! B
//! 1 0 0 0 0 0 0 0
...
...
@@ -96,49 +100,147 @@ struct InputTransformF63_NCHW44 {
//! 0 1 -1 2 -2 0.5 -0.5 -5.25
//! -1 1 1 1 1 1 1 0
//! 0 0 0 0 0 0 0 1
#define cb(m) \
auto t0##m = d0##m + (d4##m - d2##m) * 5.25f - d6##m; \
auto t1##m = d1##m + d2##m + d5##m + d6##m - (d3##m + d4##m) * 4.25f; \
auto t2##m = d2##m + d6##m - (d1##m + d5##m) + (d3##m - d4##m) * 4.25f; \
auto t3##m = d1##m * 0.5f + d2##m * 0.25f - d3##m * 2.5f - d4##m * 1.25f + \
d5##m * 2.f + d6##m; \
auto t4##m = d1##m * (-0.5f) + d2##m * 0.25f + d3##m * 2.5f - \
d4##m * 1.25f - d5##m * 2.f + d6##m; \
auto t5##m = d1##m * 2.f + d2##m * 4.f - d3##m * 2.5f - d4##m * 5.f + \
d5##m * 0.5f + d6##m; \
auto t6##m = d1##m * (-2.f) + d2##m * 4.f + d3##m * 2.5f - d4##m * 5.f - \
d5##m * 0.5f + d6##m; \
auto t7##m = (d7##m - d1##m) + (d3##m - d5##m) * 5.25f;
UNROLL_CALL_NOWRAPPER
(
8
,
cb
);
#undef cb
#define cb(m) \
d##m##0 = t##m##0 + (t##m##4 - t##m##2) * 5.25f - t##m##6; \
d##m##1 = t##m##1 + t##m##2 + t##m##5 + t##m##6 - \
(t##m##3 + t##m##4) * 4.25f; \
d##m##2 = t##m##2 + t##m##6 - (t##m##1 + t##m##5) + \
(t##m##3 - t##m##4) * 4.25f; \
d##m##3 = t##m##1 * 0.5f + t##m##2 * 0.25f - t##m##3 * 2.5f - \
t##m##4 * 1.25f + t##m##5 * 2.f + t##m##6; \
d##m##4 = t##m##1 * (-0.5f) + t##m##2 * 0.25f + t##m##3 * 2.5f - \
t##m##4 * 1.25f - t##m##5 * 2.f + t##m##6; \
d##m##5 = t##m##1 * 2.f + t##m##2 * 4.f - t##m##3 * 2.5f - t##m##4 * 5.f + \
t##m##5 * 0.5f + t##m##6; \
d##m##6 = t##m##1 * (-2.f) + t##m##2 * 4.f + t##m##3 * 2.5f - \
t##m##4 * 5.f - t##m##5 * 0.5f + t##m##6; \
d##m##7 = (t##m##7 - t##m##1) + (t##m##3 - t##m##5) * 5.25f;
UNROLL_CALL_NOWRAPPER
(
8
,
cb
);
#define cb(i) \
d1 = vld1q_f32(patchT + i * alpha * pack_size + 1 * pack_size); \
d2 = vld1q_f32(patchT + i * alpha * pack_size + 2 * pack_size); \
d3 = vld1q_f32(patchT + i * alpha * pack_size + 3 * pack_size); \
d4 = vld1q_f32(patchT + i * alpha * pack_size + 4 * pack_size); \
d5 = vld1q_f32(patchT + i * alpha * pack_size + 5 * pack_size); \
d6 = vld1q_f32(patchT + i * alpha * pack_size + 6 * pack_size); \
auto t##i##0 = vld1q_f32(patchT + i * alpha * pack_size + 0 * pack_size); \
auto t##i##7 = vld1q_f32(patchT + i * alpha * pack_size + 7 * pack_size); \
auto t##i##1 = d6; \
auto t##i##2 = d6; \
auto t##i##3 = d6; \
auto t##i##4 = d6; \
auto t##i##5 = d6; \
auto t##i##6 = d6; \
t##i##0 = t##i##0 - d6; \
t##i##1 = t##i##1 + d1; \
t##i##2 = t##i##2 - d1; \
t##i##3 = vfmaq_laneq_f32(t##i##3, d1, v0, 2); \
t##i##4 = vfmsq_laneq_f32(t##i##4, d1, v0, 2); \
t##i##5 = vfmaq_laneq_f32(t##i##5, d1, v1, 2); \
t##i##6 = vfmsq_laneq_f32(t##i##6, d1, v1, 2); \
t##i##7 = t##i##7 - d1; \
t##i##0 = vfmsq_laneq_f32(t##i##0, d2, v0, 0); \
t##i##1 = t##i##1 + d2; \
t##i##2 = t##i##2 + d2; \
t##i##3 = vfmaq_laneq_f32(t##i##3, d2, v0, 3); \
t##i##4 = vfmaq_laneq_f32(t##i##4, d2, v0, 3); \
t##i##5 = vfmaq_laneq_f32(t##i##5, d2, v1, 3); \
t##i##6 = vfmaq_laneq_f32(t##i##6, d2, v1, 3); \
t##i##1 = vfmsq_laneq_f32(t##i##1, d3, v0, 1); \
t##i##2 = vfmaq_laneq_f32(t##i##2, d3, v0, 1); \
t##i##3 = vfmsq_laneq_f32(t##i##3, d3, v1, 0); \
t##i##4 = vfmaq_laneq_f32(t##i##4, d3, v1, 0); \
t##i##5 = vfmsq_laneq_f32(t##i##5, d3, v1, 0); \
t##i##6 = vfmaq_laneq_f32(t##i##6, d3, v1, 0); \
t##i##7 = vfmaq_laneq_f32(t##i##7, d3, v0, 0); \
t##i##0 = vfmaq_laneq_f32(t##i##0, d4, v0, 0); \
t##i##1 = vfmsq_laneq_f32(t##i##1, d4, v0, 1); \
t##i##2 = vfmsq_laneq_f32(t##i##2, d4, v0, 1); \
t##i##3 = vfmsq_laneq_f32(t##i##3, d4, v1, 1); \
t##i##4 = vfmsq_laneq_f32(t##i##4, d4, v1, 1); \
t##i##5 = vfmsq_laneq_f32(t##i##5, d4, v2, 0); \
t##i##6 = vfmsq_laneq_f32(t##i##6, d4, v2, 0); \
t##i##1 = t##i##1 + d5; \
t##i##2 = t##i##2 - d5; \
t##i##3 = vfmaq_laneq_f32(t##i##3, d5, v1, 2); \
t##i##4 = vfmsq_laneq_f32(t##i##4, d5, v1, 2); \
t##i##5 = vfmaq_laneq_f32(t##i##5, d5, v0, 2); \
t##i##6 = vfmsq_laneq_f32(t##i##6, d5, v0, 2); \
t##i##7 = vfmsq_laneq_f32(t##i##7, d5, v0, 0);
UNROLL_CALL_RAW
(
8
,
cb
);
#undef cb
size_t
ICB
=
IC
/
pack_size
;
size_t
icb
=
ic
/
pack_size
;
#define cb(m, n) \
d##m##n.save(input_transform_buf + \
(m * alpha + n) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + unit_idx * pack_size);
UNROLL_CALL_NOWRAPPER_D2
(
8
,
8
,
cb
)
#define cb(i) \
d0 = t0##i; \
d1 = t6##i; \
d2 = t6##i; \
d3 = t6##i; \
d4 = t6##i; \
d5 = t6##i; \
d6 = t6##i; \
d7 = t7##i; \
d0 = d0 - t6##i; \
d1 = d1 + t1##i; \
d2 = d2 - t1##i; \
d3 = vfmaq_laneq_f32(d3, t1##i, v0, 2); \
d4 = vfmsq_laneq_f32(d4, t1##i, v0, 2); \
d5 = vfmaq_laneq_f32(d5, t1##i, v1, 2); \
d6 = vfmsq_laneq_f32(d6, t1##i, v1, 2); \
d7 = d7 - t1##i; \
d0 = vfmsq_laneq_f32(d0, t2##i, v0, 0); \
d1 = d1 + t2##i; \
d2 = d2 + t2##i; \
d3 = vfmaq_laneq_f32(d3, t2##i, v0, 3); \
d4 = vfmaq_laneq_f32(d4, t2##i, v0, 3); \
d5 = vfmaq_laneq_f32(d5, t2##i, v1, 3); \
d6 = vfmaq_laneq_f32(d6, t2##i, v1, 3); \
d1 = vfmsq_laneq_f32(d1, t3##i, v0, 1); \
d2 = vfmaq_laneq_f32(d2, t3##i, v0, 1); \
d3 = vfmsq_laneq_f32(d3, t3##i, v1, 0); \
d4 = vfmaq_laneq_f32(d4, t3##i, v1, 0); \
d5 = vfmsq_laneq_f32(d5, t3##i, v1, 0); \
d6 = vfmaq_laneq_f32(d6, t3##i, v1, 0); \
d7 = vfmaq_laneq_f32(d7, t3##i, v0, 0); \
d0 = vfmaq_laneq_f32(d0, t4##i, v0, 0); \
d1 = vfmsq_laneq_f32(d1, t4##i, v0, 1); \
d2 = vfmsq_laneq_f32(d2, t4##i, v0, 1); \
d3 = vfmsq_laneq_f32(d3, t4##i, v1, 1); \
d4 = vfmsq_laneq_f32(d4, t4##i, v1, 1); \
d5 = vfmsq_laneq_f32(d5, t4##i, v2, 0); \
d6 = vfmsq_laneq_f32(d6, t4##i, v2, 0); \
d1 = d1 + t5##i; \
d2 = d2 - t5##i; \
d3 = vfmaq_laneq_f32(d3, t5##i, v1, 2); \
d4 = vfmsq_laneq_f32(d4, t5##i, v1, 2); \
d5 = vfmaq_laneq_f32(d5, t5##i, v0, 2); \
d6 = vfmsq_laneq_f32(d6, t5##i, v0, 2); \
d7 = vfmsq_laneq_f32(d7, t5##i, v0, 0); \
vst1q_f32(input_transform_buf + \
(0 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d0); \
vst1q_f32(input_transform_buf + \
(1 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d1); \
vst1q_f32(input_transform_buf + \
(2 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d2); \
vst1q_f32(input_transform_buf + \
(3 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d3); \
vst1q_f32(input_transform_buf + \
(4 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d4); \
vst1q_f32(input_transform_buf + \
(5 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d5); \
vst1q_f32(input_transform_buf + \
(6 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d6); \
vst1q_f32(input_transform_buf + \
(7 * alpha + i) * ICB * nr_units_in_tile * pack_size + \
icb * nr_units_in_tile * pack_size + \
unit_idx * pack_size, \
d7);
UNROLL_CALL_RAW
(
8
,
cb
);
#undef cb
}
};
...
...
@@ -178,7 +280,7 @@ struct OutputTransformF63_NCHW44 {
* 1 -2 4 -8 16 -32
* 1 0.5 0.25 0.125 0.0625 0.03125
* 1 -0.5 0.25 -0.125 0.0625 -0.03125
* 0
0.
0 0 0 0 1
* 0
0 0 0 0 1
*/
Vector
<
float
,
4
>
v1addv2
,
v1subv2
,
v3addv4
,
v3subv4
,
v5addv6
,
v5subv6
;
...
...
@@ -378,28 +480,33 @@ void winograd_F63_mk4_f_nchw44::output(const float* output_transform_buf,
size_t
OW
,
size_t
oc_start
,
size_t
oc_end
,
size_t
unit_start_idx
,
size_t
nr_units_in_tile
)
{
constexpr
size_t
pack_size
=
4
;
#define cb(_bmode, _nonline_op, ...) \
OutputTransformF63_NCHW44<_bmode MEGDNN_COMMA _nonline_op>::transform( \
__VA_ARGS__);
#define cb(_bmode, _nonline_op, ...) \
for (size_t oc = oc_start; oc < oc_end; oc += pack_size) { \
size_t oc_index = oc - oc_start; \
rep(unit_idx, nr_units_in_tile) { \
size_t index = unit_start_idx + unit_idx; \
auto nh = index / units_w; \
auto nw = index % units_w; \
size_t oh_start = nh * OUTPUT_BLOCK_SIZE; \
size_t ow_start = nw * OUTPUT_BLOCK_SIZE; \
OutputTransformF63_NCHW44<_bmode MEGDNN_COMMA _nonline_op>:: \
transform(output_transform_buf, bias, output, \
transform_mid_buf, oh_start, ow_start, OH, OW, \
oc_start, oc_end, oc_index, unit_idx, \
nr_units_in_tile, src_dtype, dst_dtype); \
} \
}
auto
units_w
=
div_ceil
<
size_t
>
(
OW
,
OUTPUT_BLOCK_SIZE
);
for
(
size_t
oc
=
oc_start
;
oc
<
oc_end
;
oc
+=
pack_size
)
{
size_t
oc_index
=
oc
-
oc_start
;
rep
(
unit_idx
,
nr_units_in_tile
)
{
size_t
index
=
unit_start_idx
+
unit_idx
;
auto
nh
=
index
/
units_w
;
auto
nw
=
index
%
units_w
;
size_t
oh_start
=
nh
*
OUTPUT_BLOCK_SIZE
;
size_t
ow_start
=
nw
*
OUTPUT_BLOCK_SIZE
;
DISPATCH_CONV_WINOGRAD_BIAS
(
megdnn_arm_common_winograd_fp32_F63_mk4
,
cb
,
float
,
float
,
bmode
,
nonline_mode
,
output_transform_buf
,
bias
,
output
,
transform_mid_buf
,
oh_start
,
ow_start
,
OH
,
OW
,
oc_start
,
oc_end
,
oc_index
,
unit_idx
,
nr_units_in_tile
,
src_dtype
,
dst_dtype
);
}
}
constexpr
size_t
pack_size
=
4
;
size_t
OC
=
oc_end
-
oc_start
;
megdnn_assert
(
OC
%
pack_size
==
0
&&
oc_start
%
pack_size
==
0
&&
oc_end
%
pack_size
==
0
,
"NCHW44 Winograd filter transform requires OC is times of 4"
);
DISPATCH_CONV_WINOGRAD_BIAS
(
megdnn_arm_common_winograd_fp32_F63_mk4
,
cb
,
float
,
float
,
bmode
,
nonline_mode
);
#undef cb
}
...
...
dnn/src/arm_common/simd_macro/marm_neon.h
浏览文件 @
d37229fa
...
...
@@ -538,10 +538,43 @@ struct Vfmaq_laneq_f32_armv7<3> {
return
vmlaq_lane_f32
(
a
,
b
,
vget_high_f32
(
v
),
1
);
}
};
template
<
int
lane
>
struct
Vfmsq_laneq_f32_armv7
{
__ai
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
);
};
template
<
>
struct
Vfmsq_laneq_f32_armv7
<
0
>
{
__ai
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
return
vmlsq_lane_f32
(
a
,
b
,
vget_low_f32
(
v
),
0
);
}
};
template
<
>
struct
Vfmsq_laneq_f32_armv7
<
1
>
{
__ai
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
return
vmlsq_lane_f32
(
a
,
b
,
vget_low_f32
(
v
),
1
);
}
};
template
<
>
struct
Vfmsq_laneq_f32_armv7
<
2
>
{
__ai
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
return
vmlsq_lane_f32
(
a
,
b
,
vget_high_f32
(
v
),
0
);
}
};
template
<
>
struct
Vfmsq_laneq_f32_armv7
<
3
>
{
__ai
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
return
vmlsq_lane_f32
(
a
,
b
,
vget_high_f32
(
v
),
1
);
}
};
}
// namespace
#define vfmaq_laneq_f32(a, b, v, lane) \
Vfmaq_laneq_f32_armv7<lane>::impl(a, b, v)
#define vfmsq_laneq_f32(a, b, v, lane) \
Vfmsq_laneq_f32_armv7<lane>::impl(a, b, v)
#if __ARM_FEATURE_DOTPROD
namespace
{
template
<
int
lane
>
...
...
@@ -582,7 +615,6 @@ struct Vdotq_laneq_s32_armv7<3> {
//! GCC split fmla with lane to dup+fmla when version < 9
//! https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89101
#if !defined(__clang__) && __GNUC__ < 9
#if MEGDNN_AARCH64
namespace
{
...
...
@@ -630,13 +662,59 @@ struct Vfmaq_laneq_f32_armv8<3> {
return
a
;
}
};
template
<
int
lane
>
struct
Vfmsq_laneq_f32_armv8
{
__ai
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
);
};
template
<
>
struct
Vfmsq_laneq_f32_armv8
<
0
>
{
__ai
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
asm
volatile
(
"fmls %0.4s, %1.4s, %2.s[0]
\n
"
:
"+w"
(
a
)
:
"w"
(
b
),
"w"
(
v
)
:
);
return
a
;
}
};
template
<
>
struct
Vfmsq_laneq_f32_armv8
<
1
>
{
__ai
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
asm
volatile
(
"fmls %0.4s, %1.4s, %2.s[1]
\n
"
:
"+w"
(
a
)
:
"w"
(
b
),
"w"
(
v
)
:
);
return
a
;
}
};
template
<
>
struct
Vfmsq_laneq_f32_armv8
<
2
>
{
__ai
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
asm
volatile
(
"fmls %0.4s, %1.4s, %2.s[2]
\n
"
:
"+w"
(
a
)
:
"w"
(
b
),
"w"
(
v
)
:
);
return
a
;
}
};
template
<
>
struct
Vfmsq_laneq_f32_armv8
<
3
>
{
__ai
float32x4_t
impl
(
float32x4_t
a
,
float32x4_t
b
,
float32x4_t
v
)
{
asm
volatile
(
"fmls %0.4s, %1.4s, %2.s[3]
\n
"
:
"+w"
(
a
)
:
"w"
(
b
),
"w"
(
v
)
:
);
return
a
;
}
};
}
// namespace
#undef vfmaq_laneq_f32
#define vfmaq_laneq_f32(a, b, v, lane) \
Vfmaq_laneq_f32_armv8<lane>::impl(a, b, v)
#endif
#undef vfmsq_laneq_f32
#define vfmsq_laneq_f32(a, b, v, lane) \
Vfmsq_laneq_f32_armv8<lane>::impl(a, b, v)
#endif
__ai
int8x16_t
vld_dup_tbl_s32
(
const
int8_t
*
ptr
,
uint8x16_t
&
idx
)
{
...
...
@@ -678,6 +756,16 @@ __ai int16x8_t vld1_dup_s8_s16(const int8_t* ptr) {
return
vmovl_s8
(
vld1_dup_s8
(
ptr
));
}
//! we add this because we found that cpu=aarch64_android cann't compile fmsq into fmls.
//! it use dup+fmla instead
__ai
float32x4_t
Vfmsq_f32
(
float32x4_t
&
a
,
float32x4_t
&
b
,
float32x4_t
&
v
)
{
asm
volatile
(
"fmls %0.4s, %1.4s, %2.4s
\n
"
:
"+w"
(
a
)
:
"w"
(
b
),
"w"
(
v
)
:
);
return
a
;
}
#undef __ai
#pragma GCC diagnostic pop
...
...
dnn/test/arm_common/conv_bias.cpp
浏览文件 @
d37229fa
...
...
@@ -791,8 +791,8 @@ void benchmark_winograd_nchw_vs_nchw44(const char* algo_name, Handle* handle) {
std
::
vector
<
NLMode
>
nonlinemode
=
{
NLMode
::
IDENTITY
};
for
(
auto
nlmode
:
nonlinemode
)
for
(
size_t
n
:
{
1
,
2
})
for
(
size_t
group
=
1
;
group
<=
2
;
++
group
)
{
for
(
size_t
n
:
{
1
})
for
(
size_t
group
=
1
;
group
<=
1
;
++
group
)
{
pack
(
n
,
512
,
512
,
15
,
15
,
group
,
nlmode
);
pack
(
n
,
512
,
256
,
15
,
15
,
group
,
nlmode
);
pack
(
n
,
256
,
256
,
29
,
29
,
group
,
nlmode
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录