未验证 提交 8967a66a 编写于 作者: X XGZhang 提交者: GitHub

support quantization of conv2d_transpose (#34547)

上级 4d88cdb8
...@@ -42,17 +42,18 @@ class ImperativeQuantAware(object): ...@@ -42,17 +42,18 @@ class ImperativeQuantAware(object):
Applying quantization aware training (QAT) to the dgraph model. Applying quantization aware training (QAT) to the dgraph model.
""" """
def __init__(self, def __init__(
quantizable_layer_type=['Conv2D', 'Linear'], self,
weight_quantize_type='abs_max', quantizable_layer_type=['Conv2D', 'Linear', 'Conv2DTranspose'],
activation_quantize_type='moving_average_abs_max', weight_quantize_type='abs_max',
weight_bits=8, activation_quantize_type='moving_average_abs_max',
activation_bits=8, weight_bits=8,
moving_rate=0.9, activation_bits=8,
weight_preprocess_layer=None, moving_rate=0.9,
act_preprocess_layer=None, weight_preprocess_layer=None,
weight_quantize_layer=None, act_preprocess_layer=None,
act_quantize_layer=None): weight_quantize_layer=None,
act_quantize_layer=None):
""" """
The constructor for ImperativeQuantAware. The constructor for ImperativeQuantAware.
...@@ -212,9 +213,44 @@ class ImperativeQuantAware(object): ...@@ -212,9 +213,44 @@ class ImperativeQuantAware(object):
the out_scale value of outputs would be calculated. the out_scale value of outputs would be calculated.
Args: Args:
model(fluid.dygraph.Layer): the model to be quantized. model(paddle.nn.Layer): the model to be quantized.
Returns: Returns:
None None
Examples:
.. code-block:: python
import paddle
from paddle.fluid.contrib.slim.quantization \
import ImperativeQuantAware
class ImperativeModel(paddle.nn.Layer):
def __init__(self):
super(ImperativeModel, self).__init__()
# self.linear_0 would skip the quantization.
self.linear_0 = paddle.nn.Linear(784, 400)
self.linear_0.skip_quant = True
# self.linear_1 would not skip the quantization.
self.linear_1 = paddle.nn.Linear(400, 10)
self.linear_1.skip_quant = False
def forward(self, inputs):
x = self.linear_0(inputs)
x = self.linear_1(inputs)
return x
model = ImperativeModel()
imperative_qat = ImperativeQuantAware(
weight_quantize_type='abs_max',
activation_quantize_type='moving_average_abs_max')
# Add the fake quant logical.
# The original model will be rewrite.
#
# There is only one Layer(self.linear1) would be added the
# fake quant logical.
imperative_qat.quantize(model)
""" """
assert isinstance(model, dygraph.Layer), \ assert isinstance(model, dygraph.Layer), \
"The model must be the instance of dygraph.Layer." "The model must be the instance of dygraph.Layer."
...@@ -232,17 +268,18 @@ class ImperativeQuantizeInputs(object): ...@@ -232,17 +268,18 @@ class ImperativeQuantizeInputs(object):
logic both for activation inputs and weight inputs. logic both for activation inputs and weight inputs.
""" """
def __init__(self, def __init__(
quantizable_layer_type=['Conv2D', 'Linear'], self,
weight_quantize_type='abs_max', quantizable_layer_type=['Conv2D', 'Linear', 'Conv2DTranspose'],
activation_quantize_type='moving_average_abs_max', weight_quantize_type='abs_max',
weight_bits=8, activation_quantize_type='moving_average_abs_max',
activation_bits=8, weight_bits=8,
moving_rate=0.9, activation_bits=8,
weight_preprocess_layer=None, moving_rate=0.9,
act_preprocess_layer=None, weight_preprocess_layer=None,
weight_quantize_layer=None, act_preprocess_layer=None,
act_quantize_layer=None): weight_quantize_layer=None,
act_quantize_layer=None):
""" """
The constructor for ImperativeQuantizeInputs. The constructor for ImperativeQuantizeInputs.
...@@ -303,6 +340,18 @@ class ImperativeQuantizeInputs(object): ...@@ -303,6 +340,18 @@ class ImperativeQuantizeInputs(object):
} }
def apply(self, model): def apply(self, model):
"""
Quantize the weights and activations to calculate for specific
layers.
Args:
model(paddle.nn.Layer): The target model which would
calculate the input quantization scale.
Returns:
None
"""
assert isinstance(model, dygraph.Layer), \ assert isinstance(model, dygraph.Layer), \
"The model must be the instance of dygraph.Layer." "The model must be the instance of dygraph.Layer."
...@@ -354,7 +403,7 @@ class ImperativeQuantizeOutputs(object): ...@@ -354,7 +403,7 @@ class ImperativeQuantizeOutputs(object):
output scales for specific layers in the dygraph model. output scales for specific layers in the dygraph model.
Args: Args:
model(fluid.dygraph.Layer): The target model which would be model(paddle.nn.Layer): The target model which would be
calculate the output quantization scale. calculate the output quantization scale.
Returns: Returns:
...@@ -544,7 +593,9 @@ class ImperativeQuantizeOutputs(object): ...@@ -544,7 +593,9 @@ class ImperativeQuantizeOutputs(object):
1. the type of input op should be conv2d, depthwise_conv2d or matmul 1. the type of input op should be conv2d, depthwise_conv2d or matmul
2. the previous ops of the input op are not fake_quantize_dequantize ops 2. the previous ops of the input op are not fake_quantize_dequantize ops
""" """
target_op_types = ["conv2d", "depthwise_conv2d", "matmul"] target_op_types = [
"conv2d", "depthwise_conv2d", "matmul", "conv2d_transpose"
]
if in_op.type not in target_op_types: if in_op.type not in target_op_types:
return False return False
......
...@@ -24,6 +24,7 @@ from ..quantization_pass import _get_output_name_index ...@@ -24,6 +24,7 @@ from ..quantization_pass import _get_output_name_index
from ..quantization_pass import _get_input_name_index from ..quantization_pass import _get_input_name_index
layer_name_map = { layer_name_map = {
'Conv2DTranspose': paddle.nn.Conv2DTranspose,
'Conv2D': paddle.nn.Conv2D, 'Conv2D': paddle.nn.Conv2D,
'Linear': paddle.nn.Linear, 'Linear': paddle.nn.Linear,
'AdaptiveAvgPool2D': paddle.nn.AdaptiveAvgPool2D, 'AdaptiveAvgPool2D': paddle.nn.AdaptiveAvgPool2D,
...@@ -46,8 +47,9 @@ layer_name_map = { ...@@ -46,8 +47,9 @@ layer_name_map = {
} }
# Apply fake quant for the inputs of these layers # Apply fake quant for the inputs of these layers
# TODO (jc): support paddle.nn.Conv2DTranspose fake_quant_input_layers = [
fake_quant_input_layers = [paddle.nn.Conv2D, paddle.nn.Linear] paddle.nn.Conv2D, paddle.nn.Linear, paddle.nn.Conv2DTranspose
]
# Apply fake quant for the output of these layers # Apply fake quant for the output of these layers
# TODO(jc): fix the problem of adding duplicate fake_quant ops # TODO(jc): fix the problem of adding duplicate fake_quant ops
...@@ -65,7 +67,8 @@ fake_quant_leaf_layers = [ ...@@ -65,7 +67,8 @@ fake_quant_leaf_layers = [
] ]
fake_quant_wrap_layers = [ fake_quant_wrap_layers = [
quant_layers.QuantizedConv2D, quant_layers.QuantizedLinear quant_layers.QuantizedConv2D, quant_layers.QuantizedLinear,
quant_layers.QuantizedConv2DTranspose
] ]
# The weight format of these layers is Cin * Cout * H * W # The weight format of these layers is Cin * Cout * H * W
...@@ -84,9 +87,9 @@ fake_quantize_dequantize_op_types = [ ...@@ -84,9 +87,9 @@ fake_quantize_dequantize_op_types = [
def load_variable_data(scope, var_name): def load_variable_data(scope, var_name):
''' """
Load variable value from scope Load variable value from scope
''' """
var_node = scope.find_var(var_name) var_node = scope.find_var(var_name)
assert var_node is not None, \ assert var_node is not None, \
"Can not find " + var_name + " in the scope." "Can not find " + var_name + " in the scope."
...@@ -120,6 +123,12 @@ def find_parent_layer_and_sub_name(model, name): ...@@ -120,6 +123,12 @@ def find_parent_layer_and_sub_name(model, name):
the sub_name of the layer. the sub_name of the layer.
For example, if name is 'block_1/convbn_1/conv_1', the parent layer is For example, if name is 'block_1/convbn_1/conv_1', the parent layer is
'block_1/convbn_1' and the sub_name is `conv_1`. 'block_1/convbn_1' and the sub_name is `conv_1`.
Args:
model(paddle.nn.Layer): the model to be quantized.
name(string): the name of a layer
Returns:
parent_layer, subname
""" """
assert isinstance(model, paddle.nn.Layer), \ assert isinstance(model, paddle.nn.Layer), \
"The model must be the instance of paddle.nn.Layer." "The model must be the instance of paddle.nn.Layer."
......
...@@ -28,10 +28,10 @@ from paddle.fluid import core ...@@ -28,10 +28,10 @@ from paddle.fluid import core
from paddle.fluid.optimizer import AdamOptimizer from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware
from paddle.fluid.dygraph.container import Sequential from paddle.fluid.dygraph.container import Sequential
from paddle.nn import Linear, Conv2D, Softmax from paddle.nn import Linear, Conv2D, Softmax, Conv2DTranspose
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.nn.quant.quant_layers import QuantizedConv2D from paddle.nn.quant.quant_layers import QuantizedConv2D, QuantizedConv2DTranspose
from imperative_test_utils import fix_model_dict, ImperativeLenet from imperative_test_utils import fix_model_dict, ImperativeLenet
...@@ -75,6 +75,12 @@ class TestImperativeQat(unittest.TestCase): ...@@ -75,6 +75,12 @@ class TestImperativeQat(unittest.TestCase):
data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32') data = np.random.uniform(-1, 1, [10, 3, 32, 32]).astype('float32')
quant_conv1(fluid.dygraph.to_variable(data)) quant_conv1(fluid.dygraph.to_variable(data))
conv_transpose = Conv2DTranspose(4, 6, (3, 3))
quant_conv_transpose = QuantizedConv2DTranspose(conv_transpose)
x_var = paddle.uniform(
(2, 4, 8, 8), dtype='float32', min=-1.0, max=1.0)
quant_conv_transpose(x_var)
seed = 1 seed = 1
np.random.seed(seed) np.random.seed(seed)
fluid.default_main_program().random_seed = seed fluid.default_main_program().random_seed = seed
......
...@@ -28,6 +28,7 @@ from paddle.nn import Sequential ...@@ -28,6 +28,7 @@ from paddle.nn import Sequential
from paddle.fluid.dygraph import Conv2D from paddle.fluid.dygraph import Conv2D
from paddle.fluid.dygraph import Pool2D from paddle.fluid.dygraph import Pool2D
from paddle.fluid.dygraph import Linear from paddle.fluid.dygraph import Linear
from paddle.nn.quant.quant_layers import QuantizedConv2DTranspose
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
os.environ["CPU_NUM"] = "1" os.environ["CPU_NUM"] = "1"
...@@ -100,6 +101,19 @@ class CustomQAT(nn.Layer): ...@@ -100,6 +101,19 @@ class CustomQAT(nn.Layer):
return x return x
class ModelForConv2dT(nn.Layer):
def __init__(self, num_classes=10):
super(ModelForConv2dT, self).__init__()
self.features = nn.Conv2DTranspose(4, 6, (3, 3))
self.fc = Linear(input_dim=600, output_dim=num_classes)
def forward(self, inputs):
x = self.features(inputs)
x = paddle.flatten(x, 1)
x = self.fc(x)
return x
class ImperativeLenet(paddle.nn.Layer): class ImperativeLenet(paddle.nn.Layer):
def __init__(self, num_classes=10, classifier_activation='softmax'): def __init__(self, num_classes=10, classifier_activation='softmax'):
super(ImperativeLenet, self).__init__() super(ImperativeLenet, self).__init__()
...@@ -168,6 +182,11 @@ class TestUserDefinedActPreprocess(unittest.TestCase): ...@@ -168,6 +182,11 @@ class TestUserDefinedActPreprocess(unittest.TestCase):
imperative_qat.quantize(lenet) imperative_qat.quantize(lenet)
adam = Adam(learning_rate=0.001, parameters=lenet.parameters()) adam = Adam(learning_rate=0.001, parameters=lenet.parameters())
dynamic_loss_rec = [] dynamic_loss_rec = []
#for CI coverage
conv_transpose = ModelForConv2dT()
imperative_qat.quantize(conv_transpose)
x_var = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
conv_transpose(x_var)
def train(model): def train(model):
adam = Adam(learning_rate=0.001, parameters=model.parameters()) adam = Adam(learning_rate=0.001, parameters=model.parameters())
......
...@@ -31,6 +31,7 @@ __all__ = [ ...@@ -31,6 +31,7 @@ __all__ = [
'FakeQuantMovingAverageAbsMax', 'FakeQuantMovingAverageAbsMax',
'FakeQuantChannelWiseAbsMax', 'FakeQuantChannelWiseAbsMax',
'QuantizedConv2D', 'QuantizedConv2D',
'QuantizedConv2DTranspose',
'QuantizedLinear', 'QuantizedLinear',
'MovingAverageAbsMaxScale', 'MovingAverageAbsMaxScale',
'MAOutputScaleLayer', 'MAOutputScaleLayer',
...@@ -481,6 +482,112 @@ class QuantizedConv2D(layers.Layer): ...@@ -481,6 +482,112 @@ class QuantizedConv2D(layers.Layer):
data_format=self._data_format) data_format=self._data_format)
class QuantizedConv2DTranspose(layers.Layer):
"""
The computational logic of QuantizedConv2DTranspose is the same with Conv2DTranspose.
The only difference is that its inputs are all fake quantized.
Examples:
.. code-block:: python
import paddle
import paddle.nn as nn
from paddle.nn.quant.quant_layers import QuantizedConv2DTranspose
x_var = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
conv = nn.Conv2DTranspose(4, 6, (3, 3))
conv_quantized = QuantizedConv2DTranspose(conv)
y_quantized = conv_quantized(x_var)
y_var = conv(x_var)
y_quantized_np = y_quantized.numpy()
y_np = y_var.numpy()
print(y_np.shape, y_quantized_np.shape)
# (2, 6, 10, 10), (2, 6, 10, 10)
"""
def __init__(self,
layer,
weight_bits=8,
activation_bits=8,
moving_rate=0.9,
weight_quantize_type='abs_max',
activation_quantize_type='abs_max',
weight_pre_layer=None,
act_pre_layer=None,
weight_quant_layer=None,
act_quant_layer=None):
r"""
Constructor.
The arguments are the same as ImperativeQuantAware.
"""
super(QuantizedConv2DTranspose, self).__init__()
# For Conv2DTranspose
self._groups = getattr(layer, '_groups')
self._stride = getattr(layer, '_stride')
self._padding = getattr(layer, '_padding')
self._output_padding = getattr(layer, 'output_padding')
self._dilation = getattr(layer, '_dilation')
self._data_format = getattr(layer, '_data_format')
self.weight = getattr(layer, 'weight')
self.bias = getattr(layer, 'bias')
# For FakeQuant
self._conv2d_transpose_quant_axis = 1
if weight_quant_layer is not None:
self._fake_quant_weight = weight_quant_layer()
else:
self._fake_quant_weight = _get_fake_quant_type(
weight_quantize_type,
name=self.weight.name,
moving_rate=moving_rate,
quant_bits=weight_bits,
dtype=self._dtype,
quant_on_weight=True,
channel_num=self.weight.shape[
self._conv2d_transpose_quant_axis],
quant_axis=self._conv2d_transpose_quant_axis)
if act_quant_layer is not None:
self._fake_quant_input = act_quant_layer()
else:
self._fake_quant_input = _get_fake_quant_type(
activation_quantize_type,
name=layer.full_name(),
moving_rate=moving_rate,
quant_bits=activation_bits,
dtype=self._dtype,
quant_on_weight=False)
self._act_preprocess = act_pre_layer(
) if act_pre_layer is not None else None
self._weight_preprocess = weight_pre_layer(
) if weight_pre_layer is not None else None
def forward(self, input, output_size=None):
if self._act_preprocess is not None:
input = self._act_preprocess(input)
quant_input = self._fake_quant_input(input)
weight = self.weight
if self._weight_preprocess is not None:
weight = self._weight_preprocess(self.weight)
quant_weight = self._fake_quant_weight(weight)
if output_size is None:
output_padding = self._output_padding
else:
output_padding = 0
return F.conv2d_transpose(
quant_input,
quant_weight,
bias=self.bias,
padding=self._padding,
output_padding=output_padding,
stride=self._stride,
dilation=self._dilation,
groups=self._groups,
output_size=output_size,
data_format=self._data_format)
class QuantizedLinear(layers.Layer): class QuantizedLinear(layers.Layer):
""" """
The computational logic of QuantizedLinear is the same with Linear. The computational logic of QuantizedLinear is the same with Linear.
......
...@@ -440,6 +440,7 @@ def get_filenames(full_test=False): ...@@ -440,6 +440,7 @@ def get_filenames(full_test=False):
''' '''
global whl_error global whl_error
import paddle import paddle
import paddle.fluid.contrib.slim.quantization
whl_error = [] whl_error = []
if full_test: if full_test:
get_full_api_from_pr_spec() get_full_api_from_pr_spec()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册