未验证 提交 7dce5a2e 编写于 作者: C Chang Xu 提交者: GitHub

Add Quanters (#1686)

上级 f54331a6
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .lsq_act import ActLSQplusQuanter
from .lsq_weight import WeightLSQplusQuanter
from .pact import PACTQuanter
__all__ = ["ActLSQplusQuanter", "WeightLSQplusQuanter", "PACTQuanter"]
\ No newline at end of file
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import paddle
import numpy as np
from paddle.quantization.base_quanter import BaseQuanter
class BaseFakeQuanterLayer(BaseQuanter):
def __init__(
self,
quant_bits=8,
sign=True,
symmetric=True, ):
super(BaseFakeQuanterLayer, self).__init__()
self._quant_bits = quant_bits
self._sign = sign
self._symmetric = symmetric
self._min = None
self._max = None
self._qmin = None
self._qmax = None
self._scale = None
self._zero_point = None
@property
def qmin_qmax(self):
""" Get the range of the integer."""
if self._qmin is not None and self._qmax is not None:
return self.qmin, self.qmax
if self._sign:
self.qmin = -2**(self.bit_length() - 1)
self.qmax = 2**(self.bit_length() - 1) - 1
else:
self.qmin = 0
self.qmax = 2**self.bit_length()
return self.qmin, self.qmax
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import math
from paddle.framework import ParamAttr
from paddle.nn import Layer
from paddle.nn.initializer import Constant
from paddle.utils import unique_name
from paddle.quantization.factory import QuanterFactory
from .base_fake_quanter import BaseFakeQuanterLayer
from .lsq_func import LsqFunc, LsqPlusActFunc, round
class ActLSQplusQuanter(QuanterFactory):
r"""
Activation quantizer. More details can be found in
https://arxiv.org/pdf/1902.08153.pdf and https://arxiv.org/pdf/2004.09576.pdf.
Args:
per_channel(bool): whether layer-wise or channel-wise quantization, where True for layer-wise quantization and False for channel-wise quantization.
batch_init(int): number of batches that collect Gaussian approximation for the weight distribution in each layer.
quant_linear(bool): whether the weight is from Linear.
dtype(str): data type.
name(str): the name of the layer.
reduce_type(str): the reduce type which is needed when parallel training.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import ActLSQplusQuanter, WeightLSQplusQuanter
weight_quanter = WeightLSQplusQuanter()
act_quanter = ActLSQplusQuanter()
q_config = QuantConfig(activation=act_quanter, weight=weight_quanter)
"""
def __init__(self,
quant_bits=8,
sign=True,
symmetric=True,
per_channel=False,
batch_init=20,
quant_linear=False,
reduce_type=None,
dtype='float32',
name=None):
super(ActLSQplusQuanter, self).__init__(
quant_bits=quant_bits,
sign=sign,
symmetric=symmetric,
per_channel=per_channel,
batch_init=batch_init,
quant_linear=quant_linear,
reduce_type=reduce_type,
dtype=dtype,
name=name)
def _get_class(self):
return ActLSQplusQuanterLayer
class ActLSQplusQuanterLayer(BaseFakeQuanterLayer):
def __init__(self,
layer,
quant_bits=8,
sign=True,
symmetric=True,
per_channel=False,
batch_init=20,
quant_linear=False,
reduce_type=None,
dtype='float32',
name=None):
super(ActLSQplusQuanterLayer, self).__init__()
self._symmetric = symmetric
self._per_channel = per_channel
self._quant_linear = quant_linear
self._batch_init = batch_init
self._name = name
self._quant_axis = 1 if quant_linear else 0
self._collect_axis = 0 if quant_linear else 1
self._reduce_type = reduce_type
self.div = 2**self._quant_bits - 1
self.qmin, self.qmax = self.qmin_qmax
self._current_batch_id = 0
self._init_state = 0
scale_prefix = ("{}.scale".format(name)
if name else 'quant_dequant.scale')
self._scale_name = unique_name.generate(scale_prefix)
s_attr = ParamAttr(
name=self._scale_name, initializer=Constant(1.0), trainable=True)
self._scale = self.create_parameter(shape=[1], attr=s_attr, dtype=dtype)
self._scale.stop_gradient = False
if not self._symmetric:
beta_prefix = ("{}.beta".format(name)
if name else 'quant_dequant.beta')
self._beta_name = unique_name.generate(beta_prefix)
beta_attr = ParamAttr(
name=self._beta_name, initializer=Constant(0.0), trainable=True)
self._beta = self.create_parameter(
shape=[1], attr=beta_attr, dtype='float32')
self._beta.stop_gradient = False
def init_params(self, activation):
self.g = paddle.to_tensor(
1.0 / math.sqrt(activation.numel() * self.qmax))
min_a = paddle.min(activation.detach())
max_a = paddle.max(activation.detach())
self._scale.set_value((max_a - min_a) / (self.qmax - self.qmin))
if not self._symmetric:
self._beta.set_value(min_a - self._scale * self.qmin)
self._init_state += 1
def collect_gaussian(self, activation):
min_a = paddle.min(activation.detach())
max_a = paddle.max(activation.detach())
self._scale.set_value(self._scale * 0.9 + 0.1 * (max_a - min_a) /
(self.qmax - self.qmin))
if not self._symmetric:
self._beta.set_value(self._scale * 0.9 + 0.1 *
(min_a - self._scale * self.qmin))
self._init_state += 1
def forward(self, activation):
if self._reduce_type == "max":
paddle.distributed.all_reduce(
self._scale, op=paddle.distributed.ReduceOp.MAX)
if not self._symmetric and self._reduce_type == "max":
paddle.distributed.all_reduce(
self._beta, op=paddle.distributed.ReduceOp.MAX)
if self._init_state == 0:
self.init_params(activation)
elif self._init_state < self._batch_init:
self.collect_gaussian(activation)
activation.stop_gradient = False
if not self._symmetric:
q_a = LsqPlusActFunc.apply(activation, self._scale, self._beta,
self.g, self.qmin, self.qmax)
else:
q_a = LsqFunc.apply(
activation,
self._scale,
self.g,
self.qmin,
self.qmax,
per_channel=False)
return q_a
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self._quant_bits
def quant_axis(self):
""" Return quantization axis.
"""
return self._quant_axis
def scales(self):
""" Return output scales.
"""
return self._scale
def zero_points(self):
""" Return output zero points.
"""
if self._zero_point is None:
if self._symmetric:
if self._sign:
self._zero_point = 0
else:
self._zero_point = (self.qmax + self.qmin) / 2
else:
self._zero_point = self.qmin - round(self.qmin / self._scale)
self._zero_point = paddle.clip(self._zero_point, self.qmin,
self.qmax)
return self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import paddle
from paddle.autograd import PyLayer
def round(x):
sign = paddle.sign(x)
x = sign * paddle.floor(paddle.abs(x) + 0.5)
return x
class LsqFunc(PyLayer):
@staticmethod
def forward(ctx, weight, alpha, g, Qn, Qp, per_channel=False, quant_axis=0):
ctx.save_for_backward(weight, alpha)
ctx.other = g, Qn, Qp, per_channel, quant_axis
if per_channel:
sizes = weight.shape
weight = weight.reshape((weight.shape[quant_axis], -1))
weight = weight.transpose((1, 0))
alpha = paddle.broadcast_to(alpha, weight.shape)
quant_w = round(paddle.divide(weight, alpha)).clip(Qn, Qp)
quant_w = quant_w * alpha
quant_w = quant_w.transpose((1, 0))
quant_w = quant_w.reshape(sizes)
else:
quant_w = round(paddle.divide(weight, alpha)).clip(Qn, Qp)
quant_w = quant_w * alpha
return quant_w
@staticmethod
def backward(ctx, grad_weight):
weight, alpha = ctx.saved_tensor()
g, Qn, Qp, per_channel, quant_axis = ctx.other
if per_channel:
sizes = weight.shape
weight = weight.reshape((weight.shape[quant_axis], -1))
weight = weight.transpose((1, 0))
alpha = paddle.broadcast_to(alpha, weight.shape)
q_w = paddle.divide(weight, alpha)
q_w = q_w.transpose((1, 0))
q_w = q_w.reshape(sizes)
else:
q_w = paddle.divide(weight, alpha)
lower_flag = paddle.cast((q_w < Qn), 'float32')
upper_flag = paddle.cast((q_w > Qp), 'float32')
middle_flag = 1.0 - lower_flag - upper_flag
if per_channel:
grad_alpha = (
(lower_flag * Qn + upper_flag * Qp + middle_flag * round(q_w) -
middle_flag * q_w) * grad_weight * g)
grad_alpha = grad_alpha.reshape((grad_alpha.shape[quant_axis],
-1)).sum(axis=1)
else:
grad_alpha = ((
(lower_flag * Qn + upper_flag * Qp + middle_flag * round(q_w)
- middle_flag * q_w) * grad_weight * g).sum().unsqueeze(
axis=0)[0])
grad_weight = middle_flag * grad_weight
return grad_weight, grad_alpha
class LsqPlusActFunc(PyLayer):
@staticmethod
def forward(ctx, x, alpha, beta, g, Qn, Qp):
ctx.save_for_backward(x, alpha, beta)
ctx.other = g, Qn, Qp
quant_x = round(paddle.divide((x - beta), alpha)).clip(Qn, Qp)
return quant_x * alpha + beta
@staticmethod
def backward(ctx, grad_x):
x, alpha, beta = ctx.saved_tensor()
g, Qn, Qp = ctx.other
q_x = (x - beta) / alpha
lower_flag = paddle.cast((q_x < Qn), 'float32')
upper_flag = paddle.cast((q_x > Qp), 'float32')
middle_flag = 1.0 - lower_flag - upper_flag
grad_alpha = ((
(lower_flag * Qn + upper_flag * Qp + middle_flag * round(q_x) -
middle_flag * q_x) * grad_x * g).sum().unsqueeze(axis=0)[0])
grad_beta = (((lower_flag + upper_flag) * grad_x * g).sum().unsqueeze(
axis=0)[0])
grad_x = middle_flag * grad_x
return grad_x, grad_alpha, grad_beta
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import paddle
import numpy as np
import math
from paddle.framework import ParamAttr
from paddle.nn import Layer
from paddle.nn.initializer import Constant
from paddle.utils import unique_name
from paddle.quantization.factory import QuanterFactory
from .lsq_func import LsqFunc, round
from .base_fake_quanter import BaseFakeQuanterLayer
class WeightLSQplusQuanter(QuanterFactory):
r"""
Weight quantizer. More details can be found in
https://arxiv.org/pdf/1902.08153.pdf and https://arxiv.org/pdf/2004.09576.pdf.
Args:
per_channel(bool): Whether layer-wise or channel-wise quantization, where True for layer-wise quantization and False for channel-wise quantization.
batch_init(int): Number of batches that collect Gaussian approximation for the weight distribution in each layer.
quant_linear(bool): whether the weight is from Linear.
dtype(str): Trainable data type.
name(str): The name of the layer.
reduce_type(str): The reduce type which is needed when parallel training.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import ActLSQplusQuanter, WeightLSQplusQuanter
weight_quanter = WeightLSQplusQuanter()
act_quanter = ActLSQplusQuanter()
q_config = QuantConfig(activation=act_quanter, weight=weight_quanter)
"""
def __init__(self,
quant_bits=8,
sign=True,
symmetric=True,
per_channel=False,
batch_init=20,
quant_linear=False,
channel_num=None,
reduce_type=None,
dtype='float32',
name=None):
super(WeightLSQplusQuanter, self).__init__(
quant_bits=quant_bits,
sign=sign,
symmetric=symmetric,
per_channel=per_channel,
batch_init=batch_init,
quant_linear=quant_linear,
channel_num=channel_num,
reduce_type=reduce_type,
dtype=dtype,
name=name)
def _get_class(self):
return WeightLSQplusQuanterLayer
class WeightLSQplusQuanterLayer(BaseFakeQuanterLayer):
def __init__(self,
layer,
quant_bits=8,
sign=True,
symmetric=True,
per_channel=False,
all_postive=False,
batch_init=20,
quant_linear=False,
channel_num=None,
reduce_type=None,
dtype='float32',
name=None):
super(WeightLSQplusQuanterLayer, self).__init__()
self._per_channel = per_channel
self._quant_linear = quant_linear
self._batch_init = batch_init
self._name = name
self._quant_axis = 1 if quant_linear else 0
self._collect_axis = 0 if quant_linear else 1
self._reduce_type = reduce_type
self.div = 2**self._quant_bits - 1
self.qmin, self.qmax = self.qmin_qmax
self._current_batch_id = 0
self._init_state = 0
scale_prefix = ("{}.scale".format(name)
if name else 'quant_dequant.scale')
self._scale_name = unique_name.generate(scale_prefix)
s_attr = ParamAttr(
name=self._scale_name, initializer=Constant(1.0), trainable=True)
channel_num = layer.weight.shape[
self._quant_axis] if self._per_channel else 1
self._scale = self.create_parameter(
shape=[channel_num], attr=s_attr, dtype=dtype)
self._scale.stop_gradient = False
def init_params(self, weight):
self.g = paddle.to_tensor(1.0 / math.sqrt(weight.numel() * self.qmax))
if self._per_channel:
weight_tmp = weight.detach().reshape((weight.shape[0], -1))
mean = paddle.mean(weight_tmp, axis=self._collect_axis)
std = paddle.std(weight_tmp, axis=self._collect_axis)
s = paddle.max(
paddle.stack(
[paddle.abs(mean - 3 * std),
paddle.abs(mean + 3 * std)]),
axis=0, )
self._scale.set_value(s / self.div)
else:
mean = paddle.mean(weight.detach())
std = paddle.std(weight.detach())
self._scale.set_value(
max([paddle.abs(mean - 3 * std),
paddle.abs(mean + 3 * std)]) / self.div)
self._init_state += 1
def collect_gaussian(self, weight):
if self._per_channel:
weight_tmp = weight.detach().reshape((weight.shape[0], -1))
mean = paddle.mean(weight_tmp, axis=self._collect_axis)
std = paddle.std(weight_tmp, axis=self._collect_axis)
s = paddle.max(
paddle.stack(
[paddle.abs(mean - 3 * std),
paddle.abs(mean + 3 * std)]),
axis=0, )
self._scale.set_value(s * 0.9 + 0.1 * s / self.div)
else:
mean = paddle.mean(weight.detach())
std = paddle.std(weight.detach())
self._scale.set_value(self._scale * 0.9 + 0.1 * max(
[paddle.abs(mean - 3 * std),
paddle.abs(mean + 3 * std)]) / self.div)
self._init_state += 1
def forward(self, weight):
if self._reduce_type == "max":
paddle.distributed.all_reduce(
self._scale, op=paddle.distributed.ReduceOp.MAX)
if self._init_state == 0:
self.init_params(weight)
elif self._init_state < self._batch_init:
self.collect_gaussian(weight)
weight.stop_gradient = False
w_q = LsqFunc.apply(
weight,
self._scale,
self.g,
self.qmin,
self.qmax,
self._per_channel,
self._quant_axis, )
return w_q
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self._quant_bits
def quant_axis(self):
""" Return quantization axis.
"""
return self._quant_axis
def scales(self):
""" Return output scales.
"""
return self._scale
def zero_points(self):
""" Return output zero points.
"""
if self._zero_point is None:
if self._symmetric:
if self._sign:
self._zero_point = 0
else:
self._zero_point = (self.qmax + self.qmin) / 2
else:
self._zero_point = self.qmin - round(self.qmin / self._scale)
self._zero_point = paddle.clip(self._zero_point, self.qmin,
self.qmax)
return self._zero_point
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import paddle
import numpy as np
import math
from paddle.framework import ParamAttr
from paddle.nn import Layer
from paddle.nn.initializer import Constant
from paddle.utils import unique_name
from paddle.quantization.factory import QuanterFactory
from paddle.quantization.base_quanter import BaseQuanter
class PACTQuanter(QuanterFactory):
r"""
PArameterized Clipping acTivation(PACT) uses an activation clipping parameter alpha to find the right quantization scale.
More details can be found in
https://arxiv.org/pdf/1805.06085.pdf.
Args:
quanter(BaseQuanter, required): It can be any BaseQuanter. PACT can be used with any other quantization method.
init_value(float, optional): Value of initial alpha. Default 100
learning_rate(float, optional): The learning rate of alpha when optimizing.
dtype(str): Trainable data type.
name(str): The name of the layer.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import PACTQuanter
from paddle.quantization.quanters.abs_max import FakeQuanterWithAbsMaxObserverLayer
pact_quanter = PACTQuanter(quanter=FakeQuanterWithAbsMaxObserverLayer)
q_config = QuantConfig(activation=pact_quanter, weight=pact_quanter)
"""
def __init__(self,
quanter,
init_value=100.,
learning_rate=1000.,
dtype='float32',
name=None):
super(PACTQuanter, self).__init__(
quanter=quanter,
init_value=init_value,
learning_rate=learning_rate,
dtype=dtype,
name=name)
def _get_class(self):
return PACTQuanterLayer
class PACTQuanterLayer(BaseQuanter):
def __init__(self,
layer,
quanter,
init_value=1000,
learning_rate=1000.,
dtype='float32',
name=None):
super(PACTQuanterLayer, self).__init__()
self.quanter = quanter(layer)
alpha_prefix = ("{}.pact".format(name)
if name else 'quant_dequant.pact')
name = unique_name.generate(alpha_prefix)
alpha_attr = paddle.ParamAttr(
name=name,
initializer=paddle.nn.initializer.Constant(value=init_value),
learning_rate=learning_rate)
self.alpha = self.create_parameter(
shape=[1], attr=alpha_attr, dtype=dtype)
def forward(self, activation):
out_left = paddle.nn.functional.relu(activation - self.alpha)
out_right = paddle.nn.functional.relu(-self.alpha - activation)
activation = activation - out_left + out_right
return self.quanter(activation)
def bit_length(self):
""" Return the bit length of quantized data.
"""
return self.quanter.bit_length()
def quant_axis(self):
""" Return quantization axis.
"""
return self.quanter.quant_axis()
def scales(self):
""" Return output scales.
"""
return self.quanter.scales()
def zero_points(self):
""" Return output zero points.
"""
return self.quanter.zero_points()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
import unittest
import paddle
import tempfile
import numpy as np
sys.path.append("../../")
from paddle.vision.models import resnet18
from paddle.quantization import QuantConfig
from paddle.quantization import QAT
from paddleslim.quant.quanters import ActLSQplusQuanter, WeightLSQplusQuanter, PACTQuanter
from paddleslim.quant.quanters.lsq_act import ActLSQplusQuanterLayer
from paddleslim.quant.quanters.lsq_weight import WeightLSQplusQuanterLayer
from paddleslim.quant.quanters.pact import PACTQuanterLayer
from paddle.quantization.quanters import FakeQuanterWithAbsMaxObserver
from paddle.quantization.quanters.abs_max import FakeQuanterWithAbsMaxObserverLayer
from paddle.nn.quant.format import LinearDequanter, LinearQuanter
import logging
from paddleslim.common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
class ImperativeLenet(paddle.nn.Layer):
def __init__(self, num_classes=10, classifier_activation='softmax'):
super(ImperativeLenet, self).__init__()
self.features = paddle.nn.Sequential(
paddle.nn.Conv2D(
in_channels=1,
out_channels=6,
kernel_size=3,
stride=1,
padding=1),
paddle.nn.AvgPool2D(kernel_size=2, stride=2),
paddle.nn.Conv2D(
in_channels=6,
out_channels=16,
kernel_size=5,
stride=1,
padding=0), paddle.nn.AvgPool2D(kernel_size=2, stride=2))
self.fc = paddle.nn.Sequential(
paddle.nn.Linear(in_features=400, out_features=120),
paddle.nn.Linear(in_features=120, out_features=84),
paddle.nn.Linear(in_features=84, out_features=num_classes), )
def forward(self, inputs):
x = self.features(inputs)
x = paddle.flatten(x, 1)
x = self.fc(x)
return x
class TestQATWithQuanters(unittest.TestCase):
def __init__(self, act_observer, act_observer_type, weight_observer,
weight_observer_type, *args, **kvargs):
super(TestQATWithQuanters, self).__init__(*args, **kvargs)
self.act_observer = act_observer
self.act_observer_type = act_observer_type
self.weight_observer = weight_observer
self.weight_observer_type = weight_observer_type
def setUp(self):
self.init_case()
self.dummy_input = paddle.rand([1, 3, 224, 224])
self.temp_dir = tempfile.TemporaryDirectory(dir="./")
self.path = os.path.join(self.temp_dir.name, 'qat')
if not os.path.exists('ILSVRC2012_data_demo'):
os.system(
'wget -q https://sys-p0.bj.bcebos.com/slim_ci/ILSVRC2012_data_demo.tar.gz'
)
os.system('tar -xf ILSVRC2012_data_demo.tar.gz')
seed = 1
np.random.seed(seed)
paddle.static.default_main_program().random_seed = seed
paddle.static.default_startup_program().random_seed = seed
def tearDown(self):
self.temp_dir.cleanup()
def runTest(self):
self.test_quantize()
self.test_convert()
self.test_convergence()
def init_case(self):
self.q_config = QuantConfig(activation=None, weight=None)
self.q_config.add_type_config(
paddle.nn.Conv2D,
activation=self.act_observer,
weight=self.weight_observer)
def _count_layers(self, model, layer_type):
count = 0
for _layer in model.sublayers(True):
if isinstance(_layer, layer_type):
count += 1
return count
def test_quantize(self):
model = resnet18()
conv_count = self._count_layers(model, paddle.nn.Conv2D)
qat = QAT(self.q_config)
model.train()
quant_model = qat.quantize(model, inplace=False)
out = quant_model(self.dummy_input)
quantizer_cnt = self._count_layers(quant_model, self.act_observer_type)
self.assertEqual(quantizer_cnt, conv_count)
quantizer_cnt = self._count_layers(quant_model,
self.weight_observer_type)
self.assertEqual(quantizer_cnt, conv_count)
def test_convergence(self):
model = ImperativeLenet()
conv_count = self._count_layers(model, paddle.nn.Conv2D)
qat = QAT(self.q_config)
model.train()
quant_model = qat.quantize(model, inplace=False)
place = paddle.CUDAPlace(0) \
if paddle.is_compiled_with_cuda() else paddle.CPUPlace()
transform = paddle.vision.transforms.Compose([
paddle.vision.transforms.Transpose(),
paddle.vision.transforms.Normalize([127.5], [127.5])
])
train_dataset = paddle.vision.datasets.MNIST(
mode='train', backend='cv2', transform=transform)
val_dataset = paddle.vision.datasets.MNIST(
mode='test', backend='cv2', transform=transform)
train_reader = paddle.io.DataLoader(
train_dataset,
drop_last=True,
places=place,
batch_size=64,
return_list=True)
test_reader = paddle.io.DataLoader(
val_dataset, places=place, batch_size=64, return_list=True)
def train(model):
adam = paddle.optimizer.Adam(
learning_rate=0.0001, parameters=model.parameters())
epoch_num = 1
for epoch in range(epoch_num):
model.train()
for batch_id, data in enumerate(train_reader):
img = paddle.to_tensor(data[0])
label = paddle.to_tensor(data[1])
img = paddle.reshape(img, [-1, 1, 28, 28])
label = paddle.reshape(label, [-1, 1])
out = model(img)
acc = paddle.metric.accuracy(out, label)
loss = paddle.nn.functional.loss.cross_entropy(out, label)
avg_loss = paddle.mean(loss)
avg_loss.backward()
adam.minimize(avg_loss)
model.clear_gradients()
if batch_id % 100 == 0:
_logger.info(
"Train | At epoch {} step {}: loss = {:}, acc= {:}".
format(epoch, batch_id,
avg_loss.numpy(), acc.numpy()))
def test(model):
model.eval()
avg_acc = [[], []]
for batch_id, data in enumerate(test_reader):
img = paddle.to_tensor(data[0])
img = paddle.reshape(img, [-1, 1, 28, 28])
label = paddle.to_tensor(data[1])
label = paddle.reshape(label, [-1, 1])
out = model(img)
acc_top1 = paddle.metric.accuracy(input=out, label=label, k=1)
acc_top5 = paddle.metric.accuracy(input=out, label=label, k=5)
avg_acc[0].append(acc_top1.numpy())
avg_acc[1].append(acc_top5.numpy())
if batch_id % 100 == 0:
_logger.info(
"Test | step {}: acc1 = {:}, acc5 = {:}".format(
batch_id, acc_top1.numpy(), acc_top5.numpy()))
_logger.info("Test | Average: acc_top1 {}, acc_top5 {}".format(
np.mean(avg_acc[0]), np.mean(avg_acc[1])))
return np.mean(avg_acc[0]), np.mean(avg_acc[1])
train(model)
top1_1, top5_1 = test(model)
quant_model.train()
train(quant_model)
top1_2, top5_2 = test(quant_model)
_logger.info(
"Before quantization: top1: {}, top5: {}".format(top1_1, top5_1))
_logger.info(
"After quantization: top1: {}, top5: {}".format(top1_2, top5_2))
_logger.info("\n")
diff = 0.01
self.assertTrue(
top1_1 - top1_2 < diff,
msg="The acc of quant model is too lower than fp32 model")
_logger.info('done')
return
def test_convert(self):
model = resnet18()
conv_count = self._count_layers(model, paddle.nn.Conv2D)
qat = QAT(self.q_config)
model.train()
quant_model = qat.quantize(model, inplace=False)
out = quant_model(self.dummy_input)
converted_model = qat.convert(quant_model, inplace=False)
# check count of LinearQuanter and LinearDequanter in dygraph
quantizer_count_in_dygraph = self._count_layers(converted_model,
LinearQuanter)
dequantizer_count_in_dygraph = self._count_layers(
converted_model, LinearDequanter)
self.assertEqual(quantizer_count_in_dygraph, conv_count)
self.assertEqual(dequantizer_count_in_dygraph, conv_count * 2)
observer_suite = unittest.TestSuite()
observer_suite.addTest(
TestQATWithQuanters(
act_observer=ActLSQplusQuanter(),
act_observer_type=ActLSQplusQuanterLayer,
weight_observer=WeightLSQplusQuanter(),
weight_observer_type=WeightLSQplusQuanterLayer))
observer_suite.addTest(
TestQATWithQuanters(
act_observer=ActLSQplusQuanter(symmetric=False),
act_observer_type=ActLSQplusQuanterLayer,
weight_observer=WeightLSQplusQuanter(per_channel=True),
weight_observer_type=WeightLSQplusQuanterLayer))
observer_suite.addTest(
TestQATWithQuanters(
act_observer=PACTQuanter(quanter=ActLSQplusQuanterLayer),
act_observer_type=PACTQuanterLayer,
weight_observer=WeightLSQplusQuanter(),
weight_observer_type=WeightLSQplusQuanterLayer))
if __name__ == '__main__':
runner = unittest.TextTestRunner(verbosity=2)
runner.run(observer_suite)
os.system('rm -rf ILSVRC2012_data_demo.tar.gz')
os.system('rm -rf ILSVRC2012_data_demo')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册