Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
项目经理老王
Mace
提交
d87285bf
Mace
项目概览
项目经理老王
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d87285bf
编写于
11月 22, 2018
作者:
李
李寅
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize low precision gemv
上级
0b878466
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
217 addition
and
36 deletion
+217
-36
mace/ops/BUILD
mace/ops/BUILD
+9
-9
mace/ops/arm/fixpoint_gemm.h
mace/ops/arm/fixpoint_gemm.h
+141
-0
mace/ops/matmul.cc
mace/ops/matmul.cc
+50
-23
mace/ops/matmul_test.cc
mace/ops/matmul_test.cc
+17
-4
未找到文件。
mace/ops/BUILD
浏览文件 @
d87285bf
...
...
@@ -64,7 +64,8 @@ cc_library(
"ops_test_util.h"
,
"fixpoint.h"
,
"gemmlowp_util.h"
,
]
"arm/fixpoint_*.h"
,
],
)
+
if_opencl_enabled
(
glob
([
"opencl/*.h"
,
"opencl/image/*.h"
,
...
...
@@ -72,6 +73,7 @@ cc_library(
]))
+
if_quantize_enabled
(
glob
([
"fixpoint.h"
,
"gemmlowp_util.h"
,
"arm/fixpoint_*.h"
,
])),
copts
=
[
"-Werror"
,
...
...
@@ -101,11 +103,10 @@ cc_library(
]),
)
cc_library
(
name
=
"ops"
,
srcs
=
[
"ops_registry.cc"
"ops_registry.cc"
,
],
hdrs
=
[
"ops_registry.h"
,
...
...
@@ -138,12 +139,12 @@ cc_library(
cc_library
(
name
=
"test"
,
testonly
=
1
,
hdrs
=
glob
([
"*_test_util.h"
,
]),
srcs
=
[
"ops_test_util.cc"
,
],
hdrs
=
glob
([
"*_test_util.h"
,
]),
copts
=
[
"-Werror"
,
"-Wextra"
,
...
...
@@ -174,13 +175,12 @@ cc_test(
"opencl/*_test.cc"
,
],
exclude
=
[
"fixpoint_test.cc"
"fixpoint_test.cc"
,
],
)
+
if_quantize_enabled
(
glob
(
[
"fixpoint_test.cc"
"fixpoint_test.cc"
,
],
)),
copts
=
[
"-Werror"
,
...
...
mace/ops/arm/fixpoint_gemm.h
0 → 100644
浏览文件 @
d87285bf
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARM_FIXPOINT_GEMM_H_
#define MACE_OPS_ARM_FIXPOINT_GEMM_H_
#if defined(MACE_ENABLE_NEON)
#include <arm_neon.h>
#endif
#if defined(MACE_ENABLE_NEON) && !defined(__aarch64__)
#define vaddvq_u32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif
namespace
mace
{
namespace
ops
{
template
<
typename
INPUT_TYPE
,
typename
OUTPUT_TYPE
>
void
FixPointGemv
(
const
INPUT_TYPE
*
lhs
,
const
INPUT_TYPE
*
rhs
,
const
int
lhs_zero_point
,
const
int
rhs_zero_point
,
const
index_t
lhs_height
,
const
index_t
lhs_width
,
OUTPUT_TYPE
*
result
);
template
<
>
void
FixPointGemv
<
uint8_t
,
int32_t
>
(
const
uint8_t
*
lhs
,
const
uint8_t
*
rhs
,
const
int
lhs_zero_point
,
const
int
rhs_zero_point
,
const
index_t
lhs_height
,
const
index_t
lhs_width
,
int32_t
*
result
)
{
int32_t
zero_point_dot
=
lhs_zero_point
*
rhs_zero_point
*
lhs_width
;
uint32_t
sum_rhs
=
0
;
for
(
index_t
i
=
0
;
i
<
lhs_width
;
++
i
)
{
sum_rhs
+=
rhs
[
i
];
}
#pragma omp parallel for
for
(
index_t
h
=
0
;
h
<
lhs_height
;
++
h
)
{
const
uint8_t
*
lhs_ptr
=
lhs
+
h
*
lhs_width
;
const
uint8_t
*
rhs_ptr
=
rhs
;
int32_t
*
ret_ptr
=
result
+
h
;
uint32_t
dot
=
0
;
uint32_t
sum_lhs
=
0
;
index_t
w
=
0
;
#if defined(MACE_ENABLE_NEON)
uint32x4_t
vo0_high_u32
,
vo0_low_u32
,
vo1_high_u32
,
vo1_low_u32
;
vo0_high_u32
=
vdupq_n_u32
(
0
);
vo0_low_u32
=
vdupq_n_u32
(
0
);
vo1_high_u32
=
vdupq_n_u32
(
0
);
vo1_low_u32
=
vdupq_n_u32
(
0
);
uint32x4_t
sum_lhs_low_u32
,
sum_lhs_high_u32
;
sum_lhs_low_u32
=
vdupq_n_u32
(
0
);
sum_lhs_high_u32
=
vdupq_n_u32
(
0
);
for
(;
w
<=
lhs_width
-
16
;
w
+=
16
)
{
uint8x8_t
vl0_u8
,
vl1_u8
;
uint8x8_t
vr0_u8
,
vr1_u8
;
uint16x8_t
vl0_u16
,
vl1_u16
;
uint16x8_t
vr0_u16
,
vr1_u16
;
vl0_u8
=
vld1_u8
(
lhs_ptr
);
vl1_u8
=
vld1_u8
(
lhs_ptr
+
8
);
vr0_u8
=
vld1_u8
(
rhs_ptr
);
vr1_u8
=
vld1_u8
(
rhs_ptr
+
8
);
vl0_u16
=
vmovl_u8
(
vl0_u8
);
vl1_u16
=
vmovl_u8
(
vl1_u8
);
vr0_u16
=
vmovl_u8
(
vr0_u8
);
vr1_u16
=
vmovl_u8
(
vr1_u8
);
vo0_high_u32
=
vmlal_u16
(
vo0_high_u32
,
vget_high_u16
(
vl0_u16
),
vget_high_u16
(
vr0_u16
));
vo0_low_u32
=
vmlal_u16
(
vo0_low_u32
,
vget_low_u16
(
vl0_u16
),
vget_low_u16
(
vr0_u16
));
vo1_high_u32
=
vmlal_u16
(
vo1_high_u32
,
vget_high_u16
(
vl1_u16
),
vget_high_u16
(
vr1_u16
));
vo1_low_u32
=
vmlal_u16
(
vo1_low_u32
,
vget_low_u16
(
vl1_u16
),
vget_low_u16
(
vr1_u16
));
// It can be precuculated if lhs is const, but for this case
// computation is not bottleneck
sum_lhs_high_u32
+=
vaddl_u16
(
vget_high_u16
(
vl0_u16
),
vget_high_u16
(
vl1_u16
));
sum_lhs_low_u32
+=
vaddl_u16
(
vget_low_u16
(
vl0_u16
),
vget_low_u16
(
vl1_u16
));
lhs_ptr
+=
16
;
rhs_ptr
+=
16
;
}
vo0_low_u32
=
vaddq_u32
(
vo0_high_u32
,
vo0_low_u32
);
vo1_low_u32
=
vaddq_u32
(
vo1_high_u32
,
vo1_low_u32
);
vo0_low_u32
=
vaddq_u32
(
vo0_low_u32
,
vo1_low_u32
);
dot
+=
vaddvq_u32
(
vo0_low_u32
);
sum_lhs_low_u32
=
vaddq_u32
(
sum_lhs_high_u32
,
sum_lhs_low_u32
);
sum_lhs
=
vaddvq_u32
(
sum_lhs_low_u32
);
#endif // MACE_ENABLE_NEON
for
(;
w
<
lhs_width
;
++
w
)
{
dot
+=
(
*
lhs_ptr
)
*
(
*
rhs_ptr
);
sum_lhs
+=
(
*
lhs_ptr
);
++
lhs_ptr
;
++
rhs_ptr
;
}
int32_t
ret
=
dot
-
sum_lhs
*
rhs_zero_point
-
sum_rhs
*
lhs_zero_point
+
zero_point_dot
;
*
ret_ptr
=
ret
;
}
// h
}
}
// namespace ops
}
// namespace mace
#endif // MACE_OPS_ARM_FIXPOINT_GEMM_H_
mace/ops/matmul.cc
浏览文件 @
d87285bf
...
...
@@ -27,6 +27,7 @@
#ifdef MACE_ENABLE_QUANTIZE
#include "mace/ops/gemmlowp_util.h"
#include "mace/ops/arm/fixpoint_gemm.h"
#endif // MACE_ENABLE_QUANTIZE
#ifdef MACE_ENABLE_OPENCL
...
...
@@ -169,9 +170,6 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
const
index_t
K
,
const
index_t
width
,
Tensor
*
C
)
{
auto
gemm_context
=
context
->
device
()
->
cpu_runtime
()
->
GetGemmlowpContext
();
MACE_CHECK_NOTNULL
(
gemm_context
);
Tensor
::
MappingGuard
guarda
(
A
);
Tensor
::
MappingGuard
guardb
(
B
);
Tensor
::
MappingGuard
guardc
(
C
);
...
...
@@ -180,6 +178,9 @@ class MatMulFixpointImpl<AOrder, BOrder, uint8_t> {
auto
c_ptr_base
=
C
->
mutable_data
<
uint8_t
>
();
index_t
batch
=
std
::
accumulate
(
A
->
shape
().
begin
(),
A
->
shape
().
end
()
-
2
,
1
,
std
::
multiplies
<
index_t
>
());
auto
gemm_context
=
context
->
device
()
->
cpu_runtime
()
->
GetGemmlowpContext
();
MACE_CHECK_NOTNULL
(
gemm_context
);
index_t
a_size
=
height
*
K
;
index_t
b_size
=
K
*
width
;
index_t
c_size
=
height
*
width
;
...
...
@@ -213,9 +214,6 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
const
index_t
K
,
const
index_t
width
,
Tensor
*
C
)
{
auto
gemm_context
=
context
->
device
()
->
cpu_runtime
()
->
GetGemmlowpContext
();
MACE_CHECK_NOTNULL
(
gemm_context
);
Tensor
::
MappingGuard
guarda
(
A
);
Tensor
::
MappingGuard
guardb
(
B
);
Tensor
::
MappingGuard
guardc
(
C
);
...
...
@@ -224,24 +222,53 @@ class MatMulFixpointImpl<AOrder, BOrder, int32_t> {
auto
c_ptr_base
=
C
->
mutable_data
<
int32_t
>
();
index_t
batch
=
std
::
accumulate
(
A
->
shape
().
begin
(),
A
->
shape
().
end
()
-
2
,
1
,
std
::
multiplies
<
index_t
>
());
index_t
a_size
=
height
*
K
;
index_t
b_size
=
K
*
width
;
index_t
c_size
=
height
*
width
;
const
auto
output_pipeline
=
std
::
make_tuple
();
for
(
index_t
i
=
0
;
i
<
batch
;
++
i
)
{
gemmlowp
::
MatrixMap
<
const
uint8_t
,
AOrder
>
a_matrix
(
a_ptr_base
+
i
*
a_size
,
height
,
K
);
gemmlowp
::
MatrixMap
<
const
uint8_t
,
BOrder
>
b_matrix
(
b_ptr_base
+
i
*
b_size
,
K
,
width
);
gemmlowp
::
MatrixMap
<
int32_t
,
gemmlowp
::
MapOrder
::
RowMajor
>
c_matrix
(
c_ptr_base
+
i
*
c_size
,
height
,
width
);
using
BitDepthParams
=
gemmlowp
::
L8R8WithLhsNonzeroBitDepthParams
;
gemmlowp
::
GemmWithOutputPipeline
<
uint8_t
,
int32_t
,
BitDepthParams
>
(
gemm_context
,
a_matrix
,
b_matrix
,
&
c_matrix
,
-
A
->
zero_point
(),
-
B
->
zero_point
(),
output_pipeline
);
if
(
width
==
1
&&
AOrder
==
gemmlowp
::
MapOrder
::
RowMajor
)
{
// gemv
for
(
index_t
i
=
0
;
i
<
batch
;
++
i
)
{
FixPointGemv
(
a_ptr_base
+
i
*
height
*
K
,
b_ptr_base
+
i
*
K
,
A
->
zero_point
(),
B
->
zero_point
(),
height
,
K
,
c_ptr_base
+
i
*
height
);
}
}
else
if
(
height
==
1
&&
BOrder
==
gemmlowp
::
MapOrder
::
ColMajor
)
{
// gevm
for
(
index_t
i
=
0
;
i
<
batch
;
++
i
)
{
FixPointGemv
(
b_ptr_base
+
i
*
K
*
width
,
a_ptr_base
+
i
*
K
,
B
->
zero_point
(),
A
->
zero_point
(),
width
,
K
,
c_ptr_base
+
i
*
width
);
}
}
else
{
auto
gemm_context
=
context
->
device
()
->
cpu_runtime
()
->
GetGemmlowpContext
();
MACE_CHECK_NOTNULL
(
gemm_context
);
index_t
a_size
=
height
*
K
;
index_t
b_size
=
K
*
width
;
index_t
c_size
=
height
*
width
;
const
auto
output_pipeline
=
std
::
make_tuple
();
for
(
index_t
i
=
0
;
i
<
batch
;
++
i
)
{
gemmlowp
::
MatrixMap
<
const
uint8_t
,
AOrder
>
a_matrix
(
a_ptr_base
+
i
*
a_size
,
height
,
K
);
gemmlowp
::
MatrixMap
<
const
uint8_t
,
BOrder
>
b_matrix
(
b_ptr_base
+
i
*
b_size
,
K
,
width
);
gemmlowp
::
MatrixMap
<
int32_t
,
gemmlowp
::
MapOrder
::
RowMajor
>
c_matrix
(
c_ptr_base
+
i
*
c_size
,
height
,
width
);
using
BitDepthParams
=
gemmlowp
::
L8R8WithLhsNonzeroBitDepthParams
;
gemmlowp
::
GemmWithOutputPipeline
<
uint8_t
,
int32_t
,
BitDepthParams
>
(
gemm_context
,
a_matrix
,
b_matrix
,
&
c_matrix
,
-
A
->
zero_point
(),
-
B
->
zero_point
(),
output_pipeline
);
}
}
C
->
SetScale
(
A
->
scale
()
*
B
->
scale
());
...
...
mace/ops/matmul_test.cc
浏览文件 @
d87285bf
...
...
@@ -315,14 +315,20 @@ void QuantOutputInt32(const std::vector<index_t> &batch,
index_t
batch_count
=
std
::
accumulate
(
batch
.
begin
(),
batch
.
end
(),
1
,
std
::
multiplies
<
index_t
>
());
if
(
transpose_a
)
{
net
.
AddRandomInput
<
CPU
,
float
>
(
"A"
,
{
batch_count
,
channels
,
height
});
net
.
AddRandomInput
<
CPU
,
float
>
(
"A"
,
{
batch_count
,
channels
,
height
},
false
);
}
else
{
net
.
AddRandomInput
<
CPU
,
float
>
(
"A"
,
{
batch_count
,
height
,
channels
});
net
.
AddRandomInput
<
CPU
,
float
>
(
"A"
,
{
batch_count
,
height
,
channels
},
false
);
}
if
(
transpose_b
)
{
net
.
AddRandomInput
<
CPU
,
float
>
(
"B"
,
{
batch_count
,
out_width
,
channels
});
net
.
AddRandomInput
<
CPU
,
float
>
(
"B"
,
{
batch_count
,
out_width
,
channels
},
false
);
}
else
{
net
.
AddRandomInput
<
CPU
,
float
>
(
"B"
,
{
batch_count
,
channels
,
out_width
});
net
.
AddRandomInput
<
CPU
,
float
>
(
"B"
,
{
batch_count
,
channels
,
out_width
},
false
);
}
OpDefBuilder
(
"MatMul"
,
"MatMulTest"
)
...
...
@@ -411,11 +417,18 @@ TEST_F(MatMulOpTest, QuantOutputInt32) {
QuantOutputInt32
({
1
},
64
,
128
,
32
,
true
,
true
);
QuantOutputInt32
({
1
},
64
,
32
,
128
,
true
,
true
);
QuantOutputInt32
({
2
,
3
},
64
,
32
,
128
,
true
,
true
);
QuantOutputInt32
({
1
},
1
,
30000
,
256
,
false
,
true
);
QuantOutputInt32
({
1
},
30000
,
256
,
1
,
false
,
false
);
QuantOutputInt32
({
2
},
1
,
256
,
128
,
false
,
true
);
QuantOutputInt32
({
3
},
128
,
256
,
1
,
false
,
false
);
// UnAligned
QuantOutputInt32
({
2
},
3
,
3
,
3
,
false
,
false
);
QuantOutputInt32
({
16
},
31
,
61
,
67
,
false
,
true
);
QuantOutputInt32
({
31
},
31
,
61
,
67
,
true
,
false
);
QuantOutputInt32
({
2
,
3
},
31
,
61
,
67
,
true
,
true
);
QuantOutputInt32
({
1
},
1
,
30001
,
253
,
false
,
true
);
QuantOutputInt32
({
2
},
253
,
300
,
1
,
false
,
false
);
}
// TODO(liyin): test transpose after implementing gpu runtime
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录