Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
f30ae5ff
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
338
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f30ae5ff
编写于
4月 20, 2020
作者:
C
chenjiaoAngel
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add gemm+relu6
上级
a4770bd7
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
396 addition
and
27 deletion
+396
-27
lite/backends/arm/math/conv_impl.cc
lite/backends/arm/math/conv_impl.cc
+5
-2
lite/backends/arm/math/gemm_prepacked_int8.cc
lite/backends/arm/math/gemm_prepacked_int8.cc
+391
-25
未找到文件。
lite/backends/arm/math/conv_impl.cc
浏览文件 @
f30ae5ff
...
@@ -264,6 +264,7 @@ void conv1x1s1_gemm_int8(const int8_t* i_data,
...
@@ -264,6 +264,7 @@ void conv1x1s1_gemm_int8(const int8_t* i_data,
}
}
bool
flag_relu
=
param
.
fuse_relu
;
bool
flag_relu
=
param
.
fuse_relu
;
bool
flag_bias
=
param
.
bias
!=
nullptr
;
bool
flag_bias
=
param
.
bias
!=
nullptr
;
auto
act_param
=
param
.
activation_param
;
//! use gemv when the output channel size = 1
//! use gemv when the output channel size = 1
for
(
int
b
=
0
;
b
<
num
;
++
b
)
{
for
(
int
b
=
0
;
b
<
num
;
++
b
)
{
// dC
// dC
...
@@ -294,9 +295,9 @@ void conv1x1s1_gemm_int8(const int8_t* i_data,
...
@@ -294,9 +295,9 @@ void conv1x1s1_gemm_int8(const int8_t* i_data,
n
,
n
,
k
,
k
,
flag_bias
,
flag_bias
,
flag_relu
,
false
,
false
,
scale_group
,
scale_group
,
act_param
,
ctx
);
ctx
);
}
}
}
}
...
@@ -474,6 +475,8 @@ void conv_im2col_gemm_int8(const int8_t* i_data,
...
@@ -474,6 +475,8 @@ void conv_im2col_gemm_int8(const int8_t* i_data,
bool
flag_relu
=
param
.
fuse_relu
;
bool
flag_relu
=
param
.
fuse_relu
;
bool
flag_bias
=
param
.
bias
!=
nullptr
;
bool
flag_bias
=
param
.
bias
!=
nullptr
;
auto
act_param
=
param
.
activation_param
;
int
hblock
=
get_hblock_int8
(
ctx
);
int
hblock
=
get_hblock_int8
(
ctx
);
int
k_roundup
=
ROUNDUP
(
k
,
KBLOCK_INT8
);
int
k_roundup
=
ROUNDUP
(
k
,
KBLOCK_INT8
);
int
m_roundup
=
ROUNDUP
(
m
,
hblock
);
int
m_roundup
=
ROUNDUP
(
m
,
hblock
);
...
@@ -534,9 +537,9 @@ void conv_im2col_gemm_int8(const int8_t* i_data,
...
@@ -534,9 +537,9 @@ void conv_im2col_gemm_int8(const int8_t* i_data,
n
,
n
,
k
,
k
,
flag_bias
,
flag_bias
,
flag_relu
,
false
,
false
,
scale_group
,
scale_group
,
act_param
,
ctx
);
ctx
);
}
}
}
}
...
...
lite/backends/arm/math/gemm_prepacked_int8.cc
浏览文件 @
f30ae5ff
...
@@ -195,7 +195,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
...
@@ -195,7 +195,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
Dtype
*&
c_ptr2
,
// NOLINT
Dtype
*&
c_ptr2
,
// NOLINT
Dtype
*&
c_ptr3
,
// NOLINT
Dtype
*&
c_ptr3
,
// NOLINT
const
float
*
scale
,
const
float
*
scale
,
bool
is_relu
,
const
float32_t
*
alpha
,
int
is_relu
,
int
k
,
int
k
,
int
rem
);
int
rem
);
// clang-format off
// clang-format off
...
@@ -483,7 +484,10 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
...
@@ -483,7 +484,10 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
#define GEMM_INT8_RELU \
#define GEMM_INT8_RELU \
/* do relu */
\
/* do relu */
\
"cbz %w[is_relu], 9f\n"
/* skip relu */
\
"cmp %w[flag_act], #0\n"
/* skip relu */
\
"beq 9f \n"
/* no act end */
\
"cmp %w[flag_act], #1\n"
/* skip relu */
\
"beq 10f \n"
/* other act */
\
"movi v0.4s, #0\n"
/* for relu */
\
"movi v0.4s, #0\n"
/* for relu */
\
"fmax v16.4s, v16.4s, v0.4s\n"
/* relu */
\
"fmax v16.4s, v16.4s, v0.4s\n"
/* relu */
\
"fmax v17.4s, v17.4s, v0.4s\n"
/* relu */
\
"fmax v17.4s, v17.4s, v0.4s\n"
/* relu */
\
...
@@ -501,6 +505,102 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
...
@@ -501,6 +505,102 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
"fmax v29.4s, v29.4s, v0.4s\n"
/* relu */
\
"fmax v29.4s, v29.4s, v0.4s\n"
/* relu */
\
"fmax v30.4s, v30.4s, v0.4s\n"
/* relu */
\
"fmax v30.4s, v30.4s, v0.4s\n"
/* relu */
\
"fmax v31.4s, v31.4s, v0.4s\n"
/* relu */
\
"fmax v31.4s, v31.4s, v0.4s\n"
/* relu */
\
"b 9f \n"
/* relu end */
#define GEMM_INT8_RELU6 \
/* do relu6 */
\
"10: \n" \
"cmp %w[flag_act], #2 \n"
/* check relu6 */
\
"beq 11f \n"
/* no act end */
\
"movi v0.4s, #0\n"
/* for relu6 */
\
"fmax v16.4s, v16.4s, v0.4s\n"
/* relu */
\
"fmax v17.4s, v17.4s, v0.4s\n"
/* relu */
\
"fmax v18.4s, v18.4s, v0.4s\n"
/* relu */
\
"fmax v19.4s, v19.4s, v0.4s\n"
/* relu */
\
"fmax v20.4s, v20.4s, v0.4s\n"
/* relu */
\
"ld1 {v1.4s}, [%[alpha]] \n"
/* relu6 alpha */
\
"fmax v21.4s, v21.4s, v0.4s\n"
/* relu */
\
"fmax v22.4s, v22.4s, v0.4s\n"
/* relu */
\
"fmax v23.4s, v23.4s, v0.4s\n"
/* relu */
\
"fmax v24.4s, v24.4s, v0.4s\n"
/* relu */
\
"fmax v25.4s, v25.4s, v0.4s\n"
/* relu */
\
"fmax v26.4s, v26.4s, v0.4s\n"
/* relu */
\
"fmax v27.4s, v27.4s, v0.4s\n"
/* relu */
\
"fmax v28.4s, v28.4s, v0.4s\n"
/* relu */
\
"fmax v29.4s, v29.4s, v0.4s\n"
/* relu */
\
"fmax v30.4s, v30.4s, v0.4s\n"
/* relu */
\
"fmax v31.4s, v31.4s, v0.4s\n"
/* relu */
\
"fmin v16.4s, v16.4s, v1.4s\n"
/* relu6 */
\
"fmin v17.4s, v17.4s, v1.4s\n"
/* relu6 */
\
"fmin v18.4s, v18.4s, v1.4s\n"
/* relu6 */
\
"fmin v19.4s, v19.4s, v1.4s\n"
/* relu6 */
\
"fmin v20.4s, v20.4s, v0.4s\n"
/* relu6 */
\
"fmin v21.4s, v21.4s, v0.4s\n"
/* relu6 */
\
"fmin v22.4s, v22.4s, v0.4s\n"
/* relu6 */
\
"fmin v23.4s, v23.4s, v0.4s\n"
/* relu6 */
\
"fmin v24.4s, v24.4s, v0.4s\n"
/* relu6 */
\
"fmin v25.4s, v25.4s, v0.4s\n"
/* relu6 */
\
"fmin v26.4s, v26.4s, v0.4s\n"
/* relu6 */
\
"fmin v27.4s, v27.4s, v0.4s\n"
/* relu6 */
\
"fmin v28.4s, v28.4s, v0.4s\n"
/* relu6 */
\
"fmin v29.4s, v29.4s, v0.4s\n"
/* relu6 */
\
"fmin v30.4s, v30.4s, v0.4s\n"
/* relu6 */
\
"fmin v31.4s, v31.4s, v0.4s\n"
/* relu6 */
\
"b 9f \n"
/* relu end */
#define GEMM_INT8_LEAKY_RELU \
/* do relu */
\
"11: \n" \
"movi v0.4s, #0\n"
/* for relu6 */
\
"ld1 {v1.4s}, [%[alpha]] \n"
/* leakey relu alpha */
\
"fcmge v2.4s, v16.4s, v0.4s \n"
/* vcgeq_f32 */
\
"fmul v3.4s, v16.4s, v1.4s \n"
/* vmulq_f32 */
\
"fcmge v4.4s, v17.4s, v0.4s \n"
/* vcgeq_f32 */
\
"fmul v5.4s, v17.4s, v1.4s \n"
/* vmulq_f32 */
\
"fcmge v6.4s, v18.4s, v0.4s \n"
/* vcgeq_f32 */
\
"fmul v7.4s, v18.4s, v1.4s \n"
/* vmulq_f32 */
\
"fcmge v8.4s, v19.4s, v0.4s \n"
/* vcgeq_f32 */
\
"fmul v9.4s, v19.4s, v1.4s \n"
/* vmulq_f32 */
\
"bif v16.16b, v3.16b, v2.16b \n"
/* choose*/
\
"bif v17.16b, v5.16b, v4.16b \n"
/* choose*/
\
"bif v18.16b, v7.16b, v6.16b \n"
/* choose*/
\
"bif v19.16b, v9.16b, v8.16b \n"
/* choose*/
\
"fcmge v2.4s, v20.4s, v0.4s \n"
/* vcgeq_f32 */
\
"fmul v3.4s, v20.4s, v1.4s \n"
/* vmulq_f32 */
\
"fcmge v4.4s, v21.4s, v0.4s \n"
/* vcgeq_f32 */
\
"fmul v5.4s, v21.4s, v1.4s \n"
/* vmulq_f32 */
\
"fcmge v6.4s, v22.4s, v0.4s \n"
/* vcgeq_f32 */
\
"fmul v7.4s, v22.4s, v1.4s \n"
/* vmulq_f32 */
\
"fcmge v8.4s, v23.4s, v0.4s \n"
/* vcgeq_f32 */
\
"fmul v9.4s, v23.4s, v1.4s \n"
/* vmulq_f32 */
\
"bif v20.16b, v3.16b, v2.16b \n"
/* choose*/
\
"bif v21.16b, v5.16b, v4.16b \n"
/* choose*/
\
"bif v22.16b, v7.16b, v6.16b \n"
/* choose*/
\
"bif v23.16b, v9.16b, v8.16b \n"
/* choose*/
\
"fcmge v2.4s, v24.4s, v0.4s \n"
/* vcgeq_f32 */
\
"fmul v3.4s, v24.4s, v1.4s \n"
/* vmulq_f32 */
\
"fcmge v4.4s, v25.4s, v0.4s \n"
/* vcgeq_f32 */
\
"fmul v5.4s, v25.4s, v1.4s \n"
/* vmulq_f32 */
\
"fcmge v6.4s, v26.4s, v0.4s \n"
/* vcgeq_f32 */
\
"fmul v7.4s, v26.4s, v1.4s \n"
/* vmulq_f32 */
\
"fcmge v8.4s, v27.4s, v0.4s \n"
/* vcgeq_f32 */
\
"fmul v9.4s, v27.4s, v1.4s \n"
/* vmulq_f32 */
\
"bif v24.16b, v3.16b, v2.16b \n"
/* choose*/
\
"bif v25.16b, v5.16b, v4.16b \n"
/* choose*/
\
"bif v26.16b, v7.16b, v6.16b \n"
/* choose*/
\
"bif v27.16b, v9.16b, v8.16b \n"
/* choose*/
\
"fcmge v2.4s, v28.4s, v0.4s \n"
/* vcgeq_f32 */
\
"fmul v3.4s, v28.4s, v1.4s \n"
/* vmulq_f32 */
\
"fcmge v4.4s, v29.4s, v0.4s \n"
/* vcgeq_f32 */
\
"fmul v5.4s, v29.4s, v1.4s \n"
/* vmulq_f32 */
\
"fcmge v6.4s, v30.4s, v0.4s \n"
/* vcgeq_f32 */
\
"fmul v7.4s, v30.4s, v1.4s \n"
/* vmulq_f32 */
\
"fcmge v8.4s, v31.4s, v0.4s \n"
/* vcgeq_f32 */
\
"fmul v9.4s, v31.4s, v1.4s \n"
/* vmulq_f32 */
\
"bif v28.16b, v3.16b, v2.16b \n"
/* choose*/
\
"bif v29.16b, v5.16b, v4.16b \n"
/* choose*/
\
"bif v30.16b, v7.16b, v6.16b \n"
/* choose*/
\
"bif v31.16b, v9.16b, v8.16b \n"
/* choose*/
\
"9:\n"
"9:\n"
#define GEMM_TRANS_INT32_TO_FP32 \
#define GEMM_TRANS_INT32_TO_FP32 \
...
@@ -559,6 +659,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
...
@@ -559,6 +659,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
#define GEMM_INT8_FP32_OUT \
#define GEMM_INT8_FP32_OUT \
GEMM_TRANS_INT32_TO_FP32 \
GEMM_TRANS_INT32_TO_FP32 \
GEMM_INT8_RELU \
GEMM_INT8_RELU \
GEMM_INT8_RELU6 \
GEMM_INT8_LEAKY_RELU \
/* store result */
\
/* store result */
\
"stp q16, q17, [%[c_ptr0]], #32\n" \
"stp q16, q17, [%[c_ptr0]], #32\n" \
"stp q18, q19, [%[c_ptr0]], #32\n" \
"stp q18, q19, [%[c_ptr0]], #32\n" \
...
@@ -571,7 +673,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
...
@@ -571,7 +673,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
#define GEMM_INT8_INT8_OUT \
#define GEMM_INT8_INT8_OUT \
GEMM_TRANS_INT32_TO_FP32 \
GEMM_TRANS_INT32_TO_FP32 \
GEMM_INT8_RELU \
GEMM_INT8_RELU6 \
GEMM_INT8_LEAKY_RELU \
"ld1 {v8.4s}, [%[vmax]] \n"
/* v8 = -127 */
\
"ld1 {v8.4s}, [%[vmax]] \n"
/* v8 = -127 */
\
/* data >= -127 */
\
/* data >= -127 */
\
"fcmge v0.4s, v16.4s, v8.4s\n" \
"fcmge v0.4s, v16.4s, v8.4s\n" \
...
@@ -665,7 +768,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
...
@@ -665,7 +768,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
float32_t
*&
c_ptr2
,
// NOLINT
float32_t
*&
c_ptr2
,
// NOLINT
float32_t
*&
c_ptr3
,
// NOLINT
float32_t
*&
c_ptr3
,
// NOLINT
const
float32_t
*
scale
,
const
float32_t
*
scale
,
bool
is_relu
,
const
float32_t
*
alpha
,
int
is_relu
,
int
k
,
int
k
,
int
rem
)
{
int
rem
)
{
// clang-format off
// clang-format off
...
@@ -678,6 +782,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
...
@@ -678,6 +782,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
[
c_ptr3
]
"+r"
(
c_ptr3
),
[
c_ptr3
]
"+r"
(
c_ptr3
),
[
k
]
"+r"
(
k
)
[
k
]
"+r"
(
k
)
:
[
is_relu
]
"r"
(
is_relu
),
:
[
is_relu
]
"r"
(
is_relu
),
[
alpha
]
"r"
(
alpha
),
[
bias
]
"r"
(
bias
),
[
bias
]
"r"
(
bias
),
[
rem
]
"r"
(
rem
),
[
rem
]
"r"
(
rem
),
[
scale
]
"r"
(
scale
)
[
scale
]
"r"
(
scale
)
...
@@ -698,7 +803,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
...
@@ -698,7 +803,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
int8_t
*&
c_ptr2
,
// NOLINT
int8_t
*&
c_ptr2
,
// NOLINT
int8_t
*&
c_ptr3
,
// NOLINT
int8_t
*&
c_ptr3
,
// NOLINT
const
float32_t
*
scale
,
const
float32_t
*
scale
,
bool
is_relu
,
const
float32_t
*
alpha
,
int
is_relu
,
int
k
,
int
k
,
int
rem
)
{
int
rem
)
{
// clang-format off
// clang-format off
...
@@ -712,6 +818,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
...
@@ -712,6 +818,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
[
c_ptr3
]
"+r"
(
c_ptr3
),
[
c_ptr3
]
"+r"
(
c_ptr3
),
[
k
]
"+r"
(
k
)
[
k
]
"+r"
(
k
)
:
[
is_relu
]
"r"
(
is_relu
),
:
[
is_relu
]
"r"
(
is_relu
),
[
alpha
]
"r"
(
alpha
),
[
bias
]
"r"
(
bias
),
[
bias
]
"r"
(
bias
),
[
rem
]
"r"
(
rem
),
[
rem
]
"r"
(
rem
),
[
scale
]
"r"
(
scale
),
[
scale
]
"r"
(
scale
),
...
@@ -739,7 +846,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
...
@@ -739,7 +846,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
Dtype
*&
c_ptr6
,
// NOLINT
Dtype
*&
c_ptr6
,
// NOLINT
Dtype
*&
c_ptr7
,
// NOLINT
Dtype
*&
c_ptr7
,
// NOLINT
const
float32_t
*
scale
,
const
float32_t
*
scale
,
bool
is_relu
,
const
float32_t
*
alpha
,
int
is_relu
,
int
k
,
int
k
,
int
rem
);
int
rem
);
#if 0
#if 0
...
@@ -1099,12 +1207,47 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
...
@@ -1099,12 +1207,47 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
#endif
#endif
#define GEMM_SDOT_RELU \
#define GEMM_SDOT_RELU \
"cbz %w[relu], 12f\n"
/* skip relu */
\
"cmp %w[relu], #0\n"
/* skip relu */
\
"beq 12f\n" \
"cmp %w[relu], #1\n"
/* skip relu */
\
"beq 13f\n"
/* other act */
\
"movi v2.4s, #0\n"
/* for relu*/
\
"fmax v8.4s, v8.4s, v2.4s\n"
/* relu*/
\
"fmax v9.4s, v9.4s, v2.4s\n"
/* relu*/
\
"fmax v10.4s, v10.4s, v2.4s\n"
/* relu*/
\
"fmax v11.4s, v11.4s, v2.4s\n"
/* relu*/
\
"fmax v12.4s, v12.4s, v2.4s\n"
/* relu*/
\
"fmax v13.4s, v13.4s, v2.4s\n"
/* relu*/
\
"fmax v14.4s, v14.4s, v2.4s\n"
/* relu*/
\
"fmax v15.4s, v15.4s, v2.4s\n"
/* relu*/
\
"fmax v16.4s,v16.4s,v2.4s\n"
/* relu*/
\
"fmax v17.4s,v17.4s,v2.4s\n"
/* relu*/
\
"fmax v18.4s, v18.4s, v2.4s\n"
/* relu*/
\
"fmax v19.4s, v19.4s, v2.4s\n"
/* relu*/
\
"fmax v20.4s, v20.4s, v2.4s\n"
/* relu*/
\
"fmax v21.4s, v21.4s, v2.4s\n"
/* relu*/
\
"fmax v22.4s, v22.4s, v2.4s\n"
/* relu*/
\
"fmax v23.4s, v23.4s, v2.4s\n"
/* relu*/
\
"fmax v24.4s, v24.4s, v2.4s\n"
/* relu*/
\
"fmax v25.4s, v25.4s, v2.4s\n"
/* relu*/
\
"fmax v26.4s, v26.4s, v2.4s\n"
/* relu*/
\
"fmax v27.4s, v27.4s, v2.4s\n"
/* relu*/
\
"fmax v28.4s, v28.4s, v2.4s\n"
/* relu*/
\
"fmax v29.4s, v29.4s, v2.4s\n"
/* relu*/
\
"fmax v30.4s, v30.4s, v2.4s\n"
/* relu*/
\
"fmax v31.4s, v31.4s, v2.4s\n"
/* relu*/
\
"b 12f \n"
/* relu end */
#define GEMM_SDOT_RELU6 \
"13: \n" \
"cmp %w[relu], #2\n"
/* skip relu6 */
\
"beq 14f\n" \
"movi v2.4s, #0\n"
/* for relu*/
\
"movi v2.4s, #0\n"
/* for relu*/
\
"fmax v8.4s, v8.4s, v2.4s\n"
/* relu*/
\
"fmax v8.4s, v8.4s, v2.4s\n"
/* relu*/
\
"fmax v9.4s, v9.4s, v2.4s\n"
/* relu*/
\
"fmax v9.4s, v9.4s, v2.4s\n"
/* relu*/
\
"fmax v10.4s, v10.4s, v2.4s\n"
/* relu*/
\
"fmax v10.4s, v10.4s, v2.4s\n"
/* relu*/
\
"fmax v11.4s, v11.4s, v2.4s\n"
/* relu*/
\
"fmax v11.4s, v11.4s, v2.4s\n"
/* relu*/
\
"ld1 {v3.4s}, [%[alpha]] \n"
/* relu6 alpha */
\
"fmax v12.4s, v12.4s, v2.4s\n"
/* relu*/
\
"fmax v12.4s, v12.4s, v2.4s\n"
/* relu*/
\
"fmax v13.4s, v13.4s, v2.4s\n"
/* relu*/
\
"fmax v13.4s, v13.4s, v2.4s\n"
/* relu*/
\
"fmax v14.4s, v14.4s, v2.4s\n"
/* relu*/
\
"fmax v14.4s, v14.4s, v2.4s\n"
/* relu*/
\
...
@@ -1125,6 +1268,108 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
...
@@ -1125,6 +1268,108 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
"fmax v29.4s, v29.4s, v2.4s\n"
/* relu*/
\
"fmax v29.4s, v29.4s, v2.4s\n"
/* relu*/
\
"fmax v30.4s, v30.4s, v2.4s\n"
/* relu*/
\
"fmax v30.4s, v30.4s, v2.4s\n"
/* relu*/
\
"fmax v31.4s, v31.4s, v2.4s\n"
/* relu*/
\
"fmax v31.4s, v31.4s, v2.4s\n"
/* relu*/
\
"fmin v8.4s, v8.4s, v3.4s\n"
/* relu6*/
\
"fmin v9.4s, v9.4s, v3.4s\n"
/* relu6*/
\
"fmin v10.4s, v10.4s, v3.4s\n"
/* relu6*/
\
"fmin v11.4s, v11.4s, v3.4s\n"
/* relu6*/
\
"fmin v12.4s, v12.4s, v3.4s\n"
/* relu6*/
\
"fmin v13.4s, v13.4s, v3.4s\n"
/* relu6*/
\
"fmin v14.4s, v14.4s, v3.4s\n"
/* relu6*/
\
"fmin v15.4s, v15.4s, v3.4s\n"
/* relu6*/
\
"fmin v16.4s, v16.4s, v3.4s\n"
/* relu6*/
\
"fmin v17.4s, v17.4s, v3.4s\n"
/* relu6*/
\
"fmin v18.4s, v18.4s, v3.4s\n"
/* relu6*/
\
"fmin v19.4s, v19.4s, v3.4s\n"
/* relu6*/
\
"fmin v20.4s, v20.4s, v3.4s\n"
/* relu6*/
\
"fmin v21.4s, v21.4s, v3.4s\n"
/* relu6*/
\
"fmin v22.4s, v22.4s, v3.4s\n"
/* relu6*/
\
"fmin v23.4s, v23.4s, v3.4s\n"
/* relu6*/
\
"fmin v24.4s, v24.4s, v3.4s\n"
/* relu6*/
\
"fmin v25.4s, v25.4s, v3.4s\n"
/* relu6*/
\
"fmin v26.4s, v26.4s, v3.4s\n"
/* relu6*/
\
"fmin v27.4s, v27.4s, v3.4s\n"
/* relu6*/
\
"fmin v28.4s, v28.4s, v3.4s\n"
/* relu6*/
\
"fmin v29.4s, v29.4s, v3.4s\n"
/* relu6*/
\
"fmin v30.4s, v30.4s, v3.4s\n"
/* relu6*/
\
"fmin v31.4s, v31.4s, v3.4s\n"
/* relu6*/
\
"b 12f \n"
/* relu end */
#define GEMM_SDOT_LEAKY_RELU \
"14: \n" \
"movi v2.4s, #0\n"
/* for leakyrelu*/
\
"ld1 {v3.4s}, [%[alpha]]\n"
/* leakyrelu alpha */
\
"fcmge v4.4s, v8.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v5.4s, v8.4s, v3.4s \n"
/* vmulq_f32 */
\
"fcmge v6.4s, v9.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v7.4s, v9.4s, v3.4s \n"
/* vmulq_f32 */
\
"bif v8.16b, v5.16b, v4.16b \n"
/* choose*/
\
"bif v9.16b, v7.16b, v6.16b \n"
/* choose*/
\
"fcmge v4.4s, v10.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v5.4s, v10.4s, v3.4s \n"
/* vmulq_f32 */
\
"fcmge v6.4s, v11.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v7.4s, v11.4s, v3.4s \n"
/* vmulq_f32 */
\
"bif v10.16b, v5.16b, v4.16b \n"
/* choose*/
\
"bif v11.16b, v7.16b, v6.16b \n"
/* choose*/
\
"fcmge v4.4s, v12.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v5.4s, v12.4s, v3.4s \n"
/* vmulq_f32 */
\
"fcmge v6.4s, v13.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v7.4s, v13.4s, v3.4s \n"
/* vmulq_f32 */
\
"bif v12.16b, v5.16b, v4.16b \n"
/* choose*/
\
"bif v13.16b, v7.16b, v6.16b \n"
/* choose*/
\
"fcmge v4.4s, v14.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v5.4s, v14.4s, v3.4s \n"
/* vmulq_f32 */
\
"fcmge v6.4s, v15.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v7.4s, v15.4s, v3.4s \n"
/* vmulq_f32 */
\
"bif v14.16b, v5.16b, v4.16b \n"
/* choose*/
\
"bif v15.16b, v7.16b, v6.16b \n"
/* choose*/
\
"fcmge v4.4s, v16.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v5.4s, v16.4s, v3.4s \n"
/* vmulq_f32 */
\
"fcmge v6.4s, v17.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v7.4s, v17.4s, v3.4s \n"
/* vmulq_f32 */
\
"bif v16.16b, v5.16b, v4.16b \n"
/* choose*/
\
"bif v17.16b, v7.16b, v6.16b \n"
/* choose*/
\
"fcmge v4.4s, v18.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v5.4s, v18.4s, v3.4s \n"
/* vmulq_f32 */
\
"fcmge v6.4s, v19.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v7.4s, v19.4s, v3.4s \n"
/* vmulq_f32 */
\
"bif v18.16b, v5.16b, v4.16b \n"
/* choose*/
\
"bif v19.16b, v7.16b, v6.16b \n"
/* choose*/
\
"fcmge v4.4s, v20.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v5.4s, v20.4s, v3.4s \n"
/* vmulq_f32 */
\
"fcmge v6.4s, v21.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v7.4s, v21.4s, v3.4s \n"
/* vmulq_f32 */
\
"bif v20.16b, v5.16b, v4.16b \n"
/* choose*/
\
"bif v21.16b, v7.16b, v6.16b \n"
/* choose*/
\
"fcmge v4.4s, v22.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v5.4s, v22.4s, v3.4s \n"
/* vmulq_f32 */
\
"fcmge v6.4s, v23.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v7.4s, v23.4s, v3.4s \n"
/* vmulq_f32 */
\
"bif v22.16b, v5.16b, v4.16b \n"
/* choose*/
\
"bif v23.16b, v7.16b, v6.16b \n"
/* choose*/
\
"fcmge v4.4s, v24.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v5.4s, v24.4s, v3.4s \n"
/* vmulq_f32 */
\
"fcmge v6.4s, v25.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v7.4s, v25.4s, v3.4s \n"
/* vmulq_f32 */
\
"bif v24.16b, v5.16b, v4.16b \n"
/* choose*/
\
"bif v25.16b, v7.16b, v6.16b \n"
/* choose*/
\
"fcmge v4.4s, v26.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v5.4s, v26.4s, v3.4s \n"
/* vmulq_f32 */
\
"fcmge v6.4s, v27.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v7.4s, v27.4s, v3.4s \n"
/* vmulq_f32 */
\
"bif v26.16b, v5.16b, v4.16b \n"
/* choose*/
\
"bif v27.16b, v7.16b, v6.16b \n"
/* choose*/
\
"fcmge v4.4s, v28.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v5.4s, v28.4s, v3.4s \n"
/* vmulq_f32 */
\
"fcmge v6.4s, v29.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v7.4s, v29.4s, v3.4s \n"
/* vmulq_f32 */
\
"bif v28.16b, v5.16b, v4.16b \n"
/* choose*/
\
"bif v29.16b, v7.16b, v6.16b \n"
/* choose*/
\
"fcmge v4.4s, v30.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v5.4s, v30.4s, v3.4s \n"
/* vmulq_f32 */
\
"fcmge v6.4s, v31.4s, v2.4s \n"
/* vcgeq_f32 */
\
"fmul v7.4s, v32.4s, v3.4s \n"
/* vmulq_f32 */
\
"bif v30.16b, v5.16b, v4.16b \n"
/* choose*/
\
"bif v31.16b, v7.16b, v6.16b \n"
/* choose*/
\
"12: \n"
"12: \n"
#define GEMM_SDOT_CVT_INT32_TO_FP32 \
#define GEMM_SDOT_CVT_INT32_TO_FP32 \
...
@@ -1206,6 +1451,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
...
@@ -1206,6 +1451,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
#define GEMM_SDOT_FP32_OUT \
#define GEMM_SDOT_FP32_OUT \
GEMM_SDOT_CVT_INT32_TO_FP32 \
GEMM_SDOT_CVT_INT32_TO_FP32 \
GEMM_SDOT_RELU \
GEMM_SDOT_RELU \
GEMM_SDOT_RELU6 \
GEMM_SDOT_LEAKY_RELU \
"st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n"
/* store r0 */
\
"st1 {v8.4s, v9.4s, v10.4s},[%[c_ptr0]], #48\n"
/* store r0 */
\
"st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n"
/* store r1 */
\
"st1 {v11.4s, v12.4s, v13.4s},[%[c_ptr1]], #48\n"
/* store r1 */
\
"st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n"
/* store r2 */
\
"st1 {v14.4s, v15.4s, v16.4s},[%[c_ptr2]], #48\n"
/* store r2 */
\
...
@@ -1218,6 +1465,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
...
@@ -1218,6 +1465,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
#define GEMM_SDOT_INT8_OUT \
#define GEMM_SDOT_INT8_OUT \
GEMM_SDOT_CVT_INT32_TO_FP32 \
GEMM_SDOT_CVT_INT32_TO_FP32 \
GEMM_SDOT_RELU \
GEMM_SDOT_RELU \
GEMM_SDOT_RELU6 \
GEMM_SDOT_LEAKY_RELU \
"ld1 {v6.4s}, [%[vmax]]\n"
/* v8 = -127.f */
\
"ld1 {v6.4s}, [%[vmax]]\n"
/* v8 = -127.f */
\
/* data >= -127 */
\
/* data >= -127 */
\
"fcmge v0.4s, v8.4s, v6.4s\n" \
"fcmge v0.4s, v8.4s, v6.4s\n" \
...
@@ -1371,7 +1620,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
...
@@ -1371,7 +1620,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
float32_t
*&
c_ptr6
,
// NOLINT
float32_t
*&
c_ptr6
,
// NOLINT
float32_t
*&
c_ptr7
,
// NOLINT
float32_t
*&
c_ptr7
,
// NOLINT
const
float32_t
*
scale
,
const
float32_t
*
scale
,
bool
is_relu
,
const
float32_t
*
alpha
,
int
is_relu
,
int
k
,
int
k
,
int
tail
)
{
int
tail
)
{
// clang-format off
// clang-format off
...
@@ -1389,7 +1639,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
...
@@ -1389,7 +1639,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
[
c_ptr5
]
"+r"
(
c_ptr5
),
[
c_ptr5
]
"+r"
(
c_ptr5
),
[
c_ptr6
]
"+r"
(
c_ptr6
),
[
c_ptr6
]
"+r"
(
c_ptr6
),
[
c_ptr7
]
"+r"
(
c_ptr7
)
[
c_ptr7
]
"+r"
(
c_ptr7
)
:
[
bias_ptr
]
"r"
(
bias
),
[
scale
]
"r"
(
scale
),
[
relu
]
"r"
(
is_relu
)
:
[
bias_ptr
]
"r"
(
bias
),
[
scale
]
"r"
(
scale
),
[
relu
]
"r"
(
is_relu
),
[
alpha
]
"r"
(
alpha
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
...
@@ -1410,7 +1661,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
...
@@ -1410,7 +1661,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
int8_t
*&
c_ptr6
,
// NOLINT
int8_t
*&
c_ptr6
,
// NOLINT
int8_t
*&
c_ptr7
,
// NOLINT
int8_t
*&
c_ptr7
,
// NOLINT
const
float32_t
*
scale
,
const
float32_t
*
scale
,
bool
is_relu
,
const
float32_t
*
alpha
,
int
is_relu
,
int
k
,
int
k
,
int
tail
)
{
int
tail
)
{
// clang-format off
// clang-format off
...
@@ -1428,7 +1680,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
...
@@ -1428,7 +1680,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
[
c_ptr5
]
"+r"
(
c_ptr5
),
[
c_ptr5
]
"+r"
(
c_ptr5
),
[
c_ptr6
]
"+r"
(
c_ptr6
),
[
c_ptr6
]
"+r"
(
c_ptr6
),
[
c_ptr7
]
"+r"
(
c_ptr7
)
[
c_ptr7
]
"+r"
(
c_ptr7
)
:
[
bias_ptr
]
"r"
(
bias
),
[
scale
]
"r"
(
scale
),
[
relu
]
"r"
(
is_relu
),
[
vmax
]
"r"
(
vmax
)
:
[
bias_ptr
]
"r"
(
bias
),
[
scale
]
"r"
(
scale
),
[
relu
]
"r"
(
is_relu
),
[
vmax
]
"r"
(
vmax
),
[
alpha
]
"r"
(
alpha
)
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
:
"cc"
,
"memory"
,
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"v16"
,
"v17"
,
...
@@ -1654,6 +1907,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
...
@@ -1654,6 +1907,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
/* do relu */
\
/* do relu */
\
"cmp %[is_relu], #0\n"
/* skip relu */
\
"cmp %[is_relu], #0\n"
/* skip relu */
\
"beq 9f\n"
/* skip relu */
\
"beq 9f\n"
/* skip relu */
\
"cmp %[is_relu], #1\n"
/* check if has relu6 */
\
"beq 10f\n"
/* skip relu */
\
"vmov.i32 q15, #0\n"
/* for relu */
\
"vmov.i32 q15, #0\n"
/* for relu */
\
"vmax.f32 q8, q8, q15\n"
/* relu */
\
"vmax.f32 q8, q8, q15\n"
/* relu */
\
"vmax.f32 q9, q9, q15\n"
/* relu */
\
"vmax.f32 q9, q9, q15\n"
/* relu */
\
...
@@ -1663,12 +1918,69 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
...
@@ -1663,12 +1918,69 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
"vmax.f32 q3,q3, q15\n"
/* relu */
\
"vmax.f32 q3,q3, q15\n"
/* relu */
\
"vmax.f32 q4,q4, q15\n"
/* relu */
\
"vmax.f32 q4,q4, q15\n"
/* relu */
\
"vmax.f32 q5,q5, q15\n"
/* relu */
\
"vmax.f32 q5,q5, q15\n"
/* relu */
\
"
9:
\n"
"
b: 9f
\n"
#define GEMM_INT8_RELU6 \
/* do relu6 */
\
"10: \n" \
"cmp %[is_relu], #2\n"
/*heck if has relu6*/
\
"beq 11f\n"
/* skip relu */
\
"vmov.i32 q15, #0\n"
/* for relu */
\
"vmax.f32 q8, q8, q15\n"
/* relu */
\
"vmax.f32 q9, q9, q15\n"
/* relu */
\
"vmax.f32 q0,q0, q15\n"
/* relu */
\
"vmax.f32 q1,q1, q15\n"
/* relu */
\
"vld1.f32 {d28-d29}, [%[alpha]] @ load relu6 alpha\n" \
"vmax.f32 q2,q2, q15\n"
/* relu */
\
"vmax.f32 q3,q3, q15\n"
/* relu */
\
"vmax.f32 q4,q4, q15\n"
/* relu */
\
"vmax.f32 q5,q5, q15\n"
/* relu */
\
"vmin.f32 q8, q8, q14\n"
/* relu6 */
\
"vmin.f32 q9, q9, q14\n"
/* relu6 */
\
"vmin.f32 q0,q0, q14\n"
/* relu6 */
\
"vmin.f32 q1,q1, q14\n"
/* relu6 */
\
"vmin.f32 q2,q2, q14\n"
/* relu6 */
\
"vmin.f32 q3,q3, q14\n"
/* relu6 */
\
"vmin.f32 q4,q4, q14\n"
/* relu6 */
\
"vmin.f32 q5,q5, q14\n"
/* relu6 */
\
"b: 9f\n"
#define GEMM_INT8_LEAKY_RELU \
/* do relu6 */
\
"11: \n" \
"vmov.i32 q15, #0\n"
/* for relu */
\
"vld1.f32 {d28-d29}, [%[alpha]] @ load relu6 alpha\n" \
"vcge.f32 q6, q8, q15 @ vcgeq_u32 \n" \
"vmul.f32 q7, q8, q14 @ vmulq_f32 \n" \
"vcge.f32 q10, q9, q15 @ vcgeq_u32 \n" \
"vmul.f32 q11, q9, q14 @ vmulq_f32 \n" \
"vcge.f32 q12, q0, q15 @ vcgeq_u32 \n" \
"vmul.f32 q13, q0, q14 @ vmulq_f32 \n" \
"vbif q8, q7, q6 @ choose \n" \
"vbif q9, q11, q10 @ choose \n" \
"vbif q0, q13, q12 @ choose \n" \
"vcge.f32 q6, q1, q15 @ vcgeq_u32 \n" \
"vmul.f32 q7, q1, q14 @ vmulq_f32 \n" \
"vcge.f32 q10, q2, q15 @ vcgeq_u32 \n" \
"vmul.f32 q11, q2, q14 @ vmulq_f32 \n" \
"vcge.f32 q12, q3, q15 @ vcgeq_u32 \n" \
"vmul.f32 q13, q3, q14 @ vmulq_f32 \n" \
"vbif q1, q7, q6 @ choose \n" \
"vbif q2, q11, q10 @ choose \n" \
"vbif q3, q13, q12 @ choose \n" \
"vcge.f32 q6, q4, q15 @ vcgeq_u32 \n" \
"vmul.f32 q7, q4, q14 @ vmulq_f32 \n" \
"vcge.f32 q10, q5, q15 @ vcgeq_u32 \n" \
"vmul.f32 q11, q5, q14 @ vmulq_f32 \n" \
"vbif q4, q7, q6 @ choose \n" \
"vbif q5, q11, q10 @ choose \n" \
"9: \n"
#define GEMM_INT8_FP32_OUT \
#define GEMM_INT8_FP32_OUT \
GEMM_INT8_TRANS_INT32_TO_FP32 \
GEMM_INT8_TRANS_INT32_TO_FP32 \
GEMM_INT8_RELU \
GEMM_INT8_RELU \
GEMM_INT8_RELU6 \
GEMM_INT8_LEAKY_RELU \
"vst1.32 {d16-d19}, [%[c_ptr0]]!\n"
/* write r0, float32x4 x2 */
\
"vst1.32 {d16-d19}, [%[c_ptr0]]!\n"
/* write r0, float32x4 x2 */
\
"vst1.32 {d0-d3}, [%[c_ptr1]]!\n"
/* write r1, float32x4 x2 */
\
"vst1.32 {d0-d3}, [%[c_ptr1]]!\n"
/* write r1, float32x4 x2 */
\
"vst1.32 {d4-d7}, [%[c_ptr2]]!\n"
/* write r2, float32x4 x2 */
\
"vst1.32 {d4-d7}, [%[c_ptr2]]!\n"
/* write r2, float32x4 x2 */
\
...
@@ -1678,6 +1990,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
...
@@ -1678,6 +1990,8 @@ inline void gemm_sdot_int8_kernel(const int8_t* a_ptr,
#define GEMM_INT8_INT8_OUT \
#define GEMM_INT8_INT8_OUT \
GEMM_INT8_TRANS_INT32_TO_FP32 \
GEMM_INT8_TRANS_INT32_TO_FP32 \
GEMM_INT8_RELU \
GEMM_INT8_RELU \
GEMM_INT8_RELU6 \
GEMM_INT8_LEAKY_RELU \
"vmov.f32 q7, #-0.5\n"
/* neg offset */
\
"vmov.f32 q7, #-0.5\n"
/* neg offset */
\
"vmov.f32 q10, #0.5\n"
/* pos offset */
\
"vmov.f32 q10, #0.5\n"
/* pos offset */
\
"vmov.f32 q11, #0.5\n"
/* pos offset */
\
"vmov.f32 q11, #0.5\n"
/* pos offset */
\
...
@@ -1765,7 +2079,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
...
@@ -1765,7 +2079,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
float32_t
*&
c_ptr2
,
// NOLINT
float32_t
*&
c_ptr2
,
// NOLINT
float32_t
*&
c_ptr3
,
// NOLINT
float32_t
*&
c_ptr3
,
// NOLINT
const
float32_t
*
scale
,
const
float32_t
*
scale
,
bool
is_relu
,
const
float32_t
*
alpha
,
int
is_relu
,
int
k
,
int
k
,
int
rem
)
{
int
rem
)
{
asm
volatile
(
GEMM_INT8_KERNEL
GEMM_INT8_FP32_OUT
asm
volatile
(
GEMM_INT8_KERNEL
GEMM_INT8_FP32_OUT
...
@@ -1778,6 +2093,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
...
@@ -1778,6 +2093,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
[
k
]
"+r"
(
k
)
[
k
]
"+r"
(
k
)
:
[
is_relu
]
"r"
(
is_relu
),
:
[
is_relu
]
"r"
(
is_relu
),
[
bias
]
"r"
(
bias
),
[
bias
]
"r"
(
bias
),
[
alpha
]
"r"
(
alpha
),
[
rem
]
"r"
(
rem
),
[
rem
]
"r"
(
rem
),
[
scale
]
"r"
(
scale
)
[
scale
]
"r"
(
scale
)
:
"q0"
,
:
"q0"
,
...
@@ -1810,7 +2126,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
...
@@ -1810,7 +2126,8 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
int8_t
*&
c_ptr2
,
// NOLINT
int8_t
*&
c_ptr2
,
// NOLINT
int8_t
*&
c_ptr3
,
// NOLINT
int8_t
*&
c_ptr3
,
// NOLINT
const
float32_t
*
scale
,
const
float32_t
*
scale
,
bool
is_relu
,
const
float32_t
*
alpha
,
int
is_relu
,
int
k
,
int
k
,
int
rem
)
{
int
rem
)
{
float
vmax
[
4
]
=
{
-
127.0
,
-
127.0
,
-
127.0
,
-
127.0
};
float
vmax
[
4
]
=
{
-
127.0
,
-
127.0
,
-
127.0
,
-
127.0
};
...
@@ -1823,6 +2140,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
...
@@ -1823,6 +2140,7 @@ inline void gemm_int8_kernel(const int8_t* a_ptr,
[
c_ptr3
]
"+r"
(
c_ptr3
),
[
c_ptr3
]
"+r"
(
c_ptr3
),
[
k
]
"+r"
(
k
)
[
k
]
"+r"
(
k
)
:
[
is_relu
]
"r"
(
is_relu
),
:
[
is_relu
]
"r"
(
is_relu
),
[
alpha
]
"r"
(
alpha
),
[
bias
]
"r"
(
bias
),
[
bias
]
"r"
(
bias
),
[
rem
]
"r"
(
rem
),
[
rem
]
"r"
(
rem
),
[
vmax
]
"r"
(
vmax
),
[
vmax
]
"r"
(
vmax
),
...
@@ -1859,9 +2177,10 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
...
@@ -1859,9 +2177,10 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
int
N
,
int
N
,
int
K
,
int
K
,
bool
is_bias
,
bool
is_bias
,
bool
is_relu
,
int
flag_act
,
bool
is_transB
,
bool
is_transB
,
const
float
*
scale
,
const
float
*
scale
,
const
float
*
alpha
,
ARMContext
*
ctx
)
{
ARMContext
*
ctx
)
{
const
int
KUP
=
ROUNDUP
(
K
,
KBLOCK_INT8
);
const
int
KUP
=
ROUNDUP
(
K
,
KBLOCK_INT8
);
size_t
llc_size
=
ctx
->
llc_size
()
/
4
;
size_t
llc_size
=
ctx
->
llc_size
()
/
4
;
...
@@ -1969,7 +2288,8 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
...
@@ -1969,7 +2288,8 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
c_ptr2
,
c_ptr2
,
c_ptr3
,
c_ptr3
,
scale_local
,
scale_local
,
is_relu
,
alpha
,
flag_act
,
k
,
k
,
k_rem
);
k_rem
);
if
(
flag_rem
&&
(
xb
==
bblocks
-
1
))
{
if
(
flag_rem
&&
(
xb
==
bblocks
-
1
))
{
...
@@ -3090,9 +3410,10 @@ void gemm_prepack_sdot_int8(const int8_t* A_packed,
...
@@ -3090,9 +3410,10 @@ void gemm_prepack_sdot_int8(const int8_t* A_packed,
int
N
,
int
N
,
int
K
,
int
K
,
bool
is_bias
,
bool
is_bias
,
bool
is_relu
,
int
is_relu
,
bool
is_transB
,
bool
is_transB
,
const
float
*
scale
,
const
float
*
scale
,
const
float
*
alpha
,
ARMContext
*
ctx
)
{
ARMContext
*
ctx
)
{
size_t
llc_size
=
ctx
->
llc_size
()
/
4
;
size_t
llc_size
=
ctx
->
llc_size
()
/
4
;
auto
workspace
=
ctx
->
workspace_data
<
int8_t
>
();
auto
workspace
=
ctx
->
workspace_data
<
int8_t
>
();
...
@@ -3250,6 +3571,7 @@ void gemm_prepack_sdot_int8(const int8_t* A_packed,
...
@@ -3250,6 +3571,7 @@ void gemm_prepack_sdot_int8(const int8_t* A_packed,
c_ptr6
,
c_ptr6
,
c_ptr7
,
c_ptr7
,
scale_local
,
scale_local
,
alpha
,
is_relu
,
is_relu
,
k
,
k
,
tail
);
tail
);
...
@@ -3871,21 +4193,43 @@ void gemm_prepack_int8(const int8_t* A_packed,
...
@@ -3871,21 +4193,43 @@ void gemm_prepack_int8(const int8_t* A_packed,
int
N
,
int
N
,
int
K
,
int
K
,
bool
is_bias
,
bool
is_bias
,
bool
is_relu
,
bool
is_transB
,
bool
is_transB
,
const
float
*
scale
,
const
float
*
scale
,
const
operators
::
ActivationParam
act_param
,
ARMContext
*
ctx
)
{
ARMContext
*
ctx
)
{
auto
act_type
=
act_param
.
active_type
;
float
alpha
[
4
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
int
flag_act
=
0x00
;
// relu: 1, relu6: 2, leakey: 3
if
(
act_param
.
has_active
)
{
if
(
act_type
==
lite_api
::
ActivationType
::
kRelu
)
{
flag_act
=
0x01
;
}
else
if
(
act_type
==
lite_api
::
ActivationType
::
kRelu6
)
{
flag_act
=
0x02
;
float
local_alpha
=
act_param
.
Relu_clipped_coef
;
alpha
[
0
]
=
local_alpha
;
alpha
[
1
]
=
local_alpha
;
alpha
[
2
]
=
local_alpha
;
alpha
[
3
]
=
local_alpha
;
}
else
if
(
act_type
==
lite_api
::
ActivationType
::
kLeakyRelu
)
{
flag_act
=
0x03
;
float
local_alpha
=
act_param
.
Leaky_relu_alpha
;
alpha
[
0
]
=
local_alpha
;
alpha
[
1
]
=
local_alpha
;
alpha
[
2
]
=
local_alpha
;
alpha
[
3
]
=
local_alpha
;
}
}
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
if
(
ctx
->
has_dot
())
{
if
(
ctx
->
has_dot
())
{
gemm_prepack_sdot_int8
<
float32_t
>
(
gemm_prepack_sdot_int8
<
float32_t
>
(
A_packed
,
B
,
bias
,
C
,
M
,
N
,
K
,
is_bias
,
is_relu
,
is_transB
,
scale
,
ctx
);
A_packed
,
B
,
bias
,
C
,
M
,
N
,
K
,
is_bias
,
flag_act
,
is_transB
,
scale
,
alpha
,
ctx
);
}
else
{
}
else
{
gemm_prepack_oth_int8
<
float32_t
>
(
gemm_prepack_oth_int8
<
float32_t
>
(
A_packed
,
B
,
bias
,
C
,
M
,
N
,
K
,
is_bias
,
is_relu
,
is_transB
,
scale
,
ctx
);
A_packed
,
B
,
bias
,
C
,
M
,
N
,
K
,
is_bias
,
flag_act
,
is_transB
,
scale
,
alpha
,
ctx
);
}
}
#else
#else
gemm_prepack_oth_int8
<
float32_t
>
(
gemm_prepack_oth_int8
<
float32_t
>
(
A_packed
,
B
,
bias
,
C
,
M
,
N
,
K
,
is_bias
,
is_relu
,
is_transB
,
scale
,
ctx
);
A_packed
,
B
,
bias
,
C
,
M
,
N
,
K
,
is_bias
,
flag_act
,
is_transB
,
scale
,
alpha
,
ctx
);
#endif
#endif
}
}
...
@@ -3898,21 +4242,43 @@ void gemm_prepack_int8(const int8_t* A_packed,
...
@@ -3898,21 +4242,43 @@ void gemm_prepack_int8(const int8_t* A_packed,
int
N
,
int
N
,
int
K
,
int
K
,
bool
is_bias
,
bool
is_bias
,
bool
is_relu
,
bool
is_transB
,
bool
is_transB
,
const
float
*
scale
,
const
float
*
scale
,
const
operators
::
ActivationParam
act_param
,
ARMContext
*
ctx
)
{
ARMContext
*
ctx
)
{
auto
act_type
=
act_param
.
active_type
;
float
alpha
[
4
]
=
{
0.
f
,
0.
f
,
0.
f
,
0.
f
};
int
flag_act
=
0x00
;
// relu: 1, relu6: 2, leakey: 3
if
(
act_param
.
has_active
)
{
if
(
act_type
==
lite_api
::
ActivationType
::
kRelu
)
{
flag_act
=
0x01
;
}
else
if
(
act_type
==
lite_api
::
ActivationType
::
kRelu6
)
{
flag_act
=
0x02
;
float
local_alpha
=
act_param
.
Relu_clipped_coef
;
alpha
[
0
]
=
local_alpha
;
alpha
[
1
]
=
local_alpha
;
alpha
[
2
]
=
local_alpha
;
alpha
[
3
]
=
local_alpha
;
}
else
if
(
act_type
==
lite_api
::
ActivationType
::
kLeakyRelu
)
{
flag_act
=
0x03
;
float
local_alpha
=
act_param
.
Leaky_relu_alpha
;
alpha
[
0
]
=
local_alpha
;
alpha
[
1
]
=
local_alpha
;
alpha
[
2
]
=
local_alpha
;
alpha
[
3
]
=
local_alpha
;
}
}
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
#if defined(__aarch64__) && defined(WITH_ARM_DOTPROD)
if
(
ctx
->
has_dot
())
{
if
(
ctx
->
has_dot
())
{
gemm_prepack_sdot_int8
<
int8_t
>
(
gemm_prepack_sdot_int8
<
int8_t
>
(
A_packed
,
B
,
bias
,
C
,
M
,
N
,
K
,
is_bias
,
is_relu
,
is_transB
,
scale
,
ctx
);
A_packed
,
B
,
bias
,
C
,
M
,
N
,
K
,
is_bias
,
flag_act
,
is_transB
,
scale
,
alpha
,
ctx
);
}
else
{
}
else
{
gemm_prepack_oth_int8
<
int8_t
>
(
gemm_prepack_oth_int8
<
int8_t
>
(
A_packed
,
B
,
bias
,
C
,
M
,
N
,
K
,
is_bias
,
is_relu
,
is_transB
,
scale
,
ctx
);
A_packed
,
B
,
bias
,
C
,
M
,
N
,
K
,
is_bias
,
flag_act
,
is_transB
,
scale
,
alpha
,
ctx
);
}
}
#else
#else
gemm_prepack_oth_int8
<
int8_t
>
(
gemm_prepack_oth_int8
<
int8_t
>
(
A_packed
,
B
,
bias
,
C
,
M
,
N
,
K
,
is_bias
,
is_relu
,
is_transB
,
scale
,
ctx
);
A_packed
,
B
,
bias
,
C
,
M
,
N
,
K
,
is_bias
,
flag_act
,
is_transB
,
scale
,
alpha
,
ctx
);
#endif
#endif
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录