diff --git a/docs/zh_cn/quick_start/dygraph/dygraph_quant_post_tutorial.md b/docs/zh_cn/quick_start/dygraph/dygraph_quant_post_tutorial.md deleted file mode 100644 index 77cd49aa170b944b1d6e65629136e516cd14f33b..0000000000000000000000000000000000000000 --- a/docs/zh_cn/quick_start/dygraph/dygraph_quant_post_tutorial.md +++ /dev/null @@ -1,97 +0,0 @@ -# 离线量化 - -离线量化又称为训练后量化,仅需要使用少量校准数据,确定最佳的量化参数降低量化误差。这种方法需要的数据量较少,但量化模型精度相比在线量化稍逊。 - -下面该教程将以图像分类模型MobileNetV1为例,说明如何快速使用[PaddleSlim的模型量化接口]()。 - -该示例包含以下步骤: - -1. 导入依赖 -2. 构建模型和数据集 -3. 进行预训练 -4. 量化训练 -5. 导出预测模型 - -以下章节依次次介绍每个步骤的内容。 - -## 1. 导入依赖 - -请参考PaddleSlim安装文档,安装正确的Paddle和PaddleSlim版本,然后按以下方式导入Paddle和PaddleSlim: - -```python -import paddle -import paddleslim -import paddle.vision.models as models -from paddle.static import InputSpec as Input -from paddle.vision.datasets import Cifar10 -import paddle.vision.transforms as T -from paddleslim.dygraph.quant import QAT -``` - -## 2. 构建网络和数据集 - -该章节构造一个用于对CIFAR10数据进行分类的分类模型,选用`MobileNetV1`,并将输入大小设置为`[3, 32, 32]`,输出类别数为10。 -为了方便展示示例,我们使用Paddle高层API提供的预定义[mobilenetv1分类模型](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/vision/models/mobilenetv1/MobileNetV1_cn.html#mobilenetv1)。 -调用`model.prepare`配置模型所需的部件,比如优化器、损失函数和评价指标,API细节请参考[文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/hapi/model/Model_cn.html#prepare-optimizer-none-loss-function-none-metrics-none) - -```python -net = models.mobilenet_v1(pretrained=False, scale=1.0, num_classes=10) -inputs = [Input([None, 3, 32, 32], 'float32', name='image')] -labels = [Input([None, 1], 'int64', name='label')] -optimizer = paddle.optimizer.Momentum( - learning_rate=0.1, - parameters=net.parameters()) -model = paddle.Model(net, inputs, labels) -model.prepare( - optimizer, - paddle.nn.CrossEntropyLoss(), - paddle.metric.Accuracy(topk=(1, 5))) -transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])]) -train_dataset = Cifar10(mode='train', backend='cv2', transform=transform) -val_dataset = Cifar10(mode='test', backend='cv2', transform=transform) -``` - -## 3. 进行预训练 - -对模型进行预训练,为之后的量化做准备。 -执行以下代码对模型进行预训练 -```python -model.fit(train_dataset, epochs=5, batch_size=256, verbose=1) -model.evaluate(val_dataset, batch_size=256, verbose=1) -``` - -训练完成后导出预测模型: -```python -paddle.jit.save(net, "./fp32_inference_model", input_spec=inputs) -``` - - -## 4.离线量化 - -调用slim接口将原模型转换为离线量化模型, 导出的模型可以直接用于预测部署: - -```python -paddle.enable_static() -place = paddle.CPUPlace() -exe = paddle.static.Executor(place) -paddleslim.quant.quant_post_static( - executor=exe, - model_dir='./', - model_filename='fp32_inference_model.pdmodel', - params_filename='fp32_inference_model.pdiparams', - quantize_model_path='./quant_post_static_model', - sample_generator=paddle.dataset.cifar.test10(), - batch_nums=10) -``` - -注意,目前离线量化方法还不支持存在控制流OP的模型。 - -根据部署业务场景,可以使用PaddleLite将该量化模型部署到移动端(ARM CPU),或者使用PaddleInference将该量化模型部署到服务器端(NV GPU和Intel CPU)。 - -导出的量化模型相比原始FP32模型,模型体积没有明显差别,这是因为量化预测模型中的权重依旧保存为FP32类型。在部署时,使用PaddleLite opt工具转换量化预测模型后,模型体积才会真实减小。 - -部署参考文档: -* 部署[文档](https://paddleslim.readthedocs.io/zh_CN/latest/deploy/index.html) -* PaddleLite部署量化模型[文档](https://paddle-lite.readthedocs.io/zh/latest/user_guides/quant_aware.html) -* PaddleInference Intel CPU部署量化模型[文档](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_x86_cpu_int8.html) -* PaddleInference NV GPU部署量化模型[文档](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_trt.html) diff --git a/docs/zh_cn/tutorials/quant/advanced_quantization.md b/docs/zh_cn/tutorials/quant/advanced_quantization.md new file mode 100644 index 0000000000000000000000000000000000000000..b5cac2fca18fcc59728ffb3b7d022e9d370fb34c --- /dev/null +++ b/docs/zh_cn/tutorials/quant/advanced_quantization.md @@ -0,0 +1,158 @@ +# 量化策略详细教程 +近年来,Transformer模型已在各个领域得到广泛采用,尤其是生成式语言大模型极大地推动了人工智能领域的发展。这些模型已经从数亿个参数发展到数千亿个参数,在有限的数据和 GPU 资源下运行,对于这些模型来说变得越来越具有挑战性。此时压缩技术变得格外重要,其中量化已成为减少内存占用和计算开销的通用和主要范例。然而,许多研究表明,Transformer模型往往会存在强烈的异常激活值,这使得它们难以量化。为了保持可接受的性能,这些异常值的存在要求激活具有更高的位宽或使用不同的数字格式、额外的微调或其他解决方法。本文档会介绍前沿的优化量化效果的几种策略,其中包括一些开源工作,也包括PaddleSlim自研的方法。以下方法暂时仅支持Transformer模型,具体示例使用方法可参考[PaddleNLP LLM示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/causallm),以下教程仅详细介绍API接口。 + +## 1. Shift功能 + +Shift算法来源于[Outlier Suppression+](https://arxiv.org/abs/2304.09145)。通过Layer Norm和Linear的bias进行数学等价的异常值缩放操作,有效将激活值的分布调整对称,有助于离线量化的精度提升。在PaddleSlim的实现中,对于前面有Layer Norm的Linear将使用数学等价的方式改变bias从而进行缩放操作,对于前面没有Layer Norm的Linear可以采用插入节点的方式实现,可通过参数shift_all_linears来控制是否需要shift前面没有Layer Norm的Linear。此外,PaddleSlim版本的Shift功能提供传入sample_function,如设置sample_function为None,Shift算法将完全对齐论文[Outlier Suppression+](https://arxiv.org/abs/2304.09145). + +| **参数名** | **参数类型** | **参数释义** | +|-----------------------------|-----------------------------------------|-----------------------------------------| +| model | paddle.nn.Layer |必须传入的动态图模型 | +| model_config | dict | 必须传入的模型结构的配置 | +| shift_all_linears| bool | 可选参数,默认为False,若为True,则shift模型中全部Linear;若为False,则只shift模型中Layer Norm之后的Linear | +| sample_function | function | 可选参数,默认为None,若为None,采样方式为论文中相同方法,现可选的sample_function有MultiStepSampler和EMASampler,Shift时推荐使用EMASampler | + +以下为简单的使用示例: +```shell +from paddleslim.quant.advanced import Shift, EMASampler + +model = LLM() +model_config = {} +shift = Shift(model, model_config, sample_function=EMASampler()) +for data in dataloader(): + model(data) + shift.step += 1 +shift.update_weight() +``` + + + +## 2. Smooth功能 +Smooth算法来源于[SmoothQuant](https://arxiv.org/abs/2211.10438)。通过Layer Norm和Linear的weight和bias进行数学等价的异常值缩放操作,有效减少激活值中的异常值,有助于离线量化的精度提升。在PaddleSlim的实现中,与shift相同,对于前面有Layer Norm的Linear将使用数学等价的方式改变weight和bias从而进行缩放操作,对于前面没有Layer Norm的Linear可以采用插入节点的方式实现,可通过参数smooth_all_linears来控制是否需要smooth前面没有Layer Norm的Linear。此外,PaddleSlim版本的Smooth功能还提供搜索功能,搜索功能文档见下文。 + +| **参数名** | **参数类型** | **参数释义** | +|-----------------------------|-----------------------------------------|-----------------------------------------| +| model | paddle.nn.Layer |必须传入的动态图模型 | +| model_config | dict | 必须传入的模型结构的配置 | +| alpha | float | 可选参数,默认为0.5 | +| smooth_all_linears| bool | 可选参数,默认为False,若为True,则shift模型中全部Linear;若为False,则只shift模型中Layer Norm之后的Linear | +| sample_function | function | 可选参数,默认为None,若为None,采样方式为单batch数据,现可选的sample_function有MultiStepSampler和EMASampler,Smooth时推荐使用MultiStepSampler | +| search_function | function | 可选参数,默认为None,若为None,则不进行搜索,smooth方法与原论文保持一致 | + +以下为简单的使用示例: +```shell +from paddleslim.quant.advanced import Smooth,MultiStepSampler + +model = LLM() +model_config = {} +smooth = Smooth(model, model_config, sample_function=MultiStepSampler()) +for data in dataloader(): + model(data) + smooth.step += 1 +smooth.update_weight() +``` + + +注意:模型shift和smooth前后从理论数学角度应该是等价的,但从工程角度,输出具体数值可能会有稍微不同,所以模型shift/smooth前后的精度是大约一致的。如果精度出现很大差距,说明模型未解析成功。参数中`model_config`请按照模型的实际情况填写,此输入会影响模型结构解析,若结构解析出错,会导致模型精度不对,其中包含字段: +- `fused_qkv`:该模型是否融合了QKV,默认为True +- `linear_flag`:该模型判断Linear的名字字段,默认为`linear` +- `norm_flag`:该模型判断Layer Norm的名字字段,默认为`norm` +- `parallel_ffn`:该模型是否含有并行的FFN,默认为False +- `skip_norm_list`:该模型中需要被忽略的Layer Norm的名字字段,默认为空list + +若模型中含有PostLayerNorm Shurtcut结构,则不支持对该模型进行smooth和shift。比如PaddleNLP中[ChatGLM结构](https://github.com/PaddlePaddle/PaddleNLP/blob/64f97979a62fba6a35a1177850cc22dbc91fade0/paddlenlp/transformers/chatglm/modeling.py#L360)存在PostLayerNorm Shurtcut结构,所以不支持对该模型进行shift/smooth。 + + +## 3. PieceWiseSearch功能 +根据[SmoothQuant](https://arxiv.org/abs/2211.10438)算法,的确能够有效减少异常值,但我们在大量的实验中发现,在某些情况下,比如权重值较大,尤其是权重和激活的异常值在同一通道时,直接根据SmoothQuant计算smooth scale的公式会导致权重值难量化的情况。并且,对于一个激活值,当异常值较多、数值范围较大,使用同一个alpha去smooth整个激活张量也并不合理。因此,PaddleSlim提出分段搜索功能,根据数值大小将激活分成K段,对于每一段进行alhpa和scale的搜索。 + +| **参数名** | **参数类型** | **参数释义** | +|-----------------------------|-----------------------------------------|-----------------------------------------| +| k_piece | int | 可选参数,分段数量,默认为1,1代表不分段 | +| bits_length | int | 可选参数,量化比特数,默认为8 | +| search_piece | bool | 可选参数,是否搜索分段数k,默认为False,若为True,将会从1到k搜索合适的k | +| search_alpha_min | float | 可选参数,搜索alpha最小值,默认为0.2 | +| search_alpha_max | float | 可选参数,搜索alpha最大值,默认为0.8 | +| search_scale_min | float | 可选参数,搜索scale最小值,默认为1. | +| search_scale_max | float | 可选参数,搜索scale最大值,默认为1. | +| weight_quant_method | str | 可选参数,权重量化方法,可选`abs_max`,`abs_max_channel_wise`,`avg`,默认为`abs_max_channel_wise` | +| act_quant_method | str | 可选参数,激活量化方法,可选`abs_max`,`avg`,默认为`abs_max` | +| loss_function | function | 可选参数,搜索时使用的误差函数,默认为mse_loss | + + +```shell +from paddleslim.quant.advanced import Smooth, MultiStepSampler, PieceWiseSearch, mse_loss + +search_func =PieceWiseSearch( + k_piece=3, + bits_length=8, + search_piece=False, + search_alpha_min=0.2, + search_alpha_max=0.8, + search_scale_min=1., + search_scale_max=5., + weight_quant_method='abs_max_channel_wise', + act_quant_method='abs_max', + loss_function=mse_loss + ) +model = LLM() +model_config = {} +smooth = Smooth(model, model_config, sample_function=MultiStepSampler(), search_function=search_func) +for data in dataloader(): + model(data) + smooth.step += 1 +smooth.update_weight() + +``` + +## 4. GPTQ +GPTQ算法来自[GPTQ](https://arxiv.org/abs/2210.17323),该算法逐步按照行量化权重,利用海森矩阵来不断更新未量化的权重,在低比特Weight Only Int4量化表现良好。GPTQ默认使用搭配使用[RPTQ](https://arxiv.org/abs/2304.01089),若不想搭配RPTQ,调用fasterquant时设置act_order=False即可。 + +| **参数名** | **参数类型** | **参数释义** | +|-----------------------------|-----------------------------------------|-----------------------------------------| +| layer | paddle.nn.Layer |必须入的需要量化的层,现仅支持nn.Linear,ColumnParallelLinear和RowParallelLinear类型 | +| model_config | dict | 必须传入的模型结构的配置 | +| quant_bits| int | 可选参数,量化比特数,默认为4 | +| weight_quant_method | str | 可选参数,权重量化方法,可选`abs_max`,`abs_max_channel_wise`,`avg`,默认为`abs_max_channel_wise` | + +```shell +from paddleslim.quant.advanced import GPTQ + +model = LLM() +for cur_name, cur_layer in model.named_sublayers(): + if type(cur_layer) == paddle.nn.Linear: + gptq_layer = GPTQ(cur_layer) + # sample data + for data in dataloader(): + model(data) + # quant weight + gptq_layer.fasterquant(act_order=True) +``` + + +## 5. LayerWiseQuantError +LayerWiseQuantError是按层级别分析量化损失的方法,对于模型中每一层,量化后,计算当前层量化输出和原始模型输出的误差。 +| **参数名** | **参数类型** | **参数释义** | +|-----------------------------|-----------------------------------------|-----------------------------------------| +| layer | paddle.nn.Layer |必须入的需要量化的层,现仅支持nn.Linear,ColumnParallelLinear和RowParallelLinear类型 | +| weight_bits | int | 可选参数,权重量化比特数,默认为8 | +| act_bits| int | 可选参数,激活量化比特数,默认为8 | +| weight_quant_method| str | 可选参数,权重量化方法,可选`abs_max`,`abs_max_channel_wise`,`avg`,默认为`abs_max_channel_wise` | +| act_quant_method| str | 可选参数,激活量化方法,可选`abs_max`,`avg` | +| loss_function | function | 可选参数,使用的误差函数,默认为mse_loss | + +```shell +from paddleslim.quant.advanced import LayerWiseQuantError + +model = LLM() +for cur_name, cur_layer in model.named_sublayers(): + if type(cur_layer) == paddle.nn.Linear: + gptq_layer = LayerWiseQuantError(cur_layer) + +for data in dataloader(): + model(data) + +for cur_name, cur_layer in model.named_sublayers(): + if type(cur_layer) == LayerWiseQuantError: + print(cur_name, cur_layer.losses.mean()) +``` diff --git a/docs/zh_cn/tutorials/quant/dygraph/dygraph_quant_post_tutorial.md b/docs/zh_cn/tutorials/quant/dygraph/dygraph_quant_post_tutorial.md deleted file mode 120000 index ddd24f04c0b0db18602d8b56e20d7521547592f4..0000000000000000000000000000000000000000 --- a/docs/zh_cn/tutorials/quant/dygraph/dygraph_quant_post_tutorial.md +++ /dev/null @@ -1 +0,0 @@ -../../../quick_start/dygraph/dygraph_quant_post_tutorial.md \ No newline at end of file diff --git a/docs/zh_cn/tutorials/quant/dygraph/index.rst b/docs/zh_cn/tutorials/quant/dygraph/index.rst deleted file mode 100644 index e45cca35cc33eaee2d81cdc11129dd39ca166270..0000000000000000000000000000000000000000 --- a/docs/zh_cn/tutorials/quant/dygraph/index.rst +++ /dev/null @@ -1,8 +0,0 @@ - -动态图 -============== - -.. toctree:: - :maxdepth: 1 - - quant_aware_training_tutorial.md diff --git a/docs/zh_cn/tutorials/quant/post_training_quantization.md b/docs/zh_cn/tutorials/quant/post_training_quantization.md new file mode 100644 index 0000000000000000000000000000000000000000..09391de13f45384b4ee97930898b85482ebaaaf1 --- /dev/null +++ b/docs/zh_cn/tutorials/quant/post_training_quantization.md @@ -0,0 +1,93 @@ +# 离线量化 + +离线量化又称为训练后量化,仅需要使用少量校准数据,确定最佳的量化参数降低量化误差。这种方法需要的数据量较少,但量化模型精度相比在线量化稍逊。 + + +## 使用方法 + +离线量化的基本流程可以分为以下三步: + +1. 选择量化配置 +2. 采样收集量化信息 +3. 转换量化模型 + +## 接口介绍 + +### 1. 量化配置相关概念以及接口: + +`Observer`:用于统计OP输入或输出,并计算出量化相关的统计量,比如scale、zero_point等。每个离线量化算法对应一个Observer,现已有的Observer包含: +- `AVGObserver`:收集目标Tensor的平均值作为量化scale +- `MSEObserver`:收集最大绝对值并通过最小化MSE误差,收集量化scale +- `EMDObserver`:收集最大绝对值并通过最小化EMD误差,收集量化scale +- `HistObserver`:将张量值收集到直方图中,并根据百分比计算量化scale +- `KLObserver`:以最小化浮点值分布与量化浮点值分布之间的 Kullback-Leibler散度计算量化scale +- `AbsMaxChannelWiseWeightObserver`:根据目标权重的通道维度,收集最大绝对值作为量化scale +- `MSEChannelWiseWeightObserver`:根据目标权重的通道维度,收集最大绝对值并通过最小化MSE误差,收集量化scale + +`Quanter`:对OP的输入或输出执行量化或模拟操作操作,同时还可以对输入Tensor的数值进行统计分析。每个量化训练算法对应一个Quanter,现已有的Quanter包含: +- `PACTQuanter` +- `WeightLSQplusQuanter`: +- `ActLSQplusQuanter` + +`QuantConfig`:在执行量化操作之前,首先要配置量化相关的信息,主要是指定每层的各个输入使用什么Observer或Quanter。可通过以下调用方法,根据需求添加每层的量化配置信息: +| **QuantConfig接口** | **传入参数及其含义** | **注意事项** | +|-----------------------------|-----------------------------------------|-----------------------------------------| +| add_layer_config | `layer`: 指定模型的某一层或某些层的list

`activation`: 用于量化激活以上指定layer的`Observer`或`Quanter`

`weight`: 用于量化权重以上指定layer的`Observer`或`Quanter` | 此方法是最高优的要求,这些层的量化方式将按照这里的要求,而不是按照其他配置进行量化 +| add_name_config | `layer_name`: 指定模型的某一层的名字或某些层的名字的list

`activation`: 用于量化激活以上指定layer的`Observer`或`Quanter`

`weight`: 用于量化权重以上指定layer的`Observer`或`Quanter` | 此方法的优先级仅此于add_layer_config +| add_type_config | `layer_type`:指定需要量化的layer类型,可以为单个layer类型,或一个layer类型的list,layer类型必须为paddle.nn.Layer的子类

`activation`: 用于量化激活以上指定layer的`Observer`或`Quanter`

`weight`: 用于量化权重以上指定layer的`Observer`或`Quanter` | 此方法的优先级此于add_name_config,指定需要量化的layer类型,如nn.Linear, 量化时将对所有nn.Linear进行量化,并指定weight和activation的quanter类型 +| add_qat_layer_mapping | `source`:被量化的layer

`target`:量化的layer | source和target必须为paddle.nn.Layer的子类;当指定需要量化的layer类型,如果在框架中没有实现该层量化时,需要指定该layer的量化层,比如ColumnParallelLinear对应PaddleSlim中实现的QuantizedColumnParallelLinear + +### 2. PTQ接口介绍: +| **PTQ接口** | **传入参数及其含义** | **介绍** | +|-----------------------------|-----------------------------------------|-----------------------------------------| +| quantize | `model`:需要被量化的模型
`inplace`:inplace=True时,该模型会被inplace的量化;inplace=False时,不改变原模型,并且会return一个量化的模型 | 对模型需要量化的层插入observers以采样到需要的量化信息 +| convert | `model`:需要被转化的量化模型
`inplace`:inplace=True时,该模型会被inplace的量化;inplace=False时,不改变原模型,并且会return一个量化的模型 | 将模型转化成onnx形式,进行此步骤之后才能对量化模型进行验证、导出成静态图等 + + +## 使用示例 +```python +import paddle +import paddleslim +from paddle.vision.models import mobilenet_v1 +from paddle.quantization import QuantConfig +from paddle.quantization import PTQ +from paddleslim.quant.observers import HistObserver, KLObserver, EMDObserver, MSEObserver, AVGObserver, MSEChannelWiseWeightObserver, AbsMaxChannelWiseWeightObserver + +# create the model +model = mobilenet_v1() + +# define QuantConfig +q_config = QuantConfig(activation=None, weight=None) + +# define act_quanter and weight_quanter +act_quanter = MSEObserver() +weight_quanter = MSEObserver() + +# map ColumnParallelLinear to QuantizedColumnParallelLinear +q_config.add_qat_layer_mapping(ColumnParallelLinear, + QuantizedColumnParallelLinear) +# map RowParallelLinear to QuantizedRowParallelLinear +q_config.add_qat_layer_mapping(RowParallelLinear, + QuantizedRowParallelLinear) +# for each layer if type in [paddle.nn.Linear, ColumnParallelLinear, RowParallelLinear] +# make them quantizable +q_config.add_type_config( + [paddle.nn.Linear, ColumnParallelLinear, RowParallelLinear], + activation=activation, + weight=weight, + ) + + +ptq = PTQ(q_config) +model = ptq.quantize(model, inplace=True) + +# ptq sample +ptq_step = 100 +for step, data in enumerate(dataloader): + pred = model(data) + if step == ptq_step: + break + +# convert to quant model that can evaluate and export +model = ptq.convert(model, inplace=True) +``` diff --git a/docs/zh_cn/tutorials/quant/dygraph/quant_aware_training_tutorial.md b/docs/zh_cn/tutorials/quant/quant_aware_training.md similarity index 100% rename from docs/zh_cn/tutorials/quant/dygraph/quant_aware_training_tutorial.md rename to docs/zh_cn/tutorials/quant/quant_aware_training.md diff --git a/docs/zh_cn/tutorials/quant/Analysis.md b/docs/zh_cn/tutorials/quant/static/Analysis.md similarity index 100% rename from docs/zh_cn/tutorials/quant/Analysis.md rename to docs/zh_cn/tutorials/quant/static/Analysis.md diff --git a/paddleslim/quant/advanced/__init__.py b/paddleslim/quant/advanced/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1f0744ecf40c1972ccb1b5c2e2e4daff98b5e8a8 --- /dev/null +++ b/paddleslim/quant/advanced/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import gptq +from . import smooth +from . import shift +from . import piecewise_search +from . import sample +from . import layerwise_quant_error +from . import utils_layers + +from .gptq import * +from .smooth import * +from .shift import * +from .piecewise_search import * +from .sample import * +from .layerwise_quant_error import * +from .utils_layers import * + +__all__ = [] +__all__ += gptq.__all__ +__all__ += smooth.__all__ +__all__ += shift.__all__ +__all__ += piecewise_search.__all__ +__all__ += sample.__all__ +__all__ += layerwise_quant_error.__all__ +__all__ += utils_layers.__all__ \ No newline at end of file diff --git a/paddleslim/quant/advanced/gptq.py b/paddleslim/quant/advanced/gptq.py new file mode 100644 index 0000000000000000000000000000000000000000..96566858febc63bbef2168fb436ce4ebaea0ee00 --- /dev/null +++ b/paddleslim/quant/advanced/gptq.py @@ -0,0 +1,185 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import time +import numpy as np +import paddle +import paddle.nn as nn +from paddle.distributed.fleet.meta_parallel import ColumnParallelLinear, RowParallelLinear + +from .utils import compute_scales +__all__ = ['GPTQ'] + + +class GPTQ(nn.Layer): + def __init__(self, + layer, + quant_bits=4, + weight_quant_method='abs_max_channel_wise'): + ''' + The implementation of GPTQ(https://arxiv.org/abs/2210.17323). + The codes here are based on https://github.com/IST-DASLab/gptq. + + Args: + layer (paddle.nn.Layer): Layer object. + quant_bits (int, optional): Number of bits to quantize the weight. Default: 4. + weight_quant_method (str, optional): Method of weight quantization. Choosen from abs_max, abs_max_channel_wise and avg. Default: abs_max_channel_wise. + + Examples: + .. code-block:: python + + from paddleslim.quant.advanced import GPTQ + for cur_name, cur_layer in model.named_sublayers(): + if type(cur_layer) == paddle.nn.Linear: + gptq_layer = GPTQ(cur_layer) + # sample data + for data in dataloader(): + model(data) + # quant weight + gptq_layer.fasterquant() + ''' + super(GPTQ, self).__init__() + self.layer = layer + assert hasattr(layer, + 'weight'), "Layer {} has no attribute 'weight'".format( + layer.full_name()) + assert type(self.layer) in [ + nn.Linear, ColumnParallelLinear, RowParallelLinear + ], "Currently, GPTQ only supports linear layer and ColumnParallelLinear/RowParallelLinear layer" + + weight = layer.weight.t() + + self.rows = weight.shape[0] + self.columns = weight.shape[1] + self.hessian = paddle.zeros( + (self.columns, self.columns), dtype='float32') + self.nsamples = 0 + self.quantized = False + self.weight_quant_method = weight_quant_method + self.quant_bits = (1 << (quant_bits - 1)) - 1 + + def forward(self, input): + if not self.quantized: + inp = input[0] if type(input) == tuple else input + inp = inp.cast('float32') + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if type(self.layer) in [ + nn.Linear, ColumnParallelLinear, RowParallelLinear + ]: + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + + self.hessian *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + + inp = math.sqrt(2 / self.nsamples) * inp + self.hessian += paddle.matmul(inp, inp.t()) + del inp + return self.layer(input) + + def fasterquant(self, + blocksize=128, + percdamp=.01, + groupsize=-1, + actorder=True): + print('quant', self.layer.full_name()) + W = self.layer.weight.t().cast('float32') + weight_scale = compute_scales(W.t(), method=self.weight_quant_method) + weight_scale /= self.quant_bits + tick = time.time() + + H = self.hessian + del self.hessian + dead = paddle.where(paddle.diag(H) == 0) + H[dead, dead] = 1 + W[:, dead] = 0 + del dead + if actorder: + perm = paddle.argsort(paddle.diag(H), descending=True) + W = W.transpose((1, 0)) + W = W[perm].transpose((1, 0)) + H = H[perm].transpose((1, 0)) + H = H[perm].transpose((1, 0)) + + Losses = paddle.zeros_like(W) + Q = paddle.zeros_like(W) + + damp = percdamp * paddle.mean(paddle.diag(H)) + diag = paddle.arange(self.columns) + H[diag, diag] += damp + + H = paddle.inverse(H) + H = paddle.linalg.cholesky(H, upper=True) + Hinv = H + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + W1 = W[:, i1:i2] + Q1 = paddle.zeros_like(W1) + Err1 = paddle.zeros_like(W1) + Losses1 = paddle.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + if groupsize != -1: + if (i1 + i) % groupsize == 0: + weight_scale = compute_scales( + W[:, (i1 + i):(i1 + i + groupsize)].t(), + method=self.weight_quant_method) + weight_scale /= self.quant_bits + + q = paddle.clip( + paddle.round(w / weight_scale), -self.quant_bits - 1, + self.quant_bits) + q = q * weight_scale + Q1[:, i] = q + Losses1[:, i] = (w - q)**2 / d**2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + del w, d, q, err1 + paddle.device.cuda.empty_cache() + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + del Q1, Losses1 + if Hinv[i1:i2, i2:].shape[1] != 0: + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + del Err1, W1, Hinv1 + paddle.device.cuda.empty_cache() + + print('time %.2f' % (time.time() - tick)) + print('error', paddle.sum(Losses).item()) + + if actorder: + invperm = paddle.argsort(perm) + Q = Q.transpose((1, 0)) + Q = Q[invperm].transpose((1, 0)) + del invperm, perm + + param = self.layer.weight + Q = Q.t().cast(self.layer.weight.dtype) + paddle.assign(Q, output=param) + + self.quantized = True + del H, Q, Hinv, W, Losses + paddle.device.cuda.empty_cache() diff --git a/paddleslim/quant/advanced/layerwise_quant_error.py b/paddleslim/quant/advanced/layerwise_quant_error.py new file mode 100644 index 0000000000000000000000000000000000000000..ce03d198cd09a943cedde26b0338ff50f65b8e12 --- /dev/null +++ b/paddleslim/quant/advanced/layerwise_quant_error.py @@ -0,0 +1,85 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import numpy as np +from .utils import compute_scales +from .metrics import mse_loss + +__all__ = ['LayerWiseQuantError'] + + +class LayerWiseQuantError(nn.Layer): + def __init__(self, + layer, + weight_bits=8, + act_bits=8, + weight_quant_method='abs_max_channel_wise', + act_quant_method='abs_max', + loss_function=mse_loss): + ''' + LayerWiseQuantError computes the loss bewteen the output of the layer and the outout of the quantized layer. + + Args: + layer (paddle.nn.Layer): Layer object. + quant_bits (int, optional): Number of bits to quantize the weight. Default: 8. + act_bits (int, optional): Number of bits to quantize the activation. Default: 8. + weight_quant_method (str, optional): The method of weight quantization. Choosen from abs_max, abs_max_channel_wise and avg. Default: abs_max_channel_wise. + act_quant_method (str, optional): The method of activation quantization. Choosen from abs_max, avg. Default: abs_max. + + Examples: + .. code-block:: python + + from paddleslim.quant.advanced import GPTQ + for cur_name, cur_layer in model.named_sublayers(): + if type(cur_layer) == paddle.nn.Linear: + gptq_layer = LayerWiseQuantError(cur_layer) + + for data in dataloader(): + model(data) + + for cur_name, cur_layer in model.named_sublayers(): + if type(cur_layer) == LayerWiseQuantError: + print(cur_name, cur_layer.losses.mean()) + ''' + self.layer = layer + self.weight = layer.weight + self.weight_bits = weight_bits + self.act_bits = act_bits + self.weight_method = weight_quant_method + self.act_method = act_quant_method + self.loss_function = loss_function + self.losses = [] + + def forward(self, input): + act = input[0] if type(input) == tuple else input + origin_out = paddle.matmul(act, self.weight) + bnt = (1 << (self.weight_bits - 1)) - 1 + quant_scale = compute_scales( + self.weight.cast('float32'), + method=self.weight_method).cast(self.weight.dtype) + quant_weight = paddle.clip( + paddle.round(self.weight / quant_scale * bnt), -bnt - 1, bnt) + quant_dequant_weight = quant_weight / bnt * quant_scale + + bnt = (1 << (self.act_bits - 1)) - 1 + quant_scale = compute_scales(act, method=self.act_method) + quant_act = paddle.clip( + paddle.round(act / quant_scale * bnt), -bnt - 1, bnt) + quant_dequant_act = quant_act / bnt * quant_scale + quant_out = paddle.matmul(quant_dequant_act, quant_dequant_weight) + loss = self.loss_function(origin_out, quant_out) + self.losses.append(loss) + return self.layer(input) diff --git a/paddleslim/quant/advanced/metrics.py b/paddleslim/quant/advanced/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..b00f710c7b3789e789f97b7f7304f06d8b21ceb0 --- /dev/null +++ b/paddleslim/quant/advanced/metrics.py @@ -0,0 +1,57 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + + +def mse_loss(y_pred, y_real, reduction='mean'): + if y_pred.shape != y_real.shape: + raise ValueError( + f'Can not compute mse loss for tensors with different shape. ' + f'({y_pred.shape} and {y_real.shape})') + + mse = (y_pred - y_real)**2 + + if reduction == 'mean': + return paddle.mean(mse) + elif reduction == 'sum': + return paddle.sum(mse) + elif reduction == 'none': + return mse + else: + raise ValueError(f'Unsupported reduction method.') + + +def snr_loss(y_pred, y_real, reduction='mean'): + if y_pred.shape != y_real.shape: + raise ValueError( + f'Can not compute snr loss for tensors with different shape. ' + f'({y_pred.shape} and {y_real.shape})') + reduction = str(reduction).lower() + + y_pred = y_pred.flatten(start_axis=1) + y_real = y_real.flatten(start_axis=1) + + noise_power = paddle.pow(y_pred - y_real, 2).sum(axis=-1) + signal_power = paddle.pow(y_real, 2).sum(axis=-1) + snr = (noise_power) / (signal_power + 1e-7) + + if reduction == 'mean': + return paddle.mean(snr) + elif reduction == 'sum': + return paddle.sum(snr) + elif reduction == 'none': + return snr + else: + raise ValueError(f'Unsupported reduction method.') diff --git a/paddleslim/quant/advanced/piecewise_search.py b/paddleslim/quant/advanced/piecewise_search.py new file mode 100644 index 0000000000000000000000000000000000000000..d8e4a1f2d734caf46f9a727dd2ec307099f35ce6 --- /dev/null +++ b/paddleslim/quant/advanced/piecewise_search.py @@ -0,0 +1,161 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +from .utils import compute_scales, k_means +from .metrics import mse_loss + +__all__ = ['PieceWiseSearch'] + + +class PieceWiseSearch(): + def __init__(self, + k_piece=1, + bits_length=8, + search_piece=False, + search_alpha_min=0.2, + search_alpha_max=0.8, + search_scale_min=1., + search_scale_max=5., + weight_quant_method='abs_max_channel_wise', + act_quant_method='abs_max', + loss_function=mse_loss): + ''' + PieceWiseSearch provides to search k_piece, alpha and scale. + + Args: + k_piece (int): Number of k pieces. Default: 1. + bits_length (int): Number of bits to quantize the weight. Default: 8. + search_piece (bool): Whether to search the best k piece. Default: False. + search_alpha_min (float): Minimum alpha for search. Default: 0.2. + search_alpha_max (float): Maximum alpha for search. Default: 0.8. + search_scale_min (float): Minimum scale for search. Default: 1. + search_scale_max (float): Maximum scale for search. Default: 5. + weight_quant_method (str): Weight quantization method. Choosen from abs_max, abs_max_channel_wise and avg. Default: abs_max_channel_wise. + act_quant_method (str): Activation quantization method. Choosen from abs_max, avg. Default: abs_max. + loss_function (callable): Loss function. Default: mse_loss. + ''' + self.k_piece = k_piece + self.bits_length = bits_length + self.search_piece = search_piece + self.search_alpha_min = search_alpha_min + self.search_alpha_max = search_alpha_max + self.search_scale_min = search_scale_min + self.search_scale_max = search_scale_max + self.weight_quant_method = weight_quant_method + self.act_quant_method = act_quant_method + self.bnt = (1 << (bits_length - 1)) - 1 + self.loss_function = loss_function + + def search(self, layer_name, sampled_input, act_abs_max, weight): + act = sampled_input + act.stop_gradient = True + print('[smooth search] search input of %s' % layer_name) + + origin_out = paddle.matmul(act, weight) + w_abs_max = weight.abs().max(axis=-1, keepdim=True) + rw_abs_max = w_abs_max.reshape(act_abs_max.shape) + np_act_abs_max = np.array(act_abs_max) + np_rw_abs_max = np.array(rw_abs_max) + + smooth_scale_out = None + global_loss = float('inf') + best_scale = None + + for k_piece in range(1, self.k_piece + 1): + if not self.search_piece: + k_piece = self.k_piece + print('Search {} Piece'.format(k_piece)) + centroids, labels = k_means(act_abs_max, k_piece) + piece = ['piece_{}'.format(a) for a in range(len(centroids))] + for i in range(len(centroids)): + # print('search for piece {}; centroids value is {}'.format( + # piece[i], centroids[centroids.argsort()[i]].numpy())) + alpha = self.search_alpha_min + alpha_max = self.search_scale_max + calibration_loss = float('inf') + final_alpha = None + mask_for_search = paddle.where(labels == centroids.argsort()[i], + 1., 0.) + mask_for_ones = paddle.where(mask_for_search == 0., 1., 0.) + + while alpha <= alpha_max: + if alpha < 1: + alpha += 0.01 + if alpha >= self.search_alpha_max: + alpha = 1. + else: + alpha += 0.5 + + alpha = round(alpha, 2) + + if alpha < 1: + s = (np.power(np_act_abs_max, alpha) / np.power( + np_rw_abs_max, 1. - alpha)).clip(min=1e-5) + s = paddle.to_tensor(s, dtype='float32') + smooth_scale = s * mask_for_search + else: + smooth_scale = alpha * mask_for_search + + if smooth_scale_out is not None: + mask_for_ones_new = paddle.where( + smooth_scale_out == 0., 1., 0.) + mask_for_ones *= mask_for_ones_new + smooth_scale_ = smooth_scale_out + smooth_scale + smooth_scale_tmp = smooth_scale_ + mask_for_ones + else: + smooth_scale_tmp = smooth_scale + mask_for_ones + + new_act = act / smooth_scale_tmp + new_weight = weight * smooth_scale_tmp.reshape( + w_abs_max.shape) + + quant_scale = compute_scales( + new_act, method=self.act_quant_method) + quant_act = paddle.clip( + paddle.round(new_act / quant_scale * self.bnt), + -self.bnt - 1, self.bnt) + quant_dequant_act = quant_act / self.bnt * quant_scale + + quant_scale = compute_scales( + new_weight, method=self.weight_quant_method) + quant_weight = paddle.clip( + paddle.round(new_weight / quant_scale * self.bnt), + -self.bnt - 1, self.bnt) + quant_dequant_weight = quant_weight / self.bnt * quant_scale + new_out = paddle.matmul(quant_dequant_act, + quant_dequant_weight) + + cur_loss = self.loss_function(origin_out, new_out) + if cur_loss <= calibration_loss: + calibration_loss = cur_loss + final_smooth_scale = smooth_scale + final_alpha = alpha + + # print("Layer {} Piece {}, loss: {}, alpha : {}".format( + # layer_name, piece[i], float(calibration_loss), final_alpha)) + if smooth_scale_out is None: + smooth_scale_out = final_smooth_scale + else: + smooth_scale_out += final_smooth_scale + + if cur_loss < global_loss: + global_loss = cur_loss + best_scale = smooth_scale_out + if self.search_piece: + print('Find Better K-Piece {}'.format(k_piece)) + if not self.search_piece: + break + return best_scale diff --git a/paddleslim/quant/advanced/sample.py b/paddleslim/quant/advanced/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..a8241614af5810afe59fce4ea98623f4b1dd54b9 --- /dev/null +++ b/paddleslim/quant/advanced/sample.py @@ -0,0 +1,47 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +__all__ = ['MultiStepSampler', 'EMASampler'] + + +class MultiStepSampler(): + def __init__(self): + pass + + def sample(self, x, sampled_x=None, layer_name=None): + return paddle.concat([x, sampled_x], axis=1) + + +class EMASampler(): + def __init__(self): + self.ema_beta = 0.98 + self.ema_step = {} + self.sampled = {} + + def sample(self, x, sampled_x=None, layer_name=None): + + if layer_name not in self.ema_step: + self.sampled[layer_name] = (1 - self.ema_beta) * x + self.ema_step[layer_name] = 0 + return self.sampled[layer_name] + else: + v_ema = self.ema_beta * self.sampled[layer_name] + ( + 1 - self.ema_beta) * x + self.sampled[layer_name] = v_ema + v_ema_corr = v_ema / float( + (1 - np.power(self.ema_beta, self.ema_step[layer_name] + 1))) + self.ema_step[layer_name] += 1 + return v_ema_corr diff --git a/paddleslim/quant/advanced/shift.py b/paddleslim/quant/advanced/shift.py new file mode 100644 index 0000000000000000000000000000000000000000..70aea070f877008820c19e9aed1aa48cf20618e9 --- /dev/null +++ b/paddleslim/quant/advanced/shift.py @@ -0,0 +1,254 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from re import sub +import numpy as np +import paddle +from .utils import get_ln_linear_info, find_parent_layer_and_sub_name +from .utils_layers import ShiftSmoothHelpLayer, WOBiasHelpLayer + +__all__ = ['Shift'] + + +class Shift(): + def __init__(self, + model, + model_config, + shift_all_linears=False, + sample_function=None): + ''' + Shift is the implementation of Outlier Suppression+(https://arxiv.org/abs/2304.09145). + Currently, Shift only supports linear layer and ColumnParallelLinear/RowParallelLinear layer. + + Args: + model(paddle.nn.Layer, required): the model to be shifted + model_config (dict, required): the config of model to be shifted + shift_all_linears (bool, optional): whether to shift all linear layers + sample_function (function, optional): the function to sample data + + Examples: + .. code-block:: python + + from paddleslim.quant.advanced import Shift + shift = Shift(model) + for data in dataloader(): + model(data) + shift.step += 1 + shift.update_weight() + ''' + + self.model = model + self.model_config = model_config + self.fused_qkv = model_config.get("fused_qkv", True) + self.linear_flag = model_config.get("linear_flag", 'linear') + self.norm_flag = model_config.get("norm_flag", 'norm') + self.parallel_ffn = model_config.get("parallel_ffn", False) + self.skip_norm_list = model_config.get("skip_norm_list", []) + self.shift_all_linears = shift_all_linears + self.sample_function = sample_function + + self.layer_order = [] + self.zero_point_dict = {} + self.smooth_scale_dict = {} + self.glabal_min_max = {} + + self.model.eval() + self.step = 0 + self.print_step = 1 + self.got_shift_layers = False + self.help_layers_ready = False + self._apply_hook() + + def _apply_hook(self): + self._forward_hook_list = [] + for _, sub_layer in self.model.named_sublayers(): + if self.norm_flag in sub_layer.full_name( + ) or self.linear_flag in sub_layer.full_name(): + forward_pre_hook_handle = sub_layer.register_forward_pre_hook( + self._forward_pre_hook) + self._forward_hook_list.append(forward_pre_hook_handle) + if type(sub_layer) == ShiftSmoothHelpLayer: + self.help_layers_ready = True + forward_pre_hook_handle = sub_layer.register_forward_pre_hook( + self._forward_pre_hook) + self._forward_hook_list.append(forward_pre_hook_handle) + + def _get_shift_layers(self): + self.ln_linear_dict, self.linear_ln_dict = get_ln_linear_info( + self.layer_order, self.norm_flag, self.linear_flag, self.fused_qkv, + self.parallel_ffn, self.skip_norm_list) + assert len(self.ln_linear_dict) > 0, 'No LN/Linear pair found' + for key in self.ln_linear_dict: + print('shift pair LN {} : Linear {}'.format( + key, self.ln_linear_dict[key])) + if self.shift_all_linears: + if not self.help_layers_ready: + rest_linears = [ + l for l in self.layer_order + if self.linear_flag in l and l not in self.linear_ln_dict + ] + print('Preparing shift layers', rest_linears) + for cur_name, sub_layer in self.model.named_sublayers(): + if sub_layer.full_name() in rest_linears: + new_layer = ShiftSmoothHelpLayer(sub_layer) + parent_layer, sub_name = find_parent_layer_and_sub_name( + self.model, cur_name) + setattr(parent_layer, sub_name, new_layer) + forward_pre_hook_handle = new_layer.register_forward_pre_hook( + self._forward_pre_hook) + self._forward_hook_list.append(forward_pre_hook_handle) + self.got_shift_layers = True + + def _forward_pre_hook(self, layer, input): + ''' + when step 0, forward once and collect shift layers. + when step >1, sample scale. + ''' + if self.step == 0 and layer.full_name() in self.layer_order: + self.step += 1 + if self.step == 0: + self.layer_order.append(layer.full_name()) + return input + if self.step == 1: + if not self.got_shift_layers: + self._get_shift_layers() + if self.step > 0: + if layer.full_name() in self.linear_ln_dict.keys(): + self._sample_zero_point(input, + self.linear_ln_dict[layer.full_name()]) + if type(layer) == ShiftSmoothHelpLayer: + self._sample_zero_point(input, layer.full_name()) + + return input + + def _sample_zero_point(self, input, ln_name): + x = input[0] if type(input) == tuple else input + x = x.cast('float32') + x.stop_gradient = True + + zero_point = x.mean(axis=(0, 1)) if len(x.shape) > 2 else x.mean(axis=1) + _min = x.min(axis=(0, 1)) if len(x.shape) > 2 else x.min(axis=1) + _max = x.max(axis=(0, 1)) if len(x.shape) > 2 else x.max(axis=1) + + if ln_name not in self.zero_point_dict or ln_name not in self.glabal_min_max: + + if self.sample_function is None: + self.glabal_min_max[ln_name] = _min, _max + self.zero_point_dict[ln_name] = (_min + _max) / 2 + else: + self.zero_point_dict[ln_name] = zero_point + + else: + if self.sample_function is not None: + self.zero_point_dict[ln_name] = self.sample_function.sample( + zero_point, self.zero_point_dict[ln_name], ln_name) + else: + global_min, global_max = self.glabal_min_max[ln_name] + global_min = global_min if global_min < _min else _min + global_max = global_max if global_max > _max else _max + self.glabal_min_max[ln_name] = global_min, global_max + self.zero_point_dict[ln_name] = (global_min + global_max) / 2 + + # per step print once + if self.print_step == self.step: + print('[shift] Step [{}]: {}. zero_point min: {}, max: {}'.format( + self.step, ln_name, + round(float(self.zero_point_dict[ln_name].min()), 5), + round(float(self.zero_point_dict[ln_name].max()), 5))) + if ln_name == list(self.linear_ln_dict.values())[-1]: + self.print_step += 1 + + def update_weight(self): + ''' + update weight of smooth layers. + firstly compute s and update linear's weight, + then update LN's weight by corresponding linear and s + ''' + # update linear weight + for _, sub_layer in self.model.named_sublayers(): + layer_name = sub_layer.full_name() + if layer_name in self.linear_ln_dict: + ln_name = self.linear_ln_dict[layer_name] + shift_bias = None + for param in sub_layer.parameters(include_sublayers=False): + if 'w_0' in param.name: + zero_point = self.zero_point_dict[ln_name].squeeze() + shift_bias = paddle.matmul(zero_point, + param.cast('float32')) + print("[shift] param: {}, zero_point min: {}, max: {}". + format(param.name, + float(zero_point.min()), + float(zero_point.max()))) + break + + if not hasattr(sub_layer, "bias") or sub_layer.bias is None: + sub_layer.bias = paddle.create_parameter( + shape=shift_bias.shape, + dtype=sub_layer.weight.dtype, + default_initializer=paddle.nn.initializer.Constant(0.0), + is_bias=True, ) + for param in sub_layer.parameters(include_sublayers=False): + if 'b_0' in param.name: + shift_bias = shift_bias + param + paddle.assign( + shift_bias.cast(param.dtype), output=param) + print("[shift] update linear bias: {}.".format( + param.name)) + break + + # update LN weight + for cur_name, sub_layer in self.model.named_sublayers(): + layer_name = sub_layer.full_name() + if layer_name in self.ln_linear_dict: + if not hasattr(sub_layer, "bias") or sub_layer.bias is None: + help_layer = WOBiasHelpLayer(sub_layer) + parent_layer, sub_name = find_parent_layer_and_sub_name( + self.model, cur_name) + setattr(parent_layer, sub_name, help_layer) + sub_layer = help_layer + + for param in sub_layer.parameters(include_sublayers=False): + if "b_0" in param.name: + zero_point = self.zero_point_dict[layer_name].squeeze() + param_tmp = param - zero_point + paddle.assign(param_tmp.cast(param.dtype), output=param) + print("[shift] update layer_norm bias {}.".format( + param.name)) + break + + # update ShiftSmoothRowParallelLinear weight + for _, sub_layer in self.model.named_sublayers(): + if type(sub_layer) == ShiftSmoothHelpLayer: + layer_name = sub_layer.full_name() + linear_name = sub_layer.layer.full_name() + zero_point = self.zero_point_dict[layer_name].squeeze() + print( + "[shift ShiftSmoothHelpLayer] before param: {}, shift_bias min: {}, max: {}". + format(linear_name, + float(sub_layer.shift_bias.cast("float32").min()), + float(sub_layer.shift_bias.max().cast("float32")))) + sub_layer.convert_weight(shift_bias=zero_point) + + print( + "[shift ShiftSmoothHelpLayer] after param: {}, shift_bias min: {}, max: {}". + format(linear_name, + float(sub_layer.shift_bias.cast("float32").min()), + float(sub_layer.shift_bias.max().cast("float32")))) + + self._remove_hook() + paddle.device.cuda.empty_cache() + + def _remove_hook(self): + for hook in self._forward_hook_list: + hook.remove() diff --git a/paddleslim/quant/advanced/smooth.py b/paddleslim/quant/advanced/smooth.py new file mode 100644 index 0000000000000000000000000000000000000000..d2bd51789ee0ec3368fd709d4371d897f965a5c5 --- /dev/null +++ b/paddleslim/quant/advanced/smooth.py @@ -0,0 +1,273 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +from .utils import get_ln_linear_info, find_parent_layer_and_sub_name +from .utils_layers import ShiftSmoothHelpLayer, WOBiasHelpLayer +__all__ = ['Smooth'] + + +class Smooth(): + def __init__( + self, + model, + model_config, + alpha=0.5, + smooth_all_linears=False, + sample_function=None, + search_function=None, ): + ''' + Smooth is an updated version of SmoothQuant(https://arxiv.org/abs/2211.10438). + Compared with SmoothQuant, the piecewise smooth algorithm has the following updates: + It supports the search functions to search best smooth scale. + It supports the sample function to sample smooth scale. + Currently, Smooth only supports linear layer and ColumnParallelLinear/RowParallelLinear layer. + + + Args: + model(paddle.nn.Layer, required): the model to be smoothed + model_config (dict, required): the config of model to be smoothed + alpha(float, optional): smoothing parameter. Default: 0.5 + smooth_all_linears(bool, optional): whether to smooth all linears. Default: False + sample_function(function, optional): the function of sample to sample data. Default: None + sample_start_step(int, optional): the step of sample data by using sample_function. Default: 0 + search_function(function, optional): the function of search smooth scale. Default: None + + Examples: + .. code-block:: python + + from paddleslim.quant.advanced import Smooth + smooth = Smooth(model) + for data in dataloader(): + model(data) + smooth.step += 1 + smooth.update_weight() + ''' + self.model = model + self.model_config = model_config + self.fused_qkv = model_config.get("fused_qkv", True) + self.linear_flag = model_config.get("linear_flag", 'linear') + self.norm_flag = model_config.get("norm_flag", 'norm') + self.parallel_ffn = model_config.get("parallel_ffn", False) + self.skip_norm_list = model_config.get("skip_norm_list", []) + + self.alpha = alpha + self.smooth_all_linears = smooth_all_linears + self.sample_function = sample_function + self.search_function = search_function + + self.model.eval() + self.step = 0 + self.print_step = 1 + self.got_smooth_layers = False + self.help_layers_ready = False + self.layer_order = [] + self.scale_dict = {} + self.smooth_scale_dict = {} + self.sampled_inputs = {} + self._apply_hook() + + def _apply_hook(self): + self._forward_hook_list = [] + for _, sub_layer in self.model.named_sublayers(): + if self.norm_flag in sub_layer.full_name( + ) or self.linear_flag in sub_layer.full_name(): + forward_pre_hook_handle = sub_layer.register_forward_pre_hook( + self._forward_pre_hook) + self._forward_hook_list.append(forward_pre_hook_handle) + if type(sub_layer) == ShiftSmoothHelpLayer: + self.help_layers_ready = True + forward_pre_hook_handle = sub_layer.register_forward_pre_hook( + self._forward_pre_hook) + self._forward_hook_list.append(forward_pre_hook_handle) + + def _get_smooth_layers(self): + self.ln_linear_dict, self.linear_ln_dict = get_ln_linear_info( + self.layer_order, self.norm_flag, self.linear_flag, self.fused_qkv, + self.parallel_ffn, self.skip_norm_list) + assert len(self.ln_linear_dict) > 0, 'No LN/Linear pair found' + for key in self.ln_linear_dict: + print('smooth pair LN {} : Linear {}'.format( + key, self.ln_linear_dict[key])) + if self.smooth_all_linears: + if not self.help_layers_ready: + rest_linears = [ + l for l in self.layer_order + if self.linear_flag in l and l not in self.linear_ln_dict + ] + print('Preparing smooth layers', rest_linears) + for cur_name, sub_layer in self.model.named_sublayers(): + if sub_layer.full_name() in rest_linears: + new_layer = ShiftSmoothHelpLayer(sub_layer) + parent_layer, sub_name = find_parent_layer_and_sub_name( + self.model, cur_name) + setattr(parent_layer, sub_name, new_layer) + forward_pre_hook_handle = new_layer.register_forward_pre_hook( + self._forward_pre_hook) + self._forward_hook_list.append(forward_pre_hook_handle) + + self.got_smooth_layers = True + + def _forward_pre_hook(self, layer, input): + ''' + when step 0, forward once and collect smooth layers. + ''' + if self.step == 0 and layer.full_name() in self.layer_order: + self.step += 1 + if self.step == 0: + self.layer_order.append(layer.full_name()) + return input + if self.step == 1: + if not self.got_smooth_layers: + self._get_smooth_layers() + if self.step > 0: + if layer.full_name() in self.linear_ln_dict.keys(): + self._sample_scale(input, + self.linear_ln_dict[layer.full_name()]) + if type(layer) == ShiftSmoothHelpLayer: + self._sample_scale(input, layer.full_name()) + + return input + + def _sample_scale(self, input, ln_name): + x = input[0] if type(input) == tuple else input + x.stop_gradient = True + x_abs_max = x.abs().max(axis=1, keepdim=True) + x_abs_max = x_abs_max.max(axis=0) + + if ln_name not in self.scale_dict: + self.sampled_inputs[ln_name] = x + self.scale_dict[ln_name] = x_abs_max + else: + if self.sample_function is not None: + self.sampled_inputs[ln_name] = self.sample_function.sample( + x, self.sampled_inputs[ln_name], ln_name) + else: + self.sampled_inputs[ln_name] = x + tmp1 = paddle.concat([x_abs_max, self.scale_dict[ln_name]], axis=0) + self.scale_dict[ln_name] = tmp1.max(axis=0, keepdim=True) + + # per step print once + if self.print_step == self.step: + print('[Smooth] Step [{}]: {}. abs_min: {}, abs_max: {}'.format( + self.step, ln_name, + float(self.scale_dict[ln_name].cast("float32").min()), + float(self.scale_dict[ln_name].cast("float32").max()))) + if ln_name == list(self.linear_ln_dict.values())[-1]: + self.print_step += 1 + + def update_weight(self): + + for _, sub_layer in self.model.named_sublayers(): + layer_name = sub_layer.full_name() + ln_name = None + if layer_name in self.linear_ln_dict: + ln_name = self.linear_ln_dict[layer_name] + if type(sub_layer) == ShiftSmoothHelpLayer: + ln_name = layer_name + if ln_name is not None: + act_abs_max = self.scale_dict[ln_name].cast("float32") + sampled_input = self.sampled_inputs[ln_name].cast("float32") + for param in sub_layer.parameters(include_sublayers=False): + if 'w_0' in param.name: + weight = param.cast("float32") + if self.search_function is not None: + s = self.search_function.search( + layer_name, sampled_input, act_abs_max, weight) + else: + w_abs_max = weight.abs().max(axis=-1, keepdim=True) + rw_abs_max = w_abs_max.reshape(act_abs_max.shape) + act_abs_max_np = act_abs_max.numpy() + weight_abs_max_np = rw_abs_max.numpy() + s = ( + np.power(act_abs_max_np, self.alpha) / np.power( + weight_abs_max_np, 1 - self.alpha)).clip( + min=1e-5) + s = paddle.to_tensor(s, dtype="float32") + + self.smooth_scale_dict[ln_name] = s.cast(param.dtype) + break + + # update linear weight + for _, sub_layer in self.model.named_sublayers(): + layer_name = sub_layer.full_name() + if layer_name in self.linear_ln_dict: + for param in sub_layer.parameters(include_sublayers=False): + if 'w_0' in param.name: + ln_name = self.linear_ln_dict[layer_name] + print("[smooth] before linear [{}] weight, abs_max: {}". + format(param.name, + float(param.cast("float32").abs().max()))) + param_tmp = param * self.smooth_scale_dict[ + ln_name].transpose(perm=[1, 0]) + paddle.assign(param_tmp, output=param) + print("[smooth] after linear [{}] weight, abs_max: {}". + format(param.name, + float(param_tmp.abs().max().cast( + "float32")))) + + # update LN weight + for cur_name, sub_layer in self.model.named_sublayers(): + layer_name = sub_layer.full_name() + if layer_name in self.ln_linear_dict: + s = self.smooth_scale_dict[layer_name].squeeze() + for param in sub_layer.parameters(include_sublayers=False): + print("[smooth] before layer_norm {} weight, abs_max: {}". + format(param.name, + float(param.abs().max().cast("float32")))) + param_tmp = param / s + paddle.assign(param_tmp, output=param) + print("[smooth] after layer_norm {} weight, abs_max: {}". + format(param.name, + float(param_tmp.abs().max().cast("float32")))) + if not hasattr(sub_layer, "bias") or sub_layer.bias is None: + parent_layer, _ = find_parent_layer_and_sub_name( + self.model, cur_name) + if type(parent_layer) == WOBiasHelpLayer: + param = parent_layer.bias + print( + "[smooth WOBiasHelpLayer] before layer_norm {} bias, abs_max: {}". + format(param.name, + float(param.abs().max().cast("float32")))) + param_tmp = param / s + paddle.assign(param_tmp, output=param) + print( + "[smooth WOBiasHelpLayer] after layer_norm {} bias, abs_max: {}". + format(param.name, + float(param_tmp.abs().max().cast( + "float32")))) + + for _, sub_layer in self.model.named_sublayers(): + if type(sub_layer) == ShiftSmoothHelpLayer: + layer_name = sub_layer.full_name() + linear_name = sub_layer.layer.full_name() + smooth_scale = self.smooth_scale_dict[layer_name] + print( + "[smooth ShiftSmoothHelpLayer] param: {}, before weight, abs_max: {}". + format(linear_name, + float(sub_layer.weight.abs().max().cast("float32")))) + sub_layer.convert_weight(smooth_weight=smooth_scale) + print( + "[smooth ShiftSmoothHelpLayer] param: {}, after weight, abs_max: {}". + format(linear_name, + float(sub_layer.weight.abs().max().cast("float32")))) + + self._remove_hook() + paddle.device.cuda.empty_cache() + + def _remove_hook(self): + for hook in self._forward_hook_list: + hook.remove() + self._forward_hook_list = [] diff --git a/paddleslim/quant/advanced/utils.py b/paddleslim/quant/advanced/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..98f24ef124e10b345be6354f9c37fc352cfee6de --- /dev/null +++ b/paddleslim/quant/advanced/utils.py @@ -0,0 +1,108 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +from sklearn.cluster import KMeans + + +def k_means(weight, n_clusters, init='k-means++', max_iter=300): + org_shape = weight.shape + weight = paddle.to_tensor(weight) + weight = paddle.reshape(weight, [-1, 1]) + if n_clusters > weight.size: + n_clusters = weight.size + + k_means = KMeans( + n_clusters=n_clusters, + init=init, + n_init=10, + algorithm='lloyd', + max_iter=max_iter) + k_means.fit(weight) + + centroids = k_means.cluster_centers_ + labels = k_means.labels_ + labels = labels.reshape(org_shape) + return paddle.to_tensor(centroids.flatten()), paddle.to_tensor(labels) + + +def compute_scales(x, method='abs_max'): + if method == 'abs_max': + quant_scale = float(paddle.max(paddle.abs(x.flatten()))) + quant_scale = 1e-8 if quant_scale == 0.0 else quant_scale + elif method == 'avg': + quant_scale = paddle.abs(x.reshape((x.shape[0], -1))) + quant_scale = paddle.mean(paddle.max(quant_scale, axis=(1))) + elif method == 'abs_max_channel_wise': + reduce_axis = tuple([i for i in range(len(x.shape)) if i != 1]) + quant_scale = paddle.max(paddle.abs(x), axis=reduce_axis) + quant_scale = paddle.where(quant_scale == np.float32(0.0), + np.float32(1e-8), quant_scale) + return quant_scale + + +def find_parent_layer_and_sub_name(model, name): + last_idx = 0 + idx = 0 + parent_layer = model + while idx < len(name): + if name[idx] == '.': + sub_name = name[last_idx:idx] + if hasattr(parent_layer, sub_name): + parent_layer = getattr(parent_layer, sub_name) + last_idx = idx + 1 + idx += 1 + sub_name = name[last_idx:idx] + return parent_layer, sub_name + + +def get_ln_linear_info(ln_linear_list, norm_flag, linear_flag, fused_qkv, + llama_ffn, skip_norm_list): + # ln_linear_dict: {layer_norm_0: [linear_0, linear_1, linear_2]} + ln_linear_dict = {} + # linear_ln_dict: {linear_0: layer_norm_0, linear_1: layer_norm_0} + linear_ln_dict = {} + for i in range(len(ln_linear_list)): + layer_name = ln_linear_list[i] + if norm_flag in layer_name and layer_name not in skip_norm_list: + if i < len(ln_linear_list) - 1: + if not fused_qkv: + if linear_flag in ln_linear_list[i + + 1] and linear_flag in ln_linear_list[i + 2] and linear_flag in ln_linear_list[i + 3] and int( + layer_name.split('_') + [-1]) % 2 == 0: + ln_linear_dict[layer_name] = [ + ln_linear_list[i + 1], ln_linear_list[i + 2], + ln_linear_list[i + 3] + ] + linear_ln_dict[ln_linear_list[i + 1]] = layer_name + linear_ln_dict[ln_linear_list[i + 2]] = layer_name + linear_ln_dict[ln_linear_list[i + 3]] = layer_name + if linear_flag in ln_linear_list[i + 1] and int( + layer_name.split('_')[-1]) % 2 != 0: + if llama_ffn: + ln_linear_dict[layer_name] = [ + ln_linear_list[i + 1], ln_linear_list[i + 2] + ] + linear_ln_dict[ln_linear_list[i + 1]] = layer_name + linear_ln_dict[ln_linear_list[i + 2]] = layer_name + else: + ln_linear_dict[layer_name] = [ln_linear_list[i + 1]] + linear_ln_dict[ln_linear_list[i + 1]] = layer_name + else: + if linear_flag in ln_linear_list[i + 1]: + ln_linear_dict[layer_name] = [ln_linear_list[i + 1]] + linear_ln_dict[ln_linear_list[i + 1]] = layer_name + return ln_linear_dict, linear_ln_dict diff --git a/paddleslim/quant/advanced/utils_layers.py b/paddleslim/quant/advanced/utils_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..3b06a3ede83a3ef3892026d2ef87c0284e8c320c --- /dev/null +++ b/paddleslim/quant/advanced/utils_layers.py @@ -0,0 +1,92 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +from paddle import ParamAttr +from paddle.nn.initializer import Constant +from paddle.distributed.fleet.meta_parallel import ColumnParallelLinear, RowParallelLinear + +__all__ = ['ShiftSmoothHelpLayer', 'WOBiasHelpLayer'] + + +class ShiftSmoothHelpLayer(nn.Layer): + def __init__(self, layer): + super(ShiftSmoothHelpLayer, self).__init__() + self.weight = layer.weight + shift_shape = self.weight.shape[0] + if hasattr(layer, "bias") or layer.bias is None: + self.bias = paddle.create_parameter( + shape=[self.weight.shape[1]], + dtype=self.weight.dtype, + default_initializer=paddle.nn.initializer.Constant(0.0), + is_bias=True, ) + layer.bias = self.bias + self.layer = layer + self.layer_type = type(layer) + # add + self.shift_bias = self.create_parameter( + shape=[shift_shape], + attr=ParamAttr(initializer=Constant(value=0.)), + dtype=self.weight.dtype) + # multiply + self.smooth_weight = self.create_parameter( + shape=[shift_shape], + attr=ParamAttr(initializer=Constant(value=1.)), + dtype=self.weight.dtype) + + def forward(self, input): + shift_input = input + shift_input = paddle.add(shift_input, self.shift_bias) + smooth_input = paddle.multiply(shift_input, self.smooth_weight) + return self.layer(smooth_input) + + def convert_weight(self, shift_bias=None, smooth_weight=None): + # shift + if shift_bias is not None: + shift_bias = shift_bias.cast(self.weight.dtype) + self.shift_bias.set_value(-shift_bias) + shift_linear_bias = paddle.matmul(shift_bias, self.weight) + + if self.layer_type == RowParallelLinear: + parallel_shift_linear_bias = paddle.distributed.collective._mp_allreduce( + shift_linear_bias, + group=self.layer.model_parallel_group, + use_calc_stream=True, + use_model_parallel=True) + self.bias.set_value(self.bias + parallel_shift_linear_bias) + else: + self.bias.set_value(self.bias + shift_linear_bias) + + # smooth + if smooth_weight is not None: + self.smooth_weight.set_value( + 1 / smooth_weight.squeeze().cast(self.weight.dtype)) + self.weight.set_value( + self.weight * smooth_weight.transpose(perm=[1, 0])) + + +class WOBiasHelpLayer(nn.Layer): + def __init__(self, layer): + super(WOBiasHelpLayer, self).__init__() + self.weight = layer.weight + self.bias = paddle.create_parameter( + shape=self.weight.shape, + dtype=self.weight.dtype, + default_initializer=paddle.nn.initializer.Constant(0.0), + is_bias=True, ) + self.layer = layer + + def forward(self, input): + return self.layer(input) + self.bias diff --git a/paddleslim/quant/observers/abs_max_weight.py b/paddleslim/quant/observers/abs_max_weight.py index 5594ab1cacd0a31de7a844bfc21616a6dcad5e84..021ce57759d51f16b605c766baaad5086293cdd6 100644 --- a/paddleslim/quant/observers/abs_max_weight.py +++ b/paddleslim/quant/observers/abs_max_weight.py @@ -52,6 +52,7 @@ class AbsMaxChannelWiseWeightObserverLayer(ChannelWiseObserver): self.quant_bits = quant_bits self.calibration_loss = float('inf') self.qmin, self.qmax = self.qmin_qmax + self._layer = layer self._max = None self._scale = None self._zero_point = None @@ -64,26 +65,11 @@ class AbsMaxChannelWiseWeightObserverLayer(ChannelWiseObserver): def _cal_abs_max(self, inputs): reduce_axis = tuple( [i for i in range(len(inputs.shape)) if i != self.quant_axis()]) - abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis) + abs_max_values = paddle.max( + paddle.abs(inputs), axis=reduce_axis).cast("float32") abs_max_values = paddle.where(abs_max_values == np.float32(0.0), np.float32(1e-8), abs_max_values) - minimum_loss = paddle.full(abs_max_values.shape, float('inf')) - result = abs_max_values - factor = 0.3 - while factor <= 1.0: - scales = factor * abs_max_values - factor += 0.02 - expand_scales = paddle.unsqueeze(scales, axis=reduce_axis) - quant_var = paddle.clip( - paddle.round(inputs / expand_scales * self.qmax), self.qmin, - self.qmax) - quant_dequant_var = quant_var / self.qmax * expand_scales - - mse_loss = ((inputs - quant_dequant_var)**2).mean(axis=reduce_axis) - result = paddle.where(mse_loss < minimum_loss, scales, result) - minimum_loss = paddle.minimum(mse_loss, minimum_loss) - - return result + return abs_max_values def min_value(self) -> float: return 0. diff --git a/paddleslim/quant/observers/mse_weight.py b/paddleslim/quant/observers/mse_weight.py index 16aa272680886d653e5af779ce27b8299826ee56..2a17242a799be326b7c2901c7f409d30cd097eba 100644 --- a/paddleslim/quant/observers/mse_weight.py +++ b/paddleslim/quant/observers/mse_weight.py @@ -54,4 +54,20 @@ class MSEChannelWiseWeightObserverLayer(AbsMaxChannelWiseWeightObserverLayer): abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis) abs_max_values = paddle.where(abs_max_values == np.float32(0.0), np.float32(1e-8), abs_max_values) - return abs_max_values + minimum_loss = paddle.full(abs_max_values.shape, float('inf')) + result = abs_max_values + factor = 0.3 + while factor <= 1.0: + scales = factor * abs_max_values + factor += 0.02 + expand_scales = paddle.unsqueeze(scales, axis=reduce_axis) + quant_var = paddle.clip( + paddle.round(inputs / expand_scales * self.qmax), self.qmin, + self.qmax) + quant_dequant_var = quant_var / self.qmax * expand_scales + + mse_loss = ((inputs - quant_dequant_var)**2).mean(axis=reduce_axis) + result = paddle.where(mse_loss < minimum_loss, scales, result) + minimum_loss = paddle.minimum(mse_loss, minimum_loss) + + return result