diff --git a/paddleslim/quant/quanters/__init__.py b/paddleslim/quant/quanters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6dd16d6031ac2e111b85314d7ced068fff2227d6 --- /dev/null +++ b/paddleslim/quant/quanters/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .lsq_act import ActLSQplusQuanter +from .lsq_weight import WeightLSQplusQuanter +from .pact import PACTQuanter + +__all__ = ["ActLSQplusQuanter", "WeightLSQplusQuanter", "PACTQuanter"] \ No newline at end of file diff --git a/paddleslim/quant/quanters/base_fake_quanter.py b/paddleslim/quant/quanters/base_fake_quanter.py new file mode 100644 index 0000000000000000000000000000000000000000..45b026fde1a10099d1fc9d78852b9708d6fe2a59 --- /dev/null +++ b/paddleslim/quant/quanters/base_fake_quanter.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import paddle +import numpy as np +from paddle.quantization.base_quanter import BaseQuanter + + +class BaseFakeQuanterLayer(BaseQuanter): + def __init__( + self, + quant_bits=8, + sign=True, + symmetric=True, ): + super(BaseFakeQuanterLayer, self).__init__() + self._quant_bits = quant_bits + self._sign = sign + self._symmetric = symmetric + + self._min = None + self._max = None + self._qmin = None + self._qmax = None + + self._scale = None + self._zero_point = None + + @property + def qmin_qmax(self): + """ Get the range of the integer.""" + if self._qmin is not None and self._qmax is not None: + return self.qmin, self.qmax + if self._sign: + self.qmin = -2**(self.bit_length() - 1) + self.qmax = 2**(self.bit_length() - 1) - 1 + else: + self.qmin = 0 + self.qmax = 2**self.bit_length() + return self.qmin, self.qmax diff --git a/paddleslim/quant/quanters/lsq_act.py b/paddleslim/quant/quanters/lsq_act.py new file mode 100644 index 0000000000000000000000000000000000000000..220a7dd27628769778f30564908d0b1428280e68 --- /dev/null +++ b/paddleslim/quant/quanters/lsq_act.py @@ -0,0 +1,197 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import math +from paddle.framework import ParamAttr +from paddle.nn import Layer +from paddle.nn.initializer import Constant +from paddle.utils import unique_name +from paddle.quantization.factory import QuanterFactory +from .base_fake_quanter import BaseFakeQuanterLayer +from .lsq_func import LsqFunc, LsqPlusActFunc, round + + +class ActLSQplusQuanter(QuanterFactory): + r""" + Activation quantizer. More details can be found in + https://arxiv.org/pdf/1902.08153.pdf and https://arxiv.org/pdf/2004.09576.pdf. + Args: + per_channel(bool): whether layer-wise or channel-wise quantization, where True for layer-wise quantization and False for channel-wise quantization. + batch_init(int): number of batches that collect Gaussian approximation for the weight distribution in each layer. + quant_linear(bool): whether the weight is from Linear. + dtype(str): data type. + name(str): the name of the layer. + reduce_type(str): the reduce type which is needed when parallel training. + Examples: + .. code-block:: python + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import ActLSQplusQuanter, WeightLSQplusQuanter + weight_quanter = WeightLSQplusQuanter() + act_quanter = ActLSQplusQuanter() + q_config = QuantConfig(activation=act_quanter, weight=weight_quanter) + """ + + def __init__(self, + quant_bits=8, + sign=True, + symmetric=True, + per_channel=False, + batch_init=20, + quant_linear=False, + reduce_type=None, + dtype='float32', + name=None): + super(ActLSQplusQuanter, self).__init__( + quant_bits=quant_bits, + sign=sign, + symmetric=symmetric, + per_channel=per_channel, + batch_init=batch_init, + quant_linear=quant_linear, + reduce_type=reduce_type, + dtype=dtype, + name=name) + + def _get_class(self): + return ActLSQplusQuanterLayer + + +class ActLSQplusQuanterLayer(BaseFakeQuanterLayer): + def __init__(self, + layer, + quant_bits=8, + sign=True, + symmetric=True, + per_channel=False, + batch_init=20, + quant_linear=False, + reduce_type=None, + dtype='float32', + name=None): + super(ActLSQplusQuanterLayer, self).__init__() + self._symmetric = symmetric + self._per_channel = per_channel + self._quant_linear = quant_linear + self._batch_init = batch_init + self._name = name + self._quant_axis = 1 if quant_linear else 0 + self._collect_axis = 0 if quant_linear else 1 + self._reduce_type = reduce_type + self.div = 2**self._quant_bits - 1 + self.qmin, self.qmax = self.qmin_qmax + + self._current_batch_id = 0 + self._init_state = 0 + + scale_prefix = ("{}.scale".format(name) + if name else 'quant_dequant.scale') + self._scale_name = unique_name.generate(scale_prefix) + + s_attr = ParamAttr( + name=self._scale_name, initializer=Constant(1.0), trainable=True) + self._scale = self.create_parameter(shape=[1], attr=s_attr, dtype=dtype) + self._scale.stop_gradient = False + + if not self._symmetric: + beta_prefix = ("{}.beta".format(name) + if name else 'quant_dequant.beta') + self._beta_name = unique_name.generate(beta_prefix) + + beta_attr = ParamAttr( + name=self._beta_name, initializer=Constant(0.0), trainable=True) + self._beta = self.create_parameter( + shape=[1], attr=beta_attr, dtype='float32') + self._beta.stop_gradient = False + + def init_params(self, activation): + self.g = paddle.to_tensor( + 1.0 / math.sqrt(activation.numel() * self.qmax)) + min_a = paddle.min(activation.detach()) + max_a = paddle.max(activation.detach()) + self._scale.set_value((max_a - min_a) / (self.qmax - self.qmin)) + if not self._symmetric: + self._beta.set_value(min_a - self._scale * self.qmin) + self._init_state += 1 + + def collect_gaussian(self, activation): + min_a = paddle.min(activation.detach()) + max_a = paddle.max(activation.detach()) + self._scale.set_value(self._scale * 0.9 + 0.1 * (max_a - min_a) / + (self.qmax - self.qmin)) + if not self._symmetric: + self._beta.set_value(self._scale * 0.9 + 0.1 * + (min_a - self._scale * self.qmin)) + self._init_state += 1 + + def forward(self, activation): + + if self._reduce_type == "max": + paddle.distributed.all_reduce( + self._scale, op=paddle.distributed.ReduceOp.MAX) + + if not self._symmetric and self._reduce_type == "max": + paddle.distributed.all_reduce( + self._beta, op=paddle.distributed.ReduceOp.MAX) + + if self._init_state == 0: + self.init_params(activation) + elif self._init_state < self._batch_init: + self.collect_gaussian(activation) + + activation.stop_gradient = False + + if not self._symmetric: + q_a = LsqPlusActFunc.apply(activation, self._scale, self._beta, + self.g, self.qmin, self.qmax) + else: + q_a = LsqFunc.apply( + activation, + self._scale, + self.g, + self.qmin, + self.qmax, + per_channel=False) + return q_a + + def bit_length(self): + """ Return the bit length of quantized data. + """ + return self._quant_bits + + def quant_axis(self): + """ Return quantization axis. + """ + return self._quant_axis + + def scales(self): + """ Return output scales. + """ + return self._scale + + def zero_points(self): + """ Return output zero points. + """ + if self._zero_point is None: + if self._symmetric: + if self._sign: + self._zero_point = 0 + else: + self._zero_point = (self.qmax + self.qmin) / 2 + else: + self._zero_point = self.qmin - round(self.qmin / self._scale) + self._zero_point = paddle.clip(self._zero_point, self.qmin, + self.qmax) + return self._zero_point diff --git a/paddleslim/quant/quanters/lsq_func.py b/paddleslim/quant/quanters/lsq_func.py new file mode 100644 index 0000000000000000000000000000000000000000..d4de9cb6b84d89406f452a262e59b32701f96c8d --- /dev/null +++ b/paddleslim/quant/quanters/lsq_func.py @@ -0,0 +1,99 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import paddle +from paddle.autograd import PyLayer + + +def round(x): + sign = paddle.sign(x) + x = sign * paddle.floor(paddle.abs(x) + 0.5) + return x + + +class LsqFunc(PyLayer): + @staticmethod + def forward(ctx, weight, alpha, g, Qn, Qp, per_channel=False, quant_axis=0): + ctx.save_for_backward(weight, alpha) + ctx.other = g, Qn, Qp, per_channel, quant_axis + if per_channel: + sizes = weight.shape + weight = weight.reshape((weight.shape[quant_axis], -1)) + weight = weight.transpose((1, 0)) + alpha = paddle.broadcast_to(alpha, weight.shape) + quant_w = round(paddle.divide(weight, alpha)).clip(Qn, Qp) + quant_w = quant_w * alpha + quant_w = quant_w.transpose((1, 0)) + quant_w = quant_w.reshape(sizes) + else: + quant_w = round(paddle.divide(weight, alpha)).clip(Qn, Qp) + quant_w = quant_w * alpha + return quant_w + + @staticmethod + def backward(ctx, grad_weight): + weight, alpha = ctx.saved_tensor() + g, Qn, Qp, per_channel, quant_axis = ctx.other + if per_channel: + sizes = weight.shape + weight = weight.reshape((weight.shape[quant_axis], -1)) + weight = weight.transpose((1, 0)) + alpha = paddle.broadcast_to(alpha, weight.shape) + q_w = paddle.divide(weight, alpha) + q_w = q_w.transpose((1, 0)) + q_w = q_w.reshape(sizes) + else: + q_w = paddle.divide(weight, alpha) + lower_flag = paddle.cast((q_w < Qn), 'float32') + upper_flag = paddle.cast((q_w > Qp), 'float32') + middle_flag = 1.0 - lower_flag - upper_flag + if per_channel: + grad_alpha = ( + (lower_flag * Qn + upper_flag * Qp + middle_flag * round(q_w) - + middle_flag * q_w) * grad_weight * g) + grad_alpha = grad_alpha.reshape((grad_alpha.shape[quant_axis], + -1)).sum(axis=1) + else: + grad_alpha = (( + (lower_flag * Qn + upper_flag * Qp + middle_flag * round(q_w) + - middle_flag * q_w) * grad_weight * g).sum().unsqueeze( + axis=0)[0]) + grad_weight = middle_flag * grad_weight + return grad_weight, grad_alpha + + +class LsqPlusActFunc(PyLayer): + @staticmethod + def forward(ctx, x, alpha, beta, g, Qn, Qp): + ctx.save_for_backward(x, alpha, beta) + ctx.other = g, Qn, Qp + quant_x = round(paddle.divide((x - beta), alpha)).clip(Qn, Qp) + return quant_x * alpha + beta + + @staticmethod + def backward(ctx, grad_x): + x, alpha, beta = ctx.saved_tensor() + g, Qn, Qp = ctx.other + q_x = (x - beta) / alpha + lower_flag = paddle.cast((q_x < Qn), 'float32') + upper_flag = paddle.cast((q_x > Qp), 'float32') + middle_flag = 1.0 - lower_flag - upper_flag + grad_alpha = (( + (lower_flag * Qn + upper_flag * Qp + middle_flag * round(q_x) - + middle_flag * q_x) * grad_x * g).sum().unsqueeze(axis=0)[0]) + grad_beta = (((lower_flag + upper_flag) * grad_x * g).sum().unsqueeze( + axis=0)[0]) + grad_x = middle_flag * grad_x + return grad_x, grad_alpha, grad_beta diff --git a/paddleslim/quant/quanters/lsq_weight.py b/paddleslim/quant/quanters/lsq_weight.py new file mode 100644 index 0000000000000000000000000000000000000000..cb8badd19bd3104b23e390a6667bfe377e8e1cbd --- /dev/null +++ b/paddleslim/quant/quanters/lsq_weight.py @@ -0,0 +1,204 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import paddle +import numpy as np +import math +from paddle.framework import ParamAttr +from paddle.nn import Layer +from paddle.nn.initializer import Constant +from paddle.utils import unique_name +from paddle.quantization.factory import QuanterFactory +from .lsq_func import LsqFunc, round +from .base_fake_quanter import BaseFakeQuanterLayer + + +class WeightLSQplusQuanter(QuanterFactory): + r""" + Weight quantizer. More details can be found in + https://arxiv.org/pdf/1902.08153.pdf and https://arxiv.org/pdf/2004.09576.pdf. + Args: + per_channel(bool): Whether layer-wise or channel-wise quantization, where True for layer-wise quantization and False for channel-wise quantization. + batch_init(int): Number of batches that collect Gaussian approximation for the weight distribution in each layer. + quant_linear(bool): whether the weight is from Linear. + dtype(str): Trainable data type. + name(str): The name of the layer. + reduce_type(str): The reduce type which is needed when parallel training. + Examples: + .. code-block:: python + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import ActLSQplusQuanter, WeightLSQplusQuanter + weight_quanter = WeightLSQplusQuanter() + act_quanter = ActLSQplusQuanter() + q_config = QuantConfig(activation=act_quanter, weight=weight_quanter) + """ + + def __init__(self, + quant_bits=8, + sign=True, + symmetric=True, + per_channel=False, + batch_init=20, + quant_linear=False, + channel_num=None, + reduce_type=None, + dtype='float32', + name=None): + super(WeightLSQplusQuanter, self).__init__( + quant_bits=quant_bits, + sign=sign, + symmetric=symmetric, + per_channel=per_channel, + batch_init=batch_init, + quant_linear=quant_linear, + channel_num=channel_num, + reduce_type=reduce_type, + dtype=dtype, + name=name) + + def _get_class(self): + return WeightLSQplusQuanterLayer + + +class WeightLSQplusQuanterLayer(BaseFakeQuanterLayer): + def __init__(self, + layer, + quant_bits=8, + sign=True, + symmetric=True, + per_channel=False, + all_postive=False, + batch_init=20, + quant_linear=False, + channel_num=None, + reduce_type=None, + dtype='float32', + name=None): + super(WeightLSQplusQuanterLayer, self).__init__() + + self._per_channel = per_channel + self._quant_linear = quant_linear + self._batch_init = batch_init + self._name = name + self._quant_axis = 1 if quant_linear else 0 + self._collect_axis = 0 if quant_linear else 1 + self._reduce_type = reduce_type + self.div = 2**self._quant_bits - 1 + self.qmin, self.qmax = self.qmin_qmax + + self._current_batch_id = 0 + self._init_state = 0 + scale_prefix = ("{}.scale".format(name) + if name else 'quant_dequant.scale') + self._scale_name = unique_name.generate(scale_prefix) + s_attr = ParamAttr( + name=self._scale_name, initializer=Constant(1.0), trainable=True) + + channel_num = layer.weight.shape[ + self._quant_axis] if self._per_channel else 1 + + self._scale = self.create_parameter( + shape=[channel_num], attr=s_attr, dtype=dtype) + self._scale.stop_gradient = False + + def init_params(self, weight): + self.g = paddle.to_tensor(1.0 / math.sqrt(weight.numel() * self.qmax)) + if self._per_channel: + weight_tmp = weight.detach().reshape((weight.shape[0], -1)) + mean = paddle.mean(weight_tmp, axis=self._collect_axis) + std = paddle.std(weight_tmp, axis=self._collect_axis) + s = paddle.max( + paddle.stack( + [paddle.abs(mean - 3 * std), + paddle.abs(mean + 3 * std)]), + axis=0, ) + self._scale.set_value(s / self.div) + else: + mean = paddle.mean(weight.detach()) + std = paddle.std(weight.detach()) + self._scale.set_value( + max([paddle.abs(mean - 3 * std), + paddle.abs(mean + 3 * std)]) / self.div) + self._init_state += 1 + + def collect_gaussian(self, weight): + if self._per_channel: + weight_tmp = weight.detach().reshape((weight.shape[0], -1)) + mean = paddle.mean(weight_tmp, axis=self._collect_axis) + std = paddle.std(weight_tmp, axis=self._collect_axis) + s = paddle.max( + paddle.stack( + [paddle.abs(mean - 3 * std), + paddle.abs(mean + 3 * std)]), + axis=0, ) + self._scale.set_value(s * 0.9 + 0.1 * s / self.div) + else: + mean = paddle.mean(weight.detach()) + std = paddle.std(weight.detach()) + self._scale.set_value(self._scale * 0.9 + 0.1 * max( + [paddle.abs(mean - 3 * std), + paddle.abs(mean + 3 * std)]) / self.div) + self._init_state += 1 + + def forward(self, weight): + if self._reduce_type == "max": + paddle.distributed.all_reduce( + self._scale, op=paddle.distributed.ReduceOp.MAX) + + if self._init_state == 0: + self.init_params(weight) + elif self._init_state < self._batch_init: + self.collect_gaussian(weight) + + weight.stop_gradient = False + w_q = LsqFunc.apply( + weight, + self._scale, + self.g, + self.qmin, + self.qmax, + self._per_channel, + self._quant_axis, ) + return w_q + + def bit_length(self): + """ Return the bit length of quantized data. + """ + return self._quant_bits + + def quant_axis(self): + """ Return quantization axis. + """ + return self._quant_axis + + def scales(self): + """ Return output scales. + """ + return self._scale + + def zero_points(self): + """ Return output zero points. + """ + if self._zero_point is None: + if self._symmetric: + if self._sign: + self._zero_point = 0 + else: + self._zero_point = (self.qmax + self.qmin) / 2 + else: + self._zero_point = self.qmin - round(self.qmin / self._scale) + self._zero_point = paddle.clip(self._zero_point, self.qmin, + self.qmax) + return self._zero_point diff --git a/paddleslim/quant/quanters/pact.py b/paddleslim/quant/quanters/pact.py new file mode 100644 index 0000000000000000000000000000000000000000..adcf80fcb0111ebfaf9551facee3634647023f3f --- /dev/null +++ b/paddleslim/quant/quanters/pact.py @@ -0,0 +1,111 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import paddle +import numpy as np +import math +from paddle.framework import ParamAttr +from paddle.nn import Layer +from paddle.nn.initializer import Constant +from paddle.utils import unique_name +from paddle.quantization.factory import QuanterFactory +from paddle.quantization.base_quanter import BaseQuanter + + +class PACTQuanter(QuanterFactory): + r""" + PArameterized Clipping acTivation(PACT) uses an activation clipping parameter alpha to find the right quantization scale. + More details can be found in + https://arxiv.org/pdf/1805.06085.pdf. + Args: + quanter(BaseQuanter, required): It can be any BaseQuanter. PACT can be used with any other quantization method. + init_value(float, optional): Value of initial alpha. Default 100 + learning_rate(float, optional): The learning rate of alpha when optimizing. + dtype(str): Trainable data type. + name(str): The name of the layer. + Examples: + .. code-block:: python + from paddle.quantization import QuantConfig + from paddle.quantization.quanters import PACTQuanter + from paddle.quantization.quanters.abs_max import FakeQuanterWithAbsMaxObserverLayer + pact_quanter = PACTQuanter(quanter=FakeQuanterWithAbsMaxObserverLayer) + q_config = QuantConfig(activation=pact_quanter, weight=pact_quanter) + """ + + def __init__(self, + quanter, + init_value=100., + learning_rate=1000., + dtype='float32', + name=None): + super(PACTQuanter, self).__init__( + quanter=quanter, + init_value=init_value, + learning_rate=learning_rate, + dtype=dtype, + name=name) + + def _get_class(self): + return PACTQuanterLayer + + +class PACTQuanterLayer(BaseQuanter): + def __init__(self, + layer, + quanter, + init_value=1000, + learning_rate=1000., + dtype='float32', + name=None): + super(PACTQuanterLayer, self).__init__() + + self.quanter = quanter(layer) + alpha_prefix = ("{}.pact".format(name) + if name else 'quant_dequant.pact') + name = unique_name.generate(alpha_prefix) + + alpha_attr = paddle.ParamAttr( + name=name, + initializer=paddle.nn.initializer.Constant(value=init_value), + learning_rate=learning_rate) + + self.alpha = self.create_parameter( + shape=[1], attr=alpha_attr, dtype=dtype) + + def forward(self, activation): + out_left = paddle.nn.functional.relu(activation - self.alpha) + out_right = paddle.nn.functional.relu(-self.alpha - activation) + activation = activation - out_left + out_right + return self.quanter(activation) + + def bit_length(self): + """ Return the bit length of quantized data. + """ + return self.quanter.bit_length() + + def quant_axis(self): + """ Return quantization axis. + """ + return self.quanter.quant_axis() + + def scales(self): + """ Return output scales. + """ + return self.quanter.scales() + + def zero_points(self): + """ Return output zero points. + """ + return self.quanter.zero_points() diff --git a/tests/quantization/test_quanters.py b/tests/quantization/test_quanters.py new file mode 100644 index 0000000000000000000000000000000000000000..4b0bf9f55b2ae3dd7295397c7a0c9a7dca8ca6cc --- /dev/null +++ b/tests/quantization/test_quanters.py @@ -0,0 +1,267 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +import unittest +import paddle +import tempfile +import numpy as np +sys.path.append("../../") + +from paddle.vision.models import resnet18 +from paddle.quantization import QuantConfig +from paddle.quantization import QAT +from paddleslim.quant.quanters import ActLSQplusQuanter, WeightLSQplusQuanter, PACTQuanter +from paddleslim.quant.quanters.lsq_act import ActLSQplusQuanterLayer +from paddleslim.quant.quanters.lsq_weight import WeightLSQplusQuanterLayer +from paddleslim.quant.quanters.pact import PACTQuanterLayer +from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver +from paddle.quantization.quanters.abs_max import FakeQuanterWithAbsMaxObserverLayer +from paddle.nn.quant.format import LinearDequanter, LinearQuanter + +import logging +from paddleslim.common import get_logger +_logger = get_logger(__name__, level=logging.INFO) + + +class ImperativeLenet(paddle.nn.Layer): + def __init__(self, num_classes=10, classifier_activation='softmax'): + super(ImperativeLenet, self).__init__() + self.features = paddle.nn.Sequential( + paddle.nn.Conv2D( + in_channels=1, + out_channels=6, + kernel_size=3, + stride=1, + padding=1), + paddle.nn.AvgPool2D(kernel_size=2, stride=2), + paddle.nn.Conv2D( + in_channels=6, + out_channels=16, + kernel_size=5, + stride=1, + padding=0), paddle.nn.AvgPool2D(kernel_size=2, stride=2)) + + self.fc = paddle.nn.Sequential( + paddle.nn.Linear(in_features=400, out_features=120), + paddle.nn.Linear(in_features=120, out_features=84), + paddle.nn.Linear(in_features=84, out_features=num_classes), ) + + def forward(self, inputs): + x = self.features(inputs) + + x = paddle.flatten(x, 1) + x = self.fc(x) + return x + + +class TestQATWithQuanters(unittest.TestCase): + def __init__(self, act_observer, act_observer_type, weight_observer, + weight_observer_type, *args, **kvargs): + super(TestQATWithQuanters, self).__init__(*args, **kvargs) + self.act_observer = act_observer + self.act_observer_type = act_observer_type + self.weight_observer = weight_observer + self.weight_observer_type = weight_observer_type + + def setUp(self): + self.init_case() + self.dummy_input = paddle.rand([1, 3, 224, 224]) + self.temp_dir = tempfile.TemporaryDirectory(dir="./") + self.path = os.path.join(self.temp_dir.name, 'qat') + if not os.path.exists('ILSVRC2012_data_demo'): + os.system( + 'wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz' + ) + os.system('tar -xf ILSVRC2012_data_demo.tar.gz') + seed = 1 + np.random.seed(seed) + paddle.static.default_main_program().random_seed = seed + paddle.static.default_startup_program().random_seed = seed + + def tearDown(self): + self.temp_dir.cleanup() + + def runTest(self): + self.test_quantize() + self.test_convert() + self.test_convergence() + + def init_case(self): + self.q_config = QuantConfig(activation=None, weight=None) + self.q_config.add_type_config( + paddle.nn.Conv2D, + activation=self.act_observer, + weight=self.weight_observer) + + def _count_layers(self, model, layer_type): + count = 0 + for _layer in model.sublayers(True): + if isinstance(_layer, layer_type): + count += 1 + return count + + def test_quantize(self): + model = resnet18() + conv_count = self._count_layers(model, paddle.nn.Conv2D) + qat = QAT(self.q_config) + model.train() + quant_model = qat.quantize(model, inplace=False) + out = quant_model(self.dummy_input) + quantizer_cnt = self._count_layers(quant_model, self.act_observer_type) + self.assertEqual(quantizer_cnt, conv_count) + quantizer_cnt = self._count_layers(quant_model, + self.weight_observer_type) + self.assertEqual(quantizer_cnt, conv_count) + + def test_convergence(self): + model = ImperativeLenet() + conv_count = self._count_layers(model, paddle.nn.Conv2D) + qat = QAT(self.q_config) + model.train() + quant_model = qat.quantize(model, inplace=False) + place = paddle.CUDAPlace(0) \ + if paddle.is_compiled_with_cuda() else paddle.CPUPlace() + + transform = paddle.vision.transforms.Compose([ + paddle.vision.transforms.Transpose(), + paddle.vision.transforms.Normalize([127.5], [127.5]) + ]) + + train_dataset = paddle.vision.datasets.MNIST( + mode='train', backend='cv2', transform=transform) + val_dataset = paddle.vision.datasets.MNIST( + mode='test', backend='cv2', transform=transform) + + train_reader = paddle.io.DataLoader( + train_dataset, + drop_last=True, + places=place, + batch_size=64, + return_list=True) + test_reader = paddle.io.DataLoader( + val_dataset, places=place, batch_size=64, return_list=True) + + def train(model): + adam = paddle.optimizer.Adam( + learning_rate=0.0001, parameters=model.parameters()) + epoch_num = 1 + for epoch in range(epoch_num): + model.train() + for batch_id, data in enumerate(train_reader): + img = paddle.to_tensor(data[0]) + label = paddle.to_tensor(data[1]) + img = paddle.reshape(img, [-1, 1, 28, 28]) + label = paddle.reshape(label, [-1, 1]) + + out = model(img) + acc = paddle.metric.accuracy(out, label) + loss = paddle.nn.functional.loss.cross_entropy(out, label) + avg_loss = paddle.mean(loss) + avg_loss.backward() + adam.minimize(avg_loss) + model.clear_gradients() + if batch_id % 100 == 0: + _logger.info( + "Train | At epoch {} step {}: loss = {:}, acc= {:}". + format(epoch, batch_id, + avg_loss.numpy(), acc.numpy())) + + def test(model): + model.eval() + avg_acc = [[], []] + for batch_id, data in enumerate(test_reader): + img = paddle.to_tensor(data[0]) + img = paddle.reshape(img, [-1, 1, 28, 28]) + label = paddle.to_tensor(data[1]) + label = paddle.reshape(label, [-1, 1]) + + out = model(img) + acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1) + acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5) + avg_acc[0].append(acc_top1.numpy()) + avg_acc[1].append(acc_top5.numpy()) + if batch_id % 100 == 0: + _logger.info( + "Test | step {}: acc1 = {:}, acc5 = {:}".format( + batch_id, acc_top1.numpy(), acc_top5.numpy())) + + _logger.info("Test | Average: acc_top1 {}, acc_top5 {}".format( + np.mean(avg_acc[0]), np.mean(avg_acc[1]))) + return np.mean(avg_acc[0]), np.mean(avg_acc[1]) + + train(model) + top1_1, top5_1 = test(model) + + quant_model.train() + train(quant_model) + top1_2, top5_2 = test(quant_model) + + _logger.info( + "Before quantization: top1: {}, top5: {}".format(top1_1, top5_1)) + _logger.info( + "After quantization: top1: {}, top5: {}".format(top1_2, top5_2)) + _logger.info("\n") + + diff = 0.01 + self.assertTrue( + top1_1 - top1_2 < diff, + msg="The acc of quant model is too lower than fp32 model") + _logger.info('done') + return + + def test_convert(self): + model = resnet18() + conv_count = self._count_layers(model, paddle.nn.Conv2D) + qat = QAT(self.q_config) + model.train() + quant_model = qat.quantize(model, inplace=False) + out = quant_model(self.dummy_input) + converted_model = qat.convert(quant_model, inplace=False) + + # check count of LinearQuanter and LinearDequanter in dygraph + quantizer_count_in_dygraph = self._count_layers(converted_model, + LinearQuanter) + dequantizer_count_in_dygraph = self._count_layers( + converted_model, LinearDequanter) + self.assertEqual(quantizer_count_in_dygraph, conv_count) + self.assertEqual(dequantizer_count_in_dygraph, conv_count * 2) + + +observer_suite = unittest.TestSuite() +observer_suite.addTest( + TestQATWithQuanters( + act_observer=ActLSQplusQuanter(), + act_observer_type=ActLSQplusQuanterLayer, + weight_observer=WeightLSQplusQuanter(), + weight_observer_type=WeightLSQplusQuanterLayer)) +observer_suite.addTest( + TestQATWithQuanters( + act_observer=ActLSQplusQuanter(symmetric=False), + act_observer_type=ActLSQplusQuanterLayer, + weight_observer=WeightLSQplusQuanter(per_channel=True), + weight_observer_type=WeightLSQplusQuanterLayer)) +observer_suite.addTest( + TestQATWithQuanters( + act_observer=PACTQuanter(quanter=ActLSQplusQuanterLayer), + act_observer_type=PACTQuanterLayer, + weight_observer=WeightLSQplusQuanter(), + weight_observer_type=WeightLSQplusQuanterLayer)) + +if __name__ == '__main__': + runner = unittest.TextTestRunner(verbosity=2) + runner.run(observer_suite) + os.system('rm -rf ILSVRC2012_data_demo.tar.gz') + os.system('rm -rf ILSVRC2012_data_demo')