Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
9e0583e1
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
9e0583e1
编写于
5月 24, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(dnn/arm_common): add arm_common chanwise dot 11x11
GitOrigin-RevId: 84e0815a5943d2efcdcb79d32196e7a405e315b0
上级
115bcbce
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
551 addition
and
20 deletion
+551
-20
dnn/src/arm_common/conv_bias/int8/chanwise_direct_dot.cpp
dnn/src/arm_common/conv_bias/int8/chanwise_direct_dot.cpp
+8
-4
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h
...mon/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h
+9
-0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_11x11s1.cpp
...as/int8/direct_kernels/dot_direct_nchw_larget_11x11s1.cpp
+240
-0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_11x11s2.cpp
...as/int8/direct_kernels/dot_direct_nchw_larget_11x11s2.cpp
+249
-0
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s2.cpp
...bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s2.cpp
+8
-10
dnn/test/arm_common/conv_bias.cpp
dnn/test/arm_common/conv_bias.cpp
+35
-4
dnn/test/arm_common/conv_bias_multi_thread.cpp
dnn/test/arm_common/conv_bias_multi_thread.cpp
+2
-2
未找到文件。
dnn/src/arm_common/conv_bias/int8/chanwise_direct_dot.cpp
浏览文件 @
9e0583e1
...
...
@@ -21,9 +21,13 @@ public:
DirectConvRunner
(
size_t
flt_size
,
size_t
stride
)
{
if
(
flt_size
==
9
&&
stride
==
1
)
{
m_func
=
megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16
;
}
else
{
megdnn_assert
(
flt_size
==
9
&&
stride
==
2
);
}
else
if
(
flt_size
==
9
&&
stride
==
2
)
{
m_func
=
megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16
;
}
else
if
(
flt_size
==
11
&&
stride
==
1
)
{
m_func
=
megdnn_dot_nchw_large_chanwise_direct_conv_11x11s1_oh4_ow16
;
}
else
{
megdnn_assert
(
flt_size
==
11
&&
stride
==
2
);
m_func
=
megdnn_dot_nchw_large_chanwise_direct_conv_11x11s2_oh4_ow16
;
}
}
size_t
get_round_fw
(
const
ConvBiasImpl
::
NCBKernSizeParam
&
param
)
const
{
...
...
@@ -208,8 +212,8 @@ bool ConvBiasImpl::AlgoDotS8DirectChanWiseLarge::usable(
(
bias_mode
==
BiasMode
::
NO_BIAS
||
bias_mode
==
BiasMode
::
BROADCAST_CHANNEL_BIAS
)
&&
fm
.
spatial_ndim
==
2
&&
fm
.
dilation
[
0
]
==
1
&&
fm
.
dilation
[
1
]
==
1
&&
SH
==
SW
&&
(
SH
==
1
||
SH
==
2
)
&&
FH
==
FW
&&
(
FH
==
9
)
&&
fm
.
icpg
==
1
&&
fm
.
ocpg
==
1
;
SH
==
SW
&&
(
SH
==
1
||
SH
==
2
)
&&
FH
==
FW
&&
(
FH
==
9
||
FH
==
11
)
&&
fm
.
icpg
==
1
&&
fm
.
ocpg
==
1
;
return
avaible
;
}
...
...
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h
浏览文件 @
9e0583e1
...
...
@@ -12,4 +12,13 @@ void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16(
size_t
ow
,
size_t
OH
,
size_t
OW
,
size_t
pad_iw
,
const
float
scale
,
int8_t
relu_val
);
void
megdnn_dot_nchw_large_chanwise_direct_conv_11x11s1_oh4_ow16
(
const
int8_t
*
src
,
const
int8_t
*
weight
,
int32_t
bias
,
int8_t
*
dst
,
size_t
oh
,
size_t
ow
,
size_t
OH
,
size_t
OW
,
size_t
pad_iw
,
const
float
scale
,
int8_t
relu_val
);
void
megdnn_dot_nchw_large_chanwise_direct_conv_11x11s2_oh4_ow16
(
const
int8_t
*
src
,
const
int8_t
*
weight
,
int32_t
bias
,
int8_t
*
dst
,
size_t
oh
,
size_t
ow
,
size_t
OH
,
size_t
OW
,
size_t
pad_iw
,
const
float
scale
,
int8_t
relu_val
);
#endif
\ No newline at end of file
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_11x11s1.cpp
0 → 100644
浏览文件 @
9e0583e1
#include "megdnn/arch.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h"
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h"
#include "src/common/unroll_macro.h"
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
megdnn_dot_nchw_large_chanwise_direct_conv_11x11s1_oh4_ow16
(
const
int8_t
*
src
,
const
int8_t
*
weight
,
int32_t
bias
,
int8_t
*
dst
,
size_t
oh
,
size_t
ow
,
size_t
OH
,
size_t
OW
,
size_t
pad_iw
,
const
float
scale
,
int8_t
relu_val
)
{
//! 4x16
const
size_t
SH
=
1
;
const
size_t
SW
=
1
;
static
const
uint8_t
tbl_array_0
[
16
]
=
{
0
,
1
,
2
,
3
,
1
,
2
,
3
,
4
,
2
,
3
,
4
,
5
,
3
,
4
,
5
,
6
};
static
const
uint8_t
tbl_array_1
[
16
]
=
{
4
,
5
,
6
,
7
,
5
,
6
,
7
,
8
,
6
,
7
,
8
,
9
,
7
,
8
,
9
,
10
};
static
const
uint8_t
tbl_array_2
[
16
]
=
{
8
,
9
,
10
,
11
,
9
,
10
,
11
,
12
,
10
,
11
,
12
,
13
,
11
,
12
,
13
,
14
};
uint8x16_t
tbl_reg_0
=
vld1q_u8
(
&
tbl_array_0
[
0
]);
uint8x16_t
tbl_reg_1
=
vld1q_u8
(
&
tbl_array_1
[
0
]);
uint8x16_t
tbl_reg_2
=
vld1q_u8
(
&
tbl_array_2
[
0
]);
const
int8_t
*
src_n
=
src
+
oh
*
SH
*
pad_iw
+
ow
*
SW
;
//! init
int32x4_t
c
[
4
][
4
];
#define cb(step) \
c[step][0] = vdupq_n_s32(bias); \
c[step][1] = vdupq_n_s32(bias); \
c[step][2] = vdupq_n_s32(bias); \
c[step][3] = vdupq_n_s32(bias);
UNROLL_CALL_RAW
(
4
,
cb
);
#undef cb
#define flt_reg 4
int8x16_t
flt
[
flt_reg
];
flt
[
0
]
=
vld1q_s8
(
weight
+
0
*
16
);
flt
[
1
]
=
vld1q_s8
(
weight
+
1
*
16
);
flt
[
2
]
=
vld1q_s8
(
weight
+
2
*
16
);
flt
[
3
]
=
vld1q_s8
(
weight
+
3
*
16
);
//! row 0
int8x16_t
read_w
[
2
];
read_w
[
0
]
=
vld1q_s8
(
src_n
+
0
*
pad_iw
);
read_w
[
1
]
=
vld1q_s8
(
src_n
+
0
*
pad_iw
+
16
);
int8x16_t
n0123_0
=
vqtbl1q_s8
(
read_w
[
0
],
tbl_reg_0
);
int8x16_t
n4567_0
=
vqtbl1q_s8
(
read_w
[
0
],
tbl_reg_1
);
int8x16_t
n89ab_0
=
vqtbl1q_s8
(
read_w
[
0
],
tbl_reg_2
);
int8x16_t
ncdef_0
=
vqtbl1q_s8
(
vextq_s8
(
read_w
[
0
],
read_w
[
1
],
12
),
tbl_reg_0
);
int8x16_t
n0123_1
=
n4567_0
;
int8x16_t
n4567_1
=
n89ab_0
;
int8x16_t
n89ab_1
=
ncdef_0
;
int8x16_t
ncdef_1
=
vqtbl1q_s8
(
read_w
[
1
],
tbl_reg_0
);
int8x16_t
n0123_2
=
n89ab_0
;
int8x16_t
n4567_2
=
ncdef_0
;
int8x16_t
n89ab_2
=
ncdef_1
;
int8x16_t
ncdef_2
=
vqtbl1q_s8
(
read_w
[
1
],
tbl_reg_1
);
#define CAL_C(oh, flt_start) \
c[oh][0] = vdotq_laneq_s32( \
c[oh][0], n0123_0, flt[(flt_start + 0) / 4 % flt_reg], \
(flt_start + 0) % 4); \
c[oh][1] = vdotq_laneq_s32( \
c[oh][1], n4567_0, flt[(flt_start + 0) / 4 % flt_reg], \
(flt_start + 0) % 4); \
c[oh][2] = vdotq_laneq_s32( \
c[oh][2], n89ab_0, flt[(flt_start + 0) / 4 % flt_reg], \
(flt_start + 0) % 4); \
c[oh][3] = vdotq_laneq_s32( \
c[oh][3], ncdef_0, flt[(flt_start + 0) / 4 % flt_reg], \
(flt_start + 0) % 4); \
c[oh][0] = vdotq_laneq_s32( \
c[oh][0], n0123_1, flt[(flt_start + 1) / 4 % flt_reg], \
(flt_start + 1) % 4); \
c[oh][1] = vdotq_laneq_s32( \
c[oh][1], n4567_1, flt[(flt_start + 1) / 4 % flt_reg], \
(flt_start + 1) % 4); \
c[oh][2] = vdotq_laneq_s32( \
c[oh][2], n89ab_1, flt[(flt_start + 1) / 4 % flt_reg], \
(flt_start + 1) % 4); \
c[oh][3] = vdotq_laneq_s32( \
c[oh][3], ncdef_1, flt[(flt_start + 1) / 4 % flt_reg], \
(flt_start + 1) % 4); \
c[oh][0] = vdotq_laneq_s32( \
c[oh][0], n0123_2, flt[(flt_start + 2) / 4 % flt_reg], \
(flt_start + 2) % 4); \
c[oh][1] = vdotq_laneq_s32( \
c[oh][1], n4567_2, flt[(flt_start + 2) / 4 % flt_reg], \
(flt_start + 2) % 4); \
c[oh][2] = vdotq_laneq_s32( \
c[oh][2], n89ab_2, flt[(flt_start + 2) / 4 % flt_reg], \
(flt_start + 2) % 4); \
c[oh][3] = vdotq_laneq_s32( \
c[oh][3], ncdef_2, flt[(flt_start + 2) / 4 % flt_reg], \
(flt_start + 2) % 4);
CAL_C
(
0
,
0
);
//! row 1
#define LOAD_SRC(row_id) \
read_w[0] = vld1q_s8(src_n + row_id * pad_iw); \
read_w[1] = vld1q_s8(src_n + row_id * pad_iw + 16); \
n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); \
n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); \
n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); \
ncdef_0 = vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); \
n0123_1 = n4567_0; \
n4567_1 = n89ab_0; \
n89ab_1 = ncdef_0; \
ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); \
n0123_2 = n89ab_0; \
n4567_2 = ncdef_0; \
n89ab_2 = ncdef_1; \
ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1);
LOAD_SRC
(
1
);
CAL_C
(
0
,
3
);
CAL_C
(
1
,
0
);
//! row 2
LOAD_SRC
(
2
);
CAL_C
(
0
,
3
*
2
);
CAL_C
(
1
,
3
*
1
);
CAL_C
(
2
,
3
*
0
);
//! row 3
LOAD_SRC
(
3
);
CAL_C
(
0
,
3
*
3
);
CAL_C
(
1
,
3
*
2
);
CAL_C
(
2
,
3
*
1
);
CAL_C
(
3
,
3
*
0
);
//! row 4
LOAD_SRC
(
4
);
CAL_C
(
0
,
3
*
4
);
CAL_C
(
1
,
3
*
3
);
CAL_C
(
2
,
3
*
2
);
CAL_C
(
3
,
3
*
1
);
//! update flt 4 -> 0
flt
[
0
]
=
vld1q_s8
(
weight
+
4
*
16
);
//! row 5
LOAD_SRC
(
5
);
CAL_C
(
0
,
3
*
5
);
CAL_C
(
1
,
3
*
4
);
CAL_C
(
2
,
3
*
3
);
CAL_C
(
3
,
3
*
2
);
//! update flt 5 -> 1
flt
[
1
]
=
vld1q_s8
(
weight
+
5
*
16
);
//! row 6
LOAD_SRC
(
6
);
CAL_C
(
0
,
3
*
6
);
CAL_C
(
1
,
3
*
5
);
CAL_C
(
2
,
3
*
4
);
CAL_C
(
3
,
3
*
3
);
//! update flt 6 -> 2
flt
[
2
]
=
vld1q_s8
(
weight
+
6
*
16
);
//! row 7
LOAD_SRC
(
7
);
CAL_C
(
0
,
3
*
7
);
CAL_C
(
1
,
3
*
6
);
CAL_C
(
2
,
3
*
5
);
CAL_C
(
3
,
3
*
4
);
//! row 8
LOAD_SRC
(
8
);
CAL_C
(
3
,
3
*
5
);
//! update flt 7 -> 3
flt
[
3
]
=
vld1q_s8
(
weight
+
7
*
16
);
CAL_C
(
2
,
3
*
6
);
CAL_C
(
1
,
3
*
7
);
CAL_C
(
0
,
3
*
8
);
//! row 9
LOAD_SRC
(
9
);
CAL_C
(
0
,
3
*
9
);
CAL_C
(
1
,
3
*
8
);
CAL_C
(
2
,
3
*
7
);
CAL_C
(
3
,
3
*
6
);
//! row 10
LOAD_SRC
(
10
);
//! update flt 8 -> 0
flt
[
0
]
=
vld1q_s8
(
weight
+
8
*
16
);
CAL_C
(
3
,
3
*
7
);
CAL_C
(
2
,
3
*
8
);
CAL_C
(
1
,
3
*
9
);
CAL_C
(
0
,
3
*
10
);
//! row 11
LOAD_SRC
(
11
);
CAL_C
(
1
,
3
*
10
);
CAL_C
(
2
,
3
*
9
);
CAL_C
(
3
,
3
*
8
);
//! row 12
LOAD_SRC
(
12
);
CAL_C
(
2
,
3
*
10
);
CAL_C
(
3
,
3
*
9
);
//! row 13
LOAD_SRC
(
13
);
CAL_C
(
3
,
3
*
10
);
float32x4_t
dst_reg
[
4
][
4
];
#define cb(step) \
dst_reg[step][0] = vcvtq_f32_s32(c[step][0]); \
dst_reg[step][1] = vcvtq_f32_s32(c[step][1]); \
dst_reg[step][2] = vcvtq_f32_s32(c[step][2]); \
dst_reg[step][3] = vcvtq_f32_s32(c[step][3]);
UNROLL_CALL_RAW
(
4
,
cb
);
#undef cb
#define cb(step) \
dst_reg[step][0] = vmulq_n_f32(dst_reg[step][0], scale); \
dst_reg[step][1] = vmulq_n_f32(dst_reg[step][1], scale); \
dst_reg[step][2] = vmulq_n_f32(dst_reg[step][2], scale); \
dst_reg[step][3] = vmulq_n_f32(dst_reg[step][3], scale);
UNROLL_CALL_RAW
(
4
,
cb
);
#undef cb
int8_t
*
dst_store
=
dst
+
oh
*
OW
+
ow
;
int8x16_t
relu_reg
=
vdupq_n_s8
(
relu_val
);
#define cb(step) \
quant_store_s8( \
dst_reg[step][0], dst_reg[step][1], dst_reg[step][2], dst_reg[step][3], \
dst_store + step * OW, relu_reg);
UNROLL_CALL_RAW
(
4
,
cb
);
#undef cb
}
#endif
\ No newline at end of file
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_11x11s2.cpp
0 → 100644
浏览文件 @
9e0583e1
#include "megdnn/arch.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h"
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h"
#include "src/common/unroll_macro.h"
MEGDNN_ATTRIBUTE_TARGET
(
"dotprod"
)
void
megdnn_dot_nchw_large_chanwise_direct_conv_11x11s2_oh4_ow16
(
const
int8_t
*
src
,
const
int8_t
*
weight
,
int32_t
bias
,
int8_t
*
dst
,
size_t
oh
,
size_t
ow
,
size_t
OH
,
size_t
OW
,
size_t
pad_iw
,
const
float
scale
,
int8_t
relu_val
)
{
//! 4x16
const
size_t
SH
=
2
;
const
size_t
SW
=
2
;
static
const
uint8_t
tbl_array_0
[
16
]
=
{
0
,
1
,
2
,
3
,
2
,
3
,
4
,
5
,
4
,
5
,
6
,
7
,
6
,
7
,
8
,
9
};
static
const
uint8_t
tbl_array_1
[
16
]
=
{
4
,
5
,
6
,
7
,
6
,
7
,
8
,
9
,
8
,
9
,
10
,
11
,
10
,
11
,
12
,
13
};
uint8x16_t
tbl_reg_0
=
vld1q_u8
(
&
tbl_array_0
[
0
]);
uint8x16_t
tbl_reg_1
=
vld1q_u8
(
&
tbl_array_1
[
0
]);
const
int8_t
*
src_n
=
src
+
oh
*
SH
*
pad_iw
+
ow
*
SW
;
//! init
int32x4_t
c
[
4
][
4
];
#define cb(step) \
c[step][0] = vdupq_n_s32(bias); \
c[step][1] = vdupq_n_s32(bias); \
c[step][2] = vdupq_n_s32(bias); \
c[step][3] = vdupq_n_s32(bias);
UNROLL_CALL_RAW
(
4
,
cb
);
#undef cb
#define flt_reg 9
#define flt_per_reg 4
int8x16_t
flt
[
flt_reg
];
#define cb(step) flt[step] = vld1q_s8(weight + step * 16);
UNROLL_CALL_RAW
(
flt_reg
,
cb
);
#undef cb
#define CAL_C(oh, flt_start) \
c[oh][0] = vdotq_laneq_s32( \
c[oh][0], n0123_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \
(flt_start + 0) % flt_per_reg); \
c[oh][1] = vdotq_laneq_s32( \
c[oh][1], n4567_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \
(flt_start + 0) % flt_per_reg); \
c[oh][2] = vdotq_laneq_s32( \
c[oh][2], n89ab_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \
(flt_start + 0) % flt_per_reg); \
c[oh][3] = vdotq_laneq_s32( \
c[oh][3], ncdef_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \
(flt_start + 0) % flt_per_reg); \
c[oh][0] = vdotq_laneq_s32( \
c[oh][0], n0123_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \
(flt_start + 1) % flt_per_reg); \
c[oh][1] = vdotq_laneq_s32( \
c[oh][1], n4567_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \
(flt_start + 1) % flt_per_reg); \
c[oh][2] = vdotq_laneq_s32( \
c[oh][2], n89ab_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \
(flt_start + 1) % flt_per_reg); \
c[oh][3] = vdotq_laneq_s32( \
c[oh][3], ncdef_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \
(flt_start + 1) % flt_per_reg); \
c[oh][0] = vdotq_laneq_s32( \
c[oh][0], n0123_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \
(flt_start + 2) % flt_per_reg); \
c[oh][1] = vdotq_laneq_s32( \
c[oh][1], n4567_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \
(flt_start + 2) % flt_per_reg); \
c[oh][2] = vdotq_laneq_s32( \
c[oh][2], n89ab_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \
(flt_start + 2) % flt_per_reg); \
c[oh][3] = vdotq_laneq_s32( \
c[oh][3], ncdef_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \
(flt_start + 2) % flt_per_reg);
#define LOAD_SRC(row_id) \
read_w[0] = vld1q_s8(src_n + row_id * pad_iw); \
read_w[1] = vld1q_s8(src_n + row_id * pad_iw + 16); \
read_w[2] = vld1q_s8(src_n + row_id * pad_iw + 32); \
ext_8 = vextq_s8(read_w[0], read_w[1], 8); \
ext_24 = vextq_s8(read_w[1], read_w[2], 8); \
n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); \
n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); \
n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); \
ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); \
n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); \
n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); \
n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); \
ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); \
n0123_2 = n4567_0; \
n4567_2 = n89ab_0; \
n89ab_2 = ncdef_0; \
ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0);
//! row 0
int8x16_t
read_w
[
3
];
read_w
[
0
]
=
vld1q_s8
(
src_n
);
read_w
[
1
]
=
vld1q_s8
(
src_n
+
16
);
read_w
[
2
]
=
vld1q_s8
(
src_n
+
32
);
int8x16_t
ext_8
=
vextq_s8
(
read_w
[
0
],
read_w
[
1
],
8
);
int8x16_t
ext_24
=
vextq_s8
(
read_w
[
1
],
read_w
[
2
],
8
);
int8x16_t
n0123_0
=
vqtbl1q_s8
(
read_w
[
0
],
tbl_reg_0
);
int8x16_t
n4567_0
=
vqtbl1q_s8
(
ext_8
,
tbl_reg_0
);
int8x16_t
n89ab_0
=
vqtbl1q_s8
(
read_w
[
1
],
tbl_reg_0
);
int8x16_t
ncdef_0
=
vqtbl1q_s8
(
ext_24
,
tbl_reg_0
);
int8x16_t
n0123_1
=
vqtbl1q_s8
(
read_w
[
0
],
tbl_reg_1
);
int8x16_t
n4567_1
=
vqtbl1q_s8
(
ext_8
,
tbl_reg_1
);
int8x16_t
n89ab_1
=
vqtbl1q_s8
(
read_w
[
1
],
tbl_reg_1
);
int8x16_t
ncdef_1
=
vqtbl1q_s8
(
ext_24
,
tbl_reg_1
);
int8x16_t
n0123_2
=
n4567_0
;
int8x16_t
n4567_2
=
n89ab_0
;
int8x16_t
n89ab_2
=
ncdef_0
;
int8x16_t
ncdef_2
=
vqtbl1q_s8
(
read_w
[
2
],
tbl_reg_0
);
CAL_C
(
0
,
0
);
//! row 1
LOAD_SRC
(
1
);
CAL_C
(
0
,
3
*
1
);
//! row 2
LOAD_SRC
(
2
);
CAL_C
(
0
,
3
*
2
);
CAL_C
(
1
,
3
*
0
);
//! row 3
LOAD_SRC
(
3
);
CAL_C
(
0
,
3
*
3
);
CAL_C
(
1
,
3
*
1
);
//! row 4
LOAD_SRC
(
4
);
CAL_C
(
0
,
3
*
4
);
CAL_C
(
1
,
3
*
2
);
CAL_C
(
2
,
3
*
0
);
//! row 5
LOAD_SRC
(
5
);
CAL_C
(
0
,
3
*
5
);
CAL_C
(
1
,
3
*
3
);
CAL_C
(
2
,
3
*
1
);
//! row 6
LOAD_SRC
(
6
);
CAL_C
(
0
,
3
*
6
);
CAL_C
(
1
,
3
*
4
);
CAL_C
(
2
,
3
*
2
);
CAL_C
(
3
,
3
*
0
);
//! row 7
LOAD_SRC
(
7
);
CAL_C
(
0
,
3
*
7
);
CAL_C
(
1
,
3
*
5
);
CAL_C
(
2
,
3
*
3
);
CAL_C
(
3
,
3
*
1
);
//! row 8
LOAD_SRC
(
8
);
CAL_C
(
0
,
3
*
8
);
CAL_C
(
1
,
3
*
6
);
CAL_C
(
2
,
3
*
4
);
CAL_C
(
3
,
3
*
2
);
//! row 9
LOAD_SRC
(
9
);
CAL_C
(
0
,
3
*
9
);
CAL_C
(
1
,
3
*
7
);
CAL_C
(
2
,
3
*
5
);
CAL_C
(
3
,
3
*
3
);
//! row 10
LOAD_SRC
(
10
);
CAL_C
(
0
,
3
*
10
);
CAL_C
(
1
,
3
*
8
);
CAL_C
(
2
,
3
*
6
);
CAL_C
(
3
,
3
*
4
);
//! row 11
LOAD_SRC
(
11
);
CAL_C
(
1
,
3
*
9
);
CAL_C
(
2
,
3
*
7
);
CAL_C
(
3
,
3
*
5
);
//! row 12
LOAD_SRC
(
12
);
CAL_C
(
1
,
3
*
10
);
CAL_C
(
2
,
3
*
8
);
CAL_C
(
3
,
3
*
6
);
//! row 13
LOAD_SRC
(
13
);
CAL_C
(
2
,
3
*
9
);
CAL_C
(
3
,
3
*
7
);
//! row 14
LOAD_SRC
(
14
);
CAL_C
(
2
,
3
*
10
);
CAL_C
(
3
,
3
*
8
);
//! row 15
LOAD_SRC
(
15
);
CAL_C
(
3
,
3
*
9
);
//! row 16
LOAD_SRC
(
16
);
CAL_C
(
3
,
3
*
10
);
float32x4_t
dst_reg
[
4
][
4
];
#define cb(step) \
dst_reg[step][0] = vcvtq_f32_s32(c[step][0]); \
dst_reg[step][1] = vcvtq_f32_s32(c[step][1]); \
dst_reg[step][2] = vcvtq_f32_s32(c[step][2]); \
dst_reg[step][3] = vcvtq_f32_s32(c[step][3]);
UNROLL_CALL_RAW
(
4
,
cb
);
#undef cb
#define cb(step) \
dst_reg[step][0] = vmulq_n_f32(dst_reg[step][0], scale); \
dst_reg[step][1] = vmulq_n_f32(dst_reg[step][1], scale); \
dst_reg[step][2] = vmulq_n_f32(dst_reg[step][2], scale); \
dst_reg[step][3] = vmulq_n_f32(dst_reg[step][3], scale);
UNROLL_CALL_RAW
(
4
,
cb
);
#undef cb
int8_t
*
dst_store
=
dst
+
oh
*
OW
+
ow
;
int8x16_t
relu_reg
=
vdupq_n_s8
(
relu_val
);
#define cb(step) \
quant_store_s8( \
dst_reg[step][0], dst_reg[step][1], dst_reg[step][2], dst_reg[step][3], \
dst_store + step * OW, relu_reg);
UNROLL_CALL_RAW
(
4
,
cb
);
#undef cb
}
#endif
\ No newline at end of file
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_larget_9x9s2.cpp
浏览文件 @
9e0583e1
...
...
@@ -36,16 +36,14 @@ void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16(
UNROLL_CALL_RAW
(
4
,
cb
);
#undef cb
constexpr
int
flt_reg
=
7
;
constexpr
int
flt_per_reg
=
4
;
int8x16_t
flt
[
7
];
flt
[
0
]
=
vld1q_s8
(
weight
+
0
*
16
);
flt
[
1
]
=
vld1q_s8
(
weight
+
1
*
16
);
flt
[
2
]
=
vld1q_s8
(
weight
+
2
*
16
);
flt
[
3
]
=
vld1q_s8
(
weight
+
3
*
16
);
flt
[
4
]
=
vld1q_s8
(
weight
+
4
*
16
);
flt
[
5
]
=
vld1q_s8
(
weight
+
5
*
16
);
flt
[
6
]
=
vld1q_s8
(
weight
+
6
*
16
);
#define flt_reg 7
#define flt_per_reg 4
int8x16_t
flt
[
flt_reg
];
#define cb(step) flt[step] = vld1q_s8(weight + step * 16);
UNROLL_CALL_RAW
(
flt_reg
,
cb
);
#undef cb
#define CAL_C(oh, flt_start) \
c[oh][0] = vdotq_laneq_s32( \
...
...
dnn/test/arm_common/conv_bias.cpp
浏览文件 @
9e0583e1
...
...
@@ -2060,6 +2060,16 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) {
benchmark1
.
set_display
(
false
);
benchmark1
.
set_times
(
RUN
);
Benchmarker
<
ConvBias
>
benchmark2
(
handle
());
benchmark2
.
set_dtype
(
0
,
dtype
::
QuantizedS8
(
2.5
f
))
.
set_dtype
(
1
,
dtype
::
QuantizedS8
(
2.5
f
))
.
set_dtype
(
2
,
dtype
::
QuantizedS32
(
6.25
f
))
.
set_dtype
(
4
,
dtype
::
QuantizedS8
(
60.25
f
));
benchmark2
.
set_before_exec_callback
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBiasForward
>
(
"ARMDOTS8"
));
benchmark2
.
set_display
(
false
);
benchmark2
.
set_times
(
RUN
);
for
(
auto
&&
arg
:
args
)
{
TensorLayout
dst_layout
;
auto
opr
=
handle
()
->
create_operator
<
ConvBias
>
();
...
...
@@ -2070,6 +2080,12 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) {
//! dst.nr_elems * FH * FW * 2
float
computations
=
dst_layout
.
total_nr_elems
()
*
arg
.
filter
[
3
]
*
arg
.
filter
[
4
]
*
2.0
/
1e6
;
float
computations_5x5
=
dst_layout
.
total_nr_elems
()
*
5
*
5
*
2.0
/
1e6
;
float
computations_11x11
=
dst_layout
.
total_nr_elems
()
*
11
*
11
*
2.0
/
1e6
;
param
::
ConvBias
param_5x5
=
arg
.
param
;
param_5x5
.
pad_h
=
param_5x5
.
pad_w
=
5
/
2
;
param
::
ConvBias
param_11x11
=
arg
.
param
;
param_11x11
.
pad_h
=
param_11x11
.
pad_w
=
11
/
2
;
auto
used0
=
benchmark0
.
set_param
(
arg
.
param
).
exec
(
{
arg
.
src
,
arg
.
filter
,
arg
.
bias
,
{},
{}})
/
...
...
@@ -2077,11 +2093,26 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) {
auto
used1
=
benchmark1
.
set_param
(
arg
.
param
).
exec
(
{
arg
.
src
,
arg
.
filter
,
arg
.
bias
,
{},
{}})
/
RUN
;
TensorShape
flt_5x5_shape
=
arg
.
filter
;
flt_5x5_shape
[
3
]
=
flt_5x5_shape
[
4
]
=
5
;
auto
used5x5
=
benchmark2
.
set_param
(
param_5x5
).
exec
(
{
arg
.
src
,
flt_5x5_shape
,
arg
.
bias
,
{},
{}})
/
RUN
;
TensorShape
flt_11x11_shape
=
arg
.
filter
;
flt_11x11_shape
[
3
]
=
flt_11x11_shape
[
4
]
=
11
;
auto
used11x11
=
benchmark0
.
set_param
(
param_11x11
)
.
exec
({
arg
.
src
,
flt_11x11_shape
,
arg
.
bias
,
{},
{}})
/
RUN
;
printf
(
"%s %s: Direct use: %f ms %f Gflops im2col: %f ms %f GFlops "
"speedup: %f
\n
"
,
arg
.
src
.
to_string
().
c_str
(),
arg
.
filter
.
to_string
().
c_str
(),
used0
,
computations
/
used0
,
used1
,
computations
/
used1
,
used1
/
used0
);
printf
(
"%s %s s %u: Direct use: %f ms %f Gflops im2col: %f ms %f GFlops "
"speedup: %f, compare 5x5 %f ms %f GFlops speedup %f, compare 11x11 %f "
"ms %f GFops speedup %f
\n
"
,
arg
.
src
.
to_string
().
c_str
(),
arg
.
filter
.
to_string
().
c_str
(),
arg
.
param
.
stride_h
,
used0
,
computations
/
used0
,
used1
,
computations
/
used1
,
used1
/
used0
,
used5x5
,
computations_5x5
/
used5x5
,
used5x5
/
used0
,
used11x11
,
computations_11x11
/
used11x11
,
used11x11
/
used0
);
}
}
...
...
dnn/test/arm_common/conv_bias_multi_thread.cpp
浏览文件 @
9e0583e1
...
...
@@ -612,13 +612,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) {
#if MGB_ENABLE_DOT
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_DOT_DIRECT_LARGE_S1
)
{
checker_conv_bias_qint8x8x8
(
get_channel_wise_args
({
9
},
1
,
false
,
true
,
true
,
true
),
handle
(),
get_channel_wise_args
({
9
,
11
},
1
,
false
,
true
,
true
,
true
),
handle
(),
"ARMDOTS8_DIRECT_CHANWISE_LARGE"
);
}
TEST_F
(
ARM_COMMON_MULTI_THREADS
,
CONV_BIAS_DOT_DIRECT_LARGE_S2
)
{
checker_conv_bias_qint8x8x8
(
get_channel_wise_args
({
9
},
2
,
false
,
true
,
true
,
true
),
handle
(),
get_channel_wise_args
({
9
,
11
},
2
,
false
,
true
,
true
,
true
),
handle
(),
"ARMDOTS8_DIRECT_CHANWISE_LARGE"
);
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录