Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
34f7b509
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
338
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
34f7b509
编写于
5月 12, 2020
作者:
M
MyPandaShaoxiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat: add winograd int8 kernel
test=develop
上级
b0c58df2
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
2045 addition
and
20 deletion
+2045
-20
lite/backends/arm/math/CMakeLists.txt
lite/backends/arm/math/CMakeLists.txt
+1
-0
lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc
lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc
+2
-2
lite/backends/arm/math/conv3x3_winograd_int8.cc
lite/backends/arm/math/conv3x3_winograd_int8.cc
+570
-0
lite/backends/arm/math/conv_block_utils.h
lite/backends/arm/math/conv_block_utils.h
+4
-3
lite/backends/arm/math/conv_impl.h
lite/backends/arm/math/conv_impl.h
+29
-0
lite/backends/arm/math/packed_sgemm_c4.cc
lite/backends/arm/math/packed_sgemm_c4.cc
+906
-0
lite/backends/arm/math/packed_sgemm_c4.h
lite/backends/arm/math/packed_sgemm_c4.h
+7
-0
lite/kernels/arm/conv_compute.cc
lite/kernels/arm/conv_compute.cc
+6
-7
lite/kernels/arm/conv_winograd.cc
lite/kernels/arm/conv_winograd.cc
+180
-1
lite/kernels/arm/conv_winograd.h
lite/kernels/arm/conv_winograd.h
+21
-1
lite/tests/math/sgemm_c4_compute_test.cc
lite/tests/math/sgemm_c4_compute_test.cc
+186
-6
lite/tests/utils/naive_math_impl.h
lite/tests/utils/naive_math_impl.h
+120
-0
lite/tests/utils/tensor_utils.h
lite/tests/utils/tensor_utils.h
+13
-0
未找到文件。
lite/backends/arm/math/CMakeLists.txt
浏览文件 @
34f7b509
...
@@ -83,6 +83,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
...
@@ -83,6 +83,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
conv5x5s2_depthwise_int8.cc
conv5x5s2_depthwise_int8.cc
conv5x5s2_depthwise_fp32.cc
conv5x5s2_depthwise_fp32.cc
conv3x3_winograd_fp32_c4.cc
conv3x3_winograd_fp32_c4.cc
conv3x3_winograd_int8.cc
conv_winograd_3x3.cc
conv_winograd_3x3.cc
conv_impl.cc
conv_impl.cc
softmax.cc
softmax.cc
...
...
lite/backends/arm/math/conv3x3_winograd_fp32_c4.cc
浏览文件 @
34f7b509
...
@@ -1245,7 +1245,7 @@ void weight_trans_c4_8x8(
...
@@ -1245,7 +1245,7 @@ void weight_trans_c4_8x8(
for
(
int
i
=
0
;
i
<
ch_out
*
ch_in
*
64
;
++
i
)
{
for
(
int
i
=
0
;
i
<
ch_out
*
ch_in
*
64
;
++
i
)
{
int
new_c
=
i
%
64
;
int
new_c
=
i
%
64
;
int
new_oc
=
i
/
ch_in
/
64
/
4
;
int
new_oc
=
i
/
ch_in
/
64
/
4
;
int
new_ic
=
i
/
64
%
(
ch_in
*
4
)
%
ch_in
;
int
new_ic
=
i
/
64
%
ch_in
;
int
new_inner
=
i
/
ch_in
/
64
%
4
;
int
new_inner
=
i
/
ch_in
/
64
%
4
;
int
dest_ind
=
int
dest_ind
=
new_c
*
c_stride
+
new_oc
*
ic_pad
*
4
+
new_ic
*
4
+
new_inner
;
new_c
*
c_stride
+
new_oc
*
ic_pad
*
4
+
new_ic
*
4
+
new_inner
;
...
@@ -1302,7 +1302,7 @@ void weight_trans_c4_4x4(
...
@@ -1302,7 +1302,7 @@ void weight_trans_c4_4x4(
for
(
int
i
=
0
;
i
<
ch_out
*
ch_in
*
16
;
++
i
)
{
for
(
int
i
=
0
;
i
<
ch_out
*
ch_in
*
16
;
++
i
)
{
int
new_c
=
i
%
16
;
int
new_c
=
i
%
16
;
int
new_oc
=
i
/
ch_in
/
16
/
4
;
int
new_oc
=
i
/
ch_in
/
16
/
4
;
int
new_ic
=
i
/
16
%
(
ch_in
*
4
)
%
ch_in
;
int
new_ic
=
i
/
16
%
ch_in
;
int
new_inner
=
i
/
ch_in
/
16
%
4
;
int
new_inner
=
i
/
ch_in
/
16
%
4
;
int
dest_ind
=
int
dest_ind
=
new_c
*
c_stride
+
new_oc
*
ic_pad
*
4
+
new_ic
*
4
+
new_inner
;
new_c
*
c_stride
+
new_oc
*
ic_pad
*
4
+
new_ic
*
4
+
new_inner
;
...
...
lite/backends/arm/math/conv3x3_winograd_int8.cc
0 → 100644
浏览文件 @
34f7b509
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/arm/math/conv_block_utils.h"
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/packed_sgemm_c4.h"
#ifdef ARM_WITH_OMP
#include <omp.h>
#endif
#include <arm_neon.h>
namespace
paddle
{
namespace
lite
{
namespace
arm
{
namespace
math
{
void
input_trans_c8_4x4_int8
(
const
int8_t
*
src
,
int
src_stride
,
int
src_h_stride
,
int16_t
*
dest
,
int
dest_stride
,
int
dest_h_stride
);
void
output_trans_c8_post_2x4_int8
(
const
int32_t
*
src
,
int
src_stride
,
int
src_h_stride
,
int32_t
*
dest
,
int
dest_stride
,
int
dest_h_stride
);
void
weight_trans_c8_4x4_int8
(
int16_t
*
dest
,
const
int8_t
*
src
,
int
ic
,
int
oc
,
void
*
workspace
);
// F(2,3)
template
<
typename
Dtype
>
void
conv_compute_2x2_3x3_int8
(
const
int8_t
*
input
,
Dtype
*
output
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
const
int16_t
*
weight
,
const
float
*
bias
,
float
*
scale
,
const
operators
::
ConvParam
&
param
,
ARMContext
*
ctx
)
{
auto
act_param
=
param
.
activation_param
;
const
int
pad_h0
=
(
*
param
.
paddings
)[
0
];
const
int
pad_h1
=
(
*
param
.
paddings
)[
1
];
const
int
pad_w0
=
(
*
param
.
paddings
)[
2
];
const
int
pad_w1
=
(
*
param
.
paddings
)[
3
];
int8_t
*
tmp_work_space
=
ctx
->
workspace_data
<
int8_t
>
()
+
ctx
->
llc_size
()
/
sizeof
(
int8_t
);
int
in_n_stride
=
chin
*
hin
*
win
;
int
out_n_stride
=
chout
*
hout
*
wout
;
int
ic_stride
=
win
*
hin
;
int
oc_stride
=
wout
*
hout
;
int
ic_8
=
(
chin
+
7
)
/
8
;
int
oc_8
=
(
chout
+
7
)
/
8
;
int
tile_w
=
(
wout
+
1
)
/
2
;
int
tile_h
=
(
hout
+
1
)
/
2
;
int
size_tile
=
tile_h
*
tile_w
;
int
w_pad
=
win
+
pad_w0
+
pad_w1
;
int
h_pad
=
hin
+
pad_h0
+
pad_h1
;
const
int
zero_len
=
(
w_pad
+
3
)
/
4
*
4
;
Dtype
zero_ptr
[
zero_len
];
// NOLINT
memset
(
zero_ptr
,
0
,
zero_len
*
sizeof
(
Dtype
));
int8_t
*
input_c8
=
tmp_work_space
;
int
new_h_stride
=
w_pad
*
8
;
int
new_c_stride
=
new_h_stride
*
h_pad
;
int
ic_8_stride
=
w_pad
*
h_pad
*
8
;
int
oc_8_stride
=
wout
*
hout
*
8
;
int
tile_block
=
8
;
int
block_count
=
(
size_tile
+
tile_block
-
1
)
/
tile_block
;
int
threads
=
ctx
->
threads
();
int16_t
*
g_tmp_data
=
static_cast
<
int16_t
*>
(
tmp_work_space
+
ic_8
*
ic_8_stride
+
oc_8
*
oc_8_stride
*
sizeof
(
int32_t
));
int
tmp_input_thread_stride
=
tile_block
*
ic_8
*
128
;
int
tmp_output_thread_stride
=
tile_block
*
oc_8
*
128
;
int
tmp_data_thread_stride_size
=
tmp_input_thread_stride
*
sizeof
(
int16_t
)
+
tmp_output_thread_stride
*
sizeof
(
int32_t
);
memset
(
g_tmp_data
,
0
,
tmp_data_thread_stride_size
);
int8_t
*
g_trans_remain_tmp_data
=
static_cast
<
int8_t
*>
(
g_tmp_data
+
threads
*
(
tmp_input_thread_stride
+
tmp_output_thread_stride
*
sizeof
(
int32_t
)
/
sizeof
(
int16_t
)));
int32_t
*
g_trans_tmp_data
=
static_cast
<
int32_t
*>
(
g_trans_remain_tmp_data
+
threads
*
128
);
// begin compute
for
(
int
ni
=
0
;
ni
<
num
;
++
ni
)
{
// trans input to c4
for
(
int
i
=
0
;
i
<
ic_8
;
++
i
)
{
prepack_input_nxwc8_int8_dw
(
input
+
ni
*
in_n_stride
,
input_c8
+
i
*
new_c_stride
,
i
*
8
,
-
pad_h0
,
hin
+
pad_h1
,
-
pad_w0
,
win
+
pad_w1
,
chin
,
win
,
hin
);
}
int32_t
*
output_c8
=
static_cast
<
int32_t
*>
(
input_c8
+
ic_8
*
ic_8_stride
);
Dtype
*
output_ptr
=
output
+
ni
*
out_n_stride
;
const
int16_t
*
weight_ptr
=
weight
;
#pragma omp parallel for num_threads(threads)
for
(
int
tbi
=
0
;
tbi
<
block_count
;
++
tbi
)
{
#ifdef ARM_WITH_OMP
int16_t
*
tmp_data
=
g_tmp_data
+
omp_get_thread_num
()
*
tmp_data_thread_stride_size
/
sizeof
(
int16_t
);
int32_t
*
trans_tmp_data
=
g_trans_tmp_data
+
omp_get_thread_num
()
*
32
;
int8_t
*
trans_remain_tmp_data
=
g_trans_remain_tmp_data
+
omp_get_thread_num
()
*
128
;
#else
int16_t
*
tmp_data
=
g_tmp_data
;
int32_t
*
trans_tmp_data
=
g_trans_tmp_data
;
int8_t
*
trans_remain_tmp_data
=
g_trans_remain_tmp_data
;
#endif
int
tile_index
=
tbi
*
tile_block
;
int
tile_remain
=
size_tile
-
tile_index
;
int
tile_count
=
tile_remain
>
tile_block
?
tile_block
:
tile_remain
;
auto
t0
=
std
::
chrono
::
steady_clock
::
now
();
// input trans
int
c_gi_stride
=
tile_count
*
oc_8
*
8
;
int
b_gi_stride
=
tile_count
*
ic_8
*
8
;
//*
for
(
int
ti
=
0
;
ti
<
tile_count
;
++
ti
)
{
int
index
=
tile_index
+
ti
;
int
tw_index
=
index
%
tile_w
;
int
th_index
=
index
/
tile_w
;
int
src_x
=
tw_index
+
tw_index
;
int
src_y
=
th_index
+
th_index
;
int
ex
=
src_x
+
4
>
w_pad
?
w_pad
-
src_x
:
4
;
int
ey
=
src_y
+
4
>
h_pad
?
h_pad
-
src_y
:
4
;
int16_t
*
dst_ptr
=
tmp_data
+
ti
*
8
;
const
int8_t
*
src_ptr
=
input_c8
+
(
src_y
*
w_pad
+
src_x
)
*
8
;
if
(
ex
==
4
&&
ey
==
4
)
{
// trans input
for
(
int
ci
=
0
;
ci
<
ic_8
;
++
ci
)
{
const
int8_t
*
src_ci
=
src_ptr
+
ci
*
ic_8_stride
;
int16_t
*
dst_ci
=
dst_ptr
+
ci
*
tile_count
*
8
;
input_trans_c8_4x4_int8
(
src_ci
,
8
,
w_pad
*
8
,
dst_ci
,
b_gi_stride
,
b_gi_stride
*
4
);
}
}
else
{
// trans remain input
int
x_size
=
ex
;
for
(
int
ci
=
0
;
ci
<
ic_8
;
++
ci
)
{
const
int8_t
*
src_ci
=
src_ptr
+
ci
*
ic_8_stride
;
// pad
memset
(
trans_remain_tmp_data
,
0
,
128
*
sizeof
(
int8_t
));
if
(
x_size
>
0
)
{
for
(
int
yi
=
0
;
yi
<
ey
;
++
yi
)
{
int8_t
*
dst_yi
=
trans_remain_tmp_data
+
yi
*
32
;
const
int8_t
*
src_yi
=
src_ci
+
w_pad
*
yi
*
8
;
memcpy
(
dst_yi
,
src_yi
,
x_size
*
sizeof
(
int8_t
)
*
8
);
}
}
// trans
int16_t
*
dst_ci
=
dst_ptr
+
ci
*
tile_count
*
8
;
input_trans_c8_4x4_int8
(
trans_remain_tmp_data
,
8
,
32
,
dst_ci
,
b_gi_stride
,
b_gi_stride
*
4
);
}
// for ci_4
}
}
//*/
// input trans end
// *begin compute dot
// *
//*
int32_t
*
dst_temp_data
=
static_cast
<
int32_t
*>
(
tmp_data
+
tmp_input_thread_stride
);
int16_t
*
b_ptr
=
tmp_data
;
int
w_gi_stride
=
ic_8
*
oc_8
*
64
;
for
(
int
gi
=
0
;
gi
<
16
;
++
gi
)
{
int32_t
*
origin_C
=
dst_temp_data
+
gi
*
c_gi_stride
;
int16_t
*
origin_B
=
b_ptr
+
gi
*
b_gi_stride
;
const
int16_t
*
origin_A
=
weight
+
gi
*
w_gi_stride
;
sgemm_prepack_c8_int16_small
(
oc_8
*
8
,
tile_count
,
ic_8
*
8
,
origin_A
,
origin_B
,
origin_C
,
ctx
);
}
//*/
//*
// output trans
for
(
int
ti
=
0
;
ti
<
tile_count
;
++
ti
)
{
int
index
=
tile_index
+
ti
;
int
tw_index
=
index
%
tile_w
;
int
th_index
=
index
/
tile_w
;
int
dst_x
=
tw_index
*
2
;
int
dst_y
=
th_index
*
2
;
int
ex
=
dst_x
+
2
>
wout
?
wout
-
dst_x
:
2
;
int
ey
=
dst_y
+
2
>
hout
?
hout
-
dst_y
:
2
;
int32_t
*
src_ptr
=
dst_temp_data
+
ti
*
8
;
int32_t
*
trans_remain_tmp_i32_data
=
static_cast
<
int32_t
*>
(
trans_remain_tmp_data
);
int32_t
*
dst_ptr
=
output_c8
+
(
dst_y
*
wout
+
dst_x
)
*
8
;
if
(
ex
==
2
&&
ey
==
2
)
{
// trans output
for
(
int
ci
=
0
;
ci
<
oc_8
;
++
ci
)
{
int
cur_ind
=
ci
*
8
;
int32_t
*
src_ci
=
src_ptr
+
ci
*
tile_count
*
8
;
int32_t
*
dst_ci
=
dst_ptr
+
ci
*
oc_8_stride
;
output_trans_c8_post_2x4_int8
(
src_ci
,
c_gi_stride
,
c_gi_stride
*
4
,
dst_ci
,
8
,
wout
*
8
);
}
}
else
{
for
(
int
ci
=
0
;
ci
<
oc_8
;
++
ci
)
{
int
cur_ind
=
ci
*
8
;
// trans output
int32_t
*
src_ci
=
src_ptr
+
ci
*
tile_count
*
8
;
output_trans_c8_post_2x4_int8
(
src_ci
,
c_gi_stride
,
c_gi_stride
*
4
,
trans_remain_tmp_i32_data
,
8
,
16
);
// copy to dest
int32_t
*
dst_ci
=
dst_ptr
+
ci
*
oc_8_stride
;
for
(
int
i
=
0
;
i
<
ey
;
++
i
)
{
memcpy
(
dst_ci
+
i
*
wout
*
8
,
trans_remain_tmp_i32_data
+
i
*
16
,
ex
*
sizeof
(
int32_t
)
*
8
);
}
}
}
}
//*/
}
// for block_count
for
(
int
ci
=
0
;
ci
<
oc_8
;
++
ci
)
{
write_int32_nchwc8_to_nchw
(
output_c8
+
ci
*
oc_8_stride
,
output_ptr
,
ci
*
8
,
ci
*
8
+
8
,
0
,
hout
,
0
,
wout
,
chout
,
hout
,
wout
,
param
.
fuse_relu
,
bias
+
ci
*
8
,
param
.
bias
,
zero_ptr
,
scale
+
ci
*
8
);
}
}
// for num
}
// conv compute
template
void
conv_compute_2x2_3x3_int8
<
int8_t
>(
const
int8_t
*
input
,
int8_t
*
output
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
const
int16_t
*
weight
,
const
float
*
bias
,
float
*
scale
,
const
operators
::
ConvParam
&
param
,
ARMContext
*
ctx
);
template
void
conv_compute_2x2_3x3_int8
<
float
>(
const
int8_t
*
input
,
float
*
output
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
const
int16_t
*
weight
,
const
float
*
bias
,
float
*
scale
,
const
operators
::
ConvParam
&
param
,
ARMContext
*
ctx
);
// BT=[1, 0, -1, 0,
// 0, 1, 1, 0,
// 0, -1, 1, 0,
// 0, 1, 0, -1]
void
input_trans_c8_4x4_int8
(
const
int8_t
*
src
,
int
src_stride
,
int
src_h_stride
,
int16_t
*
dest
,
int
dest_stride
,
int
dest_h_stride
)
{
int8x8_t
src00
=
vld1_s8
(
src
);
int8x8_t
src01
=
vld1_s8
(
src
+
src_stride
);
int8x8_t
src02
=
vld1_s8
(
src
+
src_stride
+
src_stride
);
int8x8_t
src03
=
vld1_s8
(
src
+
src_stride
+
src_stride
+
src_stride
);
src
+=
src_h_stride
;
int8x8_t
src10
=
vld1_s8
(
src
);
int8x8_t
src11
=
vld1_s8
(
src
+
src_stride
);
int8x8_t
src12
=
vld1_s8
(
src
+
src_stride
+
src_stride
);
int8x8_t
src13
=
vld1_s8
(
src
+
src_stride
+
src_stride
+
src_stride
);
src
+=
src_h_stride
;
int8x8_t
src20
=
vld1_s8
(
src
);
int8x8_t
src21
=
vld1_s8
(
src
+
src_stride
);
int8x8_t
src22
=
vld1_s8
(
src
+
src_stride
+
src_stride
);
int8x8_t
src23
=
vld1_s8
(
src
+
src_stride
+
src_stride
+
src_stride
);
src
+=
src_h_stride
;
int8x8_t
src30
=
vld1_s8
(
src
);
int8x8_t
src31
=
vld1_s8
(
src
+
src_stride
);
int8x8_t
src32
=
vld1_s8
(
src
+
src_stride
+
src_stride
);
int8x8_t
src33
=
vld1_s8
(
src
+
src_stride
+
src_stride
+
src_stride
);
int16x8_t
dst00
=
vsubl_s8
(
src00
,
src02
);
int16x8_t
dst10
=
vaddl_s8
(
src01
,
src02
);
int16x8_t
dst20
=
vsubl_s8
(
src02
,
src01
);
int16x8_t
dst30
=
vsubl_s8
(
src01
,
src03
);
int16x8_t
dst01
=
vsubl_s8
(
src10
,
src12
);
int16x8_t
dst11
=
vaddl_s8
(
src11
,
src12
);
int16x8_t
dst21
=
vsubl_s8
(
src12
,
src11
);
int16x8_t
dst31
=
vsubl_s8
(
src11
,
src13
);
int16x8_t
dst02
=
vsubl_s8
(
src20
,
src22
);
int16x8_t
dst12
=
vaddl_s8
(
src21
,
src22
);
int16x8_t
dst22
=
vsubl_s8
(
src22
,
src21
);
int16x8_t
dst32
=
vsubl_s8
(
src21
,
src23
);
int16x8_t
dst03
=
vsubl_s8
(
src30
,
src32
);
int16x8_t
dst13
=
vaddl_s8
(
src31
,
src32
);
int16x8_t
dst23
=
vsubl_s8
(
src32
,
src31
);
int16x8_t
dst33
=
vsubl_s8
(
src31
,
src33
);
int16x8_t
dest00
=
vsubq_s16
(
dst00
,
dst02
);
int16x8_t
dest10
=
vaddq_s16
(
dst01
,
dst02
);
int16x8_t
dest20
=
vsubq_s16
(
dst02
,
dst01
);
int16x8_t
dest30
=
vsubq_s16
(
dst01
,
dst03
);
int16x8_t
dest01
=
vsubq_s16
(
dst10
,
dst12
);
int16x8_t
dest11
=
vaddq_s16
(
dst11
,
dst12
);
int16x8_t
dest21
=
vsubq_s16
(
dst12
,
dst11
);
int16x8_t
dest31
=
vsubq_s16
(
dst11
,
dst13
);
int16x8_t
dest02
=
vsubq_s16
(
dst20
,
dst22
);
int16x8_t
dest12
=
vaddq_s16
(
dst21
,
dst22
);
int16x8_t
dest22
=
vsubq_s16
(
dst22
,
dst21
);
int16x8_t
dest32
=
vsubq_s16
(
dst21
,
dst23
);
int16x8_t
dest03
=
vsubq_s16
(
dst30
,
dst32
);
int16x8_t
dest13
=
vaddq_s16
(
dst31
,
dst32
);
int16x8_t
dest23
=
vsubq_s16
(
dst32
,
dst31
);
int16x8_t
dest33
=
vsubq_s16
(
dst31
,
dst33
);
vst1q_s16
(
dest
,
dest00
);
vst1q_s16
(
dest
+
dest_stride
,
dest10
);
vst1q_s16
(
dest
+
dest_stride
+
dest_stride
,
dest20
);
vst1q_s16
(
dest
+
dest_stride
+
dest_stride
+
dest_stride
,
dest30
);
dest
+=
dest_h_stride
;
vst1q_s16
(
dest
,
dest01
);
vst1q_s16
(
dest
+
dest_stride
,
dest11
);
vst1q_s16
(
dest
+
dest_stride
+
dest_stride
,
dest21
);
vst1q_s16
(
dest
+
dest_stride
+
dest_stride
+
dest_stride
,
dest31
);
dest
+=
dest_h_stride
;
vst1q_s16
(
dest
,
dest02
);
vst1q_s16
(
dest
+
dest_stride
,
dest12
);
vst1q_s16
(
dest
+
dest_stride
+
dest_stride
,
dest22
);
vst1q_s16
(
dest
+
dest_stride
+
dest_stride
+
dest_stride
,
dest32
);
dest
+=
dest_h_stride
;
vst1q_s16
(
dest
,
dest03
);
vst1q_s16
(
dest
+
dest_stride
,
dest13
);
vst1q_s16
(
dest
+
dest_stride
+
dest_stride
,
dest23
);
vst1q_s16
(
dest
+
dest_stride
+
dest_stride
+
dest_stride
,
dest33
);
}
// AT=[1, 1, 1, 0,
// 0, 1, -1, -1]
void
output_trans_c8_post_2x4_int8
(
const
int32_t
*
src
,
int
src_stride
,
int
src_h_stride
,
int32_t
*
dest
,
int
dest_stride
,
int
dest_h_stride
)
{
int32x4_t
src400
=
vld1q_s32
(
src
);
int32x4_t
src800
=
vld1q_s32
(
src
+
4
);
src
+=
src_stride
;
int32x4_t
src401
=
vld1q_s32
(
src
);
int32x4_t
src801
=
vld1q_s32
(
src
+
4
);
src
+=
src_stride
;
int32x4_t
src402
=
vld1q_s32
(
src
);
int32x4_t
src802
=
vld1q_s32
(
src
+
4
);
src
+=
src_stride
;
int32x4_t
src403
=
vld1q_s32
(
src
);
int32x4_t
src803
=
vld1q_s32
(
src
+
4
);
src
+=
src_h_stride
-
3
*
src_stride
;
int32x4_t
src410
=
vld1q_s32
(
src
);
int32x4_t
src810
=
vld1q_s32
(
src
+
4
);
src
+=
src_stride
;
int32x4_t
src411
=
vld1q_s32
(
src
);
int32x4_t
src811
=
vld1q_s32
(
src
+
4
);
src
+=
src_stride
;
int32x4_t
src412
=
vld1q_s32
(
src
);
int32x4_t
src812
=
vld1q_s32
(
src
+
4
);
src
+=
src_stride
;
int32x4_t
src413
=
vld1q_s32
(
src
);
int32x4_t
src813
=
vld1q_s32
(
src
+
4
);
src
+=
src_h_stride
-
3
*
src_stride
;
int32x4_t
src420
=
vld1q_s32
(
src
);
int32x4_t
src820
=
vld1q_s32
(
src
+
4
);
src
+=
src_stride
;
int32x4_t
src421
=
vld1q_s32
(
src
);
int32x4_t
src821
=
vld1q_s32
(
src
+
4
);
src
+=
src_stride
;
int32x4_t
src422
=
vld1q_s32
(
src
);
int32x4_t
src822
=
vld1q_s32
(
src
+
4
);
src
+=
src_stride
;
int32x4_t
src423
=
vld1q_s32
(
src
);
int32x4_t
src823
=
vld1q_s32
(
src
+
4
);
src
+=
src_h_stride
-
3
*
src_stride
;
int32x4_t
src430
=
vld1q_s32
(
src
);
int32x4_t
src830
=
vld1q_s32
(
src
+
4
);
src
+=
src_stride
;
int32x4_t
src431
=
vld1q_s32
(
src
);
int32x4_t
src831
=
vld1q_s32
(
src
+
4
);
src
+=
src_stride
;
int32x4_t
src432
=
vld1q_s32
(
src
);
int32x4_t
src832
=
vld1q_s32
(
src
+
4
);
src
+=
src_stride
;
int32x4_t
src433
=
vld1q_s32
(
src
);
int32x4_t
src833
=
vld1q_s32
(
src
+
4
);
int32x4_t
dst400
=
vaddq_s32
(
vaddq_s32
(
src400
,
src401
),
src402
);
int32x4_t
dst410
=
vsubq_s32
(
vsubq_s32
(
src401
,
src402
),
src403
);
int32x4_t
dst401
=
vaddq_s32
(
vaddq_s32
(
src410
,
src411
),
src412
);
int32x4_t
dst411
=
vsubq_s32
(
vsubq_s32
(
src411
,
src412
),
src413
);
int32x4_t
dst402
=
vaddq_s32
(
vaddq_s32
(
src420
,
src421
),
src422
);
int32x4_t
dst412
=
vsubq_s32
(
vsubq_s32
(
src421
,
src422
),
src423
);
int32x4_t
dst403
=
vaddq_s32
(
vaddq_s32
(
src430
,
src431
),
src432
);
int32x4_t
dst413
=
vsubq_s32
(
vsubq_s32
(
src431
,
src432
),
src433
);
int32x4_t
dst800
=
vaddq_s32
(
vaddq_s32
(
src800
,
src801
),
src802
);
int32x4_t
dst810
=
vsubq_s32
(
vsubq_s32
(
src801
,
src802
),
src803
);
int32x4_t
dst801
=
vaddq_s32
(
vaddq_s32
(
src810
,
src811
),
src812
);
int32x4_t
dst811
=
vsubq_s32
(
vsubq_s32
(
src811
,
src812
),
src813
);
int32x4_t
dst802
=
vaddq_s32
(
vaddq_s32
(
src820
,
src821
),
src822
);
int32x4_t
dst812
=
vsubq_s32
(
vsubq_s32
(
src821
,
src822
),
src823
);
int32x4_t
dst803
=
vaddq_s32
(
vaddq_s32
(
src830
,
src831
),
src832
);
int32x4_t
dst813
=
vsubq_s32
(
vsubq_s32
(
src831
,
src832
),
src833
);
int32x4_t
dest400
=
vaddq_s32
(
vaddq_s32
(
dst400
,
dst401
),
dst402
);
int32x4_t
dest410
=
vsubq_s32
(
vsubq_s32
(
dst401
,
dst402
),
dst403
);
int32x4_t
dest401
=
vaddq_s32
(
vaddq_s32
(
dst410
,
dst411
),
dst412
);
int32x4_t
dest411
=
vsubq_s32
(
vsubq_s32
(
dst411
,
dst412
),
dst413
);
int32x4_t
dest800
=
vaddq_s32
(
vaddq_s32
(
dst800
,
dst801
),
dst802
);
int32x4_t
dest810
=
vsubq_s32
(
vsubq_s32
(
dst801
,
dst802
),
dst803
);
int32x4_t
dest801
=
vaddq_s32
(
vaddq_s32
(
dst810
,
dst811
),
dst812
);
int32x4_t
dest811
=
vsubq_s32
(
vsubq_s32
(
dst811
,
dst812
),
dst813
);
vst1q_s32
(
dest
,
dest400
);
vst1q_s32
(
dest
+
4
,
dest800
);
dest
+=
dest_stride
;
vst1q_s32
(
dest
,
dest410
);
vst1q_s32
(
dest
+
4
,
dest810
);
dest
+=
dest_h_stride
-
dest_stride
;
vst1q_s32
(
dest
,
dest401
);
vst1q_s32
(
dest
+
4
,
dest801
);
dest
+=
dest_stride
;
vst1q_s32
(
dest
,
dest411
);
vst1q_s32
(
dest
+
4
,
dest811
);
}
void
weight_trans_c8_4x4_int8
(
int16_t
*
dest
,
const
int8_t
*
din
,
int
ch_in
,
int
ch_out
,
void
*
workspace
)
{
const
int16_t
coeff
[
4
][
3
]
=
{{
2
,
0
,
0
},
{
1
,
1
,
1
},
{
1
,
-
1
,
1
},
{
0
,
0
,
2
}};
int16_t
*
ptr_out
=
static_cast
<
int16_t
*>
(
workspace
);
for
(
int
i
=
0
;
i
<
ch_out
;
i
++
)
{
for
(
int
j
=
0
;
j
<
ch_in
;
j
++
)
{
const
int8_t
*
kernel0
=
static_cast
<
const
int8_t
*>
(
din
)
+
(
i
*
ch_in
+
j
)
*
9
;
int16_t
*
ptr_channel
=
ptr_out
+
(
i
*
ch_in
+
j
)
*
16
;
//! transform kernel, transposed
const
int8_t
*
k0
=
kernel0
;
const
int8_t
*
k1
=
kernel0
+
3
;
const
int8_t
*
k2
=
kernel0
+
6
;
//! h
int16_t
tmp
[
4
][
3
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
tmp
[
i
][
0
]
=
k0
[
0
]
*
coeff
[
i
][
0
]
+
k0
[
1
]
*
coeff
[
i
][
1
]
+
k0
[
2
]
*
coeff
[
i
][
2
];
tmp
[
i
][
1
]
=
k1
[
0
]
*
coeff
[
i
][
0
]
+
k1
[
1
]
*
coeff
[
i
][
1
]
+
k1
[
2
]
*
coeff
[
i
][
2
];
tmp
[
i
][
2
]
=
k2
[
0
]
*
coeff
[
i
][
0
]
+
k2
[
1
]
*
coeff
[
i
][
1
]
+
k2
[
2
]
*
coeff
[
i
][
2
];
}
//! v
for
(
int
j
=
0
;
j
<
4
;
j
++
)
{
int16_t
*
tmpp
=
&
tmp
[
j
][
0
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
ptr_channel
[
j
*
4
+
i
]
=
tmpp
[
0
]
*
coeff
[
i
][
0
]
+
tmpp
[
1
]
*
coeff
[
i
][
1
]
+
tmpp
[
2
]
*
coeff
[
i
][
2
];
}
}
}
}
int
oc_pad
=
(
ch_out
+
7
)
/
8
*
8
;
int
ic_pad
=
(
ch_in
+
7
)
/
8
*
8
;
int
c_stride
=
ic_pad
*
oc_pad
;
for
(
int
i
=
0
;
i
<
ch_out
*
ch_in
*
16
;
++
i
)
{
int
new_c
=
i
%
16
;
int
new_oc
=
i
/
ch_in
/
16
/
8
;
int
new_ic
=
i
/
16
%
ch_in
;
int
new_inner
=
i
/
ch_in
/
16
%
8
;
int
dest_ind
=
new_c
*
c_stride
+
new_oc
*
ic_pad
*
8
+
new_ic
*
8
+
new_inner
;
dest
[
dest_ind
]
=
ptr_out
[
i
];
}
}
}
// namespace math
}
// namespace arm
}
// namespace lite
}
// namespace paddle
lite/backends/arm/math/conv_block_utils.h
浏览文件 @
34f7b509
...
@@ -3856,7 +3856,7 @@ inline void write_int32_nchwc8_to_nchw(const int* din,
...
@@ -3856,7 +3856,7 @@ inline void write_int32_nchwc8_to_nchw(const int* din,
int
height
,
int
height
,
int
width
,
int
width
,
bool
flag_relu
,
bool
flag_relu
,
float
*
bias
,
const
float
*
bias
,
bool
flag_bias
,
bool
flag_bias
,
Dtype
*
trash_ptr
,
Dtype
*
trash_ptr
,
const
float
*
scale
)
{
const
float
*
scale
)
{
...
@@ -3878,6 +3878,7 @@ inline void write_int32_nchwc8_to_nchw(const int* din,
...
@@ -3878,6 +3878,7 @@ inline void write_int32_nchwc8_to_nchw(const int* din,
int
w_stride
=
we
-
ws
;
int
w_stride
=
we
-
ws
;
int
valid_w
=
(
we
>
width
?
width
:
we
)
-
ws
;
int
valid_w
=
(
we
>
width
?
width
:
we
)
-
ws
;
int
cnt
=
valid_w
/
4
;
int
cnt
=
valid_w
/
4
;
int
remain
=
valid_w
&
3
;
float32x4_t
w_scale0
=
vld1q_f32
(
scale
);
float32x4_t
w_scale0
=
vld1q_f32
(
scale
);
float32x4_t
w_scale1
=
vld1q_f32
(
scale
+
4
);
float32x4_t
w_scale1
=
vld1q_f32
(
scale
+
4
);
...
@@ -3933,10 +3934,10 @@ inline void write_int32_nchwc8_to_nchw(const int* din,
...
@@ -3933,10 +3934,10 @@ inline void write_int32_nchwc8_to_nchw(const int* din,
w_bias1
,
w_bias1
,
flag_relu
);
flag_relu
);
}
}
if
(
we
>
width
)
{
if
(
remain
>
0
)
{
int
offset
=
32
*
cnt
;
int
offset
=
32
*
cnt
;
din_hei_ptr
=
ptr_din
+
offset
;
din_hei_ptr
=
ptr_din
+
offset
;
for
(
int
j
=
ws
+
cnt
*
4
;
j
<
width
;
++
j
)
{
for
(
int
j
=
0
;
j
<
remain
;
++
j
)
{
if
(
flag_bias
)
{
if
(
flag_bias
)
{
*
(
doutc0_ptr
++
)
=
*
(
doutc0_ptr
++
)
=
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
0
],
scale
[
0
],
bias
[
0
],
flag_relu
);
cvt_kernel
<
Dtype
>
(
din_hei_ptr
[
0
],
scale
[
0
],
bias
[
0
],
flag_relu
);
...
...
lite/backends/arm/math/conv_impl.h
浏览文件 @
34f7b509
...
@@ -359,6 +359,35 @@ void conv_compute_2x2_3x3_small(const float* input,
...
@@ -359,6 +359,35 @@ void conv_compute_2x2_3x3_small(const float* input,
const
float
*
bias
,
const
float
*
bias
,
const
operators
::
ConvParam
&
param
,
const
operators
::
ConvParam
&
param
,
ARMContext
*
ctx
);
ARMContext
*
ctx
);
void
input_trans_c8_4x4_int8
(
const
int8_t
*
src
,
int
src_stride
,
int
src_h_stride
,
int16_t
*
dest
,
int
dest_stride
,
int
dest_h_stride
);
void
output_trans_c8_post_2x4_int8
(
const
int32_t
*
src
,
int
src_stride
,
int
src_h_stride
,
int32_t
*
dest
,
int
dest_stride
,
int
dest_h_stride
);
void
weight_trans_c8_4x4_int8
(
int16_t
*
dest
,
const
int8_t
*
src
,
int
ic
,
int
oc
,
void
*
workspace
);
template
<
typename
Dtype
>
void
conv_compute_2x2_3x3_int8
(
const
int8_t
*
input
,
Dtype
*
output
,
int
num
,
int
chout
,
int
hout
,
int
wout
,
int
chin
,
int
hin
,
int
win
,
const
int16_t
*
weight
,
const
float
*
bias
,
float
*
scale
,
const
operators
::
ConvParam
&
param
,
ARMContext
*
ctx
);
template
<
typename
Dtype
>
template
<
typename
Dtype
>
void
im2col
(
const
Dtype
*
data_im
,
void
im2col
(
const
Dtype
*
data_im
,
...
...
lite/backends/arm/math/packed_sgemm_c4.cc
浏览文件 @
34f7b509
...
@@ -1679,6 +1679,912 @@ void sgemm_prepack_c4_small(int M,
...
@@ -1679,6 +1679,912 @@ void sgemm_prepack_c4_small(int M,
}
}
}
}
void
sgemm_prepack_c8_int16_small
(
int
M
,
int
N
,
int
K
,
const
int16_t
*
A_packed
,
const
int16_t
*
B
,
int32_t
*
C
,
ARMContext
*
ctx
)
{
const
int
m_round
=
(
M
+
7
)
/
8
*
8
;
const
int
k_round
=
(
K
+
7
)
/
8
*
8
;
const
int
mloop
=
m_round
>>
3
;
const
int
lda
=
8
*
k_round
;
const
int
ldb_byte
=
8
*
N
*
sizeof
(
int16_t
);
const
int
kcnt
=
k_round
>>
3
;
#ifdef __aarch64__
float32x4_t
vzero
=
vdupq_n_f32
(
0.
f
);
#endif
for
(
int
m
=
0
;
m
<
mloop
;
++
m
)
{
const
int16_t
*
b
=
B
;
int
n
=
N
;
#ifdef __aarch64__
for
(;
n
>
7
;
n
-=
8
)
{
int
cnt
=
kcnt
;
const
int16_t
*
a_ptr
=
A_packed
;
const
int16_t
*
b_ptr
=
b
;
// clang-format off
asm
volatile
(
"ld1 {v0.8h, v1.8h}, [%[a]], #32
\n
"
//load a0, a1
"ld1 {v4.8h, v5.8h}, [%[b]], #32
\n
"
//load b0, b1
"ld1 {v6.8h, v7.8h}, [%[b]], #32
\n
"
//load b2, b3
"smull v20.4s, v0.4h, v4.h[0]
\n
"
"smull v21.4s, v0.4h, v5.h[0]
\n
"
"smull v22.4s, v0.4h, v6.h[0]
\n
"
"smull v23.4s, v0.4h, v7.h[0]
\n
"
"ld1 {v8.8h, v9.8h}, [%[b]], #32
\n
"
//load b0, b1
"ld1 {v10.8h, v11.8h}, [%[b]], #32
\n
"
//load b2, b3
"smull2 v24.4s, v0.8h, v4.h[0]
\n
"
"smull2 v25.4s, v0.8h, v5.h[0]
\n
"
"smull2 v26.4s, v0.8h, v6.h[0]
\n
"
"smull2 v27.4s, v0.8h, v7.h[0]
\n
"
"ld1 {v2.8h, v3.8h}, [%[a]], #32
\n
"
//load a2, a3
"smlal v20.4s, v1.4h, v4.h[1]
\n
"
"smlal v21.4s, v1.4h, v5.h[1]
\n
"
"smlal v22.4s, v1.4h, v6.h[1]
\n
"
"smlal v23.4s, v1.4h, v7.h[1]
\n
"
"smlal2 v24.4s, v1.8h, v4.h[1]
\n
"
"smlal2 v25.4s, v1.8h, v5.h[1]
\n
"
"smlal2 v26.4s, v1.8h, v6.h[1]
\n
"
"smlal2 v27.4s, v1.8h, v7.h[1]
\n
"
"smull v12.4s, v0.4h, v8.h[0]
\n
"
"smull v13.4s, v0.4h, v9.h[0]
\n
"
"smull v14.4s, v0.4h, v10.h[0]
\n
"
"smull v15.4s, v0.4h, v11.h[0]
\n
"
"smull2 v16.4s, v0.8h, v8.h[0]
\n
"
"smull2 v17.4s, v0.8h, v9.h[0]
\n
"
"smull2 v18.4s, v0.8h, v10.h[0]
\n
"
"smull2 v19.4s, v0.8h, v11.h[0]
\n
"
"smlal v12.4s, v1.4h, v8.h[1]
\n
"
"smlal v13.4s, v1.4h, v9.h[1]
\n
"
"smlal v14.4s, v1.4h, v10.h[1]
\n
"
"smlal v15.4s, v1.4h, v11.h[1]
\n
"
"smlal2 v16.4s, v1.8h, v8.h[1]
\n
"
"smlal2 v17.4s, v1.8h, v9.h[1]
\n
"
"smlal2 v18.4s, v1.8h, v10.h[1]
\n
"
"smlal2 v19.4s, v1.8h, v11.h[1]
\n
"
"smlal v20.4s, v2.4h, v4.h[2]
\n
"
"smlal v21.4s, v2.4h, v5.h[2]
\n
"
"smlal v22.4s, v2.4h, v6.h[2]
\n
"
"smlal v23.4s, v2.4h, v7.h[2]
\n
"
"ld1 {v0.8h, v1.8h}, [%[a]], #32
\n
"
//load a0, a1
"smlal2 v24.4s, v2.8h, v4.h[2]
\n
"
"smlal2 v25.4s, v2.8h, v5.h[2]
\n
"
"smlal2 v26.4s, v2.8h, v6.h[2]
\n
"
"smlal2 v27.4s, v2.8h, v7.h[2]
\n
"
"smlal v12.4s, v2.4h, v8.h[2]
\n
"
"smlal v13.4s, v2.4h, v9.h[2]
\n
"
"smlal v14.4s, v2.4h, v10.h[2]
\n
"
"smlal v15.4s, v2.4h, v11.h[2]
\n
"
"smlal2 v16.4s, v2.8h, v8.h[2]
\n
"
"smlal2 v17.4s, v2.8h, v9.h[2]
\n
"
"smlal2 v18.4s, v2.8h, v10.h[2]
\n
"
"smlal2 v19.4s, v2.8h, v11.h[2]
\n
"
"smlal v20.4s, v3.4h, v4.h[3]
\n
"
"smlal v21.4s, v3.4h, v5.h[3]
\n
"
"smlal v22.4s, v3.4h, v6.h[3]
\n
"
"smlal v23.4s, v3.4h, v7.h[3]
\n
"
"smlal2 v24.4s, v3.8h, v4.h[3]
\n
"
"smlal2 v25.4s, v3.8h, v5.h[3]
\n
"
"smlal2 v26.4s, v3.8h, v6.h[3]
\n
"
"smlal2 v27.4s, v3.8h, v7.h[3]
\n
"
"smlal v12.4s, v3.4h, v8.h[3]
\n
"
"smlal v13.4s, v3.4h, v9.h[3]
\n
"
"smlal v14.4s, v3.4h, v10.h[3]
\n
"
"smlal v15.4s, v3.4h, v11.h[3]
\n
"
"smlal2 v16.4s, v3.8h, v8.h[3]
\n
"
"smlal2 v17.4s, v3.8h, v9.h[3]
\n
"
"smlal2 v18.4s, v3.8h, v10.h[3]
\n
"
"smlal2 v19.4s, v3.8h, v11.h[3]
\n
"
"smlal v20.4s, v0.4h, v4.h[4]
\n
"
"smlal v21.4s, v0.4h, v5.h[4]
\n
"
"smlal v22.4s, v0.4h, v6.h[4]
\n
"
"smlal v23.4s, v0.4h, v7.h[4]
\n
"
"smlal2 v24.4s, v0.8h, v4.h[4]
\n
"
"smlal2 v25.4s, v0.8h, v5.h[4]
\n
"
"smlal2 v26.4s, v0.8h, v6.h[4]
\n
"
"smlal2 v27.4s, v0.8h, v7.h[4]
\n
"
"ld1 {v2.8h, v3.8h}, [%[a]], #32
\n
"
//load a2, a3
"smlal v20.4s, v1.4h, v4.h[5]
\n
"
"smlal v21.4s, v1.4h, v5.h[5]
\n
"
"smlal v22.4s, v1.4h, v6.h[5]
\n
"
"smlal v23.4s, v1.4h, v7.h[5]
\n
"
"smlal2 v24.4s, v1.8h, v4.h[5]
\n
"
"smlal2 v25.4s, v1.8h, v5.h[5]
\n
"
"smlal2 v26.4s, v1.8h, v6.h[5]
\n
"
"smlal2 v27.4s, v1.8h, v7.h[5]
\n
"
"smlal v12.4s, v0.4h, v8.h[4]
\n
"
"smlal v13.4s, v0.4h, v9.h[4]
\n
"
"smlal v14.4s, v0.4h, v10.h[4]
\n
"
"smlal v15.4s, v0.4h, v11.h[4]
\n
"
"smlal2 v16.4s, v0.8h, v8.h[4]
\n
"
"smlal2 v17.4s, v0.8h, v9.h[4]
\n
"
"smlal2 v18.4s, v0.8h, v10.h[4]
\n
"
"smlal2 v19.4s, v0.8h, v11.h[4]
\n
"
"smlal v12.4s, v1.4h, v8.h[5]
\n
"
"smlal v13.4s, v1.4h, v9.h[5]
\n
"
"smlal v14.4s, v1.4h, v10.h[5]
\n
"
"smlal v15.4s, v1.4h, v11.h[5]
\n
"
"smlal2 v16.4s, v1.8h, v8.h[5]
\n
"
"smlal2 v17.4s, v1.8h, v9.h[5]
\n
"
"smlal2 v18.4s, v1.8h, v10.h[5]
\n
"
"smlal2 v19.4s, v1.8h, v11.h[5]
\n
"
"smlal v20.4s, v2.4h, v4.h[6]
\n
"
"smlal v21.4s, v2.4h, v5.h[6]
\n
"
"smlal v22.4s, v2.4h, v6.h[6]
\n
"
"smlal v23.4s, v2.4h, v7.h[6]
\n
"
"ld1 {v0.8h, v1.8h}, [%[a]], #32
\n
"
//load a0, a1
"smlal2 v24.4s, v2.8h, v4.h[6]
\n
"
"smlal2 v25.4s, v2.8h, v5.h[6]
\n
"
"smlal2 v26.4s, v2.8h, v6.h[6]
\n
"
"smlal2 v27.4s, v2.8h, v7.h[6]
\n
"
"sub %[b], %[b], #128
\n
"
"add %[b], %[b], %[ldb]
\n
"
"smlal v20.4s, v3.4h, v4.h[7]
\n
"
"smlal v21.4s, v3.4h, v5.h[7]
\n
"
"smlal v22.4s, v3.4h, v6.h[7]
\n
"
"smlal v23.4s, v3.4h, v7.h[7]
\n
"
"smlal2 v24.4s, v3.8h, v4.h[7]
\n
"
"smlal2 v25.4s, v3.8h, v5.h[7]
\n
"
"smlal2 v26.4s, v3.8h, v6.h[7]
\n
"
"smlal2 v27.4s, v3.8h, v7.h[7]
\n
"
"ld1 {v4.8h, v5.8h}, [%[b]], #32
\n
"
//load b0, b1
"ld1 {v6.8h, v7.8h}, [%[b]], #32
\n
"
//load b2, b3
"smlal v12.4s, v2.4h, v8.h[6]
\n
"
"smlal v13.4s, v2.4h, v9.h[6]
\n
"
"smlal v14.4s, v2.4h, v10.h[6]
\n
"
"smlal v15.4s, v2.4h, v11.h[6]
\n
"
"smlal2 v16.4s, v2.8h, v8.h[6]
\n
"
"smlal2 v17.4s, v2.8h, v9.h[6]
\n
"
"smlal2 v18.4s, v2.8h, v10.h[6]
\n
"
"smlal2 v19.4s, v2.8h, v11.h[6]
\n
"
"subs %w[cnt], %w[cnt], #1
\n
"
"smlal v12.4s, v3.4h, v8.h[7]
\n
"
"smlal v13.4s, v3.4h, v9.h[7]
\n
"
"smlal v14.4s, v3.4h, v10.h[7]
\n
"
"smlal v15.4s, v3.4h, v11.h[7]
\n
"
"smlal2 v16.4s, v3.8h, v8.h[7]
\n
"
"smlal2 v17.4s, v3.8h, v9.h[7]
\n
"
"smlal2 v18.4s, v3.8h, v10.h[7]
\n
"
"smlal2 v19.4s, v3.8h, v11.h[7]
\n
"
"beq 2f
\n
"
"1:
\n
"
"smlal v20.4s, v0.4h, v4.h[0]
\n
"
"smlal v21.4s, v0.4h, v5.h[0]
\n
"
"smlal v22.4s, v0.4h, v6.h[0]
\n
"
"smlal v23.4s, v0.4h, v7.h[0]
\n
"
"ld1 {v8.8h, v9.8h}, [%[b]], #32
\n
"
//load b0, b1
"ld1 {v10.8h, v11.8h}, [%[b]], #32
\n
"
//load b2, b3
"smlal2 v24.4s, v0.8h, v4.h[0]
\n
"
"smlal2 v25.4s, v0.8h, v5.h[0]
\n
"
"smlal2 v26.4s, v0.8h, v6.h[0]
\n
"
"smlal2 v27.4s, v0.8h, v7.h[0]
\n
"
"ld1 {v2.8h, v3.8h}, [%[a]], #32
\n
"
//load a2, a3
"smlal v20.4s, v1.4h, v4.h[1]
\n
"
"smlal v21.4s, v1.4h, v5.h[1]
\n
"
"smlal v22.4s, v1.4h, v6.h[1]
\n
"
"smlal v23.4s, v1.4h, v7.h[1]
\n
"
"smlal2 v24.4s, v1.8h, v4.h[1]
\n
"
"smlal2 v25.4s, v1.8h, v5.h[1]
\n
"
"smlal2 v26.4s, v1.8h, v6.h[1]
\n
"
"smlal2 v27.4s, v1.8h, v7.h[1]
\n
"
"smlal v12.4s, v0.4h, v8.h[0]
\n
"
"smlal v13.4s, v0.4h, v9.h[0]
\n
"
"smlal v14.4s, v0.4h, v10.h[0]
\n
"
"smlal v15.4s, v0.4h, v11.h[0]
\n
"
"smlal2 v16.4s, v0.8h, v8.h[0]
\n
"
"smlal2 v17.4s, v0.8h, v9.h[0]
\n
"
"smlal2 v18.4s, v0.8h, v10.h[0]
\n
"
"smlal2 v19.4s, v0.8h, v11.h[0]
\n
"
"smlal v12.4s, v1.4h, v8.h[1]
\n
"
"smlal v13.4s, v1.4h, v9.h[1]
\n
"
"smlal v14.4s, v1.4h, v10.h[1]
\n
"
"smlal v15.4s, v1.4h, v11.h[1]
\n
"
"smlal2 v16.4s, v1.8h, v8.h[1]
\n
"
"smlal2 v17.4s, v1.8h, v9.h[1]
\n
"
"smlal2 v18.4s, v1.8h, v10.h[1]
\n
"
"smlal2 v19.4s, v1.8h, v11.h[1]
\n
"
"smlal v20.4s, v2.4h, v4.h[2]
\n
"
"smlal v21.4s, v2.4h, v5.h[2]
\n
"
"smlal v22.4s, v2.4h, v6.h[2]
\n
"
"smlal v23.4s, v2.4h, v7.h[2]
\n
"
"ld1 {v0.8h, v1.8h}, [%[a]], #32
\n
"
//load a0, a1
"smlal2 v24.4s, v2.8h, v4.h[2]
\n
"
"smlal2 v25.4s, v2.8h, v5.h[2]
\n
"
"smlal2 v26.4s, v2.8h, v6.h[2]
\n
"
"smlal2 v27.4s, v2.8h, v7.h[2]
\n
"
"smlal v12.4s, v2.4h, v8.h[2]
\n
"
"smlal v13.4s, v2.4h, v9.h[2]
\n
"
"smlal v14.4s, v2.4h, v10.h[2]
\n
"
"smlal v15.4s, v2.4h, v11.h[2]
\n
"
"smlal2 v16.4s, v2.8h, v8.h[2]
\n
"
"smlal2 v17.4s, v2.8h, v9.h[2]
\n
"
"smlal2 v18.4s, v2.8h, v10.h[2]
\n
"
"smlal2 v19.4s, v2.8h, v11.h[2]
\n
"
"smlal v20.4s, v3.4h, v4.h[3]
\n
"
"smlal v21.4s, v3.4h, v5.h[3]
\n
"
"smlal v22.4s, v3.4h, v6.h[3]
\n
"
"smlal v23.4s, v3.4h, v7.h[3]
\n
"
"smlal2 v24.4s, v3.8h, v4.h[3]
\n
"
"smlal2 v25.4s, v3.8h, v5.h[3]
\n
"
"smlal2 v26.4s, v3.8h, v6.h[3]
\n
"
"smlal2 v27.4s, v3.8h, v7.h[3]
\n
"
"smlal v12.4s, v3.4h, v8.h[3]
\n
"
"smlal v13.4s, v3.4h, v9.h[3]
\n
"
"smlal v14.4s, v3.4h, v10.h[3]
\n
"
"smlal v15.4s, v3.4h, v11.h[3]
\n
"
"smlal2 v16.4s, v3.8h, v8.h[3]
\n
"
"smlal2 v17.4s, v3.8h, v9.h[3]
\n
"
"smlal2 v18.4s, v3.8h, v10.h[3]
\n
"
"smlal2 v19.4s, v3.8h, v11.h[3]
\n
"
"smlal v20.4s, v0.4h, v4.h[4]
\n
"
"smlal v21.4s, v0.4h, v5.h[4]
\n
"
"smlal v22.4s, v0.4h, v6.h[4]
\n
"
"smlal v23.4s, v0.4h, v7.h[4]
\n
"
"smlal2 v24.4s, v0.8h, v4.h[4]
\n
"
"smlal2 v25.4s, v0.8h, v5.h[4]
\n
"
"smlal2 v26.4s, v0.8h, v6.h[4]
\n
"
"smlal2 v27.4s, v0.8h, v7.h[4]
\n
"
"ld1 {v2.8h, v3.8h}, [%[a]], #32
\n
"
//load a2, a3
"smlal v20.4s, v1.4h, v4.h[5]
\n
"
"smlal v21.4s, v1.4h, v5.h[5]
\n
"
"smlal v22.4s, v1.4h, v6.h[5]
\n
"
"smlal v23.4s, v1.4h, v7.h[5]
\n
"
"smlal2 v24.4s, v1.8h, v4.h[5]
\n
"
"smlal2 v25.4s, v1.8h, v5.h[5]
\n
"
"smlal2 v26.4s, v1.8h, v6.h[5]
\n
"
"smlal2 v27.4s, v1.8h, v7.h[5]
\n
"
"smlal v12.4s, v0.4h, v8.h[4]
\n
"
"smlal v13.4s, v0.4h, v9.h[4]
\n
"
"smlal v14.4s, v0.4h, v10.h[4]
\n
"
"smlal v15.4s, v0.4h, v11.h[4]
\n
"
"smlal2 v16.4s, v0.8h, v8.h[4]
\n
"
"smlal2 v17.4s, v0.8h, v9.h[4]
\n
"
"smlal2 v18.4s, v0.8h, v10.h[4]
\n
"
"smlal2 v19.4s, v0.8h, v11.h[4]
\n
"
"smlal v12.4s, v1.4h, v8.h[5]
\n
"
"smlal v13.4s, v1.4h, v9.h[5]
\n
"
"smlal v14.4s, v1.4h, v10.h[5]
\n
"
"smlal v15.4s, v1.4h, v11.h[5]
\n
"
"smlal2 v16.4s, v1.8h, v8.h[5]
\n
"
"smlal2 v17.4s, v1.8h, v9.h[5]
\n
"
"smlal2 v18.4s, v1.8h, v10.h[5]
\n
"
"smlal2 v19.4s, v1.8h, v11.h[5]
\n
"
"smlal v20.4s, v2.4h, v4.h[6]
\n
"
"smlal v21.4s, v2.4h, v5.h[6]
\n
"
"smlal v22.4s, v2.4h, v6.h[6]
\n
"
"smlal v23.4s, v2.4h, v7.h[6]
\n
"
"ld1 {v0.8h, v1.8h}, [%[a]], #32
\n
"
//load a0, a1
"smlal2 v24.4s, v2.8h, v4.h[6]
\n
"
"smlal2 v25.4s, v2.8h, v5.h[6]
\n
"
"smlal2 v26.4s, v2.8h, v6.h[6]
\n
"
"smlal2 v27.4s, v2.8h, v7.h[6]
\n
"
"sub %[b], %[b], #128
\n
"
"add %[b], %[b], %[ldb]
\n
"
"smlal v20.4s, v3.4h, v4.h[7]
\n
"
"smlal v21.4s, v3.4h, v5.h[7]
\n
"
"smlal v22.4s, v3.4h, v6.h[7]
\n
"
"smlal v23.4s, v3.4h, v7.h[7]
\n
"
"smlal2 v24.4s, v3.8h, v4.h[7]
\n
"
"smlal2 v25.4s, v3.8h, v5.h[7]
\n
"
"smlal2 v26.4s, v3.8h, v6.h[7]
\n
"
"smlal2 v27.4s, v3.8h, v7.h[7]
\n
"
"ld1 {v4.8h, v5.8h}, [%[b]], #32
\n
"
//load b0, b1
"ld1 {v6.8h, v7.8h}, [%[b]], #32
\n
"
//load b2, b3
"smlal v12.4s, v2.4h, v8.h[6]
\n
"
"smlal v13.4s, v2.4h, v9.h[6]
\n
"
"smlal v14.4s, v2.4h, v10.h[6]
\n
"
"smlal v15.4s, v2.4h, v11.h[6]
\n
"
"smlal2 v16.4s, v2.8h, v8.h[6]
\n
"
"smlal2 v17.4s, v2.8h, v9.h[6]
\n
"
"smlal2 v18.4s, v2.8h, v10.h[6]
\n
"
"smlal2 v19.4s, v2.8h, v11.h[6]
\n
"
"subs %w[cnt], %w[cnt], #1
\n
"
"smlal v12.4s, v3.4h, v8.h[7]
\n
"
"smlal v13.4s, v3.4h, v9.h[7]
\n
"
"smlal v14.4s, v3.4h, v10.h[7]
\n
"
"smlal v15.4s, v3.4h, v11.h[7]
\n
"
"smlal2 v16.4s, v3.8h, v8.h[7]
\n
"
"smlal2 v17.4s, v3.8h, v9.h[7]
\n
"
"smlal2 v18.4s, v3.8h, v10.h[7]
\n
"
"smlal2 v19.4s, v3.8h, v11.h[7]
\n
"
"bne 1b
\n
"
"2:
\n
"
"stp q20, q24, [%[c]], #32
\n
"
"stp q21, q25, [%[c]], #32
\n
"
"stp q22, q26, [%[c]], #32
\n
"
"stp q23, q27, [%[c]], #32
\n
"
"stp q12, q16, [%[c]], #32
\n
"
"stp q13, q17, [%[c]], #32
\n
"
"stp q14, q18, [%[c]], #32
\n
"
"stp q15, q19, [%[c]], #32
\n
"
:
[
a
]
"+r"
(
a_ptr
),
[
b
]
"+r"
(
b_ptr
),
[
c
]
"+r"
(
C
),
[
cnt
]
"+r"
(
cnt
)
:
[
ldb
]
"r"
(
ldb_byte
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"13"
,
"14"
,
"15"
,
"16"
,
"17"
,
"18"
,
"19"
,
"v20"
,
"v21"
,
"v22"
,
"v23"
,
"v24"
,
"v25"
,
"v26"
,
"v27"
,
"cc"
,
"memory"
);
// clang format on
b
+=
64
;
}
for
(;
n
>
3
;
n
-=
4
)
{
int
cnt
=
kcnt
;
const
int16_t
*
a_ptr
=
A_packed
;
const
int16_t
*
b_ptr
=
b
;
// clang-format off
asm
volatile
(
"ld1 {v0.8h, v1.8h}, [%[a]], #32
\n
"
"ld1 {v4.8h, v5.8h}, [%[b]], #32
\n
"
"smull v8.4s, v0.4h, v4.h[0]
\n
"
"smull v9.4s, v0.4h, v5.h[0]
\n
"
"ld1 {v6.8h, v7.8h}, [%[b]], #32
\n
"
"smull2 v10.4s, v0.8h, v4.h[0]
\n
"
"smull2 v11.4s, v0.8h, v5.h[0]
\n
"
"smlal v8.4s, v1.4h, v4.h[1]
\n
"
"smlal v9.4s, v1.4h, v5.h[1]
\n
"
"smlal2 v10.4s, v1.8h, v4.h[1]
\n
"
"smlal2 v11.4s, v1.8h, v5.h[1]
\n
"
"ld1 {v2.8h, v3.8h}, [%[a]], #32
\n
"
"smull v12.4s, v0.4h, v6.h[0]
\n
"
"smull v13.4s, v0.4h, v7.h[0]
\n
"
"smull2 v14.4s, v0.8h, v6.h[0]
\n
"
"smull2 v15.4s, v0.8h, v7.h[0]
\n
"
"smlal v12.4s, v1.4h, v6.h[1]
\n
"
"smlal v13.4s, v1.4h, v7.h[1]
\n
"
"smlal2 v14.4s, v1.8h, v6.h[1]
\n
"
"smlal2 v15.4s, v1.8h, v7.h[1]
\n
"
"smlal v8.4s, v2.4h, v4.h[2]
\n
"
"smlal v9.4s, v2.4h, v5.h[2]
\n
"
"ld1 {v0.8h, v1.8h}, [%[a]], #32
\n
"
"smlal2 v10.4s, v2.8h, v4.h[2]
\n
"
"smlal2 v11.4s, v2.8h, v5.h[2]
\n
"
"smlal v8.4s, v3.4h, v4.h[3]
\n
"
"smlal v9.4s, v3.4h, v5.h[3]
\n
"
"smlal2 v10.4s, v3.8h, v4.h[3]
\n
"
"smlal2 v11.4s, v3.8h, v5.h[3]
\n
"
"smlal v12.4s, v2.4h, v6.h[2]
\n
"
"smlal v13.4s, v2.4h, v7.h[2]
\n
"
"smlal2 v14.4s, v2.8h, v6.h[2]
\n
"
"smlal2 v15.4s, v2.8h, v7.h[2]
\n
"
"smlal v12.4s, v3.4h, v6.h[3]
\n
"
"smlal v13.4s, v3.4h, v7.h[3]
\n
"
"smlal2 v14.4s, v3.8h, v6.h[3]
\n
"
"smlal2 v15.4s, v3.8h, v7.h[3]
\n
"
"smlal v8.4s, v0.4h, v4.h[4]
\n
"
"smlal v9.4s, v0.4h, v5.h[4]
\n
"
"smlal2 v10.4s, v0.8h, v4.h[4]
\n
"
"smlal2 v11.4s, v0.8h, v5.h[4]
\n
"
"smlal v8.4s, v1.4h, v4.h[5]
\n
"
"smlal v9.4s, v1.4h, v5.h[5]
\n
"
"smlal2 v10.4s, v1.8h, v4.h[5]
\n
"
"smlal2 v11.4s, v1.8h, v5.h[5]
\n
"
"ld1 {v2.8h, v3.8h}, [%[a]], #32
\n
"
"smlal v12.4s, v0.4h, v6.h[4]
\n
"
"smlal v13.4s, v0.4h, v7.h[4]
\n
"
"smlal2 v14.4s, v0.8h, v6.h[4]
\n
"
"smlal2 v15.4s, v0.8h, v7.h[4]
\n
"
"smlal v12.4s, v1.4h, v6.h[5]
\n
"
"smlal v13.4s, v1.4h, v7.h[5]
\n
"
"smlal2 v14.4s, v1.8h, v6.h[5]
\n
"
"smlal2 v15.4s, v1.8h, v7.h[5]
\n
"
"smlal v8.4s, v2.4h, v4.h[6]
\n
"
"smlal v9.4s, v2.4h, v5.h[6]
\n
"
"ld1 {v0.8h, v1.8h}, [%[a]], #32
\n
"
"smlal2 v10.4s, v2.8h, v4.h[6]
\n
"
"smlal2 v11.4s, v2.8h, v5.h[6]
\n
"
"smlal v8.4s, v3.4h, v4.h[7]
\n
"
"smlal v9.4s, v3.4h, v5.h[7]
\n
"
"smlal2 v10.4s, v3.8h, v4.h[7]
\n
"
"smlal2 v11.4s, v3.8h, v5.h[7]
\n
"
"sub %[b], %[b], #64
\n
"
"add %[b], %[b], %[ldb]
\n
"
"smlal v12.4s, v2.4h, v6.h[6]
\n
"
"smlal v13.4s, v2.4h, v7.h[6]
\n
"
"subs %w[cnt], %w[cnt], #1
\n
"
"ld1 {v4.8h, v5.8h}, [%[b]], #32
\n
"
"smlal2 v14.4s, v2.8h, v6.h[6]
\n
"
"smlal2 v15.4s, v2.8h, v7.h[6]
\n
"
"smlal v12.4s, v3.4h, v6.h[7]
\n
"
"smlal v13.4s, v3.4h, v7.h[7]
\n
"
"smlal2 v14.4s, v3.8h, v6.h[7]
\n
"
"smlal2 v15.4s, v3.8h, v7.h[7]
\n
"
"beq 2f
\n
"
"1:
\n
"
"smlal v8.4s, v0.4h, v4.h[0]
\n
"
"smlal v9.4s, v0.4h, v5.h[0]
\n
"
"ld1 {v6.8h, v7.8h}, [%[b]], #32
\n
"
"smlal2 v10.4s, v0.8h, v4.h[0]
\n
"
"smlal2 v11.4s, v0.8h, v5.h[0]
\n
"
"smlal v8.4s, v1.4h, v4.h[1]
\n
"
"smlal v9.4s, v1.4h, v5.h[1]
\n
"
"smlal2 v10.4s, v1.8h, v4.h[1]
\n
"
"smlal2 v11.4s, v1.8h, v5.h[1]
\n
"
"ld1 {v2.8h, v3.8h}, [%[a]], #32
\n
"
"smlal v12.4s, v0.4h, v6.h[0]
\n
"
"smlal v13.4s, v0.4h, v7.h[0]
\n
"
"smlal2 v14.4s, v0.8h, v6.h[0]
\n
"
"smlal2 v15.4s, v0.8h, v7.h[0]
\n
"
"smlal v12.4s, v1.4h, v6.h[1]
\n
"
"smlal v13.4s, v1.4h, v7.h[1]
\n
"
"smlal2 v14.4s, v1.8h, v6.h[1]
\n
"
"smlal2 v15.4s, v1.8h, v7.h[1]
\n
"
"smlal v8.4s, v2.4h, v4.h[2]
\n
"
"smlal v9.4s, v2.4h, v5.h[2]
\n
"
"ld1 {v0.8h, v1.8h}, [%[a]], #32
\n
"
"smlal2 v10.4s, v2.8h, v4.h[2]
\n
"
"smlal2 v11.4s, v2.8h, v5.h[2]
\n
"
"smlal v8.4s, v3.4h, v4.h[3]
\n
"
"smlal v9.4s, v3.4h, v5.h[3]
\n
"
"smlal2 v10.4s, v3.8h, v4.h[3]
\n
"
"smlal2 v11.4s, v3.8h, v5.h[3]
\n
"
"smlal v12.4s, v2.4h, v6.h[2]
\n
"
"smlal v13.4s, v2.4h, v7.h[2]
\n
"
"smlal2 v14.4s, v2.8h, v6.h[2]
\n
"
"smlal2 v15.4s, v2.8h, v7.h[2]
\n
"
"smlal v12.4s, v3.4h, v6.h[3]
\n
"
"smlal v13.4s, v3.4h, v7.h[3]
\n
"
"smlal2 v14.4s, v3.8h, v6.h[3]
\n
"
"smlal2 v15.4s, v3.8h, v7.h[3]
\n
"
"smlal v8.4s, v0.4h, v4.h[4]
\n
"
"smlal v9.4s, v0.4h, v5.h[4]
\n
"
"ld1 {v2.8h, v3.8h}, [%[a]], #32
\n
"
"smlal2 v10.4s, v0.8h, v4.h[4]
\n
"
"smlal2 v11.4s, v0.8h, v5.h[4]
\n
"
"smlal v8.4s, v1.4h, v4.h[5]
\n
"
"smlal v9.4s, v1.4h, v5.h[5]
\n
"
"smlal2 v10.4s, v1.8h, v4.h[5]
\n
"
"smlal2 v11.4s, v1.8h, v5.h[5]
\n
"
"smlal v12.4s, v0.4h, v6.h[4]
\n
"
"smlal v13.4s, v0.4h, v7.h[4]
\n
"
"smlal2 v14.4s, v0.8h, v6.h[4]
\n
"
"smlal2 v15.4s, v0.8h, v7.h[4]
\n
"
"smlal v12.4s, v1.4h, v6.h[5]
\n
"
"smlal v13.4s, v1.4h, v7.h[5]
\n
"
"smlal2 v14.4s, v1.8h, v6.h[5]
\n
"
"smlal2 v15.4s, v1.8h, v7.h[5]
\n
"
"smlal v8.4s, v2.4h, v4.h[6]
\n
"
"smlal v9.4s, v2.4h, v5.h[6]
\n
"
"ld1 {v0.8h, v1.8h}, [%[a]], #32
\n
"
"smlal2 v10.4s, v2.8h, v4.h[6]
\n
"
"smlal2 v11.4s, v2.8h, v5.h[6]
\n
"
"smlal v8.4s, v3.4h, v4.h[7]
\n
"
"smlal v9.4s, v3.4h, v5.h[7]
\n
"
"smlal2 v10.4s, v3.8h, v4.h[7]
\n
"
"smlal2 v11.4s, v3.8h, v5.h[7]
\n
"
"sub %[b], %[b], #64
\n
"
"add %[b], %[b], %[ldb]
\n
"
"smlal v12.4s, v2.4h, v6.h[6]
\n
"
"smlal v13.4s, v2.4h, v7.h[6]
\n
"
"subs %w[cnt], %w[cnt], #1
\n
"
"ld1 {v4.8h, v5.8h}, [%[b]], #32
\n
"
"smlal2 v14.4s, v2.8h, v6.h[6]
\n
"
"smlal2 v15.4s, v2.8h, v7.h[6]
\n
"
"smlal v12.4s, v3.4h, v6.h[7]
\n
"
"smlal v13.4s, v3.4h, v7.h[7]
\n
"
"smlal2 v14.4s, v3.8h, v6.h[7]
\n
"
"smlal2 v15.4s, v3.8h, v7.h[7]
\n
"
"bne 1b
\n
"
"2:
\n
"
"stp q8, q10, [%[c]], #32
\n
"
"stp q9, q11, [%[c]], #32
\n
"
"stp q12, q14, [%[c]], #32
\n
"
"stp q13, q15, [%[c]], #32
\n
"
:
[
a
]
"+r"
(
a_ptr
),
[
b
]
"+r"
(
b_ptr
),
[
c
]
"+r"
(
C
),
[
cnt
]
"+r"
(
cnt
)
:
[
ldb
]
"r"
(
ldb_byte
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"v7"
,
"v8"
,
"v9"
,
"v10"
,
"v11"
,
"v12"
,
"v13"
,
"v14"
,
"v15"
,
"cc"
,
"memory"
);
// clang-format on
b
+=
32
;
}
for
(;
n
>
0
;
--
n
)
{
int
cnt
=
kcnt
;
const
int16_t
*
a_ptr
=
A_packed
;
const
int16_t
*
b_ptr
=
b
;
// clang-format off
asm
volatile
(
"ld1 {v0.8h, v1.8h}, [%[a]], #32
\n
"
"ld1 {v4.8h}, [%[b]], #16
\n
"
"ld1 {v2.8h, v3.8h}, [%[a]], #32
\n
"
"smull v5.4s, v0.4h, v4.h[0]
\n
"
"smull2 v6.4s, v0.8h, v4.h[0]
\n
"
"ld1 {v10.8h, v11.8h}, [%[a]], #32
\n
"
"smlal v5.4s, v1.4h, v4.h[1]
\n
"
"smlal2 v6.4s, v1.8h, v4.h[1]
\n
"
"ld1 {v12.8h, v13.8h}, [%[a]], #32
\n
"
"smlal v5.4s, v2.4h, v4.h[2]
\n
"
"smlal2 v6.4s, v2.8h, v4.h[2]
\n
"
"smlal v5.4s, v3.4h, v4.h[3]
\n
"
"smlal2 v6.4s, v3.8h, v4.h[3]
\n
"
"sub %[b], %[b], #16
\n
"
"add %[b], %[b], %[ldb]
\n
"
"smlal v5.4s, v10.4h, v4.h[4]
\n
"
"smlal2 v6.4s, v10.8h, v4.h[4]
\n
"
"smlal v5.4s, v11.4h, v4.h[5]
\n
"
"smlal2 v6.4s, v11.8h, v4.h[5]
\n
"
"subs %w[cnt], %w[cnt], #1
\n
"
"ld1 {v0.8h, v1.8h}, [%[a]], #32
\n
"
"smlal v5.4s, v12.4h, v4.h[6]
\n
"
"smlal2 v6.4s, v12.8h, v4.h[6]
\n
"
"smlal v5.4s, v13.4h, v4.h[7]
\n
"
"smlal2 v6.4s, v13.8h, v4.h[7]
\n
"
"beq 2f
\n
"
"1:
\n
"
"ld1 {v4.8h}, [%[b]], #16
\n
"
"ld1 {v2.8h, v3.8h}, [%[a]], #32
\n
"
"smlal v5.4s, v0.4h, v4.h[0]
\n
"
"smlal2 v6.4s, v0.8h, v4.h[0]
\n
"
"ld1 {v10.8h, v11.8h}, [%[a]], #32
\n
"
"smlal v5.4s, v1.4h, v4.h[1]
\n
"
"smlal2 v6.4s, v1.8h, v4.h[1]
\n
"
"ld1 {v12.8h, v13.8h}, [%[a]], #32
\n
"
"smlal v5.4s, v2.4h, v4.h[2]
\n
"
"smlal2 v6.4s, v2.8h, v4.h[2]
\n
"
"smlal v5.4s, v3.4h, v4.h[3]
\n
"
"smlal2 v6.4s, v3.8h, v4.h[3]
\n
"
"sub %[b], %[b], #16
\n
"
"add %[b], %[b], %[ldb]
\n
"
"smlal v5.4s, v10.4h, v4.h[4]
\n
"
"smlal2 v6.4s, v10.8h, v4.h[4]
\n
"
"smlal v5.4s, v11.4h, v4.h[5]
\n
"
"smlal2 v6.4s, v11.8h, v4.h[5]
\n
"
"subs %w[cnt], %w[cnt], #1
\n
"
"ld1 {v0.8h, v1.8h}, [%[a]], #32
\n
"
"smlal v5.4s, v12.4h, v4.h[6]
\n
"
"smlal2 v6.4s, v12.8h, v4.h[6]
\n
"
"smlal v5.4s, v13.4h, v4.h[7]
\n
"
"smlal2 v6.4s, v13.8h, v4.h[7]
\n
"
"bne 1b
\n
"
"2:
\n
"
"st1 {v5.4s, v6.4s}, [%[c]], #32
\n
"
:
[
a
]
"+r"
(
a_ptr
),
[
b
]
"+r"
(
b_ptr
),
[
c
]
"+r"
(
C
),
[
cnt
]
"+r"
(
cnt
)
:
[
ldb
]
"r"
(
ldb_byte
)
:
"v0"
,
"v1"
,
"v2"
,
"v3"
,
"v4"
,
"v5"
,
"v6"
,
"cc"
,
"memory"
);
// clang-format on
b
+=
8
;
}
#else
for
(;
n
>
3
;
n
-=
4
)
{
int
cnt
=
kcnt
;
const
int16_t
*
a_ptr
=
A_packed
;
const
int16_t
*
b_ptr
=
b
;
// clang-format off
asm
volatile
(
"vld1.16 {d0-d3}, [%[b]]!
\n
"
"vld1.16 {d8-d11}, [%[a]]!
\n
"
"vld1.16 {d4-d7}, [%[b]]!
\n
"
"vmull.s16 q8, d8, d0[0]
\n
"
"vmull.s16 q9, d8, d2[0]
\n
"
"vld1.16 {d12-d15}, [%[a]]!
\n
"
"vmull.s16 q10, d9, d0[0]
\n
"
"vmull.s16 q11, d9, d2[0]
\n
"
"vmlal.s16 q8, d10, d0[1]
\n
"
"vmlal.s16 q9, d10, d2[1]
\n
"
"vmlal.s16 q10, d11, d0[1]
\n
"
"vmlal.s16 q11, d11, d2[1]
\n
"
"vmull.s16 q12, d8, d4[0]
\n
"
"vmull.s16 q13, d8, d6[0]
\n
"
"vmull.s16 q14, d9, d4[0]
\n
"
"vmull.s16 q15, d9, d6[0]
\n
"
"vmlal.s16 q12, d10, d4[1]
\n
"
"vmlal.s16 q13, d10, d6[1]
\n
"
"vmlal.s16 q14, d11, d4[1]
\n
"
"vmlal.s16 q15, d11, d6[1]
\n
"
"vmlal.s16 q8, d12, d0[2]
\n
"
"vmlal.s16 q9, d12, d2[2]
\n
"
"vld1.16 {d8-d11}, [%[a]]!
\n
"
"vmlal.s16 q10, d13, d0[2]
\n
"
"vmlal.s16 q11, d13, d2[2]
\n
"
"vmlal.s16 q8, d14, d0[3]
\n
"
"vmlal.s16 q9, d14, d2[3]
\n
"
"vmlal.s16 q10, d15, d0[3]
\n
"
"vmlal.s16 q11, d15, d2[3]
\n
"
"vmlal.s16 q12, d12, d4[2]
\n
"
"vmlal.s16 q13, d12, d6[2]
\n
"
"vmlal.s16 q14, d13, d4[2]
\n
"
"vmlal.s16 q15, d13, d6[2]
\n
"
"vmlal.s16 q12, d14, d4[3]
\n
"
"vmlal.s16 q13, d14, d6[3]
\n
"
"vmlal.s16 q14, d15, d4[3]
\n
"
"vmlal.s16 q15, d15, d6[3]
\n
"
"sub %[b], %[b], #64
\n
"
"add %[b], %[b], %[ldb]
\n
"
"vld1.16 {d12-d15}, [%[a]]!
\n
"
"vmlal.s16 q8, d8, d1[0]
\n
"
"vmlal.s16 q9, d8, d3[0]
\n
"
"vmlal.s16 q10, d9, d1[0]
\n
"
"vmlal.s16 q11, d9, d3[0]
\n
"
"vmlal.s16 q8, d10, d1[1]
\n
"
"vmlal.s16 q9, d10, d3[1]
\n
"
"vmlal.s16 q10, d11, d1[1]
\n
"
"vmlal.s16 q11, d11, d3[1]
\n
"
"vmlal.s16 q8, d12, d1[2]
\n
"
"vmlal.s16 q9, d12, d3[2]
\n
"
"vmlal.s16 q10, d13, d1[2]
\n
"
"vmlal.s16 q11, d13, d3[2]
\n
"
"vmlal.s16 q8, d14, d1[3]
\n
"
"vmlal.s16 q9, d14, d3[3]
\n
"
"vmlal.s16 q10, d15, d1[3]
\n
"
"vmlal.s16 q11, d15, d3[3]
\n
"
"vld1.16 {d0-d3}, [%[b]]!
\n
"
"vmlal.s16 q12, d8, d5[0]
\n
"
"vmlal.s16 q13, d8, d7[0]
\n
"
"vmlal.s16 q14, d9, d5[0]
\n
"
"vmlal.s16 q15, d9, d7[0]
\n
"
"vmlal.s16 q12, d10, d5[1]
\n
"
"vmlal.s16 q13, d10, d7[1]
\n
"
"subs %[cnt], %[cnt], #1
\n
"
"vmlal.s16 q14, d11, d5[1]
\n
"
"vmlal.s16 q15, d11, d7[1]
\n
"
"vld1.16 {d8-d11}, [%[a]]!
\n
"
"vmlal.s16 q12, d12, d5[2]
\n
"
"vmlal.s16 q13, d12, d7[2]
\n
"
"vmlal.s16 q14, d13, d5[2]
\n
"
"vmlal.s16 q15, d13, d7[2]
\n
"
"vmlal.s16 q12, d14, d5[3]
\n
"
"vmlal.s16 q13, d14, d7[3]
\n
"
"vmlal.s16 q14, d15, d5[3]
\n
"
"vmlal.s16 q15, d15, d7[3]
\n
"
"beq 2f
\n
"
"1:
\n
"
"vld1.16 {d4-d7}, [%[b]]!
\n
"
"vmlal.s16 q8, d8, d0[0]
\n
"
"vmlal.s16 q9, d8, d2[0]
\n
"
"vld1.16 {d12-d15}, [%[a]]!
\n
"
"vmlal.s16 q10, d9, d0[0]
\n
"
"vmlal.s16 q11, d9, d2[0]
\n
"
"vmlal.s16 q8, d10, d0[1]
\n
"
"vmlal.s16 q9, d10, d2[1]
\n
"
"vmlal.s16 q10, d11, d0[1]
\n
"
"vmlal.s16 q11, d11, d2[1]
\n
"
"vmlal.s16 q12, d8, d4[0]
\n
"
"vmlal.s16 q13, d8, d6[0]
\n
"
"vmlal.s16 q14, d9, d4[0]
\n
"
"vmlal.s16 q15, d9, d6[0]
\n
"
"vmlal.s16 q12, d10, d4[1]
\n
"
"vmlal.s16 q13, d10, d6[1]
\n
"
"vmlal.s16 q14, d11, d4[1]
\n
"
"vmlal.s16 q15, d11, d6[1]
\n
"
"vmlal.s16 q8, d12, d0[2]
\n
"
"vmlal.s16 q9, d12, d2[2]
\n
"
"vld1.16 {d8-d11}, [%[a]]!
\n
"
"vmlal.s16 q10, d13, d0[2]
\n
"
"vmlal.s16 q11, d13, d2[2]
\n
"
"vmlal.s16 q8, d14, d0[3]
\n
"
"vmlal.s16 q9, d14, d2[3]
\n
"
"vmlal.s16 q10, d15, d0[3]
\n
"
"vmlal.s16 q11, d15, d2[3]
\n
"
"vmlal.s16 q12, d12, d4[2]
\n
"
"vmlal.s16 q13, d12, d6[2]
\n
"
"vmlal.s16 q14, d13, d4[2]
\n
"
"vmlal.s16 q15, d13, d6[2]
\n
"
"vmlal.s16 q12, d14, d4[3]
\n
"
"vmlal.s16 q13, d14, d6[3]
\n
"
"vmlal.s16 q14, d15, d4[3]
\n
"
"vmlal.s16 q15, d15, d6[3]
\n
"
"sub %[b], %[b], #64
\n
"
"add %[b], %[b], %[ldb]
\n
"
"vld1.16 {d12-d15}, [%[a]]!
\n
"
"vmlal.s16 q8, d8, d1[0]
\n
"
"vmlal.s16 q9, d8, d3[0]
\n
"
"vmlal.s16 q10, d9, d1[0]
\n
"
"vmlal.s16 q11, d9, d3[0]
\n
"
"vmlal.s16 q8, d10, d1[1]
\n
"
"vmlal.s16 q9, d10, d3[1]
\n
"
"vmlal.s16 q10, d11, d1[1]
\n
"
"vmlal.s16 q11, d11, d3[1]
\n
"
"vmlal.s16 q8, d12, d1[2]
\n
"
"vmlal.s16 q9, d12, d3[2]
\n
"
"vmlal.s16 q10, d13, d1[2]
\n
"
"vmlal.s16 q11, d13, d3[2]
\n
"
"vmlal.s16 q8, d14, d1[3]
\n
"
"vmlal.s16 q9, d14, d3[3]
\n
"
"vmlal.s16 q10, d15, d1[3]
\n
"
"vmlal.s16 q11, d15, d3[3]
\n
"
"vld1.16 {d0-d3}, [%[b]]!
\n
"
"vmlal.s16 q12, d8, d5[0]
\n
"
"vmlal.s16 q13, d8, d7[0]
\n
"
"vmlal.s16 q14, d9, d5[0]
\n
"
"vmlal.s16 q15, d9, d7[0]
\n
"
"vmlal.s16 q12, d10, d5[1]
\n
"
"vmlal.s16 q13, d10, d7[1]
\n
"
"subs %[cnt], %[cnt], #1
\n
"
"vmlal.s16 q14, d11, d5[1]
\n
"
"vmlal.s16 q15, d11, d7[1]
\n
"
"vld1.16 {d8-d11}, [%[a]]!
\n
"
"vmlal.s16 q12, d12, d5[2]
\n
"
"vmlal.s16 q13, d12, d7[2]
\n
"
"vmlal.s16 q14, d13, d5[2]
\n
"
"vmlal.s16 q15, d13, d7[2]
\n
"
"vmlal.s16 q12, d14, d5[3]
\n
"
"vmlal.s16 q13, d14, d7[3]
\n
"
"vmlal.s16 q14, d15, d5[3]
\n
"
"vmlal.s16 q15, d15, d7[3]
\n
"
"bne 1b
\n
"
"2:
\n
"
"vst1.32 {d16-d17}, [%[c]]!
\n
"
"vst1.32 {d20-d21}, [%[c]]!
\n
"
"vst1.32 {d18-d19}, [%[c]]!
\n
"
"vst1.32 {d22-d23}, [%[c]]!
\n
"
"vst1.32 {d24-d25}, [%[c]]!
\n
"
"vst1.32 {d28-d29}, [%[c]]!
\n
"
"vst1.32 {d26-d27}, [%[c]]!
\n
"
"vst1.32 {d30-d31}, [%[c]]!
\n
"
:
[
a
]
"+r"
(
a_ptr
),
[
b
]
"+r"
(
b_ptr
),
[
c
]
"+r"
(
C
),
[
cnt
]
"+r"
(
cnt
)
:
[
ldb
]
"r"
(
ldb_byte
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"q10"
,
"q11"
,
"q12"
,
"q13"
,
"q14"
,
"q15"
,
"cc"
,
"memory"
);
// clang format on
b
+=
32
;
}
for
(;
n
>
0
;
--
n
)
{
int
cnt
=
kcnt
;
const
int16_t
*
a_ptr
=
A_packed
;
const
int16_t
*
b_ptr
=
b
;
// clang format off
asm
volatile
(
"vld1.16 {d0-d1}, [%[b]]!
\n
"
"vld1.16 {d4-d7}, [%[a]]!
\n
"
"vld1.16 {d8-d11}, [%[a]]!
\n
"
"vmull.s16 q8, d4, d0[0]
\n
"
"vmull.s16 q9, d5, d0[0]
\n
"
"sub %[b], %[b], #16
\n
"
"vmlal.s16 q8, d6, d0[1]
\n
"
"vmlal.s16 q9, d7, d0[1]
\n
"
"add %[b], %[b], %[ldb]
\n
"
"subs %[cnt], %[cnt], #1
\n
"
"vld1.16 {d4-d7}, [%[a]]!
\n
"
"vmlal.s16 q8, d8, d0[2]
\n
"
"vmlal.s16 q9, d9, d0[2]
\n
"
"vmlal.s16 q8, d10, d0[3]
\n
"
"vmlal.s16 q9, d11, d0[3]
\n
"
"vld1.16 {d8-d11}, [%[a]]!
\n
"
"vmlal.s16 q8, d4, d1[0]
\n
"
"vmlal.s16 q9, d5, d1[0]
\n
"
"vmlal.s16 q8, d6, d1[1]
\n
"
"vmlal.s16 q9, d7, d1[1]
\n
"
"vld1.16 {d4-d7}, [%[a]]!
\n
"
"vmlal.s16 q8, d8, d1[2]
\n
"
"vmlal.s16 q9, d9, d1[2]
\n
"
"vmlal.s16 q8, d10, d1[3]
\n
"
"vmlal.s16 q9, d11, d1[3]
\n
"
"beq 2f
\n
"
"1:
\n
"
"vld1.16 {d0-d1}, [%[b]]!
\n
"
"vld1.16 {d8-d11}, [%[a]]!
\n
"
"vmlal.s16 q8, d4, d0[0]
\n
"
"vmlal.s16 q9, d5, d0[0]
\n
"
"sub %[b], %[b], #16
\n
"
"vmlal.s16 q8, d6, d0[1]
\n
"
"vmlal.s16 q9, d7, d0[1]
\n
"
"add %[b], %[b], %[ldb]
\n
"
"subs %[cnt], %[cnt], #1
\n
"
"vld1.16 {d4-d7}, [%[a]]!
\n
"
"vmlal.s16 q8, d8, d0[2]
\n
"
"vmlal.s16 q9, d9, d0[2]
\n
"
"vmlal.s16 q8, d10, d0[3]
\n
"
"vmlal.s16 q9, d11, d0[3]
\n
"
"vld1.16 {d8-d11}, [%[a]]!
\n
"
"vmlal.s16 q8, d4, d1[0]
\n
"
"vmlal.s16 q9, d5, d1[0]
\n
"
"vmlal.s16 q8, d6, d1[1]
\n
"
"vmlal.s16 q9, d7, d1[1]
\n
"
"vld1.16 {d4-d7}, [%[a]]!
\n
"
"vmlal.s16 q8, d8, d1[2]
\n
"
"vmlal.s16 q9, d9, d1[2]
\n
"
"vmlal.s16 q8, d10, d1[3]
\n
"
"vmlal.s16 q9, d11, d1[3]
\n
"
"bne 1b
\n
"
"2:
\n
"
"vst1.32 {d16-d19}, [%[c]]!
\n
"
:
[
a
]
"+r"
(
a_ptr
),
[
b
]
"+r"
(
b_ptr
),
[
c
]
"+r"
(
C
),
[
cnt
]
"+r"
(
cnt
)
:
[
ldb
]
"r"
(
ldb_byte
)
:
"q0"
,
"q1"
,
"q2"
,
"q3"
,
"q4"
,
"q5"
,
"q6"
,
"q7"
,
"q8"
,
"q9"
,
"cc"
,
"memory"
);
// clang-format on
b
+=
8
;
}
#endif
A_packed
+=
lda
;
}
}
void
sgemm_prepack_c4
(
int
M
,
void
sgemm_prepack_c4
(
int
M
,
int
N
,
int
N
,
int
K
,
int
K
,
...
...
lite/backends/arm/math/packed_sgemm_c4.h
浏览文件 @
34f7b509
...
@@ -54,6 +54,13 @@ void sgemm_prepack_c4_small(int M,
...
@@ -54,6 +54,13 @@ void sgemm_prepack_c4_small(int M,
const
float
*
B
,
const
float
*
B
,
float
*
C
,
float
*
C
,
ARMContext
*
ctx
);
ARMContext
*
ctx
);
void
sgemm_prepack_c8_int16_small
(
int
M
,
int
N
,
int
K
,
const
int16_t
*
A_packed
,
const
int16_t
*
B
,
int32_t
*
C
,
ARMContext
*
ctx
);
}
// namespace math
}
// namespace math
}
// namespace arm
}
// namespace arm
}
// namespace lite
}
// namespace lite
...
...
lite/kernels/arm/conv_compute.cc
浏览文件 @
34f7b509
...
@@ -73,7 +73,6 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
...
@@ -73,7 +73,6 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
// VLOG(3) << "invoking dw conv";
// VLOG(3) << "invoking dw conv";
}
else
if
(
param
.
groups
==
1
&&
kw
==
3
&&
stride
==
1
&&
ks_equal
&&
}
else
if
(
param
.
groups
==
1
&&
kw
==
3
&&
stride
==
1
&&
ks_equal
&&
no_dilation
)
{
no_dilation
)
{
// TODO(MyPandaShaoxiang): winograd conv support any pad
impl_
=
new
WinogradConv
<
PRECISION
(
kFloat
),
PRECISION
(
kFloat
)
>
;
impl_
=
new
WinogradConv
<
PRECISION
(
kFloat
),
PRECISION
(
kFloat
)
>
;
// VLOG(3) << "invoking winograd conv";
// VLOG(3) << "invoking winograd conv";
}
else
if
(
param
.
groups
==
1
&&
kw
==
3
&&
stride
==
2
&&
}
else
if
(
param
.
groups
==
1
&&
kw
==
3
&&
stride
==
2
&&
...
@@ -121,9 +120,9 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
...
@@ -121,9 +120,9 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
no_dilation
&&
flag_dw
)
{
no_dilation
&&
flag_dw
)
{
impl_
=
new
DepthwiseConv
<
PRECISION
(
kInt8
),
PRECISION
(
kFloat
)
>
;
impl_
=
new
DepthwiseConv
<
PRECISION
(
kInt8
),
PRECISION
(
kFloat
)
>
;
// VLOG(3) << "Run DepthwiseConv Int8";
// VLOG(3) << "Run DepthwiseConv Int8";
}
else
if
(
param
.
groups
==
1
&&
kw
==
3
&&
(
sw
==
1
||
sw
==
2
)
&&
}
else
if
(
param
.
groups
==
1
&&
kw
==
3
&&
sw
==
1
&&
no_dilation
&&
kps_equal
&&
no_dilation
)
{
pads_equal
)
{
impl_
=
new
Direct
Conv
<
PRECISION
(
kInt8
),
PRECISION
(
kFloat
)
>
;
impl_
=
new
Winograd
Conv
<
PRECISION
(
kInt8
),
PRECISION
(
kFloat
)
>
;
// VLOG(3) << "Run DirectConv Int8";
// VLOG(3) << "Run DirectConv Int8";
}
else
{
}
else
{
impl_
=
new
GemmLikeConv
<
PRECISION
(
kInt8
),
PRECISION
(
kFloat
)
>
;
impl_
=
new
GemmLikeConv
<
PRECISION
(
kInt8
),
PRECISION
(
kFloat
)
>
;
...
@@ -166,9 +165,9 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
...
@@ -166,9 +165,9 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
no_dilation
&&
flag_dw
)
{
no_dilation
&&
flag_dw
)
{
impl_
=
new
DepthwiseConv
<
PRECISION
(
kInt8
),
PRECISION
(
kInt8
)
>
;
impl_
=
new
DepthwiseConv
<
PRECISION
(
kInt8
),
PRECISION
(
kInt8
)
>
;
// VLOG(3) << "Run DepthwiseConv Int8";
// VLOG(3) << "Run DepthwiseConv Int8";
}
else
if
(
param
.
groups
==
1
&&
kw
==
3
&&
(
sw
==
1
||
sw
==
2
)
&&
}
else
if
(
param
.
groups
==
1
&&
kw
==
3
&&
sw
==
1
&&
no_dilation
&&
kps_equal
&&
no_dilation
)
{
pads_equal
)
{
impl_
=
new
Direct
Conv
<
PRECISION
(
kInt8
),
PRECISION
(
kInt8
)
>
;
impl_
=
new
Winograd
Conv
<
PRECISION
(
kInt8
),
PRECISION
(
kInt8
)
>
;
// VLOG(3) << "Run DirectConv Int8";
// VLOG(3) << "Run DirectConv Int8";
}
else
{
}
else
{
impl_
=
new
GemmLikeConv
<
PRECISION
(
kInt8
),
PRECISION
(
kInt8
)
>
;
impl_
=
new
GemmLikeConv
<
PRECISION
(
kInt8
),
PRECISION
(
kInt8
)
>
;
...
...
lite/kernels/arm/conv_winograd.cc
浏览文件 @
34f7b509
...
@@ -13,7 +13,6 @@
...
@@ -13,7 +13,6 @@
// limitations under the License.
// limitations under the License.
#include "lite/kernels/arm/conv_winograd.h"
#include "lite/kernels/arm/conv_winograd.h"
#include <vector>
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/packed_sgemm.h"
#include "lite/backends/arm/math/packed_sgemm.h"
...
@@ -166,6 +165,186 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
...
@@ -166,6 +165,186 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
}
}
}
}
template
<
PrecisionType
OutType
>
void
WinogradConv
<
PRECISION
(
kInt8
),
OutType
>::
ReInitWhenNeeded
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
ARMContext
>();
int
threads
=
ctx
.
threads
();
auto
x_dims
=
param
.
x
->
dims
();
auto
w_dims
=
param
.
filter
->
dims
();
auto
o_dims
=
param
.
output
->
dims
();
if
(
last_shape_
==
x_dims
)
{
return
;
}
last_shape_
=
x_dims
;
//! update workspace size
int
ic
=
x_dims
[
1
];
int
ih
=
x_dims
[
2
];
int
iw
=
x_dims
[
3
];
int
oc
=
o_dims
[
1
];
int
oh
=
o_dims
[
2
];
int
ow
=
o_dims
[
3
];
int
tile_block
=
8
;
auto
pad
=
*
(
param
.
paddings
);
int
pad_h0
=
pad
[
0
];
int
pad_h1
=
pad
[
1
];
int
pad_w0
=
pad
[
2
];
int
pad_w1
=
pad
[
3
];
int
oc_pad
=
(
oc
+
7
)
/
8
*
8
;
int
ic_pad
=
(
ic
+
7
)
/
8
*
8
;
const
int
new_input_size
=
ic_pad
*
(
ih
+
pad_h0
+
pad_h1
)
*
(
iw
+
pad_w0
+
pad_w1
)
+
oc_pad
*
oh
*
ow
*
sizeof
(
int32_t
);
int
tmp_input_thread_size_byte
=
tile_block
*
ic_pad
*
wino_iw
*
wino_iw
*
sizeof
(
int16_t
);
int
tmp_output_thread_size_byte
=
tile_block
*
oc_pad
*
wino_iw
*
wino_iw
*
sizeof
(
int32_t
);
const
int
temp_size
=
(
tmp_input_thread_size_byte
+
tmp_output_thread_size_byte
+
wino_iw
*
wino_iw
*
(
8
+
8
*
sizeof
(
int32_t
)))
*
threads
;
workspace_size_
=
temp_size
+
new_input_size
;
//! update trans weights impl
// choose_small_ = ow * oh / (tile_block * threads) < 36 ? true : false;
// we only support 2x2 now
choose_small_
=
true
;
float
w_fact
=
0.25
;
if
(
choose_small_
)
{
wino_iw
=
4
;
if
(
last_function_
==
0
)
{
return
;
}
last_function_
=
0
;
}
else
{
wino_iw
=
6
;
if
(
last_function_
==
1
)
{
return
;
}
last_function_
=
1
;
}
/// update scale
for
(
auto
&
ws
:
w_scale_
)
{
ws
*=
w_fact
;
}
weights_
.
Resize
({
1
,
1
,
1
,
wino_iw
*
wino_iw
*
oc_pad
*
ic_pad
});
void
*
trans_tmp_ptr
=
malloc
(
sizeof
(
int16_t
)
*
wino_iw
*
wino_iw
*
oc
*
ic
);
auto
weights_data_
=
weights_
.
mutable_data
<
int16_t
>
();
if
(
!
choose_small_
)
{
}
else
{
lite
::
arm
::
math
::
weight_trans_c8_4x4_int8
(
weights_data_
,
param
.
filter
->
template
data
<
int8_t
>(),
ic
,
oc
,
trans_tmp_ptr
);
}
free
(
trans_tmp_ptr
);
}
template
<
PrecisionType
OutType
>
void
WinogradConv
<
PRECISION
(
kInt8
),
OutType
>::
PrepareForRun
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
w_scale_
=
param
.
weight_scale
;
if
(
w_scale_
.
size
()
!=
1
&&
w_scale_
.
size
()
!=
param
.
filter
->
dims
()[
0
])
{
LOG
(
FATAL
)
<<
"weights scale size must equal to filter size"
;
return
;
}
if
(
w_scale_
.
size
()
==
1
)
{
for
(
int
i
=
0
;
i
<
param
.
filter
->
dims
()[
0
]
-
1
;
++
i
)
{
w_scale_
.
push_back
(
w_scale_
[
0
]);
}
}
float
input_scale
=
param
.
input_scale
;
for
(
auto
&
ws
:
w_scale_
)
{
ws
*=
input_scale
;
}
if
(
param
.
bias
)
{
bias_
.
Resize
(
param
.
bias
->
dims
());
auto
ptr
=
bias_
.
mutable_data
<
float
>
();
auto
ptr_in
=
param
.
bias
->
template
data
<
float
>();
for
(
int
i
=
0
;
i
<
bias_
.
numel
();
++
i
)
{
ptr
[
i
]
=
ptr_in
[
i
];
}
}
if
(
OutType
==
PRECISION
(
kInt8
))
{
float
output_scale
=
param
.
output_scale
;
for
(
auto
&
ws
:
w_scale_
)
{
ws
/=
output_scale
;
}
if
(
param
.
bias
)
{
auto
ptr
=
bias_
.
mutable_data
<
float
>
();
for
(
int
i
=
0
;
i
<
bias_
.
numel
();
++
i
)
{
ptr
[
i
]
/=
output_scale
;
}
}
}
ReInitWhenNeeded
();
}
template
<
PrecisionType
OutType
>
void
WinogradConv
<
PRECISION
(
kInt8
),
OutType
>::
Run
()
{
auto
&
param
=
this
->
Param
<
param_t
>
();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
ARMContext
>();
ctx
.
ExtendWorkspace
(
workspace_size_
);
const
auto
*
i_data
=
param
.
x
->
template
data
<
int8_t
>();
const
auto
*
w_data
=
weights_
.
data
<
int16_t
>
();
const
auto
*
b_data
=
param
.
bias
?
bias_
.
data
<
float
>
()
:
nullptr
;
// const float* i_data;
auto
x_dims
=
param
.
x
->
dims
();
auto
w_dims
=
param
.
filter
->
dims
();
auto
o_dims
=
param
.
output
->
dims
();
int
iw
=
x_dims
[
3
];
// nchw
int
ih
=
x_dims
[
2
];
int
ic
=
x_dims
[
1
];
int
bs
=
x_dims
[
0
];
int
oh
=
o_dims
[
2
];
int
ow
=
o_dims
[
3
];
int
oc
=
o_dims
[
1
];
// now always choose small
if
(
OutType
==
PRECISION
(
kInt8
))
{
auto
*
o_data
=
param
.
output
->
template
mutable_data
<
int8_t
>();
lite
::
arm
::
math
::
conv_compute_2x2_3x3_int8
<
int8_t
>
(
i_data
,
o_data
,
bs
,
oc
,
oh
,
ow
,
ic
,
ih
,
iw
,
w_data
,
b_data
,
w_scale_
.
data
(),
param
,
&
ctx
);
}
else
{
auto
*
o_data
=
param
.
output
->
template
mutable_data
<
float
>();
lite
::
arm
::
math
::
conv_compute_2x2_3x3_int8
<
float
>
(
i_data
,
o_data
,
bs
,
oc
,
oh
,
ow
,
ic
,
ih
,
iw
,
w_data
,
b_data
,
w_scale_
.
data
(),
param
,
&
ctx
);
}
}
template
class
WinogradConv
<
PRECISION
(
kInt8
),
PRECISION
(
kInt8
)>;
template
class
WinogradConv
<
PRECISION
(
kInt8
),
PRECISION
(
kFloat
)>;
}
// namespace arm
}
// namespace arm
}
// namespace kernels
}
// namespace kernels
}
// namespace lite
}
// namespace lite
...
...
lite/kernels/arm/conv_winograd.h
浏览文件 @
34f7b509
...
@@ -15,11 +15,11 @@
...
@@ -15,11 +15,11 @@
#pragma once
#pragma once
#include <cmath>
#include <cmath>
#include <vector>
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/core/context.h"
#include "lite/core/context.h"
#include "lite/core/kernel.h"
#include "lite/core/kernel.h"
#include "lite/core/target_wrapper.h"
#include "lite/core/target_wrapper.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
namespace
kernels
{
namespace
kernels
{
...
@@ -44,7 +44,27 @@ class WinogradConv : public KernelLite<TARGET(kARM), Ptype> {
...
@@ -44,7 +44,27 @@ class WinogradConv : public KernelLite<TARGET(kARM), Ptype> {
bool
choose_small_
{
false
};
bool
choose_small_
{
false
};
int
wino_iw
{
8
};
int
wino_iw
{
8
};
};
};
template
<
PrecisionType
OutType
>
class
WinogradConv
<
PRECISION
(
kInt8
),
OutType
>
:
public
KernelLite
<
TARGET
(
kARM
),
PRECISION
(
kInt8
)
>
{
public:
WinogradConv
()
=
default
;
~
WinogradConv
()
{}
virtual
void
PrepareForRun
();
virtual
void
ReInitWhenNeeded
();
virtual
void
Run
();
protected:
using
param_t
=
operators
::
ConvParam
;
Tensor
weights_
;
Tensor
bias_
;
DDim
last_shape_
;
int
workspace_size_
{
0
};
int
last_function_
{
-
1
};
bool
choose_small_
{
true
};
int
wino_iw
{
4
};
std
::
vector
<
float
>
w_scale_
;
};
}
// namespace arm
}
// namespace arm
}
// namespace kernels
}
// namespace kernels
}
// namespace lite
}
// namespace lite
...
...
lite/tests/math/sgemm_c4_compute_test.cc
浏览文件 @
34f7b509
...
@@ -179,6 +179,141 @@ bool test_sgemm_c4(
...
@@ -179,6 +179,141 @@ bool test_sgemm_c4(
#endif
#endif
return
true
;
return
true
;
}
}
bool
test_sgemm_c8
(
int
m
,
int
n
,
int
k
,
bool
has_bias
,
bool
has_relu
,
int
cls
,
int
ths
)
{
int
m_round
=
(
m
+
7
)
/
8
*
8
;
int
k_round
=
(
k
+
7
)
/
8
*
8
;
int
size_a
=
m
*
k
;
int
size_b
=
n
*
k
;
int
size_a_c4
=
m_round
*
k_round
;
int
size_b_c8
=
k_round
*
n
;
Tensor
ta
;
Tensor
tb
;
Tensor
ta_c4
;
Tensor
tb_c8
;
Tensor
tc
;
Tensor
tc_basic
;
Tensor
tc_backup
;
Tensor
tbias
;
ta
.
Resize
({
size_a
});
tb
.
Resize
({
size_b
});
ta_c4
.
Resize
({
size_a_c4
});
tb_c8
.
Resize
({
size_b_c8
});
tc
.
Resize
({
m_round
*
n
});
tc_basic
.
Resize
({
m_round
*
n
});
tbias
.
Resize
({
m
});
ta
.
set_precision
(
PRECISION
(
kInt16
));
tb
.
set_precision
(
PRECISION
(
kInt16
));
ta_c4
.
set_precision
(
PRECISION
(
kInt16
));
tb_c8
.
set_precision
(
PRECISION
(
kInt16
));
tc
.
set_precision
(
PRECISION
(
kInt32
));
tc_basic
.
set_precision
(
PRECISION
(
kInt32
));
tbias
.
set_precision
(
PRECISION
(
kInt32
));
fill_tensor_rand
(
ta
);
fill_tensor_rand
(
tb
);
fill_tensor_rand
(
tbias
);
fill_tensor_rand
(
tc
);
auto
da
=
ta
.
mutable_data
<
int16_t
>
();
auto
db
=
tb
.
mutable_data
<
int16_t
>
();
auto
da_c4
=
ta_c4
.
mutable_data
<
int16_t
>
();
auto
db_c8
=
tb_c8
.
mutable_data
<
int16_t
>
();
auto
dc_basic
=
tc_basic
.
mutable_data
<
int32_t
>
();
auto
dbias
=
tbias
.
mutable_data
<
int32_t
>
();
// trans A, B to c4
basic_trans_mat_to_c8
(
da
,
da_c4
,
k
,
m
,
k
,
true
);
basic_trans_mat_to_c8
(
db
,
db_c8
,
n
,
k
,
n
,
false
);
LOG
(
INFO
)
<<
"sgemm_c8 M: "
<<
m
<<
", N: "
<<
n
<<
", K: "
<<
k
<<
", relu: "
<<
(
has_relu
?
"true"
:
"false"
)
<<
", bias: "
<<
(
has_bias
?
"true"
:
"false"
);
if
(
FLAGS_check_result
)
{
basic_gemm_c8
(
false
,
false
,
m
,
n
,
k
,
1
,
da
,
k
,
db
,
n
,
0
,
dc_basic
,
n
,
dbias
,
false
,
false
);
}
Timer
t0
;
LOG
(
INFO
)
<<
"basic test end"
;
#ifdef LITE_WITH_ARM
//! compute
double
ops
=
2.0
*
m_round
*
n
*
k_round
;
std
::
unique_ptr
<
paddle
::
lite
::
KernelContext
>
ctx1
(
new
paddle
::
lite
::
KernelContext
);
auto
&
ctx
=
ctx1
->
As
<
paddle
::
lite
::
ARMContext
>
();
ctx
.
SetRunMode
(
static_cast
<
paddle
::
lite_api
::
PowerMode
>
(
cls
),
ths
);
auto
dc
=
tc
.
mutable_data
<
int32_t
>
();
for
(
int
j
=
0
;
j
<
FLAGS_warmup
;
++
j
)
{
paddle
::
lite
::
arm
::
math
::
sgemm_prepack_c8_int16_small
(
m
,
n
,
k
,
da_c4
,
db_c8
,
dc
,
&
ctx
);
}
LOG
(
INFO
)
<<
"basic test end"
;
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
t0
.
Start
();
paddle
::
lite
::
arm
::
math
::
sgemm_prepack_c8_int16_small
(
m
,
n
,
k
,
da_c4
,
db_c8
,
dc
,
&
ctx
);
t0
.
Stop
();
}
LOG
(
INFO
)
<<
"basic test end"
;
LOG
(
INFO
)
<<
"M: "
<<
m
<<
", N: "
<<
n
<<
", K: "
<<
k
<<
", power_mode: "
<<
cls
<<
", threads: "
<<
ths
<<
", GOPS: "
<<
ops
*
1e-9
f
<<
" GOPS, avg time: "
<<
t0
.
LapTimes
().
Avg
()
<<
" ms, min time: "
<<
t0
.
LapTimes
().
Min
()
<<
" ms, mean GOPs: "
<<
ops
*
1e-6
f
/
t0
.
LapTimes
().
Avg
()
<<
" GOPs, max GOPs: "
<<
ops
*
1e-6
f
/
t0
.
LapTimes
().
Min
()
<<
" GOPs"
;
if
(
FLAGS_check_result
)
{
double
max_ratio
=
0
;
double
max_diff
=
0
;
tensor_cmp_host
(
tc_basic
,
tc
,
max_ratio
,
max_diff
);
LOG
(
INFO
)
<<
"compare result, max diff: "
<<
max_diff
<<
", max ratio: "
<<
max_ratio
;
if
(
std
::
abs
(
max_ratio
)
>
1e-4
f
&&
std
::
abs
(
max_diff
)
>
5e-5
f
)
{
Tensor
tdiff
;
tdiff
.
set_precision
(
PRECISION
(
kInt32
));
tdiff
.
Resize
(
tc
.
dims
());
tensor_diff
(
tc_basic
,
tc
,
tdiff
);
LOG
(
INFO
)
<<
"a: "
;
print_tensor
(
ta
);
LOG
(
INFO
)
<<
"a_c8: "
;
print_tensor
(
ta_c4
);
LOG
(
INFO
)
<<
"b: "
;
print_tensor
(
tb
);
LOG
(
INFO
)
<<
"b_c8: "
;
print_tensor
(
tb_c8
);
LOG
(
INFO
)
<<
"basic result: "
;
print_tensor
(
tc_basic
);
LOG
(
INFO
)
<<
"lite result: "
;
print_tensor
(
tc
);
LOG
(
INFO
)
<<
"diff result: "
;
print_tensor
(
tdiff
);
return
false
;
}
}
#endif
return
true
;
}
TEST
(
TestSgemmC4
,
test_func_sgemm_c4_prepacked
)
{
TEST
(
TestSgemmC4
,
test_func_sgemm_c4_prepacked
)
{
if
(
FLAGS_basic_test
)
{
if
(
FLAGS_basic_test
)
{
...
@@ -186,11 +321,11 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
...
@@ -186,11 +321,11 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
paddle
::
lite
::
DeviceInfo
::
Init
();
paddle
::
lite
::
DeviceInfo
::
Init
();
#endif
#endif
LOG
(
INFO
)
<<
"run basic sgemm_c4 test"
;
LOG
(
INFO
)
<<
"run basic sgemm_c4 test"
;
for
(
auto
&
m
:
{
1
,
3
,
8
,
32
,
397
})
{
for
(
auto
&
m
:
{
1
,
3
,
8
,
32
,
397
,
32
,
64
,
77
})
{
for
(
auto
&
n
:
{
1
,
2
,
3
,
4
,
13
,
141
,
789
})
{
for
(
auto
&
n
:
{
1
,
2
,
3
,
4
,
13
,
141
,
789
,
1
})
{
for
(
auto
&
k
:
{
1
,
3
,
8
,
59
,
234
})
{
for
(
auto
&
k
:
{
1
,
3
,
8
,
59
,
234
,
19
})
{
for
(
auto
&
has_bias
:
{
false
,
true
})
{
for
(
auto
&
has_bias
:
{
false
})
{
for
(
auto
&
has_relu
:
{
false
,
true
})
{
for
(
auto
&
has_relu
:
{
false
})
{
for
(
auto
&
th
:
{
1
,
2
,
4
})
{
for
(
auto
&
th
:
{
1
,
2
,
4
})
{
auto
flag
=
test_sgemm_c4
(
auto
flag
=
test_sgemm_c4
(
m
,
n
,
k
,
has_bias
,
has_relu
,
FLAGS_power_mode
,
th
);
m
,
n
,
k
,
has_bias
,
has_relu
,
FLAGS_power_mode
,
th
);
...
@@ -213,8 +348,41 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
...
@@ -213,8 +348,41 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
}
}
}
}
}
}
TEST
(
TestSgemmC8
,
test_func_sgemm_c8_prepacked
)
{
if
(
FLAGS_basic_test
)
{
#ifdef LITE_WITH_ARM
paddle
::
lite
::
DeviceInfo
::
Init
();
#endif
LOG
(
INFO
)
<<
"run basic sgemm_c4 test"
;
for
(
auto
&
m
:
{
1
,
3
,
8
,
32
,
397
,
32
,
64
,
77
})
{
for
(
auto
&
n
:
{
1
,
2
,
3
,
4
,
13
,
141
,
789
,
1
})
{
for
(
auto
&
k
:
{
1
,
3
,
8
,
59
,
234
,
19
})
{
for
(
auto
&
has_bias
:
{
false
})
{
for
(
auto
&
has_relu
:
{
false
})
{
for
(
auto
&
th
:
{
1
})
{
auto
flag
=
test_sgemm_c8
(
m
,
n
,
k
,
has_bias
,
has_relu
,
FLAGS_power_mode
,
th
);
if
(
flag
)
{
LOG
(
INFO
)
<<
"test m = "
<<
m
<<
", n="
<<
n
<<
", k="
<<
k
<<
", bias: "
<<
(
has_bias
?
"true"
:
"false"
)
<<
", relu: "
<<
(
has_relu
?
"true"
:
"false"
)
<<
" passed
\n
"
;
}
else
{
LOG
(
FATAL
)
<<
"test m = "
<<
m
<<
", n="
<<
n
<<
", k="
<<
k
<<
", bias: "
<<
(
has_bias
?
"true"
:
"false"
)
<<
", relu: "
<<
(
has_relu
?
"true"
:
"false"
)
<<
" failed
\n
"
;
}
}
}
}
}
}
}
}
}
TEST
(
TestSgemmC
4Custom
,
test_func_sgemm_c4
_prepacked_custom
)
{
TEST
(
TestSgemmC
nCustom
,
test_func_sgemm_cn
_prepacked_custom
)
{
#ifdef LITE_WITH_ARM
#ifdef LITE_WITH_ARM
paddle
::
lite
::
DeviceInfo
::
Init
();
paddle
::
lite
::
DeviceInfo
::
Init
();
#endif
#endif
...
@@ -230,6 +398,18 @@ TEST(TestSgemmC4Custom, test_func_sgemm_c4_prepacked_custom) {
...
@@ -230,6 +398,18 @@ TEST(TestSgemmC4Custom, test_func_sgemm_c4_prepacked_custom) {
<<
", k="
<<
FLAGS_K
<<
", bias: "
<<
FLAGS_flag_bias
<<
", k="
<<
FLAGS_K
<<
", bias: "
<<
FLAGS_flag_bias
<<
", relu: "
<<
FLAGS_flag_relu
<<
" failed!!"
;
<<
", relu: "
<<
FLAGS_flag_relu
<<
" failed!!"
;
}
}
flag
=
test_sgemm_c8
(
FLAGS_M
,
FLAGS_N
,
FLAGS_K
,
FLAGS_flag_bias
,
FLAGS_flag_relu
,
FLAGS_power_mode
,
FLAGS_threads
);
if
(
!
flag
)
{
LOG
(
FATAL
)
<<
"test m = "
<<
FLAGS_M
<<
", n="
<<
FLAGS_N
<<
", k="
<<
FLAGS_K
<<
", bias: "
<<
FLAGS_flag_bias
<<
", relu: "
<<
FLAGS_flag_relu
<<
" failed!!"
;
}
LOG
(
INFO
)
<<
"test m = "
<<
FLAGS_M
<<
", n="
<<
FLAGS_N
<<
", k="
<<
FLAGS_K
LOG
(
INFO
)
<<
"test m = "
<<
FLAGS_M
<<
", n="
<<
FLAGS_N
<<
", k="
<<
FLAGS_K
<<
", bias: "
<<
FLAGS_flag_bias
<<
", relu: "
<<
FLAGS_flag_relu
<<
", bias: "
<<
FLAGS_flag_bias
<<
", relu: "
<<
FLAGS_flag_relu
<<
" passed!!"
;
<<
" passed!!"
;
...
...
lite/tests/utils/naive_math_impl.h
浏览文件 @
34f7b509
...
@@ -60,6 +60,72 @@ static void basic_trans_mat_to_c4(const type* input,
...
@@ -60,6 +60,72 @@ static void basic_trans_mat_to_c4(const type* input,
}
}
}
}
}
}
template
<
typename
type
>
static
void
basic_trans_mat_to_c8
(
const
type
*
input
,
type
*
output
,
const
int
ldin
,
const
int
M
,
const
int
K
,
bool
pack_k
)
{
const
int
m_round
=
(
M
+
7
)
/
8
*
8
;
int
k_round
=
(
K
+
7
)
/
8
*
8
;
if
(
!
pack_k
)
{
k_round
=
K
;
}
const
int
m_loop
=
m_round
/
8
;
type
zero_buf
[
K
];
memset
(
zero_buf
,
0
,
K
*
sizeof
(
type
));
for
(
int
i
=
0
;
i
<
m_loop
;
++
i
)
{
const
type
*
in0
=
input
+
i
*
8
*
ldin
;
const
type
*
in1
=
in0
+
ldin
;
const
type
*
in2
=
in1
+
ldin
;
const
type
*
in3
=
in2
+
ldin
;
const
type
*
in4
=
in3
+
ldin
;
const
type
*
in5
=
in4
+
ldin
;
const
type
*
in6
=
in5
+
ldin
;
const
type
*
in7
=
in6
+
ldin
;
if
(
8
*
(
i
+
1
)
-
M
>
0
)
{
switch
(
8
*
(
i
+
1
)
-
M
)
{
case
7
:
in1
=
zero_buf
;
case
6
:
in2
=
zero_buf
;
case
5
:
in3
=
zero_buf
;
case
4
:
in4
=
zero_buf
;
case
3
:
in5
=
zero_buf
;
case
2
:
in6
=
zero_buf
;
case
1
:
in7
=
zero_buf
;
default:
break
;
}
}
for
(
int
j
=
0
;
j
<
K
;
++
j
)
{
*
output
++
=
*
in0
++
;
*
output
++
=
*
in1
++
;
*
output
++
=
*
in2
++
;
*
output
++
=
*
in3
++
;
*
output
++
=
*
in4
++
;
*
output
++
=
*
in5
++
;
*
output
++
=
*
in6
++
;
*
output
++
=
*
in7
++
;
}
for
(
int
j
=
K
;
j
<
k_round
;
++
j
)
{
*
output
++
=
static_cast
<
type
>
(
0
);
*
output
++
=
static_cast
<
type
>
(
0
);
*
output
++
=
static_cast
<
type
>
(
0
);
*
output
++
=
static_cast
<
type
>
(
0
);
*
output
++
=
static_cast
<
type
>
(
0
);
*
output
++
=
static_cast
<
type
>
(
0
);
*
output
++
=
static_cast
<
type
>
(
0
);
*
output
++
=
static_cast
<
type
>
(
0
);
}
}
}
template
<
typename
type
,
typename
type2
>
template
<
typename
type
,
typename
type2
>
static
void
basic_gemm_c4
(
bool
trans_a
,
static
void
basic_gemm_c4
(
bool
trans_a
,
...
@@ -116,6 +182,60 @@ static void basic_gemm_c4(bool trans_a,
...
@@ -116,6 +182,60 @@ static void basic_gemm_c4(bool trans_a,
free
(
tmp_c
);
free
(
tmp_c
);
}
}
template
<
typename
type
,
typename
type2
>
static
void
basic_gemm_c8
(
bool
trans_a
,
bool
trans_b
,
int
m
,
int
n
,
int
k
,
type2
alpha
,
const
type
*
a
,
int
lda
,
const
type
*
b
,
int
ldb
,
type2
beta
,
type2
*
c
,
int
ldc
,
const
type2
*
bias
,
bool
flag_bias
=
false
,
bool
flag_relu
=
false
)
{
type2
*
tmp_c
=
reinterpret_cast
<
type2
*>
(
malloc
(
m
*
ldc
*
sizeof
(
type2
)));
memset
(
tmp_c
,
0
,
m
*
ldc
*
sizeof
(
type2
));
#pragma omp parallel for
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
auto
bias_data
=
static_cast
<
type2
>
(
0
);
if
(
flag_bias
)
{
bias_data
=
bias
[
i
];
}
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
auto
sum
=
static_cast
<
type2
>
(
0
);
for
(
int
l
=
0
;
l
<
k
;
++
l
)
{
type
av
;
type
bv
;
if
(
trans_a
)
{
av
=
a
[
l
*
lda
+
i
];
}
else
{
av
=
a
[
i
*
lda
+
l
];
}
if
(
trans_b
)
{
bv
=
b
[
j
*
ldb
+
l
];
}
else
{
bv
=
b
[
l
*
ldb
+
j
];
}
sum
+=
av
*
bv
;
}
type2
tmp
=
alpha
*
sum
+
beta
*
tmp_c
[
i
*
ldc
+
j
]
+
bias_data
;
if
(
flag_relu
)
{
tmp_c
[
i
*
ldc
+
j
]
=
tmp
>
(
type2
)
0
?
tmp
:
(
type2
)
0
;
}
else
{
tmp_c
[
i
*
ldc
+
j
]
=
tmp
;
}
}
}
//! trans c to c4
basic_trans_mat_to_c8
(
tmp_c
,
c
,
ldc
,
m
,
n
,
false
);
free
(
tmp_c
);
}
template
<
typename
type
,
typename
type2
>
template
<
typename
type
,
typename
type2
>
static
void
basic_gemm
(
bool
trans_a
,
static
void
basic_gemm
(
bool
trans_a
,
bool
trans_b
,
bool
trans_b
,
...
...
lite/tests/utils/tensor_utils.h
浏览文件 @
34f7b509
...
@@ -41,6 +41,10 @@ void fill_tensor_const(Tensor& tensor, float value) { // NOLINT
...
@@ -41,6 +41,10 @@ void fill_tensor_const(Tensor& tensor, float value) { // NOLINT
fill_tensor_host_const_impl
(
fill_tensor_host_const_impl
(
tensor
.
mutable_data
<
int8_t
>
(),
static_cast
<
signed
char
>
(
value
),
size
);
tensor
.
mutable_data
<
int8_t
>
(),
static_cast
<
signed
char
>
(
value
),
size
);
break
;
break
;
case
PRECISION
(
kInt16
):
fill_tensor_host_const_impl
(
tensor
.
mutable_data
<
int16_t
>
(),
static_cast
<
int16_t
>
(
value
),
size
);
break
;
case
PRECISION
(
kInt32
):
case
PRECISION
(
kInt32
):
fill_tensor_host_const_impl
(
fill_tensor_host_const_impl
(
tensor
.
mutable_data
<
int
>
(),
static_cast
<
int
>
(
value
),
size
);
tensor
.
mutable_data
<
int
>
(),
static_cast
<
int
>
(
value
),
size
);
...
@@ -69,6 +73,12 @@ void fill_tensor_host_rand_impl<signed char>(signed char* dio, int64_t size) {
...
@@ -69,6 +73,12 @@ void fill_tensor_host_rand_impl<signed char>(signed char* dio, int64_t size) {
}
}
}
}
template
<
>
template
<
>
void
fill_tensor_host_rand_impl
<
int16_t
>
(
int16_t
*
dio
,
int64_t
size
)
{
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
dio
[
i
]
=
(
rand
()
%
256
-
128
)
*
2
;
// NOLINT
}
}
template
<
>
void
fill_tensor_host_rand_impl
<
unsigned
char
>
(
unsigned
char
*
dio
,
void
fill_tensor_host_rand_impl
<
unsigned
char
>
(
unsigned
char
*
dio
,
int64_t
size
)
{
int64_t
size
)
{
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
...
@@ -86,6 +96,9 @@ void fill_tensor_rand(Tensor& tensor) { // NOLINT
...
@@ -86,6 +96,9 @@ void fill_tensor_rand(Tensor& tensor) { // NOLINT
case
PRECISION
(
kInt8
):
case
PRECISION
(
kInt8
):
fill_tensor_host_rand_impl
(
tensor
.
mutable_data
<
int8_t
>
(),
size
);
fill_tensor_host_rand_impl
(
tensor
.
mutable_data
<
int8_t
>
(),
size
);
break
;
break
;
case
PRECISION
(
kInt16
):
fill_tensor_host_rand_impl
(
tensor
.
mutable_data
<
int16_t
>
(),
size
);
break
;
case
PRECISION
(
kInt32
):
case
PRECISION
(
kInt32
):
fill_tensor_host_rand_impl
(
tensor
.
mutable_data
<
int
>
(),
size
);
fill_tensor_host_rand_impl
(
tensor
.
mutable_data
<
int
>
(),
size
);
break
;
break
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录