Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9cab9f7a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9cab9f7a
编写于
8月 01, 2020
作者:
Z
zhanyuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
1. multithreading support for fc_int8_op. 2. change asm matmul output layout from col8x8 to row8x8
上级
6c4ee3f3
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
362 addition
and
256 deletion
+362
-256
mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc
...e/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc
+54
-17
mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h
...re/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h
+5
-3
mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/matmul.s
...ite/src/runtime/kernel/arm/opclib/assembly/arm64/matmul.s
+143
-163
mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.cc
mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.cc
+4
-0
mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.h
mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.h
+9
-6
mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.cc
mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.cc
+0
-51
mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.h
mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.h
+1
-15
mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.h
...ite/src/runtime/kernel/arm/opclib/quantization/quantize.h
+2
-1
mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc
.../src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc
+144
-0
未找到文件。
mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.cc
浏览文件 @
9cab9f7a
...
...
@@ -17,6 +17,7 @@
#include "src/runtime/kernel/arm/int8/fullconnection_int8.h"
#include "src/runtime/kernel/arm/opclib/int8/matmul.h"
#include "src/runtime/kernel/arm/opclib/common_func.h"
#include "src/runtime/runtime_api.h"
#include "include/errorcode.h"
using
mindspore
::
lite
::
RET_MEMORY_FAILED
;
...
...
@@ -25,22 +26,42 @@ using mindspore::lite::RET_OK;
namespace
mindspore
::
kernel
{
int
FullconnectionInt8CPUKernel
::
Init
()
{
fc_param_
->
row_
=
(
inputs_
[
0
]
->
shape
())[
0
];
fc_param_
->
col_
=
(
inputs_
[
1
]
->
shape
())[
1
];
fc_param_
->
deep_
=
(
inputs_
[
1
]
->
shape
())[
0
];
fc_param_
->
col_
=
(
inputs_
[
1
]
->
shape
())[
0
];
fc_param_
->
deep_
=
(
inputs_
[
1
]
->
shape
())[
1
];
fc_param_
->
row_8_
=
UP_ROUND
(
fc_param_
->
row_
,
8
);
fc_param_
->
col_8_
=
UP_ROUND
(
fc_param_
->
col_
,
8
);
thread_count_
=
MSMIN
(
thread_count_
,
UP_DIV
(
fc_param_
->
col_8_
,
8
));
thread_stride_
=
UP_DIV
(
UP_DIV
(
fc_param_
->
col_8_
,
8
),
thread_count_
);
a_c8_ptr_
=
reinterpret_cast
<
int8_t
*>
(
ctx_
->
allocator
->
Malloc
(
fc_param_
->
row_8_
*
fc_param_
->
deep_
*
sizeof
(
int8_t
)));
if
(
!
a_c8_ptr_
)
{
return
RET_MEMORY_FAILED
;
}
memset
(
a_c8_ptr_
,
0
,
fc_param_
->
row_8_
*
fc_param_
->
deep_
*
sizeof
(
int8_t
));
b_r8_ptr_
=
reinterpret_cast
<
int8_t
*>
(
ctx_
->
allocator
->
Malloc
(
fc_param_
->
col_8_
*
fc_param_
->
deep_
*
sizeof
(
int8_t
)));
if
(
!
b_r8_ptr_
)
{
return
RET_MEMORY_FAILED
;
}
memset
(
b_r8_ptr_
,
0
,
fc_param_
->
col_8_
*
fc_param_
->
deep_
*
sizeof
(
int8_t
));
auto
weight_data
=
reinterpret_cast
<
int8_t
*>
(
inputs_
[
1
]
->
Data
());
RowMajor2Col8MajorInt8
(
weight_data
,
b_r8_ptr_
,
fc_param_
->
col_
,
fc_param_
->
deep_
);
c_r8x8_ptr_
=
reinterpret_cast
<
int
*>
(
ctx_
->
allocator
->
Malloc
(
fc_param_
->
row_8_
*
fc_param_
->
col_8_
*
sizeof
(
int
)));
if
(
!
c_r8x8_ptr_
)
{
return
RET_MEMORY_FAILED
;
}
memset
(
c_r8x8_ptr_
,
0
,
fc_param_
->
row_8_
*
fc_param_
->
col_8_
*
sizeof
(
int
));
if
(
!
a_c8_ptr_
||
!
b_r8_ptr_
||
!
c_r8x8_ptr_
)
{
auto
bias_len
=
fc_param_
->
col_8_
*
sizeof
(
int
);
bias_ptr_
=
reinterpret_cast
<
int
*>
(
ctx_
->
allocator
->
Malloc
(
bias_len
));
if
(
!
bias_ptr_
)
{
return
RET_MEMORY_FAILED
;
}
memset
(
bias_ptr_
,
0
,
bias_len
);
if
(
inputs_
.
size
()
==
3
)
{
memcpy
(
bias_ptr_
,
inputs_
[
2
]
->
Data
(),
bias_len
);
}
auto
input_tensor
=
inputs_
[
0
];
auto
params
=
input_tensor
->
GetQuantParams
();
...
...
@@ -59,7 +80,8 @@ int FullconnectionInt8CPUKernel::Init() {
quant_params_
.
output
.
scale_
=
params
.
front
().
scale
;
double
real_multiplier
=
quant_params_
.
input
.
scale_
*
quant_params_
.
weight
.
scale_
/
quant_params_
.
output
.
scale_
;
QuantizeMultiplier
(
real_multiplier
,
&
quant_params_
.
quant_multiplier
,
&
quant_params_
.
output_shift
);
QuantizeRoundParameter
(
real_multiplier
,
&
quant_params_
.
quant_multiplier
,
&
quant_params_
.
left_shift
,
&
quant_params_
.
right_shift
);
CalculateActivationRangeQuantized
(
fc_param_
->
maxf_
,
fc_param_
->
minf_
,
quant_params_
.
output
.
scale_
,
quant_params_
.
output
.
zp_
,
&
quant_params_
.
out_act_max
,
&
quant_params_
.
out_act_min
);
...
...
@@ -68,22 +90,37 @@ int FullconnectionInt8CPUKernel::Init() {
int
FullconnectionInt8CPUKernel
::
ReSize
()
{
return
RET_OK
;
}
int
FullconnectionInt8CPUKernel
::
Run
(
)
{
auto
a_ptr
=
reinterpret_cast
<
int8_t
*>
(
inputs_
.
at
(
0
)
->
Data
()
);
auto
b_ptr
=
reinterpret_cast
<
int8_t
*>
(
inputs_
.
at
(
1
)
->
Data
());
auto
bias_ptr
=
reinterpret_cast
<
int
*>
(
inputs_
.
at
(
2
)
->
Data
())
;
auto
output_ptr
=
reinterpret_cast
<
int8_t
*>
(
outputs_
.
at
(
0
)
->
Data
());
int
FullconnectionInt8CPUKernel
::
Run
Impl
(
int
task_id
)
{
int
cur_oc
=
MSMIN
(
thread_stride_
,
UP_DIV
(
fc_param_
->
col_8_
,
8
)
-
task_id
*
thread_stride_
);
if
(
cur_oc
<=
0
)
{
return
RET_OK
;
}
auto
&
p
=
quant_params_
;
auto
cur_b
=
b_r8_ptr_
+
task_id
*
thread_stride_
*
C8NUM
*
fc_param_
->
deep_
;
auto
cur_c
=
c_r8x8_ptr_
+
task_id
*
thread_stride_
*
C8NUM
*
fc_param_
->
row_8_
;
MatMulInt8
(
a_c8_ptr_
,
cur_b
,
cur_c
,
fc_param_
->
row_8_
,
cur_oc
*
8
,
fc_param_
->
deep_
,
p
.
input
.
zp_
,
p
.
weight
.
zp_
);
return
RET_OK
;
}
// rows*depth -> rows*depth, col_8 major
RowMajor2Col8MajorInt8
(
a_ptr
,
a_c8_ptr_
,
fc_param_
->
row_
,
fc_param_
->
deep_
);
// cols*depth -> cols*depth, col_8 major == depth*cols, row_8 major
RowMajor2Col8MajorInt8
(
b_ptr
,
b_r8_ptr_
,
fc_param_
->
col_
,
fc_param_
->
deep_
);
MatMulInt8
(
a_c8_ptr_
,
b_r8_ptr_
,
c_r8x8_ptr_
,
fc_param_
->
row_8_
,
fc_param_
->
col_8_
,
fc_param_
->
deep_
,
p
.
input
.
zp_
,
p
.
weight
.
zp_
);
PostFuncInt8
(
c_r8x8_ptr_
,
bias_ptr
,
output_ptr
,
fc_param_
->
col_
,
fc_param_
->
row_
,
fc_param_
->
col_8_
,
fc_param_
->
row_8_
,
p
.
quant_multiplier
,
p
.
output_shift
,
p
.
output
.
zp_
,
p
.
out_act_min
,
p
.
out_act_max
);
int
FcInt8Run
(
int
task_id
,
LiteParallelGroupEnv
*
penv
,
void
*
cdata
)
{
auto
fc
=
reinterpret_cast
<
FullconnectionInt8CPUKernel
*>
(
cdata
);
auto
ret
=
fc
->
RunImpl
(
task_id
);
if
(
ret
!=
RET_OK
)
{
MS_LOG
(
ERROR
)
<<
"FcInt8Run error task_id["
<<
task_id
<<
"] error_code["
<<
ret
<<
"]"
;
return
ret
;
}
return
RET_OK
;
}
int
FullconnectionInt8CPUKernel
::
Run
()
{
auto
a_ptr
=
reinterpret_cast
<
int8_t
*>
(
inputs_
[
0
]
->
Data
());
auto
output_ptr
=
reinterpret_cast
<
int8_t
*>
(
outputs_
[
0
]
->
Data
());
auto
&
p
=
quant_params_
;
RowMajor2Col8MajorInt8
(
a_ptr
,
a_c8_ptr_
,
fc_param_
->
row_
,
fc_param_
->
deep_
);
LiteBackendParallelLaunch
(
FcInt8Run
,
this
,
thread_count_
);
PostFuncInt8
(
c_r8x8_ptr_
,
bias_ptr_
,
output_ptr
,
fc_param_
->
col_
,
fc_param_
->
row_
,
fc_param_
->
row_8_
,
p
.
quant_multiplier
,
p
.
left_shift
,
p
.
right_shift
,
p
.
output
.
zp_
,
p
.
out_act_min
,
p
.
out_act_max
);
return
RET_OK
;
}
}
// namespace mindspore::kernel
mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h
浏览文件 @
9cab9f7a
...
...
@@ -31,20 +31,22 @@ class FullconnectionInt8CPUKernel : public FullconnectionBaseCPUKernel {
const
std
::
vector
<
lite
::
tensor
::
Tensor
*>
&
outputs
,
const
Context
*
ctx
)
:
FullconnectionBaseCPUKernel
(
parameter
,
inputs
,
outputs
,
ctx
)
{}
~
FullconnectionInt8CPUKernel
()
override
{
f
ree
(
a_c8_ptr_
);
f
ree
(
b_r8_ptr_
);
f
ree
(
c_r8x8_ptr_
);
ctx_
->
allocator
->
F
ree
(
a_c8_ptr_
);
ctx_
->
allocator
->
F
ree
(
b_r8_ptr_
);
ctx_
->
allocator
->
F
ree
(
c_r8x8_ptr_
);
}
int
Init
()
override
;
int
ReSize
()
override
;
int
Run
()
override
;
int
RunImpl
(
int
task_id
);
private:
FcQuantArg
quant_params_
;
int8_t
*
a_c8_ptr_
;
int8_t
*
b_r8_ptr_
;
int
*
c_r8x8_ptr_
;
int
*
bias_ptr_
;
};
}
// namespace mindspore::kernel
...
...
mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/matmul.s
浏览文件 @
9cab9f7a
...
...
@@ -17,17 +17,17 @@
//
\-----------------------------------------/
//
LM
8
x1
block
//
/---------------------\
/-----------------------------------------
\
//
|
v0
.
s
[
0
]
| |v16
.
s
[
0
]
...
v30
.
s
[
0
]|
//
|
v0
.
s
[
0
]
| |v16
.
s
[
0
]
...
v16
.
s
[
3
]
v17
.
s
[
0
]
...
v17
.
s
[
3
]|
//
|
...
| |
...
...
|
//
|
v0
.
s
[
3
]
| |v
16
.
s
[
3
]
...
v30
.
s
[
3
]|
//
|
v1
.
s
[
0
]
| |v
17
.
s
[
0
]
...
v31
.
s
[
0
]|
//
|
v0
.
s
[
3
]
| |v
22
.
s
[
0
]
...
v22
.
s
[
3
]
v23
.
s
[
0
]
...
v23
.
s
[
3
]|
//
|
v1
.
s
[
0
]
| |v
24
.
s
[
0
]
...
v24
.
s
[
3
]
v25
.
s
[
0
]
...
v25
.
s
[
3
]|
//
|
...
| |
...
...
|
//
|
v1
.
s
[
3
]
| |v
17
.
s
[
3
]
...
v31
.
s
[
3
]|
//
|
v1
.
s
[
3
]
| |v
30
.
s
[
0
]
...
v30
.
s
[
3
]
v31
.
s
[
0
]
...
v31
.
s
[
3
]|
//
\---------------------/
\
-----------------------------------------/
//
accumulators
8
x8
block
//
///////////////////////////////////////////////////////////////////////////////
//
OptLoopMul4
R
HS
1
x8
block
//
OptLoopMul4
R
M
1
x8
block
//
/--------------------------------------------\
//
|v8.s[0] ... v8.s[3] v9.s[0] ... v9.s[3] |
//
|v10.s[0] ... v10.s[3] v11.s[0] ... v11.s[3]|
...
...
@@ -36,12 +36,12 @@
//
\--------------------------------------------/
//
LM
8
x4
block
//
/---------------------------------\
/--------------------------------------------
\
//
|
v0
.
s
[
0
]
v2
.
s
[
0
]
v4
.
s
[
0
]
v6
.
s
[
0
]
| |v16
.
s
[
0
]
...
v30
.
s
[
0
]
|
//
|
v0
.
s
[
0
]
v2
.
s
[
0
]
v4
.
s
[
0
]
v6
.
s
[
0
]
| |v16
.
s
[
0
]
...
v16
.
s
[
3
]
v17
.
s
[
0
]
...
v17
.
s
[
3
]
|
//
|
...
...
...
...
| |
...
...
|
//
|
v0
.
s
[
3
]
v2
.
s
[
3
]
v4
.
s
[
3
]
v6
.
s
[
3
]
| |v
16
.
s
[
3
]
...
v30
.
s
[
3
]
|
//
|
v1
.
s
[
0
]
v3
.
s
[
0
]
v5
.
s
[
0
]
v7
.
s
[
0
]
| |v
17
.
s
[
0
]
...
v31
.
s
[
0
]
|
//
|
v0
.
s
[
3
]
v2
.
s
[
3
]
v4
.
s
[
3
]
v6
.
s
[
3
]
| |v
22
.
s
[
0
]
...
v22
.
s
[
3
]
v23
.
s
[
0
]
...
v23
.
s
[
3
]
|
//
|
v1
.
s
[
0
]
v3
.
s
[
0
]
v5
.
s
[
0
]
v7
.
s
[
0
]
| |v
24
.
s
[
0
]
...
v24
.
s
[
3
]
v25
.
s
[
0
]
...
v25
.
s
[
3
]
|
//
|
...
...
...
...
| |
...
...
|
//
|
v1
.
s
[
3
]
v3
.
s
[
3
]
v5
.
s
[
3
]
v7
.
s
[
3
]
| |v
17
.
s
[
3
]
...
v31
.
s
[
3
]
|
//
|
v1
.
s
[
3
]
v3
.
s
[
3
]
v5
.
s
[
3
]
v7
.
s
[
3
]
| |v
30
.
s
[
0
]
...
v30
.
s
[
3
]
v31
.
s
[
0
]
...
v31
.
s
[
3
]
|
//
\---------------------------------/
\
--------------------------------------------/
//
accumulators
8
x8
block
/////////////////////////////////////////////////////////////////////////////////
...
...
@@ -64,25 +64,22 @@ MatMulFloatNeon64:
mov
w7
,
v0
.
s
[
0
]
mov
w8
,
v1
.
s
[
0
]
mov
w9
,
0
//
row
counter
mov
w10
,
0
//
col
counter
mov
w18
,
#
32
mul
w15
,
w4
,
w18
//
the
stride
of
a
or
b
mul
w16
,
w6
,
w18
//
the
stride
of
c
mov
w9
,
0
//
rm
col
offset
mov
w10
,
0
//
lm
row
offset
mov
w18
,
#
32
//
sizeof
(
float
)*
8
mul
w15
,
w4
,
w18
//
the
stride
of
lm
/
rm
:
sizeof
(
float
)*
8
*
depth
L1
:
cmp
w9
,
w
5
cmp
w9
,
w
6
beq
End1
mov
w10
,
0
//
reset
col
counter
mov
x12
,
x1
//
reload
b
ptr
mov
x17
,
x2
//
reload
current
c
ptr
mov
w10
,
0
//
reset
lm
row
offset
mov
x12
,
x0
//
reload
lm
ptr
mov
x14
,
x3
//
reload
bias
ptr
L2
:
cmp
w10
,
w6
beq
End2
mov
x11
,
x0
//
reload
a
ptr
mov
w13
,
w4
//
reload
depth
dup
v16
.4
s
,
wzr
dup
v17
.4
s
,
wzr
...
...
@@ -105,142 +102,127 @@ OptLoopMul4:
cmp
w13
,
#
4
blt
CommLoopMul
ld1
{
v0
.4
s
},
[
x11
],
#
16
ld1
{
v8
.4
s
},
[
x12
],
#
16
fmla
v16
.4
s
,
v0
.4
s
,
v8
.
s
[
0
]
fmla
v18
.4
s
,
v0
.4
s
,
v8
.
s
[
1
]
ld1
{
v1
.4
s
},
[
x11
],
#
16
fmla
v20
.4
s
,
v0
.4
s
,
v8
.
s
[
2
]
fmla
v22
.4
s
,
v0
.4
s
,
v8
.
s
[
3
]
ld1
{
v9
.4
s
},
[
x12
],
#
16
fmla
v25
.4
s
,
v1
.4
s
,
v9
.
s
[
0
]
fmla
v27
.4
s
,
v1
.4
s
,
v9
.
s
[
1
]
fmla
v29
.4
s
,
v1
.4
s
,
v9
.
s
[
2
]
fmla
v31
.4
s
,
v1
.4
s
,
v9
.
s
[
3
]
ld1
{
v2
.4
s
},
[
x11
],
#
16
ld1
{
v3
.4
s
},
[
x11
],
#
16
fmla
v24
.4
s
,
v0
.4
s
,
v9
.
s
[
0
]
fmla
v26
.4
s
,
v0
.4
s
,
v9
.
s
[
1
]
fmla
v28
.4
s
,
v0
.4
s
,
v9
.
s
[
2
]
fmla
v30
.4
s
,
v0
.4
s
,
v9
.
s
[
3
]
fmla
v17
.4
s
,
v1
.4
s
,
v8
.
s
[
0
]
fmla
v19
.4
s
,
v1
.4
s
,
v8
.
s
[
1
]
fmla
v21
.4
s
,
v1
.4
s
,
v8
.
s
[
2
]
fmla
v23
.4
s
,
v1
.4
s
,
v8
.
s
[
3
]
ld1
{
v10
.4
s
},
[
x12
],
#
16
ld1
{
v11
.4
s
},
[
x12
],
#
16
fmla
v16
.4
s
,
v2
.4
s
,
v10
.
s
[
0
]
fmla
v18
.4
s
,
v2
.4
s
,
v10
.
s
[
1
]
fmla
v20
.4
s
,
v2
.4
s
,
v10
.
s
[
2
]
fmla
v22
.4
s
,
v2
.4
s
,
v10
.
s
[
3
]
fmla
v25
.4
s
,
v3
.4
s
,
v11
.
s
[
0
]
fmla
v27
.4
s
,
v3
.4
s
,
v11
.
s
[
1
]
fmla
v29
.4
s
,
v3
.4
s
,
v11
.
s
[
2
]
fmla
v31
.4
s
,
v3
.4
s
,
v11
.
s
[
3
]
ld1
{
v4
.4
s
},
[
x11
],
#
16
ld1
{
v5
.4
s
},
[
x11
],
#
16
fmla
v24
.4
s
,
v2
.4
s
,
v11
.
s
[
0
]
fmla
v26
.4
s
,
v2
.4
s
,
v11
.
s
[
1
]
fmla
v28
.4
s
,
v2
.4
s
,
v11
.
s
[
2
]
fmla
v30
.4
s
,
v2
.4
s
,
v11
.
s
[
3
]
fmla
v17
.4
s
,
v3
.4
s
,
v10
.
s
[
0
]
fmla
v19
.4
s
,
v3
.4
s
,
v10
.
s
[
1
]
fmla
v21
.4
s
,
v3
.4
s
,
v10
.
s
[
2
]
fmla
v23
.4
s
,
v3
.4
s
,
v10
.
s
[
3
]
ld1
{
v12
.4
s
},
[
x12
],
#
16
ld1
{
v13
.4
s
},
[
x12
],
#
16
fmla
v16
.4
s
,
v4
.4
s
,
v12
.
s
[
0
]
fmla
v18
.4
s
,
v4
.4
s
,
v12
.
s
[
1
]
fmla
v20
.4
s
,
v4
.4
s
,
v12
.
s
[
2
]
fmla
v22
.4
s
,
v4
.4
s
,
v12
.
s
[
3
]
fmla
v25
.4
s
,
v5
.4
s
,
v13
.
s
[
0
]
fmla
v27
.4
s
,
v5
.4
s
,
v13
.
s
[
1
]
fmla
v29
.4
s
,
v5
.4
s
,
v13
.
s
[
2
]
fmla
v31
.4
s
,
v5
.4
s
,
v13
.
s
[
3
]
ld1
{
v6
.4
s
},
[
x11
],
#
16
ld1
{
v7
.4
s
},
[
x11
],
#
16
fmla
v24
.4
s
,
v4
.4
s
,
v13
.
s
[
0
]
fmla
v26
.4
s
,
v4
.4
s
,
v13
.
s
[
1
]
fmla
v28
.4
s
,
v4
.4
s
,
v13
.
s
[
2
]
fmla
v30
.4
s
,
v4
.4
s
,
v13
.
s
[
3
]
fmla
v17
.4
s
,
v5
.4
s
,
v12
.
s
[
0
]
fmla
v19
.4
s
,
v5
.4
s
,
v12
.
s
[
1
]
fmla
v21
.4
s
,
v5
.4
s
,
v12
.
s
[
2
]
fmla
v23
.4
s
,
v5
.4
s
,
v12
.
s
[
3
]
ld1
{
v14
.4
s
},
[
x12
],
#
16
ld1
{
v15
.4
s
},
[
x12
],
#
16
fmla
v16
.4
s
,
v6
.4
s
,
v14
.
s
[
0
]
fmla
v18
.4
s
,
v6
.4
s
,
v14
.
s
[
1
]
fmla
v20
.4
s
,
v6
.4
s
,
v14
.
s
[
2
]
fmla
v22
.4
s
,
v6
.4
s
,
v14
.
s
[
3
]
fmla
v25
.4
s
,
v7
.4
s
,
v15
.
s
[
0
]
fmla
v27
.4
s
,
v7
.4
s
,
v15
.
s
[
1
]
fmla
v29
.4
s
,
v7
.4
s
,
v15
.
s
[
2
]
fmla
v31
.4
s
,
v7
.4
s
,
v15
.
s
[
3
]
fmla
v24
.4
s
,
v6
.4
s
,
v15
.
s
[
0
]
fmla
v26
.4
s
,
v6
.4
s
,
v15
.
s
[
1
]
fmla
v28
.4
s
,
v6
.4
s
,
v15
.
s
[
2
]
fmla
v30
.4
s
,
v6
.4
s
,
v15
.
s
[
3
]
fmla
v17
.4
s
,
v7
.4
s
,
v14
.
s
[
0
]
fmla
v19
.4
s
,
v7
.4
s
,
v14
.
s
[
1
]
fmla
v21
.4
s
,
v7
.4
s
,
v14
.
s
[
2
]
fmla
v23
.4
s
,
v7
.4
s
,
v14
.
s
[
3
]
ld1
{
v0
.4
s
,
v1
.4
s
},
[
x12
],
#
32
ld1
{
v8
.4
s
,
v9
.4
s
},
[
x1
],
#
32
fmla
v16
.4
s
,
v8
.4
s
,
v0
.
s
[
0
]
fmla
v17
.4
s
,
v9
.4
s
,
v0
.
s
[
0
]
fmla
v18
.4
s
,
v8
.4
s
,
v0
.
s
[
1
]
fmla
v19
.4
s
,
v9
.4
s
,
v0
.
s
[
1
]
fmla
v20
.4
s
,
v8
.4
s
,
v0
.
s
[
2
]
fmla
v21
.4
s
,
v9
.4
s
,
v0
.
s
[
2
]
fmla
v22
.4
s
,
v8
.4
s
,
v0
.
s
[
3
]
fmla
v23
.4
s
,
v9
.4
s
,
v0
.
s
[
3
]
ld1
{
v10
.4
s
,
v11
.4
s
},
[
x1
],
#
32
fmla
v24
.4
s
,
v8
.4
s
,
v1
.
s
[
0
]
fmla
v25
.4
s
,
v9
.4
s
,
v1
.
s
[
0
]
fmla
v26
.4
s
,
v8
.4
s
,
v1
.
s
[
1
]
fmla
v27
.4
s
,
v9
.4
s
,
v1
.
s
[
1
]
ld1
{
v2
.4
s
,
v3
.4
s
},
[
x12
],
#
32
fmla
v28
.4
s
,
v8
.4
s
,
v1
.
s
[
2
]
fmla
v29
.4
s
,
v9
.4
s
,
v1
.
s
[
2
]
fmla
v30
.4
s
,
v8
.4
s
,
v1
.
s
[
3
]
fmla
v31
.4
s
,
v9
.4
s
,
v1
.
s
[
3
]
fmla
v16
.4
s
,
v10
.4
s
,
v2
.
s
[
0
]
fmla
v17
.4
s
,
v11
.4
s
,
v2
.
s
[
0
]
fmla
v18
.4
s
,
v10
.4
s
,
v2
.
s
[
1
]
fmla
v19
.4
s
,
v11
.4
s
,
v2
.
s
[
1
]
fmla
v20
.4
s
,
v10
.4
s
,
v2
.
s
[
2
]
fmla
v21
.4
s
,
v11
.4
s
,
v2
.
s
[
2
]
fmla
v22
.4
s
,
v10
.4
s
,
v2
.
s
[
3
]
fmla
v23
.4
s
,
v11
.4
s
,
v2
.
s
[
3
]
ld1
{
v12
.4
s
,
v13
.4
s
},
[
x1
],
#
32
fmla
v24
.4
s
,
v10
.4
s
,
v3
.
s
[
0
]
fmla
v25
.4
s
,
v11
.4
s
,
v3
.
s
[
0
]
fmla
v26
.4
s
,
v10
.4
s
,
v3
.
s
[
1
]
fmla
v27
.4
s
,
v11
.4
s
,
v3
.
s
[
1
]
ld1
{
v4
.4
s
,
v5
.4
s
},
[
x12
],
#
32
fmla
v28
.4
s
,
v10
.4
s
,
v3
.
s
[
2
]
fmla
v29
.4
s
,
v11
.4
s
,
v3
.
s
[
2
]
fmla
v30
.4
s
,
v10
.4
s
,
v3
.
s
[
3
]
fmla
v31
.4
s
,
v11
.4
s
,
v3
.
s
[
3
]
fmla
v16
.4
s
,
v12
.4
s
,
v4
.
s
[
0
]
fmla
v17
.4
s
,
v13
.4
s
,
v4
.
s
[
0
]
fmla
v18
.4
s
,
v12
.4
s
,
v4
.
s
[
1
]
fmla
v19
.4
s
,
v13
.4
s
,
v4
.
s
[
1
]
fmla
v20
.4
s
,
v12
.4
s
,
v4
.
s
[
2
]
fmla
v21
.4
s
,
v13
.4
s
,
v4
.
s
[
2
]
fmla
v22
.4
s
,
v12
.4
s
,
v4
.
s
[
3
]
fmla
v23
.4
s
,
v13
.4
s
,
v4
.
s
[
3
]
ld1
{
v6
.4
s
,
v7
.4
s
},
[
x12
],
#
32
fmla
v24
.4
s
,
v12
.4
s
,
v5
.
s
[
0
]
fmla
v25
.4
s
,
v13
.4
s
,
v5
.
s
[
0
]
fmla
v26
.4
s
,
v12
.4
s
,
v5
.
s
[
1
]
fmla
v27
.4
s
,
v13
.4
s
,
v5
.
s
[
1
]
ld1
{
v14
.4
s
,
v15
.4
s
},
[
x1
],
#
32
fmla
v28
.4
s
,
v12
.4
s
,
v5
.
s
[
2
]
fmla
v29
.4
s
,
v13
.4
s
,
v5
.
s
[
2
]
fmla
v30
.4
s
,
v12
.4
s
,
v5
.
s
[
3
]
fmla
v31
.4
s
,
v13
.4
s
,
v5
.
s
[
3
]
fmla
v16
.4
s
,
v14
.4
s
,
v6
.
s
[
0
]
fmla
v17
.4
s
,
v15
.4
s
,
v6
.
s
[
0
]
fmla
v18
.4
s
,
v14
.4
s
,
v6
.
s
[
1
]
fmla
v19
.4
s
,
v15
.4
s
,
v6
.
s
[
1
]
fmla
v20
.4
s
,
v14
.4
s
,
v6
.
s
[
2
]
fmla
v21
.4
s
,
v15
.4
s
,
v6
.
s
[
2
]
fmla
v22
.4
s
,
v14
.4
s
,
v6
.
s
[
3
]
fmla
v23
.4
s
,
v15
.4
s
,
v6
.
s
[
3
]
fmla
v24
.4
s
,
v14
.4
s
,
v7
.
s
[
0
]
fmla
v25
.4
s
,
v15
.4
s
,
v7
.
s
[
0
]
fmla
v26
.4
s
,
v14
.4
s
,
v7
.
s
[
1
]
fmla
v27
.4
s
,
v15
.4
s
,
v7
.
s
[
1
]
fmla
v28
.4
s
,
v14
.4
s
,
v7
.
s
[
2
]
fmla
v29
.4
s
,
v15
.4
s
,
v7
.
s
[
2
]
fmla
v30
.4
s
,
v14
.4
s
,
v7
.
s
[
3
]
fmla
v31
.4
s
,
v15
.4
s
,
v7
.
s
[
3
]
subs
w13
,
w13
,
#
4
b
OptLoopMul4
CommLoopMul
:
cmp
w13
,
#
1
blt
Bias
ld1
{
v0
.4
s
},
[
x11
],
#
16
ld1
{
v2
.4
s
},
[
x12
],
#
16
fmla
v16
.4
s
,
v0
.4
s
,
v2
.
s
[
0
]
fmla
v18
.4
s
,
v0
.4
s
,
v2
.
s
[
1
]
ld1
{
v1
.4
s
},
[
x11
],
#
16
fmla
v20
.4
s
,
v0
.4
s
,
v2
.
s
[
2
]
fmla
v22
.4
s
,
v0
.4
s
,
v2
.
s
[
3
]
ld1
{
v3
.4
s
},
[
x12
],
#
16
fmla
v25
.4
s
,
v1
.4
s
,
v3
.
s
[
0
]
fmla
v27
.4
s
,
v1
.4
s
,
v3
.
s
[
1
]
fmla
v29
.4
s
,
v1
.4
s
,
v3
.
s
[
2
]
fmla
v31
.4
s
,
v1
.4
s
,
v3
.
s
[
3
]
fmla
v24
.4
s
,
v0
.4
s
,
v3
.
s
[
0
]
fmla
v26
.4
s
,
v0
.4
s
,
v3
.
s
[
1
]
fmla
v28
.4
s
,
v0
.4
s
,
v3
.
s
[
2
]
fmla
v30
.4
s
,
v0
.4
s
,
v3
.
s
[
3
]
fmla
v17
.4
s
,
v1
.4
s
,
v2
.
s
[
0
]
fmla
v19
.4
s
,
v1
.4
s
,
v2
.
s
[
1
]
fmla
v21
.4
s
,
v1
.4
s
,
v2
.
s
[
2
]
fmla
v23
.4
s
,
v1
.4
s
,
v2
.
s
[
3
]
ld1
{
v0
.4
s
,
v1
.4
s
},
[
x12
],
#
32
ld1
{
v2
.4
s
,
v3
.4
s
},
[
x1
],
#
32
fmla
v16
.4
s
,
v2
.4
s
,
v0
.
s
[
0
]
fmla
v17
.4
s
,
v3
.4
s
,
v0
.
s
[
0
]
fmla
v18
.4
s
,
v2
.4
s
,
v0
.
s
[
1
]
fmla
v19
.4
s
,
v3
.4
s
,
v0
.
s
[
1
]
fmla
v20
.4
s
,
v2
.4
s
,
v0
.
s
[
2
]
fmla
v21
.4
s
,
v3
.4
s
,
v0
.
s
[
2
]
fmla
v22
.4
s
,
v2
.4
s
,
v0
.
s
[
3
]
fmla
v23
.4
s
,
v3
.4
s
,
v0
.
s
[
3
]
fmla
v24
.4
s
,
v2
.4
s
,
v1
.
s
[
0
]
fmla
v25
.4
s
,
v3
.4
s
,
v1
.
s
[
0
]
fmla
v26
.4
s
,
v2
.4
s
,
v1
.
s
[
1
]
fmla
v27
.4
s
,
v3
.4
s
,
v1
.
s
[
1
]
fmla
v28
.4
s
,
v2
.4
s
,
v1
.
s
[
2
]
fmla
v29
.4
s
,
v3
.4
s
,
v1
.
s
[
2
]
fmla
v30
.4
s
,
v2
.4
s
,
v1
.
s
[
3
]
fmla
v31
.4
s
,
v3
.4
s
,
v1
.
s
[
3
]
subs
w13
,
w13
,
#
1
b
CommLoopMul
Bias
:
cmp
x3
,
#
0
beq
Relu
ld1
{
v0
.4
s
},
[
x14
],
#
16
ld1
{
v1
.4
s
},
[
x14
],
#
16
dup
v2
.4
s
,
v0
.
s
[
0
]
fadd
v16
.4
s
,
v16
.4
s
,
v2
.4
s
fadd
v17
.4
s
,
v17
.4
s
,
v2
.4
s
dup
v3
.4
s
,
v0
.
s
[
1
]
fadd
v18
.4
s
,
v18
.4
s
,
v3
.4
s
fadd
v19
.4
s
,
v19
.4
s
,
v3
.4
s
dup
v4
.4
s
,
v0
.
s
[
2
]
fadd
v20
.4
s
,
v20
.4
s
,
v4
.4
s
fadd
v21
.4
s
,
v21
.4
s
,
v4
.4
s
dup
v5
.4
s
,
v0
.
s
[
3
]
fadd
v22
.4
s
,
v22
.4
s
,
v5
.4
s
fadd
v23
.4
s
,
v23
.4
s
,
v5
.4
s
dup
v2
.4
s
,
v1
.
s
[
0
]
fadd
v24
.4
s
,
v24
.4
s
,
v2
.4
s
fadd
v25
.4
s
,
v25
.4
s
,
v2
.4
s
dup
v3
.4
s
,
v1
.
s
[
1
]
fadd
v26
.4
s
,
v26
.4
s
,
v3
.4
s
fadd
v27
.4
s
,
v27
.4
s
,
v3
.4
s
dup
v4
.4
s
,
v1
.
s
[
2
]
fadd
v28
.4
s
,
v28
.4
s
,
v4
.4
s
fadd
v29
.4
s
,
v29
.4
s
,
v4
.4
s
dup
v5
.4
s
,
v1
.
s
[
3
]
fadd
v30
.4
s
,
v30
.4
s
,
v5
.4
s
fadd
v31
.4
s
,
v31
.4
s
,
v5
.4
s
fadd
v16
.4
s
,
v16
.4
s
,
v0
.4
s
fadd
v17
.4
s
,
v17
.4
s
,
v1
.4
s
fadd
v18
.4
s
,
v18
.4
s
,
v0
.4
s
fadd
v19
.4
s
,
v19
.4
s
,
v1
.4
s
fadd
v20
.4
s
,
v20
.4
s
,
v0
.4
s
fadd
v21
.4
s
,
v21
.4
s
,
v1
.4
s
fadd
v22
.4
s
,
v22
.4
s
,
v0
.4
s
fadd
v23
.4
s
,
v23
.4
s
,
v1
.4
s
fadd
v24
.4
s
,
v24
.4
s
,
v0
.4
s
fadd
v25
.4
s
,
v25
.4
s
,
v1
.4
s
fadd
v26
.4
s
,
v26
.4
s
,
v0
.4
s
fadd
v27
.4
s
,
v27
.4
s
,
v1
.4
s
fadd
v28
.4
s
,
v28
.4
s
,
v0
.4
s
fadd
v29
.4
s
,
v29
.4
s
,
v1
.4
s
fadd
v30
.4
s
,
v30
.4
s
,
v0
.4
s
fadd
v31
.4
s
,
v31
.4
s
,
v1
.4
s
Relu
:
dup
v15
.4
s
,
w7
...
...
@@ -281,30 +263,28 @@ Relu:
fmin
v31
.4
s
,
v31
.4
s
,
v15
.4
s
TransToOut
:
st1
{
v16
.4
s
},
[
x
17
],
#
16
st1
{
v17
.4
s
},
[
x
17
],
#
16
st1
{
v18
.4
s
},
[
x
17
],
#
16
st1
{
v19
.4
s
},
[
x
17
],
#
16
st1
{
v20
.4
s
},
[
x
17
],
#
16
st1
{
v21
.4
s
},
[
x
17
],
#
16
st1
{
v22
.4
s
},
[
x
17
],
#
16
st1
{
v23
.4
s
},
[
x
17
],
#
16
st1
{
v24
.4
s
},
[
x
17
],
#
16
st1
{
v25
.4
s
},
[
x
17
],
#
16
st1
{
v26
.4
s
},
[
x
17
],
#
16
st1
{
v27
.4
s
},
[
x
17
],
#
16
st1
{
v28
.4
s
},
[
x
17
],
#
16
st1
{
v29
.4
s
},
[
x
17
],
#
16
st1
{
v30
.4
s
},
[
x
17
],
#
16
st1
{
v31
.4
s
},
[
x
17
],
#
16
st1
{
v16
.4
s
},
[
x
2
],
#
16
st1
{
v17
.4
s
},
[
x
2
],
#
16
st1
{
v18
.4
s
},
[
x
2
],
#
16
st1
{
v19
.4
s
},
[
x
2
],
#
16
st1
{
v20
.4
s
},
[
x
2
],
#
16
st1
{
v21
.4
s
},
[
x
2
],
#
16
st1
{
v22
.4
s
},
[
x
2
],
#
16
st1
{
v23
.4
s
},
[
x
2
],
#
16
st1
{
v24
.4
s
},
[
x
2
],
#
16
st1
{
v25
.4
s
},
[
x
2
],
#
16
st1
{
v26
.4
s
},
[
x
2
],
#
16
st1
{
v27
.4
s
},
[
x
2
],
#
16
st1
{
v28
.4
s
},
[
x
2
],
#
16
st1
{
v29
.4
s
},
[
x
2
],
#
16
st1
{
v30
.4
s
},
[
x
2
],
#
16
st1
{
v31
.4
s
},
[
x
2
],
#
16
add
w10
,
w10
,
#
8
//
col
+=
8
add
w10
,
w10
,
#
8
//
lhs
row
offset
+
8
b
L2
End2
:
add
x0
,
x0
,
x15
//
stride
a
ptr
add
x2
,
x2
,
x16
//
stride
c
ptr
add
w9
,
w9
,
#
8
//
row
+=
8
add
w9
,
w9
,
#
8
//
rhs
col
offset
+
8
b
L1
End1
:
...
...
mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.cc
浏览文件 @
9cab9f7a
...
...
@@ -74,5 +74,9 @@ void MatMul8x8(const float *a, const float *b, float *c, const float *bias, floa
void
MatMul
(
const
float
*
a
,
const
float
*
b
,
float
*
c
,
const
float
*
bias
,
float
maxf
,
float
minf
,
int
deep
,
int
row_8_
,
int
col_8_
)
{
#ifdef __aarch64__
MatMulFloatNeon64
(
a
,
b
,
c
,
bias
,
maxf
,
minf
,
deep
,
row_8_
,
col_8_
);
#else
MatMul8x8
(
a
,
b
,
c
,
bias
,
maxf
,
minf
,
deep
,
row_8_
,
col_8_
);
#endif
}
mindspore/lite/src/runtime/kernel/arm/opclib/fp32/matmul.h
浏览文件 @
9cab9f7a
...
...
@@ -21,19 +21,22 @@
#include "src/runtime/kernel/arm/opclib/op_base.h"
#include "src/runtime/kernel/arm/opclib/matmul.h"
#ifdef __cplusplus
extern
"C"
{
#endif
void
MatMul
(
const
float
*
a
,
const
float
*
b
,
float
*
c
,
const
float
*
bias
,
float
maxf
,
float
minf
,
int
depth
,
int
row
,
int
col
);
void
RowMajor2Row8Major
(
float
*
src_ptr
,
float
*
dst_ptr
,
int
row
,
int
col
);
void
RowMajor2Col8Major
(
float
*
src_ptr
,
float
*
dst_ptr
,
int
row
,
int
col
);
void
Row8x8Major2RowMajor
(
float
*
src_ptr
,
float
*
dst_ptr
,
int
row
,
int
col
);
void
MatMul8x8
(
const
float
*
a
,
const
float
*
b
,
float
*
c
,
const
float
*
bias
,
float
maxf
,
float
minf
,
int
deep
,
int
row_8_
,
int
col_8_
);
#ifdef __cplusplus
extern
"C"
{
#endif
#ifdef __aarch64__
void
MatMulFloatNeon64
(
const
float
*
a
,
const
float
*
b
,
float
*
c
,
const
float
*
bias
,
float
maxf
,
float
minf
,
int
depth
,
int
row
,
int
col
);
#endif
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_FP32_MATMUL_H_
mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.cc
浏览文件 @
9cab9f7a
...
...
@@ -48,54 +48,3 @@ void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, co
}
return
;
}
// todo: need to delete, replace by above functions. z00445833
void
GemmRowCol8x8Major2RowMajorInt8
(
int8_t
*
src_ptr
,
int8_t
*
dst_ptr
,
int
row
,
int
col
)
{
int
col8
=
UP_ROUND
(
col
,
8
);
for
(
int
r
=
0
;
r
<
row
;
r
++
)
{
int
rd8
=
r
/
8
;
int
rm8
=
r
%
8
;
for
(
int
c
=
0
;
c
<
col
;
c
++
)
{
dst_ptr
[
r
*
col
+
c
]
=
src_ptr
[
rd8
*
col8
*
8
+
c
*
8
+
rm8
];
}
}
}
void
Gemm8x8Int8
(
const
int8_t
*
lhs_data
,
const
int8_t
*
rhs_data
,
const
int8_t
*
bias_data
,
int8_t
*
output_data
,
int
depth
,
FcQuantArg
*
params
)
{
int
lhs_offset
=
params
->
input
.
zp_
;
int
rhs_offset
=
params
->
weight
.
zp_
;
int
output_offset
=
params
->
output
.
zp_
;
int
output_multiplier
=
params
->
quant_multiplier
;
int
output_shift
=
params
->
output_shift
;
for
(
int
row
=
0
;
row
<
8
;
++
row
)
{
for
(
int
col
=
0
;
col
<
8
;
++
col
)
{
int
c_index
=
col
*
8
+
row
;
int
acc
=
0
;
for
(
int
d
=
0
;
d
<
depth
;
++
d
)
{
int
a_index
=
d
*
8
+
row
;
int
b_index
=
d
*
8
+
col
;
acc
+=
(
lhs_data
[
a_index
]
-
lhs_offset
)
*
(
rhs_data
[
b_index
]
-
rhs_offset
);
}
acc
+=
bias_data
[
col
];
acc
=
MultiplyByQuantizedMultiplier
(
acc
,
output_multiplier
,
output_shift
,
output_shift
)
+
output_offset
;
acc
=
MSMAX
(
CHAR_MIN
,
MSMIN
(
CHAR_MAX
,
acc
));
output_data
[
c_index
]
=
(
int8_t
)
acc
;
}
}
}
void
GemmInt8
(
const
int8_t
*
input_data
,
const
int8_t
*
weights_data
,
const
int8_t
*
bias_data
,
int8_t
*
output_data
,
int
row_8
,
int
col_8
,
int
depth
,
FcQuantArg
*
params
)
{
for
(
int
r
=
0
;
r
<
row_8
;
r
+=
8
)
{
int8_t
*
output
=
output_data
+
r
*
col_8
;
const
int8_t
*
input
=
input_data
+
r
*
depth
;
for
(
int
c
=
0
;
c
<
col_8
;
c
+=
8
)
{
const
int8_t
*
bias
=
bias_data
+
c
;
const
int8_t
*
weights
=
weights_data
+
c
*
depth
;
int8_t
*
dst
=
output
+
c
*
8
;
Gemm8x8Int8
(
input
,
weights
,
bias
,
dst
,
depth
,
params
);
}
}
}
mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.h
浏览文件 @
9cab9f7a
...
...
@@ -20,23 +20,9 @@
#include "src/runtime/kernel/arm/opclib/op_base.h"
#include "src/runtime/kernel/arm/opclib/matmul.h"
#ifdef __cplusplus
extern
"C"
{
#endif
void
MatMulInt8
(
const
int8_t
*
a
,
const
int8_t
*
b
,
int32_t
*
c
,
const
int
row8
,
const
int
col8
,
const
int
deep
,
const
int32_t
a_zp
,
const
int32_t
b_zp
);
void
RowMajor2Col8MajorInt8
(
int8_t
*
src_ptr
,
int8_t
*
dst_ptr
,
int
row
,
int
col
);
void
GemmRowCol8x8Major2RowMajorInt8
(
int8_t
*
src_ptr
,
int8_t
*
dst_ptr
,
int
row
,
int
col
);
void
Gemm8x8Int8
(
const
int8_t
*
lhs_data
,
const
int8_t
*
rhs_data
,
const
int8_t
*
bias_data
,
int8_t
*
output_data
,
int
depth
,
FcQuantArg
*
params
);
void
GemmInt8
(
const
int8_t
*
input_data
,
const
int8_t
*
weights_data
,
const
int8_t
*
bias_data
,
int8_t
*
output_data
,
int
row_8
,
int
col_8
,
int
depth
,
FcQuantArg
*
params
);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_OPCLIB_INT8_MATMUL_H_
#endif // MINDSPORE_LITE_SRC_BACKEND_ARM_OPCLIB_INT8_MATMUL_H_
mindspore/lite/src/runtime/kernel/arm/opclib/quantization/quantize.h
浏览文件 @
9cab9f7a
...
...
@@ -54,7 +54,8 @@ struct FcQuantArg {
QuantArg
output
;
int32_t
out_act_min
;
int32_t
out_act_max
;
int32_t
output_shift
;
int32_t
left_shift
;
int32_t
right_shift
;
int32_t
quant_multiplier
;
};
...
...
mindspore/lite/test/ut/src/runtime/kernel/arm/int8/fullconnection_int8_tests.cc
0 → 100644
浏览文件 @
9cab9f7a
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <memory>
#include "utils/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/src/runtime/kernel/arm/int8/fullconnection_int8.h"
#include "mindspore/lite/src/runtime/kernel/arm/opclib/int8/matmul.h"
#include "mindspore/lite/src/runtime/kernel/arm/opclib/common_func.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "mindspore/lite/src/lite_kernel.h"
namespace
mindspore
{
using
lite
::
tensor
::
Tensor
;
class
TestFcInt8
:
public
mindspore
::
Common
{
public:
TestFcInt8
(){}
};
void
Quantize
(
float
*
input_data
,
int
length
,
float
scale
,
int
zero_point
,
int8_t
*
output_data
)
{
for
(
int
i
=
0
;
i
<
length
;
++
i
)
{
int8_t
q
=
static_cast
<
int8_t
>
(
std
::
max
<
float
>
(
std
::
numeric_limits
<
int8_t
>::
min
(),
std
::
min
<
float
>
(
std
::
numeric_limits
<
int8_t
>::
max
(),
std
::
round
(
zero_point
+
(
input_data
[
i
]
/
scale
)))));
output_data
[
i
]
=
q
;
}
}
void
Dequantize
(
int8_t
*
input_data
,
int
length
,
float
scale
,
int
zero_point
,
float
*
output_data
)
{
for
(
int
i
=
0
;
i
<
length
;
++
i
)
{
output_data
[
i
]
=
scale
*
(
input_data
[
i
]
-
zero_point
);
}
}
int
FcInt8TestInit
(
std
::
vector
<
lite
::
tensor
::
Tensor
*>
*
inputs_
,
std
::
vector
<
lite
::
tensor
::
Tensor
*>
*
outputs_
,
MatMulParameter
*
matmal_param
,
float
**
correct
,
double
*
scale
,
int
*
zeropoint
)
{
float
input_max
=
20
;
float
input_min
=
-
20
;
float
weight_max
=
1
;
float
weight_min
=
-
1
;
float
output_max
=
20
;
float
output_min
=
-
20
;
double
input_scale
=
(
input_max
-
input_min
)
/
(
std
::
numeric_limits
<
int8_t
>::
max
()
-
std
::
numeric_limits
<
int8_t
>::
min
());
int
input_zp
=
std
::
numeric_limits
<
int8_t
>::
max
()
-
input_max
/
input_scale
;
double
weight_scale
=
(
weight_max
-
weight_min
)
/
(
std
::
numeric_limits
<
int8_t
>::
max
()
-
std
::
numeric_limits
<
int8_t
>::
min
());
int
weight_zp
=
std
::
numeric_limits
<
int8_t
>::
max
()
-
weight_max
/
weight_scale
;
double
output_scale
=
(
output_max
-
output_min
)
/
(
std
::
numeric_limits
<
int8_t
>::
max
()
-
std
::
numeric_limits
<
int8_t
>::
min
());
int
output_zp
=
std
::
numeric_limits
<
int8_t
>::
max
()
-
output_max
/
output_scale
;
*
scale
=
output_scale
;
*
zeropoint
=
output_zp
;
Tensor
*
in_t
=
new
Tensor
(
kNumberTypeInt8
,
{
2
,
2
,
2
,
2
},
schema
::
Format_NHWC
,
static_cast
<
schema
::
NodeType
>
(
1
));
in_t
->
MallocData
();
float
in
[]
=
{
-
3.2366564
,
-
4.7733846
,
-
7.8329225
,
16.146885
,
5.060793
,
-
6.1471
,
-
1.7680453
,
-
6.5721383
,
17.87506
,
-
5.1192183
,
10.742863
,
1.4536934
,
19.693445
,
19.45783
,
5.063163
,
0.5234792
};
Quantize
(
in
,
in_t
->
ElementsNum
(),
input_scale
,
input_zp
,
reinterpret_cast
<
int8_t
*>
(
in_t
->
Data
()));
auto
in_quant_arg
=
new
mindspore
::
lite
::
tensor
::
QuantArg
();
in_quant_arg
->
zeroPoint
=
input_zp
;
in_quant_arg
->
scale
=
input_scale
;
in_t
->
AddQuantParam
(
*
in_quant_arg
);
inputs_
->
push_back
(
in_t
);
Tensor
*
weight_t
=
new
Tensor
(
kNumberTypeInt8
,
{
3
,
8
},
schema
::
Format_NHWC
,
static_cast
<
schema
::
NodeType
>
(
1
));
weight_t
->
MallocData
();
float
weight
[]
=
{
-
0.24438887
,
0.06738146
,
-
0.8169129
,
0.21510671
,
-
0.012470592
,
-
0.053063435
,
0.6050155
,
0.8656233
,
0.12911413
,
-
0.028635843
,
-
0.034080597
,
-
0.10622552
,
-
0.012254699
,
-
0.01312836
,
0.25241964
,
-
0.4706142
,
0.2451482
,
-
0.9558459
,
0.4481974
,
0.33251503
,
-
0.011705584
,
-
0.1720293
,
-
0.39410214
,
-
0.73637343
};
Quantize
(
weight
,
weight_t
->
ElementsNum
(),
weight_scale
,
weight_zp
,
reinterpret_cast
<
int8_t
*>
(
weight_t
->
Data
()));
auto
weight_quant_arg
=
new
mindspore
::
lite
::
tensor
::
QuantArg
();
weight_quant_arg
->
zeroPoint
=
weight_zp
;
weight_quant_arg
->
scale
=
weight_scale
;
weight_t
->
AddQuantParam
(
*
weight_quant_arg
);
inputs_
->
push_back
(
weight_t
);
Tensor
*
bias_t
=
new
Tensor
(
kNumberTypeInt32
,
{
3
},
schema
::
Format_NHWC
,
static_cast
<
schema
::
NodeType
>
(
1
));
bias_t
->
MallocData
();
memset
(
bias_t
->
Data
(),
0
,
sizeof
(
int
)
*
bias_t
->
ElementsNum
());
inputs_
->
push_back
(
bias_t
);
Tensor
*
out_t
=
new
Tensor
(
kNumberTypeInt8
,
{
2
,
3
},
schema
::
Format_NHWC
,
static_cast
<
schema
::
NodeType
>
(
1
));
out_t
->
MallocData
();
auto
output_quant_arg
=
new
mindspore
::
lite
::
tensor
::
QuantArg
();
output_quant_arg
->
zeroPoint
=
output_zp
;
output_quant_arg
->
scale
=
output_scale
;
out_t
->
AddQuantParam
(
*
output_quant_arg
);
outputs_
->
push_back
(
out_t
);
*
correct
=
reinterpret_cast
<
float
*>
(
malloc
(
out_t
->
ElementsNum
()
*
sizeof
(
float
)));
float
nchw_co
[]
=
{
3.84586822
,
0.93586633
,
12.16212629
,
-
10.93835061
,
2.46887183
,
8.61480108
};
memcpy
(
*
correct
,
nchw_co
,
out_t
->
ElementsNum
()
*
sizeof
(
float
));
matmal_param
->
b_transpose_
=
true
;
matmal_param
->
a_transpose_
=
false
;
matmal_param
->
has_bias_
=
true
;
matmal_param
->
minf_
=
-
FLT_MAX
;
matmal_param
->
maxf_
=
FLT_MAX
;
return
out_t
->
ElementsNum
();
}
TEST_F
(
TestFcInt8
,
fcint8
)
{
std
::
vector
<
lite
::
tensor
::
Tensor
*>
inputs_
;
std
::
vector
<
lite
::
tensor
::
Tensor
*>
outputs_
;
auto
matmul_param
=
new
MatMulParameter
();
float
*
correct
;
double
output_scale
;
int
output_zp
;
int
total_size
=
FcInt8TestInit
(
&
inputs_
,
&
outputs_
,
matmul_param
,
&
correct
,
&
output_scale
,
&
output_zp
);
lite
::
Context
*
ctx
=
new
lite
::
Context
;
ctx
->
threadNum
=
2
;
kernel
::
FullconnectionInt8CPUKernel
*
fc
=
new
kernel
::
FullconnectionInt8CPUKernel
(
reinterpret_cast
<
OpParameter
*>
(
matmul_param
),
inputs_
,
outputs_
,
ctx
);
fc
->
Init
();
fc
->
Run
();
float
fout
[
6
]
=
{
0
};
Dequantize
(
reinterpret_cast
<
int8_t
*>
(
outputs_
[
0
]
->
Data
()),
outputs_
[
0
]
->
ElementsNum
(),
output_scale
,
output_zp
,
fout
);
CompareOutputData
(
fout
,
correct
,
6
,
0.2
);
delete
matmul_param
;
delete
fc
;
for
(
auto
t
:
inputs_
)
delete
t
;
for
(
auto
t
:
outputs_
)
delete
t
;
free
(
correct
);
}
}
// namespace mindspore
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录