From 341f68fec1d92271c8dcfbdb3be83e18204a1d45 Mon Sep 17 00:00:00 2001 From: Chang Xu Date: Tue, 27 Sep 2022 11:54:11 +0800 Subject: [PATCH] Add LSQ/LSQ+ in QAT (#45652) --- .../slim/quantization/imperative/qat.py | 8 +- .../fluid/contrib/slim/tests/CMakeLists.txt | 3 + .../slim/tests/test_imperative_qat_lsq.py | 198 +++++++++++ python/paddle/nn/quant/lsq.py | 328 ++++++++++++++++++ python/paddle/nn/quant/quant_layers.py | 25 +- 5 files changed, 557 insertions(+), 5 deletions(-) create mode 100644 python/paddle/fluid/contrib/slim/tests/test_imperative_qat_lsq.py create mode 100644 python/paddle/nn/quant/lsq.py diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 5df94f502a4..2f51dfd805d 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -324,16 +324,18 @@ class ImperativeQuantizeInputs(object): "%s is unspported to be quantized." % layer quantize_type = { - 'abs_max', 'moving_average_abs_max', 'channel_wise_abs_max' + 'abs_max', 'moving_average_abs_max', 'channel_wise_abs_max', + 'lsq_weight', 'channel_wise_lsq_weight' } + act_quantize_type = {'moving_average_abs_max', 'lsq_act'} assert weight_quantize_type != 'moving_average_abs_max' \ and weight_quantize_type in quantize_type, \ "Unsupported weight_quantize_type: %s. It can only " \ "be abs_max or channel_wise_abs_max." % weight_quantize_type # TODO (jc): activation_quantize_type supports range_abs_max - assert activation_quantize_type == 'moving_average_abs_max', \ + assert activation_quantize_type in act_quantize_type, \ "Unsupported activation_quantize_type: %s. It can " \ - "only be moving_average_abs_max now." \ + "only be moving_average_abs_max or lsq_act now." \ % activation_quantize_type bits_check = lambda bits: isinstance(bits, int) \ diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index 86f3e759d42..14f1e7f912c 100755 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -252,6 +252,7 @@ if(WIN32) list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1) list(REMOVE_ITEM TEST_OPS test_quantize_transpiler_v2) list(REMOVE_ITEM TEST_OPS test_imperative_qat_amp) + list(REMOVE_ITEM TEST_OPS test_imperative_qat_lsq) endif() if(LINUX AND WITH_MKLDNN) @@ -505,6 +506,7 @@ if(WIN32) test_moving_average_abs_max_scale_op test_imperative_qat_channelwise test_imperative_qat + test_imperative_qat_lsq test_imperative_out_scale test_graph) list(REMOVE_ITEM TEST_OPS ${SINGLE_CARD_TEST_OPS}) @@ -544,6 +546,7 @@ set_tests_properties(test_imperative_qat PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_qat_fuse PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_out_scale PROPERTIES TIMEOUT 200) set_tests_properties(test_imperative_qat_user_defined PROPERTIES TIMEOUT 200) +set_tests_properties(test_imperative_qat_lsq PROPERTIES TIMEOUT 300) if(LINUX AND WITH_MKLDNN) set_tests_properties(test_quant2_int8_mobilenetv1_mkldnn PROPERTIES TIMEOUT diff --git a/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_lsq.py b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_lsq.py new file mode 100644 index 00000000000..07600425866 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_imperative_qat_lsq.py @@ -0,0 +1,198 @@ +# copyright (c) 2022 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +from __future__ import print_function + +import os +import numpy as np +import random +import time +import tempfile +import unittest +import logging + +import paddle +import paddle.fluid as fluid +from paddle.fluid import core +from paddle.fluid.optimizer import SGDOptimizer, AdamOptimizer, MomentumOptimizer +from paddle.fluid.contrib.slim.quantization import ImperativeQuantAware +from paddle.fluid.dygraph.container import Sequential +from paddle.nn import ReLU, ReLU6, LeakyReLU, Sigmoid, Softmax, PReLU +from paddle.nn import Linear, Conv2D, Softmax, BatchNorm2D, MaxPool2D +from paddle.fluid.log_helper import get_logger +from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX +from paddle.nn.quant.quant_layers import QuantizedConv2D, QuantizedConv2DTranspose +from paddle.fluid.framework import _test_eager_guard +from imperative_test_utils import fix_model_dict + +paddle.enable_static() + +os.environ["CPU_NUM"] = "1" +if core.is_compiled_with_cuda(): + fluid.set_flags({"FLAGS_cudnn_deterministic": True}) + +_logger = get_logger(__name__, + logging.INFO, + fmt='%(asctime)s-%(levelname)s: %(message)s') + + +class ImperativeLenet(fluid.dygraph.Layer): + + def __init__(self, num_classes=10): + super(ImperativeLenet, self).__init__() + conv2d_w1_attr = fluid.ParamAttr(name="conv2d_w_1") + conv2d_w2_attr = fluid.ParamAttr(name="conv2d_w_2") + fc_w1_attr = fluid.ParamAttr(name="fc_w_1") + fc_w2_attr = fluid.ParamAttr(name="fc_w_2") + fc_w3_attr = fluid.ParamAttr(name="fc_w_3") + conv2d_b2_attr = fluid.ParamAttr(name="conv2d_b_2") + fc_b1_attr = fluid.ParamAttr(name="fc_b_1") + fc_b2_attr = fluid.ParamAttr(name="fc_b_2") + fc_b3_attr = fluid.ParamAttr(name="fc_b_3") + self.features = Sequential( + Conv2D(in_channels=1, + out_channels=6, + kernel_size=3, + stride=1, + padding=1, + weight_attr=conv2d_w1_attr, + bias_attr=False), BatchNorm2D(6), ReLU(), + MaxPool2D(kernel_size=2, stride=2), + Conv2D(in_channels=6, + out_channels=16, + kernel_size=5, + stride=1, + padding=0, + weight_attr=conv2d_w2_attr, + bias_attr=conv2d_b2_attr), BatchNorm2D(16), PReLU(), + MaxPool2D(kernel_size=2, stride=2)) + + self.fc = Sequential( + Linear(in_features=400, + out_features=120, + weight_attr=fc_w1_attr, + bias_attr=fc_b1_attr), LeakyReLU(), + Linear(in_features=120, + out_features=84, + weight_attr=fc_w2_attr, + bias_attr=fc_b2_attr), Sigmoid(), + Linear(in_features=84, + out_features=num_classes, + weight_attr=fc_w3_attr, + bias_attr=fc_b3_attr), Softmax()) + + def forward(self, inputs): + x = self.features(inputs) + x = fluid.layers.flatten(x, 1) + x = self.fc(x) + return x + + +class TestImperativeQatLSQ(unittest.TestCase): + + def set_vars(self): + self.weight_quantize_type = 'channel_wise_lsq_weight' + self.activation_quantize_type = 'lsq_act' + self.onnx_format = False + self.fuse_conv_bn = False + + def func_qat(self): + self.set_vars() + + imperative_qat = ImperativeQuantAware( + weight_quantize_type=self.weight_quantize_type, + activation_quantize_type=self.activation_quantize_type, + fuse_conv_bn=self.fuse_conv_bn) + + seed = 100 + np.random.seed(seed) + fluid.default_main_program().random_seed = seed + fluid.default_startup_program().random_seed = seed + paddle.disable_static() + lenet = ImperativeLenet() + lenet = fix_model_dict(lenet) + imperative_qat.quantize(lenet) + optimizer = MomentumOptimizer(learning_rate=0.1, + parameter_list=lenet.parameters(), + momentum=0.9) + + train_reader = paddle.batch(paddle.dataset.mnist.train(), + batch_size=64, + drop_last=True) + test_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=32) + epoch_num = 2 + for epoch in range(epoch_num): + lenet.train() + for batch_id, data in enumerate(train_reader()): + x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array([x[1] for x in data + ]).astype('int64').reshape(-1, 1) + + img = fluid.dygraph.to_variable(x_data) + label = fluid.dygraph.to_variable(y_data) + out = lenet(img) + acc = fluid.layers.accuracy(out, label) + loss = fluid.layers.cross_entropy(out, label) + avg_loss = paddle.mean(loss) + + avg_loss.backward() + optimizer.minimize(avg_loss) + lenet.clear_gradients() + + if batch_id % 100 == 0: + _logger.info( + "Train | At epoch {} step {}: loss = {:}, acc= {:}". + format(epoch, batch_id, avg_loss.numpy(), acc.numpy())) + + lenet.eval() + eval_acc_top1_list = [] + with paddle.no_grad(): + for batch_id, data in enumerate(test_reader()): + + x_data = np.array([x[0].reshape(1, 28, 28) + for x in data]).astype('float32') + y_data = np.array([x[1] for x in data + ]).astype('int64').reshape(-1, 1) + img = fluid.dygraph.to_variable(x_data) + label = fluid.dygraph.to_variable(y_data) + + out = lenet(img) + acc_top1 = fluid.layers.accuracy(input=out, + label=label, + k=1) + acc_top5 = fluid.layers.accuracy(input=out, + label=label, + k=5) + + if batch_id % 100 == 0: + eval_acc_top1_list.append(float(acc_top1.numpy())) + _logger.info( + "Test | At epoch {} step {}: acc1 = {:}, acc5 = {:}" + .format(epoch, batch_id, acc_top1.numpy(), + acc_top5.numpy())) + + # check eval acc + eval_acc_top1 = sum(eval_acc_top1_list) / len(eval_acc_top1_list) + print('eval_acc_top1', eval_acc_top1) + self.assertTrue(eval_acc_top1 > 0.9, + msg="The test acc {%f} is less than 0.9." % + eval_acc_top1) + + def test_qat(self): + self.func_qat() + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/nn/quant/lsq.py b/python/paddle/nn/quant/lsq.py new file mode 100644 index 00000000000..f92b3856dd3 --- /dev/null +++ b/python/paddle/nn/quant/lsq.py @@ -0,0 +1,328 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.framework import core +from paddle.fluid import dygraph_utils +from paddle.utils import unique_name +from paddle.framework import ParamAttr +from paddle.fluid.framework import _varbase_creator +from paddle.nn.initializer import Constant +from paddle.fluid.data_feeder import check_variable_and_dtype +from paddle.nn import functional as F +import logging +from paddle.fluid.log_helper import get_logger +from paddle import in_dynamic_mode +from paddle.nn import Layer +from paddle.autograd import PyLayer +import math +import copy + + +def round(x): + sign = paddle.sign(x) + x = sign * paddle.floor(paddle.abs(x) + 0.5) + return x + + +class LsqFunc(PyLayer): + + @staticmethod + def forward(ctx, weight, alpha, g, Qn, Qp, per_channel=False, quant_axis=0): + ctx.save_for_backward(weight, alpha) + ctx.other = g, Qn, Qp, per_channel, quant_axis + if per_channel: + sizes = weight.shape + weight = weight.reshape((weight.shape[quant_axis], -1)) + weight = weight.transpose((1, 0)) + alpha = paddle.broadcast_to(alpha, weight.shape) + quant_w = round(paddle.divide(weight, alpha)).clip(Qn, Qp) + quant_w = quant_w * alpha + quant_w = quant_w.transpose((1, 0)) + quant_w = quant_w.reshape(sizes) + else: + quant_w = round(paddle.divide(weight, alpha)).clip(Qn, Qp) + quant_w = quant_w * alpha + return quant_w + + @staticmethod + def backward(ctx, grad_weight): + weight, alpha = ctx.saved_tensor() + g, Qn, Qp, per_channel, quant_axis = ctx.other + if per_channel: + sizes = weight.shape + weight = weight.reshape((weight.shape[quant_axis], -1)) + weight = weight.transpose((1, 0)) + alpha = paddle.broadcast_to(alpha, weight.shape) + q_w = paddle.divide(weight, alpha) + q_w = q_w.transpose((1, 0)) + q_w = q_w.reshape(sizes) + else: + q_w = paddle.divide(weight, alpha) + lower_flag = paddle.cast((q_w < Qn), 'float32') + upper_flag = paddle.cast((q_w > Qp), 'float32') + middle_flag = 1.0 - lower_flag - upper_flag + if per_channel: + grad_alpha = ((lower_flag * Qn + upper_flag * Qp + + middle_flag * round(q_w) - middle_flag * q_w) * + grad_weight * g) + grad_alpha = grad_alpha.reshape( + (grad_alpha.shape[quant_axis], -1)).sum(axis=1) + else: + grad_alpha = ((lower_flag * Qn + upper_flag * Qp + + middle_flag * round(q_w) - middle_flag * q_w) * + grad_weight * g).sum().unsqueeze(axis=0)[0] + grad_weight = middle_flag * grad_weight + return grad_weight, grad_alpha + + +class LsqPlusActFunc(PyLayer): + + @staticmethod + def forward(ctx, x, alpha, beta, g, Qn, Qp): + ctx.save_for_backward(x, alpha, beta) + ctx.other = g, Qn, Qp + quant_x = round(paddle.divide((x - beta), alpha)).clip(Qn, Qp) + return quant_x * alpha + beta + + @staticmethod + def backward(ctx, grad_x): + x, alpha, beta = ctx.saved_tensor() + g, Qn, Qp = ctx.other + q_x = (x - beta) / alpha + lower_flag = paddle.cast((q_x < Qn), 'float32') + upper_flag = paddle.cast((q_x > Qp), 'float32') + middle_flag = 1.0 - lower_flag - upper_flag + grad_alpha = ((lower_flag * Qn + upper_flag * Qp + + middle_flag * round(q_x) - middle_flag * q_x) * grad_x * + g).sum().unsqueeze(axis=0)[0] + grad_beta = ((lower_flag + upper_flag) * grad_x * + g).sum().unsqueeze(axis=0)[0] + grad_x = middle_flag * grad_x + return grad_x, grad_alpha, grad_beta + + +class FakeQuantActLSQPlus(Layer): + + def __init__(self, + quant_bits, + all_postive=False, + symmetric=False, + batch_init=20, + dtype='float32', + name=None, + reduce_type=None): + super(FakeQuantActLSQPlus, self).__init__() + ''' + Args: + quant_bits(int): quantization bit number for weights. + all_postive(bool): whether unsigned or signed quantization, where True for unsigned quantization and False for signed quantization. + symmetric(bool): whether symmetric or asymmetric quantization. + batch_init(int): number of batches that collect Gaussian approximation for the weight distribution in each layer. + dtype(str): data type. + name(str): the name of the weight. + reduce_type(str): the reduce type which is needed when parallel training. + ''' + self.bits = quant_bits + self.all_positive = all_postive + self.symmetric = symmetric + self.batch_init = batch_init + self.name = name + self.reduce_type = reduce_type + + if self.all_positive: + # unsigned activation + self.Qn = 0 + self.Qp = 2**self.bits - 1 + else: + # signed activation + self.Qn = -2**(self.bits - 1) + self.Qp = 2**(self.bits - 1) - 1 + + scale_prefix = "{}.scale".format( + name) if name else 'quant_dequant.scale' + self._scale_name = unique_name.generate(scale_prefix) + + s_attr = ParamAttr(name=self._scale_name, + initializer=Constant(1.0), + trainable=True) + self.s = self.create_parameter(shape=[1], attr=s_attr, dtype='float32') + self.s.stop_gradient = False + + if not self.symmetric: + beta_prefix = "{}.beta".format( + name) if name else 'quant_dequant.beta' + self._beta_name = unique_name.generate(beta_prefix) + + beta_attr = ParamAttr(name=self._beta_name, + initializer=Constant(0.0), + trainable=True) + self.beta = self.create_parameter(shape=[1], + attr=beta_attr, + dtype='float32') + self.beta.stop_gradient = False + + self.init_state = 0 + + def forward(self, activation): + if self.reduce_type == "max": + paddle.distributed.all_reduce(self.s, + op=paddle.distributed.ReduceOp.MAX) + + if not self.symmetric and self.reduce_type == "max": + paddle.distributed.all_reduce(self.beta, + op=paddle.distributed.ReduceOp.MAX) + + if self.init_state == 0: + self.g = paddle.to_tensor(1.0 / + math.sqrt(activation.numel() * self.Qp)) + min_a = paddle.min(activation.detach()) + max_a = paddle.max(activation.detach()) + self.s.set_value((max_a - min_a) / (self.Qp - self.Qn)) + if not self.symmetric: + self.beta.set_value(min_a - self.s * self.Qn) + self.init_state += 1 + elif self.init_state < self.batch_init: + min_a = paddle.min(activation.detach()) + max_a = paddle.max(activation.detach()) + self.s.set_value(self.s * 0.9 + 0.1 * (max_a - min_a) / + (self.Qp - self.Qn)) + if not self.symmetric: + self.beta.set_value(self.s * 0.9 + 0.1 * + (min_a - self.s * self.Qn)) + self.init_state += 1 + else: + self.init_state += 1 + activation.stop_gradient = False + if not self.symmetric: + q_a = LsqPlusActFunc.apply(activation, self.s, self.beta, self.g, + self.Qn, self.Qp) + else: + q_a = LsqFunc.apply(activation, + self.s, + self.g, + self.Qn, + self.Qp, + per_channel=False) + return q_a + + +class FakeQuantWeightLSQPlus(Layer): + + def __init__(self, + quant_bits, + all_postive=False, + per_channel=False, + batch_init=20, + channel_num=None, + quant_linear=False, + dtype='float32', + name=None, + reduce_type=None): + super(FakeQuantWeightLSQPlus, self).__init__() + ''' + Args: + quant_bits(int): quantization bit number for weights. + all_postive(bool): whether unsigned or signed quantization, where True for unsigned quantization and False for signed quantization. + per_channel(bool): whether layer-wise or channel-wise quantization, where True for layer-wise quantization and False for channel-wise quantization. + batch_init(int): number of batches that collect Gaussian approximation for the weight distribution in each layer. + channel_num(int): the channel number of the weight which is needed when per_channel is True. + quant_linear(bool): whether the weight is from Linear. + dtype(str): data type. + name(str): the name of the weight. + reduce_type(str): the reduce type which is needed when parallel training. + ''' + + self.bits = quant_bits + self.all_positive = all_postive + self.per_channel = per_channel + self.quant_linear = quant_linear + self.batch_init = batch_init + self.name = name + self.quant_axis = 1 if quant_linear else 0 + self.collect_axis = 0 if quant_linear else 1 + self.reduce_type = reduce_type + + if self.all_positive: + # unsigned weight + self.Qn = 0 + self.Qp = 2**self.bits - 1 + else: + # signed weight + self.Qn = -2**(self.bits - 1) + self.Qp = 2**(self.bits - 1) - 1 + + self.init_state = 0 + scale_prefix = "{}.scale".format( + name) if name else 'quant_dequant.scale' + self._scale_name = unique_name.generate(scale_prefix) + s_attr = ParamAttr(name=self._scale_name, + initializer=Constant(1.0), + trainable=True) + self.s = self.create_parameter(shape=[channel_num], + attr=s_attr, + dtype=dtype) + self.s.stop_gradient = False + + def forward(self, weight): + if self.reduce_type == "max": + paddle.distributed.all_reduce(self.s, + op=paddle.distributed.ReduceOp.MAX) + + if self.init_state == 0: + self.g = paddle.to_tensor(1.0 / math.sqrt(weight.numel() * self.Qp)) + self.div = 2**self.bits - 1 + if self.per_channel: + weight_tmp = weight.detach().reshape((weight.shape[0], -1)) + mean = paddle.mean(weight_tmp, axis=self.collect_axis) + std = paddle.std(weight_tmp, axis=self.collect_axis) + s = paddle.max(paddle.stack( + [paddle.abs(mean - 3 * std), + paddle.abs(mean + 3 * std)]), + axis=0) + self.s.set_value(s / self.div) + else: + mean = paddle.mean(weight.detach()) + std = paddle.std(weight.detach()) + self.s.set_value( + max([ + paddle.abs(mean - 3 * std), + paddle.abs(mean + 3 * std) + ]) / self.div) + self.init_state += 1 + elif self.init_state < self.batch_init: + self.div = 2**self.bits - 1 + if self.per_channel: + weight_tmp = weight.detach().reshape((weight.shape[0], -1)) + mean = paddle.mean(weight_tmp, axis=self.collect_axis) + std = paddle.std(weight_tmp, axis=self.collect_axis) + s = paddle.max(paddle.stack( + [paddle.abs(mean - 3 * std), + paddle.abs(mean + 3 * std)]), + axis=0) + self.s.set_value(s * 0.9 + 0.1 * s / self.div) + else: + mean = paddle.mean(weight.detach()) + std = paddle.std(weight.detach()) + self.s.set_value(self.s * 0.9 + 0.1 * max( + [paddle.abs(mean - 3 * std), + paddle.abs(mean + 3 * std)]) / self.div) + self.init_state += 1 + elif self.init_state == self.batch_init: + self.init_state += 1 + + weight.stop_gradient = False + w_q = LsqFunc.apply(weight, self.s, self.g, self.Qn, self.Qp, + self.per_channel, self.quant_axis) + return w_q diff --git a/python/paddle/nn/quant/quant_layers.py b/python/paddle/nn/quant/quant_layers.py index c41985e4b2f..855a0538377 100644 --- a/python/paddle/nn/quant/quant_layers.py +++ b/python/paddle/nn/quant/quant_layers.py @@ -26,6 +26,7 @@ from paddle.fluid.log_helper import get_logger from paddle import _C_ops, _legacy_C_ops from paddle import in_dynamic_mode from paddle.nn import Layer +from paddle.nn.quant.lsq import FakeQuantActLSQPlus, FakeQuantWeightLSQPlus __all__ = [ 'FakeQuantAbsMax', @@ -653,7 +654,8 @@ class QuantizedLinear(Layer): dtype=self._dtype, quant_on_weight=True, channel_num=self.weight.shape[self._linear_quant_axis], - quant_axis=self._linear_quant_axis) + quant_axis=self._linear_quant_axis, + quant_linear=True) if act_quant_layer is not None: self._fake_quant_input = act_quant_layer() @@ -946,10 +948,29 @@ def _get_fake_quant_type(quant_type, **kwargs): assert call_args["channel_num"] is not None, ( "You need to input channel_num" "when you use channel_wise_abs_max strategy.") + elif quant_type == 'lsq_weight': + call_args["all_postive"] = kwargs.get("all_postive", False) + call_args["per_channel"] = False + call_args["channel_num"] = 1 + call_args["quant_linear"] = kwargs.get("quant_linear", False) + elif quant_type == 'channel_wise_lsq_weight': + quant_type = 'lsq_weight' + call_args["all_postive"] = kwargs.get("all_postive", False) + call_args["per_channel"] = True + call_args["channel_num"] = kwargs.get("channel_num", None) + call_args["quant_linear"] = kwargs.get("quant_linear", False) + assert call_args["channel_num"] is not None, ( + "You need to input channel_num" + "when you use channel_wise_abs_max strategy.") + elif quant_type == 'lsq_act': + call_args["all_postive"] = kwargs.get("all_postive", False) + call_args["symmetric"] = kwargs.get("symmetric", True) fake_quant_map = { 'abs_max': FakeQuantAbsMax, 'moving_average_abs_max': FakeQuantMovingAverageAbsMax, - 'channel_wise_abs_max': FakeQuantChannelWiseAbsMax + 'channel_wise_abs_max': FakeQuantChannelWiseAbsMax, + 'lsq_weight': FakeQuantWeightLSQPlus, + 'lsq_act': FakeQuantActLSQPlus } return fake_quant_map[quant_type](**call_args) -- GitLab