Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
19592d2b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
19592d2b
编写于
3月 17, 2021
作者:
C
cc
提交者:
GitHub
3月 17, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine dygraph qat, test=develop (#31680)
上级
4c0c55bb
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
303 addition
and
226 deletion
+303
-226
python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
.../paddle/fluid/contrib/slim/quantization/imperative/qat.py
+257
-226
python/paddle/fluid/contrib/slim/quantization/imperative/utils.py
...addle/fluid/contrib/slim/quantization/imperative/utils.py
+46
-0
未找到文件。
python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
浏览文件 @
19592d2b
...
@@ -25,101 +25,99 @@ from paddle.fluid.executor import Executor
...
@@ -25,101 +25,99 @@ from paddle.fluid.executor import Executor
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.param_attr
import
ParamAttr
from
paddle.fluid.initializer
import
Constant
from
paddle.fluid.initializer
import
Constant
from
paddle.fluid.dygraph.io
import
INFER_MODEL_SUFFIX
,
INFER_PARAMS_SUFFIX
from
paddle.fluid.dygraph.io
import
INFER_MODEL_SUFFIX
,
INFER_PARAMS_SUFFIX
from
paddle.nn
import
Linear
,
Conv2D
,
Conv2DTranspose
,
MaxPool2D
,
MaxPool1D
,
BatchNorm1D
,
BatchNorm2D
,
BatchNorm3D
,
SyncBatchNorm
from
paddle.nn
import
Linear
,
Conv2D
,
Conv2DTranspose
,
MaxPool2D
,
MaxPool1D
from
paddle.nn
import
BatchNorm1D
,
BatchNorm2D
,
BatchNorm3D
,
SyncBatchNorm
from
paddle.fluid.dygraph.nn
import
BatchNorm
,
Pool2D
from
paddle.fluid.dygraph.nn
import
BatchNorm
,
Pool2D
from
paddle.fluid.io
import
load_inference_model
,
save_inference_model
from
paddle.fluid.io
import
load_inference_model
,
save_inference_model
from
paddle.nn.layer.activation
import
ReLU
,
LeakyReLU
,
Sigmoid
,
ReLU6
,
Tanh
,
Softmax
,
PReLU
,
Swish
from
paddle.nn.layer.activation
import
ReLU
,
LeakyReLU
,
Sigmoid
,
ReLU6
from
paddle.nn.layer.activation
import
Tanh
,
Softmax
,
PReLU
,
Swish
from
paddle.fluid.log_helper
import
get_logger
from
paddle.fluid.log_helper
import
get_logger
from
.
import
quant_nn
from
.
import
quant_nn
from
..
import
quantization_pass
from
..
import
quantization_pass
from
.
import
utils
__all__
=
[
'ImperativeQuantAware'
,
'ImperativeCalcOutScale'
]
__all__
=
[
'ImperativeQuantAware'
]
_logger
=
get_logger
(
_logger
=
get_logger
(
__name__
,
logging
.
INFO
,
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
)
__name__
,
logging
.
INFO
,
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
)
_op_real_in_out_name
=
{
"conv2d"
:
[[
"Input"
,
"Filter"
],
[
"Output"
]],
"depthwise_conv2d"
:
[[
"Input"
,
"Filter"
],
[
"Output"
]],
"pool2d"
:
[[
"X"
],
[
"Out"
]],
"elementwise_add"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"softmax"
:
[[
"X"
],
[
"Out"
]],
"relu"
:
[[
"X"
],
[
"Out"
]],
"relu6"
:
[[
"X"
],
[
"Out"
]],
"leaky_relu"
:
[[
"X"
],
[
"Out"
]],
"prelu"
:
[[
"X"
],
[
"Out"
]],
"tanh"
:
[[
"X"
],
[
"Out"
]],
"batch_norm"
:
[[
"X"
],
[
"Y"
]],
"sigmoid"
:
[[
"X"
],
[
"Out"
]],
"swish"
:
[[
"X"
],
[
"Out"
]],
}
class
ImperativeQuantAware
(
object
):
class
ImperativeQuantAware
(
object
):
"""
"""
Add the fake quant logic for given quantizable layers, namely add the quant_dequant
Applying quantization aware training (QAT) to dgraph model.
computational logic both for activation inputs and weight inputs.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
weight_bits
=
8
,
quantizable_layer_type
=
[
'Conv2D'
,
'Linear'
],
activation_bits
=
8
,
weight_quantize_type
=
'abs_max'
,
weight_quantize_type
=
'abs_max'
,
activation_quantize_type
=
'moving_average_abs_max'
,
activation_quantize_type
=
'moving_average_abs_max'
,
weight_bits
=
8
,
activation_bits
=
8
,
moving_rate
=
0.9
,
moving_rate
=
0.9
,
quantizable_layer_type
=
[
'Conv2D'
,
'Linear'
],
weight_preprocess_layer
=
None
,
weight_preprocess_layer
=
None
,
act_preprocess_layer
=
None
,
act_preprocess_layer
=
None
,
weight_quantize_layer
=
None
,
weight_quantize_layer
=
None
,
act_quantize_layer
=
None
):
act_quantize_layer
=
None
):
r
"""
"""
The constructor for ImperativeQuantAware.
The constructor for ImperativeQuantAware.
Args:
Args:
weight_bits(int): quantization bit number for weights,
quantizable_layer_type(list[str]): List the type of layers that
whereas the bias is not quantized.
will be quantized. Default is ['Conv2D', 'Linear'].
activation_bits(int): quantization bit number for activations.
The quantizable_op_type in QuantizationFreezePass and
ConvertToInt8Pass must be the same as this.
weight_quantize_type(str): quantization type for weights,
weight_quantize_type(str): quantization type for weights,
which supports 'abs_max' now. The 'moving_average_abs_max'
which supports 'abs_max' now. The 'moving_average_abs_max'
usually is not used for weights, since weights are fixed
once the
usually is not used for weights, since weights are fixed
model is well trained.
once the
model is well trained.
activation_quantize_type(str): quantization type for activations,
activation_quantize_type(str): quantization type for activations,
which supports 'abs_max' and 'moving_average_abs_max' now.
which supports 'abs_max' and 'moving_average_abs_max' now.
If using 'abs_max' mode, the quantization scale will be calculated
If using 'abs_max' mode, the quantization scale will be
dynamically each step in both training and testing period. If using
calculated dynamically each step in both training and testing
'moving_average_abs_max', the static quantization scale will be calculated
period. If using 'moving_average_abs_max', the static
during training and used in inference.
quantization scale will be calculated during training and
moving_rate(float): the parameter for 'moving_average_abs_max' quantization.
used in inference.
quantizable_layer_type(list[str]): List the type of layers that will be quantized.
weight_bits(int): quantization bit number for weights,
Default is ['Conv2D', 'Linear']. The quantizable_op_type in
whereas the bias is not quantized.
QuantizationFreezePass and ConvertToInt8Pass must be the same as this.
activation_bits(int): quantization bit number for activations.
weight_preprocess_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to preprocess
moving_rate(float): the parameter for 'moving_average_abs_max'
weight before quantization. Using this can quickly test if user's
quantization.
preprocess method works or not. The input is non-quantized
weight_preprocess_layer(paddle.nn.Layer, optional): A paddle
weight and function returns processed weight to be quantized.
Layer that defines how to preprocess weight before quantization.
If None, the weight will be quantized directly. Default is None.
Using this can quickly test if user's preprocess method works
act_preprocess_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to preprocess
or not. The input is non-quantized weight and function returns
activation before quantization. Using this can quickly test if user's
processed weight to be quantized.
preprocess method works or not. The input is non-quantized
If None, the weight will be quantized directly.
activation and function returns processed activation to be quantized.
Default is None.
If None, the activation will be quantized directly. Default is None.
act_preprocess_layer(paddle.nn.Layer, optional): A paddle Layer
weight_quantize_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to quantize weight.
that defines how to preprocess activation before quantization.
Using this can quickly test if user's preprocess method works
or not. The input is non-quantized activation and function returns
processed activation to be quantized.
If None, the activation will be quantized directly.
Default is None.
weight_quantize_layer(paddle.nn.Layer, optional): A paddle Layer that
defines how to quantize weight.
Using this can quickly test if user's quantization method works or not.
Using this can quickly test if user's quantization method works or not.
In this layer, user should both define quantization method and
In this layer, user should both define quantization method and
dequantization method, that is, the function's input is non-quantized
dequantization method, that is, the function's input is non-quantized
weight and returns dequantized weight. If None, will use
weight and returns dequantized weight.
quantization op defined by 'weight_quantize_type'. Default is None.
If None, will use uantization op defined by 'weight_quantize_type'.
act_quantize_layer(paddle.nn.Layer, optional): A paddle Layer that defines how to quantize activation.
Default is None.
act_quantize_layer(paddle.nn.Layer, optional): A paddle Layer that defines
how to quantize activation.
Using this can quickly test if user's quantization method works or not.
Using this can quickly test if user's quantization method works or not.
In this layer, user should both define quantization method and
In this layer, user should both define quantization method and
dequantization method, that is, the function's input is non-quantized
dequantization method, that is, the function's input is non-quantized
activation and returns dequantized activation. If None, will use
activation and returns dequantized activation.
quantization op defined by 'activation_quantize_type'. Default is None.
If None, will use quantization op defined by 'activation_quantize_type'.
Default is None.
Note:
Note:
If user sets attribute 'skip_quant' to a Layer that support dynamic quantization and sets
If user sets attribute 'skip_quant' to a Layer that support dynamic
it to true, the layer would not be quantized during training. If this attribute is not sets
quantization and sets it to true, the layer would not be quantized
or the attribute is false, the Layer would be qunatized in training.
during training. If this attribute is not sets or the attribute is
false, the Layer would be qunatized in training.
Examples 1:
Examples 1:
.. code-block:: python
.. code-block:: python
...
@@ -196,141 +194,175 @@ class ImperativeQuantAware(object):
...
@@ -196,141 +194,175 @@ class ImperativeQuantAware(object):
model_path="./imperative_model_qat")
model_path="./imperative_model_qat")
"""
"""
super
(
ImperativeQuantAware
,
self
).
__init__
()
super
(
ImperativeQuantAware
,
self
).
__init__
()
self
.
_weight_bits
=
weight_bits
self
.
_activation_bits
=
activation_bits
self
.
_moving_rate
=
moving_rate
self
.
_activation_quantize_type
=
activation_quantize_type
self
.
_weight_quantize_type
=
weight_quantize_type
self
.
_weight_pre_layer
=
weight_preprocess_layer
self
.
_act_pre_layer
=
act_preprocess_layer
self
.
_weight_quant_layer
=
weight_quantize_layer
self
.
_act_quant_layer
=
act_quantize_layer
self
.
_out_scale
=
ImperativeCalcOutScale
()
t_check
=
lambda
method
:
method
is
None
or
issubclass
(
method
,
dygraph
.
layers
.
Layer
)
assert
t_check
(
self
.
_weight_pre_layer
),
"weight_preprocess should be nn.Layer"
assert
t_check
(
self
.
_act_pre_layer
),
"act_preprocess should be nn.Layer"
assert
t_check
(
self
.
_weight_quant_layer
),
"weight_quantize should be nn.Layer"
assert
t_check
(
self
.
_act_quant_layer
),
"act_quantize should be nn.Layer"
quant_type
=
{
'abs_max'
,
'moving_average_abs_max'
,
'channel_wise_abs_max'
}
assert
activation_quantize_type
!=
'channel_wise_abs_max'
,
\
kwargs
=
{
"The activation quantization type does not support 'channel_wise_abs_max'."
"quantizable_layer_type"
:
quantizable_layer_type
,
if
activation_quantize_type
not
in
quant_type
:
"weight_quantize_type"
:
weight_quantize_type
,
raise
ValueError
(
"activation_quantize_type"
:
activation_quantize_type
,
"Unknown activation_quantize_type : '%s'. It can only be "
"weight_bits"
:
weight_bits
,
"'abs_max' or 'moving_average_abs_max' now."
%
"activation_bits"
:
activation_bits
,
(
str
(
activation_quantize_type
)))
"moving_rate"
:
moving_rate
,
if
weight_quantize_type
not
in
quant_type
:
"weight_preprocess_layer"
:
weight_preprocess_layer
,
raise
ValueError
(
"act_preprocess_layer"
:
act_preprocess_layer
,
"Unknown weight_quantize_type: '%s'. It can only be "
"weight_quantize_layer"
:
weight_quantize_layer
,
"'abs_max' or 'moving_average_abs_max' or 'channel_wise_abs_max' now."
"act_quantize_layer"
:
act_quantize_layer
%
(
str
(
weight_quantize_type
)))
self
.
_quant_layers_map
=
{
'Conv2D'
:
Conv2D
,
'Linear'
:
Linear
,
'Pool2D'
:
Pool2D
,
'ReLU'
:
ReLU
,
'LeakyReLU'
:
LeakyReLU
,
'ReLU6'
:
ReLU6
,
'Softmax'
:
Softmax
,
'Tanh'
:
Tanh
,
'Swish'
:
Swish
}
}
self
.
_quantizable_layer_type
=
tuple
(
self
.
_quant_layers_map
[
layer
]
self
.
_quantize_inputs
=
ImperativeQuantizeInputs
(
**
kwargs
)
if
layer
in
self
.
_quant_layers_map
else
layer
for
layer
in
quantizable_layer_type
)
self
.
_calc_output_scale
=
ImperativeCalcOutputScale
()
for
layer
in
self
.
_quantizable_layer_type
:
assert
not
isinstance
(
layer
,
str
),
"{} is unspported to be quantized."
.
format
(
layer
)
def
quantize
(
self
,
model
):
def
quantize
(
self
,
model
):
"""
"""
According to weights' and activations' quantization types, the model will be added some fake
According to weights' and activations' quantization types,
quant ops, such as fake_quantize_dequantize_moving_average_abs_max, fake_quantize_dequantize_abs_max
the model will be added some fake quant ops, such as
and so on. At the same time, the out_scale value of outputs would be calculated.
fake_quantize_dequantize_moving_average_abs_max,
fake_quantize_dequantize_abs_max and so on. At the same time,
the out_scale value of outputs would be calculated.
Args:
Args:
model(fluid.dygraph.Layer): the model to be quantized.
model(fluid.dygraph.Layer): the model to be quantized.
Returns:
Returns:
None
None
"""
"""
assert
isinstance
(
model
,
dygraph
.
Layer
),
\
"The model must be the instance of dygraph.Layer."
self
.
_quantize_inputs
.
apply
(
model
)
self
.
_calc_output_scale
.
apply
(
model
)
def
save_quantized_model
(
self
,
layer
,
path
,
input_spec
=
None
,
**
config
):
self
.
_calc_output_scale
.
save_quantized_model
(
layer
,
path
,
input_spec
,
**
config
)
class
ImperativeQuantizeInputs
(
object
):
"""
Based on the input params, add the quant_dequant computational
logic both for activation inputs and weight inputs.
"""
def
__init__
(
self
,
quantizable_layer_type
=
[
'Conv2D'
,
'Linear'
],
weight_quantize_type
=
'abs_max'
,
activation_quantize_type
=
'moving_average_abs_max'
,
weight_bits
=
8
,
activation_bits
=
8
,
moving_rate
=
0.9
,
weight_preprocess_layer
=
None
,
act_preprocess_layer
=
None
,
weight_quantize_layer
=
None
,
act_quantize_layer
=
None
):
"""
The constructor for ImperativeQuantizeInputs.
Please refer to the args of ImperativeQuantAware.
"""
super
(
ImperativeQuantizeInputs
,
self
).
__init__
()
self
.
_quantizable_layer_type
=
tuple
(
utils
.
_quant_layers_map
[
layer
]
if
layer
in
utils
.
_quant_layers_map
else
layer
for
layer
in
quantizable_layer_type
)
for
layer
in
self
.
_quantizable_layer_type
:
assert
not
isinstance
(
layer
,
str
),
\
"%s is unspported to be quantized."
%
layer
quantize_type
=
{
'abs_max'
,
'moving_average_abs_max'
,
'channel_wise_abs_max'
}
assert
weight_quantize_type
in
quantize_type
,
\
"Unsupported weight_quantize_type: %s. It can only "
\
"be abs_max or moving_average_abs_max or "
\
"channel_wise_abs_max."
%
weight_quantize_type
assert
activation_quantize_type
!=
'channel_wise_abs_max'
\
and
activation_quantize_type
in
quantize_type
,
\
"Unsupported activation_quantize_type: %s. It can "
\
"only be abs_max or moving_average_abs_max now."
\
%
activation_quantize_type
bits_check
=
lambda
bits
:
isinstance
(
bits
,
int
)
\
and
bits
>=
0
and
bits
<=
16
assert
bits_check
(
weight_bits
),
\
"weight_bits should be 1, 2,... or 16."
assert
bits_check
(
activation_bits
),
\
"activation_bits should be 1, 2,... or 16."
layer_check
=
lambda
method
:
method
is
None
or
\
issubclass
(
method
,
dygraph
.
layers
.
Layer
)
assert
layer_check
(
weight_preprocess_layer
),
\
"weight_preprocess should be nn.Layer."
assert
layer_check
(
act_preprocess_layer
),
\
"act_preprocess should be nn.Layer."
assert
layer_check
(
weight_quantize_layer
),
\
"weight_quantize should be nn.Layer."
assert
layer_check
(
act_quantize_layer
),
\
"act_quantize should be nn.Layer."
self
.
_kwargs
=
{
"weight_quantize_type"
:
weight_quantize_type
,
"activation_quantize_type"
:
activation_quantize_type
,
"weight_bits"
:
weight_bits
,
"activation_bits"
:
activation_bits
,
"moving_rate"
:
moving_rate
,
"weight_pre_layer"
:
weight_preprocess_layer
,
"act_pre_layer"
:
act_preprocess_layer
,
"weight_quant_layer"
:
weight_quantize_layer
,
"act_quant_layer"
:
act_quantize_layer
}
def
apply
(
self
,
model
):
assert
isinstance
(
model
,
dygraph
.
Layer
),
\
"The model must be the instance of dygraph.Layer."
for
name
,
layer
in
model
.
named_sublayers
():
for
name
,
layer
in
model
.
named_sublayers
():
if
not
isinstance
(
layer
,
self
.
_quantizable_layer_type
)
:
if
not
isinstance
(
layer
,
self
.
_quantizable_layer_type
)
\
continue
or
(
hasattr
(
layer
,
"skip_quant"
)
\
if
hasattr
(
layer
,
"skip_quant"
)
and
layer
.
skip_quant
==
True
:
and
layer
.
skip_quant
==
True
)
:
continue
continue
# TODO(jc): optimize this module
last_idx
=
0
last_idx
=
0
idx
=
0
idx
=
0
obj
=
model
obj
=
model
parent
=
model
while
idx
<
len
(
name
):
while
idx
<
len
(
name
):
if
(
name
[
idx
]
==
'.'
):
if
(
name
[
idx
]
==
'.'
):
if
hasattr
(
parent
,
name
[
last_idx
:
idx
]):
if
hasattr
(
obj
,
name
[
last_idx
:
idx
]):
obj
=
getattr
(
obj
,
name
[
last_idx
:
idx
])
obj
=
getattr
(
obj
,
name
[
last_idx
:
idx
])
parent
=
obj
last_idx
=
idx
+
1
last_idx
=
idx
+
1
idx
+=
1
idx
+=
1
target
=
name
[
last_idx
:
idx
]
target
=
name
[
last_idx
:
idx
]
quant_layer
=
self
.
_get_quantized_
counterpart
(
layer
)
quant_layer
=
self
.
_get_quantized_
layer
(
layer
)
setattr
(
quant_layer
,
"layer_name"
,
layer
.
full_name
())
setattr
(
quant_layer
,
"layer_name"
,
layer
.
full_name
())
setattr
(
obj
,
target
,
quant_layer
)
setattr
(
obj
,
target
,
quant_layer
)
self
.
_out_scale
.
calc_out_scale
(
model
)
def
_get_quantized_layer
(
self
,
layer
):
quant_layer_name
=
None
def
_get_quantized_counterpart
(
self
,
layer
):
for
key
,
value
in
utils
.
_quant_layers_map
.
items
():
quant_layers
=
tuple
(
self
.
_quant_layers_map
.
values
())
if
isinstance
(
layer
,
value
):
quantized_counterpart
=
tuple
(
'Quantized'
+
k
quant_layer_name
=
'Quantized'
+
key
for
k
in
self
.
_quant_layers_map
.
keys
())
break
assert
quant_layer_name
is
not
None
,
\
predicate
=
lambda
value
:
isinstance
(
layer
,
value
)
"The layer %s is unsupported to be quantized."
\
index_generator
=
(
i
for
i
,
v
in
enumerate
(
quant_layers
)
%
layer
.
full_name
()
if
predicate
(
v
))
try
:
index
=
next
(
index_generator
)
except
StopIteration
:
_logger
.
fatal
(
"The layer {} is unsupported to be quantized."
.
format
(
layer
.
full_name
()))
sys
.
exit
(
-
1
)
layer_with_weight
=
[
'QuantizedConv2D'
,
'QuantizedLinear'
]
layer_with_weight
=
[
'QuantizedConv2D'
,
'QuantizedLinear'
]
if
quantized_counterpart
[
index
]
not
in
layer_with_weight
:
if
quant_layer_name
not
in
layer_with_weight
:
quant_layer_class_name
=
'QuantizedNoweightLayer'
quant_layer_name
=
'QuantizedNoweightLayer'
else
:
quant_layer_class_name
=
quantized_counterpart
[
index
]
quantized_layer
=
quant_nn
.
__dict__
[
quant_layer_class_name
](
layer
,
self
.
_weight_bits
,
self
.
_activation_bits
,
self
.
_moving_rate
,
self
.
_weight_quantize_type
,
self
.
_activation_quantize_type
,
self
.
_weight_pre_layer
,
self
.
_act_pre_layer
,
self
.
_weight_quant_layer
,
self
.
_act_quant_layer
)
return
quantized_layer
def
save_quantized_model
(
self
,
layer
,
path
,
input_spec
=
None
,
**
config
):
return
quant_nn
.
__dict__
[
quant_layer_name
](
layer
,
**
self
.
_kwargs
)
self
.
_out_scale
.
save_quantized_model
(
layer
,
path
,
input_spec
,
**
config
)
class
ImperativeCalcOutScale
(
object
):
class
ImperativeCalcOut
put
Scale
(
object
):
def
__init__
(
self
,
moving_rate
=
0.9
):
def
__init__
(
self
,
moving_rate
=
0.9
):
"""
"""
Add the logic of calculating and setting output quantization scales of some layers.
Add the logic of calculating and setting output scales of some layers.
These output quantization scales may be used by tensorRT or some other inference engines.
Args:
Args:
moving_rate(float): The decay coefficient of moving average. The default value is 0.9.
moving_rate(float): The decay coefficient of moving average.
The default value is 0.9.
"""
"""
super
(
ImperativeCalcOutScale
,
self
).
__init__
()
super
(
ImperativeCalcOut
put
Scale
,
self
).
__init__
()
self
.
_moving_rate
=
moving_rate
self
.
_moving_rate
=
moving_rate
self
.
_out_scale_layer_type_list
=
(
self
.
_out_scale_layer_type_list
=
(
BatchNorm
,
BatchNorm1D
,
BatchNorm2D
,
BatchNorm3D
,
Conv2D
,
LeakyReLU
,
BatchNorm
,
BatchNorm1D
,
BatchNorm2D
,
BatchNorm3D
,
Conv2D
,
LeakyReLU
,
...
@@ -339,83 +371,22 @@ class ImperativeCalcOutScale(object):
...
@@ -339,83 +371,22 @@ class ImperativeCalcOutScale(object):
self
.
_register_hook_handle_list
=
[]
self
.
_register_hook_handle_list
=
[]
self
.
_out_scale_dict
=
collections
.
OrderedDict
()
self
.
_out_scale_dict
=
collections
.
OrderedDict
()
# Determine whether layer supports calculation out_scale
def
apply
(
self
,
model
):
def
_is_matched_layer
(
self
,
layer
):
if
not
isinstance
(
layer
,
self
.
_out_scale_layer_type_list
):
if
'quantized_'
not
in
layer
.
full_name
():
return
False
return
True
# When inferenc model is saved, the logic in hook would not be executed
# in program translation, so that some parameters can not created in
# __init__, which would cause the model to fail to save. Therefore, the
# parameters creation in the hook is advanced to be exected outside the hook.
def
_add_new_parameters
(
self
,
layer
,
name
=
None
):
dtype
=
layer
.
_dtype
if
layer
.
_dtype
is
not
None
else
"float32"
if
dtype
not
in
[
"float32"
,
"float64"
]:
return
scale_prefix
=
'{}.scale'
.
format
(
name
)
if
name
else
'outscale.scale'
scale_name
=
unique_name
.
generate
(
scale_prefix
)
scale_attr
=
ParamAttr
(
name
=
scale_name
,
initializer
=
Constant
(
1
),
trainable
=
False
)
layer
.
_quant_out_scale
=
layer
.
create_parameter
(
shape
=
[
1
],
attr
=
scale_attr
,
dtype
=
dtype
)
layer
.
_quant_out_scale
.
stop_gradient
=
True
state_prefix
=
"{}.state"
.
format
(
name
)
if
name
else
'outscale.state'
state_attr
=
ParamAttr
(
name
=
unique_name
.
generate
(
state_prefix
),
initializer
=
Constant
(
1
),
trainable
=
False
)
layer
.
_quant_out_state
=
layer
.
create_parameter
(
shape
=
[
1
],
attr
=
state_attr
,
dtype
=
dtype
)
layer
.
_quant_out_state
.
stop_gradient
=
True
accum_prefix
=
"{}.accum"
.
format
(
name
)
if
name
else
'outscale.accum'
accum_attr
=
ParamAttr
(
name
=
unique_name
.
generate
(
accum_prefix
),
initializer
=
Constant
(
1
),
trainable
=
False
)
layer
.
_quant_out_accum
=
layer
.
create_parameter
(
shape
=
[
1
],
attr
=
accum_attr
,
dtype
=
dtype
)
layer
.
_quant_out_accum
.
stop_gradient
=
True
# Judge whether the op in program matches the Layer in dynamic model
def
_is_op_matched
(
self
,
layer_name
,
op
,
block
):
output_var_names
=
quantization_pass
.
_get_op_output_var_names
(
op
)
for
output_var_name
in
output_var_names
:
output_var_tensor
=
block
.
var
(
output_var_name
)
if
output_var_tensor
.
dtype
not
in
[
core
.
VarDesc
.
VarType
.
FP64
,
core
.
VarDesc
.
VarType
.
FP32
]:
return
False
# Because the naming styles of static and dynamic graph are different,
# in order to avoid mistakes, we unify the name here.
op_type
=
output_var_names
[
0
].
split
(
"."
)[
0
]
op_type
=
op_type
.
rsplit
(
"_"
,
1
)[
0
]
if
op_type
==
'depthwise_conv2d'
:
op_type
=
'conv2d'
if
'prelu'
in
op_type
:
op_type
=
op_type
.
replace
(
'prelu'
,
'p_re_lu'
)
if
'relu'
in
op_type
:
op_type
=
op_type
.
replace
(
'relu'
,
're_lu'
)
return
op_type
in
layer_name
def
calc_out_scale
(
self
,
model
):
"""
"""
Insert the `moving_average_abs_max_scale` op to calculate output scale of Specific layers in model.
Insert the `moving_average_abs_max_scale` op to calculate output
scale of specific layers in model.
Args:
Args:
model(fluid.dygraph.Layer): The target model which would be calculate the output quantization scale.
model(fluid.dygraph.Layer): The target model which would be
calculate the output quantization scale.
Returns:
Returns:
None
None
"""
"""
assert
isinstance
(
assert
isinstance
(
model
,
dygraph
.
Layer
),
\
model
,
dygraph
.
Layer
),
"model must be the instance of dygraph.Layer
"
"The model must be the instance of dygraph.Layer.
"
for
_
,
layer
in
model
.
named_sublayers
():
for
_
,
layer
in
model
.
named_sublayers
():
if
self
.
_is_
matched
_layer
(
layer
):
if
self
.
_is_
target
_layer
(
layer
):
self
.
_add_new_parameters
(
layer
)
self
.
_add_new_parameters
(
layer
)
forward_post_hook_handle
=
layer
.
register_forward_post_hook
(
forward_post_hook_handle
=
layer
.
register_forward_post_hook
(
self
.
_forward_post_hook
)
self
.
_forward_post_hook
)
...
@@ -459,7 +430,7 @@ class ImperativeCalcOutScale(object):
...
@@ -459,7 +430,7 @@ class ImperativeCalcOutScale(object):
.
numpy
())
.
numpy
())
else
:
else
:
for
_
,
sub_layer
in
self
.
_layer
.
named_sublayers
():
for
_
,
sub_layer
in
self
.
_layer
.
named_sublayers
():
if
self
.
_is_
matched
_layer
(
sub_layer
):
if
self
.
_is_
target
_layer
(
sub_layer
):
layer_name
=
sub_layer
.
full_name
()
layer_name
=
sub_layer
.
full_name
()
if
hasattr
(
sub_layer
,
"layer_name"
):
if
hasattr
(
sub_layer
,
"layer_name"
):
layer_name
=
sub_layer
.
layer_name
layer_name
=
sub_layer
.
layer_name
...
@@ -510,7 +481,7 @@ class ImperativeCalcOutScale(object):
...
@@ -510,7 +481,7 @@ class ImperativeCalcOutScale(object):
forward_op
=
None
forward_op
=
None
for
block
in
inference_program
.
blocks
:
for
block
in
inference_program
.
blocks
:
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
if
op
.
type
in
_op_real_in_out_name
:
if
op
.
type
in
utils
.
_op_real_in_out_name
:
if
op_count
>
len
(
ops_list
):
if
op_count
>
len
(
ops_list
):
warnings
.
warn
(
warnings
.
warn
(
"The number of Layer which has out_threshold attribute should be bigger than the op in inference model"
"The number of Layer which has out_threshold attribute should be bigger than the op in inference model"
...
@@ -567,6 +538,66 @@ class ImperativeCalcOutScale(object):
...
@@ -567,6 +538,66 @@ class ImperativeCalcOutScale(object):
if
is_dynamic_mode
:
if
is_dynamic_mode
:
paddle
.
disable_static
()
paddle
.
disable_static
()
def
_is_target_layer
(
self
,
layer
):
return
isinstance
(
layer
,
self
.
_out_scale_layer_type_list
)
\
or
'quantized_'
in
layer
.
full_name
()
# When inferenc model is saved, the logic in hook would not be executed
# in program translation, so that some parameters can not created in
# __init__, which would cause the model to fail to save. Therefore, the
# parameters creation in the hook is advanced to be exected outside the hook.
def
_add_new_parameters
(
self
,
layer
,
name
=
None
):
dtype
=
layer
.
_dtype
if
layer
.
_dtype
is
not
None
else
"float32"
if
dtype
not
in
[
"float32"
,
"float64"
]:
return
scale_prefix
=
'{}.scale'
.
format
(
name
)
if
name
else
'outscale.scale'
scale_name
=
unique_name
.
generate
(
scale_prefix
)
scale_attr
=
ParamAttr
(
name
=
scale_name
,
initializer
=
Constant
(
1
),
trainable
=
False
)
layer
.
_quant_out_scale
=
layer
.
create_parameter
(
shape
=
[
1
],
attr
=
scale_attr
,
dtype
=
dtype
)
layer
.
_quant_out_scale
.
stop_gradient
=
True
state_prefix
=
"{}.state"
.
format
(
name
)
if
name
else
'outscale.state'
state_attr
=
ParamAttr
(
name
=
unique_name
.
generate
(
state_prefix
),
initializer
=
Constant
(
1
),
trainable
=
False
)
layer
.
_quant_out_state
=
layer
.
create_parameter
(
shape
=
[
1
],
attr
=
state_attr
,
dtype
=
dtype
)
layer
.
_quant_out_state
.
stop_gradient
=
True
accum_prefix
=
"{}.accum"
.
format
(
name
)
if
name
else
'outscale.accum'
accum_attr
=
ParamAttr
(
name
=
unique_name
.
generate
(
accum_prefix
),
initializer
=
Constant
(
1
),
trainable
=
False
)
layer
.
_quant_out_accum
=
layer
.
create_parameter
(
shape
=
[
1
],
attr
=
accum_attr
,
dtype
=
dtype
)
layer
.
_quant_out_accum
.
stop_gradient
=
True
# Judge whether the op in program matches the Layer in dynamic model
def
_is_op_matched
(
self
,
layer_name
,
op
,
block
):
output_var_names
=
quantization_pass
.
_get_op_output_var_names
(
op
)
for
output_var_name
in
output_var_names
:
output_var_tensor
=
block
.
var
(
output_var_name
)
if
output_var_tensor
.
dtype
not
in
[
core
.
VarDesc
.
VarType
.
FP64
,
core
.
VarDesc
.
VarType
.
FP32
]:
return
False
# Because the naming styles of static and dynamic graph are different,
# in order to avoid mistakes, we unify the name here.
op_type
=
output_var_names
[
0
].
split
(
"."
)[
0
]
op_type
=
op_type
.
rsplit
(
"_"
,
1
)[
0
]
if
op_type
==
'depthwise_conv2d'
:
op_type
=
'conv2d'
if
'prelu'
in
op_type
:
op_type
=
op_type
.
replace
(
'prelu'
,
'p_re_lu'
)
if
'relu'
in
op_type
:
op_type
=
op_type
.
replace
(
'relu'
,
're_lu'
)
return
op_type
in
layer_name
def
_forward_post_hook
(
self
,
layer
,
input
,
output
):
def
_forward_post_hook
(
self
,
layer
,
input
,
output
):
assert
isinstance
(
assert
isinstance
(
output
,
(
core
.
VarBase
,
framework
.
Variable
)
output
,
(
core
.
VarBase
,
framework
.
Variable
)
...
...
python/paddle/fluid/contrib/slim/quantization/imperative/utils.py
0 → 100644
浏览文件 @
19592d2b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.nn
import
Linear
,
Conv2D
from
paddle.fluid.dygraph.nn
import
Pool2D
from
paddle.nn.layer.activation
import
ReLU
,
LeakyReLU
,
Sigmoid
,
ReLU6
from
paddle.nn.layer.activation
import
Tanh
,
Softmax
,
PReLU
,
Swish
_op_real_in_out_name
=
{
"conv2d"
:
[[
"Input"
,
"Filter"
],
[
"Output"
]],
"depthwise_conv2d"
:
[[
"Input"
,
"Filter"
],
[
"Output"
]],
"pool2d"
:
[[
"X"
],
[
"Out"
]],
"elementwise_add"
:
[[
"X"
,
"Y"
],
[
"Out"
]],
"softmax"
:
[[
"X"
],
[
"Out"
]],
"relu"
:
[[
"X"
],
[
"Out"
]],
"relu6"
:
[[
"X"
],
[
"Out"
]],
"leaky_relu"
:
[[
"X"
],
[
"Out"
]],
"prelu"
:
[[
"X"
],
[
"Out"
]],
"tanh"
:
[[
"X"
],
[
"Out"
]],
"batch_norm"
:
[[
"X"
],
[
"Y"
]],
"sigmoid"
:
[[
"X"
],
[
"Out"
]],
"swish"
:
[[
"X"
],
[
"Out"
]],
}
_quant_layers_map
=
{
'Conv2D'
:
Conv2D
,
'Linear'
:
Linear
,
'Pool2D'
:
Pool2D
,
'ReLU'
:
ReLU
,
'LeakyReLU'
:
LeakyReLU
,
'ReLU6'
:
ReLU6
,
'Softmax'
:
Softmax
,
'Tanh'
:
Tanh
,
'Swish'
:
Swish
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录