Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b99d8590
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b99d8590
编写于
8月 07, 2020
作者:
Z
zhanyuan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix matmul asm bugs
上级
7ec1cb4e
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
187 addition
and
90 deletion
+187
-90
mindspore/lite/src/ops/matmul.cc
mindspore/lite/src/ops/matmul.cc
+1
-1
mindspore/lite/src/ops/power.cc
mindspore/lite/src/ops/power.cc
+10
-14
mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc
mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc
+1
-0
mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc
mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc
+7
-7
mindspore/lite/src/runtime/kernel/arm/fp32/power.cc
mindspore/lite/src/runtime/kernel/arm/fp32/power.cc
+8
-5
mindspore/lite/src/runtime/kernel/arm/fp32/power.h
mindspore/lite/src/runtime/kernel/arm/fp32/power.h
+2
-0
mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.s
...lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.s
+54
-48
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.cc
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.cc
+5
-2
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h
+2
-2
mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h
...lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h
+3
-3
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc
.../test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc
+94
-8
未找到文件。
mindspore/lite/src/ops/matmul.cc
浏览文件 @
b99d8590
...
...
@@ -35,7 +35,7 @@ int MatMul::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
std
::
vector
<
int
>
a_shape
=
input0
->
shape
();
std
::
vector
<
int
>
b_shape
=
input1
->
shape
();
if
(
a_shape
.
size
()
<
3
||
b_shape
.
size
()
<
3
)
{
if
(
a_shape
.
size
()
<
2
||
b_shape
.
size
()
<
2
)
{
MS_LOG
(
ERROR
)
<<
"inputs shape is invalid"
;
return
RET_INPUT_TENSOR_ERROR
;
}
...
...
mindspore/lite/src/ops/power.cc
浏览文件 @
b99d8590
...
...
@@ -24,24 +24,20 @@ int Power::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
MS_ASSERT
(
this
->
primitive
!=
nullptr
);
auto
x_tensor
=
inputs
[
0
];
MS_ASSERT
(
x_tensor
!=
nullptr
);
auto
exp_tensor
=
inputs
[
1
];
MS_ASSERT
(
exp_tensor
!=
nullptr
);
tensor
::
Tensor
*
exp_tensor
=
nullptr
;
if
(
inputs
.
size
()
==
2
)
{
exp_tensor
=
inputs
[
1
];
MS_ASSERT
(
exp_tensor
!=
nullptr
);
}
auto
output_tensor
=
outputs
[
0
];
MS_ASSERT
(
output_tensor
!=
nullptr
);
if
(
inputs
.
size
()
<
2
)
{
MS_LOG
(
ERROR
)
<<
"input size"
<<
inputs
.
size
()
<<
" is error!"
;
return
RET_INPUT_TENSOR_ERROR
;
}
if
(
exp_tensor
->
shape
()
!=
x_tensor
->
shape
()
&&
exp_tensor
->
shape
().
size
()
!=
1
)
{
MS_LOG
(
ERROR
)
<<
"Power inputs shape is not equal!"
;
return
RET_INPUT_TENSOR_ERROR
;
if
(
exp_tensor
)
{
if
(
exp_tensor
->
shape
()
!=
x_tensor
->
shape
()
||
exp_tensor
->
data_type
()
!=
x_tensor
->
data_type
())
{
MS_LOG
(
ERROR
)
<<
"Power inputs shape or type is not equal!"
;
return
RET_INPUT_TENSOR_ERROR
;
}
}
int
exp_size
=
std
::
accumulate
(
exp_tensor
->
shape
().
begin
(),
exp_tensor
->
shape
().
end
(),
1
,
std
::
multiplies
<
int
>
());
if
(
x_tensor
->
data_type
()
!=
exp_tensor
->
data_type
()
&&
exp_size
!=
1
)
{
MS_LOG
(
ERROR
)
<<
"Exponent tensor's shape is wrong"
;
return
RET_INPUT_TENSOR_ERROR
;
}
output_tensor
->
SetFormat
(
x_tensor
->
GetFormat
());
output_tensor
->
set_shape
(
x_tensor
->
shape
());
output_tensor
->
set_data_type
(
x_tensor
->
data_type
());
...
...
mindspore/lite/src/runtime/kernel/arm/base/matmul_base.cc
浏览文件 @
b99d8590
...
...
@@ -69,4 +69,5 @@ kernel::LiteKernel *CpuMatmulKernelCreator(const std::vector<lite::tensor::Tenso
}
REG_KERNEL
(
kCPU
,
kNumberTypeFloat32
,
PrimitiveType_MatMul
,
CpuMatmulKernelCreator
)
REG_KERNEL
(
kCPU
,
kNumberTypeInt8
,
PrimitiveType_MatMul
,
CpuMatmulKernelCreator
)
}
// namespace mindspore::kernel
mindspore/lite/src/runtime/kernel/arm/fp32/matmul.cc
浏览文件 @
b99d8590
...
...
@@ -34,15 +34,15 @@ int MatmulCPUKernel::ReSize() { return RET_OK; }
int
MatmulCPUKernel
::
Init
()
{
int
batch
=
1
;
auto
x
_shape
=
inputs_
[
0
]
->
shape
();
auto
o
_shape
=
outputs_
[
0
]
->
shape
();
for
(
int
i
=
0
;
i
<
x
_shape
.
size
()
-
2
;
++
i
)
{
batch
*=
x
_shape
[
i
];
auto
a
_shape
=
inputs_
[
0
]
->
shape
();
auto
c
_shape
=
outputs_
[
0
]
->
shape
();
for
(
int
i
=
0
;
i
<
a
_shape
.
size
()
-
2
;
++
i
)
{
batch
*=
a
_shape
[
i
];
}
params_
->
batch
=
batch
;
params_
->
row_
=
o_shape
[
o
_shape
.
size
()
-
2
];
params_
->
col_
=
o_shape
[
o
_shape
.
size
()
-
1
];
params_
->
deep_
=
params_
->
a_transpose_
?
x_shape
[
x_shape
.
size
()
-
2
]
:
x_shape
[
x
_shape
.
size
()
-
1
];
params_
->
row_
=
c_shape
[
c
_shape
.
size
()
-
2
];
params_
->
col_
=
c_shape
[
c
_shape
.
size
()
-
1
];
params_
->
deep_
=
params_
->
a_transpose_
?
a_shape
[
a_shape
.
size
()
-
2
]
:
a_shape
[
a
_shape
.
size
()
-
1
];
params_
->
row_8_
=
UP_ROUND
(
params_
->
row_
,
8
);
params_
->
col_8_
=
UP_ROUND
(
params_
->
col_
,
8
);
thread_count_
=
MSMIN
(
thread_count_
,
UP_DIV
(
params_
->
col_8_
,
8
));
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/power.cc
浏览文件 @
b99d8590
...
...
@@ -51,15 +51,19 @@ int PowerCPUKernel::Run() {
int
PowerCPUKernel
::
RunImpl
(
int
task_id
)
{
auto
x_addr
=
reinterpret_cast
<
float
*>
(
inputs_
[
0
]
->
Data
());
auto
exp_addr
=
reinterpret_cast
<
float
*>
(
inputs_
[
1
]
->
Data
());
auto
output_addr
=
reinterpret_cast
<
float
*>
(
outputs_
[
0
]
->
Data
());
auto
size
=
inputs_
[
0
]
->
ElementsNum
();
int
stride
=
UP_DIV
(
size
,
thread_count_
);
int
len
=
MSMIN
(
stride
,
size
-
stride
*
task_id
);
bool
broadcast
=
(
inputs_
[
1
]
->
ElementsNum
()
==
1
)
?
true
:
false
;
float
*
exp_addr
=
nullptr
;
bool
broadcast
=
true
;
if
(
inputs_
.
size
()
==
2
)
{
exp_addr
=
reinterpret_cast
<
float
*>
(
inputs_
[
1
]
->
Data
());
broadcast
=
false
;
}
float
*
cur_exp
;
if
(
broadcast
)
{
cur_exp
=
exp_addr
;
cur_exp
=
&
power_
;
}
else
{
cur_exp
=
exp_addr
+
stride
*
task_id
;
}
...
...
@@ -73,8 +77,7 @@ kernel::LiteKernel *CpuPowerFp32KernelCreator(const std::vector<lite::tensor::Te
const
kernel
::
KernelKey
&
desc
)
{
MS_ASSERT
(
opParameter
!=
nullptr
);
MS_ASSERT
(
desc
.
type
==
schema
::
PrimitiveType_Power
);
auto
*
kernel
=
new
(
std
::
nothrow
)
PowerCPUKernel
(
opParameter
,
inputs
,
outputs
,
ctx
);
auto
*
kernel
=
new
(
std
::
nothrow
)
PowerCPUKernel
(
opParameter
,
inputs
,
outputs
,
ctx
);
if
(
kernel
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"new PowerCPUKernel fail!"
;
return
nullptr
;
...
...
mindspore/lite/src/runtime/kernel/arm/fp32/power.h
浏览文件 @
b99d8590
...
...
@@ -30,6 +30,7 @@ class PowerCPUKernel : public LiteKernel {
:
LiteKernel
(
param
,
inputs
,
outputs
),
ctx_
(
ctx
),
thread_count_
(
ctx
->
thread_num_
),
power_
(
reinterpret_cast
<
PowerParameter
*>
(
opParameter
)
->
power_
),
scale_
(
reinterpret_cast
<
PowerParameter
*>
(
opParameter
)
->
scale_
),
shift_
(
reinterpret_cast
<
PowerParameter
*>
(
opParameter
)
->
shift_
)
{}
~
PowerCPUKernel
()
override
=
default
;
...
...
@@ -42,6 +43,7 @@ class PowerCPUKernel : public LiteKernel {
private:
const
lite
::
Context
*
ctx_
;
int
thread_count_
;
float
power_
;
float
scale_
;
float
shift_
;
};
...
...
mindspore/lite/src/runtime/kernel/arm/nnacl/assembly/arm64/matmul.s
浏览文件 @
b99d8590
#ifdef __aarch64__
.
text
.
align
5
.
global
Mat
M
ulFloatNeon64
.
global
Mat
m
ulFloatNeon64
#ifndef __APPLE__
.
type
Mat
M
ulFloatNeon64
,
%
function
.
type
Mat
m
ulFloatNeon64
,
%
function
#endif
//
A
:
LM
[
row_8
*
depth
]
col_8_major
...
...
@@ -46,41 +46,39 @@
//
accumulators
8
x8
block
/////////////////////////////////////////////////////////////////////////////////
//
//
void
Mat
MulFloatNeon64
(
const
float
*
a
,
const
float
*
b
,
float
*
c
,
const
float
*
bias
,
float
maxf
,
float
minf
,
int
depth
,
int
row
,
int
col
)
//
void
Mat
mulFloatNeon64
(
const
float
*
a
,
const
float
*
b
,
float
*
c
,
const
float
*
bias
,
int
act_type
,
int
depth
,
int
row
,
int
col
)
//
x0
:
a
//
x1
:
b
//
x2
:
c
//
x3
:
bias
//
v0.s
[0]:
maxf
//
v1.s
[0]:
minf
//
w4
:
depth
//
w5
:
row
//
w6
:
col
//
w4
:
act_type
//
w5
:
depth
//
w6
:
row
//
w7
:
col
Mat
M
ulFloatNeon64
:
Mat
m
ulFloatNeon64
:
sub
sp
,
sp
,
#
128
st1
{
v8
.4
s
,
v9
.4
s
,
v10
.4
s
,
v11
.4
s
},
[
sp
],
#
64
st1
{
v12
.4
s
,
v13
.4
s
,
v14
.4
s
,
v15
.4
s
},
[
sp
],
#
64
mov
w7
,
v0
.
s
[
0
]
mov
w8
,
v1
.
s
[
0
]
mov
w9
,
0
//
rm
col
offset
mov
w10
,
0
//
lm
row
offset
mov
w9
,
#
0
//
rm
col
offset
mov
w10
,
#
0
//
lm
row
offset
mov
w18
,
#
32
//
sizeof
(
float
)*
8
mul
w15
,
w
4
,
w18
//
the
stride
of
lm
/
rm
:
sizeof
(
float
)*
8
*
depth
mul
w15
,
w
5
,
w18
//
the
stride
of
lm
/
rm
:
sizeof
(
float
)*
8
*
depth
mov
x11
,
x3
//
bias
flag
L1
:
cmp
w9
,
w
6
cmp
w9
,
w
7
beq
End1
mov
w10
,
0
//
reset
lm
row
offset
mov
w10
,
#
0
//
reset
lm
row
offset
mov
x12
,
x0
//
reload
lm
ptr
mov
x14
,
x3
//
reload
bias
ptr
L2
:
cmp
w10
,
w6
beq
End2
mov
w13
,
w4
//
reload
depth
mov
x16
,
x1
//
reload
rm
ptr
mov
w13
,
w5
//
reload
depth
mov
x14
,
x3
//
reload
bias
ptr
dup
v16
.4
s
,
wzr
dup
v17
.4
s
,
wzr
dup
v18
.4
s
,
wzr
...
...
@@ -103,7 +101,7 @@ OptLoopMul4:
blt
CommLoopMul
ld1
{
v0
.4
s
,
v1
.4
s
},
[
x12
],
#
32
ld1
{
v8
.4
s
,
v9
.4
s
},
[
x1
],
#
32
ld1
{
v8
.4
s
,
v9
.4
s
},
[
x1
6
],
#
32
fmla
v16
.4
s
,
v8
.4
s
,
v0
.
s
[
0
]
fmla
v17
.4
s
,
v9
.4
s
,
v0
.
s
[
0
]
fmla
v18
.4
s
,
v8
.4
s
,
v0
.
s
[
1
]
...
...
@@ -112,7 +110,7 @@ OptLoopMul4:
fmla
v21
.4
s
,
v9
.4
s
,
v0
.
s
[
2
]
fmla
v22
.4
s
,
v8
.4
s
,
v0
.
s
[
3
]
fmla
v23
.4
s
,
v9
.4
s
,
v0
.
s
[
3
]
ld1
{
v10
.4
s
,
v11
.4
s
},
[
x1
],
#
32
ld1
{
v10
.4
s
,
v11
.4
s
},
[
x1
6
],
#
32
fmla
v24
.4
s
,
v8
.4
s
,
v1
.
s
[
0
]
fmla
v25
.4
s
,
v9
.4
s
,
v1
.
s
[
0
]
fmla
v26
.4
s
,
v8
.4
s
,
v1
.
s
[
1
]
...
...
@@ -130,7 +128,7 @@ OptLoopMul4:
fmla
v21
.4
s
,
v11
.4
s
,
v2
.
s
[
2
]
fmla
v22
.4
s
,
v10
.4
s
,
v2
.
s
[
3
]
fmla
v23
.4
s
,
v11
.4
s
,
v2
.
s
[
3
]
ld1
{
v12
.4
s
,
v13
.4
s
},
[
x1
],
#
32
ld1
{
v12
.4
s
,
v13
.4
s
},
[
x1
6
],
#
32
fmla
v24
.4
s
,
v10
.4
s
,
v3
.
s
[
0
]
fmla
v25
.4
s
,
v11
.4
s
,
v3
.
s
[
0
]
fmla
v26
.4
s
,
v10
.4
s
,
v3
.
s
[
1
]
...
...
@@ -153,7 +151,7 @@ OptLoopMul4:
fmla
v25
.4
s
,
v13
.4
s
,
v5
.
s
[
0
]
fmla
v26
.4
s
,
v12
.4
s
,
v5
.
s
[
1
]
fmla
v27
.4
s
,
v13
.4
s
,
v5
.
s
[
1
]
ld1
{
v14
.4
s
,
v15
.4
s
},
[
x1
],
#
32
ld1
{
v14
.4
s
,
v15
.4
s
},
[
x1
6
],
#
32
fmla
v28
.4
s
,
v12
.4
s
,
v5
.
s
[
2
]
fmla
v29
.4
s
,
v13
.4
s
,
v5
.
s
[
2
]
fmla
v30
.4
s
,
v12
.4
s
,
v5
.
s
[
3
]
...
...
@@ -182,7 +180,7 @@ CommLoopMul:
blt
Bias
ld1
{
v0
.4
s
,
v1
.4
s
},
[
x12
],
#
32
ld1
{
v2
.4
s
,
v3
.4
s
},
[
x1
],
#
32
ld1
{
v2
.4
s
,
v3
.4
s
},
[
x1
6
],
#
32
fmla
v16
.4
s
,
v2
.4
s
,
v0
.
s
[
0
]
fmla
v17
.4
s
,
v3
.4
s
,
v0
.
s
[
0
]
fmla
v18
.4
s
,
v2
.4
s
,
v0
.
s
[
1
]
...
...
@@ -203,8 +201,7 @@ CommLoopMul:
b
CommLoopMul
Bias
:
cmp
x3
,
#
0
beq
Relu
cbz
x11
,
Activation
ld1
{
v0
.4
s
},
[
x14
],
#
16
ld1
{
v1
.4
s
},
[
x14
],
#
16
fadd
v16
.4
s
,
v16
.4
s
,
v0
.4
s
...
...
@@ -224,9 +221,34 @@ Bias:
fadd
v30
.4
s
,
v30
.4
s
,
v0
.4
s
fadd
v31
.4
s
,
v31
.4
s
,
v1
.4
s
Activation
:
cmp
w4
,
#
2
beq
Relu6
cmp
w4
,
#
1
beq
Relu
b
TransToOut
Relu6
:
mov
w8
,
#
6
dup
v15
.4
s
,
w8
scvtf
v15
.4
s
,
v15
.4
s
fmin
v16
.4
s
,
v16
.4
s
,
v15
.4
s
fmin
v17
.4
s
,
v17
.4
s
,
v15
.4
s
fmin
v18
.4
s
,
v18
.4
s
,
v15
.4
s
fmin
v19
.4
s
,
v19
.4
s
,
v15
.4
s
fmin
v20
.4
s
,
v20
.4
s
,
v15
.4
s
fmin
v21
.4
s
,
v21
.4
s
,
v15
.4
s
fmin
v22
.4
s
,
v22
.4
s
,
v15
.4
s
fmin
v23
.4
s
,
v23
.4
s
,
v15
.4
s
fmin
v24
.4
s
,
v24
.4
s
,
v15
.4
s
fmin
v25
.4
s
,
v25
.4
s
,
v15
.4
s
fmin
v26
.4
s
,
v26
.4
s
,
v15
.4
s
fmin
v27
.4
s
,
v27
.4
s
,
v15
.4
s
fmin
v28
.4
s
,
v28
.4
s
,
v15
.4
s
fmin
v29
.4
s
,
v29
.4
s
,
v15
.4
s
fmin
v30
.4
s
,
v30
.4
s
,
v15
.4
s
fmin
v31
.4
s
,
v31
.4
s
,
v15
.4
s
Relu
:
dup
v15
.4
s
,
w7
dup
v14
.4
s
,
w8
dup
v14
.4
s
,
wzr
fmax
v16
.4
s
,
v16
.4
s
,
v14
.4
s
fmax
v17
.4
s
,
v17
.4
s
,
v14
.4
s
fmax
v18
.4
s
,
v18
.4
s
,
v14
.4
s
...
...
@@ -244,24 +266,6 @@ Relu:
fmax
v30
.4
s
,
v30
.4
s
,
v14
.4
s
fmax
v31
.4
s
,
v31
.4
s
,
v14
.4
s
fmin
v16
.4
s
,
v16
.4
s
,
v15
.4
s
fmin
v17
.4
s
,
v17
.4
s
,
v15
.4
s
fmin
v18
.4
s
,
v18
.4
s
,
v15
.4
s
fmin
v19
.4
s
,
v19
.4
s
,
v15
.4
s
fmin
v20
.4
s
,
v20
.4
s
,
v15
.4
s
fmin
v20
.4
s
,
v20
.4
s
,
v15
.4
s
fmin
v21
.4
s
,
v21
.4
s
,
v15
.4
s
fmin
v22
.4
s
,
v22
.4
s
,
v15
.4
s
fmin
v23
.4
s
,
v23
.4
s
,
v15
.4
s
fmin
v24
.4
s
,
v24
.4
s
,
v15
.4
s
fmin
v25
.4
s
,
v25
.4
s
,
v15
.4
s
fmin
v26
.4
s
,
v26
.4
s
,
v15
.4
s
fmin
v27
.4
s
,
v27
.4
s
,
v15
.4
s
fmin
v28
.4
s
,
v28
.4
s
,
v15
.4
s
fmin
v29
.4
s
,
v29
.4
s
,
v15
.4
s
fmin
v30
.4
s
,
v30
.4
s
,
v15
.4
s
fmin
v31
.4
s
,
v31
.4
s
,
v15
.4
s
TransToOut
:
st1
{
v16
.4
s
},
[
x2
],
#
16
st1
{
v17
.4
s
},
[
x2
],
#
16
...
...
@@ -280,11 +284,13 @@ TransToOut:
st1
{
v30
.4
s
},
[
x2
],
#
16
st1
{
v31
.4
s
},
[
x2
],
#
16
add
w10
,
w10
,
#
8
//
l
hs
row
offset
+
8
add
w10
,
w10
,
#
8
//
l
m
row
offset
+
8
b
L2
End2
:
add
w9
,
w9
,
#
8
//
rhs
col
offset
+
8
add
w9
,
w9
,
#
8
//
rm
col
offset
+
8
add
x1
,
x1
,
x15
//
rm
ptr
+
stride
add
x3
,
x3
,
x18
//
bias
ptr
+
stride
b
L1
End1
:
...
...
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.cc
浏览文件 @
b99d8590
...
...
@@ -42,7 +42,7 @@ void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col)
float
*
dst_c
=
dst_r
+
ci
*
C8NUM
;
/* 8x4 row-major to col-major */
#ifdef ENABLE_
NEON
#ifdef ENABLE_
ARM64
size_t
stride
=
col
*
4
;
asm
volatile
(
"mov x10, %[src_c]
\n
"
...
...
@@ -156,6 +156,9 @@ void MatMul8x8(const float *a, const float *b, float *c, const float *bias, ActT
void
MatMul
(
const
float
*
a
,
const
float
*
b
,
float
*
c
,
const
float
*
bias
,
ActType
act_type
,
int
deep
,
int
row_8_
,
int
col_8_
)
{
#ifdef __aarch64__
MatmulFloatNeon64
(
a
,
b
,
c
,
bias
,
(
int
)
act_type
,
deep
,
row_8_
,
col_8_
);
#else
MatMul8x8
(
a
,
b
,
c
,
bias
,
act_type
,
deep
,
row_8_
,
col_8_
);
return
;
#endif
}
mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/matmul.h
浏览文件 @
b99d8590
...
...
@@ -32,8 +32,8 @@ void MatMul8x8(const float *a, const float *b, float *c, const float *bias, floa
extern
"C"
{
#endif
#ifdef __aarch64__
void
Mat
MulFloatNeon64
(
const
float
*
a
,
const
float
*
b
,
float
*
c
,
const
float
*
bias
,
float
maxf
,
float
minf
,
int
depth
,
int
row
,
int
col
);
void
Mat
mulFloatNeon64
(
const
float
*
a
,
const
float
*
b
,
float
*
c
,
const
float
*
bias
,
int
act_type
,
int
depth
,
int
row
,
int
col
);
#endif
#ifdef __cplusplus
}
...
...
mindspore/lite/src/runtime/kernel/arm/nnacl/quantization/quantize.h
浏览文件 @
b99d8590
...
...
@@ -157,10 +157,10 @@ inline void CalculateActivationRangeQuantized(bool is_relu, bool is_relu6, int32
// quantize from float to int8
inline
void
Quantize
(
float
*
input_data
,
int
length
,
float
scale
,
int
zero_point
,
int8_t
*
output_data
)
{
for
(
int
i
=
0
;
i
<
length
;
++
i
)
{
int
r
=
(
int
)
round
(
input_data
[
i
]
/
scale
+
zero_point
);
int8_t
q
=
r
>
CHAR_MAX
?
(
int8_t
)
CHAR_MAX
:
(
int8_t
)
r
;
int
q
=
(
int
)
round
(
input_data
[
i
]
/
scale
+
zero_point
);
q
=
q
>
CHAR_MAX
?
CHAR_MAX
:
q
;
q
=
q
<
CHAR_MIN
?
CHAR_MIN
:
q
;
output_data
[
i
]
=
q
;
output_data
[
i
]
=
(
int8_t
)
q
;
}
}
...
...
mindspore/lite/test/ut/src/runtime/kernel/arm/fp32/matmul_fp32_tests.cc
浏览文件 @
b99d8590
...
...
@@ -201,19 +201,108 @@ TEST_F(TestMatMulFp32, simple) {
0.006050155
,
0.008656233
,
0.012911413
,
-
0.0028635843
,
-
0.00034080597
,
-
0.0010622552
,
-
0.012254699
,
-
0.01312836
,
0.0025241964
,
-
0.004706142
,
0.002451482
,
-
0.009558459
,
0.004481974
,
0.0033251503
,
-
0.011705584
,
-
0.001720293
,
-
0.0039410214
,
-
0.0073637343
};
std
::
vector
<
int
>
a_shape
=
{
1
,
2
,
8
};
std
::
vector
<
int
>
b_shape
=
{
1
,
8
,
3
};
std
::
vector
<
int
>
c_shape
=
{
1
,
2
,
3
};
std
::
vector
<
int
>
a_shape
=
{
2
,
8
};
std
::
vector
<
int
>
b_shape
=
{
8
,
3
};
std
::
vector
<
int
>
c_shape
=
{
2
,
3
};
int
total_size
=
MMTestInit
(
&
inputs_
,
&
outputs_
,
a
,
b
,
a_shape
,
b_shape
,
c_shape
);
auto
ctx
=
new
lite
::
Context
;
ctx
->
thread_num_
=
2
;
ctx
->
thread_num_
=
1
;
auto
mm
=
new
kernel
::
MatmulCPUKernel
(
reinterpret_cast
<
OpParameter
*>
(
matmul_param
),
inputs_
,
outputs_
,
ctx
);
mm
->
Init
();
mm
->
Run
();
float
correct
[]
=
{
-
0.1256939023733139
,
-
0.07744802534580231
,
0.07410638779401779
,
-
0.3049793541431427
,
-
0.027687929570674896
,
-
0.18109679222106934
};
CompareOutputData
(
reinterpret_cast
<
float
*>
(
outputs_
[
0
]
->
Data
()),
correct
,
total_size
,
0.0001
);
delete
matmul_param
;
delete
mm
;
for
(
auto
t
:
inputs_
)
delete
t
;
for
(
auto
t
:
outputs_
)
delete
t
;
}
TEST_F
(
TestMatMulFp32
,
simple2
)
{
std
::
vector
<
lite
::
tensor
::
Tensor
*>
inputs_
;
std
::
vector
<
lite
::
tensor
::
Tensor
*>
outputs_
;
auto
matmul_param
=
new
MatMulParameter
();
matmul_param
->
a_transpose_
=
false
;
matmul_param
->
b_transpose_
=
false
;
matmul_param
->
has_bias_
=
false
;
float
a
[
25
*
12
]
=
{
1
,
4
,
10
,
2
,
3
,
10
,
4
,
6
,
5
,
6
,
9
,
5
,
4
,
2
,
5
,
7
,
5
,
8
,
0
,
5
,
1
,
0
,
10
,
3
,
0
,
4
,
2
,
3
,
2
,
9
,
8
,
9
,
5
,
4
,
4
,
9
,
7
,
4
,
2
,
6
,
10
,
2
,
1
,
7
,
2
,
10
,
5
,
10
,
1
,
2
,
2
,
9
,
8
,
8
,
2
,
5
,
6
,
3
,
2
,
8
,
3
,
3
,
7
,
3
,
0
,
4
,
10
,
9
,
0
,
5
,
2
,
6
,
1
,
10
,
7
,
6
,
9
,
6
,
0
,
3
,
8
,
0
,
8
,
3
,
10
,
4
,
7
,
7
,
0
,
5
,
6
,
5
,
4
,
6
,
5
,
5
,
3
,
7
,
1
,
9
,
3
,
2
,
8
,
3
,
0
,
0
,
6
,
7
,
6
,
3
,
6
,
5
,
1
,
0
,
4
,
2
,
6
,
0
,
7
,
7
,
7
,
4
,
9
,
8
,
6
,
1
,
10
,
10
,
7
,
3
,
0
,
6
,
9
,
4
,
1
,
4
,
4
,
3
,
1
,
6
,
7
,
3
,
8
,
6
,
4
,
10
,
9
,
8
,
10
,
5
,
2
,
3
,
8
,
10
,
0
,
8
,
2
,
9
,
5
,
3
,
3
,
0
,
1
,
8
,
1
,
1
,
2
,
0
,
1
,
5
,
5
,
0
,
1
,
10
,
9
,
9
,
3
,
6
,
7
,
1
,
2
,
3
,
7
,
5
,
0
,
8
,
2
,
8
,
7
,
8
,
9
,
10
,
4
,
2
,
5
,
3
,
10
,
1
,
5
,
0
,
6
,
2
,
3
,
5
,
5
,
1
,
5
,
5
,
5
,
1
,
8
,
2
,
6
,
9
,
10
,
4
,
9
,
1
,
10
,
9
,
8
,
2
,
5
,
2
,
4
,
2
,
3
,
7
,
7
,
2
,
9
,
10
,
10
,
10
,
5
,
1
,
8
,
8
,
10
,
3
,
2
,
10
,
2
,
6
,
5
,
9
,
10
,
6
,
10
,
0
,
5
,
5
,
4
,
0
,
9
,
4
,
4
,
9
,
4
,
6
,
4
,
2
,
5
,
2
,
10
,
5
,
9
,
8
,
1
,
4
,
7
,
9
,
6
,
5
,
0
,
3
,
6
,
4
,
3
,
10
,
6
,
4
,
10
,
5
,
8
,
8
,
9
,
4
,
5
,
6
,
8
,
9
,
2
,
2
,
4
,
4
,
8
,
0
,
4
,
5
};
float
b
[
12
*
36
]
=
{
6
,
6
,
7
,
2
,
1
,
10
,
3
,
7
,
7
,
5
,
5
,
5
,
6
,
6
,
9
,
8
,
4
,
10
,
9
,
5
,
5
,
7
,
2
,
1
,
7
,
9
,
10
,
0
,
3
,
10
,
4
,
2
,
7
,
4
,
3
,
10
,
5
,
3
,
1
,
3
,
3
,
1
,
9
,
6
,
7
,
6
,
6
,
6
,
7
,
6
,
10
,
8
,
2
,
8
,
5
,
2
,
1
,
7
,
5
,
9
,
10
,
9
,
0
,
8
,
10
,
2
,
3
,
4
,
0
,
7
,
5
,
5
,
0
,
9
,
6
,
1
,
6
,
7
,
4
,
1
,
0
,
3
,
0
,
7
,
3
,
0
,
10
,
7
,
6
,
4
,
10
,
7
,
6
,
5
,
10
,
2
,
10
,
9
,
10
,
6
,
9
,
10
,
8
,
8
,
5
,
3
,
9
,
10
,
8
,
3
,
3
,
4
,
6
,
2
,
6
,
0
,
4
,
0
,
3
,
4
,
1
,
0
,
3
,
10
,
5
,
4
,
0
,
2
,
3
,
2
,
4
,
3
,
10
,
5
,
4
,
10
,
8
,
2
,
0
,
4
,
0
,
5
,
8
,
0
,
1
,
10
,
0
,
3
,
1
,
1
,
9
,
4
,
0
,
3
,
0
,
1
,
6
,
3
,
10
,
0
,
10
,
3
,
3
,
0
,
6
,
7
,
3
,
2
,
3
,
5
,
10
,
6
,
1
,
5
,
7
,
3
,
3
,
1
,
1
,
10
,
5
,
4
,
0
,
8
,
4
,
0
,
9
,
6
,
2
,
3
,
6
,
10
,
10
,
0
,
2
,
2
,
1
,
2
,
7
,
10
,
9
,
7
,
10
,
2
,
8
,
5
,
3
,
7
,
0
,
4
,
3
,
4
,
8
,
3
,
8
,
0
,
5
,
5
,
6
,
9
,
10
,
0
,
1
,
5
,
6
,
6
,
4
,
7
,
7
,
6
,
7
,
9
,
5
,
5
,
6
,
0
,
4
,
1
,
2
,
6
,
8
,
4
,
10
,
4
,
10
,
9
,
8
,
8
,
1
,
7
,
1
,
8
,
1
,
0
,
10
,
8
,
8
,
1
,
8
,
0
,
10
,
3
,
1
,
7
,
0
,
10
,
5
,
0
,
2
,
8
,
4
,
1
,
8
,
1
,
6
,
7
,
1
,
8
,
3
,
4
,
3
,
4
,
7
,
0
,
9
,
1
,
1
,
4
,
8
,
10
,
0
,
3
,
3
,
2
,
7
,
9
,
3
,
3
,
10
,
10
,
9
,
4
,
4
,
0
,
7
,
1
,
1
,
3
,
5
,
1
,
4
,
8
,
5
,
7
,
3
,
9
,
10
,
1
,
5
,
9
,
7
,
4
,
10
,
10
,
3
,
4
,
3
,
5
,
1
,
10
,
5
,
2
,
3
,
3
,
0
,
3
,
1
,
2
,
8
,
7
,
4
,
2
,
0
,
8
,
7
,
6
,
6
,
6
,
5
,
7
,
5
,
5
,
3
,
0
,
4
,
10
,
1
,
7
,
8
,
9
,
6
,
7
,
0
,
1
,
9
,
3
,
1
,
6
,
8
,
4
,
9
,
0
,
3
,
2
,
4
,
0
,
2
,
7
,
2
,
2
,
8
,
0
,
4
,
1
,
3
,
2
,
6
,
8
,
5
,
5
,
2
,
3
,
9
,
0
,
1
,
7
,
6
,
9
,
1
,
10
,
4
,
10
,
5
,
10
,
0
,
9
,
5
,
1
,
6
,
2
,
9
,
9
,
8
,
8
,
10
,
8
,
1
,
6
,
5
,
8
,
8
,
6
,
4
,
8
,
10
,
3
,
0
,
6
,
2
,
8
,
4
,
2
};
std
::
vector
<
int
>
a_shape
=
{
25
,
12
};
std
::
vector
<
int
>
b_shape
=
{
12
,
36
};
std
::
vector
<
int
>
c_shape
=
{
25
,
36
};
int
total_size
=
MMTestInit
(
&
inputs_
,
&
outputs_
,
a
,
b
,
a_shape
,
b_shape
,
c_shape
);
auto
ctx
=
new
lite
::
Context
;
ctx
->
thread_num_
=
2
;
auto
mm
=
new
kernel
::
MatmulCPUKernel
(
reinterpret_cast
<
OpParameter
*>
(
matmul_param
),
inputs_
,
outputs_
,
ctx
);
mm
->
Init
();
mm
->
Run
();
float
correct
[]
=
{
263
,
386
,
184
,
309
,
338
,
244
,
359
,
294
,
252
,
254
,
273
,
353
,
320
,
183
,
412
,
273
,
271
,
307
,
329
,
314
,
391
,
261
,
400
,
280
,
416
,
399
,
355
,
427
,
373
,
302
,
288
,
349
,
336
,
241
,
349
,
393
,
226
,
285
,
134
,
209
,
264
,
163
,
281
,
212
,
219
,
171
,
221
,
228
,
227
,
131
,
289
,
196
,
204
,
270
,
238
,
205
,
303
,
196
,
280
,
156
,
311
,
284
,
282
,
335
,
243
,
245
,
181
,
188
,
280
,
142
,
229
,
256
,
270
,
310
,
184
,
377
,
323
,
187
,
345
,
295
,
255
,
262
,
259
,
332
,
310
,
222
,
357
,
275
,
253
,
301
,
296
,
254
,
316
,
221
,
323
,
322
,
370
,
353
,
281
,
386
,
363
,
240
,
245
,
301
,
270
,
263
,
275
,
292
,
278
,
388
,
199
,
324
,
252
,
336
,
385
,
300
,
257
,
274
,
215
,
243
,
272
,
230
,
485
,
335
,
343
,
366
,
293
,
272
,
337
,
313
,
310
,
305
,
385
,
421
,
377
,
398
,
343
,
262
,
249
,
309
,
258
,
280
,
286
,
411
,
268
,
337
,
127
,
307
,
244
,
185
,
368
,
263
,
178
,
205
,
223
,
281
,
288
,
154
,
339
,
255
,
295
,
250
,
241
,
236
,
289
,
240
,
296
,
261
,
361
,
333
,
282
,
399
,
315
,
202
,
203
,
272
,
231
,
229
,
300
,
273
,
199
,
253
,
246
,
315
,
307
,
213
,
257
,
202
,
243
,
230
,
163
,
288
,
220
,
212
,
361
,
314
,
219
,
296
,
300
,
217
,
274
,
196
,
285
,
264
,
351
,
339
,
312
,
289
,
338
,
282
,
256
,
274
,
214
,
243
,
228
,
302
,
276
,
394
,
110
,
224
,
274
,
163
,
395
,
296
,
231
,
223
,
289
,
311
,
331
,
177
,
405
,
236
,
294
,
293
,
264
,
213
,
314
,
258
,
330
,
270
,
403
,
381
,
305
,
450
,
382
,
250
,
248
,
287
,
278
,
211
,
324
,
374
,
306
,
350
,
246
,
298
,
309
,
305
,
315
,
289
,
292
,
256
,
264
,
341
,
295
,
218
,
427
,
382
,
272
,
359
,
335
,
286
,
333
,
263
,
327
,
275
,
448
,
423
,
380
,
369
,
397
,
330
,
260
,
329
,
285
,
284
,
333
,
397
,
259
,
258
,
146
,
261
,
281
,
156
,
248
,
234
,
236
,
219
,
220
,
207
,
233
,
173
,
326
,
316
,
223
,
301
,
237
,
145
,
202
,
181
,
209
,
236
,
357
,
279
,
265
,
332
,
352
,
230
,
165
,
219
,
154
,
233
,
189
,
237
,
246
,
316
,
147
,
197
,
247
,
221
,
212
,
256
,
201
,
208
,
239
,
220
,
231
,
153
,
322
,
263
,
237
,
278
,
254
,
178
,
215
,
164
,
217
,
211
,
326
,
295
,
284
,
306
,
354
,
247
,
178
,
244
,
216
,
199
,
229
,
308
,
298
,
409
,
306
,
359
,
359
,
273
,
388
,
291
,
301
,
281
,
239
,
395
,
323
,
290
,
505
,
398
,
370
,
381
,
365
,
235
,
344
,
268
,
340
,
351
,
473
,
481
,
445
,
415
,
481
,
373
,
354
,
365
,
284
,
309
,
338
,
469
,
285
,
336
,
166
,
244
,
245
,
247
,
305
,
304
,
273
,
233
,
281
,
260
,
276
,
218
,
364
,
241
,
255
,
330
,
257
,
213
,
296
,
221
,
252
,
251
,
325
,
355
,
301
,
341
,
319
,
246
,
206
,
243
,
295
,
210
,
249
,
357
,
328
,
481
,
196
,
345
,
276
,
338
,
493
,
349
,
236
,
299
,
265
,
388
,
383
,
224
,
573
,
425
,
411
,
354
,
353
,
340
,
363
,
385
,
414
,
387
,
541
,
528
,
412
,
515
,
486
,
298
,
320
,
438
,
254
,
361
,
454
,
494
,
120
,
156
,
151
,
140
,
176
,
99
,
231
,
113
,
197
,
132
,
113
,
190
,
134
,
171
,
264
,
169
,
137
,
219
,
165
,
92
,
172
,
145
,
188
,
186
,
225
,
260
,
166
,
216
,
225
,
161
,
173
,
134
,
147
,
130
,
152
,
218
,
226
,
273
,
205
,
314
,
331
,
157
,
311
,
242
,
289
,
228
,
238
,
346
,
285
,
223
,
344
,
235
,
194
,
282
,
274
,
238
,
358
,
207
,
333
,
270
,
345
,
345
,
302
,
339
,
309
,
273
,
284
,
291
,
297
,
219
,
261
,
338
,
319
,
396
,
200
,
356
,
349
,
311
,
377
,
330
,
280
,
280
,
308
,
351
,
311
,
204
,
421
,
319
,
294
,
348
,
328
,
346
,
387
,
261
,
403
,
335
,
434
,
428
,
333
,
467
,
422
,
270
,
254
,
370
,
345
,
285
,
381
,
378
,
200
,
347
,
110
,
195
,
189
,
184
,
252
,
242
,
134
,
191
,
179
,
205
,
256
,
140
,
349
,
219
,
287
,
216
,
225
,
155
,
223
,
203
,
203
,
196
,
295
,
281
,
321
,
291
,
292
,
235
,
219
,
255
,
177
,
186
,
213
,
349
,
286
,
389
,
180
,
262
,
306
,
275
,
269
,
284
,
257
,
239
,
256
,
262
,
270
,
189
,
410
,
306
,
302
,
297
,
244
,
226
,
335
,
213
,
276
,
257
,
371
,
351
,
398
,
376
,
378
,
289
,
265
,
355
,
258
,
252
,
286
,
446
,
274
,
419
,
214
,
263
,
277
,
296
,
317
,
276
,
202
,
240
,
214
,
287
,
292
,
174
,
454
,
366
,
352
,
328
,
342
,
247
,
300
,
273
,
300
,
232
,
440
,
401
,
436
,
374
,
394
,
351
,
269
,
317
,
247
,
255
,
312
,
416
,
384
,
533
,
202
,
336
,
369
,
322
,
449
,
373
,
291
,
282
,
343
,
409
,
416
,
198
,
526
,
383
,
405
,
363
,
355
,
355
,
478
,
348
,
435
,
296
,
544
,
490
,
519
,
540
,
449
,
390
,
345
,
444
,
378
,
307
,
454
,
542
,
356
,
394
,
179
,
370
,
364
,
152
,
424
,
370
,
316
,
291
,
358
,
420
,
419
,
267
,
429
,
323
,
311
,
348
,
320
,
232
,
344
,
260
,
344
,
369
,
472
,
424
,
339
,
479
,
470
,
297
,
298
,
350
,
300
,
302
,
340
,
389
,
211
,
314
,
186
,
248
,
277
,
184
,
294
,
217
,
204
,
184
,
203
,
311
,
262
,
154
,
324
,
221
,
233
,
249
,
283
,
241
,
331
,
210
,
318
,
191
,
341
,
330
,
331
,
323
,
278
,
289
,
255
,
259
,
294
,
174
,
280
,
323
,
295
,
348
,
303
,
319
,
321
,
286
,
365
,
266
,
310
,
251
,
240
,
406
,
302
,
265
,
457
,
396
,
297
,
366
,
350
,
270
,
343
,
271
,
347
,
314
,
469
,
476
,
396
,
375
,
428
,
351
,
315
,
341
,
291
,
296
,
361
,
428
,
383
,
442
,
232
,
360
,
387
,
279
,
391
,
349
,
348
,
288
,
334
,
374
,
360
,
262
,
485
,
391
,
362
,
379
,
296
,
262
,
406
,
270
,
346
,
346
,
486
,
451
,
451
,
490
,
475
,
339
,
319
,
409
,
315
,
324
,
367
,
493
,
286
,
348
,
185
,
240
,
287
,
214
,
312
,
265
,
237
,
218
,
261
,
316
,
279
,
186
,
377
,
319
,
279
,
304
,
281
,
207
,
261
,
209
,
287
,
270
,
415
,
378
,
312
,
388
,
423
,
273
,
230
,
294
,
239
,
243
,
319
,
346
};
CompareOutputData
(
reinterpret_cast
<
float
*>
(
outputs_
[
0
]
->
Data
()),
correct
,
total_size
,
0.0001
);
delete
mm
;
for
(
auto
t
:
inputs_
)
delete
t
;
for
(
auto
t
:
outputs_
)
delete
t
;
...
...
@@ -243,7 +332,6 @@ TEST_F(TestMatMulFp32, simple_transb) {
mm
->
Run
();
float
correct
[]
=
{
0.00533547
,
0.002545945
,
0.062974121
,
-
0.445441471
,
-
0.246223617
,
-
0.142070031
};
CompareOutputData
(
reinterpret_cast
<
float
*>
(
outputs_
[
0
]
->
Data
()),
correct
,
total_size
,
0.0001
);
delete
matmul_param
;
delete
mm
;
for
(
auto
t
:
inputs_
)
delete
t
;
for
(
auto
t
:
outputs_
)
delete
t
;
...
...
@@ -298,9 +386,7 @@ TEST_F(TestMatMulFp32, batch) {
8.869029998779297
,
25.034008026123047
};
float
*
output
=
reinterpret_cast
<
float
*>
(
outputs_
[
0
]
->
Data
());
for
(
int
i
=
0
;
i
<
18
;
++
i
)
printf
(
"%f "
,
output
[
i
]);
CompareOutputData
(
reinterpret_cast
<
float
*>
(
outputs_
[
0
]
->
Data
()),
correct
,
total_size
,
0.0001
);
delete
matmul_param
;
delete
mm
;
for
(
auto
t
:
inputs_
)
delete
t
;
for
(
auto
t
:
outputs_
)
delete
t
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录