Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
慢慢CG
Mace
提交
1ee5fd20
Mace
项目概览
慢慢CG
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
“e0ab51f46ee291075734d0267520ffe68d3e224e”上不存在“paddlespeech/s2t/decoders/ctcdecoder/scorer_deprecated.py”
提交
1ee5fd20
编写于
3月 12, 2019
作者:
L
liyin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Regress gemv; support quantize gather only
上级
202ea3a6
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
134 addition
and
402 deletion
+134
-402
mace/ops/arm/q8/gemv.cc
mace/ops/arm/q8/gemv.cc
+105
-389
mace/ops/arm/q8/gemv.h
mace/ops/arm/q8/gemv.h
+6
-1
mace/python/tools/converter_tool/base_converter.py
mace/python/tools/converter_tool/base_converter.py
+1
-1
mace/python/tools/converter_tool/transformer.py
mace/python/tools/converter_tool/transformer.py
+22
-11
未找到文件。
mace/ops/arm/q8/gemv.cc
浏览文件 @
1ee5fd20
...
@@ -23,9 +23,7 @@
...
@@ -23,9 +23,7 @@
#if !defined(__aarch64__)
#if !defined(__aarch64__)
#define vmlal_high_s16(c, a, b) vmlal_s16(c, vget_high_s16(a), vget_high_s16(b))
#define vaddvq_u32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#define vaddvq_s32(v) ((v)[0] + (v)[1] + (v)[2] + (v)[3])
#endif
#endif
...
@@ -47,17 +45,19 @@ MaceStatus Gemv<OUTPUT_TYPE>::Compute(const OpContext *context,
...
@@ -47,17 +45,19 @@ MaceStatus Gemv<OUTPUT_TYPE>::Compute(const OpContext *context,
Tensor
*
output
)
{
Tensor
*
output
)
{
MACE_UNUSED
(
context
);
MACE_UNUSED
(
context
);
bool
is_output_type_uint8
=
DataTypeToEnum
<
OUTPUT_TYPE
>::
value
==
DataType
::
DT_UINT8
;
Tensor
::
MappingGuard
lhs_guard
(
lhs
);
Tensor
::
MappingGuard
lhs_guard
(
lhs
);
Tensor
::
MappingGuard
rhs_guard
(
rhs
);
Tensor
::
MappingGuard
rhs_guard
(
rhs
);
Tensor
::
MappingGuard
bias_guard
(
bias
);
Tensor
::
MappingGuard
bias_guard
(
bias
);
Tensor
::
MappingGuard
output_guard
(
output
);
Tensor
::
MappingGuard
output_guard
(
output
);
const
auto
*
lhs_data
=
lhs
->
data
<
uint8_t
>
();
const
auto
*
rhs_data
=
rhs
->
data
<
uint8_t
>
();
OUTPUT_TYPE
*
output_data
=
output
->
mutable_data
<
OUTPUT_TYPE
>
();
float
output_multiplier_float
=
0.0
;
float
output_multiplier_float
=
0.0
;
int32_t
output_multiplier
=
0
;
int32_t
output_multiplier
=
0
;
int32_t
output_shift
=
0
;
int32_t
output_shift
=
0
;
if
(
is_output_type_uint8
)
{
if
(
is_output_type_uint8
_
)
{
MACE_CHECK
(
output
->
scale
()
>
0
,
"output scale must not be zero"
);
MACE_CHECK
(
output
->
scale
()
>
0
,
"output scale must not be zero"
);
output_multiplier_float
=
lhs
->
scale
()
*
rhs
->
scale
()
/
output
->
scale
();
output_multiplier_float
=
lhs
->
scale
()
*
rhs
->
scale
()
/
output
->
scale
();
GetOutputMultiplierAndShift
(
lhs
->
scale
(),
GetOutputMultiplierAndShift
(
lhs
->
scale
(),
...
@@ -66,393 +66,110 @@ MaceStatus Gemv<OUTPUT_TYPE>::Compute(const OpContext *context,
...
@@ -66,393 +66,110 @@ MaceStatus Gemv<OUTPUT_TYPE>::Compute(const OpContext *context,
&
output_multiplier
,
&
output_multiplier
,
&
output_shift
);
&
output_shift
);
}
}
const
index_t
h_block_size
=
4
;
const
index_t
h_block_count
=
RoundUpDiv
(
lhs_height
,
h_block_size
);
#pragma omp parallel for collapse(2) schedule(runtime)
const
int32_t
lhs_zero_point
=
lhs
->
zero_point
();
const
int32_t
rhs_zero_point
=
rhs
->
zero_point
();
const
index_t
w_block_size
=
16
;
const
index_t
w_block_count
=
lhs_width
/
w_block_size
;
const
index_t
w_block_remain
=
lhs_width
-
w_block_size
*
w_block_count
;
for
(
index_t
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
index_t
b
=
0
;
b
<
batch
;
++
b
)
{
for
(
index_t
h_block_idx
=
0
;
h_block_idx
<
h_block_count
;
++
h_block_idx
)
{
const
uint8_t
*
rhs_base
=
// TODO(liyin): it can be put it outside the loop,
rhs_data
+
static_cast
<
index_t
>
(
rhs_batched
)
*
b
*
lhs_width
;
// but openmp limits param count
uint32_t
sum_rhs
=
0
;
const
index_t
w_block_size
=
16
;
for
(
index_t
i
=
0
;
i
<
lhs_width
;
++
i
)
{
const
index_t
w_block_count
=
lhs_width
/
w_block_size
;
sum_rhs
+=
static_cast
<
uint32_t
>
(
rhs_base
[
i
]);
const
index_t
w_remain
=
lhs_width
-
w_block_size
*
w_block_count
;
}
uint8_t
lhs_zero_point
=
static_cast
<
uint8_t
>
(
lhs
->
zero_point
());
#pragma omp parallel for schedule(runtime)
uint8_t
rhs_zero_point
=
static_cast
<
uint8_t
>
(
rhs
->
zero_point
());
for
(
index_t
h
=
0
;
h
<
lhs_height
;
++
h
)
{
const
uint8_t
*
lhs_ptr
=
lhs_data
const
uint8_t
*
lhs_data
=
lhs
->
data
<
uint8_t
>
();
+
static_cast
<
index_t
>
(
lhs_batched
)
*
b
*
lhs_height
*
lhs_width
const
uint8_t
*
rhs_data
=
rhs
->
data
<
uint8_t
>
();
+
h
*
lhs_width
;
const
int32_t
*
bias_data
=
nullptr
;
const
uint8_t
*
rhs_ptr
=
rhs_base
;
if
(
bias
)
{
OUTPUT_TYPE
*
output_ptr
=
output_data
+
b
*
lhs_height
+
h
;
bias_data
=
bias
->
data
<
int32_t
>
();
uint32_t
dot
=
0
;
uint32_t
sum_lhs
=
0
;
uint32x4_t
vo0_high_u32
=
vdupq_n_u32
(
0
);
uint32x4_t
vo0_low_u32
=
vdupq_n_u32
(
0
);
uint32x4_t
vo1_high_u32
=
vdupq_n_u32
(
0
);
uint32x4_t
vo1_low_u32
=
vdupq_n_u32
(
0
);
uint32x4_t
sum_lhs_low_u32
=
vdupq_n_u32
(
0
);
uint32x4_t
sum_lhs_high_u32
=
vdupq_n_u32
(
0
);
for
(
index_t
w_block_idx
=
0
;
w_block_idx
<
w_block_count
;
++
w_block_idx
)
{
uint8x8_t
vl0_u8
=
vld1_u8
(
lhs_ptr
);
uint8x8_t
vl1_u8
=
vld1_u8
(
lhs_ptr
+
8
);
uint8x8_t
vr0_u8
=
vld1_u8
(
rhs_ptr
);
uint8x8_t
vr1_u8
=
vld1_u8
(
rhs_ptr
+
8
);
uint16x8_t
vl0_u16
=
vmovl_u8
(
vl0_u8
);
uint16x8_t
vl1_u16
=
vmovl_u8
(
vl1_u8
);
uint16x8_t
vr0_u16
=
vmovl_u8
(
vr0_u8
);
uint16x8_t
vr1_u16
=
vmovl_u8
(
vr1_u8
);
vo0_high_u32
=
vmlal_u16
(
vo0_high_u32
,
vget_high_u16
(
vl0_u16
),
vget_high_u16
(
vr0_u16
));
vo0_low_u32
=
vmlal_u16
(
vo0_low_u32
,
vget_low_u16
(
vl0_u16
),
vget_low_u16
(
vr0_u16
));
vo1_high_u32
=
vmlal_u16
(
vo1_high_u32
,
vget_high_u16
(
vl1_u16
),
vget_high_u16
(
vr1_u16
));
vo1_low_u32
=
vmlal_u16
(
vo1_low_u32
,
vget_low_u16
(
vl1_u16
),
vget_low_u16
(
vr1_u16
));
// It can be precuculated if lhs is const, but for this case
// computation is not bottleneck
sum_lhs_high_u32
+=
vaddl_u16
(
vget_high_u16
(
vl0_u16
),
vget_high_u16
(
vl1_u16
));
sum_lhs_low_u32
+=
vaddl_u16
(
vget_low_u16
(
vl0_u16
),
vget_low_u16
(
vl1_u16
));
lhs_ptr
+=
16
;
rhs_ptr
+=
16
;
}
}
OUTPUT_TYPE
*
output_data
=
output
->
mutable_data
<
OUTPUT_TYPE
>
();
int32x4_t
voutput_multiplier
=
vdupq_n_s32
(
output_multiplier
);
vo0_low_u32
=
vaddq_u32
(
vo0_high_u32
,
vo0_low_u32
);
int32x4_t
voutput_shift_left
=
vdupq_n_s32
(
-
output_shift
);
vo1_low_u32
=
vaddq_u32
(
vo1_high_u32
,
vo1_low_u32
);
vo0_low_u32
=
vaddq_u32
(
vo0_low_u32
,
vo1_low_u32
);
dot
+=
vaddvq_u32
(
vo0_low_u32
);
uint8x8_t
sum_lhs_low_u32
=
vaddq_u32
(
sum_lhs_high_u32
,
sum_lhs_low_u32
);
vlhs_zero_point
=
vdup_n_u8
(
lhs_zero_point
);
sum_lhs
=
vaddvq_u32
(
sum_lhs_low_u32
);
uint8x8_t
vrhs_zero_point
=
vdup_n_u8
(
rhs_zero_point
);
const
uint8_t
for
(
index_t
w
=
0
;
w
<
w_block_remain
;
++
w
)
{
*
lhs_ptr
=
lhs_data
dot
+=
(
*
lhs_ptr
)
*
(
*
rhs_ptr
);
+
static_cast
<
index_t
>
(
lhs_batched
)
*
b
*
lhs_height
*
lhs_width
sum_lhs
+=
(
*
lhs_ptr
);
+
lhs_width
*
h_block_idx
*
h_block_size
;
++
lhs_ptr
;
const
uint8_t
*
rhs_ptr
=
++
rhs_ptr
;
rhs_data
+
static_cast
<
index_t
>
(
rhs_batched
)
*
b
*
lhs_width
;
}
OUTPUT_TYPE
*
ret_ptr
=
output_data
+
b
*
lhs_height
+
h_block_idx
*
h_block_size
;
const
auto
zero_point_dot
=
static_cast
<
int32_t
>
(
lhs_zero_point
*
rhs_zero_point
*
lhs_width
);
const
index_t
h_block_len
=
int32_t
ret
=
dot
-
sum_lhs
*
rhs_zero_point
-
sum_rhs
*
lhs_zero_point
std
::
min
(
h_block_size
,
lhs_height
-
h_block_idx
*
h_block_size
);
+
zero_point_dot
;
const
index_t
h_offset
=
h_block_idx
*
h_block_size
;
if
(
bias
)
{
ret
+=
bias
->
data
<
int32_t
>
()[
h
];
if
(
h_block_len
==
4
)
{
}
int32x4_t
vo0
=
vdupq_n_s32
(
0
);
int32x4_t
vo1
=
vdupq_n_s32
(
0
);
if
(
is_output_type_uint8_
)
{
int32x4_t
vo2
=
vdupq_n_s32
(
0
);
*
output_ptr
=
int32x4_t
vo3
=
vdupq_n_s32
(
0
);
Saturate
<
uint8_t
>
(
std
::
roundf
(
ret
*
output_multiplier_float
));
}
else
{
index_t
r_w_block_count
=
w_block_count
;
*
output_ptr
=
ret
;
// just make compiler happy
}
MACE_UNUSED
(
r_w_block_count
);
}
// h
// Register layout: (4x16) x (16x1)
//
// +----+
// |d16 |
// | . |
// | . |
// | . |
// Rhs +----+
// |d17 |
// | . |
// | . |
// | . |
// +----+
// |d18 |
// | . |
// | . |
// | . |
// +----+
// |d19 |
// | . |
// | . |
// | . |
// +----+
//
// | |
//
// Lhs | |
//
// +--------+--------+--------+--------+ - - - - +----+
// | d0 ... | d1 ... | d2 ... | d3 ... | |vo0 |
// | d4 ... | d5 ... | d6 ... | d7 ... | |vo1 |
// | d8 ... | d9 ... | d10... | d11... | |vo2 |
// | d12... | d13... | d14... | d15... | |vo3 |
// +--------+--------+--------+--------+ - - - - +----+
//
// Accumulator
//
#if not defined(__aarch64__)
asm
volatile
(
"cmp %[r_w_block_count], #0
\n
"
"beq 0f
\n
"
"mov r0, %[rhs_ptr]
\n
"
"mov r1, %[lhs_ptr]
\n
"
"add r2, r1, %[lhs_width]
\n
"
"add r3, r2, %[lhs_width]
\n
"
"add r4, r3, %[lhs_width]
\n
"
"vdup.u8 d20, %[rhs_zero_point]
\n
"
"vdup.u8 d21, %[lhs_zero_point]
\n
"
// prelogue
"vld1.8 d16, [r0]!
\n
"
"vld1.8 d18, [r0]!
\n
"
"vld1.8 d0, [r1]!
\n
"
"vld1.8 d2, [r1]!
\n
"
"vld1.8 d4, [r2]!
\n
"
"vld1.8 d6, [r2]!
\n
"
"vld1.8 d8, [r3]!
\n
"
"vld1.8 d10, [r3]!
\n
"
"vld1.8 d12, [r4]!
\n
"
"vld1.8 d14, [r4]!
\n
"
"subs %[r_w_block_count], #1
\n
"
"beq 1f
\n
"
"2:
\n
"
"vsubl.u8 q8, d16, d20
\n
"
"vsubl.u8 q9, d18, d20
\n
"
"vsubl.u8 q0, d0, d21
\n
"
"vsubl.u8 q1, d2, d21
\n
"
"vsubl.u8 q2, d4, d21
\n
"
"vsubl.u8 q3, d6, d21
\n
"
"vsubl.u8 q4, d8, d21
\n
"
"vsubl.u8 q5, d10, d21
\n
"
"vsubl.u8 q6, d12, d21
\n
"
"vsubl.u8 q7, d14, d21
\n
"
"vmlal.s16 %q[vo0], d0, d16
\n
"
"vmlal.s16 %q[vo1], d4, d16
\n
"
"vmlal.s16 %q[vo2], d8, d16
\n
"
"vmlal.s16 %q[vo3], d12, d16
\n
"
"vld1.8 d0, [r1]!
\n
"
"vld1.8 d4, [r2]!
\n
"
"vld1.8 d8, [r3]!
\n
"
"vld1.8 d12, [r4]!
\n
"
"vld1.8 d16, [r0]!
\n
"
"vmlal.s16 %q[vo0], d2, d18
\n
"
"vmlal.s16 %q[vo1], d6, d18
\n
"
"vmlal.s16 %q[vo2], d10, d18
\n
"
"vmlal.s16 %q[vo3], d14, d18
\n
"
"vld1.8 d2, [r1]!
\n
"
"vld1.8 d6, [r2]!
\n
"
"vld1.8 d10, [r3]!
\n
"
"vld1.8 d14, [r4]!
\n
"
"vld1.8 d18, [r0]!
\n
"
"vmlal.s16 %q[vo0], d1, d17
\n
"
"vmlal.s16 %q[vo1], d5, d17
\n
"
"vmlal.s16 %q[vo2], d9, d17
\n
"
"vmlal.s16 %q[vo3], d13, d17
\n
"
"subs %[r_w_block_count], #1
\n
"
"vmlal.s16 %q[vo0], d3, d19
\n
"
"vmlal.s16 %q[vo1], d7, d19
\n
"
"vmlal.s16 %q[vo2], d11, d19
\n
"
"vmlal.s16 %q[vo3], d15, d19
\n
"
"bne 2b
\n
"
// prologue
"1:
\n
"
"vsubl.u8 q8, d16, d20
\n
"
"vsubl.u8 q9, d18, d20
\n
"
"vsubl.u8 q0, d0, d21
\n
"
"vsubl.u8 q1, d2, d21
\n
"
"vsubl.u8 q2, d4, d21
\n
"
"vsubl.u8 q3, d6, d21
\n
"
"vsubl.u8 q4, d8, d21
\n
"
"vsubl.u8 q5, d10, d21
\n
"
"vsubl.u8 q6, d12, d21
\n
"
"vsubl.u8 q7, d14, d21
\n
"
"vmlal.s16 %q[vo0], d0, d16
\n
"
"vmlal.s16 %q[vo1], d4, d16
\n
"
"vmlal.s16 %q[vo2], d8, d16
\n
"
"vmlal.s16 %q[vo3], d12, d16
\n
"
"vmlal.s16 %q[vo0], d1, d17
\n
"
"vmlal.s16 %q[vo1], d5, d17
\n
"
"vmlal.s16 %q[vo2], d9, d17
\n
"
"vmlal.s16 %q[vo3], d13, d17
\n
"
"vmlal.s16 %q[vo0], d2, d18
\n
"
"vmlal.s16 %q[vo1], d6, d18
\n
"
"vmlal.s16 %q[vo2], d10, d18
\n
"
"vmlal.s16 %q[vo3], d14, d18
\n
"
"vmlal.s16 %q[vo0], d3, d19
\n
"
"vmlal.s16 %q[vo1], d7, d19
\n
"
"vmlal.s16 %q[vo2], d11, d19
\n
"
"vmlal.s16 %q[vo3], d15, d19
\n
"
"0:
\n
"
:
// outputs
[
vo0
]
"+w"
(
vo0
),
[
vo1
]
"+w"
(
vo1
),
[
vo2
]
"+w"
(
vo2
),
[
vo3
]
"+w"
(
vo3
),
[
r_w_block_count
]
"+r"
(
r_w_block_count
)
:
// inputs
[
lhs_ptr
]
"r"
(
lhs_ptr
),
[
rhs_ptr
]
"r"
(
rhs_ptr
),
[
lhs_width
]
"r"
(
lhs_width
),
[
lhs_zero_point
]
"r"
(
lhs_zero_point
),
[
rhs_zero_point
]
"r"
(
rhs_zero_point
)
:
// clobbers
"cc"
,
"memory"
,
"r0"
,
"r1"
,
"r2"
,
"r3"
,
"r4"
,
"d0"
,
"d1"
,
"d2"
,
"d3"
,
"d4"
,
"d5"
,
"d6"
,
"d7"
,
"d8"
,
"d9"
,
"d10"
,
"d11"
,
"d12"
,
"d13"
,
"d14"
,
"d15"
,
"d16"
,
"d17"
,
"d18"
,
"d19"
,
"d20"
,
"d21"
);
lhs_ptr
+=
w_block_count
*
w_block_size
;
rhs_ptr
+=
w_block_count
*
w_block_size
;
#else
for
(
index_t
w_block_index
=
0
;
w_block_index
<
w_block_count
;
++
w_block_index
)
{
uint8x8_t
vr0
=
vld1_u8
(
rhs_ptr
);
int16x8_t
vxr0
=
vreinterpretq_s16_u16
(
vsubl_u8
(
vr0
,
vrhs_zero_point
));
uint8x8_t
vr0n
=
vld1_u8
(
rhs_ptr
+
8
);
int16x8_t
vxr0n
=
vreinterpretq_s16_u16
(
vsubl_u8
(
vr0n
,
vrhs_zero_point
));
uint8x8_t
vl0
=
vld1_u8
(
lhs_ptr
);
int16x8_t
vxl0
=
vreinterpretq_s16_u16
(
vsubl_u8
(
vl0
,
vlhs_zero_point
));
uint8x8_t
vl0n
=
vld1_u8
(
lhs_ptr
+
8
);
int16x8_t
vxl0n
=
vreinterpretq_s16_u16
(
vsubl_u8
(
vl0n
,
vlhs_zero_point
));
vo0
=
vmlal_s16
(
vo0
,
vget_low_s16
(
vxl0
),
vget_low_s16
(
vxr0
));
vo0
=
vmlal_high_s16
(
vo0
,
vxl0
,
vxr0
);
vo0
=
vmlal_s16
(
vo0
,
vget_low_s16
(
vxl0n
),
vget_low_s16
(
vxr0n
));
vo0
=
vmlal_high_s16
(
vo0
,
vxl0n
,
vxr0n
);
const
uint8_t
*
lhs_ptr1
=
lhs_ptr
+
lhs_width
;
uint8x8_t
vl1
=
vld1_u8
(
lhs_ptr1
);
int16x8_t
vxl1
=
vreinterpretq_s16_u16
(
vsubl_u8
(
vl1
,
vlhs_zero_point
));
uint8x8_t
vl1n
=
vld1_u8
(
lhs_ptr1
+
8
);
int16x8_t
vxl1n
=
vreinterpretq_s16_u16
(
vsubl_u8
(
vl1n
,
vlhs_zero_point
));
vo1
=
vmlal_s16
(
vo1
,
vget_low_s16
(
vxl1
),
vget_low_s16
(
vxr0
));
vo1
=
vmlal_high_s16
(
vo1
,
vxl1
,
vxr0
);
vo1
=
vmlal_s16
(
vo1
,
vget_low_s16
(
vxl1n
),
vget_low_s16
(
vxr0n
));
vo1
=
vmlal_high_s16
(
vo1
,
vxl1n
,
vxr0n
);
const
uint8_t
*
lhs_ptr2
=
lhs_ptr1
+
lhs_width
;
uint8x8_t
vl2
=
vld1_u8
(
lhs_ptr2
);
int16x8_t
vxl2
=
vreinterpretq_s16_u16
(
vsubl_u8
(
vl2
,
vlhs_zero_point
));
uint8x8_t
vl2n
=
vld1_u8
(
lhs_ptr2
+
8
);
int16x8_t
vxl2n
=
vreinterpretq_s16_u16
(
vsubl_u8
(
vl2n
,
vlhs_zero_point
));
vo2
=
vmlal_s16
(
vo2
,
vget_low_s16
(
vxl2
),
vget_low_s16
(
vxr0
));
vo2
=
vmlal_high_s16
(
vo2
,
vxl2
,
vxr0
);
vo2
=
vmlal_s16
(
vo2
,
vget_low_s16
(
vxl2n
),
vget_low_s16
(
vxr0n
));
vo2
=
vmlal_high_s16
(
vo2
,
vxl2n
,
vxr0n
);
const
uint8_t
*
lhs_ptr3
=
lhs_ptr2
+
lhs_width
;
uint8x8_t
vl3
=
vld1_u8
(
lhs_ptr3
);
int16x8_t
vxl3
=
vreinterpretq_s16_u16
(
vsubl_u8
(
vl3
,
vlhs_zero_point
));
uint8x8_t
vl3n
=
vld1_u8
(
lhs_ptr3
+
8
);
int16x8_t
vxl3n
=
vreinterpretq_s16_u16
(
vsubl_u8
(
vl3n
,
vlhs_zero_point
));
vo3
=
vmlal_s16
(
vo3
,
vget_low_s16
(
vxl3
),
vget_low_s16
(
vxr0
));
vo3
=
vmlal_high_s16
(
vo3
,
vxl3
,
vxr0
);
vo3
=
vmlal_s16
(
vo3
,
vget_low_s16
(
vxl3n
),
vget_low_s16
(
vxr0n
));
vo3
=
vmlal_high_s16
(
vo3
,
vxl3n
,
vxr0n
);
lhs_ptr
+=
16
;
rhs_ptr
+=
16
;
}
#endif // __aarch64__
int32x4_t
vo
=
{
vaddvq_s32
(
vo0
),
vaddvq_s32
(
vo1
),
vaddvq_s32
(
vo2
),
vaddvq_s32
(
vo3
)};
for
(
index_t
w
=
0
;
w
<
w_remain
;
++
w
)
{
vo
[
0
]
+=
(
lhs_ptr
[
0
]
-
lhs_zero_point
)
*
(
rhs_ptr
[
0
]
-
rhs_zero_point
);
vo
[
1
]
+=
(
lhs_ptr
[
lhs_width
]
-
lhs_zero_point
)
*
(
rhs_ptr
[
0
]
-
rhs_zero_point
);
vo
[
2
]
+=
(
lhs_ptr
[
lhs_width
*
2
]
-
lhs_zero_point
)
*
(
rhs_ptr
[
0
]
-
rhs_zero_point
);
vo
[
3
]
+=
(
lhs_ptr
[
lhs_width
*
3
]
-
lhs_zero_point
)
*
(
rhs_ptr
[
0
]
-
rhs_zero_point
);
++
lhs_ptr
;
++
rhs_ptr
;
}
if
(
bias
)
{
int32x4_t
vbias
=
vdupq_n_s32
(
0
);
vbias
=
vld1q_s32
(
bias_data
+
h_offset
);
vo
=
vaddq_s32
(
vo
,
vbias
);
}
if
(
is_output_type_uint8
)
{
int32x4_t
vo_mul
=
vqrdmulhq_s32
(
vo
,
voutput_multiplier
);
int32x4_t
fixup
=
vshrq_n_s32
(
vandq_s32
(
vo_mul
,
voutput_shift_left
),
31
);
int32x4_t
fixed_up_x
=
vqaddq_s32
(
vo_mul
,
fixup
);
int32x4_t
vo_rescale_int32
=
vrshlq_s32
(
fixed_up_x
,
voutput_shift_left
);
int16x4_t
vo_rescale_int16
=
vqmovn_s32
(
vo_rescale_int32
);
uint8x8_t
vo_rescale_uint8
=
vqmovun_s16
(
vcombine_s16
(
vo_rescale_int16
,
vo_rescale_int16
));
ret_ptr
[
0
]
=
vo_rescale_uint8
[
0
];
ret_ptr
[
1
]
=
vo_rescale_uint8
[
1
];
ret_ptr
[
2
]
=
vo_rescale_uint8
[
2
];
ret_ptr
[
3
]
=
vo_rescale_uint8
[
3
];
}
else
{
ret_ptr
[
0
]
=
vo
[
0
];
ret_ptr
[
1
]
=
vo
[
1
];
ret_ptr
[
2
]
=
vo
[
2
];
ret_ptr
[
3
]
=
vo
[
3
];
}
}
else
{
// h_block_len < 4
// TODO(liyin): handle here case by case (1,2,3) to accelerate
const
uint8_t
*
tmp_lhs_ptr
=
lhs_ptr
;
const
uint8_t
*
tmp_rhs_ptr
=
rhs_ptr
;
for
(
index_t
h
=
0
;
h
<
h_block_len
;
++
h
)
{
lhs_ptr
=
tmp_lhs_ptr
+
h
*
lhs_width
;
rhs_ptr
=
tmp_rhs_ptr
;
int32x4_t
vo0
=
vdupq_n_s32
(
0
);
for
(
index_t
w
=
0
;
w
<
w_block_count
;
++
w
)
{
uint8x8_t
vr0
=
vld1_u8
(
rhs_ptr
);
int16x8_t
vxr0
=
vreinterpretq_s16_u16
(
vsubl_u8
(
vr0
,
vrhs_zero_point
));
uint8x8_t
vr0n
=
vld1_u8
(
rhs_ptr
+
8
);
int16x8_t
vxr0n
=
vreinterpretq_s16_u16
(
vsubl_u8
(
vr0n
,
vrhs_zero_point
));
uint8x8_t
vl0
=
vld1_u8
(
lhs_ptr
);
int16x8_t
vxl0
=
vreinterpretq_s16_u16
(
vsubl_u8
(
vl0
,
vlhs_zero_point
));
uint8x8_t
vl0n
=
vld1_u8
(
lhs_ptr
+
8
);
int16x8_t
vxl0n
=
vreinterpretq_s16_u16
(
vsubl_u8
(
vl0n
,
vlhs_zero_point
));
vo0
=
vmlal_s16
(
vo0
,
vget_low_s16
(
vxl0
),
vget_low_s16
(
vxr0
));
vo0
=
vmlal_high_s16
(
vo0
,
vxl0
,
vxr0
);
vo0
=
vmlal_s16
(
vo0
,
vget_low_s16
(
vxl0n
),
vget_low_s16
(
vxr0n
));
vo0
=
vmlal_high_s16
(
vo0
,
vxl0n
,
vxr0n
);
lhs_ptr
+=
16
;
rhs_ptr
+=
16
;
}
// w
int32_t
s0
=
vaddvq_s32
(
vo0
)
+
(
bias
?
bias_data
[
h_offset
+
h
]
:
0
);
for
(
index_t
w
=
0
;
w
<
w_remain
;
++
w
)
{
s0
+=
(
lhs_ptr
[
0
]
-
lhs_zero_point
)
*
(
rhs_ptr
[
0
]
-
rhs_zero_point
);
++
lhs_ptr
;
++
rhs_ptr
;
}
// w
if
(
is_output_type_uint8
)
{
ret_ptr
[
h
]
=
Saturate
<
uint8_t
>
(
std
::
roundf
(
s0
*
output_multiplier_float
));
}
else
{
ret_ptr
[
h
]
=
s0
;
}
}
// h
}
// if
}
// h_block_idx
}
// b
}
// b
return
MaceStatus
::
MACE_SUCCESS
;
return
MaceStatus
::
MACE_SUCCESS
;
}
}
...
@@ -466,7 +183,6 @@ class Gemv<int32_t>;
...
@@ -466,7 +183,6 @@ class Gemv<int32_t>;
}
// namespace ops
}
// namespace ops
}
// namespace mace
}
// namespace mace
#if defined(vmlal_high_s16)
#ifdef vaddvq_u32
#undef vmlal_high_s16
#undef vaddvq_u32
#undef vaddvq_s32
#endif // vaddvq_u32
#endif
mace/ops/arm/q8/gemv.h
浏览文件 @
1ee5fd20
...
@@ -30,7 +30,9 @@ namespace q8 {
...
@@ -30,7 +30,9 @@ namespace q8 {
template
<
typename
OUTPUT_TYPE
>
template
<
typename
OUTPUT_TYPE
>
class
Gemv
{
class
Gemv
{
public:
public:
Gemv
()
{}
Gemv
()
:
is_output_type_uint8_
(
DataTypeToEnum
<
OUTPUT_TYPE
>::
value
==
DataType
::
DT_UINT8
)
{
}
~
Gemv
()
{}
~
Gemv
()
{}
// Always row-major after transpose
// Always row-major after transpose
MaceStatus
Compute
(
MaceStatus
Compute
(
...
@@ -44,6 +46,9 @@ class Gemv {
...
@@ -44,6 +46,9 @@ class Gemv {
const
bool
lhs_batched
,
const
bool
lhs_batched
,
const
bool
rhs_batched
,
const
bool
rhs_batched
,
Tensor
*
output
);
Tensor
*
output
);
private:
bool
is_output_type_uint8_
;
};
};
}
// namespace q8
}
// namespace q8
...
...
mace/python/tools/converter_tool/base_converter.py
浏览文件 @
1ee5fd20
...
@@ -280,7 +280,7 @@ class TransformerRule(Enum):
...
@@ -280,7 +280,7 @@ class TransformerRule(Enum):
FOLD_FC_RESHAPE
=
37
FOLD_FC_RESHAPE
=
37
TRANSFORM_CHANNEL_SHUFFLE
=
38
TRANSFORM_CHANNEL_SHUFFLE
=
38
UPDATE_DATA_FORMAT
=
39
UPDATE_DATA_FORMAT
=
39
QUANTIZE_
MATMUL
_ONLY
=
40
QUANTIZE_
SPECIFIC_OPS
_ONLY
=
40
class
ConverterInterface
(
object
):
class
ConverterInterface
(
object
):
...
...
mace/python/tools/converter_tool/transformer.py
浏览文件 @
1ee5fd20
...
@@ -103,8 +103,8 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -103,8 +103,8 @@ class Transformer(base_converter.ConverterInterface):
self
.
transform_caffe_reshape_and_flatten
,
self
.
transform_caffe_reshape_and_flatten
,
TransformerRule
.
TRANSFORM_CHANNEL_SHUFFLE
:
TransformerRule
.
TRANSFORM_CHANNEL_SHUFFLE
:
self
.
transform_channel_shuffle
,
self
.
transform_channel_shuffle
,
TransformerRule
.
QUANTIZE_
MATMUL
_ONLY
:
TransformerRule
.
QUANTIZE_
SPECIFIC_OPS
_ONLY
:
self
.
quantize_
matmul
_only
,
self
.
quantize_
specific_ops
_only
,
}
}
self
.
_option
=
option
self
.
_option
=
option
...
@@ -1118,7 +1118,7 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -1118,7 +1118,7 @@ class Transformer(base_converter.ConverterInterface):
rhs
=
op
.
input
[
1
]
rhs
=
op
.
input
[
1
]
if
rhs
in
self
.
_consts
and
len
(
self
.
_consts
[
rhs
].
dims
)
==
2
:
if
rhs
in
self
.
_consts
and
len
(
self
.
_consts
[
rhs
].
dims
)
==
2
:
arg
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_transpose_b_str
)
# noqa
arg
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_transpose_b_str
)
# noqa
six
.
print_
(
'transpose matmul weight'
)
six
.
print_
(
"Transpose matmul weight %s"
%
rhs
)
if
arg
is
None
:
if
arg
is
None
:
arg
=
op
.
arg
.
add
()
arg
=
op
.
arg
.
add
()
arg
.
name
=
MaceKeyword
.
mace_transpose_b_str
arg
.
name
=
MaceKeyword
.
mace_transpose_b_str
...
@@ -1927,35 +1927,46 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -1927,35 +1927,46 @@ class Transformer(base_converter.ConverterInterface):
return
True
return
True
def
quantize_
matmul
_only
(
self
):
def
quantize_
specific_ops
_only
(
self
):
"""
"""
This transform rule is only used internally, we are not gonna make
This transform rule is only used internally, we are not gonna make
things too complex for users
things too complex for users
"""
"""
to_quantize_ops
=
[
MaceOp
.
MatMul
.
name
]
to_quantize_ops_output_type
=
{
MaceOp
.
MatMul
.
name
:
mace_pb2
.
DT_INT32
,
MaceOp
.
Gather
.
name
:
mace_pb2
.
DT_UINT8
,
}
for
op
in
self
.
_model
.
op
:
for
op
in
self
.
_model
.
op
:
if
(
op
.
type
not
in
to_quantize_ops
or
len
(
op
.
output
)
>
1
if
(
op
.
type
not
in
to_quantize_ops_output_type
or
len
(
op
.
output
)
>
1
or
ConverterUtil
.
get_arg
(
op
,
or
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_op_data_type_str
).
i
!=
mace_pb2
.
DT_FLOAT
):
# noqa
MaceKeyword
.
mace_op_data_type_str
).
i
!=
mace_pb2
.
DT_FLOAT
):
# noqa
# only support single output
# only support single output
continue
continue
quantized_inputs_names
=
[]
quantized_inputs_names
=
[]
should_quantize
=
Tru
e
should_quantize
=
Fals
e
for
idx
,
input_tensor
in
enumerate
(
op
.
input
):
for
idx
,
input_tensor
in
enumerate
(
op
.
input
):
if
self
.
get_tensor_data_type
(
input_tensor
)
\
if
self
.
get_tensor_data_type
(
input_tensor
)
\
!
=
mace_pb2
.
DT_FLOAT
:
=
=
mace_pb2
.
DT_FLOAT
:
should_quantize
=
Fals
e
should_quantize
=
Tru
e
break
break
if
not
should_quantize
:
if
not
should_quantize
:
continue
continue
else
:
print
(
"Quantize op %s (%s)"
%
(
op
.
name
,
op
.
type
))
non_zero
=
self
.
_option
.
device
==
DeviceType
.
CPU
.
value
non_zero
=
self
.
_option
.
device
==
DeviceType
.
CPU
.
value
for
idx
,
input_tensor
in
enumerate
(
op
.
input
):
for
idx
,
input_tensor
in
enumerate
(
op
.
input
):
quantized_inputs_names
.
append
(
input_tensor
)
quantized_inputs_names
.
append
(
input_tensor
)
if
self
.
get_tensor_data_type
(
input_tensor
)
\
!=
mace_pb2
.
DT_FLOAT
:
continue
if
input_tensor
in
self
.
_consts
:
if
input_tensor
in
self
.
_consts
:
const_tensor
=
self
.
_consts
[
input_tensor
]
const_tensor
=
self
.
_consts
[
input_tensor
]
quantized_tensor
=
quantize_util
.
quantize
(
quantized_tensor
=
quantize_util
.
quantize
(
...
@@ -2005,7 +2016,7 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -2005,7 +2016,7 @@ class Transformer(base_converter.ConverterInterface):
orginal_output_name
=
op
.
output
[
0
]
orginal_output_name
=
op
.
output
[
0
]
op
.
output
[
0
]
=
orginal_output_name
+
"_quant"
op
.
output
[
0
]
=
orginal_output_name
+
"_quant"
op
.
output_type
.
extend
([
mace_pb2
.
DT_INT32
])
op
.
output_type
.
extend
([
to_quantize_ops_output_type
[
op
.
type
]
])
data_type_arg
=
ConverterUtil
.
get_arg
(
op
,
data_type_arg
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_op_data_type_str
)
# noqa
MaceKeyword
.
mace_op_data_type_str
)
# noqa
if
data_type_arg
is
None
:
if
data_type_arg
is
None
:
...
@@ -2022,7 +2033,7 @@ class Transformer(base_converter.ConverterInterface):
...
@@ -2022,7 +2033,7 @@ class Transformer(base_converter.ConverterInterface):
dequantize_op
.
output_type
.
extend
([
mace_pb2
.
DT_FLOAT
])
dequantize_op
.
output_type
.
extend
([
mace_pb2
.
DT_FLOAT
])
data_type_arg
=
dequantize_op
.
arg
.
add
()
data_type_arg
=
dequantize_op
.
arg
.
add
()
data_type_arg
.
name
=
MaceKeyword
.
mace_op_data_type_str
data_type_arg
.
name
=
MaceKeyword
.
mace_op_data_type_str
data_type_arg
.
i
=
mace_pb2
.
DT_INT32
data_type_arg
.
i
=
to_quantize_ops_output_type
[
op
.
type
]
quantize_flag_arg
=
ConverterUtil
.
get_arg
(
self
.
_model
,
quantize_flag_arg
=
ConverterUtil
.
get_arg
(
self
.
_model
,
MaceKeyword
.
mace_quantize_flag_arg_str
)
# noqa
MaceKeyword
.
mace_quantize_flag_arg_str
)
# noqa
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录