未验证 提交 973dab86 编写于 作者: W whs 提交者: GitHub

Enhance the quantization API with some new features (#50816)

上级 262358e8
......@@ -37,6 +37,7 @@ class LinearQuanterDequanter(Layer):
@staticmethod
def from_quanter(quanter):
assert quanter is not None
return LinearQuanterDequanter(
LinearQuanter.from_quanter(quanter),
LinearDequanter.from_quanter(quanter),
......@@ -208,6 +209,8 @@ class ConvertibleQuantedLayer(Layer, metaclass=abc.ABCMeta):
self, quanter_name
), f"{quanter_name} is not attribute of current layer."
quanter = getattr(self, quanter_name)
if quanter is None:
return None
quanter = LinearQuanterDequanter.from_quanter(quanter)
setattr(self, quanter_name, quanter)
self._sub_layers[quanter_name] = quanter
......@@ -224,9 +227,10 @@ class ConvertibleQuantedLayer(Layer, metaclass=abc.ABCMeta):
assert not self.converted, "The model should be converted only once."
for weight_name, quanter_name in self.weights_to_quanters():
qdq = self._convert_quanter_to_qdq(quanter_name)
self._quant_weights(weight_name, qdq._quanter)
qdq._quanter = None
qdq._sub_layers['_quanter'] = None
if qdq is not None:
self._quant_weights(weight_name, qdq._quanter)
qdq._quanter = None
qdq._sub_layers['_quanter'] = None
for quanter_name in self.activation_quanters():
self._convert_quanter_to_qdq(quanter_name)
......
......@@ -89,6 +89,7 @@ class QuantConfig(object):
self._type2config = {}
self._model = None
self._qat_layer_mapping = copy.deepcopy(DEFAULT_QAT_LAYER_MAPPINGS)
self._customized_qat_layer_mapping = dict()
self._customized_leaves = []
......@@ -259,6 +260,7 @@ class QuantConfig(object):
source, paddle.nn.Layer
), "The target layer should be a subclass of paddle.nn.qat.Layer"
self._qat_layer_mapping[source] = target
self._customized_qat_layer_mapping[source] = target
def add_customized_leaf(self, layer_type: type):
r"""
......@@ -296,7 +298,11 @@ class QuantConfig(object):
def _get_qat_layer(self, layer: Layer):
q_config = self._get_config_by_layer(layer)
return self.qat_layer_mappings[type(layer)](layer, q_config)
target_type = self._customized_qat_layer_mapping.get(
type(layer), self.qat_layer_mappings.get(type(layer))
)
return target_type(layer, q_config)
def _has_observer_config(self, layer: Layer):
r"""
......@@ -397,6 +403,7 @@ class QuantConfig(object):
for child in model.children():
layer_prefix = child.full_name()
config = self._layer2config.get(model, self.global_config)
config = self._type2config.get(type(child), config)
config = self._prefix2config.get(layer_prefix, config)
if config is not None:
......@@ -413,26 +420,21 @@ class QuantConfig(object):
return self._details_helper(self._model)
def _details_helper(self, layer: Layer):
extra_lines = []
sublayer_lines = []
for name, sublayer in layer.named_children():
sublayer_str = self._details_helper(sublayer)
sublayer_str = self._addindent(sublayer_str, 2)
sublayer_lines.append(
'('
+ name
+ '): '
+ sublayer_str
+ ', '
+ str(self._layer2config[sublayer])
)
if sublayer in self._layer2config:
sublayer_lines.append(
'('
+ name
+ '): '
+ sublayer_str
+ ', '
+ str(self._layer2config[sublayer])
)
final_str = layer.__class__.__name__ + '('
if extra_lines:
if len(extra_lines) > 1:
final_str += '\n ' + '\n '.join(extra_lines) + '\n'
elif len(extra_lines) == 1:
final_str += extra_lines[0]
if sublayer_lines:
final_str += '\n ' + '\n '.join(sublayer_lines) + '\n'
......
......@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle import _legacy_C_ops
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import _varbase_creator
from paddle.framework import ParamAttr
from paddle.framework import ParamAttr, core
from paddle.nn.initializer import Constant
from paddle.utils import unique_name
......@@ -142,7 +144,7 @@ class FakeQuanterWithAbsMaxObserverLayer(BaseQuanter):
)
self._accum.stop_gradient = True
def forward(self, input):
def dynamic_forward(self, input):
attrs = (
'moving_rate',
self._moving_rate,
......@@ -181,6 +183,46 @@ class FakeQuanterWithAbsMaxObserverLayer(BaseQuanter):
return out
def static_forward(self, input):
check_variable_and_dtype(
input, 'input', ['float32'], "FakeQuantMovingAverageAbsMax"
)
attrs = {
'moving_rate': self._moving_rate,
'bit_length': self._bit_length,
'is_test': not self.training,
}
inputs = {"X": [input], "InScale": [self._scale]}
quant_out = self._helper.create_variable(
name="{}.quantized.dequantized".format(input.name),
dtype=input.dtype,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False,
)
outputs = {"Out": [quant_out], "OutScale": [self._scale]}
if self.training:
inputs['InState'] = [self._state]
inputs['InAccum'] = [self._accum]
outputs['OutState'] = [self._state]
outputs['OutAccum'] = [self._accum]
self._helper.append_op(
type="fake_quantize_dequantize_moving_average_abs_max",
inputs=inputs,
outputs=outputs,
attrs=attrs,
)
return quant_out
def forward(self, input):
if paddle.framework.in_dynamic_mode():
return self.dynamic_forward(input)
else:
return self.static_forward(input)
def bit_length(self):
return self._bit_length
......
......@@ -84,11 +84,13 @@ class Quantization(object, metaclass=abc.ABCMeta):
def _convert_to_quant_layers(self, model: Layer, config: QuantConfig):
replaced = {}
for name, child in model.named_children():
if config._is_quantifiable(child):
if type(child) not in config.qat_layer_mappings:
self._convert_to_quant_layers(child, config)
else:
replaced[name] = config._get_qat_layer(child)
if (
config._is_quantifiable(child)
and type(child) in config.qat_layer_mappings
):
replaced[name] = config._get_qat_layer(child)
else:
self._convert_to_quant_layers(child, config)
for key, value in replaced.items():
model._sub_layers[key] = value
......
# copyright (c) 2023 paddlepaddle authors. all rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The quantizer layers should be traced by paddle.jit.save function."""
import os
import tempfile
import unittest
import paddle
from paddle.quantization import QAT, QuantConfig
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
from paddle.quantization.quanters.abs_max import (
FakeQuanterWithAbsMaxObserverLayer,
)
from paddle.vision.models import resnet18
class TestPTQ(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory(dir="./")
self.path = os.path.join(self.temp_dir.name, 'ptq')
def tearDown(self):
self.temp_dir.cleanup()
def _get_model_for_qat(self):
observer = FakeQuanterWithAbsMaxObserver()
model = resnet18()
model.train()
q_config = QuantConfig(activation=None, weight=None)
q_config.add_type_config(
paddle.nn.Conv2D, activation=observer, weight=observer
)
qat = QAT(q_config)
quant_model = qat.quantize(model)
return quant_model, qat
def _count_layers(self, model, layer_type):
count = 0
for _layer in model.sublayers(True):
if isinstance(_layer, layer_type):
count += 1
return count
def test_trace(self):
quant_model, ptq = self._get_model_for_qat()
image = paddle.rand([1, 3, 32, 32], dtype="float32")
quantizer_count_in_dygraph = self._count_layers(
quant_model, FakeQuanterWithAbsMaxObserverLayer
)
save_path = os.path.join(self.path, 'int8_infer')
paddle.jit.save(quant_model, save_path, [image])
print(f"quant_model is saved into {save_path}")
paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace())
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
[
inference_program,
feed_target_names,
fetch_targets,
] = paddle.static.load_inference_model(save_path, exe)
quantizer_count_in_static_model = 0
for _op in inference_program.global_block().ops:
if _op.type == "fake_quantize_dequantize_moving_average_abs_max":
quantizer_count_in_static_model += 1
self.assertEqual(
quantizer_count_in_dygraph, quantizer_count_in_static_model
)
paddle.disable_static()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册