Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
205291a3
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
205291a3
编写于
5月 18, 2020
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(mge/quantization): add histgram observer
GitOrigin-RevId: a9252a6bafe19b3ac958acb7a617bbbb47dc1514
上级
7c4f1a38
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
288 addition
and
17 deletion
+288
-17
python_module/megengine/module/module.py
python_module/megengine/module/module.py
+13
-6
python_module/megengine/quantization/__init__.py
python_module/megengine/quantization/__init__.py
+7
-2
python_module/megengine/quantization/observer.py
python_module/megengine/quantization/observer.py
+259
-5
python_module/megengine/quantization/qconfig.py
python_module/megengine/quantization/qconfig.py
+9
-3
python_module/megengine/quantization/quantize.py
python_module/megengine/quantization/quantize.py
+0
-1
未找到文件。
python_module/megengine/module/module.py
浏览文件 @
205291a3
...
...
@@ -486,8 +486,16 @@ class QATModule(Module):
self
.
weight_observer
=
qconfig
.
weight_observer
()
self
.
act_observer
=
qconfig
.
act_observer
()
self
.
weight_fake_quant
=
qconfig
.
fake_quant
(
self
.
weight_observer
.
dtype
)
self
.
act_fake_quant
=
qconfig
.
fake_quant
(
self
.
act_observer
.
dtype
)
self
.
weight_fake_quant
=
(
None
if
qconfig
.
fake_quant
is
None
else
qconfig
.
fake_quant
(
self
.
weight_observer
.
dtype
)
)
self
.
act_fake_quant
=
(
None
if
qconfig
.
fake_quant
is
None
else
qconfig
.
fake_quant
(
self
.
act_observer
.
dtype
)
)
def
apply_observer
(
self
,
target
:
Tensor
,
obs
:
"Observer"
):
return
obs
(
target
)
...
...
@@ -496,11 +504,10 @@ class QATModule(Module):
self
,
target
:
Tensor
,
fq
:
"FakeQuantize"
,
obs
:
"Observer"
):
oup
=
self
.
apply_observer
(
target
,
obs
)
if
self
.
quantizing
==
self
.
QATMode
.
CALIBRATION
:
return
oup
else
:
if
fq
is
not
None
:
scale
,
zero_point
=
obs
.
get_qparams
()
return
fq
(
oup
,
scale
,
zero_point
)
oup
=
fq
(
oup
,
scale
,
zero_point
)
return
oup
def
set_qat_mode
(
self
,
mode
:
QATMode
):
r
"""
...
...
python_module/megengine/quantization/__init__.py
浏览文件 @
205291a3
...
...
@@ -6,8 +6,13 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
.fake_quant
import
FakeQuantize
from
.observer
import
Observer
from
.qconfig
import
QConfig
,
ema_fakequant_qconfig
,
min_max_fakequant_qconfig
from
.observer
import
HistogramObserver
,
Observer
from
.qconfig
import
(
QConfig
,
calibration_qconfig
,
ema_fakequant_qconfig
,
min_max_fakequant_qconfig
,
)
from
.quantize
import
(
disable_fake_quant
,
disable_observer
,
...
...
python_module/megengine/quantization/observer.py
浏览文件 @
205291a3
...
...
@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
math
from
abc
import
abstractmethod
import
numpy
as
np
...
...
@@ -12,6 +13,7 @@ import numpy as np
from
..
import
functional
as
F
from
.._internal.dtype
import
_metadata_dict
,
get_quantized_dtype
from
..core
import
Buffer
,
Function
,
tensor
from
..jit
import
sideeffect
from
..module
import
Module
...
...
@@ -94,9 +96,11 @@ class MinMaxObserver(Observer):
F
.
add_update
(
self
.
max_val
,
tmp_max
,
alpha
=
0.0
,
beta
=
1.0
,
bias
=
0.0
)
F
.
add_update
(
self
.
first_flag
,
self
.
not_flag
,
alpha
=
0.0
,
beta
=
1.0
,
bias
=
0.0
)
def
get_qparams
(
self
):
def
_calculate_qparams
(
self
,
inp_min_val
,
inp_max_val
):
min_val
=
F
.
minimum
(
0.0
,
inp_min_val
)
max_val
=
F
.
maximum
(
0.0
,
inp_max_val
)
if
self
.
symmetric
:
symmetric_max_vals
=
F
.
maximum
(
-
self
.
min_val
,
self
.
max_val
)
symmetric_max_vals
=
F
.
maximum
(
-
min_val
,
max_val
)
# use maximun to avoid scale too small at the begin
scale
=
F
.
maximum
(
symmetric_max_vals
/
((
self
.
qmax
-
self
.
qmin
)
/
2
),
self
.
scale_limit
...
...
@@ -105,14 +109,16 @@ class MinMaxObserver(Observer):
else
:
# use maximun to avoid scale too small at the begin
scale
=
F
.
maximum
(
(
self
.
max_val
-
self
.
min_val
)
/
(
self
.
qmax
-
self
.
qmin
),
self
.
scale_limit
,
(
max_val
-
min_val
)
/
(
self
.
qmax
-
self
.
qmin
),
self
.
scale_limit
,
)
# caculate zero_point
zero_point
=
self
.
qmin
-
Round
()((
self
.
min_val
/
scale
))
zero_point
=
self
.
qmin
-
Round
()((
min_val
/
scale
))
return
scale
,
zero_point
def
get_qparams
(
self
):
return
self
.
_calculate_qparams
(
self
.
min_val
,
self
.
max_val
)
def
forward
(
self
,
x_orig
):
if
self
.
enabled
:
# stop gradient
...
...
@@ -161,3 +167,251 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
)
self
.
set_min_max
(
tmp_min
,
tmp_max
)
return
x_orig
class
HistogramObserver
(
MinMaxObserver
):
def
__init__
(
self
,
bins
=
2048
,
upsample_rate
=
128
,
dtype
=
"qint8"
,
*
args
,
**
kwargs
):
super
().
__init__
(
dtype
=
dtype
,
*
args
,
**
kwargs
)
self
.
bins
=
bins
self
.
upsample_rate
=
upsample_rate
self
.
dst_nbins
=
_metadata_dict
[
dtype
].
qmax
-
_metadata_dict
[
dtype
].
qmin
+
1
self
.
histogram
=
None
def
_non_linear_param_search
(
self
):
r
"""Non-linear parameter search.
An approximation for L2 error minimization for selecting min/max.
By selecting new min/max, we filter out outliers in input distribution.
"""
def
_get_norm
(
delta_begin
,
delta_end
,
density
,
norm_type
):
r
"""
Compute the norm of the values uniformaly distributed between
delta_begin and delta_end.
norm = density * (integral_{begin, end} x^2)
= density * (end^3 - begin^3) / 3
"""
assert
norm_type
==
"L2"
,
"Only L2 norms are currently supported"
norm
=
0.0
if
norm_type
==
"L2"
:
norm
=
(
delta_end
*
delta_end
*
delta_end
-
delta_begin
*
delta_begin
*
delta_begin
)
/
3
return
density
*
norm
def
_compute_quantization_error
(
next_start_bin
,
next_end_bin
,
norm_type
):
r
"""
Compute the quantization error if we use start_bin to end_bin as the
min and max to do the quantization.
"""
np_min_val
=
self
.
min_val
.
numpy
()[
0
]
np_max_val
=
self
.
max_val
.
numpy
()[
0
]
bin_width
=
(
np_max_val
-
np_min_val
)
/
self
.
bins
norm
=
0.0
dst_bin_width
=
(
bin_width
*
(
next_end_bin
-
next_start_bin
+
1
)
/
self
.
dst_nbins
)
if
dst_bin_width
==
0.0
:
return
0.0
for
src_bin
in
range
(
self
.
bins
):
# distances from the beginning of first dst_bin to the beginning and
# end of src_bin
src_bin_begin
=
(
src_bin
-
next_start_bin
)
*
bin_width
src_bin_end
=
src_bin_begin
+
bin_width
# which dst_bins the beginning and end of src_bin belong to?
dst_bin_of_begin
=
min
(
self
.
dst_nbins
-
1
,
max
(
0.0
,
math
.
floor
(
src_bin_begin
/
dst_bin_width
)),
)
dst_bin_of_end
=
min
(
self
.
dst_nbins
-
1
,
max
(
0.0
,
math
.
floor
(
src_bin_end
/
dst_bin_width
)),
)
dst_bin_of_begin_center
=
(
dst_bin_of_begin
*
dst_bin_width
+
dst_bin_width
/
2
)
density
=
self
.
histogram
[
src_bin
]
/
bin_width
if
dst_bin_of_begin
==
dst_bin_of_end
:
# if src_bin is entirely within 1 dst_bin
delta_begin
=
src_bin_begin
-
dst_bin_of_begin_center
delta_end
=
src_bin_end
-
dst_bin_of_begin_center
norm
=
norm
+
_get_norm
(
delta_begin
,
delta_end
,
density
,
norm_type
)
else
:
delta_begin
=
src_bin_begin
-
dst_bin_of_begin_center
delta_end
=
dst_bin_width
/
2
norm
=
norm
+
_get_norm
(
delta_begin
,
delta_end
,
density
,
norm_type
)
norm
=
norm
+
(
dst_bin_of_end
-
dst_bin_of_begin
-
1
)
*
_get_norm
(
-
dst_bin_width
/
2
,
dst_bin_width
/
2
,
density
,
norm_type
)
dst_bin_of_end_center
=
(
dst_bin_of_end
*
dst_bin_width
+
dst_bin_width
/
2
)
delta_begin
=
-
dst_bin_width
/
2
delta_end
=
src_bin_end
-
dst_bin_of_end_center
norm
=
norm
+
_get_norm
(
delta_begin
,
delta_end
,
density
,
norm_type
)
return
norm
assert
len
(
self
.
histogram
)
==
self
.
bins
,
"bins mistmatch"
bin_width
=
(
self
.
max_val
-
self
.
min_val
)
/
self
.
bins
# cumulative sum
total
=
sum
(
self
.
histogram
)
cSum
=
np
.
cumsum
(
self
.
histogram
,
axis
=
0
)
stepsize
=
1e-5
# granularity
alpha
=
0.0
# lower bound
beta
=
1.0
# upper bound
start_bin
=
0
end_bin
=
self
.
bins
-
1
norm_min
=
float
(
"inf"
)
while
alpha
<
beta
:
# Find the next step
next_alpha
=
alpha
+
stepsize
next_beta
=
beta
-
stepsize
# find the left and right bins between the quantile bounds
l
=
start_bin
r
=
end_bin
while
l
<
end_bin
and
cSum
[
l
]
<
next_alpha
*
total
:
l
=
l
+
1
while
r
>
start_bin
and
cSum
[
r
]
>
next_beta
*
total
:
r
=
r
-
1
# decide the next move
next_start_bin
=
start_bin
next_end_bin
=
end_bin
if
(
l
-
start_bin
)
>
(
end_bin
-
r
):
# move the start bin
next_start_bin
=
l
alpha
=
next_alpha
else
:
# move the end bin
next_end_bin
=
r
beta
=
next_beta
if
next_start_bin
==
start_bin
and
next_end_bin
==
end_bin
:
continue
# calculate the quantization error using next_start_bin and next_end_bin
norm
=
_compute_quantization_error
(
next_start_bin
,
next_end_bin
,
"L2"
)
if
norm
>
norm_min
:
break
norm_min
=
norm
start_bin
=
next_start_bin
end_bin
=
next_end_bin
new_min
=
self
.
min_val
+
bin_width
*
start_bin
new_max
=
self
.
min_val
+
bin_width
*
(
end_bin
+
1
)
return
new_min
,
new_max
def
get_qparams
(
self
):
new_min
,
new_max
=
self
.
_non_linear_param_search
()
return
self
.
_calculate_qparams
(
new_min
,
new_max
)
def
_combine_histograms
(
self
,
orig_hist
,
new_hist
,
upsample_rate
,
downsample_rate
,
start_idx
,
Nbins
):
# First up-sample the histogram with new data by a factor of L
# This creates an approximate probability density thats piecwise constant
upsampled_histogram
=
new_hist
.
repeat
(
upsample_rate
)
# Now insert the upsampled histogram into the output
# histogram, which is initialized with zeros.
# The offset at which the histogram is introduced is determined
# by the start index as the output histogram can cover a wider range
histogram_with_output_range
=
np
.
zeros
((
Nbins
*
downsample_rate
))
histogram_with_output_range
[
start_idx
:
Nbins
*
upsample_rate
+
start_idx
]
=
upsampled_histogram
# Compute integral histogram, double precision is needed to ensure
# that there are no overflows
integral_histogram
=
np
.
cumsum
(
histogram_with_output_range
,
0
)[
downsample_rate
-
1
::
downsample_rate
]
# Finally perform interpolation
shifted_integral_histogram
=
np
.
zeros
((
Nbins
))
shifted_integral_histogram
[
1
:
Nbins
]
=
integral_histogram
[
0
:
-
1
]
interpolated_histogram
=
(
integral_histogram
-
shifted_integral_histogram
)
/
upsample_rate
orig_hist
=
orig_hist
+
interpolated_histogram
return
orig_hist
def
_adjust_min_max
(
self
,
combined_min
,
combined_max
,
upsample_rate
):
# We ensure that:
# (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
# This allows us to have a common grid of resolution s, where we can align
# the input histogram
# start_idx maps min_val to the histogram bin index.
np_min_val
=
self
.
min_val
.
numpy
()[
0
]
np_max_val
=
self
.
max_val
.
numpy
()[
0
]
hist_bin_width
=
(
np_max_val
-
np_min_val
)
/
(
self
.
bins
*
upsample_rate
)
downsample_rate
=
int
(
np
.
ceil
((
combined_max
-
combined_min
)
/
(
self
.
bins
*
hist_bin_width
))
)
e
=
downsample_rate
*
(
self
.
bins
*
hist_bin_width
)
-
(
combined_max
-
combined_min
)
combined_max
=
combined_max
+
e
/
2
combined_min
=
combined_min
-
e
/
2
start_idx
=
int
(
np
.
round
((
np_min_val
-
combined_min
)
/
hist_bin_width
))
return
combined_min
,
combined_max
,
downsample_rate
,
start_idx
@
sideeffect
def
sideeffect_forward
(
self
,
x_orig
):
x
=
x_orig
.
numpy
()
min_val
=
self
.
min_val
.
numpy
()[
0
]
max_val
=
self
.
max_val
.
numpy
()[
0
]
if
min_val
==
0
or
max_val
==
0
:
min_val
=
x
.
min
()
max_val
=
x
.
max
()
self
.
min_val
.
set_value
(
min_val
)
self
.
max_val
.
set_value
(
max_val
)
self
.
histogram
,
_
=
np
.
histogram
(
x
,
self
.
bins
,
(
min_val
,
max_val
))
self
.
histogram
=
self
.
histogram
.
astype
(
np
.
float64
)
else
:
new_min
=
x
.
min
()
new_max
=
x
.
max
()
combined_min
=
min
(
new_min
,
min_val
)
combined_max
=
max
(
new_max
,
max_val
)
# combine the existing histogram and new histogram into 1 histogram
# We do this by first upsampling the histogram to a dense grid
# and then downsampling the histogram efficiently
(
combined_min
,
combined_max
,
downsample_rate
,
start_idx
,
)
=
self
.
_adjust_min_max
(
combined_min
,
combined_max
,
self
.
upsample_rate
)
combined_histogram
,
_
=
np
.
histogram
(
x
,
self
.
bins
,
(
combined_min
,
combined_max
)
)
combined_histogram
=
combined_histogram
.
astype
(
np
.
float64
)
if
combined_min
==
min_val
and
combined_max
==
max_val
:
combined_histogram
+=
self
.
histogram
else
:
combined_histogram
=
self
.
_combine_histograms
(
combined_histogram
,
self
.
histogram
,
self
.
upsample_rate
,
downsample_rate
,
start_idx
,
self
.
bins
,
)
self
.
histogram
=
combined_histogram
self
.
min_val
.
set_value
(
combined_min
)
self
.
max_val
.
set_value
(
combined_max
)
def
forward
(
self
,
x_orig
):
self
.
sideeffect_forward
(
x_orig
)
return
x_orig
python_module/megengine/quantization/qconfig.py
浏览文件 @
205291a3
...
...
@@ -5,11 +5,13 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from
functools
import
partial
from
..module
import
Module
from
.fake_quant
import
FakeQuantize
from
.observer
import
ExponentialMovingAverageObserver
,
MinMaxObserver
from
.observer
import
(
ExponentialMovingAverageObserver
,
HistogramObserver
,
MinMaxObserver
,
)
class
QConfig
:
...
...
@@ -66,3 +68,7 @@ ema_fakequant_qconfig = QConfig(
act_observer
=
ExponentialMovingAverageObserver
,
fake_quant
=
FakeQuantize
,
)
calibration_qconfig
=
QConfig
(
weight_observer
=
MinMaxObserver
,
act_observer
=
HistogramObserver
,
fake_quant
=
None
,
)
python_module/megengine/quantization/quantize.py
浏览文件 @
205291a3
...
...
@@ -71,7 +71,6 @@ def quantize_calibration(module: Module, qconfig: QConfig = ema_fakequant_qconfi
mod
.
set_qconfig
(
qconfig
)
module
.
apply
(
fn
)
enable_observer
(
module
)
def
disable_fake_quant
(
module
:
Module
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录