Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
9e361a4d
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9e361a4d
编写于
6月 03, 2020
作者:
Y
yiicy
提交者:
GitHub
6月 03, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ARM] int8 direct_conv, dw_conv add relu6 and leaky relu fusion, test=develop (#3737)
int8 direct_conv, dw_conv add relu6 and leaky relu fusion
上级
cba42f0d
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
636 addition
and
547 deletion
+636
-547
lite/backends/arm/math/conv3x3s1_depthwise_int8.cc
lite/backends/arm/math/conv3x3s1_depthwise_int8.cc
+8
-4
lite/backends/arm/math/conv3x3s1_direct_int8.cc
lite/backends/arm/math/conv3x3s1_direct_int8.cc
+25
-2
lite/backends/arm/math/conv3x3s2_depthwise_int8.cc
lite/backends/arm/math/conv3x3s2_depthwise_int8.cc
+8
-4
lite/backends/arm/math/conv3x3s2_direct_int8.cc
lite/backends/arm/math/conv3x3s2_direct_int8.cc
+50
-4
lite/backends/arm/math/conv5x5s1_depthwise_int8.cc
lite/backends/arm/math/conv5x5s1_depthwise_int8.cc
+8
-4
lite/backends/arm/math/conv5x5s2_depthwise_int8.cc
lite/backends/arm/math/conv5x5s2_depthwise_int8.cc
+8
-4
lite/backends/arm/math/conv_block_utils.h
lite/backends/arm/math/conv_block_utils.h
+361
-476
lite/backends/arm/math/conv_depthwise.h
lite/backends/arm/math/conv_depthwise.h
+8
-4
lite/backends/arm/math/conv_impl.cc
lite/backends/arm/math/conv_impl.cc
+108
-12
lite/backends/arm/math/gemm_prepacked_int8.cc
lite/backends/arm/math/gemm_prepacked_int8.cc
+12
-12
lite/kernels/arm/conv_depthwise.cc
lite/kernels/arm/conv_depthwise.cc
+6
-0
lite/kernels/arm/conv_direct.h
lite/kernels/arm/conv_direct.h
+19
-13
lite/tests/math/conv_int8_compute_test.cc
lite/tests/math/conv_int8_compute_test.cc
+15
-8
未找到文件。
lite/backends/arm/math/conv3x3s1_depthwise_int8.cc
浏览文件 @
9e361a4d
...
...
@@ -36,7 +36,8 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
const
float
*
scale
,
const
float
*
bias
,
bool
flag_bias
,
bool
flag_relu
,
int
flag_act
,
float
*
alpha
,
int
num
,
int
chin
,
int
hin
,
...
...
@@ -434,7 +435,8 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
chout
,
hout
,
wout
,
flag_relu
,
flag_act
,
alpha
,
bias_local
,
flag_bias
,
ptr_write
,
...
...
@@ -450,7 +452,8 @@ template void conv_depthwise_3x3s1_int8<int8_t>(int8_t* dout,
const
float
*
scale
,
const
float
*
bias
,
bool
flag_bias
,
bool
flag_relu
,
int
flag_act
,
float
*
alpha
,
int
num
,
int
chin
,
int
hin
,
...
...
@@ -467,7 +470,8 @@ template void conv_depthwise_3x3s1_int8<float>(float* dout,
const
float
*
scale
,
const
float
*
bias
,
bool
flag_bias
,
bool
flag_relu
,
int
flag_act
,
float
*
alpha
,
int
num
,
int
chin
,
int
hin
,
...
...
lite/backends/arm/math/conv3x3s1_direct_int8.cc
浏览文件 @
9e361a4d
...
...
@@ -42,8 +42,30 @@ void conv_3x3s1_direct_int8(const int8_t* din,
Context
<
TARGET
(
kARM
)
>*
ctx
,
const
float
*
scale
)
{
auto
paddings
=
*
param
.
paddings
;
bool
flag_relu
=
param
.
fuse_relu
;
bool
flag_bias
=
param
.
bias
;
auto
act_param
=
param
.
activation_param
;
auto
act_type
=
act_param
.
active_type
;
int
flag_act
=
0
;
// relu: 1, relu6: 2, leakey: 3
float
alpha
[
4
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
if
(
act_param
.
has_active
)
{
if
(
act_type
==
lite_api
::
ActivationType
::
kRelu
)
{
flag_act
=
1
;
}
else
if
(
act_type
==
lite_api
::
ActivationType
::
kRelu6
)
{
flag_act
=
2
;
float
local_alpha
=
act_param
.
Relu_clipped_coef
;
alpha
[
0
]
=
local_alpha
;
alpha
[
1
]
=
local_alpha
;
alpha
[
2
]
=
local_alpha
;
alpha
[
3
]
=
local_alpha
;
}
else
if
(
act_type
==
lite_api
::
ActivationType
::
kLeakyRelu
)
{
flag_act
=
3
;
float
local_alpha
=
act_param
.
Leaky_relu_alpha
;
alpha
[
0
]
=
local_alpha
;
alpha
[
1
]
=
local_alpha
;
alpha
[
2
]
=
local_alpha
;
alpha
[
3
]
=
local_alpha
;
}
}
int
pad_h
=
paddings
[
0
];
int
pad_w
=
paddings
[
2
];
...
...
@@ -442,7 +464,8 @@ void conv_3x3s1_direct_int8(const int8_t* din,
chout
,
hout
,
wout
,
flag_relu
,
flag_act
,
alpha
,
bias_local
,
flag_bias
,
ptr_write
,
...
...
lite/backends/arm/math/conv3x3s2_depthwise_int8.cc
浏览文件 @
9e361a4d
...
...
@@ -36,7 +36,8 @@ void conv_depthwise_3x3s2_int8(Dtype* dout,
const
float
*
scale
,
const
float
*
bias
,
bool
flag_bias
,
bool
flag_relu
,
int
flag_act
,
float
*
alpha
,
int
num
,
int
chin
,
int
hin
,
...
...
@@ -447,7 +448,8 @@ void conv_depthwise_3x3s2_int8(Dtype* dout,
chout
,
hout
,
wout
,
flag_relu
,
flag_act
,
alpha
,
bias_local
,
flag_bias
,
ptr_write
,
...
...
@@ -463,7 +465,8 @@ template void conv_depthwise_3x3s2_int8<int8_t>(int8_t* dout,
const
float
*
scale
,
const
float
*
bias
,
bool
flag_bias
,
bool
flag_relu
,
int
flag_act
,
float
*
alpha
,
int
num
,
int
chin
,
int
hin
,
...
...
@@ -480,7 +483,8 @@ template void conv_depthwise_3x3s2_int8<float>(float* dout,
const
float
*
scale
,
const
float
*
bias
,
bool
flag_bias
,
bool
flag_relu
,
int
flag_act
,
float
*
alpha
,
int
num
,
int
chin
,
int
hin
,
...
...
lite/backends/arm/math/conv3x3s2_direct_int8.cc
浏览文件 @
9e361a4d
...
...
@@ -47,8 +47,30 @@ void conv_3x3s2_direct_int8(const int8_t* din,
//! prepack input to tmp buffer
//! write output to tmp buffer
auto
paddings
=
*
param
.
paddings
;
bool
flag_relu
=
param
.
fuse_relu
;
bool
flag_bias
=
param
.
bias
;
auto
act_param
=
param
.
activation_param
;
auto
act_type
=
act_param
.
active_type
;
int
flag_act
=
0
;
// relu: 1, relu6: 2, leakey: 3
float
alpha
[
4
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
if
(
act_param
.
has_active
)
{
if
(
act_type
==
lite_api
::
ActivationType
::
kRelu
)
{
flag_act
=
1
;
}
else
if
(
act_type
==
lite_api
::
ActivationType
::
kRelu6
)
{
flag_act
=
2
;
float
local_alpha
=
act_param
.
Relu_clipped_coef
;
alpha
[
0
]
=
local_alpha
;
alpha
[
1
]
=
local_alpha
;
alpha
[
2
]
=
local_alpha
;
alpha
[
3
]
=
local_alpha
;
}
else
if
(
act_type
==
lite_api
::
ActivationType
::
kLeakyRelu
)
{
flag_act
=
3
;
float
local_alpha
=
act_param
.
Leaky_relu_alpha
;
alpha
[
0
]
=
local_alpha
;
alpha
[
1
]
=
local_alpha
;
alpha
[
2
]
=
local_alpha
;
alpha
[
3
]
=
local_alpha
;
}
}
int
pad_h
=
paddings
[
0
];
int
pad_w
=
paddings
[
2
];
...
...
@@ -442,7 +464,8 @@ void conv_3x3s2_direct_int8(const int8_t* din,
chout
,
hout
,
wout
,
flag_relu
,
flag_act
,
alpha
,
bias_local
,
flag_bias
,
ptr_write
,
...
...
@@ -474,8 +497,30 @@ void conv_3x3s2_direct_int8(const int8_t* din,
//! prepack input to tmp buffer
//! write output to tmp buffer
auto
paddings
=
*
param
.
paddings
;
bool
flag_relu
=
param
.
fuse_relu
;
bool
flag_bias
=
param
.
bias
;
auto
act_param
=
param
.
activation_param
;
auto
act_type
=
act_param
.
active_type
;
int
flag_act
=
0
;
// relu: 1, relu6: 2, leakey: 3
float
alpha
[
4
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
if
(
act_param
.
has_active
)
{
if
(
act_type
==
lite_api
::
ActivationType
::
kRelu
)
{
flag_act
=
1
;
}
else
if
(
act_type
==
lite_api
::
ActivationType
::
kRelu6
)
{
flag_act
=
2
;
float
local_alpha
=
act_param
.
Relu_clipped_coef
;
alpha
[
0
]
=
local_alpha
;
alpha
[
1
]
=
local_alpha
;
alpha
[
2
]
=
local_alpha
;
alpha
[
3
]
=
local_alpha
;
}
else
if
(
act_type
==
lite_api
::
ActivationType
::
kLeakyRelu
)
{
flag_act
=
3
;
float
local_alpha
=
act_param
.
Leaky_relu_alpha
;
alpha
[
0
]
=
local_alpha
;
alpha
[
1
]
=
local_alpha
;
alpha
[
2
]
=
local_alpha
;
alpha
[
3
]
=
local_alpha
;
}
}
int
pad_h
=
paddings
[
0
];
int
pad_w
=
paddings
[
2
];
const
int
threads
=
ctx
->
threads
();
...
...
@@ -698,7 +743,8 @@ void conv_3x3s2_direct_int8(const int8_t* din,
chout
,
hout
,
wout
,
flag_relu
,
flag_act
,
alpha
,
bias_local
,
flag_bias
,
ptr_write
,
...
...
lite/backends/arm/math/conv5x5s1_depthwise_int8.cc
浏览文件 @
9e361a4d
...
...
@@ -36,7 +36,8 @@ void conv_depthwise_5x5s1_int8(Dtype* dout,
const
float
*
scale
,
const
float
*
bias
,
bool
flag_bias
,
bool
flag_relu
,
int
flag_act
,
float
*
alpha
,
int
num
,
int
chin
,
int
hin
,
...
...
@@ -726,7 +727,8 @@ void conv_depthwise_5x5s1_int8(Dtype* dout,
chout
,
hout
,
wout
,
flag_relu
,
flag_act
,
alpha
,
bias_local
,
flag_bias
,
ptr_write
,
...
...
@@ -742,7 +744,8 @@ template void conv_depthwise_5x5s1_int8<int8_t>(int8_t* dout,
const
float
*
scale
,
const
float
*
bias
,
bool
flag_bias
,
bool
flag_relu
,
int
flag_act
,
float
*
alpha
,
int
num
,
int
chin
,
int
hin
,
...
...
@@ -759,7 +762,8 @@ template void conv_depthwise_5x5s1_int8<float>(float* dout,
const
float
*
scale
,
const
float
*
bias
,
bool
flag_bias
,
bool
flag_relu
,
int
flag_act
,
float
*
alpha
,
int
num
,
int
chin
,
int
hin
,
...
...
lite/backends/arm/math/conv5x5s2_depthwise_int8.cc
浏览文件 @
9e361a4d
...
...
@@ -36,7 +36,8 @@ void conv_depthwise_5x5s2_int8(Dtype* dout,
const
float
*
scale
,
const
float
*
bias
,
bool
flag_bias
,
bool
flag_relu
,
int
flag_act
,
float
*
alpha
,
int
num
,
int
chin
,
int
hin
,
...
...
@@ -746,7 +747,8 @@ void conv_depthwise_5x5s2_int8(Dtype* dout,
chout
,
hout
,
wout
,
flag_relu
,
flag_act
,
alpha
,
bias_local
,
flag_bias
,
ptr_write
,
...
...
@@ -762,7 +764,8 @@ template void conv_depthwise_5x5s2_int8<int8_t>(int8_t* dout,
const
float
*
scale
,
const
float
*
bias
,
bool
flag_bias
,
bool
flag_relu
,
int
flag_act
,
float
*
alpha
,
int
num
,
int
chin
,
int
hin
,
...
...
@@ -779,7 +782,8 @@ template void conv_depthwise_5x5s2_int8<float>(float* dout,
const
float
*
scale
,
const
float
*
bias
,
bool
flag_bias
,
bool
flag_relu
,
int
flag_act
,
float
*
alpha
,
int
num
,
int
chin
,
int
hin
,
...
...
lite/backends/arm/math/conv_block_utils.h
浏览文件 @
9e361a4d
...
...
@@ -2643,48 +2643,81 @@ inline void int32_nchwc4_kernel(Dtype*& dout0, // NOLINT
int
cnt
,
float32x4_t
scale
,
float32x4_t
bias
,
bool
is_relu
);
int
flag_act
,
float
*
alpha
);
#ifdef __aarch64__
#define NCHWC4_TRANS_INT32 \
"ldp q0, q1, [%[ptr_din]], #32\n" \
"ldp q2, q3, [%[ptr_din]], #32\n" \
"movi v20.4s, #0\n" \
"1:\n" \
"trn1 v8.4s, v0.4s, v1.4s\n" \
"trn2 v9.4s, v0.4s, v1.4s\n" \
"ldp q0, q1, [%[ptr_din]], #32\n" \
"trn1 v10.4s, v2.4s, v3.4s\n" \
"trn2 v11.4s, v2.4s, v3.4s\n" \
"ldp q2, q3, [%[ptr_din]], #32\n" \
"trn1 v16.2d, v8.2d, v10.2d\n" \
"trn2 v17.2d, v8.2d, v10.2d\n" \
"trn1 v18.2d, v9.2d, v11.2d\n" \
"trn2 v19.2d, v9.2d, v11.2d\n"
/* int32 --> fp32 */
\
"scvtf v4.4s, v16.4s\n" \
"scvtf v5.4s, v17.4s\n" \
"scvtf v6.4s, v18.4s\n" \
"scvtf v7.4s, v19.4s\n"
/* add bias */
\
"dup v16.4s, %[bias].s[0]\n" \
"dup v17.4s, %[bias].s[2]\n" \
"dup v18.4s, %[bias].s[1]\n" \
"dup v19.4s, %[bias].s[3]\n"
/* mul scale */
\
"fmla v16.4s, v4.4s, %[scale].s[0]\n" \
"fmla v17.4s, v5.4s, %[scale].s[2]\n" \
"fmla v18.4s, v6.4s, %[scale].s[1]\n" \
"fmla v19.4s, v7.4s, %[scale].s[3]\n"
/* relu */
\
"cbz %w[relu], 2f\n" \
"fmax v16.4s, v16.4s, v20.4s \n" \
"fmax v17.4s, v17.4s, v20.4s \n" \
"fmax v18.4s, v18.4s, v20.4s \n" \
"fmax v19.4s, v19.4s, v20.4s \n" \
"2:\n"
#define NCHWC4_TRANS_INT32 \
"ldp q0, q1, [%[ptr_din]], #32\n" \
"ldp q2, q3, [%[ptr_din]], #32\n" \
"1:\n" \
"trn1 v8.4s, v0.4s, v1.4s\n" \
"trn2 v9.4s, v0.4s, v1.4s\n" \
"ldp q0, q1, [%[ptr_din]], #32\n" \
"trn1 v10.4s, v2.4s, v3.4s\n" \
"trn2 v11.4s, v2.4s, v3.4s\n" \
"ldp q2, q3, [%[ptr_din]], #32\n" \
"trn1 v16.2d, v8.2d, v10.2d\n" \
"trn2 v17.2d, v8.2d, v10.2d\n" \
"trn1 v18.2d, v9.2d, v11.2d\n" \
"trn2 v19.2d, v9.2d, v11.2d\n"
/* int32 --> fp32 */
\
"scvtf v4.4s, v16.4s\n" \
"scvtf v5.4s, v17.4s\n" \
"scvtf v6.4s, v18.4s\n" \
"scvtf v7.4s, v19.4s\n"
/* add bias */
\
"dup v16.4s, %[bias].s[0]\n" \
"dup v17.4s, %[bias].s[2]\n" \
"dup v18.4s, %[bias].s[1]\n" \
"dup v19.4s, %[bias].s[3]\n"
/* mul scale */
\
"fmla v16.4s, v4.4s, %[scale].s[0]\n" \
"fmla v17.4s, v5.4s, %[scale].s[2]\n" \
"fmla v18.4s, v6.4s, %[scale].s[1]\n" \
"fmla v19.4s, v7.4s, %[scale].s[3]\n" \
"cmp %w[flag_act], #1\n" \
"bne 12f \n" \
"movi v20.4s, #0 \n"
/* for relu*/
\
"fmax v16.4s, v16.4s, v20.4s \n" \
"fmax v17.4s, v17.4s, v20.4s \n" \
"fmax v18.4s, v18.4s, v20.4s \n" \
"fmax v19.4s, v19.4s, v20.4s \n" \
"b 2f \n"
/* relu end */
\
"12: \n"
/* no relu */
\
"cmp %w[flag_act], #0 \n"
/* check no act */
\
"beq 2f \n"
/* no act end */
\
"cmp %w[flag_act], #2 \n"
/* check relu6 */
\
"bne 13f \n"
/* jump no relu6*/
\
"movi v8.4s, #0 \n"
/* for relu6 */
\
"ld1 {v9.4s}, [%[alpha]] \n"
/* relu6 alpha */
\
"fmax v16.4s, v16.4s, v8.4s \n"
/* relu6 */
\
"fmax v17.4s, v17.4s, v8.4s \n"
/* relu6 */
\
"fmax v18.4s, v18.4s, v8.4s \n"
/* relu6 */
\
"fmax v19.4s, v19.4s, v8.4s \n"
/* relu6 */
\
"fmin v16.4s, v16.4s, v9.4s \n"
/* relu6 */
\
"fmin v17.4s, v17.4s, v9.4s \n"
/* relu6 */
\
"fmin v18.4s, v18.4s, v9.4s \n"
/* relu6 */
\
"fmin v19.4s, v19.4s, v9.4s \n"
/* relu6 */
\
"b 2f \n"
/* relu6 end */
\
"13: \n"
/* leakey relu */
\
"movi v12.4s, #0 \n"
/* for leakey relu */
\
"ld1 {v13.4s}, [%[alpha]] \n"
/* leakey relu alpha */
\
"fcmge v4.4s, v16.4s, v12.4s \n"
/* vcgeq_f32 */
\
"fmul v5.4s, v16.4s, v13.4s \n"
/* vmulq_f32 */
\
"fcmge v6.4s, v17.4s, v12.4s \n"
/* vcgeq_f32 */
\
"fmul v7.4s, v17.4s, v13.4s \n"
/* vmulq_f32 */
\
"fcmge v8.4s, v18.4s, v12.4s \n"
/* vcgeq_f32 */
\
"fmul v9.4s, v18.4s, v13.4s \n"
/* vmulq_f32 */
\
"fcmge v10.4s, v19.4s, v12.4s \n"
/* vcgeq_f32 */
\
"fmul v11.4s, v19.4s, v13.4s \n"
/* vmulq_f32 */
\
"bif v16.16b, v5.16b, v4.16b \n"
/* choose*/
\
"bif v17.16b, v7.16b, v6.16b \n"
/* choose*/
\
"bif v18.16b, v9.16b, v8.16b \n"
/* choose*/
\
"bif v19.16b, v11.16b, v10.16b \n"
/* choose*/
\
"2: \n"
/* act end */
#else
#define NCHWC4_TRANS_INT32 \
"vld1.32 {d4-d7}, [%[ptr_din]]!\n" \
"vld1.32 {d8-d11}, [%[ptr_din]]!\n" \
"vmov.u32 q15, #0\n" \
"1:\n"
/* transpose */
\
"vtrn.32 q2, q3\n" \
"vtrn.32 q4, q5\n" \
...
...
@@ -2701,13 +2734,44 @@ inline void int32_nchwc4_kernel(Dtype*& dout0, // NOLINT
"vmla.f32 q10, q6, %e[scale][0]\n" \
"vmla.f32 q11, q7, %e[scale][1]\n" \
"vmla.f32 q12, q8, %f[scale][0]\n" \
"vmla.f32 q13, q9, %f[scale][1]\n"
/* relu */
\
"cmp %[relu], #0\n" \
"beq 2f\n" \
"vmax.f32 q10, q10, q15\n" \
"vmax.f32 q11, q11, q15\n" \
"vmax.f32 q12, q12, q15\n" \
"vmax.f32 q13, q13, q15\n" \
"vmla.f32 q13, q9, %f[scale][1]\n" \
"vmov.u32 q15, #0 \n" \
"cmp %[flag_act], #1 \n" \
"bne 12f \n" \
"vmax.f32 q10, q10, q15 \n" \
"vmax.f32 q11, q11, q15 \n" \
"vmax.f32 q12, q12, q15 \n" \
"vmax.f32 q13, q13, q15 \n" \
"b 2f \n" \
"12: \n" \
"cmp %[flag_act], #0 \n" \
"beq 2f \n" \
"cmp %[flag_act], #2 \n" \
"bne 13f \n" \
"vld1.f32 {d14-d15}, [%[alpha]] \n" \
"vmax.f32 q10, q10, q15 \n" \
"vmax.f32 q11, q11, q15 \n" \
"vmax.f32 q12, q12, q15 \n" \
"vmax.f32 q13, q13, q15 \n" \
"vmin.f32 q10, q10, q7 \n" \
"vmin.f32 q11, q11, q7 \n" \
"vmin.f32 q12, q12, q7 \n" \
"vmin.f32 q13, q13, q7 \n" \
"b 2f \n" \
"13: \n" \
"vld1.f32 {d6-d7}, [%[alpha]] \n" \
"vcge.f32 q6, q10, q15 \n" \
"vmul.f32 q7, q10, q3 \n" \
"vcge.f32 q8, q11, q15 \n" \
"vmul.f32 q9, q11, q3 \n" \
"vbif q10, q7, q6 \n" \
"vbif q11, q9, q8 \n" \
"vcge.f32 q6, q12, q15 \n" \
"vmul.f32 q7, q12, q3 \n" \
"vcge.f32 q8, q13, q15 \n" \
"vmul.f32 q9, q13, q3 \n" \
"vbif q12, q7, q6 \n" \
"vbif q13, q9, q8 \n" \
"2:\n"
#endif
...
...
@@ -2721,7 +2785,8 @@ inline void int32_nchwc4_kernel(float*& dout0, // NOLINT
int
cnt
,
float32x4_t
scale
,
float32x4_t
bias
,
bool
is_relu
)
{
int
flag_act
,
float
*
alpha
)
{
#ifdef __aarch64__
asm
volatile
(
NCHWC4_TRANS_INT32
"subs %w[cnt], %w[cnt], #1
\n
"
...
...
@@ -2737,7 +2802,10 @@ inline void int32_nchwc4_kernel(float*& dout0, // NOLINT
[
doutc3r0
]
"+r"
(
dout3
),
[
ptr_din
]
"+r"
(
din
),
[
cnt
]
"+r"
(
cnt
)
:
[
scale
]
"w"
(
scale
),
[
bias
]
"w"
(
bias
),
[
relu
]
"r"
(
is_relu
)
:
[
scale
]
"w"
(
scale
),
[
bias
]
"w"
(
bias
),
[
flag_act
]
"r"
(
flag_act
),
[
alpha
]
"r"
(
alpha
)
:
"cc"
,
"memory"
,
"v0"
,
...
...
@@ -2779,7 +2847,10 @@ inline void int32_nchwc4_kernel(float*& dout0, // NOLINT
[
doutc3r0
]
"+r"
(
dout3
),
[
ptr_din
]
"+r"
(
din
),
[
cnt
]
"+r"
(
cnt
)
:
[
scale
]
"w"
(
scale
),
[
bias
]
"w"
(
bias
),
[
relu
]
"r"
(
is_relu
)
:
[
scale
]
"w"
(
scale
),
[
bias
]
"w"
(
bias
),
[
flag_act
]
"r"
(
flag_act
),
[
alpha
]
"r"
(
alpha
)
:
"cc"
,
"memory"
,
"q2"
,
...
...
@@ -2808,7 +2879,8 @@ inline void int32_nchwc4_kernel(int8_t*& dout0, // NOLINT
int
cnt
,
float32x4_t
scale
,
float32x4_t
bias
,
bool
is_relu
)
{
int
flag_act
,
float
*
alpha
)
{
#ifdef __aarch64__
float32x4_t
vmax
=
vdupq_n_f32
(
-
127.
f
);
asm
volatile
(
NCHWC4_TRANS_INT32
...
...
@@ -2852,7 +2924,8 @@ inline void int32_nchwc4_kernel(int8_t*& dout0, // NOLINT
:
[
scale
]
"w"
(
scale
),
[
vmax
]
"w"
(
vmax
),
[
bias
]
"w"
(
bias
),
[
relu
]
"r"
(
is_relu
)
[
flag_act
]
"r"
(
flag_act
),
[
alpha
]
"r"
(
alpha
)
:
"cc"
,
"memory"
,
"v0"
,
...
...
@@ -2942,8 +3015,9 @@ inline void int32_nchwc4_kernel(int8_t*& dout0, // NOLINT
[
cnt
]
"+r"
(
cnt
)
:
[
scale
]
"w"
(
scale
),
[
bias
]
"w"
(
bias
),
[
relu
]
"r"
(
is_relu
),
[
vmax
]
"r"
(
vmax
)
[
vmax
]
"r"
(
vmax
),
[
flag_act
]
"r"
(
flag_act
),
[
alpha
]
"r"
(
alpha
)
:
"cc"
,
"memory"
,
"q2"
,
...
...
@@ -2963,139 +3037,48 @@ inline void int32_nchwc4_kernel(int8_t*& dout0, // NOLINT
#endif
}
template
<
>
inline
void
int32_nchwc4_kernel
(
int32_t
*&
dout0
,
// NOLINT
int32_t
*&
dout1
,
// NOLINT
int32_t
*&
dout2
,
// NOLINT
int32_t
*&
dout3
,
// NOLINT
const
int32_t
*&
din
,
// NOLINT
int
cnt
,
float32x4_t
scale
,
float32x4_t
bias
,
bool
is_relu
)
{
#ifdef __aarch64__
asm
volatile
(
"ldp q0, q1, [%[ptr_din]], #32
\n
"
/* load r00, r01 to q0, q1 */
"ldp q2, q3, [%[ptr_din]], #32
\n
"
/* load r02, r03 to q2, q3 */
"movi v20.4s, #0
\n
"
/* for relu */
"1:
\n
"
/* main loop*/
"trn1 v8.4s, v0.4s, v1.4s
\n
"
/* trans q0, q1*/
"trn2 v9.4s, v0.4s, v1.4s
\n
"
/* trans q0, q1*/
"ldp q0, q1, [%[ptr_din]], #32
\n
"
/* load r00, r01 to q0, q1 */
"trn1 v10.4s, v2.4s, v3.4s
\n
"
/* trans q2, q3*/
"trn2 v11.4s, v2.4s, v3.4s
\n
"
/* trans q2, q3*/
"ldp q2, q3, [%[ptr_din]], #32
\n
"
/* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v10.2d
\n
"
/* trans q8, q10*/
"trn2 v17.2d, v8.2d, v10.2d
\n
"
/* trans q8, q10*/
"trn1 v18.2d, v9.2d, v11.2d
\n
"
/* trans q9, q11*/
"trn2 v19.2d, v9.2d, v11.2d
\n
"
/* trans q9, q11*/
"cbz %w[relu], 2f
\n
"
"smax v16.4s, v16.4s, v20.4s
\n
"
/* relu */
"smax v17.4s, v17.4s, v20.4s
\n
"
/* relu */
"smax v18.4s, v18.4s, v20.4s
\n
"
/* relu */
"smax v19.4s, v19.4s, v20.4s
\n
"
/* relu */
"2:
\n
"
"str q16, [%[doutc0r0]], #16
\n
"
/* store c0r0*/
"str q17, [%[doutc2r0]], #16
\n
"
/* store c2r0*/
"str q18, [%[doutc1r0]], #16
\n
"
/* store c1r0*/
"str q19, [%[doutc3r0]], #16
\n
"
/* store c3r0*/
"subs %w[cnt], %w[cnt], #1
\n
"
/* loop count -1*/
"bne 1b
\n
"
/* jump to main loop*/
:
[
doutc0r0
]
"+r"
(
dout0
),
[
doutc1r0
]
"+r"
(
dout1
),
[
doutc2r0
]
"+r"
(
dout2
),
[
doutc3r0
]
"+r"
(
dout3
),
[
ptr_din
]
"+r"
(
din
),
[
cnt
]
"+r"
(
cnt
)
:
[
relu
]
"r"
(
is_relu
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
);
#else
asm
volatile
(
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data
\n
"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data
\n
"
"vmov.u32 q15, #0 @ dump zero
\n
"
"1: @ main loop
\n
"
"vtrn.32 q0, q1 @ trans q0, q1
\n
"
"vtrn.32 q2, q3 @ trans q2, q3
\n
"
"vswp.32 d1, d4 @ swap d1, d4
\n
"
"vswp.32 d3, d6 @ swap d3, d6
\n
"
"cmp %[relu], #0
\n
"
"bne 2f
\n
"
"vmax.s32 q0, q0, q15 @ relu
\n
"
"vmax.s32 q1, q1, q15 @ relu
\n
"
"vmax.s32 q2, q2, q15 @ relu
\n
"
"vmax.s32 q3, q3, q15 @ relu
\n
"
"2:
\n
"
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer
\n
"
"vst1.32 {d2-d3}, [%[doutc1r0]]! @ store result, add pointer
\n
"
"vst1.32 {d4-d5}, [%[doutc2r0]]! @ store result, add pointer
\n
"
"vst1.32 {d6-d7}, [%[doutc3r0]]! @ store result, add pointer
\n
"
"subs %[cnt], %[cnt], #1 @ loop count - 1
\n
"
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data
\n
"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data
\n
"
"bne 1b @ jump to main loop
\n
"
:
[
doutc0r0
]
"+r"
(
dout0
),
[
doutc1r0
]
"+r"
(
dout1
),
[
doutc2r0
]
"+r"
(
dout2
),
[
doutc3r0
]
"+r"
(
dout3
),
[
ptr_din
]
"+r"
(
din
),
[
cnt
]
"+r"
(
cnt
)
:
[
relu
]
"r"
(
is_relu
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q15"
);
#endif
}
template
<
typename
Dtype
>
inline
Dtype
cvt_kernel
(
int
din
,
float
scale
,
float
bias
,
bool
flag_relu
);
inline
Dtype
cvt_kernel
(
int
din
,
float
scale
,
float
bias
,
int
flag_act
,
float
alpha
);
template
<
>
inline
float
cvt_kernel
(
int
din
,
float
scale
,
float
bias
,
bool
flag_relu
)
{
if
(
flag_relu
)
{
inline
float
cvt_kernel
(
int
din
,
float
scale
,
float
bias
,
int
flag_act
,
float
alpha
)
{
if
(
flag_act
==
1
)
{
return
LITEMAX
(
din
*
scale
+
bias
,
0
);
}
else
if
(
flag_act
==
0
)
{
return
din
*
scale
+
bias
;
}
else
if
(
flag_act
==
2
)
{
float
max
=
LITEMAX
(
din
*
scale
+
bias
,
0
);
return
LITEMIN
(
max
,
alpha
);
}
else
{
float
result
=
din
*
scale
+
bias
;
return
result
>
0
?
result
:
alpha
*
result
;
}
return
din
*
scale
+
bias
;
}
template
<
>
inline
int8_t
cvt_kernel
(
int
din
,
float
scale
,
float
bias
,
bool
flag_relu
)
{
if
(
flag_relu
)
{
return
saturate_cast
<
int8_t
>
(
round
(
LITEMAX
(
din
*
scale
+
bias
,
0
)));
}
else
{
inline
int8_t
cvt_kernel
(
int
din
,
float
scale
,
float
bias
,
int
flag_act
,
float
alpha
)
{
if
(
flag_act
==
1
)
{
auto
tmp
=
saturate_cast
<
int8_t
>
(
round
(
LITEMAX
(
din
*
scale
+
bias
,
0
)));
return
tmp
<
-
127
?
-
127
:
tmp
;
}
else
if
(
flag_act
==
0
)
{
auto
tmp
=
saturate_cast
<
int8_t
>
(
round
(
din
*
scale
+
bias
));
return
tmp
<
-
127
?
-
127
:
tmp
;
}
else
if
(
flag_act
==
2
)
{
float
max
=
LITEMAX
(
din
*
scale
+
bias
,
0
);
float
relu6_result
=
LITEMIN
(
max
,
alpha
);
auto
tmp
=
saturate_cast
<
int8_t
>
(
round
(
relu6_result
));
return
tmp
<
-
127
?
-
127
:
tmp
;
}
else
{
float
result
=
din
*
scale
+
bias
;
float
leaky_result
=
result
>
0
?
result
:
alpha
*
result
;
auto
tmp
=
saturate_cast
<
int8_t
>
(
round
(
leaky_result
));
return
tmp
<
-
127
?
-
127
:
tmp
;
}
}
template
<
>
inline
int32_t
cvt_kernel
(
int
din
,
float
scale
,
float
bias
,
bool
flag_relu
)
{
if
(
flag_relu
)
{
return
LITEMAX
(
din
,
0
);
}
return
din
;
}
template
<
typename
Dtype
>
inline
void
write_int32_nchwc4_to_nchw
(
const
int
*
din
,
Dtype
*
dout
,
...
...
@@ -3108,7 +3091,8 @@ inline void write_int32_nchwc4_to_nchw(const int* din,
int
channel
,
int
height
,
int
width
,
bool
flag_relu
,
int
flag_act
,
float
*
alpha
,
float
*
bias
,
bool
flag_bias
,
Dtype
*
trash_ptr
,
...
...
@@ -3160,21 +3144,22 @@ inline void write_int32_nchwc4_to_nchw(const int* din,
cnt
,
w_scale
,
w_bias
,
flag_relu
);
flag_act
,
alpha
);
}
if
(
we
>
width
)
{
int
offset
=
16
*
(
valid_w
/
4
-
1
);
din_hei_ptr
=
din
+
index
+
offset
;
int
j
=
we
-
4
;
for
(;
j
<
width
;
++
j
)
{
*
(
doutc0_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
0
],
scale
[
0
],
bias
[
0
],
flag_relu
);
*
(
doutc1_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
1
],
scale
[
1
],
bias
[
1
],
flag_relu
);
*
(
doutc2_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
2
],
scale
[
2
],
bias
[
2
],
flag_relu
);
*
(
doutc3_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
3
],
scale
[
3
],
bias
[
3
],
flag_relu
);
*
(
doutc0_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
0
],
scale
[
0
],
bias
[
0
],
flag_act
,
alpha
[
0
]
);
*
(
doutc1_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
1
],
scale
[
1
],
bias
[
1
],
flag_act
,
alpha
[
0
]
);
*
(
doutc2_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
2
],
scale
[
2
],
bias
[
2
],
flag_act
,
alpha
[
0
]
);
*
(
doutc3_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
3
],
scale
[
3
],
bias
[
3
],
flag_act
,
alpha
[
0
]
);
din_hei_ptr
+=
4
;
}
}
...
...
@@ -3196,7 +3181,8 @@ inline void int32_nchwc8_kernel(Dtype*& dout0, // NOLINT
float32x4_t
scale1
,
float32x4_t
bias0
,
float32x4_t
bias1
,
bool
is_relu
);
int
flag_act
,
float
*
alpha
);
// clang-format off
#ifdef __aarch64__
...
...
@@ -3205,7 +3191,6 @@ inline void int32_nchwc8_kernel(Dtype*& dout0, // NOLINT
"ldp q2, q3, [%[ptr_din]], #32\n"
/* load r02, r03 to q2, q3 */
\
"ldp q4, q5, [%[ptr_din]], #32\n"
/* load r00, r01 to q0, q1 */
\
"ldp q6, q7, [%[ptr_din]], #32\n"
/* load r02, r03 to q2, q3 */
\
"movi v31.4s, #0\n"
/* main loop*/
\
"1:\n" \
"trn1 v8.4s, v0.4s, v2.4s\n"
/* trans q0, q1*/
\
"trn2 v9.4s, v0.4s, v2.4s\n"
/* trans q0, q1*/
\
...
...
@@ -3256,17 +3241,71 @@ inline void int32_nchwc8_kernel(Dtype*& dout0, // NOLINT
"fmla v9.4s, v11.4s, %[scale1].s[2]\n" \
"fmla v12.4s, v14.4s, %[scale1].s[1]\n" \
"fmla v13.4s, v15.4s, %[scale1].s[3]\n" \
/* relu */
\
"cbz %w[relu], 2f\n" \
"fmax v16.4s, v16.4s, v31.4s\n"
/*relu*/
\
"fmax v17.4s, v17.4s, v31.4s\n"
/*relu*/
\
"fmax v18.4s, v18.4s, v31.4s\n"
/*relu*/
\
"fmax v19.4s, v19.4s, v31.4s\n"
/*relu*/
\
"fmax v8.4s, v8.4s, v31.4s\n"
/*relu*/
\
"fmax v9.4s, v9.4s, v31.4s\n"
/*relu*/
\
"fmax v12.4s, v12.4s, v31.4s\n"
/*relu*/
\
"fmax v13.4s, v13.4s, v31.4s\n"
/*relu*/
\
"2:\n"
/* activation */
\
"cmp %w[flag_act], #1\n" \
"bne 12f \n" \
"movi v31.4s, #0 \n"
/* for relu*/
\
"fmax v16.4s, v16.4s, v31.4s \n"
/*relu*/
\
"fmax v17.4s, v17.4s, v31.4s \n"
/*relu*/
\
"fmax v18.4s, v18.4s, v31.4s \n"
/*relu*/
\
"fmax v19.4s, v19.4s, v31.4s \n"
/*relu*/
\
"fmax v8.4s, v8.4s, v31.4s \n"
/*relu*/
\
"fmax v9.4s, v9.4s, v31.4s \n"
/*relu*/
\
"fmax v12.4s, v12.4s, v31.4s \n"
/*relu*/
\
"fmax v13.4s, v13.4s, v31.4s \n"
/*relu*/
\
"b 2f \n"
/* relu end */
\
"12: \n"
/* no relu */
\
"cmp %w[flag_act], #0 \n"
/* check no act */
\
"beq 2f \n"
/* no act end */
\
"cmp %w[flag_act], #2 \n"
/* check relu6 */
\
"bne 13f \n"
/* jump no relu6*/
\
"movi v20.4s, #0 \n"
/* for relu6 */
\
"ld1 {v21.4s}, [%[alpha]] \n"
/* relu6 alpha */
\
"fmax v16.4s, v16.4s, v20.4s \n"
/* relu6 */
\
"fmax v17.4s, v17.4s, v20.4s \n"
/* relu6 */
\
"fmax v18.4s, v18.4s, v20.4s \n"
/* relu6 */
\
"fmax v19.4s, v19.4s, v20.4s \n"
/* relu6 */
\
"fmax v8.4s, v8.4s, v20.4s \n"
/* relu6 */
\
"fmax v9.4s, v9.4s, v20.4s \n"
/* relu6 */
\
"fmax v12.4s, v12.4s, v20.4s \n"
/* relu6 */
\
"fmax v13.4s, v13.4s, v20.4s \n"
/* relu6 */
\
"fmin v16.4s, v16.4s, v21.4s \n"
/* relu6 */
\
"fmin v17.4s, v17.4s, v21.4s \n"
/* relu6 */
\
"fmin v18.4s, v18.4s, v21.4s \n"
/* relu6 */
\
"fmin v19.4s, v19.4s, v21.4s \n"
/* relu6 */
\
"fmin v8.4s, v8.4s, v21.4s \n"
/* relu6 */
\
"fmin v9.4s, v9.4s, v21.4s \n"
/* relu6 */
\
"fmin v12.4s, v12.4s, v21.4s \n"
/* relu6 */
\
"fmin v13.4s, v13.4s, v21.4s \n"
/* relu6 */
\
"b 2f \n"
/* relu6 end */
\
"13: \n"
/* leakey relu */
\
"movi v20.4s, #0 \n"
/* for leakey relu */
\
"ld1 {v21.4s}, [%[alpha]] \n"
/* leakey relu alpha */
\
"fcmge v10.4s, v16.4s, v20.4s \n"
/* vcgeq_f32 */
\
"fmul v11.4s, v16.4s, v21.4s \n"
/* vmulq_f32 */
\
"fcmge v14.4s, v17.4s, v20.4s \n"
/* vcgeq_f32 */
\
"fmul v15.4s, v17.4s, v21.4s \n"
/* vmulq_f32 */
\
"fcmge v22.4s, v18.4s, v20.4s \n"
/* vcgeq_f32 */
\
"fmul v23.4s, v18.4s, v21.4s \n"
/* vmulq_f32 */
\
"fcmge v24.4s, v19.4s, v20.4s \n"
/* vcgeq_f32 */
\
"fmul v25.4s, v19.4s, v21.4s \n"
/* vmulq_f32 */
\
"bif v16.16b, v11.16b, v10.16b \n"
/* choose*/
\
"bif v17.16b, v15.16b, v14.16b \n"
/* choose*/
\
"bif v18.16b, v23.16b, v22.16b \n"
/* choose*/
\
"bif v19.16b, v25.16b, v24.16b \n"
/* choose*/
\
"fcmge v10.4s, v8.4s, v20.4s \n"
/* vcgeq_f32 */
\
"fmul v11.4s, v8.4s, v21.4s \n"
/* vmulq_f32 */
\
"fcmge v14.4s, v9.4s, v20.4s \n"
/* vcgeq_f32 */
\
"fmul v15.4s, v9.4s, v21.4s \n"
/* vmulq_f32 */
\
"fcmge v22.4s, v12.4s, v20.4s \n"
/* vcgeq_f32 */
\
"fmul v23.4s, v12.4s, v21.4s \n"
/* vmulq_f32 */
\
"fcmge v24.4s, v13.4s, v20.4s \n"
/* vcgeq_f32 */
\
"fmul v25.4s, v13.4s, v21.4s \n"
/* vmulq_f32 */
\
"bif v8.16b, v11.16b, v10.16b \n"
/* choose*/
\
"bif v9.16b, v15.16b, v14.16b \n"
/* choose*/
\
"bif v12.16b, v23.16b, v22.16b \n"
/* choose*/
\
"bif v13.16b, v25.16b, v24.16b \n"
/* choose*/
\
"2: \n"
/* act end */
#else
#define INT32_NCHWC8_TO_NCHW_FP32 \
...
...
@@ -3312,18 +3351,68 @@ inline void int32_nchwc8_kernel(Dtype*& dout0, // NOLINT
"vswp d5, d12\n"
/* q2: b0-b3, q6: d0-d3 */
\
"vswp d3, d10\n"
/* q1: e0-e3, q5: g0-g3 */
\
"vswp d7, d14\n"
/* q3: f0-f3, q7: h0-h3 */
\
/* relu */
\
"vmov.i32 q8, #0\n" \
"cmp %[relu], #0\n" \
"beq 2f\n" \
"vmax.f32 q0, q0, q8\n"
/*relu*/
\
"vmax.f32 q2, q2, q8\n"
/*relu*/
\
"vmax.f32 q4, q4, q8\n"
/*relu*/
\
"vmax.f32 q6, q6, q8\n"
/*relu*/
\
"vmax.f32 q1, q1, q8\n"
/*relu*/
\
"vmax.f32 q3, q3, q8\n"
/*relu*/
\
"vmax.f32 q5, q5, q8\n"
/*relu*/
\
"vmax.f32 q7, q7, q8\n"
/*relu*/
\
/* activation */
\
"vmov.u32 q8, #0 \n" \
"cmp %[flag_act], #1 \n" \
"bne 12f \n" \
"vmax.f32 q0, q0, q8 \n"
/*relu*/
\
"vmax.f32 q2, q2, q8 \n"
/*relu*/
\
"vmax.f32 q4, q4, q8 \n"
/*relu*/
\
"vmax.f32 q6, q6, q8 \n"
/*relu*/
\
"vmax.f32 q1, q1, q8 \n"
/*relu*/
\
"vmax.f32 q3, q3, q8 \n"
/*relu*/
\
"vmax.f32 q5, q5, q8 \n"
/*relu*/
\
"vmax.f32 q7, q7, q8 \n"
/*relu*/
\
"b 2f \n" \
"12: \n" \
"cmp %[flag_act], #0 \n" \
"beq 2f \n" \
"cmp %[flag_act], #2 \n" \
"bne 13f \n" \
"vld1.f32 {d18-d19}, [%[alpha]] \n" \
"vmax.f32 q0, q0, q8 \n" \
"vmax.f32 q2, q2, q8 \n" \
"vmax.f32 q4, q4, q8 \n" \
"vmax.f32 q6, q6, q8 \n" \
"vmax.f32 q1, q1, q8 \n" \
"vmax.f32 q3, q3, q8 \n" \
"vmax.f32 q5, q5, q8 \n" \
"vmax.f32 q7, q7, q8 \n" \
"vmin.f32 q0, q0, q9 \n" \
"vmin.f32 q2, q2, q9 \n" \
"vmin.f32 q4, q4, q9 \n" \
"vmin.f32 q6, q6, q9 \n" \
"vmin.f32 q1, q1, q9 \n" \
"vmin.f32 q3, q3, q9 \n" \
"vmin.f32 q5, q5, q9 \n" \
"vmin.f32 q7, q7, q9 \n" \
"b 2f \n" \
"13: \n" \
"vld1.f32 {d18-d19}, [%[alpha]] \n" \
"vcge.f32 q10, q0, q8 \n" \
"vmul.f32 q11, q0, q9 \n" \
"vbif q0, q11, q10 \n" \
"vcge.f32 q10, q2, q8 \n" \
"vmul.f32 q11, q2, q9 \n" \
"vbif q2, q11, q10 \n" \
"vcge.f32 q10, q4, q8 \n" \
"vmul.f32 q11, q4, q9 \n" \
"vbif q4, q11, q10 \n" \
"vcge.f32 q10, q6, q8 \n" \
"vmul.f32 q11, q6, q9 \n" \
"vbif q6, q11, q10 \n" \
"vcge.f32 q10, q1, q8 \n" \
"vmul.f32 q11, q1, q9 \n" \
"vbif q1, q11, q10 \n" \
"vcge.f32 q10, q3, q8 \n" \
"vmul.f32 q11, q3, q9 \n" \
"vbif q3, q11, q10 \n" \
"vcge.f32 q10, q5, q8 \n" \
"vmul.f32 q11, q5, q9 \n" \
"vbif q5, q11, q10 \n" \
"vcge.f32 q10, q7, q8 \n" \
"vmul.f32 q11, q7, q9 \n" \
"vbif q7, q11, q10 \n" \
"2:\n"
#endif
...
...
@@ -3344,7 +3433,9 @@ inline void int32_nchwc8_kernel(float*& dout0, // NOLINT
float32x4_t
scale1
,
float32x4_t
bias0
,
float32x4_t
bias1
,
bool
is_relu
)
{
int
flag_act
,
float
*
alpha
)
{
// clang-format off
#ifdef __aarch64__
asm
volatile
(
INT32_NCHWC8_TO_NCHW_FP32
"subs %w[cnt], %w[cnt], #1
\n
"
/* loop count -1*/
...
...
@@ -3371,31 +3462,13 @@ inline void int32_nchwc8_kernel(float*& dout0, // NOLINT
[
scale1
]
"w"
(
scale1
),
[
bias0
]
"w"
(
bias0
),
[
bias1
]
"w"
(
bias1
),
[
relu
]
"r"
(
is_relu
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v31"
);
[
flag_act
]
"r"
(
flag_act
),
[
alpha
]
"r"
(
alpha
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v31"
);
#else
asm
volatile
(
INT32_NCHWC8_TO_NCHW_FP32
"subs %[cnt], #1
\n
"
/* loop count -1*/
...
...
@@ -3422,22 +3495,13 @@ inline void int32_nchwc8_kernel(float*& dout0, // NOLINT
[
scale1
]
"w"
(
scale1
),
[
bias0
]
"w"
(
bias0
),
[
bias1
]
"w"
(
bias1
),
[
relu
]
"r"
(
is_relu
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
);
[
flag_act
]
"r"
(
flag_act
),
[
alpha
]
"r"
(
alpha
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
);
#endif
// clang-format on
}
template
<
>
...
...
@@ -3455,7 +3519,9 @@ inline void int32_nchwc8_kernel(int8_t*& dout0, // NOLINT
float32x4_t
scale1
,
float32x4_t
bias0
,
float32x4_t
bias1
,
bool
is_relu
)
{
int
flag_act
,
float
*
alpha
)
{
// clang-format off
#ifdef __aarch64__
float32x4_t
vmax
=
vdupq_n_f32
(
-
127.
f
);
asm
volatile
(
INT32_NCHWC8_TO_NCHW_FP32
/* fp32-int32 */
...
...
@@ -3529,34 +3595,13 @@ inline void int32_nchwc8_kernel(int8_t*& dout0, // NOLINT
[
bias0
]
"w"
(
bias0
),
[
bias1
]
"w"
(
bias1
),
[
vmax
]
"w"
(
vmax
),
[
relu
]
"r"
(
is_relu
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v31"
);
[
flag_act
]
"r"
(
flag_act
),
[
alpha
]
"r"
(
alpha
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v31"
);
#else
float
vmax
[
4
]
=
{
-
127.
f
,
-
127.
f
,
-
127.
f
,
-
127.
f
};
asm
volatile
(
INT32_NCHWC8_TO_NCHW_FP32
/* set +-0.5 offset */
...
...
@@ -3669,175 +3714,13 @@ inline void int32_nchwc8_kernel(int8_t*& dout0, // NOLINT
[
bias0
]
"w"
(
bias0
),
[
bias1
]
"w"
(
bias1
),
[
vmax
]
"r"
(
vmax
),
[
relu
]
"r"
(
is_relu
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
);
#endif
}
template
<
>
inline
void
int32_nchwc8_kernel
(
int32_t
*&
dout0
,
// NOLINT
int32_t
*&
dout1
,
// NOLINT
int32_t
*&
dout2
,
// NOLINT
int32_t
*&
dout3
,
// NOLINT
int32_t
*&
dout4
,
// NOLINT
int32_t
*&
dout5
,
// NOLINT
int32_t
*&
dout6
,
// NOLINT
int32_t
*&
dout7
,
// NOLINT
const
int32_t
*&
din
,
// NOLINT
int
cnt
,
float32x4_t
scale0
,
float32x4_t
scale1
,
float32x4_t
bias0
,
float32x4_t
bias1
,
bool
is_relu
)
{
#ifdef __aarch64__
asm
volatile
(
"ldp q0, q1, [%[ptr_din]], #32
\n
"
/* load r00, r01 to q0, q1 */
"ldp q2, q3, [%[ptr_din]], #32
\n
"
/* load r02, r03 to q2, q3 */
"ldp q4, q5, [%[ptr_din]], #32
\n
"
/* load r00, r01 to q0, q1 */
"ldp q6, q7, [%[ptr_din]], #32
\n
"
/* load r02, r03 to q2, q3 */
"movi v20.4s, #0
\n
"
/* for relu */
"1:
\n
"
/* main loop*/
"trn1 v8.4s, v0.4s, v2.4s
\n
"
/* trans q0, q1*/
"trn2 v9.4s, v0.4s, v2.4s
\n
"
/* trans q0, q1*/
"trn1 v10.4s, v1.4s, v3.4s
\n
"
/* trans q2, q3*/
"trn2 v11.4s, v1.4s, v3.4s
\n
"
/* trans q2, q3*/
"ldp q0, q1, [%[ptr_din]], #32
\n
"
/* load r00, r01 to q0, q1 */
"trn1 v12.4s, v4.4s, v6.4s
\n
"
/* trans q0, q1*/
"trn2 v13.4s, v4.4s, v6.4s
\n
"
/* trans q0, q1*/
"trn1 v14.4s, v5.4s, v7.4s
\n
"
/* trans q2, q3*/
"trn2 v15.4s, v5.4s, v7.4s
\n
"
/* trans q2, q3*/
"ldp q2, q3, [%[ptr_din]], #32
\n
"
/* load r02, r03 to q2, q3 */
"trn1 v16.2d, v8.2d, v12.2d
\n
"
/* trans q8, q10 00 01 02 03*/
"trn2 v17.2d, v8.2d, v12.2d
\n
"
/* trans q8, q10 20 21 22 23*/
"trn1 v18.2d, v9.2d, v13.2d
\n
"
/* trans q9, q11 10 11 12 13*/
"trn2 v19.2d, v9.2d, v13.2d
\n
"
/* trans q9, q11 30 31 32 33*/
"ldp q4, q5, [%[ptr_din]], #32
\n
"
/* load r00, r01 to q0, q1 */
"trn1 v8.2d, v10.2d, v14.2d
\n
"
/* trans q8, q10 40 41 42 43*/
"trn2 v9.2d, v10.2d, v14.2d
\n
"
/* trans q8, q10 60 61 62 63*/
"trn1 v12.2d, v11.2d, v15.2d
\n
"
/* trans q9, q11 50 51 52 53*/
"trn2 v13.2d, v11.2d, v15.2d
\n
"
/* trans q9, q11 70 71 72 73*/
"ldp q6, q7, [%[ptr_din]], #32
\n
"
/* load r02, r03 to q2, q3 */
"cbz %w[relu], 2f
\n
"
"smax v16.4s, v16.4s, v20.4s
\n
"
/*relu*/
"smax v17.4s, v17.4s, v20.4s
\n
"
/*relu*/
"smax v18.4s, v18.4s, v20.4s
\n
"
/*relu*/
"smax v19.4s, v19.4s, v20.4s
\n
"
/*relu*/
"smax v8.4s, v8.4s, v20.4s
\n
"
/*relu*/
"smax v9.4s, v9.4s, v20.4s
\n
"
/*relu*/
"smax v12.4s, v12.4s, v20.4s
\n
"
/*relu*/
"smax v13.4s, v13.4s, v20.4s
\n
"
/*relu*/
"2:
\n
"
"str q16, [%[doutc0r0]], #16
\n
"
/* store c0r0*/
"str q17, [%[doutc2r0]], #16
\n
"
/* store c2r0*/
"str q18, [%[doutc1r0]], #16
\n
"
/* store c1r0*/
"str q19, [%[doutc3r0]], #16
\n
"
/* store c3r0*/
"subs %w[cnt], %w[cnt], #1
\n
"
/* loop count -1*/
"str q8, [%[doutc4r0]], #16
\n
"
/* store c0r0*/
"str q9, [%[doutc6r0]], #16
\n
"
/* store c2r0*/
"str q12, [%[doutc5r0]], #16
\n
"
/* store c1r0*/
"str q13, [%[doutc7r0]], #16
\n
"
/* store c3r0*/
"bne 1b
\n
"
/* jump to main loop*/
:
[
doutc0r0
]
"+r"
(
dout0
),
[
doutc1r0
]
"+r"
(
dout1
),
[
doutc2r0
]
"+r"
(
dout2
),
[
doutc3r0
]
"+r"
(
dout3
),
[
doutc4r0
]
"+r"
(
dout4
),
[
doutc5r0
]
"+r"
(
dout5
),
[
doutc6r0
]
"+r"
(
dout6
),
[
doutc7r0
]
"+r"
(
dout7
),
[
ptr_din
]
"+r"
(
din
),
[
cnt
]
"+r"
(
cnt
)
:
[
relu
]
"r"
(
is_relu
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v18"
,
"v19"
,
"v20"
);
#else
asm
volatile
(
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data
\n
"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data
\n
"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data
\n
"
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data
\n
"
"vmov.s32 q15, #0 @ dump zero
\n
"
"1: @ main loop
\n
"
"vtrn.32 q0, q2 @ trans q0, q2
\n
"
"vtrn.32 q4, q6 @ trans q4, q6
\n
"
"vswp.32 d1, d8 @ swap d1, d8
\n
"
"vswp.32 d5, d12 @ swap d5, d12
\n
"
"vtrn.32 q1, q3 @ trans q1, q3
\n
"
"vtrn.32 q5, q7 @ trans q5, q7
\n
"
"vswp.32 d3, d10 @ swap d3, d10
\n
"
"vswp.32 d7, d14 @ swap d7, d14
\n
"
"cmp %[relu], #0
\n
"
"bne 2f
\n
"
"vmax.s32 q0, q0, q15 @ relu
\n
"
"vmax.s32 q1, q1, q15 @ relu
\n
"
"vmax.s32 q2, q2, q15 @ relu
\n
"
"vmax.s32 q3, q3, q15 @ relu
\n
"
"vmax.s32 q4, q4, q15 @ relu
\n
"
"vmax.s32 q5, q5, q15 @ relu
\n
"
"vmax.s32 q6, q6, q15 @ relu
\n
"
"vmax.s32 q7, q7, q15 @ relu
\n
"
"2:
\n
"
"subs %[cnt], %[cnt], #1 @ loop count - 1
\n
"
"vst1.32 {d0-d1}, [%[doutc0r0]]! @ store result, add pointer
\n
"
"vst1.32 {d2-d3}, [%[doutc4r0]]! @ store result, add pointer
\n
"
"vst1.32 {d4-d5}, [%[doutc1r0]]! @ store result, add pointer
\n
"
"vst1.32 {d6-d7}, [%[doutc5r0]]! @ store result, add pointer
\n
"
"vld1.32 {d0-d3}, [%[ptr_din]]! @load data
\n
"
"vld1.32 {d4-d7}, [%[ptr_din]]! @load data
\n
"
"vst1.32 {d8-d9}, [%[doutc2r0]]! @ store result, add pointer
\n
"
"vst1.32 {d10-d11}, [%[doutc6r0]]! @ store result, add pointer
\n
"
"vst1.32 {d12-d13}, [%[doutc3r0]]! @ store result, add pointer
\n
"
"vst1.32 {d14-d15}, [%[doutc7r0]]! @ store result, add pointer
\n
"
"vld1.32 {d8-d11}, [%[ptr_din]]! @load data
\n
"
"vld1.32 {d12-d15}, [%[ptr_din]]! @load data
\n
"
"bne 1b @ jump to main loop
\n
"
:
[
doutc0r0
]
"+r"
(
dout0
),
[
doutc1r0
]
"+r"
(
dout1
),
[
doutc2r0
]
"+r"
(
dout2
),
[
doutc3r0
]
"+r"
(
dout3
),
[
doutc4r0
]
"+r"
(
dout4
),
[
doutc5r0
]
"+r"
(
dout5
),
[
doutc6r0
]
"+r"
(
dout6
),
[
doutc7r0
]
"+r"
(
dout7
),
[
ptr_din
]
"+r"
(
din
)
:
[
cnt
]
"r"
(
cnt
),
[
relu
]
"r"
(
is_relu
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q15"
);
[
flag_act
]
"r"
(
flag_act
),
[
alpha
]
"r"
(
alpha
)
:
"cc"
,
"memory"
,
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
);
#endif
// clang-format on
}
/*wirte result in outputs
...
...
@@ -3855,7 +3738,8 @@ inline void write_int32_nchwc8_to_nchw(const int* din,
int
channel
,
int
height
,
int
width
,
bool
flag_relu
,
int
flag_act
,
float
*
alpha
,
float
*
bias
,
bool
flag_bias
,
Dtype
*
trash_ptr
,
...
...
@@ -3931,46 +3815,47 @@ inline void write_int32_nchwc8_to_nchw(const int* din,
w_scale1
,
w_bias0
,
w_bias1
,
flag_relu
);
flag_act
,
alpha
);
}
if
(
we
>
width
)
{
int
offset
=
32
*
cnt
;
din_hei_ptr
=
ptr_din
+
offset
;
for
(
int
j
=
ws
+
cnt
*
4
;
j
<
width
;
++
j
)
{
if
(
flag_bias
)
{
*
(
doutc0_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
0
],
scale
[
0
],
bias
[
0
],
flag_relu
);
*
(
doutc1_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
1
],
scale
[
1
],
bias
[
1
],
flag_relu
);
*
(
doutc2_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
2
],
scale
[
2
],
bias
[
2
],
flag_relu
);
*
(
doutc3_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
3
],
scale
[
3
],
bias
[
3
],
flag_relu
);
*
(
doutc4_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
4
],
scale
[
4
],
bias
[
4
],
flag_relu
);
*
(
doutc5_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
5
],
scale
[
5
],
bias
[
5
],
flag_relu
);
*
(
doutc6_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
6
],
scale
[
6
],
bias
[
6
],
flag_relu
);
*
(
doutc7_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
7
],
scale
[
7
],
bias
[
7
],
flag_relu
);
*
(
doutc0_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
0
],
scale
[
0
],
bias
[
0
],
flag_act
,
alpha
[
0
]
);
*
(
doutc1_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
1
],
scale
[
1
],
bias
[
1
],
flag_act
,
alpha
[
0
]
);
*
(
doutc2_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
2
],
scale
[
2
],
bias
[
2
],
flag_act
,
alpha
[
0
]
);
*
(
doutc3_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
3
],
scale
[
3
],
bias
[
3
],
flag_act
,
alpha
[
0
]
);
*
(
doutc4_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
4
],
scale
[
4
],
bias
[
4
],
flag_act
,
alpha
[
0
]
);
*
(
doutc5_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
5
],
scale
[
5
],
bias
[
5
],
flag_act
,
alpha
[
0
]
);
*
(
doutc6_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
6
],
scale
[
6
],
bias
[
6
],
flag_act
,
alpha
[
0
]
);
*
(
doutc7_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
7
],
scale
[
7
],
bias
[
7
],
flag_act
,
alpha
[
0
]
);
}
else
{
*
(
doutc0_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
0
],
scale
[
0
],
0.
f
,
flag_relu
);
*
(
doutc1_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
1
],
scale
[
1
],
0.
f
,
flag_relu
);
*
(
doutc2_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
2
],
scale
[
2
],
0.
f
,
flag_relu
);
*
(
doutc3_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
3
],
scale
[
3
],
0.
f
,
flag_relu
);
*
(
doutc4_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
4
],
scale
[
4
],
0.
f
,
flag_relu
);
*
(
doutc5_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
5
],
scale
[
5
],
0.
f
,
flag_relu
);
*
(
doutc6_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
6
],
scale
[
6
],
0.
f
,
flag_relu
);
*
(
doutc7_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
7
],
scale
[
7
],
0.
f
,
flag_relu
);
*
(
doutc0_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
0
],
scale
[
0
],
0.
f
,
flag_act
,
alpha
[
0
]
);
*
(
doutc1_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
1
],
scale
[
1
],
0.
f
,
flag_act
,
alpha
[
0
]
);
*
(
doutc2_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
2
],
scale
[
2
],
0.
f
,
flag_act
,
alpha
[
0
]
);
*
(
doutc3_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
3
],
scale
[
3
],
0.
f
,
flag_act
,
alpha
[
0
]
);
*
(
doutc4_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
4
],
scale
[
4
],
0.
f
,
flag_act
,
alpha
[
0
]
);
*
(
doutc5_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
5
],
scale
[
5
],
0.
f
,
flag_act
,
alpha
[
0
]
);
*
(
doutc6_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
6
],
scale
[
6
],
0.
f
,
flag_act
,
alpha
[
0
]
);
*
(
doutc7_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
7
],
scale
[
7
],
0.
f
,
flag_act
,
alpha
[
0
]
);
}
din_hei_ptr
+=
8
;
}
...
...
lite/backends/arm/math/conv_depthwise.h
浏览文件 @
9e361a4d
...
...
@@ -94,7 +94,8 @@ void conv_depthwise_3x3s1_int8(Dtype* dout,
const
float
*
scale
,
const
float
*
bias
,
bool
flag_bias
,
bool
flag_relu
,
int
flag_act
,
float
*
alpha
,
int
num
,
int
chin
,
int
hin
,
...
...
@@ -112,7 +113,8 @@ void conv_depthwise_3x3s2_int8(Dtype* dout,
const
float
*
scale
,
const
float
*
bias
,
bool
flag_bias
,
bool
flag_relu
,
int
flag_act
,
float
*
alpha
,
int
num
,
int
chin
,
int
hin
,
...
...
@@ -178,7 +180,8 @@ void conv_depthwise_5x5s1_int8(Dtype* dout,
const
float
*
scale
,
const
float
*
bias
,
bool
flag_bias
,
bool
flag_relu
,
int
flag_act
,
float
*
alpha
,
int
num
,
int
chin
,
int
hin
,
...
...
@@ -196,7 +199,8 @@ void conv_depthwise_5x5s2_int8(Dtype* dout,
const
float
*
scale
,
const
float
*
bias
,
bool
flag_bias
,
bool
flag_relu
,
int
flag_act
,
float
*
alpha
,
int
num
,
int
chin
,
int
hin
,
...
...
lite/backends/arm/math/conv_impl.cc
浏览文件 @
9e361a4d
...
...
@@ -790,8 +790,30 @@ void conv_depthwise_3x3_int8_fp32(const void* din,
int
pad_h
=
paddings
[
0
];
int
pad_w
=
paddings
[
2
];
int
stride
=
param
.
strides
[
1
];
bool
flag_relu
=
param
.
fuse_relu
;
bool
flag_bias
=
param
.
bias
!=
nullptr
;
auto
act_param
=
param
.
activation_param
;
auto
act_type
=
act_param
.
active_type
;
int
flag_act
=
0
;
// relu: 1, relu6: 2, leakey: 3
float
alpha
[
4
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
if
(
act_param
.
has_active
)
{
if
(
act_type
==
lite_api
::
ActivationType
::
kRelu
)
{
flag_act
=
1
;
}
else
if
(
act_type
==
lite_api
::
ActivationType
::
kRelu6
)
{
flag_act
=
2
;
float
local_alpha
=
act_param
.
Relu_clipped_coef
;
alpha
[
0
]
=
local_alpha
;
alpha
[
1
]
=
local_alpha
;
alpha
[
2
]
=
local_alpha
;
alpha
[
3
]
=
local_alpha
;
}
else
if
(
act_type
==
lite_api
::
ActivationType
::
kLeakyRelu
)
{
flag_act
=
3
;
float
local_alpha
=
act_param
.
Leaky_relu_alpha
;
alpha
[
0
]
=
local_alpha
;
alpha
[
1
]
=
local_alpha
;
alpha
[
2
]
=
local_alpha
;
alpha
[
3
]
=
local_alpha
;
}
}
if
(
stride
==
1
)
{
conv_depthwise_3x3s1_int8
(
reinterpret_cast
<
float
*>
(
dout
),
reinterpret_cast
<
const
int8_t
*>
(
din
),
...
...
@@ -799,7 +821,8 @@ void conv_depthwise_3x3_int8_fp32(const void* din,
scale
,
bias
,
flag_bias
,
flag_relu
,
flag_act
,
alpha
,
num
,
ch_in
,
h_in
,
...
...
@@ -816,7 +839,8 @@ void conv_depthwise_3x3_int8_fp32(const void* din,
scale
,
bias
,
flag_bias
,
flag_relu
,
flag_act
,
alpha
,
num
,
ch_in
,
h_in
,
...
...
@@ -849,8 +873,30 @@ void conv_depthwise_3x3_int8_int8(const void* din,
int
pad_h
=
paddings
[
0
];
int
pad_w
=
paddings
[
2
];
int
stride
=
param
.
strides
[
1
];
bool
flag_relu
=
param
.
fuse_relu
;
bool
flag_bias
=
param
.
bias
!=
nullptr
;
auto
act_param
=
param
.
activation_param
;
auto
act_type
=
act_param
.
active_type
;
int
flag_act
=
0
;
// relu: 1, relu6: 2, leakey: 3
float
alpha
[
4
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
if
(
act_param
.
has_active
)
{
if
(
act_type
==
lite_api
::
ActivationType
::
kRelu
)
{
flag_act
=
1
;
}
else
if
(
act_type
==
lite_api
::
ActivationType
::
kRelu6
)
{
flag_act
=
2
;
float
local_alpha
=
act_param
.
Relu_clipped_coef
;
alpha
[
0
]
=
local_alpha
;
alpha
[
1
]
=
local_alpha
;
alpha
[
2
]
=
local_alpha
;
alpha
[
3
]
=
local_alpha
;
}
else
if
(
act_type
==
lite_api
::
ActivationType
::
kLeakyRelu
)
{
flag_act
=
3
;
float
local_alpha
=
act_param
.
Leaky_relu_alpha
;
alpha
[
0
]
=
local_alpha
;
alpha
[
1
]
=
local_alpha
;
alpha
[
2
]
=
local_alpha
;
alpha
[
3
]
=
local_alpha
;
}
}
if
(
stride
==
1
)
{
conv_depthwise_3x3s1_int8
(
reinterpret_cast
<
int8_t
*>
(
dout
),
reinterpret_cast
<
const
int8_t
*>
(
din
),
...
...
@@ -858,7 +904,8 @@ void conv_depthwise_3x3_int8_int8(const void* din,
scale
,
bias
,
flag_bias
,
flag_relu
,
flag_act
,
alpha
,
num
,
ch_in
,
h_in
,
...
...
@@ -875,7 +922,8 @@ void conv_depthwise_3x3_int8_int8(const void* din,
scale
,
bias
,
flag_bias
,
flag_relu
,
flag_act
,
alpha
,
num
,
ch_in
,
h_in
,
...
...
@@ -908,8 +956,30 @@ void conv_depthwise_5x5_int8_fp32(const void* din,
int
pad_h
=
paddings
[
0
];
int
pad_w
=
paddings
[
2
];
int
stride
=
param
.
strides
[
1
];
bool
flag_relu
=
param
.
fuse_relu
;
bool
flag_bias
=
param
.
bias
!=
nullptr
;
auto
act_param
=
param
.
activation_param
;
auto
act_type
=
act_param
.
active_type
;
int
flag_act
=
0
;
// relu: 1, relu6: 2, leakey: 3
float
alpha
[
4
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
if
(
act_param
.
has_active
)
{
if
(
act_type
==
lite_api
::
ActivationType
::
kRelu
)
{
flag_act
=
1
;
}
else
if
(
act_type
==
lite_api
::
ActivationType
::
kRelu6
)
{
flag_act
=
2
;
float
local_alpha
=
act_param
.
Relu_clipped_coef
;
alpha
[
0
]
=
local_alpha
;
alpha
[
1
]
=
local_alpha
;
alpha
[
2
]
=
local_alpha
;
alpha
[
3
]
=
local_alpha
;
}
else
if
(
act_type
==
lite_api
::
ActivationType
::
kLeakyRelu
)
{
flag_act
=
3
;
float
local_alpha
=
act_param
.
Leaky_relu_alpha
;
alpha
[
0
]
=
local_alpha
;
alpha
[
1
]
=
local_alpha
;
alpha
[
2
]
=
local_alpha
;
alpha
[
3
]
=
local_alpha
;
}
}
if
(
stride
==
1
)
{
conv_depthwise_5x5s1_int8
(
reinterpret_cast
<
float
*>
(
dout
),
reinterpret_cast
<
const
int8_t
*>
(
din
),
...
...
@@ -917,7 +987,8 @@ void conv_depthwise_5x5_int8_fp32(const void* din,
scale
,
bias
,
flag_bias
,
flag_relu
,
flag_act
,
alpha
,
num
,
ch_in
,
h_in
,
...
...
@@ -934,7 +1005,8 @@ void conv_depthwise_5x5_int8_fp32(const void* din,
scale
,
bias
,
flag_bias
,
flag_relu
,
flag_act
,
alpha
,
num
,
ch_in
,
h_in
,
...
...
@@ -967,8 +1039,30 @@ void conv_depthwise_5x5_int8_int8(const void* din,
int
pad_h
=
paddings
[
0
];
int
pad_w
=
paddings
[
2
];
int
stride
=
param
.
strides
[
1
];
bool
flag_relu
=
param
.
fuse_relu
;
bool
flag_bias
=
param
.
bias
!=
nullptr
;
auto
act_param
=
param
.
activation_param
;
auto
act_type
=
act_param
.
active_type
;
int
flag_act
=
0
;
// relu: 1, relu6: 2, leakey: 3
float
alpha
[
4
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
if
(
act_param
.
has_active
)
{
if
(
act_type
==
lite_api
::
ActivationType
::
kRelu
)
{
flag_act
=
1
;
}
else
if
(
act_type
==
lite_api
::
ActivationType
::
kRelu6
)
{
flag_act
=
2
;
float
local_alpha
=
act_param
.
Relu_clipped_coef
;
alpha
[
0
]
=
local_alpha
;
alpha
[
1
]
=
local_alpha
;
alpha
[
2
]
=
local_alpha
;
alpha
[
3
]
=
local_alpha
;
}
else
if
(
act_type
==
lite_api
::
ActivationType
::
kLeakyRelu
)
{
flag_act
=
3
;
float
local_alpha
=
act_param
.
Leaky_relu_alpha
;
alpha
[
0
]
=
local_alpha
;
alpha
[
1
]
=
local_alpha
;
alpha
[
2
]
=
local_alpha
;
alpha
[
3
]
=
local_alpha
;
}
}
if
(
stride
==
1
)
{
conv_depthwise_5x5s1_int8
(
reinterpret_cast
<
int8_t
*>
(
dout
),
reinterpret_cast
<
const
int8_t
*>
(
din
),
...
...
@@ -976,7 +1070,8 @@ void conv_depthwise_5x5_int8_int8(const void* din,
scale
,
bias
,
flag_bias
,
flag_relu
,
flag_act
,
alpha
,
num
,
ch_in
,
h_in
,
...
...
@@ -993,7 +1088,8 @@ void conv_depthwise_5x5_int8_int8(const void* din,
scale
,
bias
,
flag_bias
,
flag_relu
,
flag_act
,
alpha
,
num
,
ch_in
,
h_in
,
...
...
lite/backends/arm/math/gemm_prepacked_int8.cc
浏览文件 @
9e361a4d
...
...
@@ -534,18 +534,18 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
"fmin v17.4s, v17.4s, v1.4s\n"
/* relu6 */
\
"fmin v18.4s, v18.4s, v1.4s\n"
/* relu6 */
\
"fmin v19.4s, v19.4s, v1.4s\n"
/* relu6 */
\
"fmin v20.4s, v20.4s, v
0
.4s\n"
/* relu6 */
\
"fmin v21.4s, v21.4s, v
0
.4s\n"
/* relu6 */
\
"fmin v22.4s, v22.4s, v
0
.4s\n"
/* relu6 */
\
"fmin v23.4s, v23.4s, v
0
.4s\n"
/* relu6 */
\
"fmin v24.4s, v24.4s, v
0
.4s\n"
/* relu6 */
\
"fmin v25.4s, v25.4s, v
0
.4s\n"
/* relu6 */
\
"fmin v26.4s, v26.4s, v
0
.4s\n"
/* relu6 */
\
"fmin v27.4s, v27.4s, v
0
.4s\n"
/* relu6 */
\
"fmin v28.4s, v28.4s, v
0
.4s\n"
/* relu6 */
\
"fmin v29.4s, v29.4s, v
0
.4s\n"
/* relu6 */
\
"fmin v30.4s, v30.4s, v
0
.4s\n"
/* relu6 */
\
"fmin v31.4s, v31.4s, v
0
.4s\n"
/* relu6 */
\
"fmin v20.4s, v20.4s, v
1
.4s\n"
/* relu6 */
\
"fmin v21.4s, v21.4s, v
1
.4s\n"
/* relu6 */
\
"fmin v22.4s, v22.4s, v
1
.4s\n"
/* relu6 */
\
"fmin v23.4s, v23.4s, v
1
.4s\n"
/* relu6 */
\
"fmin v24.4s, v24.4s, v
1
.4s\n"
/* relu6 */
\
"fmin v25.4s, v25.4s, v
1
.4s\n"
/* relu6 */
\
"fmin v26.4s, v26.4s, v
1
.4s\n"
/* relu6 */
\
"fmin v27.4s, v27.4s, v
1
.4s\n"
/* relu6 */
\
"fmin v28.4s, v28.4s, v
1
.4s\n"
/* relu6 */
\
"fmin v29.4s, v29.4s, v
1
.4s\n"
/* relu6 */
\
"fmin v30.4s, v30.4s, v
1
.4s\n"
/* relu6 */
\
"fmin v31.4s, v31.4s, v
1
.4s\n"
/* relu6 */
\
"b 9f \n"
/* relu end */
#define GEMM_INT8_LEAKY_RELU \
...
...
lite/kernels/arm/conv_depthwise.cc
浏览文件 @
9e361a4d
...
...
@@ -169,6 +169,12 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
}
flag_trans_bias_
=
true
;
}
//! update relu6 parameter
if
(
param
.
activation_param
.
has_active
&&
param
.
activation_param
.
active_type
==
lite_api
::
ActivationType
::
kRelu6
)
{
param
.
activation_param
.
Relu_clipped_coef
=
param
.
activation_param
.
Relu_clipped_coef
/
param
.
output_scale
;
}
/// select dw conv kernel
if
(
kw
==
3
)
{
// trans weights
...
...
lite/kernels/arm/conv_direct.h
浏览文件 @
9e361a4d
...
...
@@ -39,7 +39,8 @@ inline bool direct_conv_trans_weights(
const
std
::
vector
<
float
>&
w_scale
,
float
in_scale
,
float
out_scale
,
std
::
vector
<
float
>&
merge_scale
)
{
// NOLINT
std
::
vector
<
float
>&
merge_scale
,
// NOLINT
float
*
relu_clipped_coef
)
{
constexpr
int
cblock
=
4
;
int
oc
=
win
->
dims
()[
0
];
int
ic
=
win
->
dims
()[
1
];
...
...
@@ -64,7 +65,8 @@ inline bool direct_conv_trans_weights<PRECISION(kInt8), PRECISION(kFloat)>(
const
std
::
vector
<
float
>&
w_scale
,
float
in_scale
,
float
out_scale
,
std
::
vector
<
float
>&
merge_scale
)
{
// NOLINT
std
::
vector
<
float
>&
merge_scale
,
// NOLINT
float
*
relu_clipped_coef
)
{
int
cblock
=
4
;
if
(
stride
==
2
)
{
cblock
=
lite
::
arm
::
math
::
conv_3x3s2_direct_int8_c_num
();
...
...
@@ -103,7 +105,8 @@ inline bool direct_conv_trans_weights<PRECISION(kInt8), PRECISION(kInt8)>(
const
std
::
vector
<
float
>&
w_scale
,
float
in_scale
,
float
out_scale
,
std
::
vector
<
float
>&
merge_scale
)
{
// NOLINT
std
::
vector
<
float
>&
merge_scale
,
// NOLINT
float
*
relu_clipped_coef
)
{
int
cblock
=
4
;
if
(
stride
==
2
)
{
cblock
=
lite
::
arm
::
math
::
conv_3x3s2_direct_int8_c_num
();
...
...
@@ -130,6 +133,8 @@ inline bool direct_conv_trans_weights<PRECISION(kInt8), PRECISION(kInt8)>(
merge_scale
[
i
]
=
w_scale
[
i
]
*
scale
;
}
}
/// update relu_clipped_coef
*
relu_clipped_coef
/=
out_scale
;
/// update bias
if
(
bin
)
{
bout
->
Resize
(
bin
->
dims
());
...
...
@@ -167,16 +172,17 @@ class DirectConv : public KernelLite<TARGET(kARM), Ptype> {
<<
"direct conv only support conv3x3s1 and conv3x3s2"
;
CHECK
(
kw
==
3
&&
kh
==
3
)
<<
"direct conv only support conv3x3s1 and conv3x3s2"
;
flag_trans_bias_
=
direct_conv_trans_weights
<
Ptype
,
OutType
>
(
param
.
filter
,
&
weights_
,
param
.
bias
,
&
bias_
,
sw
,
param
.
weight_scale
,
param
.
input_scale
,
param
.
output_scale
,
w_scale_
);
flag_trans_bias_
=
direct_conv_trans_weights
<
Ptype
,
OutType
>
(
param
.
filter
,
&
weights_
,
param
.
bias
,
&
bias_
,
sw
,
param
.
weight_scale
,
param
.
input_scale
,
param
.
output_scale
,
w_scale_
,
&
param
.
activation_param
.
Relu_clipped_coef
);
}
virtual
void
Run
();
...
...
lite/tests/math/conv_int8_compute_test.cc
浏览文件 @
9e361a4d
...
...
@@ -56,7 +56,7 @@ DEFINE_int32(dila_w, 1, "dilation width");
DEFINE_bool
(
flag_act
,
true
,
"do act"
);
DEFINE_bool
(
flag_bias
,
true
,
"with bias"
);
DEFINE_double
(
clipped_coef
,
1.0
,
"clipped relu coef"
);
DEFINE_double
(
leakey_relu_alpha
,
8.88
,
"leakey relu alpha"
);
DEFINE_double
(
leakey_relu_alpha
,
2.22
,
"leakey relu alpha"
);
typedef
paddle
::
lite
::
DDim
DDim
;
typedef
paddle
::
lite
::
Tensor
Tensor
;
...
...
@@ -188,7 +188,14 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
}
std
::
vector
<
float
>
scale_in
{
1.
f
/
127
};
std
::
vector
<
float
>
scale_out
{
weight_dim
.
count
(
1
,
4
)
/
127.
f
};
std
::
vector
<
float
>
scale_out
(
1
,
weight_dim
.
count
(
1
,
4
)
/
127.
f
);
if
(
flag_act
==
2
)
{
scale_out
[
0
]
=
six
/
127.
f
;
}
else
if
(
flag_act
==
4
)
{
if
(
std
::
abs
(
alpha
)
>
1
)
{
scale_out
[
0
]
*=
std
::
abs
(
alpha
);
}
}
std
::
vector
<
float
>
scale_w
(
weight_dim
[
0
],
1.
f
/
127
);
param_int8_out
.
input_scale
=
scale_in
[
0
];
...
...
@@ -484,7 +491,7 @@ TEST(TestConv3x3DWInt8, test_conv3x3_depthwise) {
for
(
auto
&
stride
:
{
1
,
2
})
{
for
(
auto
&
pad
:
{
0
,
1
})
{
for
(
auto
&
flag_bias
:
{
false
,
true
})
{
for
(
auto
&
flag_act
:
{
0
,
1
})
{
for
(
auto
&
flag_act
:
{
0
,
1
,
2
,
4
})
{
for
(
auto
&
c
:
{
1
,
3
,
5
,
8
,
16
,
32
})
{
std
::
vector
<
DDim
>
dims
;
DDim
weights_dim
({
c
,
1
,
3
,
3
});
...
...
@@ -520,7 +527,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
for
(
auto
&
stride
:
{
1
,
2
})
{
for
(
auto
&
pad
:
{
0
,
1
,
2
,
3
,
4
})
{
for
(
auto
&
flag_bias
:
{
false
,
true
})
{
for
(
auto
&
flag_act
:
{
0
,
1
})
{
for
(
auto
&
flag_act
:
{
0
,
1
,
2
,
4
})
{
for
(
auto
&
c
:
{
1
,
5
,
15
,
33
})
{
std
::
vector
<
DDim
>
dims
;
DDim
weights_dim
({
c
,
1
,
5
,
5
});
...
...
@@ -553,7 +560,7 @@ TEST(TestConv5x5DWInt8, test_conv5x5_depthwise) {
#if 1 /// conv1x1s1
TEST
(
TestConv1x1s1Int8
,
test_conv1x1s1
)
{
if
(
FLAGS_basic_test
)
{
for
(
auto
&
cin
:
{
1
,
3
,
8
,
3
2
})
{
for
(
auto
&
cin
:
{
1
,
3
,
8
,
3
3
})
{
for
(
auto
&
cout
:
{
1
,
5
,
17
})
{
for
(
auto
&
g
:
{
1
,
2
})
{
for
(
auto
&
flag_bias
:
{
false
,
true
})
{
...
...
@@ -599,7 +606,7 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
for
(
auto
&
pad_left
:
{
1
,
2
})
{
for
(
auto
&
pad_right
:
{
1
,
2
})
{
for
(
auto
&
flag_bias
:
{
false
,
true
})
{
for
(
auto
&
flag_act
:
{
0
,
1
})
{
for
(
auto
&
flag_act
:
{
0
,
1
,
2
,
4
})
{
std
::
vector
<
DDim
>
dims
;
DDim
weights_dim
({
cout
,
cin
,
3
,
3
});
for
(
auto
&
batch
:
{
1
,
2
})
{
...
...
@@ -641,7 +648,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
for
(
auto
&
pad_left
:
{
1
,
2
})
{
for
(
auto
&
pad_right
:
{
1
,
2
})
{
for
(
auto
&
flag_bias
:
{
false
,
true
})
{
for
(
auto
&
flag_act
:
{
0
,
1
})
{
for
(
auto
&
flag_act
:
{
0
,
1
,
2
,
4
})
{
std
::
vector
<
DDim
>
dims
;
DDim
weights_dim
({
cout
,
cin
,
3
,
3
});
for
(
auto
&
batch
:
{
1
,
2
})
{
...
...
@@ -673,7 +680,7 @@ TEST(TestConv3x3s2Int8, test_conv_3x3s2) {
}
#endif /// conv3x3s2
#if
0
/// random param conv
#if
1
/// random param conv
TEST
(
TestConvRandInt8
,
test_conv_rand
)
{
if
(
FLAGS_basic_test
)
{
for
(
auto
&
cin
:
{
1
,
17
})
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录