未验证 提交 b3052f60 编写于 作者: C Chang Xu 提交者: GitHub

Open Advanced Quant Method (#1779)

上级 a7bd589b
# 离线量化
离线量化又称为训练后量化,仅需要使用少量校准数据,确定最佳的量化参数降低量化误差。这种方法需要的数据量较少,但量化模型精度相比在线量化稍逊。
下面该教程将以图像分类模型MobileNetV1为例,说明如何快速使用[PaddleSlim的模型量化接口]()。
该示例包含以下步骤:
1. 导入依赖
2. 构建模型和数据集
3. 进行预训练
4. 量化训练
5. 导出预测模型
以下章节依次次介绍每个步骤的内容。
## 1. 导入依赖
请参考PaddleSlim安装文档,安装正确的Paddle和PaddleSlim版本,然后按以下方式导入Paddle和PaddleSlim:
```python
import paddle
import paddleslim
import paddle.vision.models as models
from paddle.static import InputSpec as Input
from paddle.vision.datasets import Cifar10
import paddle.vision.transforms as T
from paddleslim.dygraph.quant import QAT
```
## 2. 构建网络和数据集
该章节构造一个用于对CIFAR10数据进行分类的分类模型,选用`MobileNetV1`,并将输入大小设置为`[3, 32, 32]`,输出类别数为10。
为了方便展示示例,我们使用Paddle高层API提供的预定义[mobilenetv1分类模型](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/vision/models/mobilenetv1/MobileNetV1_cn.html#mobilenetv1)
调用`model.prepare`配置模型所需的部件,比如优化器、损失函数和评价指标,API细节请参考[文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/hapi/model/Model_cn.html#prepare-optimizer-none-loss-function-none-metrics-none)
```python
net = models.mobilenet_v1(pretrained=False, scale=1.0, num_classes=10)
inputs = [Input([None, 3, 32, 32], 'float32', name='image')]
labels = [Input([None, 1], 'int64', name='label')]
optimizer = paddle.optimizer.Momentum(
learning_rate=0.1,
parameters=net.parameters())
model = paddle.Model(net, inputs, labels)
model.prepare(
optimizer,
paddle.nn.CrossEntropyLoss(),
paddle.metric.Accuracy(topk=(1, 5)))
transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
train_dataset = Cifar10(mode='train', backend='cv2', transform=transform)
val_dataset = Cifar10(mode='test', backend='cv2', transform=transform)
```
## 3. 进行预训练
对模型进行预训练,为之后的量化做准备。
执行以下代码对模型进行预训练
```python
model.fit(train_dataset, epochs=5, batch_size=256, verbose=1)
model.evaluate(val_dataset, batch_size=256, verbose=1)
```
训练完成后导出预测模型:
```python
paddle.jit.save(net, "./fp32_inference_model", input_spec=inputs)
```
## 4.离线量化
调用slim接口将原模型转换为离线量化模型, 导出的模型可以直接用于预测部署:
```python
paddle.enable_static()
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
paddleslim.quant.quant_post_static(
executor=exe,
model_dir='./',
model_filename='fp32_inference_model.pdmodel',
params_filename='fp32_inference_model.pdiparams',
quantize_model_path='./quant_post_static_model',
sample_generator=paddle.dataset.cifar.test10(),
batch_nums=10)
```
注意,目前离线量化方法还不支持存在控制流OP的模型。
根据部署业务场景,可以使用PaddleLite将该量化模型部署到移动端(ARM CPU),或者使用PaddleInference将该量化模型部署到服务器端(NV GPU和Intel CPU)。
导出的量化模型相比原始FP32模型,模型体积没有明显差别,这是因为量化预测模型中的权重依旧保存为FP32类型。在部署时,使用PaddleLite opt工具转换量化预测模型后,模型体积才会真实减小。
部署参考文档:
* 部署[文档](https://paddleslim.readthedocs.io/zh_CN/latest/deploy/index.html)
* PaddleLite部署量化模型[文档](https://paddle-lite.readthedocs.io/zh/latest/user_guides/quant_aware.html)
* PaddleInference Intel CPU部署量化模型[文档](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_x86_cpu_int8.html)
* PaddleInference NV GPU部署量化模型[文档](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_trt.html)
# 量化策略详细教程
近年来,Transformer模型已在各个领域得到广泛采用,尤其是生成式语言大模型极大地推动了人工智能领域的发展。这些模型已经从数亿个参数发展到数千亿个参数,在有限的数据和 GPU 资源下运行,对于这些模型来说变得越来越具有挑战性。此时压缩技术变得格外重要,其中量化已成为减少内存占用和计算开销的通用和主要范例。然而,许多研究表明,Transformer模型往往会存在强烈的异常激活值,这使得它们难以量化。为了保持可接受的性能,这些异常值的存在要求激活具有更高的位宽或使用不同的数字格式、额外的微调或其他解决方法。本文档会介绍前沿的优化量化效果的几种策略,其中包括一些开源工作,也包括PaddleSlim自研的方法。以下方法暂时仅支持Transformer模型,具体示例使用方法可参考[PaddleNLP LLM示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/causallm),以下教程仅详细介绍API接口。
## 1. Shift功能
Shift算法来源于[Outlier Suppression+](https://arxiv.org/abs/2304.09145)。通过Layer Norm和Linear的bias进行数学等价的异常值缩放操作,有效将激活值的分布调整对称,有助于离线量化的精度提升。在PaddleSlim的实现中,对于前面有Layer Norm的Linear将使用数学等价的方式改变bias从而进行缩放操作,对于前面没有Layer Norm的Linear可以采用插入节点的方式实现,可通过参数shift_all_linears来控制是否需要shift前面没有Layer Norm的Linear。此外,PaddleSlim版本的Shift功能提供传入sample_function,如设置sample_function为None,Shift算法将完全对齐论文[Outlier Suppression+](https://arxiv.org/abs/2304.09145).
| **参数名** | **参数类型** | **参数释义** |
|-----------------------------|-----------------------------------------|-----------------------------------------|
| model | paddle.nn.Layer |必须传入的动态图模型 |
| model_config | dict | 必须传入的模型结构的配置 |
| shift_all_linears| bool | 可选参数,默认为False,若为True,则shift模型中全部Linear;若为False,则只shift模型中Layer Norm之后的Linear |
| sample_function | function | 可选参数,默认为None,若为None,采样方式为论文中相同方法,现可选的sample_function有MultiStepSampler和EMASampler,Shift时推荐使用EMASampler |
以下为简单的使用示例:
```shell
from paddleslim.quant.advanced import Shift, EMASampler
model = LLM()
model_config = {}
shift = Shift(model, model_config, sample_function=EMASampler())
for data in dataloader():
model(data)
shift.step += 1
shift.update_weight()
```
## 2. Smooth功能
Smooth算法来源于[SmoothQuant](https://arxiv.org/abs/2211.10438)。通过Layer Norm和Linear的weight和bias进行数学等价的异常值缩放操作,有效减少激活值中的异常值,有助于离线量化的精度提升。在PaddleSlim的实现中,与shift相同,对于前面有Layer Norm的Linear将使用数学等价的方式改变weight和bias从而进行缩放操作,对于前面没有Layer Norm的Linear可以采用插入节点的方式实现,可通过参数smooth_all_linears来控制是否需要smooth前面没有Layer Norm的Linear。此外,PaddleSlim版本的Smooth功能还提供搜索功能,搜索功能文档见下文。
| **参数名** | **参数类型** | **参数释义** |
|-----------------------------|-----------------------------------------|-----------------------------------------|
| model | paddle.nn.Layer |必须传入的动态图模型 |
| model_config | dict | 必须传入的模型结构的配置 |
| alpha | float | 可选参数,默认为0.5 |
| smooth_all_linears| bool | 可选参数,默认为False,若为True,则shift模型中全部Linear;若为False,则只shift模型中Layer Norm之后的Linear |
| sample_function | function | 可选参数,默认为None,若为None,采样方式为单batch数据,现可选的sample_function有MultiStepSampler和EMASampler,Smooth时推荐使用MultiStepSampler |
| search_function | function | 可选参数,默认为None,若为None,则不进行搜索,smooth方法与原论文保持一致 |
以下为简单的使用示例:
```shell
from paddleslim.quant.advanced import Smooth,MultiStepSampler
model = LLM()
model_config = {}
smooth = Smooth(model, model_config, sample_function=MultiStepSampler())
for data in dataloader():
model(data)
smooth.step += 1
smooth.update_weight()
```
注意:模型shift和smooth前后从理论数学角度应该是等价的,但从工程角度,输出具体数值可能会有稍微不同,所以模型shift/smooth前后的精度是大约一致的。如果精度出现很大差距,说明模型未解析成功。参数中`model_config`请按照模型的实际情况填写,此输入会影响模型结构解析,若结构解析出错,会导致模型精度不对,其中包含字段:
- `fused_qkv`:该模型是否融合了QKV,默认为True
- `linear_flag`:该模型判断Linear的名字字段,默认为`linear`
- `norm_flag`:该模型判断Layer Norm的名字字段,默认为`norm`
- `parallel_ffn`:该模型是否含有并行的FFN,默认为False
- `skip_norm_list`:该模型中需要被忽略的Layer Norm的名字字段,默认为空list
若模型中含有PostLayerNorm Shurtcut结构,则不支持对该模型进行smooth和shift。比如PaddleNLP中[ChatGLM结构](https://github.com/PaddlePaddle/PaddleNLP/blob/64f97979a62fba6a35a1177850cc22dbc91fade0/paddlenlp/transformers/chatglm/modeling.py#L360)存在PostLayerNorm Shurtcut结构,所以不支持对该模型进行shift/smooth。
## 3. PieceWiseSearch功能
根据[SmoothQuant](https://arxiv.org/abs/2211.10438)算法,的确能够有效减少异常值,但我们在大量的实验中发现,在某些情况下,比如权重值较大,尤其是权重和激活的异常值在同一通道时,直接根据SmoothQuant计算smooth scale的公式会导致权重值难量化的情况。并且,对于一个激活值,当异常值较多、数值范围较大,使用同一个alpha去smooth整个激活张量也并不合理。因此,PaddleSlim提出分段搜索功能,根据数值大小将激活分成K段,对于每一段进行alhpa和scale的搜索。
| **参数名** | **参数类型** | **参数释义** |
|-----------------------------|-----------------------------------------|-----------------------------------------|
| k_piece | int | 可选参数,分段数量,默认为1,1代表不分段 |
| bits_length | int | 可选参数,量化比特数,默认为8 |
| search_piece | bool | 可选参数,是否搜索分段数k,默认为False,若为True,将会从1到k搜索合适的k |
| search_alpha_min | float | 可选参数,搜索alpha最小值,默认为0.2 |
| search_alpha_max | float | 可选参数,搜索alpha最大值,默认为0.8 |
| search_scale_min | float | 可选参数,搜索scale最小值,默认为1. |
| search_scale_max | float | 可选参数,搜索scale最大值,默认为1. |
| weight_quant_method | str | 可选参数,权重量化方法,可选`abs_max``abs_max_channel_wise``avg`,默认为`abs_max_channel_wise` |
| act_quant_method | str | 可选参数,激活量化方法,可选`abs_max``avg`,默认为`abs_max` |
| loss_function | function | 可选参数,搜索时使用的误差函数,默认为mse_loss |
```shell
from paddleslim.quant.advanced import Smooth, MultiStepSampler, PieceWiseSearch, mse_loss
search_func =PieceWiseSearch(
k_piece=3,
bits_length=8,
search_piece=False,
search_alpha_min=0.2,
search_alpha_max=0.8,
search_scale_min=1.,
search_scale_max=5.,
weight_quant_method='abs_max_channel_wise',
act_quant_method='abs_max',
loss_function=mse_loss
)
model = LLM()
model_config = {}
smooth = Smooth(model, model_config, sample_function=MultiStepSampler(), search_function=search_func)
for data in dataloader():
model(data)
smooth.step += 1
smooth.update_weight()
```
## 4. GPTQ
GPTQ算法来自[GPTQ](https://arxiv.org/abs/2210.17323),该算法逐步按照行量化权重,利用海森矩阵来不断更新未量化的权重,在低比特Weight Only Int4量化表现良好。GPTQ默认使用搭配使用[RPTQ](https://arxiv.org/abs/2304.01089),若不想搭配RPTQ,调用fasterquant时设置act_order=False即可。
| **参数名** | **参数类型** | **参数释义** |
|-----------------------------|-----------------------------------------|-----------------------------------------|
| layer | paddle.nn.Layer |必须入的需要量化的层,现仅支持nn.Linear,ColumnParallelLinear和RowParallelLinear类型 |
| model_config | dict | 必须传入的模型结构的配置 |
| quant_bits| int | 可选参数,量化比特数,默认为4 |
| weight_quant_method | str | 可选参数,权重量化方法,可选`abs_max``abs_max_channel_wise``avg`,默认为`abs_max_channel_wise` |
```shell
from paddleslim.quant.advanced import GPTQ
model = LLM()
for cur_name, cur_layer in model.named_sublayers():
if type(cur_layer) == paddle.nn.Linear:
gptq_layer = GPTQ(cur_layer)
# sample data
for data in dataloader():
model(data)
# quant weight
gptq_layer.fasterquant(act_order=True)
```
## 5. LayerWiseQuantError
LayerWiseQuantError是按层级别分析量化损失的方法,对于模型中每一层,量化后,计算当前层量化输出和原始模型输出的误差。
| **参数名** | **参数类型** | **参数释义** |
|-----------------------------|-----------------------------------------|-----------------------------------------|
| layer | paddle.nn.Layer |必须入的需要量化的层,现仅支持nn.Linear,ColumnParallelLinear和RowParallelLinear类型 |
| weight_bits | int | 可选参数,权重量化比特数,默认为8 |
| act_bits| int | 可选参数,激活量化比特数,默认为8 |
| weight_quant_method| str | 可选参数,权重量化方法,可选`abs_max``abs_max_channel_wise``avg`,默认为`abs_max_channel_wise` |
| act_quant_method| str | 可选参数,激活量化方法,可选`abs_max``avg` |
| loss_function | function | 可选参数,使用的误差函数,默认为mse_loss |
```shell
from paddleslim.quant.advanced import LayerWiseQuantError
model = LLM()
for cur_name, cur_layer in model.named_sublayers():
if type(cur_layer) == paddle.nn.Linear:
gptq_layer = LayerWiseQuantError(cur_layer)
for data in dataloader():
model(data)
for cur_name, cur_layer in model.named_sublayers():
if type(cur_layer) == LayerWiseQuantError:
print(cur_name, cur_layer.losses.mean())
```
../../../quick_start/dygraph/dygraph_quant_post_tutorial.md
\ No newline at end of file
动态图
==============
.. toctree::
:maxdepth: 1
quant_aware_training_tutorial.md
# 离线量化
离线量化又称为训练后量化,仅需要使用少量校准数据,确定最佳的量化参数降低量化误差。这种方法需要的数据量较少,但量化模型精度相比在线量化稍逊。
## 使用方法
离线量化的基本流程可以分为以下三步:
1. 选择量化配置
2. 采样收集量化信息
3. 转换量化模型
## 接口介绍
### 1. 量化配置相关概念以及接口:
`Observer`:用于统计OP输入或输出,并计算出量化相关的统计量,比如scale、zero_point等。每个离线量化算法对应一个Observer,现已有的Observer包含:
- `AVGObserver`:收集目标Tensor的平均值作为量化scale
- `MSEObserver`:收集最大绝对值并通过最小化MSE误差,收集量化scale
- `EMDObserver`:收集最大绝对值并通过最小化EMD误差,收集量化scale
- `HistObserver`:将张量值收集到直方图中,并根据百分比计算量化scale
- `KLObserver`:以最小化浮点值分布与量化浮点值分布之间的 Kullback-Leibler散度计算量化scale
- `AbsMaxChannelWiseWeightObserver`:根据目标权重的通道维度,收集最大绝对值作为量化scale
- `MSEChannelWiseWeightObserver`:根据目标权重的通道维度,收集最大绝对值并通过最小化MSE误差,收集量化scale
`Quanter`:对OP的输入或输出执行量化或模拟操作操作,同时还可以对输入Tensor的数值进行统计分析。每个量化训练算法对应一个Quanter,现已有的Quanter包含:
- `PACTQuanter`
- `WeightLSQplusQuanter`
- `ActLSQplusQuanter`
`QuantConfig`:在执行量化操作之前,首先要配置量化相关的信息,主要是指定每层的各个输入使用什么Observer或Quanter。可通过以下调用方法,根据需求添加每层的量化配置信息:
| **QuantConfig接口** | **传入参数及其含义** | **注意事项** |
|-----------------------------|-----------------------------------------|-----------------------------------------|
| add_layer_config | `layer`: 指定模型的某一层或某些层的list<br><br>`activation`: 用于量化激活以上指定layer的`Observer``Quanter` <br><br> `weight`: 用于量化权重以上指定layer的`Observer``Quanter` | 此方法是最高优的要求,这些层的量化方式将按照这里的要求,而不是按照其他配置进行量化
| add_name_config | `layer_name`: 指定模型的某一层的名字或某些层的名字的list <br><br> `activation`: 用于量化激活以上指定layer的`Observer``Quanter` <br><br> `weight`: 用于量化权重以上指定layer的`Observer``Quanter` | 此方法的优先级仅此于add_layer_config
| add_type_config | `layer_type`:指定需要量化的layer类型,可以为单个layer类型,或一个layer类型的list,layer类型必须为paddle.nn.Layer的子类 <br><br> `activation`: 用于量化激活以上指定layer的`Observer``Quanter` <br><br> `weight`: 用于量化权重以上指定layer的`Observer``Quanter` | 此方法的优先级此于add_name_config,指定需要量化的layer类型,如nn.Linear, 量化时将对所有nn.Linear进行量化,并指定weight和activation的quanter类型
| add_qat_layer_mapping | `source`:被量化的layer <br><br> `target`:量化的layer | source和target必须为paddle.nn.Layer的子类;当指定需要量化的layer类型,如果在框架中没有实现该层量化时,需要指定该layer的量化层,比如ColumnParallelLinear对应PaddleSlim中实现的QuantizedColumnParallelLinear
### 2. PTQ接口介绍:
| **PTQ接口** | **传入参数及其含义** | **介绍** |
|-----------------------------|-----------------------------------------|-----------------------------------------|
| quantize | `model`:需要被量化的模型 <br> `inplace`:inplace=True时,该模型会被inplace的量化;inplace=False时,不改变原模型,并且会return一个量化的模型 | 对模型需要量化的层插入observers以采样到需要的量化信息
| convert | `model`:需要被转化的量化模型 <br> `inplace`:inplace=True时,该模型会被inplace的量化;inplace=False时,不改变原模型,并且会return一个量化的模型 | 将模型转化成onnx形式,进行此步骤之后才能对量化模型进行验证、导出成静态图等
## 使用示例
```python
import paddle
import paddleslim
from paddle.vision.models import mobilenet_v1
from paddle.quantization import QuantConfig
from paddle.quantization import PTQ
from paddleslim.quant.observers import HistObserver, KLObserver, EMDObserver, MSEObserver, AVGObserver, MSEChannelWiseWeightObserver, AbsMaxChannelWiseWeightObserver
# create the model
model = mobilenet_v1()
# define QuantConfig
q_config = QuantConfig(activation=None, weight=None)
# define act_quanter and weight_quanter
act_quanter = MSEObserver()
weight_quanter = MSEObserver()
# map ColumnParallelLinear to QuantizedColumnParallelLinear
q_config.add_qat_layer_mapping(ColumnParallelLinear,
QuantizedColumnParallelLinear)
# map RowParallelLinear to QuantizedRowParallelLinear
q_config.add_qat_layer_mapping(RowParallelLinear,
QuantizedRowParallelLinear)
# for each layer if type in [paddle.nn.Linear, ColumnParallelLinear, RowParallelLinear]
# make them quantizable
q_config.add_type_config(
[paddle.nn.Linear, ColumnParallelLinear, RowParallelLinear],
activation=activation,
weight=weight,
)
ptq = PTQ(q_config)
model = ptq.quantize(model, inplace=True)
# ptq sample
ptq_step = 100
for step, data in enumerate(dataloader):
pred = model(data)
if step == ptq_step:
break
# convert to quant model that can evaluate and export
model = ptq.convert(model, inplace=True)
```
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import gptq
from . import smooth
from . import shift
from . import piecewise_search
from . import sample
from . import layerwise_quant_error
from . import utils_layers
from .gptq import *
from .smooth import *
from .shift import *
from .piecewise_search import *
from .sample import *
from .layerwise_quant_error import *
from .utils_layers import *
__all__ = []
__all__ += gptq.__all__
__all__ += smooth.__all__
__all__ += shift.__all__
__all__ += piecewise_search.__all__
__all__ += sample.__all__
__all__ += layerwise_quant_error.__all__
__all__ += utils_layers.__all__
\ No newline at end of file
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import time
import numpy as np
import paddle
import paddle.nn as nn
from paddle.distributed.fleet.meta_parallel import ColumnParallelLinear, RowParallelLinear
from .utils import compute_scales
__all__ = ['GPTQ']
class GPTQ(nn.Layer):
def __init__(self,
layer,
quant_bits=4,
weight_quant_method='abs_max_channel_wise'):
'''
The implementation of GPTQ(https://arxiv.org/abs/2210.17323).
The codes here are based on https://github.com/IST-DASLab/gptq.
Args:
layer (paddle.nn.Layer): Layer object.
quant_bits (int, optional): Number of bits to quantize the weight. Default: 4.
weight_quant_method (str, optional): Method of weight quantization. Choosen from abs_max, abs_max_channel_wise and avg. Default: abs_max_channel_wise.
Examples:
.. code-block:: python
from paddleslim.quant.advanced import GPTQ
for cur_name, cur_layer in model.named_sublayers():
if type(cur_layer) == paddle.nn.Linear:
gptq_layer = GPTQ(cur_layer)
# sample data
for data in dataloader():
model(data)
# quant weight
gptq_layer.fasterquant()
'''
super(GPTQ, self).__init__()
self.layer = layer
assert hasattr(layer,
'weight'), "Layer {} has no attribute 'weight'".format(
layer.full_name())
assert type(self.layer) in [
nn.Linear, ColumnParallelLinear, RowParallelLinear
], "Currently, GPTQ only supports linear layer and ColumnParallelLinear/RowParallelLinear layer"
weight = layer.weight.t()
self.rows = weight.shape[0]
self.columns = weight.shape[1]
self.hessian = paddle.zeros(
(self.columns, self.columns), dtype='float32')
self.nsamples = 0
self.quantized = False
self.weight_quant_method = weight_quant_method
self.quant_bits = (1 << (quant_bits - 1)) - 1
def forward(self, input):
if not self.quantized:
inp = input[0] if type(input) == tuple else input
inp = inp.cast('float32')
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if type(self.layer) in [
nn.Linear, ColumnParallelLinear, RowParallelLinear
]:
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
self.hessian *= self.nsamples / (self.nsamples + tmp)
self.nsamples += tmp
inp = math.sqrt(2 / self.nsamples) * inp
self.hessian += paddle.matmul(inp, inp.t())
del inp
return self.layer(input)
def fasterquant(self,
blocksize=128,
percdamp=.01,
groupsize=-1,
actorder=True):
print('quant', self.layer.full_name())
W = self.layer.weight.t().cast('float32')
weight_scale = compute_scales(W.t(), method=self.weight_quant_method)
weight_scale /= self.quant_bits
tick = time.time()
H = self.hessian
del self.hessian
dead = paddle.where(paddle.diag(H) == 0)
H[dead, dead] = 1
W[:, dead] = 0
del dead
if actorder:
perm = paddle.argsort(paddle.diag(H), descending=True)
W = W.transpose((1, 0))
W = W[perm].transpose((1, 0))
H = H[perm].transpose((1, 0))
H = H[perm].transpose((1, 0))
Losses = paddle.zeros_like(W)
Q = paddle.zeros_like(W)
damp = percdamp * paddle.mean(paddle.diag(H))
diag = paddle.arange(self.columns)
H[diag, diag] += damp
H = paddle.inverse(H)
H = paddle.linalg.cholesky(H, upper=True)
Hinv = H
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
W1 = W[:, i1:i2]
Q1 = paddle.zeros_like(W1)
Err1 = paddle.zeros_like(W1)
Losses1 = paddle.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
if groupsize != -1:
if (i1 + i) % groupsize == 0:
weight_scale = compute_scales(
W[:, (i1 + i):(i1 + i + groupsize)].t(),
method=self.weight_quant_method)
weight_scale /= self.quant_bits
q = paddle.clip(
paddle.round(w / weight_scale), -self.quant_bits - 1,
self.quant_bits)
q = q * weight_scale
Q1[:, i] = q
Losses1[:, i] = (w - q)**2 / d**2
err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1
del w, d, q, err1
paddle.device.cuda.empty_cache()
Q[:, i1:i2] = Q1
Losses[:, i1:i2] = Losses1 / 2
del Q1, Losses1
if Hinv[i1:i2, i2:].shape[1] != 0:
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
del Err1, W1, Hinv1
paddle.device.cuda.empty_cache()
print('time %.2f' % (time.time() - tick))
print('error', paddle.sum(Losses).item())
if actorder:
invperm = paddle.argsort(perm)
Q = Q.transpose((1, 0))
Q = Q[invperm].transpose((1, 0))
del invperm, perm
param = self.layer.weight
Q = Q.t().cast(self.layer.weight.dtype)
paddle.assign(Q, output=param)
self.quantized = True
del H, Q, Hinv, W, Losses
paddle.device.cuda.empty_cache()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
import numpy as np
from .utils import compute_scales
from .metrics import mse_loss
__all__ = ['LayerWiseQuantError']
class LayerWiseQuantError(nn.Layer):
def __init__(self,
layer,
weight_bits=8,
act_bits=8,
weight_quant_method='abs_max_channel_wise',
act_quant_method='abs_max',
loss_function=mse_loss):
'''
LayerWiseQuantError computes the loss bewteen the output of the layer and the outout of the quantized layer.
Args:
layer (paddle.nn.Layer): Layer object.
quant_bits (int, optional): Number of bits to quantize the weight. Default: 8.
act_bits (int, optional): Number of bits to quantize the activation. Default: 8.
weight_quant_method (str, optional): The method of weight quantization. Choosen from abs_max, abs_max_channel_wise and avg. Default: abs_max_channel_wise.
act_quant_method (str, optional): The method of activation quantization. Choosen from abs_max, avg. Default: abs_max.
Examples:
.. code-block:: python
from paddleslim.quant.advanced import GPTQ
for cur_name, cur_layer in model.named_sublayers():
if type(cur_layer) == paddle.nn.Linear:
gptq_layer = LayerWiseQuantError(cur_layer)
for data in dataloader():
model(data)
for cur_name, cur_layer in model.named_sublayers():
if type(cur_layer) == LayerWiseQuantError:
print(cur_name, cur_layer.losses.mean())
'''
self.layer = layer
self.weight = layer.weight
self.weight_bits = weight_bits
self.act_bits = act_bits
self.weight_method = weight_quant_method
self.act_method = act_quant_method
self.loss_function = loss_function
self.losses = []
def forward(self, input):
act = input[0] if type(input) == tuple else input
origin_out = paddle.matmul(act, self.weight)
bnt = (1 << (self.weight_bits - 1)) - 1
quant_scale = compute_scales(
self.weight.cast('float32'),
method=self.weight_method).cast(self.weight.dtype)
quant_weight = paddle.clip(
paddle.round(self.weight / quant_scale * bnt), -bnt - 1, bnt)
quant_dequant_weight = quant_weight / bnt * quant_scale
bnt = (1 << (self.act_bits - 1)) - 1
quant_scale = compute_scales(act, method=self.act_method)
quant_act = paddle.clip(
paddle.round(act / quant_scale * bnt), -bnt - 1, bnt)
quant_dequant_act = quant_act / bnt * quant_scale
quant_out = paddle.matmul(quant_dequant_act, quant_dequant_weight)
loss = self.loss_function(origin_out, quant_out)
self.losses.append(loss)
return self.layer(input)
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
def mse_loss(y_pred, y_real, reduction='mean'):
if y_pred.shape != y_real.shape:
raise ValueError(
f'Can not compute mse loss for tensors with different shape. '
f'({y_pred.shape} and {y_real.shape})')
mse = (y_pred - y_real)**2
if reduction == 'mean':
return paddle.mean(mse)
elif reduction == 'sum':
return paddle.sum(mse)
elif reduction == 'none':
return mse
else:
raise ValueError(f'Unsupported reduction method.')
def snr_loss(y_pred, y_real, reduction='mean'):
if y_pred.shape != y_real.shape:
raise ValueError(
f'Can not compute snr loss for tensors with different shape. '
f'({y_pred.shape} and {y_real.shape})')
reduction = str(reduction).lower()
y_pred = y_pred.flatten(start_axis=1)
y_real = y_real.flatten(start_axis=1)
noise_power = paddle.pow(y_pred - y_real, 2).sum(axis=-1)
signal_power = paddle.pow(y_real, 2).sum(axis=-1)
snr = (noise_power) / (signal_power + 1e-7)
if reduction == 'mean':
return paddle.mean(snr)
elif reduction == 'sum':
return paddle.sum(snr)
elif reduction == 'none':
return snr
else:
raise ValueError(f'Unsupported reduction method.')
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
from .utils import compute_scales, k_means
from .metrics import mse_loss
__all__ = ['PieceWiseSearch']
class PieceWiseSearch():
def __init__(self,
k_piece=1,
bits_length=8,
search_piece=False,
search_alpha_min=0.2,
search_alpha_max=0.8,
search_scale_min=1.,
search_scale_max=5.,
weight_quant_method='abs_max_channel_wise',
act_quant_method='abs_max',
loss_function=mse_loss):
'''
PieceWiseSearch provides to search k_piece, alpha and scale.
Args:
k_piece (int): Number of k pieces. Default: 1.
bits_length (int): Number of bits to quantize the weight. Default: 8.
search_piece (bool): Whether to search the best k piece. Default: False.
search_alpha_min (float): Minimum alpha for search. Default: 0.2.
search_alpha_max (float): Maximum alpha for search. Default: 0.8.
search_scale_min (float): Minimum scale for search. Default: 1.
search_scale_max (float): Maximum scale for search. Default: 5.
weight_quant_method (str): Weight quantization method. Choosen from abs_max, abs_max_channel_wise and avg. Default: abs_max_channel_wise.
act_quant_method (str): Activation quantization method. Choosen from abs_max, avg. Default: abs_max.
loss_function (callable): Loss function. Default: mse_loss.
'''
self.k_piece = k_piece
self.bits_length = bits_length
self.search_piece = search_piece
self.search_alpha_min = search_alpha_min
self.search_alpha_max = search_alpha_max
self.search_scale_min = search_scale_min
self.search_scale_max = search_scale_max
self.weight_quant_method = weight_quant_method
self.act_quant_method = act_quant_method
self.bnt = (1 << (bits_length - 1)) - 1
self.loss_function = loss_function
def search(self, layer_name, sampled_input, act_abs_max, weight):
act = sampled_input
act.stop_gradient = True
print('[smooth search] search input of %s' % layer_name)
origin_out = paddle.matmul(act, weight)
w_abs_max = weight.abs().max(axis=-1, keepdim=True)
rw_abs_max = w_abs_max.reshape(act_abs_max.shape)
np_act_abs_max = np.array(act_abs_max)
np_rw_abs_max = np.array(rw_abs_max)
smooth_scale_out = None
global_loss = float('inf')
best_scale = None
for k_piece in range(1, self.k_piece + 1):
if not self.search_piece:
k_piece = self.k_piece
print('Search {} Piece'.format(k_piece))
centroids, labels = k_means(act_abs_max, k_piece)
piece = ['piece_{}'.format(a) for a in range(len(centroids))]
for i in range(len(centroids)):
# print('search for piece {}; centroids value is {}'.format(
# piece[i], centroids[centroids.argsort()[i]].numpy()))
alpha = self.search_alpha_min
alpha_max = self.search_scale_max
calibration_loss = float('inf')
final_alpha = None
mask_for_search = paddle.where(labels == centroids.argsort()[i],
1., 0.)
mask_for_ones = paddle.where(mask_for_search == 0., 1., 0.)
while alpha <= alpha_max:
if alpha < 1:
alpha += 0.01
if alpha >= self.search_alpha_max:
alpha = 1.
else:
alpha += 0.5
alpha = round(alpha, 2)
if alpha < 1:
s = (np.power(np_act_abs_max, alpha) / np.power(
np_rw_abs_max, 1. - alpha)).clip(min=1e-5)
s = paddle.to_tensor(s, dtype='float32')
smooth_scale = s * mask_for_search
else:
smooth_scale = alpha * mask_for_search
if smooth_scale_out is not None:
mask_for_ones_new = paddle.where(
smooth_scale_out == 0., 1., 0.)
mask_for_ones *= mask_for_ones_new
smooth_scale_ = smooth_scale_out + smooth_scale
smooth_scale_tmp = smooth_scale_ + mask_for_ones
else:
smooth_scale_tmp = smooth_scale + mask_for_ones
new_act = act / smooth_scale_tmp
new_weight = weight * smooth_scale_tmp.reshape(
w_abs_max.shape)
quant_scale = compute_scales(
new_act, method=self.act_quant_method)
quant_act = paddle.clip(
paddle.round(new_act / quant_scale * self.bnt),
-self.bnt - 1, self.bnt)
quant_dequant_act = quant_act / self.bnt * quant_scale
quant_scale = compute_scales(
new_weight, method=self.weight_quant_method)
quant_weight = paddle.clip(
paddle.round(new_weight / quant_scale * self.bnt),
-self.bnt - 1, self.bnt)
quant_dequant_weight = quant_weight / self.bnt * quant_scale
new_out = paddle.matmul(quant_dequant_act,
quant_dequant_weight)
cur_loss = self.loss_function(origin_out, new_out)
if cur_loss <= calibration_loss:
calibration_loss = cur_loss
final_smooth_scale = smooth_scale
final_alpha = alpha
# print("Layer {} Piece {}, loss: {}, alpha : {}".format(
# layer_name, piece[i], float(calibration_loss), final_alpha))
if smooth_scale_out is None:
smooth_scale_out = final_smooth_scale
else:
smooth_scale_out += final_smooth_scale
if cur_loss < global_loss:
global_loss = cur_loss
best_scale = smooth_scale_out
if self.search_piece:
print('Find Better K-Piece {}'.format(k_piece))
if not self.search_piece:
break
return best_scale
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
__all__ = ['MultiStepSampler', 'EMASampler']
class MultiStepSampler():
def __init__(self):
pass
def sample(self, x, sampled_x=None, layer_name=None):
return paddle.concat([x, sampled_x], axis=1)
class EMASampler():
def __init__(self):
self.ema_beta = 0.98
self.ema_step = {}
self.sampled = {}
def sample(self, x, sampled_x=None, layer_name=None):
if layer_name not in self.ema_step:
self.sampled[layer_name] = (1 - self.ema_beta) * x
self.ema_step[layer_name] = 0
return self.sampled[layer_name]
else:
v_ema = self.ema_beta * self.sampled[layer_name] + (
1 - self.ema_beta) * x
self.sampled[layer_name] = v_ema
v_ema_corr = v_ema / float(
(1 - np.power(self.ema_beta, self.ema_step[layer_name] + 1)))
self.ema_step[layer_name] += 1
return v_ema_corr
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from re import sub
import numpy as np
import paddle
from .utils import get_ln_linear_info, find_parent_layer_and_sub_name
from .utils_layers import ShiftSmoothHelpLayer, WOBiasHelpLayer
__all__ = ['Shift']
class Shift():
def __init__(self,
model,
model_config,
shift_all_linears=False,
sample_function=None):
'''
Shift is the implementation of Outlier Suppression+(https://arxiv.org/abs/2304.09145).
Currently, Shift only supports linear layer and ColumnParallelLinear/RowParallelLinear layer.
Args:
model(paddle.nn.Layer, required): the model to be shifted
model_config (dict, required): the config of model to be shifted
shift_all_linears (bool, optional): whether to shift all linear layers
sample_function (function, optional): the function to sample data
Examples:
.. code-block:: python
from paddleslim.quant.advanced import Shift
shift = Shift(model)
for data in dataloader():
model(data)
shift.step += 1
shift.update_weight()
'''
self.model = model
self.model_config = model_config
self.fused_qkv = model_config.get("fused_qkv", True)
self.linear_flag = model_config.get("linear_flag", 'linear')
self.norm_flag = model_config.get("norm_flag", 'norm')
self.parallel_ffn = model_config.get("parallel_ffn", False)
self.skip_norm_list = model_config.get("skip_norm_list", [])
self.shift_all_linears = shift_all_linears
self.sample_function = sample_function
self.layer_order = []
self.zero_point_dict = {}
self.smooth_scale_dict = {}
self.glabal_min_max = {}
self.model.eval()
self.step = 0
self.print_step = 1
self.got_shift_layers = False
self.help_layers_ready = False
self._apply_hook()
def _apply_hook(self):
self._forward_hook_list = []
for _, sub_layer in self.model.named_sublayers():
if self.norm_flag in sub_layer.full_name(
) or self.linear_flag in sub_layer.full_name():
forward_pre_hook_handle = sub_layer.register_forward_pre_hook(
self._forward_pre_hook)
self._forward_hook_list.append(forward_pre_hook_handle)
if type(sub_layer) == ShiftSmoothHelpLayer:
self.help_layers_ready = True
forward_pre_hook_handle = sub_layer.register_forward_pre_hook(
self._forward_pre_hook)
self._forward_hook_list.append(forward_pre_hook_handle)
def _get_shift_layers(self):
self.ln_linear_dict, self.linear_ln_dict = get_ln_linear_info(
self.layer_order, self.norm_flag, self.linear_flag, self.fused_qkv,
self.parallel_ffn, self.skip_norm_list)
assert len(self.ln_linear_dict) > 0, 'No LN/Linear pair found'
for key in self.ln_linear_dict:
print('shift pair LN {} : Linear {}'.format(
key, self.ln_linear_dict[key]))
if self.shift_all_linears:
if not self.help_layers_ready:
rest_linears = [
l for l in self.layer_order
if self.linear_flag in l and l not in self.linear_ln_dict
]
print('Preparing shift layers', rest_linears)
for cur_name, sub_layer in self.model.named_sublayers():
if sub_layer.full_name() in rest_linears:
new_layer = ShiftSmoothHelpLayer(sub_layer)
parent_layer, sub_name = find_parent_layer_and_sub_name(
self.model, cur_name)
setattr(parent_layer, sub_name, new_layer)
forward_pre_hook_handle = new_layer.register_forward_pre_hook(
self._forward_pre_hook)
self._forward_hook_list.append(forward_pre_hook_handle)
self.got_shift_layers = True
def _forward_pre_hook(self, layer, input):
'''
when step 0, forward once and collect shift layers.
when step >1, sample scale.
'''
if self.step == 0 and layer.full_name() in self.layer_order:
self.step += 1
if self.step == 0:
self.layer_order.append(layer.full_name())
return input
if self.step == 1:
if not self.got_shift_layers:
self._get_shift_layers()
if self.step > 0:
if layer.full_name() in self.linear_ln_dict.keys():
self._sample_zero_point(input,
self.linear_ln_dict[layer.full_name()])
if type(layer) == ShiftSmoothHelpLayer:
self._sample_zero_point(input, layer.full_name())
return input
def _sample_zero_point(self, input, ln_name):
x = input[0] if type(input) == tuple else input
x = x.cast('float32')
x.stop_gradient = True
zero_point = x.mean(axis=(0, 1)) if len(x.shape) > 2 else x.mean(axis=1)
_min = x.min(axis=(0, 1)) if len(x.shape) > 2 else x.min(axis=1)
_max = x.max(axis=(0, 1)) if len(x.shape) > 2 else x.max(axis=1)
if ln_name not in self.zero_point_dict or ln_name not in self.glabal_min_max:
if self.sample_function is None:
self.glabal_min_max[ln_name] = _min, _max
self.zero_point_dict[ln_name] = (_min + _max) / 2
else:
self.zero_point_dict[ln_name] = zero_point
else:
if self.sample_function is not None:
self.zero_point_dict[ln_name] = self.sample_function.sample(
zero_point, self.zero_point_dict[ln_name], ln_name)
else:
global_min, global_max = self.glabal_min_max[ln_name]
global_min = global_min if global_min < _min else _min
global_max = global_max if global_max > _max else _max
self.glabal_min_max[ln_name] = global_min, global_max
self.zero_point_dict[ln_name] = (global_min + global_max) / 2
# per step print once
if self.print_step == self.step:
print('[shift] Step [{}]: {}. zero_point min: {}, max: {}'.format(
self.step, ln_name,
round(float(self.zero_point_dict[ln_name].min()), 5),
round(float(self.zero_point_dict[ln_name].max()), 5)))
if ln_name == list(self.linear_ln_dict.values())[-1]:
self.print_step += 1
def update_weight(self):
'''
update weight of smooth layers.
firstly compute s and update linear's weight,
then update LN's weight by corresponding linear and s
'''
# update linear weight
for _, sub_layer in self.model.named_sublayers():
layer_name = sub_layer.full_name()
if layer_name in self.linear_ln_dict:
ln_name = self.linear_ln_dict[layer_name]
shift_bias = None
for param in sub_layer.parameters(include_sublayers=False):
if 'w_0' in param.name:
zero_point = self.zero_point_dict[ln_name].squeeze()
shift_bias = paddle.matmul(zero_point,
param.cast('float32'))
print("[shift] param: {}, zero_point min: {}, max: {}".
format(param.name,
float(zero_point.min()),
float(zero_point.max())))
break
if not hasattr(sub_layer, "bias") or sub_layer.bias is None:
sub_layer.bias = paddle.create_parameter(
shape=shift_bias.shape,
dtype=sub_layer.weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0.0),
is_bias=True, )
for param in sub_layer.parameters(include_sublayers=False):
if 'b_0' in param.name:
shift_bias = shift_bias + param
paddle.assign(
shift_bias.cast(param.dtype), output=param)
print("[shift] update linear bias: {}.".format(
param.name))
break
# update LN weight
for cur_name, sub_layer in self.model.named_sublayers():
layer_name = sub_layer.full_name()
if layer_name in self.ln_linear_dict:
if not hasattr(sub_layer, "bias") or sub_layer.bias is None:
help_layer = WOBiasHelpLayer(sub_layer)
parent_layer, sub_name = find_parent_layer_and_sub_name(
self.model, cur_name)
setattr(parent_layer, sub_name, help_layer)
sub_layer = help_layer
for param in sub_layer.parameters(include_sublayers=False):
if "b_0" in param.name:
zero_point = self.zero_point_dict[layer_name].squeeze()
param_tmp = param - zero_point
paddle.assign(param_tmp.cast(param.dtype), output=param)
print("[shift] update layer_norm bias {}.".format(
param.name))
break
# update ShiftSmoothRowParallelLinear weight
for _, sub_layer in self.model.named_sublayers():
if type(sub_layer) == ShiftSmoothHelpLayer:
layer_name = sub_layer.full_name()
linear_name = sub_layer.layer.full_name()
zero_point = self.zero_point_dict[layer_name].squeeze()
print(
"[shift ShiftSmoothHelpLayer] before param: {}, shift_bias min: {}, max: {}".
format(linear_name,
float(sub_layer.shift_bias.cast("float32").min()),
float(sub_layer.shift_bias.max().cast("float32"))))
sub_layer.convert_weight(shift_bias=zero_point)
print(
"[shift ShiftSmoothHelpLayer] after param: {}, shift_bias min: {}, max: {}".
format(linear_name,
float(sub_layer.shift_bias.cast("float32").min()),
float(sub_layer.shift_bias.max().cast("float32"))))
self._remove_hook()
paddle.device.cuda.empty_cache()
def _remove_hook(self):
for hook in self._forward_hook_list:
hook.remove()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from .utils import get_ln_linear_info, find_parent_layer_and_sub_name
from .utils_layers import ShiftSmoothHelpLayer, WOBiasHelpLayer
__all__ = ['Smooth']
class Smooth():
def __init__(
self,
model,
model_config,
alpha=0.5,
smooth_all_linears=False,
sample_function=None,
search_function=None, ):
'''
Smooth is an updated version of SmoothQuant(https://arxiv.org/abs/2211.10438).
Compared with SmoothQuant, the piecewise smooth algorithm has the following updates:
It supports the search functions to search best smooth scale.
It supports the sample function to sample smooth scale.
Currently, Smooth only supports linear layer and ColumnParallelLinear/RowParallelLinear layer.
Args:
model(paddle.nn.Layer, required): the model to be smoothed
model_config (dict, required): the config of model to be smoothed
alpha(float, optional): smoothing parameter. Default: 0.5
smooth_all_linears(bool, optional): whether to smooth all linears. Default: False
sample_function(function, optional): the function of sample to sample data. Default: None
sample_start_step(int, optional): the step of sample data by using sample_function. Default: 0
search_function(function, optional): the function of search smooth scale. Default: None
Examples:
.. code-block:: python
from paddleslim.quant.advanced import Smooth
smooth = Smooth(model)
for data in dataloader():
model(data)
smooth.step += 1
smooth.update_weight()
'''
self.model = model
self.model_config = model_config
self.fused_qkv = model_config.get("fused_qkv", True)
self.linear_flag = model_config.get("linear_flag", 'linear')
self.norm_flag = model_config.get("norm_flag", 'norm')
self.parallel_ffn = model_config.get("parallel_ffn", False)
self.skip_norm_list = model_config.get("skip_norm_list", [])
self.alpha = alpha
self.smooth_all_linears = smooth_all_linears
self.sample_function = sample_function
self.search_function = search_function
self.model.eval()
self.step = 0
self.print_step = 1
self.got_smooth_layers = False
self.help_layers_ready = False
self.layer_order = []
self.scale_dict = {}
self.smooth_scale_dict = {}
self.sampled_inputs = {}
self._apply_hook()
def _apply_hook(self):
self._forward_hook_list = []
for _, sub_layer in self.model.named_sublayers():
if self.norm_flag in sub_layer.full_name(
) or self.linear_flag in sub_layer.full_name():
forward_pre_hook_handle = sub_layer.register_forward_pre_hook(
self._forward_pre_hook)
self._forward_hook_list.append(forward_pre_hook_handle)
if type(sub_layer) == ShiftSmoothHelpLayer:
self.help_layers_ready = True
forward_pre_hook_handle = sub_layer.register_forward_pre_hook(
self._forward_pre_hook)
self._forward_hook_list.append(forward_pre_hook_handle)
def _get_smooth_layers(self):
self.ln_linear_dict, self.linear_ln_dict = get_ln_linear_info(
self.layer_order, self.norm_flag, self.linear_flag, self.fused_qkv,
self.parallel_ffn, self.skip_norm_list)
assert len(self.ln_linear_dict) > 0, 'No LN/Linear pair found'
for key in self.ln_linear_dict:
print('smooth pair LN {} : Linear {}'.format(
key, self.ln_linear_dict[key]))
if self.smooth_all_linears:
if not self.help_layers_ready:
rest_linears = [
l for l in self.layer_order
if self.linear_flag in l and l not in self.linear_ln_dict
]
print('Preparing smooth layers', rest_linears)
for cur_name, sub_layer in self.model.named_sublayers():
if sub_layer.full_name() in rest_linears:
new_layer = ShiftSmoothHelpLayer(sub_layer)
parent_layer, sub_name = find_parent_layer_and_sub_name(
self.model, cur_name)
setattr(parent_layer, sub_name, new_layer)
forward_pre_hook_handle = new_layer.register_forward_pre_hook(
self._forward_pre_hook)
self._forward_hook_list.append(forward_pre_hook_handle)
self.got_smooth_layers = True
def _forward_pre_hook(self, layer, input):
'''
when step 0, forward once and collect smooth layers.
'''
if self.step == 0 and layer.full_name() in self.layer_order:
self.step += 1
if self.step == 0:
self.layer_order.append(layer.full_name())
return input
if self.step == 1:
if not self.got_smooth_layers:
self._get_smooth_layers()
if self.step > 0:
if layer.full_name() in self.linear_ln_dict.keys():
self._sample_scale(input,
self.linear_ln_dict[layer.full_name()])
if type(layer) == ShiftSmoothHelpLayer:
self._sample_scale(input, layer.full_name())
return input
def _sample_scale(self, input, ln_name):
x = input[0] if type(input) == tuple else input
x.stop_gradient = True
x_abs_max = x.abs().max(axis=1, keepdim=True)
x_abs_max = x_abs_max.max(axis=0)
if ln_name not in self.scale_dict:
self.sampled_inputs[ln_name] = x
self.scale_dict[ln_name] = x_abs_max
else:
if self.sample_function is not None:
self.sampled_inputs[ln_name] = self.sample_function.sample(
x, self.sampled_inputs[ln_name], ln_name)
else:
self.sampled_inputs[ln_name] = x
tmp1 = paddle.concat([x_abs_max, self.scale_dict[ln_name]], axis=0)
self.scale_dict[ln_name] = tmp1.max(axis=0, keepdim=True)
# per step print once
if self.print_step == self.step:
print('[Smooth] Step [{}]: {}. abs_min: {}, abs_max: {}'.format(
self.step, ln_name,
float(self.scale_dict[ln_name].cast("float32").min()),
float(self.scale_dict[ln_name].cast("float32").max())))
if ln_name == list(self.linear_ln_dict.values())[-1]:
self.print_step += 1
def update_weight(self):
for _, sub_layer in self.model.named_sublayers():
layer_name = sub_layer.full_name()
ln_name = None
if layer_name in self.linear_ln_dict:
ln_name = self.linear_ln_dict[layer_name]
if type(sub_layer) == ShiftSmoothHelpLayer:
ln_name = layer_name
if ln_name is not None:
act_abs_max = self.scale_dict[ln_name].cast("float32")
sampled_input = self.sampled_inputs[ln_name].cast("float32")
for param in sub_layer.parameters(include_sublayers=False):
if 'w_0' in param.name:
weight = param.cast("float32")
if self.search_function is not None:
s = self.search_function.search(
layer_name, sampled_input, act_abs_max, weight)
else:
w_abs_max = weight.abs().max(axis=-1, keepdim=True)
rw_abs_max = w_abs_max.reshape(act_abs_max.shape)
act_abs_max_np = act_abs_max.numpy()
weight_abs_max_np = rw_abs_max.numpy()
s = (
np.power(act_abs_max_np, self.alpha) / np.power(
weight_abs_max_np, 1 - self.alpha)).clip(
min=1e-5)
s = paddle.to_tensor(s, dtype="float32")
self.smooth_scale_dict[ln_name] = s.cast(param.dtype)
break
# update linear weight
for _, sub_layer in self.model.named_sublayers():
layer_name = sub_layer.full_name()
if layer_name in self.linear_ln_dict:
for param in sub_layer.parameters(include_sublayers=False):
if 'w_0' in param.name:
ln_name = self.linear_ln_dict[layer_name]
print("[smooth] before linear [{}] weight, abs_max: {}".
format(param.name,
float(param.cast("float32").abs().max())))
param_tmp = param * self.smooth_scale_dict[
ln_name].transpose(perm=[1, 0])
paddle.assign(param_tmp, output=param)
print("[smooth] after linear [{}] weight, abs_max: {}".
format(param.name,
float(param_tmp.abs().max().cast(
"float32"))))
# update LN weight
for cur_name, sub_layer in self.model.named_sublayers():
layer_name = sub_layer.full_name()
if layer_name in self.ln_linear_dict:
s = self.smooth_scale_dict[layer_name].squeeze()
for param in sub_layer.parameters(include_sublayers=False):
print("[smooth] before layer_norm {} weight, abs_max: {}".
format(param.name,
float(param.abs().max().cast("float32"))))
param_tmp = param / s
paddle.assign(param_tmp, output=param)
print("[smooth] after layer_norm {} weight, abs_max: {}".
format(param.name,
float(param_tmp.abs().max().cast("float32"))))
if not hasattr(sub_layer, "bias") or sub_layer.bias is None:
parent_layer, _ = find_parent_layer_and_sub_name(
self.model, cur_name)
if type(parent_layer) == WOBiasHelpLayer:
param = parent_layer.bias
print(
"[smooth WOBiasHelpLayer] before layer_norm {} bias, abs_max: {}".
format(param.name,
float(param.abs().max().cast("float32"))))
param_tmp = param / s
paddle.assign(param_tmp, output=param)
print(
"[smooth WOBiasHelpLayer] after layer_norm {} bias, abs_max: {}".
format(param.name,
float(param_tmp.abs().max().cast(
"float32"))))
for _, sub_layer in self.model.named_sublayers():
if type(sub_layer) == ShiftSmoothHelpLayer:
layer_name = sub_layer.full_name()
linear_name = sub_layer.layer.full_name()
smooth_scale = self.smooth_scale_dict[layer_name]
print(
"[smooth ShiftSmoothHelpLayer] param: {}, before weight, abs_max: {}".
format(linear_name,
float(sub_layer.weight.abs().max().cast("float32"))))
sub_layer.convert_weight(smooth_weight=smooth_scale)
print(
"[smooth ShiftSmoothHelpLayer] param: {}, after weight, abs_max: {}".
format(linear_name,
float(sub_layer.weight.abs().max().cast("float32"))))
self._remove_hook()
paddle.device.cuda.empty_cache()
def _remove_hook(self):
for hook in self._forward_hook_list:
hook.remove()
self._forward_hook_list = []
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
from sklearn.cluster import KMeans
def k_means(weight, n_clusters, init='k-means++', max_iter=300):
org_shape = weight.shape
weight = paddle.to_tensor(weight)
weight = paddle.reshape(weight, [-1, 1])
if n_clusters > weight.size:
n_clusters = weight.size
k_means = KMeans(
n_clusters=n_clusters,
init=init,
n_init=10,
algorithm='lloyd',
max_iter=max_iter)
k_means.fit(weight)
centroids = k_means.cluster_centers_
labels = k_means.labels_
labels = labels.reshape(org_shape)
return paddle.to_tensor(centroids.flatten()), paddle.to_tensor(labels)
def compute_scales(x, method='abs_max'):
if method == 'abs_max':
quant_scale = float(paddle.max(paddle.abs(x.flatten())))
quant_scale = 1e-8 if quant_scale == 0.0 else quant_scale
elif method == 'avg':
quant_scale = paddle.abs(x.reshape((x.shape[0], -1)))
quant_scale = paddle.mean(paddle.max(quant_scale, axis=(1)))
elif method == 'abs_max_channel_wise':
reduce_axis = tuple([i for i in range(len(x.shape)) if i != 1])
quant_scale = paddle.max(paddle.abs(x), axis=reduce_axis)
quant_scale = paddle.where(quant_scale == np.float32(0.0),
np.float32(1e-8), quant_scale)
return quant_scale
def find_parent_layer_and_sub_name(model, name):
last_idx = 0
idx = 0
parent_layer = model
while idx < len(name):
if name[idx] == '.':
sub_name = name[last_idx:idx]
if hasattr(parent_layer, sub_name):
parent_layer = getattr(parent_layer, sub_name)
last_idx = idx + 1
idx += 1
sub_name = name[last_idx:idx]
return parent_layer, sub_name
def get_ln_linear_info(ln_linear_list, norm_flag, linear_flag, fused_qkv,
llama_ffn, skip_norm_list):
# ln_linear_dict: {layer_norm_0: [linear_0, linear_1, linear_2]}
ln_linear_dict = {}
# linear_ln_dict: {linear_0: layer_norm_0, linear_1: layer_norm_0}
linear_ln_dict = {}
for i in range(len(ln_linear_list)):
layer_name = ln_linear_list[i]
if norm_flag in layer_name and layer_name not in skip_norm_list:
if i < len(ln_linear_list) - 1:
if not fused_qkv:
if linear_flag in ln_linear_list[i +
1] and linear_flag in ln_linear_list[i + 2] and linear_flag in ln_linear_list[i + 3] and int(
layer_name.split('_')
[-1]) % 2 == 0:
ln_linear_dict[layer_name] = [
ln_linear_list[i + 1], ln_linear_list[i + 2],
ln_linear_list[i + 3]
]
linear_ln_dict[ln_linear_list[i + 1]] = layer_name
linear_ln_dict[ln_linear_list[i + 2]] = layer_name
linear_ln_dict[ln_linear_list[i + 3]] = layer_name
if linear_flag in ln_linear_list[i + 1] and int(
layer_name.split('_')[-1]) % 2 != 0:
if llama_ffn:
ln_linear_dict[layer_name] = [
ln_linear_list[i + 1], ln_linear_list[i + 2]
]
linear_ln_dict[ln_linear_list[i + 1]] = layer_name
linear_ln_dict[ln_linear_list[i + 2]] = layer_name
else:
ln_linear_dict[layer_name] = [ln_linear_list[i + 1]]
linear_ln_dict[ln_linear_list[i + 1]] = layer_name
else:
if linear_flag in ln_linear_list[i + 1]:
ln_linear_dict[layer_name] = [ln_linear_list[i + 1]]
linear_ln_dict[ln_linear_list[i + 1]] = layer_name
return ln_linear_dict, linear_ln_dict
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from paddle import ParamAttr
from paddle.nn.initializer import Constant
from paddle.distributed.fleet.meta_parallel import ColumnParallelLinear, RowParallelLinear
__all__ = ['ShiftSmoothHelpLayer', 'WOBiasHelpLayer']
class ShiftSmoothHelpLayer(nn.Layer):
def __init__(self, layer):
super(ShiftSmoothHelpLayer, self).__init__()
self.weight = layer.weight
shift_shape = self.weight.shape[0]
if hasattr(layer, "bias") or layer.bias is None:
self.bias = paddle.create_parameter(
shape=[self.weight.shape[1]],
dtype=self.weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0.0),
is_bias=True, )
layer.bias = self.bias
self.layer = layer
self.layer_type = type(layer)
# add
self.shift_bias = self.create_parameter(
shape=[shift_shape],
attr=ParamAttr(initializer=Constant(value=0.)),
dtype=self.weight.dtype)
# multiply
self.smooth_weight = self.create_parameter(
shape=[shift_shape],
attr=ParamAttr(initializer=Constant(value=1.)),
dtype=self.weight.dtype)
def forward(self, input):
shift_input = input
shift_input = paddle.add(shift_input, self.shift_bias)
smooth_input = paddle.multiply(shift_input, self.smooth_weight)
return self.layer(smooth_input)
def convert_weight(self, shift_bias=None, smooth_weight=None):
# shift
if shift_bias is not None:
shift_bias = shift_bias.cast(self.weight.dtype)
self.shift_bias.set_value(-shift_bias)
shift_linear_bias = paddle.matmul(shift_bias, self.weight)
if self.layer_type == RowParallelLinear:
parallel_shift_linear_bias = paddle.distributed.collective._mp_allreduce(
shift_linear_bias,
group=self.layer.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
self.bias.set_value(self.bias + parallel_shift_linear_bias)
else:
self.bias.set_value(self.bias + shift_linear_bias)
# smooth
if smooth_weight is not None:
self.smooth_weight.set_value(
1 / smooth_weight.squeeze().cast(self.weight.dtype))
self.weight.set_value(
self.weight * smooth_weight.transpose(perm=[1, 0]))
class WOBiasHelpLayer(nn.Layer):
def __init__(self, layer):
super(WOBiasHelpLayer, self).__init__()
self.weight = layer.weight
self.bias = paddle.create_parameter(
shape=self.weight.shape,
dtype=self.weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0.0),
is_bias=True, )
self.layer = layer
def forward(self, input):
return self.layer(input) + self.bias
...@@ -52,6 +52,7 @@ class AbsMaxChannelWiseWeightObserverLayer(ChannelWiseObserver): ...@@ -52,6 +52,7 @@ class AbsMaxChannelWiseWeightObserverLayer(ChannelWiseObserver):
self.quant_bits = quant_bits self.quant_bits = quant_bits
self.calibration_loss = float('inf') self.calibration_loss = float('inf')
self.qmin, self.qmax = self.qmin_qmax self.qmin, self.qmax = self.qmin_qmax
self._layer = layer
self._max = None self._max = None
self._scale = None self._scale = None
self._zero_point = None self._zero_point = None
...@@ -64,26 +65,11 @@ class AbsMaxChannelWiseWeightObserverLayer(ChannelWiseObserver): ...@@ -64,26 +65,11 @@ class AbsMaxChannelWiseWeightObserverLayer(ChannelWiseObserver):
def _cal_abs_max(self, inputs): def _cal_abs_max(self, inputs):
reduce_axis = tuple( reduce_axis = tuple(
[i for i in range(len(inputs.shape)) if i != self.quant_axis()]) [i for i in range(len(inputs.shape)) if i != self.quant_axis()])
abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis) abs_max_values = paddle.max(
paddle.abs(inputs), axis=reduce_axis).cast("float32")
abs_max_values = paddle.where(abs_max_values == np.float32(0.0), abs_max_values = paddle.where(abs_max_values == np.float32(0.0),
np.float32(1e-8), abs_max_values) np.float32(1e-8), abs_max_values)
minimum_loss = paddle.full(abs_max_values.shape, float('inf')) return abs_max_values
result = abs_max_values
factor = 0.3
while factor <= 1.0:
scales = factor * abs_max_values
factor += 0.02
expand_scales = paddle.unsqueeze(scales, axis=reduce_axis)
quant_var = paddle.clip(
paddle.round(inputs / expand_scales * self.qmax), self.qmin,
self.qmax)
quant_dequant_var = quant_var / self.qmax * expand_scales
mse_loss = ((inputs - quant_dequant_var)**2).mean(axis=reduce_axis)
result = paddle.where(mse_loss < minimum_loss, scales, result)
minimum_loss = paddle.minimum(mse_loss, minimum_loss)
return result
def min_value(self) -> float: def min_value(self) -> float:
return 0. return 0.
......
...@@ -54,4 +54,20 @@ class MSEChannelWiseWeightObserverLayer(AbsMaxChannelWiseWeightObserverLayer): ...@@ -54,4 +54,20 @@ class MSEChannelWiseWeightObserverLayer(AbsMaxChannelWiseWeightObserverLayer):
abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis) abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis)
abs_max_values = paddle.where(abs_max_values == np.float32(0.0), abs_max_values = paddle.where(abs_max_values == np.float32(0.0),
np.float32(1e-8), abs_max_values) np.float32(1e-8), abs_max_values)
return abs_max_values minimum_loss = paddle.full(abs_max_values.shape, float('inf'))
result = abs_max_values
factor = 0.3
while factor <= 1.0:
scales = factor * abs_max_values
factor += 0.02
expand_scales = paddle.unsqueeze(scales, axis=reduce_axis)
quant_var = paddle.clip(
paddle.round(inputs / expand_scales * self.qmax), self.qmin,
self.qmax)
quant_dequant_var = quant_var / self.qmax * expand_scales
mse_loss = ((inputs - quant_dequant_var)**2).mean(axis=reduce_axis)
result = paddle.where(mse_loss < minimum_loss, scales, result)
minimum_loss = paddle.minimum(mse_loss, minimum_loss)
return result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册