diff --git a/docs/zh_cn/quick_start/dygraph/dygraph_quant_post_tutorial.md b/docs/zh_cn/quick_start/dygraph/dygraph_quant_post_tutorial.md
deleted file mode 100644
index 77cd49aa170b944b1d6e65629136e516cd14f33b..0000000000000000000000000000000000000000
--- a/docs/zh_cn/quick_start/dygraph/dygraph_quant_post_tutorial.md
+++ /dev/null
@@ -1,97 +0,0 @@
-# 离线量化
-
-离线量化又称为训练后量化,仅需要使用少量校准数据,确定最佳的量化参数降低量化误差。这种方法需要的数据量较少,但量化模型精度相比在线量化稍逊。
-
-下面该教程将以图像分类模型MobileNetV1为例,说明如何快速使用[PaddleSlim的模型量化接口]()。
-
-该示例包含以下步骤:
-
-1. 导入依赖
-2. 构建模型和数据集
-3. 进行预训练
-4. 量化训练
-5. 导出预测模型
-
-以下章节依次次介绍每个步骤的内容。
-
-## 1. 导入依赖
-
-请参考PaddleSlim安装文档,安装正确的Paddle和PaddleSlim版本,然后按以下方式导入Paddle和PaddleSlim:
-
-```python
-import paddle
-import paddleslim
-import paddle.vision.models as models
-from paddle.static import InputSpec as Input
-from paddle.vision.datasets import Cifar10
-import paddle.vision.transforms as T
-from paddleslim.dygraph.quant import QAT
-```
-
-## 2. 构建网络和数据集
-
-该章节构造一个用于对CIFAR10数据进行分类的分类模型,选用`MobileNetV1`,并将输入大小设置为`[3, 32, 32]`,输出类别数为10。
-为了方便展示示例,我们使用Paddle高层API提供的预定义[mobilenetv1分类模型](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/vision/models/mobilenetv1/MobileNetV1_cn.html#mobilenetv1)。
-调用`model.prepare`配置模型所需的部件,比如优化器、损失函数和评价指标,API细节请参考[文档](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/hapi/model/Model_cn.html#prepare-optimizer-none-loss-function-none-metrics-none)
-
-```python
-net = models.mobilenet_v1(pretrained=False, scale=1.0, num_classes=10)
-inputs = [Input([None, 3, 32, 32], 'float32', name='image')]
-labels = [Input([None, 1], 'int64', name='label')]
-optimizer = paddle.optimizer.Momentum(
- learning_rate=0.1,
- parameters=net.parameters())
-model = paddle.Model(net, inputs, labels)
-model.prepare(
- optimizer,
- paddle.nn.CrossEntropyLoss(),
- paddle.metric.Accuracy(topk=(1, 5)))
-transform = T.Compose([T.Transpose(), T.Normalize([127.5], [127.5])])
-train_dataset = Cifar10(mode='train', backend='cv2', transform=transform)
-val_dataset = Cifar10(mode='test', backend='cv2', transform=transform)
-```
-
-## 3. 进行预训练
-
-对模型进行预训练,为之后的量化做准备。
-执行以下代码对模型进行预训练
-```python
-model.fit(train_dataset, epochs=5, batch_size=256, verbose=1)
-model.evaluate(val_dataset, batch_size=256, verbose=1)
-```
-
-训练完成后导出预测模型:
-```python
-paddle.jit.save(net, "./fp32_inference_model", input_spec=inputs)
-```
-
-
-## 4.离线量化
-
-调用slim接口将原模型转换为离线量化模型, 导出的模型可以直接用于预测部署:
-
-```python
-paddle.enable_static()
-place = paddle.CPUPlace()
-exe = paddle.static.Executor(place)
-paddleslim.quant.quant_post_static(
- executor=exe,
- model_dir='./',
- model_filename='fp32_inference_model.pdmodel',
- params_filename='fp32_inference_model.pdiparams',
- quantize_model_path='./quant_post_static_model',
- sample_generator=paddle.dataset.cifar.test10(),
- batch_nums=10)
-```
-
-注意,目前离线量化方法还不支持存在控制流OP的模型。
-
-根据部署业务场景,可以使用PaddleLite将该量化模型部署到移动端(ARM CPU),或者使用PaddleInference将该量化模型部署到服务器端(NV GPU和Intel CPU)。
-
-导出的量化模型相比原始FP32模型,模型体积没有明显差别,这是因为量化预测模型中的权重依旧保存为FP32类型。在部署时,使用PaddleLite opt工具转换量化预测模型后,模型体积才会真实减小。
-
-部署参考文档:
-* 部署[文档](https://paddleslim.readthedocs.io/zh_CN/latest/deploy/index.html)
-* PaddleLite部署量化模型[文档](https://paddle-lite.readthedocs.io/zh/latest/user_guides/quant_aware.html)
-* PaddleInference Intel CPU部署量化模型[文档](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_x86_cpu_int8.html)
-* PaddleInference NV GPU部署量化模型[文档](https://paddle-inference.readthedocs.io/en/latest/optimize/paddle_trt.html)
diff --git a/docs/zh_cn/tutorials/quant/advanced_quantization.md b/docs/zh_cn/tutorials/quant/advanced_quantization.md
new file mode 100644
index 0000000000000000000000000000000000000000..b5cac2fca18fcc59728ffb3b7d022e9d370fb34c
--- /dev/null
+++ b/docs/zh_cn/tutorials/quant/advanced_quantization.md
@@ -0,0 +1,158 @@
+# 量化策略详细教程
+近年来,Transformer模型已在各个领域得到广泛采用,尤其是生成式语言大模型极大地推动了人工智能领域的发展。这些模型已经从数亿个参数发展到数千亿个参数,在有限的数据和 GPU 资源下运行,对于这些模型来说变得越来越具有挑战性。此时压缩技术变得格外重要,其中量化已成为减少内存占用和计算开销的通用和主要范例。然而,许多研究表明,Transformer模型往往会存在强烈的异常激活值,这使得它们难以量化。为了保持可接受的性能,这些异常值的存在要求激活具有更高的位宽或使用不同的数字格式、额外的微调或其他解决方法。本文档会介绍前沿的优化量化效果的几种策略,其中包括一些开源工作,也包括PaddleSlim自研的方法。以下方法暂时仅支持Transformer模型,具体示例使用方法可参考[PaddleNLP LLM示例](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/llm/causallm),以下教程仅详细介绍API接口。
+
+## 1. Shift功能
+
+Shift算法来源于[Outlier Suppression+](https://arxiv.org/abs/2304.09145)。通过Layer Norm和Linear的bias进行数学等价的异常值缩放操作,有效将激活值的分布调整对称,有助于离线量化的精度提升。在PaddleSlim的实现中,对于前面有Layer Norm的Linear将使用数学等价的方式改变bias从而进行缩放操作,对于前面没有Layer Norm的Linear可以采用插入节点的方式实现,可通过参数shift_all_linears来控制是否需要shift前面没有Layer Norm的Linear。此外,PaddleSlim版本的Shift功能提供传入sample_function,如设置sample_function为None,Shift算法将完全对齐论文[Outlier Suppression+](https://arxiv.org/abs/2304.09145).
+
+| **参数名** | **参数类型** | **参数释义** |
+|-----------------------------|-----------------------------------------|-----------------------------------------|
+| model | paddle.nn.Layer |必须传入的动态图模型 |
+| model_config | dict | 必须传入的模型结构的配置 |
+| shift_all_linears| bool | 可选参数,默认为False,若为True,则shift模型中全部Linear;若为False,则只shift模型中Layer Norm之后的Linear |
+| sample_function | function | 可选参数,默认为None,若为None,采样方式为论文中相同方法,现可选的sample_function有MultiStepSampler和EMASampler,Shift时推荐使用EMASampler |
+
+以下为简单的使用示例:
+```shell
+from paddleslim.quant.advanced import Shift, EMASampler
+
+model = LLM()
+model_config = {}
+shift = Shift(model, model_config, sample_function=EMASampler())
+for data in dataloader():
+ model(data)
+ shift.step += 1
+shift.update_weight()
+```
+
+
+
+## 2. Smooth功能
+Smooth算法来源于[SmoothQuant](https://arxiv.org/abs/2211.10438)。通过Layer Norm和Linear的weight和bias进行数学等价的异常值缩放操作,有效减少激活值中的异常值,有助于离线量化的精度提升。在PaddleSlim的实现中,与shift相同,对于前面有Layer Norm的Linear将使用数学等价的方式改变weight和bias从而进行缩放操作,对于前面没有Layer Norm的Linear可以采用插入节点的方式实现,可通过参数smooth_all_linears来控制是否需要smooth前面没有Layer Norm的Linear。此外,PaddleSlim版本的Smooth功能还提供搜索功能,搜索功能文档见下文。
+
+| **参数名** | **参数类型** | **参数释义** |
+|-----------------------------|-----------------------------------------|-----------------------------------------|
+| model | paddle.nn.Layer |必须传入的动态图模型 |
+| model_config | dict | 必须传入的模型结构的配置 |
+| alpha | float | 可选参数,默认为0.5 |
+| smooth_all_linears| bool | 可选参数,默认为False,若为True,则shift模型中全部Linear;若为False,则只shift模型中Layer Norm之后的Linear |
+| sample_function | function | 可选参数,默认为None,若为None,采样方式为单batch数据,现可选的sample_function有MultiStepSampler和EMASampler,Smooth时推荐使用MultiStepSampler |
+| search_function | function | 可选参数,默认为None,若为None,则不进行搜索,smooth方法与原论文保持一致 |
+
+以下为简单的使用示例:
+```shell
+from paddleslim.quant.advanced import Smooth,MultiStepSampler
+
+model = LLM()
+model_config = {}
+smooth = Smooth(model, model_config, sample_function=MultiStepSampler())
+for data in dataloader():
+ model(data)
+ smooth.step += 1
+smooth.update_weight()
+```
+
+
+注意:模型shift和smooth前后从理论数学角度应该是等价的,但从工程角度,输出具体数值可能会有稍微不同,所以模型shift/smooth前后的精度是大约一致的。如果精度出现很大差距,说明模型未解析成功。参数中`model_config`请按照模型的实际情况填写,此输入会影响模型结构解析,若结构解析出错,会导致模型精度不对,其中包含字段:
+- `fused_qkv`:该模型是否融合了QKV,默认为True
+- `linear_flag`:该模型判断Linear的名字字段,默认为`linear`
+- `norm_flag`:该模型判断Layer Norm的名字字段,默认为`norm`
+- `parallel_ffn`:该模型是否含有并行的FFN,默认为False
+- `skip_norm_list`:该模型中需要被忽略的Layer Norm的名字字段,默认为空list
+
+若模型中含有PostLayerNorm Shurtcut结构,则不支持对该模型进行smooth和shift。比如PaddleNLP中[ChatGLM结构](https://github.com/PaddlePaddle/PaddleNLP/blob/64f97979a62fba6a35a1177850cc22dbc91fade0/paddlenlp/transformers/chatglm/modeling.py#L360)存在PostLayerNorm Shurtcut结构,所以不支持对该模型进行shift/smooth。
+
+
+## 3. PieceWiseSearch功能
+根据[SmoothQuant](https://arxiv.org/abs/2211.10438)算法,的确能够有效减少异常值,但我们在大量的实验中发现,在某些情况下,比如权重值较大,尤其是权重和激活的异常值在同一通道时,直接根据SmoothQuant计算smooth scale的公式会导致权重值难量化的情况。并且,对于一个激活值,当异常值较多、数值范围较大,使用同一个alpha去smooth整个激活张量也并不合理。因此,PaddleSlim提出分段搜索功能,根据数值大小将激活分成K段,对于每一段进行alhpa和scale的搜索。
+
+| **参数名** | **参数类型** | **参数释义** |
+|-----------------------------|-----------------------------------------|-----------------------------------------|
+| k_piece | int | 可选参数,分段数量,默认为1,1代表不分段 |
+| bits_length | int | 可选参数,量化比特数,默认为8 |
+| search_piece | bool | 可选参数,是否搜索分段数k,默认为False,若为True,将会从1到k搜索合适的k |
+| search_alpha_min | float | 可选参数,搜索alpha最小值,默认为0.2 |
+| search_alpha_max | float | 可选参数,搜索alpha最大值,默认为0.8 |
+| search_scale_min | float | 可选参数,搜索scale最小值,默认为1. |
+| search_scale_max | float | 可选参数,搜索scale最大值,默认为1. |
+| weight_quant_method | str | 可选参数,权重量化方法,可选`abs_max`,`abs_max_channel_wise`,`avg`,默认为`abs_max_channel_wise` |
+| act_quant_method | str | 可选参数,激活量化方法,可选`abs_max`,`avg`,默认为`abs_max` |
+| loss_function | function | 可选参数,搜索时使用的误差函数,默认为mse_loss |
+
+
+```shell
+from paddleslim.quant.advanced import Smooth, MultiStepSampler, PieceWiseSearch, mse_loss
+
+search_func =PieceWiseSearch(
+ k_piece=3,
+ bits_length=8,
+ search_piece=False,
+ search_alpha_min=0.2,
+ search_alpha_max=0.8,
+ search_scale_min=1.,
+ search_scale_max=5.,
+ weight_quant_method='abs_max_channel_wise',
+ act_quant_method='abs_max',
+ loss_function=mse_loss
+ )
+model = LLM()
+model_config = {}
+smooth = Smooth(model, model_config, sample_function=MultiStepSampler(), search_function=search_func)
+for data in dataloader():
+ model(data)
+ smooth.step += 1
+smooth.update_weight()
+
+```
+
+## 4. GPTQ
+GPTQ算法来自[GPTQ](https://arxiv.org/abs/2210.17323),该算法逐步按照行量化权重,利用海森矩阵来不断更新未量化的权重,在低比特Weight Only Int4量化表现良好。GPTQ默认使用搭配使用[RPTQ](https://arxiv.org/abs/2304.01089),若不想搭配RPTQ,调用fasterquant时设置act_order=False即可。
+
+| **参数名** | **参数类型** | **参数释义** |
+|-----------------------------|-----------------------------------------|-----------------------------------------|
+| layer | paddle.nn.Layer |必须入的需要量化的层,现仅支持nn.Linear,ColumnParallelLinear和RowParallelLinear类型 |
+| model_config | dict | 必须传入的模型结构的配置 |
+| quant_bits| int | 可选参数,量化比特数,默认为4 |
+| weight_quant_method | str | 可选参数,权重量化方法,可选`abs_max`,`abs_max_channel_wise`,`avg`,默认为`abs_max_channel_wise` |
+
+```shell
+from paddleslim.quant.advanced import GPTQ
+
+model = LLM()
+for cur_name, cur_layer in model.named_sublayers():
+ if type(cur_layer) == paddle.nn.Linear:
+ gptq_layer = GPTQ(cur_layer)
+ # sample data
+ for data in dataloader():
+ model(data)
+ # quant weight
+ gptq_layer.fasterquant(act_order=True)
+```
+
+
+## 5. LayerWiseQuantError
+LayerWiseQuantError是按层级别分析量化损失的方法,对于模型中每一层,量化后,计算当前层量化输出和原始模型输出的误差。
+| **参数名** | **参数类型** | **参数释义** |
+|-----------------------------|-----------------------------------------|-----------------------------------------|
+| layer | paddle.nn.Layer |必须入的需要量化的层,现仅支持nn.Linear,ColumnParallelLinear和RowParallelLinear类型 |
+| weight_bits | int | 可选参数,权重量化比特数,默认为8 |
+| act_bits| int | 可选参数,激活量化比特数,默认为8 |
+| weight_quant_method| str | 可选参数,权重量化方法,可选`abs_max`,`abs_max_channel_wise`,`avg`,默认为`abs_max_channel_wise` |
+| act_quant_method| str | 可选参数,激活量化方法,可选`abs_max`,`avg` |
+| loss_function | function | 可选参数,使用的误差函数,默认为mse_loss |
+
+```shell
+from paddleslim.quant.advanced import LayerWiseQuantError
+
+model = LLM()
+for cur_name, cur_layer in model.named_sublayers():
+ if type(cur_layer) == paddle.nn.Linear:
+ gptq_layer = LayerWiseQuantError(cur_layer)
+
+for data in dataloader():
+ model(data)
+
+for cur_name, cur_layer in model.named_sublayers():
+ if type(cur_layer) == LayerWiseQuantError:
+ print(cur_name, cur_layer.losses.mean())
+```
diff --git a/docs/zh_cn/tutorials/quant/dygraph/dygraph_quant_post_tutorial.md b/docs/zh_cn/tutorials/quant/dygraph/dygraph_quant_post_tutorial.md
deleted file mode 120000
index ddd24f04c0b0db18602d8b56e20d7521547592f4..0000000000000000000000000000000000000000
--- a/docs/zh_cn/tutorials/quant/dygraph/dygraph_quant_post_tutorial.md
+++ /dev/null
@@ -1 +0,0 @@
-../../../quick_start/dygraph/dygraph_quant_post_tutorial.md
\ No newline at end of file
diff --git a/docs/zh_cn/tutorials/quant/dygraph/index.rst b/docs/zh_cn/tutorials/quant/dygraph/index.rst
deleted file mode 100644
index e45cca35cc33eaee2d81cdc11129dd39ca166270..0000000000000000000000000000000000000000
--- a/docs/zh_cn/tutorials/quant/dygraph/index.rst
+++ /dev/null
@@ -1,8 +0,0 @@
-
-动态图
-==============
-
-.. toctree::
- :maxdepth: 1
-
- quant_aware_training_tutorial.md
diff --git a/docs/zh_cn/tutorials/quant/post_training_quantization.md b/docs/zh_cn/tutorials/quant/post_training_quantization.md
new file mode 100644
index 0000000000000000000000000000000000000000..09391de13f45384b4ee97930898b85482ebaaaf1
--- /dev/null
+++ b/docs/zh_cn/tutorials/quant/post_training_quantization.md
@@ -0,0 +1,93 @@
+# 离线量化
+
+离线量化又称为训练后量化,仅需要使用少量校准数据,确定最佳的量化参数降低量化误差。这种方法需要的数据量较少,但量化模型精度相比在线量化稍逊。
+
+
+## 使用方法
+
+离线量化的基本流程可以分为以下三步:
+
+1. 选择量化配置
+2. 采样收集量化信息
+3. 转换量化模型
+
+## 接口介绍
+
+### 1. 量化配置相关概念以及接口:
+
+`Observer`:用于统计OP输入或输出,并计算出量化相关的统计量,比如scale、zero_point等。每个离线量化算法对应一个Observer,现已有的Observer包含:
+- `AVGObserver`:收集目标Tensor的平均值作为量化scale
+- `MSEObserver`:收集最大绝对值并通过最小化MSE误差,收集量化scale
+- `EMDObserver`:收集最大绝对值并通过最小化EMD误差,收集量化scale
+- `HistObserver`:将张量值收集到直方图中,并根据百分比计算量化scale
+- `KLObserver`:以最小化浮点值分布与量化浮点值分布之间的 Kullback-Leibler散度计算量化scale
+- `AbsMaxChannelWiseWeightObserver`:根据目标权重的通道维度,收集最大绝对值作为量化scale
+- `MSEChannelWiseWeightObserver`:根据目标权重的通道维度,收集最大绝对值并通过最小化MSE误差,收集量化scale
+
+`Quanter`:对OP的输入或输出执行量化或模拟操作操作,同时还可以对输入Tensor的数值进行统计分析。每个量化训练算法对应一个Quanter,现已有的Quanter包含:
+- `PACTQuanter`
+- `WeightLSQplusQuanter`:
+- `ActLSQplusQuanter`
+
+`QuantConfig`:在执行量化操作之前,首先要配置量化相关的信息,主要是指定每层的各个输入使用什么Observer或Quanter。可通过以下调用方法,根据需求添加每层的量化配置信息:
+| **QuantConfig接口** | **传入参数及其含义** | **注意事项** |
+|-----------------------------|-----------------------------------------|-----------------------------------------|
+| add_layer_config | `layer`: 指定模型的某一层或某些层的list
`activation`: 用于量化激活以上指定layer的`Observer`或`Quanter`
`weight`: 用于量化权重以上指定layer的`Observer`或`Quanter` | 此方法是最高优的要求,这些层的量化方式将按照这里的要求,而不是按照其他配置进行量化
+| add_name_config | `layer_name`: 指定模型的某一层的名字或某些层的名字的list
`activation`: 用于量化激活以上指定layer的`Observer`或`Quanter`
`weight`: 用于量化权重以上指定layer的`Observer`或`Quanter` | 此方法的优先级仅此于add_layer_config
+| add_type_config | `layer_type`:指定需要量化的layer类型,可以为单个layer类型,或一个layer类型的list,layer类型必须为paddle.nn.Layer的子类
`activation`: 用于量化激活以上指定layer的`Observer`或`Quanter`
`weight`: 用于量化权重以上指定layer的`Observer`或`Quanter` | 此方法的优先级此于add_name_config,指定需要量化的layer类型,如nn.Linear, 量化时将对所有nn.Linear进行量化,并指定weight和activation的quanter类型
+| add_qat_layer_mapping | `source`:被量化的layer
`target`:量化的layer | source和target必须为paddle.nn.Layer的子类;当指定需要量化的layer类型,如果在框架中没有实现该层量化时,需要指定该layer的量化层,比如ColumnParallelLinear对应PaddleSlim中实现的QuantizedColumnParallelLinear
+
+### 2. PTQ接口介绍:
+| **PTQ接口** | **传入参数及其含义** | **介绍** |
+|-----------------------------|-----------------------------------------|-----------------------------------------|
+| quantize | `model`:需要被量化的模型
`inplace`:inplace=True时,该模型会被inplace的量化;inplace=False时,不改变原模型,并且会return一个量化的模型 | 对模型需要量化的层插入observers以采样到需要的量化信息
+| convert | `model`:需要被转化的量化模型
`inplace`:inplace=True时,该模型会被inplace的量化;inplace=False时,不改变原模型,并且会return一个量化的模型 | 将模型转化成onnx形式,进行此步骤之后才能对量化模型进行验证、导出成静态图等
+
+
+## 使用示例
+```python
+import paddle
+import paddleslim
+from paddle.vision.models import mobilenet_v1
+from paddle.quantization import QuantConfig
+from paddle.quantization import PTQ
+from paddleslim.quant.observers import HistObserver, KLObserver, EMDObserver, MSEObserver, AVGObserver, MSEChannelWiseWeightObserver, AbsMaxChannelWiseWeightObserver
+
+# create the model
+model = mobilenet_v1()
+
+# define QuantConfig
+q_config = QuantConfig(activation=None, weight=None)
+
+# define act_quanter and weight_quanter
+act_quanter = MSEObserver()
+weight_quanter = MSEObserver()
+
+# map ColumnParallelLinear to QuantizedColumnParallelLinear
+q_config.add_qat_layer_mapping(ColumnParallelLinear,
+ QuantizedColumnParallelLinear)
+# map RowParallelLinear to QuantizedRowParallelLinear
+q_config.add_qat_layer_mapping(RowParallelLinear,
+ QuantizedRowParallelLinear)
+# for each layer if type in [paddle.nn.Linear, ColumnParallelLinear, RowParallelLinear]
+# make them quantizable
+q_config.add_type_config(
+ [paddle.nn.Linear, ColumnParallelLinear, RowParallelLinear],
+ activation=activation,
+ weight=weight,
+ )
+
+
+ptq = PTQ(q_config)
+model = ptq.quantize(model, inplace=True)
+
+# ptq sample
+ptq_step = 100
+for step, data in enumerate(dataloader):
+ pred = model(data)
+ if step == ptq_step:
+ break
+
+# convert to quant model that can evaluate and export
+model = ptq.convert(model, inplace=True)
+```
diff --git a/docs/zh_cn/tutorials/quant/dygraph/quant_aware_training_tutorial.md b/docs/zh_cn/tutorials/quant/quant_aware_training.md
similarity index 100%
rename from docs/zh_cn/tutorials/quant/dygraph/quant_aware_training_tutorial.md
rename to docs/zh_cn/tutorials/quant/quant_aware_training.md
diff --git a/docs/zh_cn/tutorials/quant/Analysis.md b/docs/zh_cn/tutorials/quant/static/Analysis.md
similarity index 100%
rename from docs/zh_cn/tutorials/quant/Analysis.md
rename to docs/zh_cn/tutorials/quant/static/Analysis.md
diff --git a/paddleslim/quant/advanced/__init__.py b/paddleslim/quant/advanced/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f0744ecf40c1972ccb1b5c2e2e4daff98b5e8a8
--- /dev/null
+++ b/paddleslim/quant/advanced/__init__.py
@@ -0,0 +1,38 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from . import gptq
+from . import smooth
+from . import shift
+from . import piecewise_search
+from . import sample
+from . import layerwise_quant_error
+from . import utils_layers
+
+from .gptq import *
+from .smooth import *
+from .shift import *
+from .piecewise_search import *
+from .sample import *
+from .layerwise_quant_error import *
+from .utils_layers import *
+
+__all__ = []
+__all__ += gptq.__all__
+__all__ += smooth.__all__
+__all__ += shift.__all__
+__all__ += piecewise_search.__all__
+__all__ += sample.__all__
+__all__ += layerwise_quant_error.__all__
+__all__ += utils_layers.__all__
\ No newline at end of file
diff --git a/paddleslim/quant/advanced/gptq.py b/paddleslim/quant/advanced/gptq.py
new file mode 100644
index 0000000000000000000000000000000000000000..96566858febc63bbef2168fb436ce4ebaea0ee00
--- /dev/null
+++ b/paddleslim/quant/advanced/gptq.py
@@ -0,0 +1,185 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import math
+import time
+import numpy as np
+import paddle
+import paddle.nn as nn
+from paddle.distributed.fleet.meta_parallel import ColumnParallelLinear, RowParallelLinear
+
+from .utils import compute_scales
+__all__ = ['GPTQ']
+
+
+class GPTQ(nn.Layer):
+ def __init__(self,
+ layer,
+ quant_bits=4,
+ weight_quant_method='abs_max_channel_wise'):
+ '''
+ The implementation of GPTQ(https://arxiv.org/abs/2210.17323).
+ The codes here are based on https://github.com/IST-DASLab/gptq.
+
+ Args:
+ layer (paddle.nn.Layer): Layer object.
+ quant_bits (int, optional): Number of bits to quantize the weight. Default: 4.
+ weight_quant_method (str, optional): Method of weight quantization. Choosen from abs_max, abs_max_channel_wise and avg. Default: abs_max_channel_wise.
+
+ Examples:
+ .. code-block:: python
+
+ from paddleslim.quant.advanced import GPTQ
+ for cur_name, cur_layer in model.named_sublayers():
+ if type(cur_layer) == paddle.nn.Linear:
+ gptq_layer = GPTQ(cur_layer)
+ # sample data
+ for data in dataloader():
+ model(data)
+ # quant weight
+ gptq_layer.fasterquant()
+ '''
+ super(GPTQ, self).__init__()
+ self.layer = layer
+ assert hasattr(layer,
+ 'weight'), "Layer {} has no attribute 'weight'".format(
+ layer.full_name())
+ assert type(self.layer) in [
+ nn.Linear, ColumnParallelLinear, RowParallelLinear
+ ], "Currently, GPTQ only supports linear layer and ColumnParallelLinear/RowParallelLinear layer"
+
+ weight = layer.weight.t()
+
+ self.rows = weight.shape[0]
+ self.columns = weight.shape[1]
+ self.hessian = paddle.zeros(
+ (self.columns, self.columns), dtype='float32')
+ self.nsamples = 0
+ self.quantized = False
+ self.weight_quant_method = weight_quant_method
+ self.quant_bits = (1 << (quant_bits - 1)) - 1
+
+ def forward(self, input):
+ if not self.quantized:
+ inp = input[0] if type(input) == tuple else input
+ inp = inp.cast('float32')
+ if len(inp.shape) == 2:
+ inp = inp.unsqueeze(0)
+ tmp = inp.shape[0]
+ if type(self.layer) in [
+ nn.Linear, ColumnParallelLinear, RowParallelLinear
+ ]:
+ if len(inp.shape) == 3:
+ inp = inp.reshape((-1, inp.shape[-1]))
+ inp = inp.t()
+
+ self.hessian *= self.nsamples / (self.nsamples + tmp)
+ self.nsamples += tmp
+
+ inp = math.sqrt(2 / self.nsamples) * inp
+ self.hessian += paddle.matmul(inp, inp.t())
+ del inp
+ return self.layer(input)
+
+ def fasterquant(self,
+ blocksize=128,
+ percdamp=.01,
+ groupsize=-1,
+ actorder=True):
+ print('quant', self.layer.full_name())
+ W = self.layer.weight.t().cast('float32')
+ weight_scale = compute_scales(W.t(), method=self.weight_quant_method)
+ weight_scale /= self.quant_bits
+ tick = time.time()
+
+ H = self.hessian
+ del self.hessian
+ dead = paddle.where(paddle.diag(H) == 0)
+ H[dead, dead] = 1
+ W[:, dead] = 0
+ del dead
+ if actorder:
+ perm = paddle.argsort(paddle.diag(H), descending=True)
+ W = W.transpose((1, 0))
+ W = W[perm].transpose((1, 0))
+ H = H[perm].transpose((1, 0))
+ H = H[perm].transpose((1, 0))
+
+ Losses = paddle.zeros_like(W)
+ Q = paddle.zeros_like(W)
+
+ damp = percdamp * paddle.mean(paddle.diag(H))
+ diag = paddle.arange(self.columns)
+ H[diag, diag] += damp
+
+ H = paddle.inverse(H)
+ H = paddle.linalg.cholesky(H, upper=True)
+ Hinv = H
+
+ for i1 in range(0, self.columns, blocksize):
+ i2 = min(i1 + blocksize, self.columns)
+ count = i2 - i1
+ W1 = W[:, i1:i2]
+ Q1 = paddle.zeros_like(W1)
+ Err1 = paddle.zeros_like(W1)
+ Losses1 = paddle.zeros_like(W1)
+ Hinv1 = Hinv[i1:i2, i1:i2]
+
+ for i in range(count):
+ w = W1[:, i]
+ d = Hinv1[i, i]
+ if groupsize != -1:
+ if (i1 + i) % groupsize == 0:
+ weight_scale = compute_scales(
+ W[:, (i1 + i):(i1 + i + groupsize)].t(),
+ method=self.weight_quant_method)
+ weight_scale /= self.quant_bits
+
+ q = paddle.clip(
+ paddle.round(w / weight_scale), -self.quant_bits - 1,
+ self.quant_bits)
+ q = q * weight_scale
+ Q1[:, i] = q
+ Losses1[:, i] = (w - q)**2 / d**2
+
+ err1 = (w - q) / d
+ W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
+ Err1[:, i] = err1
+ del w, d, q, err1
+ paddle.device.cuda.empty_cache()
+
+ Q[:, i1:i2] = Q1
+ Losses[:, i1:i2] = Losses1 / 2
+ del Q1, Losses1
+ if Hinv[i1:i2, i2:].shape[1] != 0:
+ W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
+ del Err1, W1, Hinv1
+ paddle.device.cuda.empty_cache()
+
+ print('time %.2f' % (time.time() - tick))
+ print('error', paddle.sum(Losses).item())
+
+ if actorder:
+ invperm = paddle.argsort(perm)
+ Q = Q.transpose((1, 0))
+ Q = Q[invperm].transpose((1, 0))
+ del invperm, perm
+
+ param = self.layer.weight
+ Q = Q.t().cast(self.layer.weight.dtype)
+ paddle.assign(Q, output=param)
+
+ self.quantized = True
+ del H, Q, Hinv, W, Losses
+ paddle.device.cuda.empty_cache()
diff --git a/paddleslim/quant/advanced/layerwise_quant_error.py b/paddleslim/quant/advanced/layerwise_quant_error.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce03d198cd09a943cedde26b0338ff50f65b8e12
--- /dev/null
+++ b/paddleslim/quant/advanced/layerwise_quant_error.py
@@ -0,0 +1,85 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+import numpy as np
+from .utils import compute_scales
+from .metrics import mse_loss
+
+__all__ = ['LayerWiseQuantError']
+
+
+class LayerWiseQuantError(nn.Layer):
+ def __init__(self,
+ layer,
+ weight_bits=8,
+ act_bits=8,
+ weight_quant_method='abs_max_channel_wise',
+ act_quant_method='abs_max',
+ loss_function=mse_loss):
+ '''
+ LayerWiseQuantError computes the loss bewteen the output of the layer and the outout of the quantized layer.
+
+ Args:
+ layer (paddle.nn.Layer): Layer object.
+ quant_bits (int, optional): Number of bits to quantize the weight. Default: 8.
+ act_bits (int, optional): Number of bits to quantize the activation. Default: 8.
+ weight_quant_method (str, optional): The method of weight quantization. Choosen from abs_max, abs_max_channel_wise and avg. Default: abs_max_channel_wise.
+ act_quant_method (str, optional): The method of activation quantization. Choosen from abs_max, avg. Default: abs_max.
+
+ Examples:
+ .. code-block:: python
+
+ from paddleslim.quant.advanced import GPTQ
+ for cur_name, cur_layer in model.named_sublayers():
+ if type(cur_layer) == paddle.nn.Linear:
+ gptq_layer = LayerWiseQuantError(cur_layer)
+
+ for data in dataloader():
+ model(data)
+
+ for cur_name, cur_layer in model.named_sublayers():
+ if type(cur_layer) == LayerWiseQuantError:
+ print(cur_name, cur_layer.losses.mean())
+ '''
+ self.layer = layer
+ self.weight = layer.weight
+ self.weight_bits = weight_bits
+ self.act_bits = act_bits
+ self.weight_method = weight_quant_method
+ self.act_method = act_quant_method
+ self.loss_function = loss_function
+ self.losses = []
+
+ def forward(self, input):
+ act = input[0] if type(input) == tuple else input
+ origin_out = paddle.matmul(act, self.weight)
+ bnt = (1 << (self.weight_bits - 1)) - 1
+ quant_scale = compute_scales(
+ self.weight.cast('float32'),
+ method=self.weight_method).cast(self.weight.dtype)
+ quant_weight = paddle.clip(
+ paddle.round(self.weight / quant_scale * bnt), -bnt - 1, bnt)
+ quant_dequant_weight = quant_weight / bnt * quant_scale
+
+ bnt = (1 << (self.act_bits - 1)) - 1
+ quant_scale = compute_scales(act, method=self.act_method)
+ quant_act = paddle.clip(
+ paddle.round(act / quant_scale * bnt), -bnt - 1, bnt)
+ quant_dequant_act = quant_act / bnt * quant_scale
+ quant_out = paddle.matmul(quant_dequant_act, quant_dequant_weight)
+ loss = self.loss_function(origin_out, quant_out)
+ self.losses.append(loss)
+ return self.layer(input)
diff --git a/paddleslim/quant/advanced/metrics.py b/paddleslim/quant/advanced/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..b00f710c7b3789e789f97b7f7304f06d8b21ceb0
--- /dev/null
+++ b/paddleslim/quant/advanced/metrics.py
@@ -0,0 +1,57 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+
+
+def mse_loss(y_pred, y_real, reduction='mean'):
+ if y_pred.shape != y_real.shape:
+ raise ValueError(
+ f'Can not compute mse loss for tensors with different shape. '
+ f'({y_pred.shape} and {y_real.shape})')
+
+ mse = (y_pred - y_real)**2
+
+ if reduction == 'mean':
+ return paddle.mean(mse)
+ elif reduction == 'sum':
+ return paddle.sum(mse)
+ elif reduction == 'none':
+ return mse
+ else:
+ raise ValueError(f'Unsupported reduction method.')
+
+
+def snr_loss(y_pred, y_real, reduction='mean'):
+ if y_pred.shape != y_real.shape:
+ raise ValueError(
+ f'Can not compute snr loss for tensors with different shape. '
+ f'({y_pred.shape} and {y_real.shape})')
+ reduction = str(reduction).lower()
+
+ y_pred = y_pred.flatten(start_axis=1)
+ y_real = y_real.flatten(start_axis=1)
+
+ noise_power = paddle.pow(y_pred - y_real, 2).sum(axis=-1)
+ signal_power = paddle.pow(y_real, 2).sum(axis=-1)
+ snr = (noise_power) / (signal_power + 1e-7)
+
+ if reduction == 'mean':
+ return paddle.mean(snr)
+ elif reduction == 'sum':
+ return paddle.sum(snr)
+ elif reduction == 'none':
+ return snr
+ else:
+ raise ValueError(f'Unsupported reduction method.')
diff --git a/paddleslim/quant/advanced/piecewise_search.py b/paddleslim/quant/advanced/piecewise_search.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8e4a1f2d734caf46f9a727dd2ec307099f35ce6
--- /dev/null
+++ b/paddleslim/quant/advanced/piecewise_search.py
@@ -0,0 +1,161 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import numpy as np
+from .utils import compute_scales, k_means
+from .metrics import mse_loss
+
+__all__ = ['PieceWiseSearch']
+
+
+class PieceWiseSearch():
+ def __init__(self,
+ k_piece=1,
+ bits_length=8,
+ search_piece=False,
+ search_alpha_min=0.2,
+ search_alpha_max=0.8,
+ search_scale_min=1.,
+ search_scale_max=5.,
+ weight_quant_method='abs_max_channel_wise',
+ act_quant_method='abs_max',
+ loss_function=mse_loss):
+ '''
+ PieceWiseSearch provides to search k_piece, alpha and scale.
+
+ Args:
+ k_piece (int): Number of k pieces. Default: 1.
+ bits_length (int): Number of bits to quantize the weight. Default: 8.
+ search_piece (bool): Whether to search the best k piece. Default: False.
+ search_alpha_min (float): Minimum alpha for search. Default: 0.2.
+ search_alpha_max (float): Maximum alpha for search. Default: 0.8.
+ search_scale_min (float): Minimum scale for search. Default: 1.
+ search_scale_max (float): Maximum scale for search. Default: 5.
+ weight_quant_method (str): Weight quantization method. Choosen from abs_max, abs_max_channel_wise and avg. Default: abs_max_channel_wise.
+ act_quant_method (str): Activation quantization method. Choosen from abs_max, avg. Default: abs_max.
+ loss_function (callable): Loss function. Default: mse_loss.
+ '''
+ self.k_piece = k_piece
+ self.bits_length = bits_length
+ self.search_piece = search_piece
+ self.search_alpha_min = search_alpha_min
+ self.search_alpha_max = search_alpha_max
+ self.search_scale_min = search_scale_min
+ self.search_scale_max = search_scale_max
+ self.weight_quant_method = weight_quant_method
+ self.act_quant_method = act_quant_method
+ self.bnt = (1 << (bits_length - 1)) - 1
+ self.loss_function = loss_function
+
+ def search(self, layer_name, sampled_input, act_abs_max, weight):
+ act = sampled_input
+ act.stop_gradient = True
+ print('[smooth search] search input of %s' % layer_name)
+
+ origin_out = paddle.matmul(act, weight)
+ w_abs_max = weight.abs().max(axis=-1, keepdim=True)
+ rw_abs_max = w_abs_max.reshape(act_abs_max.shape)
+ np_act_abs_max = np.array(act_abs_max)
+ np_rw_abs_max = np.array(rw_abs_max)
+
+ smooth_scale_out = None
+ global_loss = float('inf')
+ best_scale = None
+
+ for k_piece in range(1, self.k_piece + 1):
+ if not self.search_piece:
+ k_piece = self.k_piece
+ print('Search {} Piece'.format(k_piece))
+ centroids, labels = k_means(act_abs_max, k_piece)
+ piece = ['piece_{}'.format(a) for a in range(len(centroids))]
+ for i in range(len(centroids)):
+ # print('search for piece {}; centroids value is {}'.format(
+ # piece[i], centroids[centroids.argsort()[i]].numpy()))
+ alpha = self.search_alpha_min
+ alpha_max = self.search_scale_max
+ calibration_loss = float('inf')
+ final_alpha = None
+ mask_for_search = paddle.where(labels == centroids.argsort()[i],
+ 1., 0.)
+ mask_for_ones = paddle.where(mask_for_search == 0., 1., 0.)
+
+ while alpha <= alpha_max:
+ if alpha < 1:
+ alpha += 0.01
+ if alpha >= self.search_alpha_max:
+ alpha = 1.
+ else:
+ alpha += 0.5
+
+ alpha = round(alpha, 2)
+
+ if alpha < 1:
+ s = (np.power(np_act_abs_max, alpha) / np.power(
+ np_rw_abs_max, 1. - alpha)).clip(min=1e-5)
+ s = paddle.to_tensor(s, dtype='float32')
+ smooth_scale = s * mask_for_search
+ else:
+ smooth_scale = alpha * mask_for_search
+
+ if smooth_scale_out is not None:
+ mask_for_ones_new = paddle.where(
+ smooth_scale_out == 0., 1., 0.)
+ mask_for_ones *= mask_for_ones_new
+ smooth_scale_ = smooth_scale_out + smooth_scale
+ smooth_scale_tmp = smooth_scale_ + mask_for_ones
+ else:
+ smooth_scale_tmp = smooth_scale + mask_for_ones
+
+ new_act = act / smooth_scale_tmp
+ new_weight = weight * smooth_scale_tmp.reshape(
+ w_abs_max.shape)
+
+ quant_scale = compute_scales(
+ new_act, method=self.act_quant_method)
+ quant_act = paddle.clip(
+ paddle.round(new_act / quant_scale * self.bnt),
+ -self.bnt - 1, self.bnt)
+ quant_dequant_act = quant_act / self.bnt * quant_scale
+
+ quant_scale = compute_scales(
+ new_weight, method=self.weight_quant_method)
+ quant_weight = paddle.clip(
+ paddle.round(new_weight / quant_scale * self.bnt),
+ -self.bnt - 1, self.bnt)
+ quant_dequant_weight = quant_weight / self.bnt * quant_scale
+ new_out = paddle.matmul(quant_dequant_act,
+ quant_dequant_weight)
+
+ cur_loss = self.loss_function(origin_out, new_out)
+ if cur_loss <= calibration_loss:
+ calibration_loss = cur_loss
+ final_smooth_scale = smooth_scale
+ final_alpha = alpha
+
+ # print("Layer {} Piece {}, loss: {}, alpha : {}".format(
+ # layer_name, piece[i], float(calibration_loss), final_alpha))
+ if smooth_scale_out is None:
+ smooth_scale_out = final_smooth_scale
+ else:
+ smooth_scale_out += final_smooth_scale
+
+ if cur_loss < global_loss:
+ global_loss = cur_loss
+ best_scale = smooth_scale_out
+ if self.search_piece:
+ print('Find Better K-Piece {}'.format(k_piece))
+ if not self.search_piece:
+ break
+ return best_scale
diff --git a/paddleslim/quant/advanced/sample.py b/paddleslim/quant/advanced/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8241614af5810afe59fce4ea98623f4b1dd54b9
--- /dev/null
+++ b/paddleslim/quant/advanced/sample.py
@@ -0,0 +1,47 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import numpy as np
+__all__ = ['MultiStepSampler', 'EMASampler']
+
+
+class MultiStepSampler():
+ def __init__(self):
+ pass
+
+ def sample(self, x, sampled_x=None, layer_name=None):
+ return paddle.concat([x, sampled_x], axis=1)
+
+
+class EMASampler():
+ def __init__(self):
+ self.ema_beta = 0.98
+ self.ema_step = {}
+ self.sampled = {}
+
+ def sample(self, x, sampled_x=None, layer_name=None):
+
+ if layer_name not in self.ema_step:
+ self.sampled[layer_name] = (1 - self.ema_beta) * x
+ self.ema_step[layer_name] = 0
+ return self.sampled[layer_name]
+ else:
+ v_ema = self.ema_beta * self.sampled[layer_name] + (
+ 1 - self.ema_beta) * x
+ self.sampled[layer_name] = v_ema
+ v_ema_corr = v_ema / float(
+ (1 - np.power(self.ema_beta, self.ema_step[layer_name] + 1)))
+ self.ema_step[layer_name] += 1
+ return v_ema_corr
diff --git a/paddleslim/quant/advanced/shift.py b/paddleslim/quant/advanced/shift.py
new file mode 100644
index 0000000000000000000000000000000000000000..70aea070f877008820c19e9aed1aa48cf20618e9
--- /dev/null
+++ b/paddleslim/quant/advanced/shift.py
@@ -0,0 +1,254 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from re import sub
+import numpy as np
+import paddle
+from .utils import get_ln_linear_info, find_parent_layer_and_sub_name
+from .utils_layers import ShiftSmoothHelpLayer, WOBiasHelpLayer
+
+__all__ = ['Shift']
+
+
+class Shift():
+ def __init__(self,
+ model,
+ model_config,
+ shift_all_linears=False,
+ sample_function=None):
+ '''
+ Shift is the implementation of Outlier Suppression+(https://arxiv.org/abs/2304.09145).
+ Currently, Shift only supports linear layer and ColumnParallelLinear/RowParallelLinear layer.
+
+ Args:
+ model(paddle.nn.Layer, required): the model to be shifted
+ model_config (dict, required): the config of model to be shifted
+ shift_all_linears (bool, optional): whether to shift all linear layers
+ sample_function (function, optional): the function to sample data
+
+ Examples:
+ .. code-block:: python
+
+ from paddleslim.quant.advanced import Shift
+ shift = Shift(model)
+ for data in dataloader():
+ model(data)
+ shift.step += 1
+ shift.update_weight()
+ '''
+
+ self.model = model
+ self.model_config = model_config
+ self.fused_qkv = model_config.get("fused_qkv", True)
+ self.linear_flag = model_config.get("linear_flag", 'linear')
+ self.norm_flag = model_config.get("norm_flag", 'norm')
+ self.parallel_ffn = model_config.get("parallel_ffn", False)
+ self.skip_norm_list = model_config.get("skip_norm_list", [])
+ self.shift_all_linears = shift_all_linears
+ self.sample_function = sample_function
+
+ self.layer_order = []
+ self.zero_point_dict = {}
+ self.smooth_scale_dict = {}
+ self.glabal_min_max = {}
+
+ self.model.eval()
+ self.step = 0
+ self.print_step = 1
+ self.got_shift_layers = False
+ self.help_layers_ready = False
+ self._apply_hook()
+
+ def _apply_hook(self):
+ self._forward_hook_list = []
+ for _, sub_layer in self.model.named_sublayers():
+ if self.norm_flag in sub_layer.full_name(
+ ) or self.linear_flag in sub_layer.full_name():
+ forward_pre_hook_handle = sub_layer.register_forward_pre_hook(
+ self._forward_pre_hook)
+ self._forward_hook_list.append(forward_pre_hook_handle)
+ if type(sub_layer) == ShiftSmoothHelpLayer:
+ self.help_layers_ready = True
+ forward_pre_hook_handle = sub_layer.register_forward_pre_hook(
+ self._forward_pre_hook)
+ self._forward_hook_list.append(forward_pre_hook_handle)
+
+ def _get_shift_layers(self):
+ self.ln_linear_dict, self.linear_ln_dict = get_ln_linear_info(
+ self.layer_order, self.norm_flag, self.linear_flag, self.fused_qkv,
+ self.parallel_ffn, self.skip_norm_list)
+ assert len(self.ln_linear_dict) > 0, 'No LN/Linear pair found'
+ for key in self.ln_linear_dict:
+ print('shift pair LN {} : Linear {}'.format(
+ key, self.ln_linear_dict[key]))
+ if self.shift_all_linears:
+ if not self.help_layers_ready:
+ rest_linears = [
+ l for l in self.layer_order
+ if self.linear_flag in l and l not in self.linear_ln_dict
+ ]
+ print('Preparing shift layers', rest_linears)
+ for cur_name, sub_layer in self.model.named_sublayers():
+ if sub_layer.full_name() in rest_linears:
+ new_layer = ShiftSmoothHelpLayer(sub_layer)
+ parent_layer, sub_name = find_parent_layer_and_sub_name(
+ self.model, cur_name)
+ setattr(parent_layer, sub_name, new_layer)
+ forward_pre_hook_handle = new_layer.register_forward_pre_hook(
+ self._forward_pre_hook)
+ self._forward_hook_list.append(forward_pre_hook_handle)
+ self.got_shift_layers = True
+
+ def _forward_pre_hook(self, layer, input):
+ '''
+ when step 0, forward once and collect shift layers.
+ when step >1, sample scale.
+ '''
+ if self.step == 0 and layer.full_name() in self.layer_order:
+ self.step += 1
+ if self.step == 0:
+ self.layer_order.append(layer.full_name())
+ return input
+ if self.step == 1:
+ if not self.got_shift_layers:
+ self._get_shift_layers()
+ if self.step > 0:
+ if layer.full_name() in self.linear_ln_dict.keys():
+ self._sample_zero_point(input,
+ self.linear_ln_dict[layer.full_name()])
+ if type(layer) == ShiftSmoothHelpLayer:
+ self._sample_zero_point(input, layer.full_name())
+
+ return input
+
+ def _sample_zero_point(self, input, ln_name):
+ x = input[0] if type(input) == tuple else input
+ x = x.cast('float32')
+ x.stop_gradient = True
+
+ zero_point = x.mean(axis=(0, 1)) if len(x.shape) > 2 else x.mean(axis=1)
+ _min = x.min(axis=(0, 1)) if len(x.shape) > 2 else x.min(axis=1)
+ _max = x.max(axis=(0, 1)) if len(x.shape) > 2 else x.max(axis=1)
+
+ if ln_name not in self.zero_point_dict or ln_name not in self.glabal_min_max:
+
+ if self.sample_function is None:
+ self.glabal_min_max[ln_name] = _min, _max
+ self.zero_point_dict[ln_name] = (_min + _max) / 2
+ else:
+ self.zero_point_dict[ln_name] = zero_point
+
+ else:
+ if self.sample_function is not None:
+ self.zero_point_dict[ln_name] = self.sample_function.sample(
+ zero_point, self.zero_point_dict[ln_name], ln_name)
+ else:
+ global_min, global_max = self.glabal_min_max[ln_name]
+ global_min = global_min if global_min < _min else _min
+ global_max = global_max if global_max > _max else _max
+ self.glabal_min_max[ln_name] = global_min, global_max
+ self.zero_point_dict[ln_name] = (global_min + global_max) / 2
+
+ # per step print once
+ if self.print_step == self.step:
+ print('[shift] Step [{}]: {}. zero_point min: {}, max: {}'.format(
+ self.step, ln_name,
+ round(float(self.zero_point_dict[ln_name].min()), 5),
+ round(float(self.zero_point_dict[ln_name].max()), 5)))
+ if ln_name == list(self.linear_ln_dict.values())[-1]:
+ self.print_step += 1
+
+ def update_weight(self):
+ '''
+ update weight of smooth layers.
+ firstly compute s and update linear's weight,
+ then update LN's weight by corresponding linear and s
+ '''
+ # update linear weight
+ for _, sub_layer in self.model.named_sublayers():
+ layer_name = sub_layer.full_name()
+ if layer_name in self.linear_ln_dict:
+ ln_name = self.linear_ln_dict[layer_name]
+ shift_bias = None
+ for param in sub_layer.parameters(include_sublayers=False):
+ if 'w_0' in param.name:
+ zero_point = self.zero_point_dict[ln_name].squeeze()
+ shift_bias = paddle.matmul(zero_point,
+ param.cast('float32'))
+ print("[shift] param: {}, zero_point min: {}, max: {}".
+ format(param.name,
+ float(zero_point.min()),
+ float(zero_point.max())))
+ break
+
+ if not hasattr(sub_layer, "bias") or sub_layer.bias is None:
+ sub_layer.bias = paddle.create_parameter(
+ shape=shift_bias.shape,
+ dtype=sub_layer.weight.dtype,
+ default_initializer=paddle.nn.initializer.Constant(0.0),
+ is_bias=True, )
+ for param in sub_layer.parameters(include_sublayers=False):
+ if 'b_0' in param.name:
+ shift_bias = shift_bias + param
+ paddle.assign(
+ shift_bias.cast(param.dtype), output=param)
+ print("[shift] update linear bias: {}.".format(
+ param.name))
+ break
+
+ # update LN weight
+ for cur_name, sub_layer in self.model.named_sublayers():
+ layer_name = sub_layer.full_name()
+ if layer_name in self.ln_linear_dict:
+ if not hasattr(sub_layer, "bias") or sub_layer.bias is None:
+ help_layer = WOBiasHelpLayer(sub_layer)
+ parent_layer, sub_name = find_parent_layer_and_sub_name(
+ self.model, cur_name)
+ setattr(parent_layer, sub_name, help_layer)
+ sub_layer = help_layer
+
+ for param in sub_layer.parameters(include_sublayers=False):
+ if "b_0" in param.name:
+ zero_point = self.zero_point_dict[layer_name].squeeze()
+ param_tmp = param - zero_point
+ paddle.assign(param_tmp.cast(param.dtype), output=param)
+ print("[shift] update layer_norm bias {}.".format(
+ param.name))
+ break
+
+ # update ShiftSmoothRowParallelLinear weight
+ for _, sub_layer in self.model.named_sublayers():
+ if type(sub_layer) == ShiftSmoothHelpLayer:
+ layer_name = sub_layer.full_name()
+ linear_name = sub_layer.layer.full_name()
+ zero_point = self.zero_point_dict[layer_name].squeeze()
+ print(
+ "[shift ShiftSmoothHelpLayer] before param: {}, shift_bias min: {}, max: {}".
+ format(linear_name,
+ float(sub_layer.shift_bias.cast("float32").min()),
+ float(sub_layer.shift_bias.max().cast("float32"))))
+ sub_layer.convert_weight(shift_bias=zero_point)
+
+ print(
+ "[shift ShiftSmoothHelpLayer] after param: {}, shift_bias min: {}, max: {}".
+ format(linear_name,
+ float(sub_layer.shift_bias.cast("float32").min()),
+ float(sub_layer.shift_bias.max().cast("float32"))))
+
+ self._remove_hook()
+ paddle.device.cuda.empty_cache()
+
+ def _remove_hook(self):
+ for hook in self._forward_hook_list:
+ hook.remove()
diff --git a/paddleslim/quant/advanced/smooth.py b/paddleslim/quant/advanced/smooth.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2bd51789ee0ec3368fd709d4371d897f965a5c5
--- /dev/null
+++ b/paddleslim/quant/advanced/smooth.py
@@ -0,0 +1,273 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import paddle
+from .utils import get_ln_linear_info, find_parent_layer_and_sub_name
+from .utils_layers import ShiftSmoothHelpLayer, WOBiasHelpLayer
+__all__ = ['Smooth']
+
+
+class Smooth():
+ def __init__(
+ self,
+ model,
+ model_config,
+ alpha=0.5,
+ smooth_all_linears=False,
+ sample_function=None,
+ search_function=None, ):
+ '''
+ Smooth is an updated version of SmoothQuant(https://arxiv.org/abs/2211.10438).
+ Compared with SmoothQuant, the piecewise smooth algorithm has the following updates:
+ It supports the search functions to search best smooth scale.
+ It supports the sample function to sample smooth scale.
+ Currently, Smooth only supports linear layer and ColumnParallelLinear/RowParallelLinear layer.
+
+
+ Args:
+ model(paddle.nn.Layer, required): the model to be smoothed
+ model_config (dict, required): the config of model to be smoothed
+ alpha(float, optional): smoothing parameter. Default: 0.5
+ smooth_all_linears(bool, optional): whether to smooth all linears. Default: False
+ sample_function(function, optional): the function of sample to sample data. Default: None
+ sample_start_step(int, optional): the step of sample data by using sample_function. Default: 0
+ search_function(function, optional): the function of search smooth scale. Default: None
+
+ Examples:
+ .. code-block:: python
+
+ from paddleslim.quant.advanced import Smooth
+ smooth = Smooth(model)
+ for data in dataloader():
+ model(data)
+ smooth.step += 1
+ smooth.update_weight()
+ '''
+ self.model = model
+ self.model_config = model_config
+ self.fused_qkv = model_config.get("fused_qkv", True)
+ self.linear_flag = model_config.get("linear_flag", 'linear')
+ self.norm_flag = model_config.get("norm_flag", 'norm')
+ self.parallel_ffn = model_config.get("parallel_ffn", False)
+ self.skip_norm_list = model_config.get("skip_norm_list", [])
+
+ self.alpha = alpha
+ self.smooth_all_linears = smooth_all_linears
+ self.sample_function = sample_function
+ self.search_function = search_function
+
+ self.model.eval()
+ self.step = 0
+ self.print_step = 1
+ self.got_smooth_layers = False
+ self.help_layers_ready = False
+ self.layer_order = []
+ self.scale_dict = {}
+ self.smooth_scale_dict = {}
+ self.sampled_inputs = {}
+ self._apply_hook()
+
+ def _apply_hook(self):
+ self._forward_hook_list = []
+ for _, sub_layer in self.model.named_sublayers():
+ if self.norm_flag in sub_layer.full_name(
+ ) or self.linear_flag in sub_layer.full_name():
+ forward_pre_hook_handle = sub_layer.register_forward_pre_hook(
+ self._forward_pre_hook)
+ self._forward_hook_list.append(forward_pre_hook_handle)
+ if type(sub_layer) == ShiftSmoothHelpLayer:
+ self.help_layers_ready = True
+ forward_pre_hook_handle = sub_layer.register_forward_pre_hook(
+ self._forward_pre_hook)
+ self._forward_hook_list.append(forward_pre_hook_handle)
+
+ def _get_smooth_layers(self):
+ self.ln_linear_dict, self.linear_ln_dict = get_ln_linear_info(
+ self.layer_order, self.norm_flag, self.linear_flag, self.fused_qkv,
+ self.parallel_ffn, self.skip_norm_list)
+ assert len(self.ln_linear_dict) > 0, 'No LN/Linear pair found'
+ for key in self.ln_linear_dict:
+ print('smooth pair LN {} : Linear {}'.format(
+ key, self.ln_linear_dict[key]))
+ if self.smooth_all_linears:
+ if not self.help_layers_ready:
+ rest_linears = [
+ l for l in self.layer_order
+ if self.linear_flag in l and l not in self.linear_ln_dict
+ ]
+ print('Preparing smooth layers', rest_linears)
+ for cur_name, sub_layer in self.model.named_sublayers():
+ if sub_layer.full_name() in rest_linears:
+ new_layer = ShiftSmoothHelpLayer(sub_layer)
+ parent_layer, sub_name = find_parent_layer_and_sub_name(
+ self.model, cur_name)
+ setattr(parent_layer, sub_name, new_layer)
+ forward_pre_hook_handle = new_layer.register_forward_pre_hook(
+ self._forward_pre_hook)
+ self._forward_hook_list.append(forward_pre_hook_handle)
+
+ self.got_smooth_layers = True
+
+ def _forward_pre_hook(self, layer, input):
+ '''
+ when step 0, forward once and collect smooth layers.
+ '''
+ if self.step == 0 and layer.full_name() in self.layer_order:
+ self.step += 1
+ if self.step == 0:
+ self.layer_order.append(layer.full_name())
+ return input
+ if self.step == 1:
+ if not self.got_smooth_layers:
+ self._get_smooth_layers()
+ if self.step > 0:
+ if layer.full_name() in self.linear_ln_dict.keys():
+ self._sample_scale(input,
+ self.linear_ln_dict[layer.full_name()])
+ if type(layer) == ShiftSmoothHelpLayer:
+ self._sample_scale(input, layer.full_name())
+
+ return input
+
+ def _sample_scale(self, input, ln_name):
+ x = input[0] if type(input) == tuple else input
+ x.stop_gradient = True
+ x_abs_max = x.abs().max(axis=1, keepdim=True)
+ x_abs_max = x_abs_max.max(axis=0)
+
+ if ln_name not in self.scale_dict:
+ self.sampled_inputs[ln_name] = x
+ self.scale_dict[ln_name] = x_abs_max
+ else:
+ if self.sample_function is not None:
+ self.sampled_inputs[ln_name] = self.sample_function.sample(
+ x, self.sampled_inputs[ln_name], ln_name)
+ else:
+ self.sampled_inputs[ln_name] = x
+ tmp1 = paddle.concat([x_abs_max, self.scale_dict[ln_name]], axis=0)
+ self.scale_dict[ln_name] = tmp1.max(axis=0, keepdim=True)
+
+ # per step print once
+ if self.print_step == self.step:
+ print('[Smooth] Step [{}]: {}. abs_min: {}, abs_max: {}'.format(
+ self.step, ln_name,
+ float(self.scale_dict[ln_name].cast("float32").min()),
+ float(self.scale_dict[ln_name].cast("float32").max())))
+ if ln_name == list(self.linear_ln_dict.values())[-1]:
+ self.print_step += 1
+
+ def update_weight(self):
+
+ for _, sub_layer in self.model.named_sublayers():
+ layer_name = sub_layer.full_name()
+ ln_name = None
+ if layer_name in self.linear_ln_dict:
+ ln_name = self.linear_ln_dict[layer_name]
+ if type(sub_layer) == ShiftSmoothHelpLayer:
+ ln_name = layer_name
+ if ln_name is not None:
+ act_abs_max = self.scale_dict[ln_name].cast("float32")
+ sampled_input = self.sampled_inputs[ln_name].cast("float32")
+ for param in sub_layer.parameters(include_sublayers=False):
+ if 'w_0' in param.name:
+ weight = param.cast("float32")
+ if self.search_function is not None:
+ s = self.search_function.search(
+ layer_name, sampled_input, act_abs_max, weight)
+ else:
+ w_abs_max = weight.abs().max(axis=-1, keepdim=True)
+ rw_abs_max = w_abs_max.reshape(act_abs_max.shape)
+ act_abs_max_np = act_abs_max.numpy()
+ weight_abs_max_np = rw_abs_max.numpy()
+ s = (
+ np.power(act_abs_max_np, self.alpha) / np.power(
+ weight_abs_max_np, 1 - self.alpha)).clip(
+ min=1e-5)
+ s = paddle.to_tensor(s, dtype="float32")
+
+ self.smooth_scale_dict[ln_name] = s.cast(param.dtype)
+ break
+
+ # update linear weight
+ for _, sub_layer in self.model.named_sublayers():
+ layer_name = sub_layer.full_name()
+ if layer_name in self.linear_ln_dict:
+ for param in sub_layer.parameters(include_sublayers=False):
+ if 'w_0' in param.name:
+ ln_name = self.linear_ln_dict[layer_name]
+ print("[smooth] before linear [{}] weight, abs_max: {}".
+ format(param.name,
+ float(param.cast("float32").abs().max())))
+ param_tmp = param * self.smooth_scale_dict[
+ ln_name].transpose(perm=[1, 0])
+ paddle.assign(param_tmp, output=param)
+ print("[smooth] after linear [{}] weight, abs_max: {}".
+ format(param.name,
+ float(param_tmp.abs().max().cast(
+ "float32"))))
+
+ # update LN weight
+ for cur_name, sub_layer in self.model.named_sublayers():
+ layer_name = sub_layer.full_name()
+ if layer_name in self.ln_linear_dict:
+ s = self.smooth_scale_dict[layer_name].squeeze()
+ for param in sub_layer.parameters(include_sublayers=False):
+ print("[smooth] before layer_norm {} weight, abs_max: {}".
+ format(param.name,
+ float(param.abs().max().cast("float32"))))
+ param_tmp = param / s
+ paddle.assign(param_tmp, output=param)
+ print("[smooth] after layer_norm {} weight, abs_max: {}".
+ format(param.name,
+ float(param_tmp.abs().max().cast("float32"))))
+ if not hasattr(sub_layer, "bias") or sub_layer.bias is None:
+ parent_layer, _ = find_parent_layer_and_sub_name(
+ self.model, cur_name)
+ if type(parent_layer) == WOBiasHelpLayer:
+ param = parent_layer.bias
+ print(
+ "[smooth WOBiasHelpLayer] before layer_norm {} bias, abs_max: {}".
+ format(param.name,
+ float(param.abs().max().cast("float32"))))
+ param_tmp = param / s
+ paddle.assign(param_tmp, output=param)
+ print(
+ "[smooth WOBiasHelpLayer] after layer_norm {} bias, abs_max: {}".
+ format(param.name,
+ float(param_tmp.abs().max().cast(
+ "float32"))))
+
+ for _, sub_layer in self.model.named_sublayers():
+ if type(sub_layer) == ShiftSmoothHelpLayer:
+ layer_name = sub_layer.full_name()
+ linear_name = sub_layer.layer.full_name()
+ smooth_scale = self.smooth_scale_dict[layer_name]
+ print(
+ "[smooth ShiftSmoothHelpLayer] param: {}, before weight, abs_max: {}".
+ format(linear_name,
+ float(sub_layer.weight.abs().max().cast("float32"))))
+ sub_layer.convert_weight(smooth_weight=smooth_scale)
+ print(
+ "[smooth ShiftSmoothHelpLayer] param: {}, after weight, abs_max: {}".
+ format(linear_name,
+ float(sub_layer.weight.abs().max().cast("float32"))))
+
+ self._remove_hook()
+ paddle.device.cuda.empty_cache()
+
+ def _remove_hook(self):
+ for hook in self._forward_hook_list:
+ hook.remove()
+ self._forward_hook_list = []
diff --git a/paddleslim/quant/advanced/utils.py b/paddleslim/quant/advanced/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..98f24ef124e10b345be6354f9c37fc352cfee6de
--- /dev/null
+++ b/paddleslim/quant/advanced/utils.py
@@ -0,0 +1,108 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import numpy as np
+from sklearn.cluster import KMeans
+
+
+def k_means(weight, n_clusters, init='k-means++', max_iter=300):
+ org_shape = weight.shape
+ weight = paddle.to_tensor(weight)
+ weight = paddle.reshape(weight, [-1, 1])
+ if n_clusters > weight.size:
+ n_clusters = weight.size
+
+ k_means = KMeans(
+ n_clusters=n_clusters,
+ init=init,
+ n_init=10,
+ algorithm='lloyd',
+ max_iter=max_iter)
+ k_means.fit(weight)
+
+ centroids = k_means.cluster_centers_
+ labels = k_means.labels_
+ labels = labels.reshape(org_shape)
+ return paddle.to_tensor(centroids.flatten()), paddle.to_tensor(labels)
+
+
+def compute_scales(x, method='abs_max'):
+ if method == 'abs_max':
+ quant_scale = float(paddle.max(paddle.abs(x.flatten())))
+ quant_scale = 1e-8 if quant_scale == 0.0 else quant_scale
+ elif method == 'avg':
+ quant_scale = paddle.abs(x.reshape((x.shape[0], -1)))
+ quant_scale = paddle.mean(paddle.max(quant_scale, axis=(1)))
+ elif method == 'abs_max_channel_wise':
+ reduce_axis = tuple([i for i in range(len(x.shape)) if i != 1])
+ quant_scale = paddle.max(paddle.abs(x), axis=reduce_axis)
+ quant_scale = paddle.where(quant_scale == np.float32(0.0),
+ np.float32(1e-8), quant_scale)
+ return quant_scale
+
+
+def find_parent_layer_and_sub_name(model, name):
+ last_idx = 0
+ idx = 0
+ parent_layer = model
+ while idx < len(name):
+ if name[idx] == '.':
+ sub_name = name[last_idx:idx]
+ if hasattr(parent_layer, sub_name):
+ parent_layer = getattr(parent_layer, sub_name)
+ last_idx = idx + 1
+ idx += 1
+ sub_name = name[last_idx:idx]
+ return parent_layer, sub_name
+
+
+def get_ln_linear_info(ln_linear_list, norm_flag, linear_flag, fused_qkv,
+ llama_ffn, skip_norm_list):
+ # ln_linear_dict: {layer_norm_0: [linear_0, linear_1, linear_2]}
+ ln_linear_dict = {}
+ # linear_ln_dict: {linear_0: layer_norm_0, linear_1: layer_norm_0}
+ linear_ln_dict = {}
+ for i in range(len(ln_linear_list)):
+ layer_name = ln_linear_list[i]
+ if norm_flag in layer_name and layer_name not in skip_norm_list:
+ if i < len(ln_linear_list) - 1:
+ if not fused_qkv:
+ if linear_flag in ln_linear_list[i +
+ 1] and linear_flag in ln_linear_list[i + 2] and linear_flag in ln_linear_list[i + 3] and int(
+ layer_name.split('_')
+ [-1]) % 2 == 0:
+ ln_linear_dict[layer_name] = [
+ ln_linear_list[i + 1], ln_linear_list[i + 2],
+ ln_linear_list[i + 3]
+ ]
+ linear_ln_dict[ln_linear_list[i + 1]] = layer_name
+ linear_ln_dict[ln_linear_list[i + 2]] = layer_name
+ linear_ln_dict[ln_linear_list[i + 3]] = layer_name
+ if linear_flag in ln_linear_list[i + 1] and int(
+ layer_name.split('_')[-1]) % 2 != 0:
+ if llama_ffn:
+ ln_linear_dict[layer_name] = [
+ ln_linear_list[i + 1], ln_linear_list[i + 2]
+ ]
+ linear_ln_dict[ln_linear_list[i + 1]] = layer_name
+ linear_ln_dict[ln_linear_list[i + 2]] = layer_name
+ else:
+ ln_linear_dict[layer_name] = [ln_linear_list[i + 1]]
+ linear_ln_dict[ln_linear_list[i + 1]] = layer_name
+ else:
+ if linear_flag in ln_linear_list[i + 1]:
+ ln_linear_dict[layer_name] = [ln_linear_list[i + 1]]
+ linear_ln_dict[ln_linear_list[i + 1]] = layer_name
+ return ln_linear_dict, linear_ln_dict
diff --git a/paddleslim/quant/advanced/utils_layers.py b/paddleslim/quant/advanced/utils_layers.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b06a3ede83a3ef3892026d2ef87c0284e8c320c
--- /dev/null
+++ b/paddleslim/quant/advanced/utils_layers.py
@@ -0,0 +1,92 @@
+# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import paddle
+import paddle.nn as nn
+from paddle import ParamAttr
+from paddle.nn.initializer import Constant
+from paddle.distributed.fleet.meta_parallel import ColumnParallelLinear, RowParallelLinear
+
+__all__ = ['ShiftSmoothHelpLayer', 'WOBiasHelpLayer']
+
+
+class ShiftSmoothHelpLayer(nn.Layer):
+ def __init__(self, layer):
+ super(ShiftSmoothHelpLayer, self).__init__()
+ self.weight = layer.weight
+ shift_shape = self.weight.shape[0]
+ if hasattr(layer, "bias") or layer.bias is None:
+ self.bias = paddle.create_parameter(
+ shape=[self.weight.shape[1]],
+ dtype=self.weight.dtype,
+ default_initializer=paddle.nn.initializer.Constant(0.0),
+ is_bias=True, )
+ layer.bias = self.bias
+ self.layer = layer
+ self.layer_type = type(layer)
+ # add
+ self.shift_bias = self.create_parameter(
+ shape=[shift_shape],
+ attr=ParamAttr(initializer=Constant(value=0.)),
+ dtype=self.weight.dtype)
+ # multiply
+ self.smooth_weight = self.create_parameter(
+ shape=[shift_shape],
+ attr=ParamAttr(initializer=Constant(value=1.)),
+ dtype=self.weight.dtype)
+
+ def forward(self, input):
+ shift_input = input
+ shift_input = paddle.add(shift_input, self.shift_bias)
+ smooth_input = paddle.multiply(shift_input, self.smooth_weight)
+ return self.layer(smooth_input)
+
+ def convert_weight(self, shift_bias=None, smooth_weight=None):
+ # shift
+ if shift_bias is not None:
+ shift_bias = shift_bias.cast(self.weight.dtype)
+ self.shift_bias.set_value(-shift_bias)
+ shift_linear_bias = paddle.matmul(shift_bias, self.weight)
+
+ if self.layer_type == RowParallelLinear:
+ parallel_shift_linear_bias = paddle.distributed.collective._mp_allreduce(
+ shift_linear_bias,
+ group=self.layer.model_parallel_group,
+ use_calc_stream=True,
+ use_model_parallel=True)
+ self.bias.set_value(self.bias + parallel_shift_linear_bias)
+ else:
+ self.bias.set_value(self.bias + shift_linear_bias)
+
+ # smooth
+ if smooth_weight is not None:
+ self.smooth_weight.set_value(
+ 1 / smooth_weight.squeeze().cast(self.weight.dtype))
+ self.weight.set_value(
+ self.weight * smooth_weight.transpose(perm=[1, 0]))
+
+
+class WOBiasHelpLayer(nn.Layer):
+ def __init__(self, layer):
+ super(WOBiasHelpLayer, self).__init__()
+ self.weight = layer.weight
+ self.bias = paddle.create_parameter(
+ shape=self.weight.shape,
+ dtype=self.weight.dtype,
+ default_initializer=paddle.nn.initializer.Constant(0.0),
+ is_bias=True, )
+ self.layer = layer
+
+ def forward(self, input):
+ return self.layer(input) + self.bias
diff --git a/paddleslim/quant/observers/abs_max_weight.py b/paddleslim/quant/observers/abs_max_weight.py
index 5594ab1cacd0a31de7a844bfc21616a6dcad5e84..021ce57759d51f16b605c766baaad5086293cdd6 100644
--- a/paddleslim/quant/observers/abs_max_weight.py
+++ b/paddleslim/quant/observers/abs_max_weight.py
@@ -52,6 +52,7 @@ class AbsMaxChannelWiseWeightObserverLayer(ChannelWiseObserver):
self.quant_bits = quant_bits
self.calibration_loss = float('inf')
self.qmin, self.qmax = self.qmin_qmax
+ self._layer = layer
self._max = None
self._scale = None
self._zero_point = None
@@ -64,26 +65,11 @@ class AbsMaxChannelWiseWeightObserverLayer(ChannelWiseObserver):
def _cal_abs_max(self, inputs):
reduce_axis = tuple(
[i for i in range(len(inputs.shape)) if i != self.quant_axis()])
- abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis)
+ abs_max_values = paddle.max(
+ paddle.abs(inputs), axis=reduce_axis).cast("float32")
abs_max_values = paddle.where(abs_max_values == np.float32(0.0),
np.float32(1e-8), abs_max_values)
- minimum_loss = paddle.full(abs_max_values.shape, float('inf'))
- result = abs_max_values
- factor = 0.3
- while factor <= 1.0:
- scales = factor * abs_max_values
- factor += 0.02
- expand_scales = paddle.unsqueeze(scales, axis=reduce_axis)
- quant_var = paddle.clip(
- paddle.round(inputs / expand_scales * self.qmax), self.qmin,
- self.qmax)
- quant_dequant_var = quant_var / self.qmax * expand_scales
-
- mse_loss = ((inputs - quant_dequant_var)**2).mean(axis=reduce_axis)
- result = paddle.where(mse_loss < minimum_loss, scales, result)
- minimum_loss = paddle.minimum(mse_loss, minimum_loss)
-
- return result
+ return abs_max_values
def min_value(self) -> float:
return 0.
diff --git a/paddleslim/quant/observers/mse_weight.py b/paddleslim/quant/observers/mse_weight.py
index 16aa272680886d653e5af779ce27b8299826ee56..2a17242a799be326b7c2901c7f409d30cd097eba 100644
--- a/paddleslim/quant/observers/mse_weight.py
+++ b/paddleslim/quant/observers/mse_weight.py
@@ -54,4 +54,20 @@ class MSEChannelWiseWeightObserverLayer(AbsMaxChannelWiseWeightObserverLayer):
abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis)
abs_max_values = paddle.where(abs_max_values == np.float32(0.0),
np.float32(1e-8), abs_max_values)
- return abs_max_values
+ minimum_loss = paddle.full(abs_max_values.shape, float('inf'))
+ result = abs_max_values
+ factor = 0.3
+ while factor <= 1.0:
+ scales = factor * abs_max_values
+ factor += 0.02
+ expand_scales = paddle.unsqueeze(scales, axis=reduce_axis)
+ quant_var = paddle.clip(
+ paddle.round(inputs / expand_scales * self.qmax), self.qmin,
+ self.qmax)
+ quant_dequant_var = quant_var / self.qmax * expand_scales
+
+ mse_loss = ((inputs - quant_dequant_var)**2).mean(axis=reduce_axis)
+ result = paddle.where(mse_loss < minimum_loss, scales, result)
+ minimum_loss = paddle.minimum(mse_loss, minimum_loss)
+
+ return result