Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8967a66a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8967a66a
编写于
8月 18, 2021
作者:
X
XGZhang
提交者:
GitHub
8月 18, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support quantization of conv2d_transpose (#34547)
上级
4d88cdb8
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
225 addition
and
32 deletion
+225
-32
python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
.../paddle/fluid/contrib/slim/quantization/imperative/qat.py
+76
-25
python/paddle/fluid/contrib/slim/quantization/imperative/utils.py
...addle/fluid/contrib/slim/quantization/imperative/utils.py
+14
-5
python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py
...on/paddle/fluid/contrib/slim/tests/test_imperative_qat.py
+8
-2
python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py
...id/contrib/slim/tests/test_imperative_qat_user_defined.py
+19
-0
python/paddle/nn/quant/quant_layers.py
python/paddle/nn/quant/quant_layers.py
+107
-0
tools/sampcd_processor.py
tools/sampcd_processor.py
+1
-0
未找到文件。
python/paddle/fluid/contrib/slim/quantization/imperative/qat.py
浏览文件 @
8967a66a
...
...
@@ -42,8 +42,9 @@ class ImperativeQuantAware(object):
Applying quantization aware training (QAT) to the dgraph model.
"""
def
__init__
(
self
,
quantizable_layer_type
=
[
'Conv2D'
,
'Linear'
],
def
__init__
(
self
,
quantizable_layer_type
=
[
'Conv2D'
,
'Linear'
,
'Conv2DTranspose'
],
weight_quantize_type
=
'abs_max'
,
activation_quantize_type
=
'moving_average_abs_max'
,
weight_bits
=
8
,
...
...
@@ -212,9 +213,44 @@ class ImperativeQuantAware(object):
the out_scale value of outputs would be calculated.
Args:
model(
fluid.dygraph
.Layer): the model to be quantized.
model(
paddle.nn
.Layer): the model to be quantized.
Returns:
None
Examples:
.. code-block:: python
import paddle
from paddle.fluid.contrib.slim.quantization
\
import ImperativeQuantAware
class ImperativeModel(paddle.nn.Layer):
def __init__(self):
super(ImperativeModel, self).__init__()
# self.linear_0 would skip the quantization.
self.linear_0 = paddle.nn.Linear(784, 400)
self.linear_0.skip_quant = True
# self.linear_1 would not skip the quantization.
self.linear_1 = paddle.nn.Linear(400, 10)
self.linear_1.skip_quant = False
def forward(self, inputs):
x = self.linear_0(inputs)
x = self.linear_1(inputs)
return x
model = ImperativeModel()
imperative_qat = ImperativeQuantAware(
weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max')
# Add the fake quant logical.
# The original model will be rewrite.
#
# There is only one Layer(self.linear1) would be added the
# fake quant logical.
imperative_qat.quantize(model)
"""
assert
isinstance
(
model
,
dygraph
.
Layer
),
\
"The model must be the instance of dygraph.Layer."
...
...
@@ -232,8 +268,9 @@ class ImperativeQuantizeInputs(object):
logic both for activation inputs and weight inputs.
"""
def
__init__
(
self
,
quantizable_layer_type
=
[
'Conv2D'
,
'Linear'
],
def
__init__
(
self
,
quantizable_layer_type
=
[
'Conv2D'
,
'Linear'
,
'Conv2DTranspose'
],
weight_quantize_type
=
'abs_max'
,
activation_quantize_type
=
'moving_average_abs_max'
,
weight_bits
=
8
,
...
...
@@ -303,6 +340,18 @@ class ImperativeQuantizeInputs(object):
}
def
apply
(
self
,
model
):
"""
Quantize the weights and activations to calculate for specific
layers.
Args:
model(paddle.nn.Layer): The target model which would
calculate the input quantization scale.
Returns:
None
"""
assert
isinstance
(
model
,
dygraph
.
Layer
),
\
"The model must be the instance of dygraph.Layer."
...
...
@@ -354,7 +403,7 @@ class ImperativeQuantizeOutputs(object):
output scales for specific layers in the dygraph model.
Args:
model(
fluid.dygraph
.Layer): The target model which would be
model(
paddle.nn
.Layer): The target model which would be
calculate the output quantization scale.
Returns:
...
...
@@ -544,7 +593,9 @@ class ImperativeQuantizeOutputs(object):
1. the type of input op should be conv2d, depthwise_conv2d or matmul
2. the previous ops of the input op are not fake_quantize_dequantize ops
"""
target_op_types
=
[
"conv2d"
,
"depthwise_conv2d"
,
"matmul"
]
target_op_types
=
[
"conv2d"
,
"depthwise_conv2d"
,
"matmul"
,
"conv2d_transpose"
]
if
in_op
.
type
not
in
target_op_types
:
return
False
...
...
python/paddle/fluid/contrib/slim/quantization/imperative/utils.py
浏览文件 @
8967a66a
...
...
@@ -24,6 +24,7 @@ from ..quantization_pass import _get_output_name_index
from
..quantization_pass
import
_get_input_name_index
layer_name_map
=
{
'Conv2DTranspose'
:
paddle
.
nn
.
Conv2DTranspose
,
'Conv2D'
:
paddle
.
nn
.
Conv2D
,
'Linear'
:
paddle
.
nn
.
Linear
,
'AdaptiveAvgPool2D'
:
paddle
.
nn
.
AdaptiveAvgPool2D
,
...
...
@@ -46,8 +47,9 @@ layer_name_map = {
}
# Apply fake quant for the inputs of these layers
# TODO (jc): support paddle.nn.Conv2DTranspose
fake_quant_input_layers
=
[
paddle
.
nn
.
Conv2D
,
paddle
.
nn
.
Linear
]
fake_quant_input_layers
=
[
paddle
.
nn
.
Conv2D
,
paddle
.
nn
.
Linear
,
paddle
.
nn
.
Conv2DTranspose
]
# Apply fake quant for the output of these layers
# TODO(jc): fix the problem of adding duplicate fake_quant ops
...
...
@@ -65,7 +67,8 @@ fake_quant_leaf_layers = [
]
fake_quant_wrap_layers
=
[
quant_layers
.
QuantizedConv2D
,
quant_layers
.
QuantizedLinear
quant_layers
.
QuantizedConv2D
,
quant_layers
.
QuantizedLinear
,
quant_layers
.
QuantizedConv2DTranspose
]
# The weight format of these layers is Cin * Cout * H * W
...
...
@@ -84,9 +87,9 @@ fake_quantize_dequantize_op_types = [
def
load_variable_data
(
scope
,
var_name
):
'''
"""
Load variable value from scope
'''
"""
var_node
=
scope
.
find_var
(
var_name
)
assert
var_node
is
not
None
,
\
"Can not find "
+
var_name
+
" in the scope."
...
...
@@ -120,6 +123,12 @@ def find_parent_layer_and_sub_name(model, name):
the sub_name of the layer.
For example, if name is 'block_1/convbn_1/conv_1', the parent layer is
'block_1/convbn_1' and the sub_name is `conv_1`.
Args:
model(paddle.nn.Layer): the model to be quantized.
name(string): the name of a layer
Returns:
parent_layer, subname
"""
assert
isinstance
(
model
,
paddle
.
nn
.
Layer
),
\
"The model must be the instance of paddle.nn.Layer."
...
...
python/paddle/fluid/contrib/slim/tests/test_imperative_qat.py
浏览文件 @
8967a66a
...
...
@@ -28,10 +28,10 @@ from paddle.fluid import core
from
paddle.fluid.optimizer
import
AdamOptimizer
from
paddle.fluid.contrib.slim.quantization
import
ImperativeQuantAware
from
paddle.fluid.dygraph.container
import
Sequential
from
paddle.nn
import
Linear
,
Conv2D
,
Softmax
from
paddle.nn
import
Linear
,
Conv2D
,
Softmax
,
Conv2DTranspose
from
paddle.fluid.log_helper
import
get_logger
from
paddle.fluid.dygraph.io
import
INFER_MODEL_SUFFIX
,
INFER_PARAMS_SUFFIX
from
paddle.nn.quant.quant_layers
import
QuantizedConv2D
from
paddle.nn.quant.quant_layers
import
QuantizedConv2D
,
QuantizedConv2DTranspose
from
imperative_test_utils
import
fix_model_dict
,
ImperativeLenet
...
...
@@ -75,6 +75,12 @@ class TestImperativeQat(unittest.TestCase):
data
=
np
.
random
.
uniform
(
-
1
,
1
,
[
10
,
3
,
32
,
32
]).
astype
(
'float32'
)
quant_conv1
(
fluid
.
dygraph
.
to_variable
(
data
))
conv_transpose
=
Conv2DTranspose
(
4
,
6
,
(
3
,
3
))
quant_conv_transpose
=
QuantizedConv2DTranspose
(
conv_transpose
)
x_var
=
paddle
.
uniform
(
(
2
,
4
,
8
,
8
),
dtype
=
'float32'
,
min
=-
1.0
,
max
=
1.0
)
quant_conv_transpose
(
x_var
)
seed
=
1
np
.
random
.
seed
(
seed
)
fluid
.
default_main_program
().
random_seed
=
seed
...
...
python/paddle/fluid/contrib/slim/tests/test_imperative_qat_user_defined.py
浏览文件 @
8967a66a
...
...
@@ -28,6 +28,7 @@ from paddle.nn import Sequential
from
paddle.fluid.dygraph
import
Conv2D
from
paddle.fluid.dygraph
import
Pool2D
from
paddle.fluid.dygraph
import
Linear
from
paddle.nn.quant.quant_layers
import
QuantizedConv2DTranspose
from
paddle.fluid.log_helper
import
get_logger
os
.
environ
[
"CPU_NUM"
]
=
"1"
...
...
@@ -100,6 +101,19 @@ class CustomQAT(nn.Layer):
return
x
class
ModelForConv2dT
(
nn
.
Layer
):
def
__init__
(
self
,
num_classes
=
10
):
super
(
ModelForConv2dT
,
self
).
__init__
()
self
.
features
=
nn
.
Conv2DTranspose
(
4
,
6
,
(
3
,
3
))
self
.
fc
=
Linear
(
input_dim
=
600
,
output_dim
=
num_classes
)
def
forward
(
self
,
inputs
):
x
=
self
.
features
(
inputs
)
x
=
paddle
.
flatten
(
x
,
1
)
x
=
self
.
fc
(
x
)
return
x
class
ImperativeLenet
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
num_classes
=
10
,
classifier_activation
=
'softmax'
):
super
(
ImperativeLenet
,
self
).
__init__
()
...
...
@@ -168,6 +182,11 @@ class TestUserDefinedActPreprocess(unittest.TestCase):
imperative_qat
.
quantize
(
lenet
)
adam
=
Adam
(
learning_rate
=
0.001
,
parameters
=
lenet
.
parameters
())
dynamic_loss_rec
=
[]
#for CI coverage
conv_transpose
=
ModelForConv2dT
()
imperative_qat
.
quantize
(
conv_transpose
)
x_var
=
paddle
.
uniform
((
2
,
4
,
8
,
8
),
dtype
=
'float32'
,
min
=-
1.
,
max
=
1.
)
conv_transpose
(
x_var
)
def
train
(
model
):
adam
=
Adam
(
learning_rate
=
0.001
,
parameters
=
model
.
parameters
())
...
...
python/paddle/nn/quant/quant_layers.py
浏览文件 @
8967a66a
...
...
@@ -31,6 +31,7 @@ __all__ = [
'FakeQuantMovingAverageAbsMax'
,
'FakeQuantChannelWiseAbsMax'
,
'QuantizedConv2D'
,
'QuantizedConv2DTranspose'
,
'QuantizedLinear'
,
'MovingAverageAbsMaxScale'
,
'MAOutputScaleLayer'
,
...
...
@@ -481,6 +482,112 @@ class QuantizedConv2D(layers.Layer):
data_format
=
self
.
_data_format
)
class
QuantizedConv2DTranspose
(
layers
.
Layer
):
"""
The computational logic of QuantizedConv2DTranspose is the same with Conv2DTranspose.
The only difference is that its inputs are all fake quantized.
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
from paddle.nn.quant.quant_layers import QuantizedConv2DTranspose
x_var = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
conv = nn.Conv2DTranspose(4, 6, (3, 3))
conv_quantized = QuantizedConv2DTranspose(conv)
y_quantized = conv_quantized(x_var)
y_var = conv(x_var)
y_quantized_np = y_quantized.numpy()
y_np = y_var.numpy()
print(y_np.shape, y_quantized_np.shape)
# (2, 6, 10, 10), (2, 6, 10, 10)
"""
def
__init__
(
self
,
layer
,
weight_bits
=
8
,
activation_bits
=
8
,
moving_rate
=
0.9
,
weight_quantize_type
=
'abs_max'
,
activation_quantize_type
=
'abs_max'
,
weight_pre_layer
=
None
,
act_pre_layer
=
None
,
weight_quant_layer
=
None
,
act_quant_layer
=
None
):
r
"""
Constructor.
The arguments are the same as ImperativeQuantAware.
"""
super
(
QuantizedConv2DTranspose
,
self
).
__init__
()
# For Conv2DTranspose
self
.
_groups
=
getattr
(
layer
,
'_groups'
)
self
.
_stride
=
getattr
(
layer
,
'_stride'
)
self
.
_padding
=
getattr
(
layer
,
'_padding'
)
self
.
_output_padding
=
getattr
(
layer
,
'output_padding'
)
self
.
_dilation
=
getattr
(
layer
,
'_dilation'
)
self
.
_data_format
=
getattr
(
layer
,
'_data_format'
)
self
.
weight
=
getattr
(
layer
,
'weight'
)
self
.
bias
=
getattr
(
layer
,
'bias'
)
# For FakeQuant
self
.
_conv2d_transpose_quant_axis
=
1
if
weight_quant_layer
is
not
None
:
self
.
_fake_quant_weight
=
weight_quant_layer
()
else
:
self
.
_fake_quant_weight
=
_get_fake_quant_type
(
weight_quantize_type
,
name
=
self
.
weight
.
name
,
moving_rate
=
moving_rate
,
quant_bits
=
weight_bits
,
dtype
=
self
.
_dtype
,
quant_on_weight
=
True
,
channel_num
=
self
.
weight
.
shape
[
self
.
_conv2d_transpose_quant_axis
],
quant_axis
=
self
.
_conv2d_transpose_quant_axis
)
if
act_quant_layer
is
not
None
:
self
.
_fake_quant_input
=
act_quant_layer
()
else
:
self
.
_fake_quant_input
=
_get_fake_quant_type
(
activation_quantize_type
,
name
=
layer
.
full_name
(),
moving_rate
=
moving_rate
,
quant_bits
=
activation_bits
,
dtype
=
self
.
_dtype
,
quant_on_weight
=
False
)
self
.
_act_preprocess
=
act_pre_layer
(
)
if
act_pre_layer
is
not
None
else
None
self
.
_weight_preprocess
=
weight_pre_layer
(
)
if
weight_pre_layer
is
not
None
else
None
def
forward
(
self
,
input
,
output_size
=
None
):
if
self
.
_act_preprocess
is
not
None
:
input
=
self
.
_act_preprocess
(
input
)
quant_input
=
self
.
_fake_quant_input
(
input
)
weight
=
self
.
weight
if
self
.
_weight_preprocess
is
not
None
:
weight
=
self
.
_weight_preprocess
(
self
.
weight
)
quant_weight
=
self
.
_fake_quant_weight
(
weight
)
if
output_size
is
None
:
output_padding
=
self
.
_output_padding
else
:
output_padding
=
0
return
F
.
conv2d_transpose
(
quant_input
,
quant_weight
,
bias
=
self
.
bias
,
padding
=
self
.
_padding
,
output_padding
=
output_padding
,
stride
=
self
.
_stride
,
dilation
=
self
.
_dilation
,
groups
=
self
.
_groups
,
output_size
=
output_size
,
data_format
=
self
.
_data_format
)
class
QuantizedLinear
(
layers
.
Layer
):
"""
The computational logic of QuantizedLinear is the same with Linear.
...
...
tools/sampcd_processor.py
浏览文件 @
8967a66a
...
...
@@ -440,6 +440,7 @@ def get_filenames(full_test=False):
'''
global
whl_error
import
paddle
import
paddle.fluid.contrib.slim.quantization
whl_error
=
[]
if
full_test
:
get_full_api_from_pr_spec
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录